Skip to content

Commit

Permalink
make sure dtypes are correct
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Jul 29, 2024
1 parent 3862601 commit 9bd271b
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions tests/unit/test_dask_nvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def test_dask_workflow_api_dlrm(

@pytest.mark.parametrize("part_mem_fraction", [0.01])
def test_dask_groupby_stats(client, tmpdir, datasets, part_mem_fraction):
from nvtabular.ops.join_groupby import AGG_DTYPES

set_dask_client(client=client)

engine = "parquet"
Expand Down Expand Up @@ -175,12 +177,14 @@ def test_dask_groupby_stats(client, tmpdir, datasets, part_mem_fraction):
gb_e = expect.groupby("name-cat").aggregate({"name-cat": "count", "x": ["sum", "min", "std"]})
gb_e.columns = ["count", "sum", "min", "std"]
df_check = got.merge(gb_e, left_on="name-cat", right_index=True, how="left")
# Names and dtypes don't need to match (just values)
options = {"check_names": False, "check_dtype": False}
assert_eq(df_check["name-cat_count"], df_check["count"], **options)
assert_eq(df_check["name-cat_x_sum"], df_check["sum"], **options)
assert_eq(df_check["name-cat_x_min"], df_check["min"], **options)
assert_eq(df_check["name-cat_x_std"], df_check["std"], **options)
assert_eq(
df_check["name-cat_count"], df_check["count"].astype(AGG_DTYPES["count"]), check_names=False
)
assert_eq(df_check["name-cat_x_sum"], df_check["sum"], check_names=False)
assert_eq(df_check["name-cat_x_min"], df_check["min"], check_names=False)
assert_eq(
df_check["name-cat_x_std"], df_check["std"].astype(AGG_DTYPES["std"]), check_names=False
)


@pytest.mark.parametrize("part_mem_fraction", [0.01])
Expand Down

0 comments on commit 9bd271b

Please sign in to comment.