Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add base model for multi-table statistic model, change single-table base class location #102

Merged
merged 27 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
12fd97a
Create base.py for multi-table statistic models
MooooCat Jan 10, 2024
df2e570
Update base.py
MooooCat Jan 10, 2024
f5fd7cb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2024
3108dee
update statistic single-table base class
MooooCat Jan 11, 2024
d5b8e10
update multi-table base class
MooooCat Jan 11, 2024
eab3143
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2024
980941e
Merge branch 'refactoring-base-model-partitial' of github.com:hitsz-i…
MooooCat Jan 12, 2024
689fe28
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2024
00c6f42
Merge branch 'main' into refactoring-base-model-partitial
MooooCat Jan 12, 2024
9819db2
add functions (still draft)
MooooCat Jan 12, 2024
535bde1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2024
dc7a265
Update base.py
MooooCat Jan 13, 2024
870d7db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2024
9bbec2c
fix dict typo
MooooCat Jan 13, 2024
a72748c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2024
157d387
fix type hint typo
MooooCat Jan 13, 2024
47125f6
Update base.py
MooooCat Jan 15, 2024
cc8bcf9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2024
8de02cc
add multi-table test fixture
MooooCat Jan 15, 2024
7cc304c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2024
7dd91fd
modify check settings in metadata
MooooCat Jan 15, 2024
39cc1df
update multi-table base class
MooooCat Jan 15, 2024
0603baf
add test cases
MooooCat Jan 15, 2024
f0caac8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2024
4ec72ee
Apply reviewer's suggestions.
MooooCat Jan 16, 2024
71380ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
01b1749
Merge branch 'main' into refactoring-base-model-partitial
MooooCat Jan 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sdgx/data_models/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class MetadataCombiner(BaseModel):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.check()

def check(self):
"""Do necessary checks:
Expand Down
8 changes: 6 additions & 2 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def from_dataloader(
include_inspectors: Iterable[str] | None = None,
exclude_inspectors: Iterable[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataLoader and Inspectors

Expand Down Expand Up @@ -257,7 +258,8 @@ def from_dataloader(
if not primary_keys:
metadata.update_primary_key(metadata.id_columns)

metadata.check()
if check:
metadata.check()
return metadata

@classmethod
Expand All @@ -267,6 +269,7 @@ def from_dataframe(
include_inspectors: list[str] | None = None,
exclude_inspectors: list[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataFrame and Inspectors

Expand Down Expand Up @@ -294,7 +297,8 @@ def from_dataframe(
metadata = Metadata(primary_keys=[df.columns[0]], column_list=set(df.columns))
for inspector in inspectors:
metadata.update(inspector.inspect())
metadata.check()
if check:
metadata.check()
return metadata

def _dump_json(self):
Expand Down
151 changes: 151 additions & 0 deletions sdgx/models/statistics/multi_tables/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from __future__ import annotations

import time
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List

import pandas as pd
from pydantic import BaseModel

from sdgx.data_loader import DataLoader
from sdgx.data_models.combiner import MetadataCombiner
from sdgx.log import logger
from sdgx.utils import DataAccessType


class MultiTableSynthesizerModel(BaseModel):
"""MultiTableSynthesizerModel

The base model of multi-table statistic models.
"""

data_access_method: DataAccessType = DataAccessType.pd_data_frame
"""
The type of the data access, now support pandas.DataFrame or sdgx.DataLoader.
"""

metadata_combiner: MetadataCombiner = None
"""
metadata_combiner is a sdgx builtin class, it stores all tables' metadata and relationships.
"""

tables_data_frame: Dict[str, Any] = defaultdict()
"""
tables_data_frame is a dict contains every table's csv data frame.
For a small amount of data, this scheme can be used.
"""

tables_data_loader: Dict[str, Any] = defaultdict()
"""
tables_data_loader is a dict contains every table's data loader.
"""

_parent_id: List = []
"""
_parent_id is used to store all parent table's parimary keys in list.
"""

_table_synthesizers: Dict[str, Any] = {}
"""
_table_synthesizers is a dict to store model for each table.
"""

parent_map: Dict = defaultdict()
"""
The mapping from all child tables to their parent table.
"""

child_map: Dict = defaultdict()
"""
The mapping from all parent tabels to their child table.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self._calculate_parent_and_child_map()

self.check()

def _calculate_parent_and_child_map(self):
"""Get the mapping from all parent tables to self._parent_map
- key(str) is a child map;
- value(str) is the parent map.
"""
relationships = self.metadata_combiner.relationships
for each_relationship in relationships:
parent_table = each_relationship.parent_table
child_table = each_relationship.child_table
self.parent_map[child_table] = parent_table
self.child_map[parent_table] = child_table

def _get_foreign_keys(self, parent_table, child_table):
"""Get the foreign key list from a relationship"""
relationships = self.metadata_combiner.relationships
for each_relationship in relationships:
# find the exact relationship and return foreign keys
if (
each_relationship.parent_table == parent_table
and each_relationship.child_table == child_table
):
return each_relationship.foreign_keys
return []

