Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add demo data for multi-table scenario #98

Merged
merged 10 commits into from
Jan 8, 2024
88 changes: 87 additions & 1 deletion sdgx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,27 @@
except ImportError:
from functools import lru_cache as cache

__all__ = ["download_demo_data", "get_demo_single_table", "cache", "Singleton", "find_free_port"]
__all__ = [
"download_demo_data",
"get_demo_single_table",
"cache",
"Singleton",
"find_free_port",
"download_multi_table_demo_data",
"get_demo_single_table",
]

MULTI_TABLE_DEMO_DATA = {
"rossman": {
"parent_table": "store",
"child_table": "train",
"parent_url": "https://github.com/raw/juniorcl/rossman-store-sales/main/databases/store.csv",
"child_url": "https://github.com/raw/juniorcl/rossman-store-sales/main/databases/train.csv",
"parent_primary_keys": ["Store"],
"child_primary_keys": ["Store", "Date"],
"foreign_keys": ["Store"],
}
}


def find_free_port():
Expand Down Expand Up @@ -95,6 +115,72 @@ def __call__(cls, *args, **kwargs):
return cls._instances[cls]


def download_multi_table_demo_data(
data_dir: str | Path = "./dataset", dataset_name="rossman"
) -> dict[str, Path]:
"""
Download multi-table demo data "Rossman Store Sales" or "Rossmann Store Sales" if not exist

Args:
data_dir(str | Path): data directory

Returns:
dict[str, pathlib.Path]: dict, the key is table name, value is demo data path
"""
demo_data_info = MULTI_TABLE_DEMO_DATA[dataset_name]
data_dir = Path(data_dir).expanduser().resolve()
parent_file_name = dataset_name + "_" + demo_data_info["parent_table"] + ".csv"
child_file_name = dataset_name + "_" + demo_data_info["child_table"] + ".csv"
demo_data_path_parent = data_dir / parent_file_name
demo_data_path_child = data_dir / child_file_name
# For now, I think it's OK to hardcode the URL for each dataset
# In the future we can consider using our own S3 Bucket or providing more data sets through sdg.idslab.io.
if not demo_data_path_parent.exists():
# make dir
demo_data_path_parent.parent.mkdir(parents=True, exist_ok=True)
# download parent table from github link
logger.info("Downloading parent table from github to {}".format(demo_data_path_parent))
parent_url = demo_data_info["parent_url"]
urllib.request.urlretrieve(parent_url, demo_data_path_parent)
# then child table
if not demo_data_path_child.exists():
# make dir
demo_data_path_child.parent.mkdir(parents=True, exist_ok=True)
# download child table from github link
logger.info("Downloading child table from github to {}".format(demo_data_path_child))
parent_url = demo_data_info["child_url"]
urllib.request.urlretrieve(parent_url, demo_data_path_child)

return {
demo_data_info["parent_table"]: demo_data_path_parent,
demo_data_info["child_table"]: demo_data_path_child,
}


def get_demo_multi_table(
data_dir: str | Path = "./dataset", dataset_name="rossman"
) -> dict[str, pd.DataFrame]:
"""
Get multi-table demo data as DataFrame and relationship

Args:
data_dir(str | Path): data directory

Returns:
dict[str, pd.DataFrame]: multi-table data dict, the key is table name, value is DataFrame.
"""
multi_table_dict = {}
# download if not exist
demo_data_dict = download_multi_table_demo_data(data_dir, dataset_name)
# read Data from path
for table_name in demo_data_dict.keys():
each_path = demo_data_dict[table_name]
pd_obj = pd.read_csv(each_path)
multi_table_dict[table_name] = pd_obj

return multi_table_dict


def ignore_warnings(category: Warning):
def ignore_warnings_decorator(func: Callable):
@functools.wraps(func)
Expand Down
28 changes: 27 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.data_loader import DataLoader
from sdgx.data_models.metadata import Metadata
from sdgx.utils import download_demo_data
from sdgx.utils import download_demo_data, download_multi_table_demo_data

_HERE = os.path.dirname(__file__)

Expand Down Expand Up @@ -132,3 +132,29 @@ def demo_single_table_data_loader(demo_single_table_data_connector, cacher_kwarg
@pytest.fixture
def demo_single_table_metadata(demo_single_table_data_loader):
yield Metadata.from_dataloader(demo_single_table_data_loader)


@pytest.fixture
def demo_multi_table_path():
yield download_multi_table_demo_data(DATA_DIR)


@pytest.fixture
def demo_multi_table_data_connector(demo_multi_table_path):
connector_dict = {}
for each_table in demo_multi_table_path.keys():
each_path = demo_multi_table_path[each_table]
connector_dict[each_table] = CsvConnector(path=each_path)
yield connector_dict


@pytest.fixture
def demo_multi_table_data_loader(demo_multi_table_data_connector, cacher_kwargs):
loader_dict = {}
for each_table in demo_multi_table_data_connector.keys():
each_connector = demo_multi_table_data_connector[each_table]
each_d = DataLoader(each_connector, cacher_kwargs=cacher_kwargs)
loader_dict[each_table] = each_d
yield loader_dict
for each_table in demo_multi_table_data_connector.keys():
demo_multi_table_data_connector[each_table].finalize()