diff --git a/ivy/functional/frontends/paddle/manipulation.py b/ivy/functional/frontends/paddle/manipulation.py index 6c2c8d6a90adc..fb486570f5adc 100644 --- a/ivy/functional/frontends/paddle/manipulation.py +++ b/ivy/functional/frontends/paddle/manipulation.py @@ -117,6 +117,16 @@ def index_add(x, index, axis, value, *, name=None): return ret +@with_supported_dtypes({"2.5.1 and below": ("int32", "int64")}, "paddle") +@to_ivy_arrays_and_back +def one_hot(x, num_classes, name=None): + if not isinstance(num_classes, int) or num_classes <= 0: + raise ValueError("num_classes must be a positive integer.") + + one_hot_tensor = ivy.one_hot(x, num_classes) + return one_hot_tensor.astype(ivy.float32) + + @to_ivy_arrays_and_back def put_along_axis(arr, indices, values, axis, reduce="assign"): result = ivy.put_along_axis(arr, indices, values, axis) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py index c3515fd426518..435af95f08a43 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py @@ -297,6 +297,39 @@ def dtypes_x_reshape(draw): # ------------ # +@handle_frontend_test( + fn_tree="paddle.one_hot", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=["int64"], + min_value=0, + min_num_dims=1, + max_num_dims=5, + ), + num_classes=helpers.ints(min_value=1), +) +def test_one_hot( + *, + dtype_and_x, + num_classes, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + num_classes=num_classes, + backend_to_test=backend_fw, + ) + + # abs @handle_frontend_test( fn_tree="paddle.abs",