Skip to content

Commit

Permalink
improve Categorify inference testing
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Apr 10, 2024
1 parent 9f0ba33 commit 337e904
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/unit/ops/test_categorify.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,3 +734,8 @@ def test_categorify_inference():
output_tensors = inference_op.transform(cats.input_columns, input_tensors)
for key in input_tensors:
assert output_tensors[key].dtype == np.dtype("int64")

# Check results are consistent with python code path
expect = workflow.transform(df)
got = pd.DataFrame(output_tensors)
assert_eq(expect, got)

0 comments on commit 337e904

Please sign in to comment.