Skip to content

Commit

Permalink
Implement the Categorify() start_index feature request of #1074
Browse files Browse the repository at this point in the history
This commit implements the feature in issue #1074. This issue
asks to add an argument start_index to Categorify to give an offset
for translating vocabulary items to categorical values.

We update nvtabular/ops/categorify.py to add a start_index arg in the
implementation of Categorify(). This update touches the categorify.py module
in various places. We also add docstrings to the _encode() and
_write_uniques() methods for improved readability in categorify.py.

We also update the test_categorify_lists_with_start_index() test
method in tests/unit/test_ops.py to test various start_index
values.
  • Loading branch information
Adam Lesnikowski committed Sep 3, 2021
1 parent 642e5f5 commit 8cfdd83
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 6 deletions.
48 changes: 46 additions & 2 deletions nvtabular/ops/categorify.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class Categorify(StatOperator):
value will be `max_size - num_buckets -1`. Setting the max_size param means that
freq_threshold should not be given. If the num_buckets parameter is set, it must be
smaller than the max_size value.
start_index: int, default 0
start_index: int, default 1
The start index where Categorify will begin to translate dataframe entries
into integer values. For instance, if our original translated dataframe entries appear
as [[1], [1, 4], [3, 2], [2]], then with a start_index of 16, Categorify will now be
Expand All @@ -197,7 +197,7 @@ def __init__(
num_buckets=None,
vocabs=None,
max_size=0,
start_index=0,
start_index=1,
):

# We need to handle three types of encoding here:
Expand Down Expand Up @@ -250,6 +250,7 @@ def __init__(
self.cat_cache = cat_cache
self.encode_type = encode_type
self.search_sorted = search_sorted
self.start_index = start_index

if self.search_sorted and self.freq_threshold:
raise ValueError(
Expand Down Expand Up @@ -445,6 +446,7 @@ def transform(self, col_selector: ColumnSelector, df: DataFrameType) -> DataFram
cat_names=cat_names,
max_size=self.max_size,
dtype=self.dtype,
start_index=self.start_index,
)
except Exception as e:
raise RuntimeError(f"Failed to categorical encode column {name}") from e
Expand Down Expand Up @@ -859,6 +861,20 @@ def _write_gb_stats(dfs, base_path, col_selector: ColumnSelector, options: FitOp

@annotate("write_uniques", color="green", domain="nvt_python")
def _write_uniques(dfs, base_path, col_selector: ColumnSelector, options: FitOptions):
"""Writes out a dataframe to a parquet file.
Args:
dfs (DataFrame): [description]
base_path (string): [description]
col_selector (ColumnSelector): [description]
options (FitOptions): [description]
Raises:
ValueError: [description]
Returns:
string: the path to the output parquet file.
"""
if options.concat_groups and len(col_selector) > 1:
col_selector = ColumnSelector([_make_name(*col_selector.names, sep=options.name_sep)])

Expand Down Expand Up @@ -1052,7 +1068,30 @@ def _encode(
cat_names=None,
max_size=0,
dtype=None,
start_index=1
):
"""The _encode method is responsible for transforming a dataframe (taking the written
out vocabulary file and looking up values to translate from say string inputs to numeric
outputs)
Args:
name ([type]): [description]
storage_name ([type]): [description]
path ([type]): [description]
df ([type]): [description]
cat_cache ([type]): [description]
na_sentinel (int, optional): [description]. Defaults to -1.
freq_threshold (int, optional): [description]. Defaults to 0.
search_sorted (bool, optional): [description]. Defaults to False.
buckets ([type], optional): [description]. Defaults to None.
encode_type (str, optional): [description]. Defaults to "joint".
cat_names ([type], optional): [description]. Defaults to None.
max_size (int, optional): [description]. Defaults to 0.
dtype ([type], optional): [description]. Defaults to None.
Returns:
[type]: labels
"""
if isinstance(buckets, int):
buckets = {name: buckets for name in cat_names}
# this is to apply freq_hashing logic
Expand Down Expand Up @@ -1142,6 +1181,11 @@ def _encode(
elif dtype:
labels = labels.astype(dtype, copy=False)

labels = [i + (start_index - 1) for i in labels]
if isinstance(labels, np.ndarray):
labels = np.array(labels)
elif isinstance(labels, pd.Series):
labels = pd.Series(labels)
return labels


Expand Down
8 changes: 4 additions & 4 deletions tests/unit/ops/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def test_categorify_lists(tmpdir, freq_threshold, cpu, dtype, vocabs):
@pytest.mark.parametrize("cpu", _CPU)
@pytest.mark.parametrize("dtype", [None, np.int32, np.int64])
@pytest.mark.parametrize("vocabs", [None, pd.DataFrame({"Authors": [f"User_{x}" for x in "ACBE"]})])
@pytest.mark.parametrize("start_index", [0, 2, 16])
@pytest.mark.parametrize("start_index", [1, 2, 16])
def test_categorify_lists_with_start_index(tmpdir, cpu, dtype, vocabs, start_index):
df = dispatch._make_df(
{
Expand All @@ -520,7 +520,7 @@ def test_categorify_lists_with_start_index(tmpdir, cpu, dtype, vocabs, start_ind
label_name = ["Post"]

cat_features = cat_names >> ops.Categorify(
out_path=str(tmpdir), dtype=dtype, vocabs=vocabs
out_path=str(tmpdir), dtype=dtype, vocabs=vocabs, start_index=start_index
)

workflow = nvt.Workflow(cat_features + label_name)
Expand All @@ -531,11 +531,11 @@ def test_categorify_lists_with_start_index(tmpdir, cpu, dtype, vocabs, start_ind
else:
compare = df_out["Authors"].to_arrow().to_pylist()

if start_index == 0:
if start_index == 1:
assert compare == [[1], [1, 4], [3, 2], [2]]

if start_index == 2:
assert compare == [[3], [3, 6], [5, 4], [4]]
assert compare == [[2], [2, 5], [4, 3], [3]]

if start_index == 16:
assert compare == [[16], [16, 19], [18, 17], [17]]
Expand Down

0 comments on commit 8cfdd83

Please sign in to comment.