Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
PyTorch Tabular integrations (#1559)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
3 people committed May 14, 2023
1 parent 59dab9b commit 14c2755
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 15 deletions.
1 change: 1 addition & 0 deletions src/flash/core/integrations/pytorch_tabular/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def from_task(
"categorical_dim": len(categorical_fields),
"continuous_dim": num_features - len(categorical_fields),
"output_dim": output_dim,
"embedded_cat_dim": sum([embd_dim for _, embd_dim in embedding_sizes]),
}
return cls(
task_type,
Expand Down
4 changes: 3 additions & 1 deletion src/flash/core/integrations/pytorch_tabular/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AutoIntConfig,
CategoryEmbeddingModelConfig,
FTTransformerConfig,
GatedAdditiveTreeEnsembleConfig,
NodeConfig,
TabNetModelConfig,
TabTransformerConfig,
Expand Down Expand Up @@ -88,8 +89,9 @@ def load_pytorch_tabular(
AutoIntConfig,
NodeConfig,
CategoryEmbeddingModelConfig,
GatedAdditiveTreeEnsembleConfig,
],
["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"],
["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding", "gate"],
):
PYTORCH_TABULAR_BACKBONES(
functools.partial(load_pytorch_tabular, model_config_class),
Expand Down
3 changes: 1 addition & 2 deletions tests/tabular/classification/test_data_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# ("category_embedding", # todo: seems to be bug in tabular
# {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# No categorical / numerical fields
("tabnet", {"categorical_fields": ["category"]}),
("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}),
Expand Down
6 changes: 3 additions & 3 deletions tests/tabular/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class TestTabularClassifier(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
# {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
{"backbone": "category_embedding"},
],
)
],
Expand All @@ -68,7 +68,7 @@ class TestTabularClassifier(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
# {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
{"backbone": "category_embedding"},
],
)
],
Expand All @@ -81,7 +81,7 @@ class TestTabularClassifier(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
# {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
{"backbone": "category_embedding"},
],
)
],
Expand Down
9 changes: 3 additions & 6 deletions tests/tabular/regression/test_data_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@
("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# ("category_embedding", # todo: seems to be bug in tabular
# {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# No categorical / numerical fields
("tabnet", {"categorical_fields": ["category"]}),
("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}),
Expand Down Expand Up @@ -82,8 +81,7 @@ def test_regression_data_frame(backbone, fields, tmpdir):
("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# ("category_embedding", # todo: seems to be bug in tabular
# {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# No categorical / numerical fields
("tabnet", {"categorical_fields": ["category"]}),
("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}),
Expand Down Expand Up @@ -113,8 +111,7 @@ def test_regression_dicts(backbone, fields, tmpdir):
("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# ("category_embedding", # todo: seems to be bug in tabular
# {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# No categorical / numerical fields
("tabnet", {"categorical_fields": ["category"]}),
("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}),
Expand Down
6 changes: 3 additions & 3 deletions tests/tabular/regression/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class TestTabularRegressor(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
# {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
{"backbone": "category_embedding"},
],
)
],
Expand All @@ -66,7 +66,7 @@ class TestTabularRegressor(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
# {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
{"backbone": "category_embedding"},
],
)
],
Expand All @@ -79,7 +79,7 @@ class TestTabularRegressor(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
# {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
{"backbone": "category_embedding"},
],
)
],
Expand Down

0 comments on commit 14c2755

Please sign in to comment.