From 47d65701378cc18972ff53569d7039647cc9de2c Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 12 Jul 2021 14:29:52 +0200 Subject: [PATCH 1/9] Allow to pass in vocabs in Categorify to fix make_feature_column_workflow --- .gitignore | 2 + nvtabular/dispatch.py | 4 + .../tensorflow/feature_column_utils.py | 23 +---- nvtabular/ops/categorify.py | 83 ++++++++++++++----- tests/unit/test_ops.py | 10 +-- tests/unit/test_tf_feature_columns.py | 24 ++++++ 6 files changed, 99 insertions(+), 47 deletions(-) create mode 100644 tests/unit/test_tf_feature_columns.py diff --git a/.gitignore b/.gitignore index 3594bfc2f58..422a3064cec 100644 --- a/.gitignore +++ b/.gitignore @@ -71,3 +71,5 @@ ipython_config.py .dmypy.json dmypy.json +# PyCharm +.idea \ No newline at end of file diff --git a/nvtabular/dispatch.py b/nvtabular/dispatch.py index 43aa8b4b8b3..25814a07ae9 100644 --- a/nvtabular/dispatch.py +++ b/nvtabular/dispatch.py @@ -101,6 +101,10 @@ def _is_cpu_object(x): return isinstance(x, (pd.DataFrame, pd.Series)) +def is_series_or_dataframe_object(maybe_series_or_df): + return _is_series_object(maybe_series_or_df) or _is_dataframe_object(maybe_series_or_df) + + def _hex_to_int(s, dtype=None): def _pd_convert_hex(x): if pd.isnull(x): diff --git a/nvtabular/framework_utils/tensorflow/feature_column_utils.py b/nvtabular/framework_utils/tensorflow/feature_column_utils.py index 1b9658169c0..6099ab66340 100644 --- a/nvtabular/framework_utils/tensorflow/feature_column_utils.py +++ b/nvtabular/framework_utils/tensorflow/feature_column_utils.py @@ -13,10 +13,9 @@ # limitations under the License. # -import os import warnings -import cudf +import pandas as pd import tensorflow as tf from tensorflow.python.feature_column import feature_column_v2 as fc @@ -227,7 +226,7 @@ def _get_parents(column): features += features_replaced_buckets if len(categorifies) > 0: - features += categorifies.keys() >> Categorify() + features += categorifies.keys() >> Categorify(vocabs=pd.DataFrame(categorifies)) if len(hashes) > 0: features += hashes.keys() >> HashBucket(hashes) @@ -282,22 +281,4 @@ def _get_parents(column): workflow = nvt.Workflow(features) - # create stats for Categorify op if we need it - if len(categorifies) > 0: - if category_dir is None: - category_dir = "/tmp/categories" # nosec - if not os.path.exists(category_dir): - os.makedirs(category_dir) - - stats = {"categories": {}} - for feature_name, categories in categorifies.items(): - categories.insert(0, None) - df = cudf.DataFrame({feature_name: categories}) - - save_path = os.path.join(category_dir, f"unique.{feature_name}.parquet") - df.to_parquet(save_path) - stats["categories"][feature_name] = save_path - - workflow.stats = stats - return workflow, numeric_columns + new_feature_columns diff --git a/nvtabular/ops/categorify.py b/nvtabular/ops/categorify.py index efe6f09b116..e720332455e 100644 --- a/nvtabular/ops/categorify.py +++ b/nvtabular/ops/categorify.py @@ -15,6 +15,7 @@ import os import warnings +from copy import deepcopy from dataclasses import dataclass from operator import getitem from typing import Optional, Union @@ -39,6 +40,7 @@ _flatten_list_column, _from_host, _hash_series, + _is_dataframe_object, _is_list_dtype, _make_df, _parquet_writer_dispatch, @@ -199,6 +201,7 @@ def __init__( name_sep="_", search_sorted=False, num_buckets=None, + vocabs=None, max_size=0, ): @@ -239,7 +242,7 @@ def __init__( if encode_type not in ("joint", "combo"): raise ValueError(f"encode_type={encode_type} not supported.") - # Other self-explanatory intialization + # Other self-explanatory initialization super().__init__() self.freq_threshold = freq_threshold or 0 self.out_path = out_path or "./" @@ -250,7 +253,6 @@ def __init__( self.cat_cache = cat_cache self.encode_type = encode_type self.search_sorted = search_sorted - self.categories = {} if self.search_sorted and self.freq_threshold: raise ValueError( @@ -285,6 +287,9 @@ def __init__( "with this num_buckets setting!" ) + self.vocabs = self.process_vocabs(vocabs if vocabs is not None else {}) + self.categories = deepcopy(self.vocabs) + @annotate("Categorify_fit", color="darkgreen", domain="nvt_python") def fit(self, columns: ColumnNames, ddf: dd.DataFrame): # User passed in a list of column groups. We need to figure out @@ -320,23 +325,11 @@ def fit(self, columns: ColumnNames, ddf: dd.DataFrame): warnings.warn("Cannot use `search_sorted=True` for pandas-backed data.") # convert tuples to lists - columns = [list(c) if isinstance(c, tuple) else c for c in columns] - dsk, key = _category_stats( - ddf, - FitOptions( - columns, - [], - [], - self.out_path, - self.freq_threshold, - self.tree_width, - self.on_host, - concat_groups=self.encode_type == "joint", - name_sep=self.name_sep, - max_size=self.max_size, - num_buckets=self.num_buckets, - ), - ) + cols_with_vocabs = list(self.categories.keys()) + columns = [ + list(c) if isinstance(c, tuple) else c for c in columns if c not in cols_with_vocabs + ] + dsk, key = _category_stats(ddf, self._create_fit_options_from_columns(columns)) return Delayed(key, dsk) def fit_finalize(self, categories): @@ -344,7 +337,57 @@ def fit_finalize(self, categories): self.categories[col] = categories[col] def clear(self): - self.categories = {} + self.categories = deepcopy(self.vocabs) + + def process_vocabs(self, vocabs): + categories = {} + + if _is_dataframe_object(vocabs): + fit_options = self._create_fit_options_from_columns(list(vocabs.columns)) + base_path = os.path.join(self.out_path, fit_options.stat_name) + os.makedirs(base_path, exist_ok=True) + for col in list(vocabs.columns): + col_df = vocabs[[col]] + if col_df[col].iloc[0] is not None: + if isinstance(col_df, pd.DataFrame): + col_df = pd.DataFrame( + {col: pd.concat([pd.Series([None]), col_df[col]]).reset_index()[0]} + ) + else: + import cudf + + col_df = cudf.DataFrame( + {col: cudf.concat([cudf.Series([None]), col_df[col]]).reset_index()[0]} + ) + + save_path = os.path.join(base_path, f"unique.{col}.parquet") + col_df.to_parquet(save_path) + categories[col] = save_path + elif isinstance(vocabs, dict) and all(isinstance(v, str) for v in vocabs.values()): + categories = vocabs + else: + error = """Unrecognized vocab type, + please provide either a dictionary with paths to a parquet files + or a DataFrame that contains the vocabulary per column. + """ + raise ValueError(error) + + return categories + + def _create_fit_options_from_columns(self, columns) -> "FitOptions": + return FitOptions( + columns, + [], + [], + self.out_path, + self.freq_threshold, + self.tree_width, + self.on_host, + concat_groups=self.encode_type == "joint", + name_sep=self.name_sep, + max_size=self.max_size, + num_buckets=self.num_buckets, + ) def set_storage_path(self, new_path, copy=False): self.categories = _copy_storage(self.categories, self.out_path, new_path, copy=copy) diff --git a/tests/unit/test_ops.py b/tests/unit/test_ops.py index 85f1af06cf3..c50f6c0c4f8 100644 --- a/tests/unit/test_ops.py +++ b/tests/unit/test_ops.py @@ -107,7 +107,6 @@ def test_target_encode(tmpdir, cat_groups, kfold, fold_seed, cpu): @pytest.mark.parametrize("npartitions", [1, 2]) @pytest.mark.parametrize("cpu", [True, False]) def test_target_encode_multi(tmpdir, npartitions, cpu): - cat_1 = np.asarray(["baaaa"] * 12) cat_2 = np.asarray(["baaaa"] * 6 + ["bbaaa"] * 3 + ["bcaaa"] * 3) num_1 = np.asarray([1, 1, 2, 2, 2, 1, 1, 5, 4, 4, 4, 4]) @@ -445,7 +444,8 @@ def test_lambdaop_misalign(cpu): @pytest.mark.parametrize("freq_threshold", [0, 1, 2]) @pytest.mark.parametrize("cpu", [False, True]) @pytest.mark.parametrize("dtype", [None, np.int32, np.int64]) -def test_categorify_lists(tmpdir, freq_threshold, cpu, dtype): +@pytest.mark.parametrize("vocabs", [None, pd.DataFrame({"Authors": [f"User_{x}" for x in "ABCE"]})]) +def test_categorify_lists(tmpdir, freq_threshold, cpu, dtype, vocabs): df = cudf.DataFrame( { "Authors": [["User_A"], ["User_A", "User_E"], ["User_B", "User_C"], ["User_C"]], @@ -457,7 +457,7 @@ def test_categorify_lists(tmpdir, freq_threshold, cpu, dtype): label_name = ["Post"] cat_features = cat_names >> ops.Categorify( - out_path=str(tmpdir), freq_threshold=freq_threshold, dtype=dtype + out_path=str(tmpdir), freq_threshold=freq_threshold, dtype=dtype, vocabs=vocabs ) workflow = nvt.Workflow(cat_features + label_name) @@ -471,7 +471,7 @@ def test_categorify_lists(tmpdir, freq_threshold, cpu, dtype): assert df_out["Authors"].dtype == cudf.core.dtypes.ListDtype(dtype if dtype else "int64") compare = df_out["Authors"].to_arrow().to_pylist() - if freq_threshold < 2: + if freq_threshold < 2 or vocabs is not None: assert compare == [[1], [1, 4], [2, 3], [3]] else: assert compare == [[1], [1, 0], [0, 2], [2]] @@ -767,7 +767,6 @@ def test_joingroupby_dependency(tmpdir, cpu): @pytest.mark.parametrize("cpu", [True, False]) @pytest.mark.parametrize("groups", [[["Author", "Engaging-User"]], "Author"]) def test_joingroupby_multi(tmpdir, groups, cpu): - df = pd.DataFrame( { "Author": ["User_A", "User_A", "User_A", "User_B"], @@ -820,7 +819,6 @@ def test_joingroupby_multi(tmpdir, groups, cpu): @pytest.mark.parametrize("cpu", [True, False]) @pytest.mark.parametrize("drop_duplicates", [True, False]) def test_join_external(tmpdir, df, dataset, engine, kind_ext, cache, how, cpu, drop_duplicates): - # Define "external" table shift = 100 df_ext = df[["id"]].copy().sort_values("id") diff --git a/tests/unit/test_tf_feature_columns.py b/tests/unit/test_tf_feature_columns.py new file mode 100644 index 00000000000..985217f81df --- /dev/null +++ b/tests/unit/test_tf_feature_columns.py @@ -0,0 +1,24 @@ +import pytest + +tf = pytest.importorskip("tensorflow") +nvtf = pytest.importorskip("nvtabular.framework_utils.tensorflow") + + +def test_feature_column_utils(): + cols = [ + tf.feature_column.embedding_column( + tf.feature_column.categorical_column_with_vocabulary_list( + "vocab_1", ["a", "b", "c", "d"] + ), + 16, + ), + tf.feature_column.embedding_column( + tf.feature_column.categorical_column_with_vocabulary_list( + "vocab_2", ["1", "2", "3", "4"] + ), + 32, + ), + ] + + workflow, _ = nvtf.make_feature_column_workflow(cols, "target") + assert workflow.column_group.columns == ["target", "vocab_1", "vocab_2"] From 110d1df66c251a8ccff8ee2899d4ccb42a140e53 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Wed, 14 Jul 2021 10:29:06 +0200 Subject: [PATCH 2/9] Update .gitignore Co-authored-by: Karl Higley --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 422a3064cec..2a3ef1fd806 100644 --- a/.gitignore +++ b/.gitignore @@ -72,4 +72,4 @@ ipython_config.py dmypy.json # PyCharm -.idea \ No newline at end of file +.idea From 4d32576175eab491f12be8bd7592166fc9dcbaf8 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Wed, 14 Jul 2021 11:01:23 +0200 Subject: [PATCH 3/9] Addressing PR comments --- nvtabular/dispatch.py | 17 ++++++++++ nvtabular/ops/categorify.py | 67 ++++++++++++++----------------------- 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/nvtabular/dispatch.py b/nvtabular/dispatch.py index 25814a07ae9..ffadef20f8e 100644 --- a/nvtabular/dispatch.py +++ b/nvtabular/dispatch.py @@ -320,11 +320,28 @@ def _make_df(_like_df=None, device=None): return pd.DataFrame(_like_df) elif isinstance(_like_df, (cudf.DataFrame, cudf.Series)): return cudf.DataFrame(_like_df) + elif isinstance(_like_df, dict) and len(_like_df) > 0: + is_pandas = all(isinstance(v, pd.Series) for v in _like_df.values()) + + return pd.DataFrame(_like_df) if is_pandas else cudf.DataFrame(_like_df) if device == "cpu": return pd.DataFrame() return cudf.DataFrame() +def _add_to_series(series, to_add, prepend=True): + if isinstance(series, pd.Series): + series_2 = pd.Series(to_add) + elif isinstance(series, cudf.Series): + series_2 = cudf.Series(to_add) + else: + raise ValueError("Unrecognized series, please provide either a pandas a cudf series") + + to_concat = [series, series_2] if prepend else [series_2, series] + + return _concat(to_concat) + + def _detect_format(data): """Utility to detect the format of `data`""" from nvtabular import Dataset diff --git a/nvtabular/ops/categorify.py b/nvtabular/ops/categorify.py index e720332455e..68ba153ffd2 100644 --- a/nvtabular/ops/categorify.py +++ b/nvtabular/ops/categorify.py @@ -33,21 +33,8 @@ from fsspec.core import get_fs_token_paths from pyarrow import parquet as pq -from nvtabular.dispatch import ( - DataFrameType, - _arange, - _encode_list_column, - _flatten_list_column, - _from_host, - _hash_series, - _is_dataframe_object, - _is_list_dtype, - _make_df, - _parquet_writer_dispatch, - _read_parquet_dispatch, - _series_has_nulls, - annotate, -) +from nvtabular import dispatch +from nvtabular.dispatch import DataFrameType, annotate from nvtabular.worker import fetch_table_data, get_worker_cache from .operator import ColumnNames, Operator @@ -241,6 +228,8 @@ def __init__( # Only support two kinds of multi-column encoding if encode_type not in ("joint", "combo"): raise ValueError(f"encode_type={encode_type} not supported.") + if encode_type == "joint" and vocabs: + raise ValueError("Passing in vocabs is not supported with a joint encoding.") # Other self-explanatory initialization super().__init__() @@ -287,7 +276,9 @@ def __init__( "with this num_buckets setting!" ) - self.vocabs = self.process_vocabs(vocabs if vocabs is not None else {}) + self.vocabs = {} + if vocabs is not None: + self.vocabs = self.process_vocabs(vocabs) self.categories = deepcopy(self.vocabs) @annotate("Categorify_fit", color="darkgreen", domain="nvt_python") @@ -342,23 +333,15 @@ def clear(self): def process_vocabs(self, vocabs): categories = {} - if _is_dataframe_object(vocabs): + if dispatch._is_dataframe_object(vocabs): fit_options = self._create_fit_options_from_columns(list(vocabs.columns)) base_path = os.path.join(self.out_path, fit_options.stat_name) os.makedirs(base_path, exist_ok=True) for col in list(vocabs.columns): col_df = vocabs[[col]] if col_df[col].iloc[0] is not None: - if isinstance(col_df, pd.DataFrame): - col_df = pd.DataFrame( - {col: pd.concat([pd.Series([None]), col_df[col]]).reset_index()[0]} - ) - else: - import cudf - - col_df = cudf.DataFrame( - {col: cudf.concat([cudf.Series([None]), col_df[col]]).reset_index()[0]} - ) + vals = {col: dispatch._add_to_series(col_df[col], [None]).reset_index()[0]} + col_df = dispatch._make_df(vals) save_path = os.path.join(base_path, f"unique.{col}.parquet") col_df.to_parquet(save_path) @@ -468,7 +451,7 @@ def get_embedding_sizes(self, columns): def inference_initialize(self, columns: ColumnNames, model_config: dict) -> Optional[Operator]: # on the first transform call we load up categories from disk, which can # take multiple seconds. preload this data by running an empty dataframe through - df = _make_df() + df = dispatch._make_df() for column in columns: df[column] = [] self.transform(columns, df) @@ -529,7 +512,7 @@ def get_embedding_sizes(source, output_dtypes=None): for column in output: dtype = output_dtypes.get(column) - if dtype and _is_list_dtype(dtype): + if dtype and dispatch._is_list_dtype(dtype): # multi hot so remove from output and add to multihot multihot_columns.add(column) # TODO: returning differnt return types like this (based off the presence @@ -673,7 +656,7 @@ def _top_level_groupby(df, options: FitOptions): # (flattening provides better cudf/pd support) if _is_list_col(cat_col_group, df_gb): # handle list columns by encoding the list values - df_gb = _flatten_list_column(df_gb[cat_col_group[0]]) + df_gb = dispatch._flatten_list_column(df_gb[cat_col_group[0]]) # NOTE: groupby(..., dropna=False) requires pandas>=1.1.0 gb = df_gb.groupby(cat_col_group, dropna=False).agg(agg_dict) @@ -714,7 +697,7 @@ def _mid_level_groupby(dfs, col_group, freq_limit_val, options: FitOptions): # Construct gpu DataFrame from pyarrow data. # `on_host=True` implies gpu-backed data. df = pa.concat_tables(dfs, promote=True) - df = _from_host(df) + df = dispatch._from_host(df) else: df = _concat(dfs, ignore_index=True) groups = df.groupby(col_group, dropna=False) @@ -796,7 +779,7 @@ def _write_gb_stats(dfs, base_path, col_group, options: FitOptions): if not options.on_host and len(dfs): # Want first non-empty df for schema (if there are any) _d = next((df for df in dfs if len(df)), dfs[0]) - pwriter = _parquet_writer_dispatch(_d, path=path, compression=None) + pwriter = dispatch._parquet_writer_dispatch(_d, path=path, compression=None) # Loop over dfs and append to file # TODO: For high-cardinality columns, should support @@ -837,7 +820,7 @@ def _write_uniques(dfs, base_path, col_group, options): # Construct gpu DataFrame from pyarrow data. # `on_host=True` implies gpu-backed data. df = pa.concat_tables(dfs, promote=True) - df = _from_host(df) + df = dispatch._from_host(df) else: df = _concat(dfs, ignore_index=True) rel_path = "unique.%s.parquet" % (_make_name(*col_group, sep=options.name_sep)) @@ -866,7 +849,7 @@ def _write_uniques(dfs, base_path, col_group, options): if nlargest < len(df): df = df.nlargest(n=nlargest, columns=name_count) - if not _series_has_nulls(df[col]): + if not dispatch._series_has_nulls(df[col]): nulls_missing = True new_cols[col] = _concat( [df._constructor_sliced([None], dtype=df[col].dtype), df[col]], @@ -1023,7 +1006,7 @@ def _encode( selection_r = name if isinstance(name, list) else [storage_name] list_col = _is_list_col(selection_l, df) if path: - read_pq_func = _read_parquet_dispatch(df) + read_pq_func = dispatch._read_parquet_dispatch(df) if cat_cache is not None: cat_cache = ( cat_cache if isinstance(cat_cache, str) else cat_cache.get(storage_name, "disk") @@ -1055,10 +1038,10 @@ def _encode( if not search_sorted: if list_col: - codes = _flatten_list_column(df[selection_l[0]]) - codes["order"] = _arange(len(codes), like_df=df) + codes = dispatch._flatten_list_column(df[selection_l[0]]) + codes["order"] = dispatch._arange(len(codes), like_df=df) else: - codes = type(df)({"order": _arange(len(df), like_df=df)}, index=df.index) + codes = type(df)({"order": dispatch._arange(len(df), like_df=df)}, index=df.index) for c in selection_l: codes[c] = df[c].copy() if buckets and storage_name in buckets: @@ -1098,7 +1081,7 @@ def _encode( labels[labels >= len(value[selection_r])] = na_sentinel if list_col: - labels = _encode_list_column(df[selection_l[0]], labels, dtype=dtype) + labels = dispatch._encode_list_column(df[selection_l[0]], labels, dtype=dtype) elif dtype: labels = labels.astype(dtype, copy=False) @@ -1131,7 +1114,7 @@ def _get_multicolumn_names(column_groups, df_columns, name_sep): def _is_list_col(column_group, df): - has_lists = any(_is_list_dtype(df[col]) for col in column_group) + has_lists = any(dispatch._is_list_dtype(df[col]) for col in column_group) if has_lists and len(column_group) != 1: raise ValueError("Can't categorical encode multiple list columns") return has_lists @@ -1140,7 +1123,7 @@ def _is_list_col(column_group, df): def _hash_bucket(df, num_buckets, col, encode_type="joint"): if encode_type == "joint": nb = num_buckets[col[0]] - encoded = _hash_series(df[col[0]]) % nb + encoded = dispatch._hash_series(df[col[0]]) % nb elif encode_type == "combo": if len(col) > 1: name = _make_name(*tuple(col), sep="_") @@ -1149,7 +1132,7 @@ def _hash_bucket(df, num_buckets, col, encode_type="joint"): nb = num_buckets[name] val = 0 for column in col: - val ^= _hash_series(df[column]) # or however we want to do this aggregation + val ^= dispatch._hash_series(df[column]) # or however we want to do this aggregation val = val % nb encoded = val return encoded From 3ffc73f7038fc7fdc32657d0fb34e1167516e8aa Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Fri, 16 Jul 2021 10:02:32 +0200 Subject: [PATCH 4/9] Quick fix to try to make the tests pass --- nvtabular/ops/categorify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nvtabular/ops/categorify.py b/nvtabular/ops/categorify.py index 68ba153ffd2..55ba2903d40 100644 --- a/nvtabular/ops/categorify.py +++ b/nvtabular/ops/categorify.py @@ -228,7 +228,7 @@ def __init__( # Only support two kinds of multi-column encoding if encode_type not in ("joint", "combo"): raise ValueError(f"encode_type={encode_type} not supported.") - if encode_type == "joint" and vocabs: + if encode_type == "joint" and vocabs is not None: raise ValueError("Passing in vocabs is not supported with a joint encoding.") # Other self-explanatory initialization From c5a82bef60fe1d123f389440f5c06ba383c03b9a Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 16 Jul 2021 15:05:07 -0700 Subject: [PATCH 5/9] Update cpu-ci.yml --- .github/workflows/cpu-ci.yml | 92 ++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/.github/workflows/cpu-ci.yml b/.github/workflows/cpu-ci.yml index a4466bbe3c1..2accb428cfe 100644 --- a/.github/workflows/cpu-ci.yml +++ b/.github/workflows/cpu-ci.yml @@ -1,46 +1,46 @@ -name: CPU CI - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - -jobs: - build: - runs-on: ${{ matrix.os }} - strategy: - matrix: - python-version: [3.8] - os: [ubuntu-latest] - - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install Ubuntu packages - run: | - sudo apt-get update -y - sudo apt-get install -y protobuf-compiler - - name: Install dependencies - run: | - python -m pip install --upgrade pip setuptools wheel - python -m pip install -r requirements.txt - python -m pip install -r requirements-dev.txt - - name: Lint with flake8 - run: | - flake8 . - - name: Lint with black - run: | - black --check . - - name: Lint with isort - run: | - isort -c . - - name: Build - run: | - python -m pip install -e . - - name: Run unittests - run: | - python -m pytest -svv tests/unit/test_cpu_workflow.py +name: CPU CI + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + python-version: [3.8] + os: [ubuntu-latest] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install Ubuntu packages + run: | + sudo apt-get update -y + sudo apt-get install -y protobuf-compiler + - name: Install dependencies + run: | + python -m pip install --upgrade pip setuptools wheel + python -m pip install -r requirements.txt + python -m pip install -r requirements-dev.txt + - name: Lint with flake8 + run: | + flake8 . + - name: Lint with black + run: | + black --check . + - name: Lint with isort + run: | + isort -c . + - name: Build + run: | + python setup.py develop + - name: Run unittests + run: | + python -m pytest -svv tests/unit/test_cpu_workflow.py From 1ef8e4361074cc0a5a3faaaf8ac5efb0326b38f7 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 19 Jul 2021 11:48:52 -0700 Subject: [PATCH 6/9] Update nvtabular/ops/categorify.py --- nvtabular/ops/categorify.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nvtabular/ops/categorify.py b/nvtabular/ops/categorify.py index 55ba2903d40..1450aba2598 100644 --- a/nvtabular/ops/categorify.py +++ b/nvtabular/ops/categorify.py @@ -228,8 +228,8 @@ def __init__( # Only support two kinds of multi-column encoding if encode_type not in ("joint", "combo"): raise ValueError(f"encode_type={encode_type} not supported.") - if encode_type == "joint" and vocabs is not None: - raise ValueError("Passing in vocabs is not supported with a joint encoding.") + if encode_type == "combo" and vocabs is not None: + raise ValueError("Passing in vocabs is not supported with a combo encoding.") # Other self-explanatory initialization super().__init__() From f6183acc644a7d041a9bb2ebccd730fb4f81ef9c Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 20 Jul 2021 16:47:32 +0200 Subject: [PATCH 7/9] Fixing prepend in dispatch._add_to_series --- nvtabular/dispatch.py | 8 ++++---- nvtabular/ops/categorify.py | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/nvtabular/dispatch.py b/nvtabular/dispatch.py index d8fb9a0c5de..bf37f309f1e 100644 --- a/nvtabular/dispatch.py +++ b/nvtabular/dispatch.py @@ -335,15 +335,15 @@ def _make_df(_like_df=None, device=None): def _add_to_series(series, to_add, prepend=True): if isinstance(series, pd.Series): - series_2 = pd.Series(to_add) + series_to_add = pd.Series(to_add) elif isinstance(series, cudf.Series): - series_2 = cudf.Series(to_add) + series_to_add = cudf.Series(to_add) else: raise ValueError("Unrecognized series, please provide either a pandas a cudf series") - to_concat = [series, series_2] if prepend else [series_2, series] + series_to_concat = [series_to_add, series] if prepend else [series, series_to_add] - return _concat(to_concat) + return _concat(series_to_concat) def _detect_format(data): diff --git a/nvtabular/ops/categorify.py b/nvtabular/ops/categorify.py index 1450aba2598..fdc79506b8a 100644 --- a/nvtabular/ops/categorify.py +++ b/nvtabular/ops/categorify.py @@ -340,7 +340,8 @@ def process_vocabs(self, vocabs): for col in list(vocabs.columns): col_df = vocabs[[col]] if col_df[col].iloc[0] is not None: - vals = {col: dispatch._add_to_series(col_df[col], [None]).reset_index()[0]} + with_empty = dispatch._add_to_series(col_df[col], [None]).reset_index()[0] + vals = {col: with_empty} col_df = dispatch._make_df(vals) save_path = os.path.join(base_path, f"unique.{col}.parquet") From 419b07f18b99926b9227a14a483b21758dbb4cae Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 20 Jul 2021 17:02:41 +0200 Subject: [PATCH 8/9] Fixing flake8 --- nvtabular/ops/categorify.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/nvtabular/ops/categorify.py b/nvtabular/ops/categorify.py index 2208eafbd31..eb672710f6d 100644 --- a/nvtabular/ops/categorify.py +++ b/nvtabular/ops/categorify.py @@ -34,19 +34,7 @@ from pyarrow import parquet as pq from nvtabular import dispatch -from nvtabular.dispatch import ( - DataFrameType, - _arange, - _encode_list_column, - _flatten_list_column, - _from_host, - _hash_series, - _is_list_dtype, - _parquet_writer_dispatch, - _read_parquet_dispatch, - _series_has_nulls, - annotate, -) +from nvtabular.dispatch import DataFrameType, annotate from nvtabular.worker import fetch_table_data, get_worker_cache from .operator import ColumnNames, Operator From 11f5b0379ff6074590a18147d6732f1644d1ca21 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 20 Jul 2021 14:59:24 -0700 Subject: [PATCH 9/9] Fix to match frequency sorting categorify changes --- tests/unit/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_ops.py b/tests/unit/test_ops.py index d3d7c8a21e1..0ef32660ab8 100644 --- a/tests/unit/test_ops.py +++ b/tests/unit/test_ops.py @@ -444,7 +444,7 @@ def test_lambdaop_misalign(cpu): @pytest.mark.parametrize("freq_threshold", [0, 1, 2]) @pytest.mark.parametrize("cpu", [False, True]) @pytest.mark.parametrize("dtype", [None, np.int32, np.int64]) -@pytest.mark.parametrize("vocabs", [None, pd.DataFrame({"Authors": [f"User_{x}" for x in "ABCE"]})]) +@pytest.mark.parametrize("vocabs", [None, pd.DataFrame({"Authors": [f"User_{x}" for x in "ACBE"]})]) def test_categorify_lists(tmpdir, freq_threshold, cpu, dtype, vocabs): df = cudf.DataFrame( {