Skip to content

Commit

Permalink
Add mock data and testing for multi tables' related imp (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wh1isper committed Jan 4, 2024
1 parent 160f6a2 commit 9bf075a
Show file tree
Hide file tree
Showing 15 changed files with 243 additions and 48 deletions.
14 changes: 8 additions & 6 deletions sdgx/data_models/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class MetadataCombiner(BaseModel):
Combine different tables with relationship, used for describing the relationship between tables.
Args:
named_metadata (Dict[str, Any]): Name of the table: Metadata
relationships (List[Any])
version (str): version
named_metadata (Dict[str, Any]): pairs of table name and metadata
relationships (List[Any]): list of relationships
"""

version: str = "1.0"
Expand Down Expand Up @@ -106,7 +106,7 @@ def from_dataloader(
)
for d in dataloaders:
for i, chunk in enumerate(d.iter()):
inspector.fit(chunk)
inspector.fit(chunk, name=d.identity)
if inspector.ready or i > max_chunk:
break
relationships = inspector.inspect()["relationships"]
Expand Down Expand Up @@ -156,8 +156,8 @@ def from_dataframe(
inspector = InspectorManager().init(
relationshipe_inspector, **relationships_inspector_kwargs
)
for d in dataframes:
inspector.fit(d)
for n, d in zip(names, dataframes):
inspector.fit(d, name=n)
relationships = inspector.inspect()["relationships"]

return cls(named_metadata=named_metadata, relationships=relationships)
Expand All @@ -182,6 +182,8 @@ def save(
relationship_subdir (str): subdirectory for relationship, default is "relationship"
"""
save_dir = Path(save_dir).expanduser().resolve()
save_dir.mkdir(parents=True, exist_ok=True)

version_file = save_dir / "version"
version_file.write_text(self.version)

Expand Down
4 changes: 2 additions & 2 deletions sdgx/data_models/inspectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ class Inspector:
def __init__(self, *args, **kwargs):
self.ready: bool = False

def fit(self, raw_data: pd.DataFrame):
def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
"""Fit the inspector.
Args:
raw_data (pd.DataFrame): Raw data
"""
return

def inspect(self) -> dict[str, Any]:
def inspect(self, *args, **kwargs) -> dict[str, Any]:
"""Inspect raw data and generate metadata."""
4 changes: 2 additions & 2 deletions sdgx/data_models/inspectors/bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.bool_columns: set[str] = set()

def fit(self, raw_data: pd.DataFrame):
def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
"""Fit the inspector.
Gets the list of discrete columns from the raw data.
Expand All @@ -28,7 +28,7 @@ def fit(self, raw_data: pd.DataFrame):

self.ready = True

def inspect(self) -> dict[str, Any]:
def inspect(self, *args, **kwargs) -> dict[str, Any]:
"""Inspect raw data and generate metadata."""

return {"bool_columns": list(self.bool_columns)}
Expand Down
6 changes: 4 additions & 2 deletions sdgx/data_models/inspectors/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from sdgx.data_models.inspectors.base import Inspector
from sdgx.data_models.inspectors.extension import hookimpl
from sdgx.utils import ignore_warnings


class DatetimeInspector(Inspector):
Expand All @@ -15,6 +16,7 @@ def __init__(self, *args, **kwargs):
self.datetime_columns: set[str] = set()

@classmethod
@ignore_warnings(category=UserWarning)
def can_convert_to_datetime(cls, input_col: pd.Series):
"""Whether a df column can be converted to datetime.
Expand All @@ -30,7 +32,7 @@ def can_convert_to_datetime(cls, input_col: pd.Series):
except:
return False

def fit(self, raw_data: pd.DataFrame):
def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
"""Fit the inspector.
Gets the list of discrete columns from the raw data.
Expand All @@ -52,7 +54,7 @@ def fit(self, raw_data: pd.DataFrame):

self.ready = True

def inspect(self) -> dict[str, Any]:
def inspect(self, *args, **kwargs) -> dict[str, Any]:
"""Inspect raw data and generate metadata."""

return {"datetime_columns": list(self.datetime_columns)}
Expand Down
4 changes: 2 additions & 2 deletions sdgx/data_models/inspectors/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.discrete_columns: set[str] = set()

def fit(self, raw_data: pd.DataFrame):
def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
"""Fit the inspector.
Gets the list of discrete columns from the raw data.
Expand All @@ -27,7 +27,7 @@ def fit(self, raw_data: pd.DataFrame):
)
self.ready = True