def _get_all_foreign_keys(self, child_table):
"""Given a child table, return ALL foreign keys from metadata."""
all_foreign_keys = []
relationships = self.metadata_combiner.relationships
for each_relationship in relationships:
# find the exact relationship and return foreign keys
if each_relationship.child_table == child_table:
all_foreign_keys.append(each_relationship.foreign_keys)

return all_foreign_keys

def _finalize(self):
"""Finalize the"""
raise NotImplementedError

def check(self, check_circular=True):
"""Excute necessary checks

- validate circular relationships
- validate child map_circular relationship
- validate all tables connect relationship
- validate column relationships foreign keys
"""

pass

def fit(self, dataloader: DataLoader, *args, **kwargs):
"""
Fit the model using the given metadata and dataloader.

Args:
metadata (Metadata): The metadata to use.
dataloader (DataLoader): The dataloader to use.
"""
raise NotImplementedError

def sample(self, count: int, *args, **kwargs) -> pd.DataFrame:
"""
Sample data from the model.

Args:
count (int): The number of samples to generate.

Returns:
pd.DataFrame: The generated data.
"""

raise NotImplementedError

def save(self, save_dir: str | Path):
pass

@classmethod
def load(target_path: str | Path):
pass

pass
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ class SynthesizerModel:
random_states = None

def __init__(self, transformer=None, sampler=None) -> None:
# 以下几个变量都需要在初始化 model 时进行更改
self.model = None # 存放模型
self.model = None
self.status = "UNFINED"
self.model_type = "MODEL_TYPE_UNDEFINED"
# self.epochs = epochs
Expand Down
2 changes: 1 addition & 1 deletion sdgx/models/statistics/single_table/copula.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
unflatten_dict,
validate_numerical_distributions,
)
from sdgx.models.statistics.base import SynthesizerModel
from sdgx.models.statistics.single_table.base import SynthesizerModel

LOGGER = logging.getLogger(__name__)

Expand Down
10 changes: 10 additions & 0 deletions sdgx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import urllib.request
import warnings
from contextlib import closing
from enum import Enum
from pathlib import Path
from typing import Callable

Expand Down Expand Up @@ -41,6 +42,15 @@
}


class DataAccessType(Enum):
"""
Type of data access.
"""

pd_data_frame = 1
sdgx_data_loader = 2


def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
Expand Down
23 changes: 23 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.data_loader import DataLoader
from sdgx.data_models.combiner import MetadataCombiner
from sdgx.data_models.metadata import Metadata
from sdgx.data_models.relationship import Relationship
from sdgx.utils import download_demo_data, download_multi_table_demo_data

_HERE = os.path.dirname(__file__)
Expand Down Expand Up @@ -158,3 +160,24 @@ def demo_multi_table_data_loader(demo_multi_table_data_connector, cacher_kwargs)
yield loader_dict
for each_table in demo_multi_table_data_connector.keys():
demo_multi_table_data_connector[each_table].finalize()


@pytest.fixture
def demo_multi_data_relationship():
yield Relationship.build(parent_table="store", child_table="train", foreign_keys=["Store"])


@pytest.fixture
def demo_multi_table_data_metadata_combiner(
demo_multi_table_data_loader, demo_multi_data_relationship
):
# 1. get metadata
metadata_dict = {}
for each_table_name in demo_multi_table_data_loader:
each_metadata = Metadata.from_dataloader(demo_multi_table_data_loader[each_table_name])
metadata_dict[each_table_name] = each_metadata
# 2. define relationship - already defined
# 3. define combiner
m = MetadataCombiner(named_metadata=metadata_dict, relationships=[demo_multi_data_relationship])

yield m
33 changes: 33 additions & 0 deletions tests/models/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

from collections import defaultdict, namedtuple

import pytest

from sdgx.models.statistics.multi_tables.base import MultiTableSynthesizerModel
from sdgx.utils import DataAccessType


@pytest.fixture
def demo_base_multi_table_synthesizer(
demo_multi_table_data_metadata_combiner, demo_multi_table_data_loader
):
yield MultiTableSynthesizerModel(
metadata_combiner=demo_multi_table_data_metadata_combiner,
data_access_method=DataAccessType.sdgx_data_loader,
tables_data_loader=demo_multi_table_data_loader,
)


def test_base_multi_table_synthesizer(demo_base_multi_table_synthesizer):
KeyTuple = namedtuple("KeyTuple", ["parent", "child"])

assert demo_base_multi_table_synthesizer.parent_map == defaultdict(None, {"train": "store"})
assert demo_base_multi_table_synthesizer.child_map == defaultdict(None, {"store": "train"})
assert demo_base_multi_table_synthesizer._get_all_foreign_keys("train")[0][0] == KeyTuple(
parent="Store", child="Store"
)


if __name__ == "__main__":
pytest.main(["-vv", "-s", __file__])