Skip to content

Commit

Permalink
feat: add one_hot function of paddle frontend(#24153)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamruddhiNavale committed Mar 28, 2024
1 parent 74db123 commit 05899d3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
10 changes: 10 additions & 0 deletions ivy/functional/frontends/paddle/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ def index_add(x, index, axis, value, *, name=None):
return ret


@with_supported_dtypes({"2.5.1 and below": ("int32", "int64")}, "paddle")
@to_ivy_arrays_and_back
def one_hot(x, num_classes, name=None):
if not isinstance(num_classes, int) or num_classes <= 0:
raise ValueError("num_classes must be a positive integer.")

one_hot_tensor = ivy.one_hot(x, num_classes)
return one_hot_tensor.astype(ivy.float32)


@to_ivy_arrays_and_back
def put_along_axis(arr, indices, values, axis, reduce="assign"):
result = ivy.put_along_axis(arr, indices, values, axis)
Expand Down
33 changes: 33 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,39 @@ def dtypes_x_reshape(draw):
# ------------ #


@handle_frontend_test(
fn_tree="paddle.one_hot",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=["int64"],
min_value=0,
min_num_dims=1,
max_num_dims=5,
),
num_classes=helpers.ints(min_value=1),
)
def test_one_hot(
*,
dtype_and_x,
num_classes,
on_device,
fn_tree,
frontend,
test_flags,
backend_fw,
):
input_dtype, x = dtype_and_x
helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
x=x[0],
num_classes=num_classes,
backend_to_test=backend_fw,
)


# abs
@handle_frontend_test(
fn_tree="paddle.abs",
Expand Down

0 comments on commit 05899d3

Please sign in to comment.