def inspect(self) -> dict[str, Any]:
def inspect(self, *args, **kwargs) -> dict[str, Any]:
"""Inspect raw data and generate metadata."""

return {"discrete_columns": list(self.discrete_columns)}
Expand Down
4 changes: 2 additions & 2 deletions sdgx/data_models/inspectors/i_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ID_columns: set[str] = set()

def fit(self, raw_data: pd.DataFrame):
def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
"""Fit the inspector.
Gets the list of discrete columns from the raw data.
Expand All @@ -33,7 +33,7 @@ def fit(self, raw_data: pd.DataFrame):

self.ready = True

def inspect(self) -> dict[str, Any]:
def inspect(self, *args, **kwargs) -> dict[str, Any]:
"""Inspect raw data and generate metadata."""

return {"id_columns": list(self.ID_columns)}
Expand Down
4 changes: 2 additions & 2 deletions sdgx/data_models/inspectors/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.numeric_columns: set[str] = set()

def fit(self, raw_data: pd.DataFrame):
def fit(self, raw_data: pd.DataFrame, *args, **kwargs):
"""Fit the inspector.
Gets the list of discrete columns from the raw data.
Expand All @@ -27,7 +27,7 @@ def fit(self, raw_data: pd.DataFrame):
)
self.ready = True

def inspect(self) -> dict[str, Any]:
def inspect(self, *args, **kwargs) -> dict[str, Any]:
"""Inspect raw data and generate metadata."""

return {"numeric_columns": list(self.numeric_columns)}
Expand Down
7 changes: 6 additions & 1 deletion sdgx/data_models/inspectors/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Any

import pandas as pd

from sdgx.data_models.inspectors.base import Inspector
from sdgx.data_models.inspectors.extension import hookimpl
from sdgx.data_models.relationship import Relationship
Expand All @@ -11,7 +13,10 @@ class RelationshipInspector(Inspector):
def _build_relationship(self) -> list[Relationship]:
return []

def inspect(self) -> dict[str, Any]:
def fit(self, raw_data: pd.DataFrame, name: str | None = None, *args, **kwargs):
pass

def inspect(self, *args, **kwargs) -> dict[str, Any]:
"""Inspect raw data and generate metadata."""
return {"relationships": self._build_relationship()}

Expand Down
21 changes: 19 additions & 2 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from sdgx.data_loader import DataLoader
from sdgx.data_models.inspectors.manager import InspectorManager
from sdgx.data_models.inspectors.relationship import RelationshipInspector
from sdgx.exceptions import MetadataInitError, MetadataInvalidError
from sdgx.utils import logger

