diff --git a/test/test_expression.py b/test/test_expression.py index cbbed2a35..1b973e9a6 100644 --- a/test/test_expression.py +++ b/test/test_expression.py @@ -642,6 +642,81 @@ def test_np_bool_handling(ctx_factory): assert out.get().item() is True +@pytest.mark.parametrize("target", [lp.PyOpenCLTarget, lp.ExecutableCTarget]) +def test_complex_functions_with_real_args(ctx_factory, target): + # Reported by David Ham. See + t_unit = lp.make_kernel( + "{[i]: 0<=i<10}", + """ + y1[i] = abs(c64[i]) + y2[i] = real(c64[i]) + y3[i] = imag(c64[i]) + y4[i] = conj(c64[i]) + + y5[i] = abs(c128[i]) + y6[i] = real(c128[i]) + y7[i] = imag(c128[i]) + y8[i] = conj(c128[i]) + + + y9[i] = abs(f32[i]) + y10[i] = real(f32[i]) + y11[i] = imag(f32[i]) + y12[i] = conj(f32[i]) + + y13[i] = abs(f64[i]) + y14[i] = real(f64[i]) + y15[i] = imag(f64[i]) + y16[i] = conj(f64[i]) + """, + target=target()) + + t_unit = lp.add_dtypes(t_unit, + {"y9,y10,y11,y12": np.complex64, + "y13,y14,y15,y16": np.complex128, + "c64": np.complex64, + "c128": np.complex128, + "f64": np.float64, + "f32": np.float32}) + t_unit = lp.set_options(t_unit, return_dict=True) + + from numpy.random import default_rng + rng = default_rng(0) + c64 = (rng.random(10, dtype=np.float32) + + np.csingle(1j)*rng.random(10, dtype=np.float32)) + c128 = (rng.random(10, dtype=np.float64) + + np.cdouble(1j)*rng.random(10, dtype=np.float64)) + f32 = rng.random(10, dtype=np.float32) + f64 = rng.random(10, dtype=np.float64) + + if target == lp.PyOpenCLTarget: + cl_ctx = ctx_factory() + with cl.CommandQueue(cl_ctx) as queue: + evt, out = t_unit(queue, c64=c64, c128=c128, f32=f32, f64=f64) + elif target == lp.ExecutableCTarget: + t_unit = lp.set_options(t_unit, build_options=["-Werror"]) + evt, out = t_unit(c64=c64, c128=c128, f32=f32, f64=f64) + else: + raise NotImplementedError("unsupported target") + + np.testing.assert_allclose(out["y1"], np.abs(c64), rtol=1e-6) + np.testing.assert_allclose(out["y2"], np.real(c64), rtol=1e-6) + np.testing.assert_allclose(out["y3"], np.imag(c64), rtol=1e-6) + np.testing.assert_allclose(out["y4"], np.conj(c64), rtol=1e-6) + np.testing.assert_allclose(out["y5"], np.abs(c128), rtol=1e-6) + np.testing.assert_allclose(out["y6"], np.real(c128), rtol=1e-6) + np.testing.assert_allclose(out["y7"], np.imag(c128), rtol=1e-6) + np.testing.assert_allclose(out["y8"], np.conj(c128), rtol=1e-6) + np.testing.assert_allclose(out["y9"], np.abs(f32), rtol=1e-6) + np.testing.assert_allclose(out["y10"], np.real(f32), rtol=1e-6) + np.testing.assert_allclose(out["y11"], np.imag(f32), rtol=1e-6) + np.testing.assert_allclose(out["y12"], np.conj(f32), rtol=1e-6) + np.testing.assert_allclose(out["y13"], np.abs(f64), rtol=1e-6) + np.testing.assert_allclose(out["y14"], np.real(f64), rtol=1e-6) + np.testing.assert_allclose(out["y15"], np.imag(f64), rtol=1e-6) + np.testing.assert_allclose(out["y16"], np.conj(f64), rtol=1e-6) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])