From 0498d9d3da36b7b4ac23d7385fcddc10fabe67ce Mon Sep 17 00:00:00 2001 From: samthakur587 Date: Thu, 18 Jan 2024 02:12:57 +0530 Subject: [PATCH 1/2] fix: fixed unsupported complex dtype at jax and torch backend --- ivy/functional/backends/jax/elementwise.py | 1 + ivy/functional/backends/torch/elementwise.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ivy/functional/backends/jax/elementwise.py b/ivy/functional/backends/jax/elementwise.py index 056996b8cc710..eb55534b074fa 100644 --- a/ivy/functional/backends/jax/elementwise.py +++ b/ivy/functional/backends/jax/elementwise.py @@ -28,6 +28,7 @@ def abs( return jnp.where(x != 0, jnp.absolute(x), 0) +@with_unsupported_dtypes({"0.4.23 and below": ("complex",)}, backend_version) def acos(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jnp.arccos(x) diff --git a/ivy/functional/backends/torch/elementwise.py b/ivy/functional/backends/torch/elementwise.py index 7aec8d25c8a0c..b9ab20789a2de 100644 --- a/ivy/functional/backends/torch/elementwise.py +++ b/ivy/functional/backends/torch/elementwise.py @@ -403,7 +403,7 @@ def greater_equal( greater_equal.support_native_out = True -@with_unsupported_dtypes({"2.1.2 and below": ("float16",)}, backend_version) +@with_unsupported_dtypes({"2.1.2 and below": ("complex",)}, backend_version) @handle_numpy_arrays_in_specific_backend def acos(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: x = _cast_for_unary_op(x) From a6ab47ae660ad6c79567c7b87f9501e6036ca4b8 Mon Sep 17 00:00:00 2001 From: samthakur587 Date: Thu, 18 Jan 2024 02:22:41 +0530 Subject: [PATCH 2/2] fixed the test function --- .../test_ivy/test_functional/test_core/test_elementwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index cea111ff23d2e..d5a9f9522c5dc 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -177,7 +177,7 @@ def test_abs(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.acos", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=4, small_abs_safety_factor=4, safety_factor_scale="log",