Expand Down Expand Up @@ -229,7 +230,15 @@ def from_dataloader(
inspector_init_kwargs(dict): inspector args.
"""
logger.info("Inspecting metadata...")
inspectors = InspectorManager().init_inspcetors(
im = InspectorManager()
exclude_inspectors = exclude_inspectors or []
exclude_inspectors.extend(
name
for name, inspector_type in im.registed_inspectors.items()
if issubclass(inspector_type, RelationshipInspector)
)

inspectors = im.init_inspcetors(
include_inspectors, exclude_inspectors, **(inspector_init_kwargs or {})
)
for i, chunk in enumerate(dataloader.iter()):
Expand Down Expand Up @@ -267,7 +276,15 @@ def from_dataframe(
inspector_init_kwargs(dict): inspector args.
"""

inspectors = InspectorManager().init_inspcetors(
im = InspectorManager()
exclude_inspectors = exclude_inspectors or []
exclude_inspectors.extend(
name
for name, inspector_type in im.registed_inspectors.items()
if issubclass(inspector_type, RelationshipInspector)
)

inspectors = im.init_inspcetors(
include_inspectors, exclude_inspectors, **(inspector_init_kwargs or {})
)
for inspector in inspectors:
Expand Down
20 changes: 14 additions & 6 deletions sdgx/data_models/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
from pathlib import Path
from typing import Any, Iterable, Set
from typing import Any, Iterable, List, Set, Tuple, Union

from pydantic import BaseModel

Expand All @@ -24,32 +24,40 @@ class Relationship(BaseModel):
parent_table: str
child_table: str

foreign_keys: Set[str]
foreign_keys: List[Union[str, Tuple[str, str]]]
"""
foreign keys.
If key is a tuple, the first element is parent column name and the second element is child column name
"""

@classmethod
def build(
cls, parent_table: str, child_table: str, foreign_keys: Iterable[str]
cls,
parent_table: str,
child_table: str,
foreign_keys: Iterable[str | tuple[str, str]],
) -> "Relationship":
"""
Build relationship from parent table, child table and foreign keys
Args:
parent_table (str): parent table
child_table (str): child table
foreign_keys (Iterable[str]): foreign keys
foreign_keys (Iterable[str | tuple[str, str]]): foreign keys. If key is a tuple, the first element is parent column name and the second element is child column name
"""

if not parent_table:
raise RelationshipInitError("parent table cannot be empty")
if not child_table:
raise RelationshipInitError("child table cannot be empty")

foreign_keys = list(foreign_keys)
if not foreign_keys:
raise RelationshipInitError("foreign keys cannot be empty")
if parent_table == child_table:
raise RelationshipInitError("child table and parent table cannot be the same")

foreign_keys = set(foreign_keys)

return cls(
parent_table=parent_table,
child_table=child_table,
Expand Down
17 changes: 10 additions & 7 deletions sdgx/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,29 @@ def register(self, cls_name, cls: type):
return
self._registed_cls[cls_name] = cls

def init(self, cls_name, **kwargs: dict[str, Any]):
def init(self, c, **kwargs: dict[str, Any]):
"""
Init a new subclass of self.register_type.
Raises:
NotFoundError: if cls_name is not registered
InitializationError: if failed to initialize
"""
if isinstance(cls_name, type):
cls_type = cls_name
if isinstance(c, self.register_type):
return c

if isinstance(c, type):
cls_type = c
else:
cls_name = self._normalize_name(cls_name)
c = self._normalize_name(c)

if not cls_name in self.registed_cls:
if not c in self.registed_cls:
raise NotFoundError
cls_type = self.registed_cls[cls_name]
cls_type = self.registed_cls[c]
try:
instance = cls_type(**kwargs)
if not isinstance(instance, self.register_type):
raise InitializationError(f"{cls_name} is not a subclass of {self.register_type}.")
raise InitializationError(f"{c} is not a subclass of {self.register_type}.")
return instance
except Exception as e:
raise InitializationError(e)
16 changes: 16 additions & 0 deletions sdgx/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import functools
import socket
import threading
import urllib.request
import warnings
from contextlib import closing
from pathlib import Path
from typing import Callable

import pandas as pd

Expand Down Expand Up @@ -90,3 +93,16 @@ def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]


def ignore_warnings(category: Warning):
def ignore_warnings_decorator(func: Callable):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=category)
return func(*args, **kwargs)

return wrapper

return ignore_warnings_decorator
34 changes: 34 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,40 @@ def demo_single_table_path():
yield download_demo_data(DATA_DIR).as_posix()


@pytest.fixture
def demo_relational_table_path(tmp_path):
dummy_size = 10
role_set = ["admin", "user", "guest"]

df = pd.DataFrame(
{
"id": list(range(dummy_size)),
"role": [random.choice(role_set) for _ in range(dummy_size)],
"name": [ramdon_str() for _ in range(dummy_size)],
"feature_x": [random.random() for _ in range(dummy_size)],
"feature_y": [random.random() for _ in range(dummy_size)],
"feature_z": [random.random() for _ in range(dummy_size)],
}
)
save_path_a = tmp_path / "dummy_relation_A.csv"
df.to_csv(save_path_a, index=False, header=True)

sub_size = 5
assert dummy_size >= sub_size
df = pd.DataFrame(
{
"foreign_id": list(range(sub_size)),
"feature_i": [random.random() for _ in range(sub_size)],
"feature_j": [random.random() for _ in range(sub_size)],
"feature_k": [random.random() for _ in range(sub_size)],
}
)
save_path_b = tmp_path / "dummy_relation_B.csv"
df.to_csv(save_path_b, index=False, header=True)

return save_path_a, save_path_b, [("id", "foreign_id")]


@pytest.fixture
def cacher_kwargs(tmp_path):
cache_dir = tmp_path / "cache"
Expand Down
Loading

0 comments on commit 9bf075a

Please sign in to comment.