Skip to content

Commit

Permalink
Introduce inspect_level in inspector and metadata (#113)
Browse files Browse the repository at this point in the history
* add inspect_level

* update metadata check

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add pii in inspector

* update pii and inspect level in metadata creation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update metadata.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add func dump

* add test cases

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add inspect_level limitation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* simply __init__ code

* add more test case of inspect_level

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update sdgx/data_models/metadata.py

Co-authored-by: Zhongsheng Ji <9573586@qq.com>

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Zhongsheng Ji <9573586@qq.com>
  • Loading branch information
3 people committed Jan 20, 2024
1 parent 1bbf3f5 commit f2c12a2
Show file tree
Hide file tree
Showing 11 changed files with 210 additions and 11 deletions.
31 changes: 30 additions & 1 deletion sdgx/data_models/inspectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sdgx.data_models.metadata import Metadata

from sdgx.data_models.relationship import Relationship
from sdgx.exceptions import DataModelError


class Inspector:
Expand All @@ -20,8 +21,36 @@ class Inspector:
ready (bool): Ready to inspect, maybe all fields are fitted, or indicate if there is more data, inspector will be more precise.
"""

def __init__(self, *args, **kwargs):
pii = False
"""
PII refers if a column contains private or sensitive information.m
"""

_inspect_level: int = 10
"""
Inspected level is a concept newly introduced in version 0.1.5. Since a single column in the table may be marked by different inspectors at the same time (for example: the email column may be recognized as email, but it may also be recognized as the id column, and it may also be recognized by different inspectors at the same time identified as a discrete column, which will cause confusion in subsequent processing), the inspect_leve is used when determining the specific type of a column.
We will preset different inspector levels for different inspectors, usually more specific inspectors will get higher levels, and general inspectors (like discrete) will have inspect_level.
The value of the variable inspect_level is limited to 1-100. In baseclass and bool, discrete and numeric types, the inspect_level is set to 10. For datetime and id types, the inspect_level is set to 20. When a variable is marked multiple times At this time, the mark of the inspector with a higher inspect_level shall prevail. Such a markup method will also make it easier for developers to insert a custom inspector from the middle.
"""

@property
def inspect_level(self):
return self._inspect_level

@inspect_level.setter
def inspect_level(self, value: int):
if value > 0 and value <= 100:
self._inspect_level = value
else:
raise DataModelError("The inspect_level should be set in [1, 100].")

def __init__(self, inspect_level=None, *args, **kwargs):
self.ready: bool = False
# add inspect_level check
if inspect_level:
self.inspect_level = inspect_level

def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
"""Fit the inspector.
Expand Down
7 changes: 7 additions & 0 deletions sdgx/data_models/inspectors/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@


class DatetimeInspector(Inspector):
_inspect_level = 20
"""
The inspect_level of DatetimeInspector is higher than DiscreteInspector.
Often, difficult-to-recognize date or datetime objects are also recognized as descrete types by DatetimeInspector, causing the column to be marked repeatedly.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.datetime_columns: set[str] = set()
Expand Down
7 changes: 7 additions & 0 deletions sdgx/data_models/inspectors/i_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@


class IDInspector(Inspector):
_inspect_level = 20
"""
The inspect_level of IDInspector is higher than NumericInspector.
Often, some column, especially int type id column can also be recognized as numeric types by NumericInspector, causing the column to be marked repeatedly.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ID_columns: set[str] = set()
Expand Down
95 changes: 89 additions & 6 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ class Metadata(BaseModel):
column_list is used to store all columns' name
"""

column_inspect_level: Dict[str, int] = defaultdict(lambda: 10)
"""
column_inspect_level is used to store every inspector's level, to specify the true type of each column.
"""

pii_columns: Set[set] = set()
"""
pii_columns is used to store all PII columns' name
"""

# other columns lists are used to store column information
# here are 5 basic data types
id_columns: Set[str] = set()
Expand Down Expand Up @@ -254,7 +264,17 @@ def from_dataloader(

metadata = Metadata(primary_keys=primary_keys, column_list=set(dataloader.columns()))
for inspector in inspectors:
metadata.update(inspector.inspect())
inspect_res = inspector.inspect()
# update column type
metadata.update(inspect_res)
# update pii column
if inspector.pii:
for each_key in inspect_res:
metadata.update({"pii_columns": inspect_res[each_key]})
# update inspect level
for each_key in inspect_res:
metadata.column_inspect_level[each_key] = inspector.inspect_level

if not primary_keys:
metadata.update_primary_key(metadata.id_columns)

Expand Down Expand Up @@ -296,7 +316,17 @@ def from_dataframe(

metadata = Metadata(primary_keys=[df.columns[0]], column_list=set(df.columns))
for inspector in inspectors:
metadata.update(inspector.inspect())
inspect_res = inspector.inspect()
# update column type
metadata.update(inspect_res)
# update pii column
if inspector.pii:
for each_key in inspect_res:
metadata.update({"pii_columns": inspect_res[each_key]})
# update inspect level
for each_key in inspect_res:
metadata.column_inspect_level[each_key] = inspector.inspect_level

if check:
metadata.check()
return metadata
Expand Down Expand Up @@ -341,8 +371,6 @@ def check_single_primary_key(self, input_key: str):

if input_key not in self.column_list:
raise MetadataInvalidError(f"Primary Key {input_key} not Exist in columns.")
if input_key not in self.id_columns:
raise MetadataInvalidError(f"Primary Key {input_key} should has ID DataType.")

def get_all_data_type_columns(self):
"""Get all column names from `self.xxx_columns`.
Expand All @@ -360,7 +388,6 @@ def get_all_data_type_columns(self):
if each_key.endswith("_columns"):
column_names = self.get(each_key)
all_dtype_cols = all_dtype_cols.union(set(column_names))

return all_dtype_cols

def check(self):
Expand All @@ -371,10 +398,16 @@ def check(self):
-Is there any missing definition of each column in table.
-Are there any unknown columns that have been incorrectly updated.
"""
# check primary key in column_list and has ID data type
# check primary key in column_list
for each_key in self.primary_keys:
self.check_single_primary_key(each_key)

# for single primary key, it should has ID type
if len(self.primary_keys) == 1 and list(self.primary_keys)[0] not in self.id_columns:
raise MetadataInvalidError(
f"Primary Key {self.primary_keys[0]} should has ID DataType."
)

all_dtype_columns = self.get_all_data_type_columns()

# check missing columns
Expand Down Expand Up @@ -410,3 +443,53 @@ def update_primary_key(self, primary_keys: Iterable[str] | str):
self.primary_keys = primary_keys

logger.info(f"Primary Key updated: {primary_keys}.")

def dump(self):
"""Dump model dict, can be used in downstream process, like processor.
Returns:
dict: dumped dict.
"""
model_dict = self.model_dump()
model_dict["column_data_type"] = {}
for each_col in self.column_list:
model_dict["column_data_type"][each_col] = self.get_column_data_type(each_col)
return model_dict

def get_column_data_type(self, column_name: str):
"""Get the exact type of specific column.
Args:
column_name(str): The query colmun name.
Returns:
str: The data type query result.
"""
if column_name not in self.column_list:
raise MetadataInvalidError(f"Column {column_name}not exists in metadata.")
current_type = None
current_level = 0
# find the dtype who has most high inspector level
for each_key in list(self.model_fields.keys()) + list(self._extend.keys()):
if (
each_key != "pii_columns"
and each_key.endswith("_columns")
and column_name in self.get(each_key)
and current_level < self.column_inspect_level[each_key]
):
current_level = self.column_inspect_level[each_key]
current_type = each_key
if not current_type:
raise MetadataInvalidError(f"Column {column_name} has no data type.")
return current_type.split("_columns")[0]

def get_column_pii(self, column_name: str):
"""Return if a column is a PII column.
Args:
column_name(str): The query colmun name.
Returns:
bool: The PII query result.
"""
if column_name not in self.column_list:
raise MetadataInvalidError(f"Column {column_name}not exists in metadata.")
if column_name in self.pii_columns:
return True
return False
19 changes: 15 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,20 +162,31 @@ def demo_multi_table_data_loader(demo_multi_table_data_connector, cacher_kwargs)
demo_multi_table_data_connector[each_table].finalize()


@pytest.fixture
def demo_multi_data_parent_matadata(demo_multi_table_data_loader):
yield Metadata.from_dataloader(demo_multi_table_data_loader["store"])


@pytest.fixture
def demo_multi_data_child_matadata(demo_multi_table_data_loader):
yield Metadata.from_dataloader(demo_multi_table_data_loader["train"])


@pytest.fixture
def demo_multi_data_relationship():
yield Relationship.build(parent_table="store", child_table="train", foreign_keys=["Store"])


@pytest.fixture
def demo_multi_table_data_metadata_combiner(
demo_multi_table_data_loader, demo_multi_data_relationship
demo_multi_data_parent_matadata: Metadata,
demo_multi_data_child_matadata: Metadata,
demo_multi_data_relationship: Relationship,
):
# 1. get metadata
metadata_dict = {}
for each_table_name in demo_multi_table_data_loader:
each_metadata = Metadata.from_dataloader(demo_multi_table_data_loader[each_table_name])
metadata_dict[each_table_name] = each_metadata
metadata_dict["store"] = demo_multi_data_parent_matadata
metadata_dict["train"] = demo_multi_data_child_matadata
# 2. define relationship - already defined
# 3. define combiner
m = MetadataCombiner(named_metadata=metadata_dict, relationships=[demo_multi_data_relationship])
Expand Down
12 changes: 12 additions & 0 deletions tests/data_models/inspector/test_bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from sdgx.data_models.inspectors.bool import BoolInspector
from sdgx.exceptions import DataModelError


@pytest.fixture
Expand Down Expand Up @@ -37,13 +38,24 @@ def test_inspector_demo_data(inspector: BoolInspector, raw_data):
# should be empty set
assert not inspector.bool_columns
assert sorted(inspector.inspect()["bool_columns"]) == sorted([])
assert inspector.inspect_level == 10
# test inspect_level.setter
try:
inspector.inspect_level = 120
except Exception as e:
assert type(e) == DataModelError


def test_inspector_generated_data(inspector: BoolInspector, bool_test_df: pd.DataFrame):
# use generated id data
inspector.fit(bool_test_df)
assert inspector.bool_columns
assert sorted(inspector.inspect()["bool_columns"]) == sorted(["bool_random"])
assert inspector.inspect_level == 10
try:
inspector.inspect_level = 0
except Exception as e:
assert type(e) == DataModelError


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions tests/data_models/inspector/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_inspector_demo_data(inspector: DatetimeInspector, raw_data):
# should be empty set
assert not inspector.datetime_columns
assert sorted(inspector.inspect()["datetime_columns"]) == sorted([])
assert inspector.inspect_level == 20


def test_inspector_generated_data(inspector: DatetimeInspector, datetime_test_df: pd.DataFrame):
Expand All @@ -83,6 +84,7 @@ def test_inspector_generated_data(inspector: DatetimeInspector, datetime_test_df
assert sorted(inspector.inspect()["datetime_columns"]) == sorted(
["simple_datetime", "simple_datetime_2", "date_with_time"]
)
assert inspector.inspect_level == 20


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions tests/data_models/inspector/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_inspector(inspector: DiscreteInspector, raw_data):
"income",
]
)
assert inspector.inspect_level == 10


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions tests/data_models/inspector/test_i_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_inspector_demo_data(inspector: IDInspector, raw_data):
# should be empty set
assert not inspector.ID_columns
assert sorted(inspector.inspect()["id_columns"]) == sorted([])
assert inspector.inspect_level == 20


def test_inspector_generated_data(inspector: IDInspector, id_test_df: pd.DataFrame):
Expand Down
1 change: 1 addition & 0 deletions tests/data_models/inspector/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_inspector(inspector: NumericInspector, raw_data):
assert sorted(inspector.inspect()["numeric_columns"]) == sorted(
["educational-num", "fnlwgt", "hours-per-week", "age", "capital-gain", "capital-loss"]
)
assert inspector.inspect_level == 10


if __name__ == "__main__":
Expand Down
45 changes: 45 additions & 0 deletions tests/data_models/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,50 @@ def test_metadata_check(metadata: Metadata):
metadata.check()


def test_demo_multi_table_data_metadata_parent(demo_multi_data_parent_matadata):
# self check
demo_multi_data_parent_matadata.check()
# check each col's data type
assert demo_multi_data_parent_matadata.get_column_data_type("Store") == "id"
assert demo_multi_data_parent_matadata.get_column_data_type("StoreType") == "discrete"
assert demo_multi_data_parent_matadata.get_column_data_type("Assortment") == "discrete"
assert demo_multi_data_parent_matadata.get_column_data_type("CompetitionDistance") == "numeric"
assert (
demo_multi_data_parent_matadata.get_column_data_type("CompetitionOpenSinceMonth")
== "numeric"
)
assert demo_multi_data_parent_matadata.get_column_data_type("Promo2") == "numeric"
assert demo_multi_data_parent_matadata.get_column_data_type("Promo2SinceWeek") == "numeric"
assert demo_multi_data_parent_matadata.get_column_data_type("Promo2SinceYear") == "numeric"
assert demo_multi_data_parent_matadata.get_column_data_type("PromoInterval") == "discrete"
# check pii
for each_col in demo_multi_data_parent_matadata.column_list:
assert demo_multi_data_parent_matadata.get_column_pii(each_col) is False
assert len(demo_multi_data_parent_matadata.pii_columns) is 0
# check dump
assert "column_data_type" in demo_multi_data_parent_matadata.dump().keys()


def test_demo_multi_table_data_metadata_child(demo_multi_data_child_matadata):
# self check
demo_multi_data_child_matadata.check()
# check each col's data type
assert demo_multi_data_child_matadata.get_column_data_type("Store") == "numeric"
assert demo_multi_data_child_matadata.get_column_data_type("Date") == "datetime"
assert demo_multi_data_child_matadata.get_column_data_type("Customers") == "numeric"
assert demo_multi_data_child_matadata.get_column_data_type("StateHoliday") == "numeric"
assert demo_multi_data_child_matadata.get_column_data_type("Sales") == "numeric"
assert demo_multi_data_child_matadata.get_column_data_type("Promo") == "numeric"
assert demo_multi_data_child_matadata.get_column_data_type("DayOfWeek") == "numeric"
assert demo_multi_data_child_matadata.get_column_data_type("Open") == "numeric"
assert demo_multi_data_child_matadata.get_column_data_type("SchoolHoliday") == "numeric"
# check pii
for each_col in demo_multi_data_child_matadata.column_list:
assert demo_multi_data_child_matadata.get_column_pii(each_col) is False
assert len(demo_multi_data_child_matadata.pii_columns) is 0
# check dump
assert "column_data_type" in demo_multi_data_child_matadata.dump().keys()


if __name__ == "__main__":
pytest.main(["-vv", "-s", __file__])

0 comments on commit f2c12a2

Please sign in to comment.