From 4ea93b0bbbd96a6f743e4c260f5b717062a0f0bc Mon Sep 17 00:00:00 2001 From: Zhongsheng Ji <9573586@qq.com> Date: Sat, 23 Dec 2023 10:29:26 +0800 Subject: [PATCH] CLI for singe table synthesizer (#86) - Intro `Data Exporter` for exporting sampled data to data sources - CLI updates for synthesizer --- .github/workflows/extension.yml | 1 + .../api_reference/data_exporters/base.rst | 9 + .../data_exporters/csv_exporter.rst | 10 + .../data_exporters/extension.rst | 11 + .../api_reference/data_exporters/index.rst | 24 ++ .../api_reference/data_exporters/manager.rst | 9 + docs/source/api_reference/index.rst | 1 + .../developer_guides/extension/index.rst | 20 +- docs/source/user_guides/cli.rst | 22 + .../dummyexporter/dummyexporter/__init__.py | 1 + .../dummyexporter/dummyexporter.py | 15 + .../extension/dummyexporter/pyproject.toml | 27 ++ .../tests/test_registed_exporter.py | 16 + sdgx/cli/main.py | 398 +++++++++++++++--- sdgx/cli/message.py | 48 +++ sdgx/cli/models.py | 0 sdgx/cli/utils.py | 28 ++ .../__init__.py} | 0 sdgx/data_exporters/base.py | 10 + sdgx/data_exporters/csv_exporter.py | 41 ++ sdgx/data_exporters/extension.py | 57 +++ sdgx/data_exporters/manager.py | 25 ++ sdgx/exceptions.py | 48 ++- sdgx/log.py | 9 +- sdgx/models/manager.py | 4 +- sdgx/models/ml/single_table/ctgan.py | 4 +- sdgx/synthesizer.py | 58 +-- tests/cli/test_cli.py | 97 +++++ tests/cli/test_message.py | 41 ++ tests/conftest.py | 5 +- tests/dataloader/conftest.py | 0 tests/dataloader/test_cacher.py | 2 - tests/manager/test_exporter.py | 22 + tests/models/test_copula.py | 4 +- tests/test_csv_exporter.py | 37 ++ tests/test_synthesizer.py | 2 +- 36 files changed, 1018 insertions(+), 88 deletions(-) create mode 100644 docs/source/api_reference/data_exporters/base.rst create mode 100644 docs/source/api_reference/data_exporters/csv_exporter.rst create mode 100644 docs/source/api_reference/data_exporters/extension.rst create mode 100644 docs/source/api_reference/data_exporters/index.rst create mode 100644 docs/source/api_reference/data_exporters/manager.rst create mode 100644 example/extension/dummyexporter/dummyexporter/__init__.py create mode 100644 example/extension/dummyexporter/dummyexporter/dummyexporter.py create mode 100644 example/extension/dummyexporter/pyproject.toml create mode 100644 example/extension/dummyexporter/tests/test_registed_exporter.py create mode 100644 sdgx/cli/message.py delete mode 100644 sdgx/cli/models.py create mode 100644 sdgx/cli/utils.py rename sdgx/{cli/exporter.py => data_exporters/__init__.py} (100%) create mode 100644 sdgx/data_exporters/base.py create mode 100644 sdgx/data_exporters/csv_exporter.py create mode 100644 sdgx/data_exporters/extension.py create mode 100644 sdgx/data_exporters/manager.py create mode 100644 tests/cli/test_cli.py create mode 100644 tests/cli/test_message.py delete mode 100644 tests/dataloader/conftest.py create mode 100644 tests/manager/test_exporter.py create mode 100644 tests/test_csv_exporter.py diff --git a/.github/workflows/extension.yml b/.github/workflows/extension.yml index e2b9c37b..3fd2ab78 100644 --- a/.github/workflows/extension.yml +++ b/.github/workflows/extension.yml @@ -27,6 +27,7 @@ jobs: python -m pip install -e .[test] - name: Install all packages in example/extension run: | + python -m pip install -e example/extension/dummyexporter[test] python -m pip install -e example/extension/dummymetadatainspector[test] python -m pip install -e example/extension/dummycache[test] python -m pip install -e example/extension/dummydataconnector[test] diff --git a/docs/source/api_reference/data_exporters/base.rst b/docs/source/api_reference/data_exporters/base.rst new file mode 100644 index 00000000..0ac564dc --- /dev/null +++ b/docs/source/api_reference/data_exporters/base.rst @@ -0,0 +1,9 @@ +Base Class for DataExporter +======================= + +.. autoclass:: sdgx.data_exporters.base.DataExporter + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + :private-members: diff --git a/docs/source/api_reference/data_exporters/csv_exporter.rst b/docs/source/api_reference/data_exporters/csv_exporter.rst new file mode 100644 index 00000000..1c1e2725 --- /dev/null +++ b/docs/source/api_reference/data_exporters/csv_exporter.rst @@ -0,0 +1,10 @@ +CsvExporter +===================================== + + +.. autoclass:: sdgx.data_exporters.csv_exporter.CsvExporter + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + :private-members: diff --git a/docs/source/api_reference/data_exporters/extension.rst b/docs/source/api_reference/data_exporters/extension.rst new file mode 100644 index 00000000..65b29e46 --- /dev/null +++ b/docs/source/api_reference/data_exporters/extension.rst @@ -0,0 +1,11 @@ +.. _api_reference/data-exporters-extension: + +Extension hookspec +============================ + +.. automodule:: sdgx.data_exporters.extension + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + :private-members: diff --git a/docs/source/api_reference/data_exporters/index.rst b/docs/source/api_reference/data_exporters/index.rst new file mode 100644 index 00000000..ffd27472 --- /dev/null +++ b/docs/source/api_reference/data_exporters/index.rst @@ -0,0 +1,24 @@ +Data Exporter +======================================================== + +.. toctree:: + :maxdepth: 1 + + Base Class for DataExporter + +Built-in DataExporter +----------------------------- + +.. toctree:: + :maxdepth: 2 + + CsvExporter + +Custom DataExporter Relevant +----------------------------- + +.. toctree:: + :maxdepth: 2 + + Extension hookspec + DataExporterManager diff --git a/docs/source/api_reference/data_exporters/manager.rst b/docs/source/api_reference/data_exporters/manager.rst new file mode 100644 index 00000000..ab96c2ac --- /dev/null +++ b/docs/source/api_reference/data_exporters/manager.rst @@ -0,0 +1,9 @@ +DataExporterManager +================================= + +.. autoclass:: sdgx.data_exporters.manager.DataExporterManager + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + :private-members: diff --git a/docs/source/api_reference/index.rst b/docs/source/api_reference/index.rst index e08737a8..b234122f 100644 --- a/docs/source/api_reference/index.rst +++ b/docs/source/api_reference/index.rst @@ -14,6 +14,7 @@ API Reference Data Processor Models Metadata and Inspectors + Data Exporter Manager Exceptions Utils diff --git a/docs/source/developer_guides/extension/index.rst b/docs/source/developer_guides/extension/index.rst index e028ddd2..885c6053 100644 --- a/docs/source/developer_guides/extension/index.rst +++ b/docs/source/developer_guides/extension/index.rst @@ -23,8 +23,18 @@ View latest extension example on `GitHub ` -- :ref:`Data Connector ` -- :ref:`Data Processor ` -- :ref:`Inspector for Metadata ` -- :ref:`Model ` +- :ref:`API Reference for extended Data Connector `: + :ref:`Data Connector ` is used to connect to data sources. +- :ref:`API Reference for extended Cacher for DataLoader `: + :ref:`Cacher ` is used for improving performance, + reducing network overhead and support large datasets. +- :ref:`API Reference for extended Data Processor `: + :ref:`Data Processor ` is used to pre-process and post-process data. + It is useful for business logic. +- :ref:`API Reference for extended Inspector for Metadata `: + :ref:`Inspector ` is used to extract metadata such as patterns, types, etc. from raw data. +- :ref:`API Reference for extended Model `: + :ref:`Model `, the model fitted by processed data and used to generate synthetic data. +- :ref:`API Reference for extended Data Exporter `: + :ref:`Data Exporter ` is used to export data to somewhere. + Use it in CLI or library way to save your processed data or synthetic data. diff --git a/docs/source/user_guides/cli.rst b/docs/source/user_guides/cli.rst index 2c91d74a..9001cd0d 100644 --- a/docs/source/user_guides/cli.rst +++ b/docs/source/user_guides/cli.rst @@ -1,2 +1,24 @@ Command Line Interface ================================================== + +Command Line Interface(CLI) is designed to simplify the usage of SDG and enable other programs to use SDG in a more convenient way. + +There are tow main commands in the CLI: + +- ``fit``: For fitting, finetuning, retraining... the model, which will save the final model to a specified path. +- ``sample``: Load existing model and sample synthetic data. + +And as SDG supports plug-in system, users can list all available via ``list-{component}`` command. + +.. Note:: + + If you want to use SDG as a library, please refer to :ref:`Use Synthetic Data Generator as a library `. + + If you want to extend SDG with your own components, please refer to :ref:`Developer guides for extension `. + +CLI for synthetic single-table data +-------------------------------------------------- + +.. click:: sdgx.cli.main:cli + :prog: sdgx + :nested: full diff --git a/example/extension/dummyexporter/dummyexporter/__init__.py b/example/extension/dummyexporter/dummyexporter/__init__.py new file mode 100644 index 00000000..3dc1f76b --- /dev/null +++ b/example/extension/dummyexporter/dummyexporter/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/example/extension/dummyexporter/dummyexporter/dummyexporter.py b/example/extension/dummyexporter/dummyexporter/dummyexporter.py new file mode 100644 index 00000000..a9b2e9e9 --- /dev/null +++ b/example/extension/dummyexporter/dummyexporter/dummyexporter.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from sdgx.data_exporters.base import DataExporter + + +class MyOwnExporter(DataExporter): + ... + + +from sdgx.data_exporters.extension import hookimpl + + +@hookimpl +def register(manager): + manager.register("MyOwnExporter", MyOwnExporter) diff --git a/example/extension/dummyexporter/pyproject.toml b/example/extension/dummyexporter/pyproject.toml new file mode 100644 index 00000000..fd6d50bc --- /dev/null +++ b/example/extension/dummyexporter/pyproject.toml @@ -0,0 +1,27 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "dummyexporter" +dependencies = ["sdgx"] +dynamic = ["version"] +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', +] +[project.optional-dependencies] +test = ["pytest"] + +[tool.check-manifest] +ignore = [".*"] + +[tool.hatch.version] +path = "dummyexporter/__init__.py" + +[project.entry-points."sdgx.data_exporter"] +dummyexporter = "dummyexporter.dummyexporter" diff --git a/example/extension/dummyexporter/tests/test_registed_exporter.py b/example/extension/dummyexporter/tests/test_registed_exporter.py new file mode 100644 index 00000000..3b25827c --- /dev/null +++ b/example/extension/dummyexporter/tests/test_registed_exporter.py @@ -0,0 +1,16 @@ +import pytest + +from sdgx.data_exporters.manager import DataExporterManager + + +@pytest.fixture +def manager(): + yield DataExporterManager() + + +def test_registed_exporter(manager: DataExporterManager): + assert manager._normalize_name("MyOwnExporter") in manager.registed_exporters + + +if __name__ == "__main__": + pytest.main(["-vv", "-s", __file__]) diff --git a/sdgx/cli/main.py b/sdgx/cli/main.py index ee53eedd..5888f06b 100644 --- a/sdgx/cli/main.py +++ b/sdgx/cli/main.py @@ -1,96 +1,390 @@ +from __future__ import annotations + import json +import time from pathlib import Path import click -import pandas from sdgx.cachers.manager import CacherManager +from sdgx.cli.utils import cli_wrapper from sdgx.data_connectors.manager import DataConnectorManager +from sdgx.data_exporters.manager import DataExporterManager from sdgx.data_processors.manager import DataProcessorManager +from sdgx.log import logger from sdgx.models.manager import ModelManager +from sdgx.synthesizer import Synthesizer @click.command() +@click.option( + "--save_dir", + type=str, + required=True, + default="", + help="The directory to save the synthesizer", +) @click.option( "--model", - help="Name of model, use `sdgx list-models` to list all available models.", + type=str, required=True, + help="The name of the model.", ) @click.option( - "--model_params", - default="{}", - help="[Json-string] Parameters for model.", + "--model_path", + type=str, + default=None, + help="The path of the model to load", ) @click.option( - "--input_path", - help="Path of input data.", - required=True, + "--model_kwargs", + type=str, + default=None, + help="[Json String] The kwargs of the model for initialization", ) @click.option( - "--input_type", - default="csv", - help="Type of input data, will be used as `pandas.read_{input_type}`.", + "--load_dir", + type=str, + default=None, + help="The directory to load the synthesizer, if it is specified, ``model_path`` will be ignored.", ) @click.option( - "--read_params", - default="{}", - help="[Json-string] Parameters for `pandas.read_{input_type}`.", + "--metadata_path", + type=str, + default=None, + help="The path of the metadata to load", ) @click.option( - "--fit_params", - default="{}", - help="[Json-string] Parameters for `model.fit`.", + "--data_connector", + type=str, + default=None, + help="The name of the data connector to use", ) @click.option( - "--output_path", - help="Path to save the model.", - required=True, + "--data_connector_kwargs", + type=str, + default=None, + help="[Json String] The kwargs of the data connector to use", +) +@click.option( + "--raw_data_loaders_kwargs", + type=str, + default=None, + help="[Json String] The kwargs of the raw data loader to use", +) +@click.option( + "--processed_data_loaders_kwargs", + type=str, + default=None, + help="[Json String] The kwargs of the processed data loader to use", ) +@click.option( + "--data_processors", + type=str, + default=None, + help="[Comma separated list] The name of the data processors to use, e.g. 'processor_x,processor_y'", +) +@click.option( + "--data_processors_kwargs", + type=str, + default=None, + help="[Json String] The kwargs of the data processors to use", +) +@click.option( + "--inspector_max_chunk", + type=int, + default=None, + help="The max chunk of the inspector to load", +) +@click.option( + "--metadata_include_inspectors", + type=str, + default=None, + help="[Comma separated list] The name of the inspectors to include, e.g. 'inspector_x,inspector_y'", +) +@click.option( + "--metadata_exclude_inspectors", + type=str, + default=None, + help="[Comma separated list] The name of the inspectors to exclude, e.g. 'inspector_x,inspector_y'", +) +@click.option( + "--inspector_init_kwargs", + type=str, + default=None, + help="[Json String] The kwargs of the inspector to use", +) +@click.option( + "--model_fit_kwargs", + type=str, + default=None, + help="[Json String] The kwargs of the model fit method", +) +@click.option( + "--dry_run", + type=bool, + default=False, + help="Only init the synthesizer without fitting and save.", +) +@cli_wrapper def fit( - model, - model_params, - input_path, - input_type, - read_params, - fit_params, - output_path, + save_dir: str, + model: str, + model_path: str | None, + model_kwargs: str | None, + load_dir: str | None, + metadata_path: str | None, + data_connector: str | None, + data_connector_kwargs: str | None, + raw_data_loaders_kwargs: str | None, + processed_data_loaders_kwargs: str | None, + data_processors: str | None, + data_processors_kwargs: str | None, + # ``fit`` args + inspector_max_chunk: int | None, + metadata_include_inspectors: str | None, + metadata_exclude_inspectors: str | None = None, + inspector_init_kwargs: str | None = None, + model_fit_kwargs: str | None = None, + # Others + dry_run: bool = False, ): - model_params = json.loads(model_params) - read_params = json.loads(read_params) - fit_params = json.loads(fit_params) + """ + Fit the synthesizer or load a synthesizer for fitnuning/retraining/continue training... + """ + if data_processors is not None: + data_processors = data_processors.strip().split(",") + + if model_kwargs is not None: + model_kwargs = json.loads(model_kwargs) + if data_connector_kwargs is not None: + data_connector_kwargs = json.loads(data_connector_kwargs) + if raw_data_loaders_kwargs is not None: + raw_data_loaders_kwargs = json.loads(raw_data_loaders_kwargs) + if processed_data_loaders_kwargs is not None: + processed_data_loaders_kwargs = json.loads(processed_data_loaders_kwargs) + if data_processors_kwargs is not None: + data_processors_kwargs = json.loads(data_processors_kwargs) + + fit_kwargs = {} + if inspector_max_chunk is not None: + fit_kwargs["inspector_max_chunk"] = inspector_max_chunk + if metadata_include_inspectors is not None: + fit_kwargs["metadata_include_inspectors"] = metadata_include_inspectors.strip().split(",") + if metadata_exclude_inspectors is not None: + fit_kwargs["metadata_exclude_inspectors"] = metadata_exclude_inspectors.strip().split(",") + if inspector_init_kwargs is not None: + fit_kwargs["inspector_init_kwargs"] = json.loads(inspector_init_kwargs) + if model_fit_kwargs is not None: + fit_kwargs["model_fit_kwargs"] = json.loads(model_fit_kwargs) + if not save_dir: + save_dir = Path(f"./sdgx-fit-model-{model}-{time.time()}") + else: + save_dir = Path(save_dir).expanduser().resolve() + save_dir.mkdir(parents=True, exist_ok=True) - model = ModelManager().init_model(model, **model_params) - input_method = getattr(pandas, f"read_{input_type}") - if not input_method: - raise NotImplementedError(f"Pandas not support read_{input_type}") - df = input_method(input_path, **read_params) - model.fit(df, **fit_params) + if load_dir: + if model_path: + logger.warning( + "Both ``model_path`` and ``load_dir`` are specified, ``model_path`` will be ignored." + ) + synthesizer = Synthesizer.load( + load_dir=load_dir, + metadata_path=metadata_path, + data_connector=data_connector, + data_connector_kwargs=data_connector_kwargs, + raw_data_loaders_kwargs=raw_data_loaders_kwargs, + processed_data_loaders_kwargs=processed_data_loaders_kwargs, + data_processors=data_processors, + data_processors_kwargs=data_processors_kwargs, + ) + else: + if model_kwargs and model_path: + logger.warning( + "Both ``model_kwargs`` and ``model_path`` are specified, ``model_kwargs`` will be ignored." + ) + synthesizer = Synthesizer( + model=model, + model_kwargs=model_kwargs, + model_path=model_path, + metadata_path=metadata_path, + data_connector=data_connector, + data_connector_kwargs=data_connector_kwargs, + raw_data_loaders_kwargs=raw_data_loaders_kwargs, + processed_data_loaders_kwargs=processed_data_loaders_kwargs, + data_processors=data_processors, + data_processors_kwargs=data_processors_kwargs, + ) - Path(output_path).mkdir(parents=True, exist_ok=True) - model.save(output_path) + if dry_run: + return + + synthesizer.fit(**fit_kwargs) + save_dir = synthesizer.save(save_dir) + + return save_dir.absolute().as_posix() @click.command() +@click.option( + "--load_dir", + type=str, + required=True, + help="The directory to load the synthesizer.", +) +@click.option( + "--model", + type=str, + required=True, + help="The name of the model.", +) +@click.option( + "--raw_data_loaders_kwargs", + type=str, + default=None, + help="[Json String] The kwargs of the raw data loaders.", +) +@click.option( + "--processed_data_loaders_kwargs", + type=str, + default=None, + help="[Json String] The kwargs of the processed data loaders.", +) +@click.option( + "--data_processors", + type=str, + default=None, + help="[Comma separated list] The name of the data processors, e.g. 'data_processor_1,data_processor_2'.", +) +@click.option( + "--data_processors_kwargs", + type=str, + default=None, + help="[Json String] The kwargs of the data processors.", +) +@click.option( + "--count", + type=int, + default=100, + help="The number of samples to generate.", +) +@click.option( + "--chunksize", + type=int, + default=None, + help="The size of each chunk. If count is very large, chunksize is recommended.", +) +@click.option( + "--model_sample_args", + type=str, + default=None, + help="[Json String] The kwargs of the model.sample.", +) +@click.option( + "--data_exporter", + type=str, + default="CsvExporter", + required=True, + help="The name of the data exporter.", +) +@click.option( + "--data_exporter_kwargs", + type=str, + default=None, + help="[Json String] The kwargs of the data exporter.", +) +@click.option( + "--export_dst", + type=str, + default=None, + help="The destination of the exported data.", +) +@click.option( + "--dry_run", + type=bool, + default=False, + help="Dry run. Only initialize the synthesizer without sampling.", +) +@cli_wrapper def sample( - model_path, - output_path, - write_type, - write_param, + load_dir: str, + model: str, + raw_data_loaders_kwargs: str | None, + processed_data_loaders_kwargs: str | None, + data_processors: str | None, + data_processors_kwargs: str | None, + # ``sample`` args + count: int, + chunksize: int | None, + model_sample_args: str | None, + # ``exporter`` args + data_exporter: str, + data_exporter_kwargs: str | None, + export_dst: str | None, + dry_run: bool, ): - model = ModelManager.load(model_path) - # TODO: Model not have `sample` in Base Class yet - # sampled_data = model.sample() - # write_method = getattr(sampled_data, f"to_{write_type}") - # write_method(output_path, write_param) + """ + Load a synthesizer and sample. + + ``load_dir`` should contain model and metadata. Please check :ref:`Synthesizer `'s `load` method for more details. + """ + if data_processors is not None: + data_processors = data_processors.strip().split(",") + + if raw_data_loaders_kwargs is not None: + raw_data_loaders_kwargs = json.loads(raw_data_loaders_kwargs) + if processed_data_loaders_kwargs is not None: + processed_data_loaders_kwargs = json.loads(processed_data_loaders_kwargs) + if data_processors_kwargs is not None: + data_processors_kwargs = json.loads(data_processors_kwargs) + + if model_sample_args is not None: + model_sample_args = json.loads(model_sample_args) + + if data_exporter_kwargs is not None: + data_exporter_kwargs = json.loads(data_exporter_kwargs) + else: + data_exporter_kwargs = {} + + if not export_dst: + # Assume export to current directory + export_dst = Path(f"./sdgx-{model}-{time.time()}/sample-data.csv").expanduser().resolve() + + synthesizer = Synthesizer.load( + load_dir=load_dir, + model=model, + raw_data_loaders_kwargs=raw_data_loaders_kwargs, + processed_data_loaders_kwargs=processed_data_loaders_kwargs, + data_processors=data_processors, + data_processors_kwargs=data_processors_kwargs, + ) + + exporter = DataExporterManager().init_exporter(data_exporter, **data_exporter_kwargs) + + if dry_run: + return + + exporter.write( + export_dst, + synthesizer.sample( + count=count, + chunksize=chunksize, + model_sample_args=model_sample_args, + ), + ) + + return export_dst @click.command() +@cli_wrapper def list_models(): for model_name, model_cls in ModelManager().registed_models.items(): print(f"{model_name} is registed as class: {model_cls}.") @click.command() +@cli_wrapper def list_data_connectors(): for ( model_name, @@ -100,6 +394,7 @@ def list_data_connectors(): @click.command() +@cli_wrapper def list_data_processors(): for ( model_name, @@ -109,11 +404,19 @@ def list_data_processors(): @click.command() +@cli_wrapper def list_cachers(): for model_name, model_cls in CacherManager().registed_cachers.items(): print(f"{model_name} is registed as class: {model_cls}.") +@click.command() +@cli_wrapper +def list_data_exporters(): + for model_name, model_cls in DataExporterManager().registed_exporters.items(): + print(f"{model_name} is registed as class: {model_cls}.") + + @click.group() def cli(): pass @@ -125,6 +428,7 @@ def cli(): cli.add_command(list_data_connectors) cli.add_command(list_data_processors) cli.add_command(list_cachers) +cli.add_command(list_data_exporters) if __name__ == "__main__": diff --git a/sdgx/cli/message.py b/sdgx/cli/message.py new file mode 100644 index 00000000..1fabad46 --- /dev/null +++ b/sdgx/cli/message.py @@ -0,0 +1,48 @@ +from pydantic import BaseModel + +from sdgx.exceptions import SdgxError + + +class ExitMessage(BaseModel): + code: int + msg: str + payload: dict = {} + + def _dumo_json(self) -> str: + return self.model_dump_json() + + def send(self): + print(self._dumo_json(), flush=True, end="") + + +class NormalMessage(ExitMessage): + code: int = 0 + msg: str = "Success" + + @classmethod + def from_return_val(cls, return_val) -> "NormalMessage": + if isinstance(return_val, dict): + payload = return_val + else: + payload = {"return_val": return_val} + return cls(code=0, msg="Success", payload=payload) + + +class ExceptionMessage(ExitMessage): + @classmethod + def from_exception(cls, e: Exception) -> "ExceptionMessage": + if isinstance(e, SdgxError): + return cls( + code=e.ERROR_CODE, + msg=str(e), + payload={ + "details": "Synthetic Data Generator Error, please check logs and raise an issue on https://github.com/hitsz-ids/synthetic-data-generator." + }, + ) + return cls( + code=-1, + msg=str(e), + payload={ + "details": "Unknown Exceptions, please check logs and raise an issue on https://github.com/hitsz-ids/synthetic-data-generator." + }, + ) diff --git a/sdgx/cli/models.py b/sdgx/cli/models.py deleted file mode 100644 index e69de29b..00000000 diff --git a/sdgx/cli/utils.py b/sdgx/cli/utils.py new file mode 100644 index 00000000..7379546f --- /dev/null +++ b/sdgx/cli/utils.py @@ -0,0 +1,28 @@ +from functools import wraps + +import click + +from sdgx.cli.message import ExceptionMessage, NormalMessage +from sdgx.log import LOG_TO_FILE, add_log_file_handler, logger + + +def cli_wrapper(func): + @click.option("--json_output", type=bool, default=False, help="Exit with json output.") + @click.option("--log_to_file", type=bool, default=False, help="Log to file.") + @wraps(func) + def wrapper(json_output, log_to_file, *args, **kwargs): + if log_to_file and not LOG_TO_FILE: + add_log_file_handler() + try: + func(*args, **kwargs) + except Exception as e: + logger.exception(e) + if json_output: + ExceptionMessage.from_exception(e).send() + exit(getattr(e, "EXIT_CODE", -1)) + else: + if json_output: + NormalMessage().send() + exit(0) + + return wrapper diff --git a/sdgx/cli/exporter.py b/sdgx/data_exporters/__init__.py similarity index 100% rename from sdgx/cli/exporter.py rename to sdgx/data_exporters/__init__.py diff --git a/sdgx/data_exporters/base.py b/sdgx/data_exporters/base.py new file mode 100644 index 00000000..c88cf524 --- /dev/null +++ b/sdgx/data_exporters/base.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from typing import Any, Generator + +import pandas as pd + + +class DataExporter: + def write(self, dst: Any, data: pd.DataFrame | Generator[pd.DataFrame, None, None]) -> None: + raise NotImplementedError diff --git a/sdgx/data_exporters/csv_exporter.py b/sdgx/data_exporters/csv_exporter.py new file mode 100644 index 00000000..f00a9906 --- /dev/null +++ b/sdgx/data_exporters/csv_exporter.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from functools import partial +from pathlib import Path +from typing import Generator + +import pandas as pd + +from sdgx.data_exporters.base import DataExporter +from sdgx.exceptions import CannotExportError + + +class CsvExporter(DataExporter): + def __init__(self, **kwargs): + self.to_csv_kwargs = kwargs + if "header" in self.to_csv_kwargs: + self.to_csv_kwargs.pop("header") + if "index" in self.to_csv_kwargs: + self.to_csv_kwargs.pop("index") + + def write( + self, + dst: str | Path, + data: pd.DataFrame | Generator[pd.DataFrame, None, None], + ) -> None: + if isinstance(data, pd.DataFrame): + data.to_csv(dst, index=False, **self.to_csv_kwargs) + elif isinstance(data, Generator): + with open(dst, "a") as file: + for df in data: + df.to_csv(file, header=file.tell() == 0, index=False, **self.to_csv_kwargs) + else: + raise CannotExportError(f"Cannot export data of type {type(data)} to csv") + + +from sdgx.data_exporters.extension import hookimpl + + +@hookimpl +def register(manager): + manager.register("CsvExporter", CsvExporter) diff --git a/sdgx/data_exporters/extension.py b/sdgx/data_exporters/extension.py new file mode 100644 index 00000000..1d6d205a --- /dev/null +++ b/sdgx/data_exporters/extension.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import pluggy + +project_name = "sdgx.data_exporter" +""" +The entry-point name of this extension. + +Should be used in ``pyproject.toml`` as ``[project.entry-points."{project_name}"]`` +""" +hookimpl = pluggy.HookimplMarker(project_name) +""" +Hookimpl marker for this extension, extension module should use this marker + +Example: + + .. code-block:: python + + @hookimpl + def register(manager): + ... +""" + +hookspec = pluggy.HookspecMarker(project_name) + + +@hookspec +def register(manager): + """ + For more information about this function, please check the :ref:`manager` + + We provided an example package for you in ``{project_root}/example/extension/dummyexporter``. + + Example: + + .. code-block:: python + + class MyOwnExporter(DataExporter): + ... + + from sdgx.data_exporters.extension import hookimpl + + @hookimpl + def register(manager): + manager.register("MyOwnExporter", MyOwnExporter) + + + Config ``project.entry-points`` so that we can find it + + .. code-block:: toml + + [project.entry-points."sdgx.data_exporter"] + {whatever-name} = "{package}.{path}.{to}.{file-with-hookimpl-function}" + + + You can verify it by `sdgx list-data-connectors`. + """ diff --git a/sdgx/data_exporters/manager.py b/sdgx/data_exporters/manager.py new file mode 100644 index 00000000..592d56a3 --- /dev/null +++ b/sdgx/data_exporters/manager.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import Any + +from sdgx import data_exporters +from sdgx.data_exporters import extension +from sdgx.data_exporters.base import DataExporter +from sdgx.data_exporters.extension import project_name as PROJECT_NAME +from sdgx.manager import Manager + + +class DataExporterManager(Manager): + register_type = DataExporter + project_name = PROJECT_NAME + hookspecs_model = extension + + @property + def registed_exporters(self): + return self.registed_cls + + def load_all_local_model(self): + self._load_dir(data_exporters) + + def init_exporter(self, exporter_name, **kwargs: dict[str, Any]) -> DataExporter: + return self.init(exporter_name, **kwargs) diff --git a/sdgx/exceptions.py b/sdgx/exceptions.py index b73ae4e3..03dabb6f 100644 --- a/sdgx/exceptions.py +++ b/sdgx/exceptions.py @@ -3,62 +3,108 @@ class SdgxError(Exception): Base class for exceptions in this module. """ + EXIT_CODE = 100 + ERROR_CODE = 1001 + class NonParametricError(Exception): """ Exception to indicate that a model is not parametric. """ + EXIT_CODE = 101 + ERROR_CODE = 2001 + class ManagerError(SdgxError): """ Exception to indicate that exception when using manager. """ + EXIT_CODE = 102 + ERROR_CODE = 3000 + class NotFoundError(ManagerError): """ Exception to indicate that a model is not found. """ + ERROR_CODE = 3001 + class RegisterError(ManagerError): """ Exception to indicate that exception when registering. """ + ERROR_CODE = 3002 + class InitializationError(ManagerError): """ Exception to indicate that exception when initializing model. """ + ERROR_CODE = 3003 + class ManagerLoadModelError(ManagerError): """ Exception to indicate that exception when loading model for :ref:`ModelManager`. """ + ERROR_CODE = 3004 + -class SynthesizerInitError(ManagerError): +class SynthesizerError(SdgxError): """ Exception to indicate that exception when synthesizing model. """ + EXIT_CODE = 103 + ERROR_CODE = 4000 + + +class SynthesizerInitError(SynthesizerError): + ERROR_CODE = 4001 + + +class SynthesizerSampleError(SynthesizerError): + ERROR_CODE = 4002 + class CacheError(SdgxError): """ Exception to indicate that exception when using cache. """ + EXIT_CODE = 104 + ERROR_CODE = 5001 + class MetadataInitError(SdgxError): """ Exception to indicate that exception when initializing metadata. """ + EXIT_CODE = 105 + ERROR_CODE = 6001 + class DataLoaderInitError(SdgxError): """ Exception to indicate that exception when initializing dataloader. """ + + EXIT_CODE = 106 + ERROR_CODE = 7001 + + +class CannotExportError(SdgxError): + """ + Exception to indicate that exception when exporting data. + """ + + EXIT_CODE = 107 + ERROR_CODE = 8001 diff --git a/sdgx/log.py b/sdgx/log.py index 0c569a9d..c9133063 100644 --- a/sdgx/log.py +++ b/sdgx/log.py @@ -7,10 +7,15 @@ from loguru import logger -if LOG_TO_FILE: + +def add_log_file_handler(): logger.add( "sdgx-{time}.log", rotation="10 MB", ) -__all__ = ["logger"] + +if LOG_TO_FILE: + add_log_file_handler() + +__all__ = ["logger", "LOG_TO_FILE", "add_log_file_handler", "USER_DEFINED_LOG_LEVEL"] diff --git a/sdgx/models/manager.py b/sdgx/models/manager.py index c7c55696..320c9303 100644 --- a/sdgx/models/manager.py +++ b/sdgx/models/manager.py @@ -36,10 +36,12 @@ def init_model(self, model_name, **kwargs: dict[str, Any]) -> SynthesizerModel: return self.init(model_name, **kwargs) def load(self, model: type[SynthesizerModel] | str, model_path) -> SynthesizerModel: - if not isinstance(model, type) or isinstance(model, str): + if not (isinstance(model, type) or isinstance(model, str)): raise ManagerLoadModelError( "model must be type of SynthesizerModel or str for model_name" ) + if isinstance(model, str): + model = self._normalize_name(model) if isinstance(model, str) and model not in self.registed_models: raise ManagerLoadModelError(f"{model} is not registered.") diff --git a/sdgx/models/ml/single_table/ctgan.py b/sdgx/models/ml/single_table/ctgan.py index fca76524..4d166809 100644 --- a/sdgx/models/ml/single_table/ctgan.py +++ b/sdgx/models/ml/single_table/ctgan.py @@ -204,8 +204,10 @@ def __init__( self._generator = None self._ndarry_loader = None - def fit(self, metadata: Metadata, dataloader: DataLoader, *args, **kwargs): + def fit(self, metadata: Metadata, dataloader: DataLoader, epochs=None, *args, **kwargs): discrete_columns = metadata.get("discrete_columns", []) + if epochs is not None: + self._epochs = epochs self._pre_fit(dataloader, discrete_columns) self._fit(len(self._ndarry_loader)) diff --git a/sdgx/synthesizer.py b/sdgx/synthesizer.py index ecd5dc99..5b49ebe5 100644 --- a/sdgx/synthesizer.py +++ b/sdgx/synthesizer.py @@ -13,7 +13,7 @@ from sdgx.data_models.metadata import Metadata from sdgx.data_processors.base import DataProcessor from sdgx.data_processors.manager import DataProcessorManager -from sdgx.exceptions import SynthesizerInitError +from sdgx.exceptions import SynthesizerInitError, SynthesizerSampleError from sdgx.log import logger from sdgx.models.base import SynthesizerModel from sdgx.models.manager import ModelManager @@ -35,9 +35,9 @@ class Synthesizer: metadata_path (str | Path, optional): The path to the metadata file. Defaults to None. Used to load the metadata if ``metadata`` is None. data_connector (DataConnector | type[DataConnector] | str, optional): The data connector to use. Defaults to None. When data_connector is a string, it must be registered in :class:`~sdgx.data_connectors.manager.DataConnectorManager`. - data_connectors_kwargs (dict[str, Any], optional): The keyword arguments for data connectors. Defaults to None. + data_connector_kwargs (dict[str, Any], optional): The keyword arguments for data connectors. Defaults to None. raw_data_loaders_kwargs (dict[str, Any], optional): The keyword arguments for raw data loaders. Defaults to None. - processored_data_loaders_kwargs (dict[str, Any], optional): The keyword arguments for processed data loaders. Defaults to None. + processed_data_loaders_kwargs (dict[str, Any], optional): The keyword arguments for processed data loaders. Defaults to None. data_processors (list[str | DataProcessor | type[DataProcessor]], optional): The data processors to use. Defaults to None. When data_processor is a string, it must be registered in :class:`~sdgx.data_processors.manager.DataProcessorManager`. data_processors_kwargs (dict[str, dict[str, Any]], optional): The keyword arguments for data processors. Defaults to None. @@ -78,16 +78,16 @@ def __init__( metadata: None | Metadata = None, metadata_path: None | str | Path = None, data_connector: None | str | DataConnector | type[DataConnector] = None, - data_connectors_kwargs: None | dict[str, Any] = None, + data_connector_kwargs: None | dict[str, Any] = None, raw_data_loaders_kwargs: None | dict[str, Any] = None, - processored_data_loaders_kwargs: None | dict[str, Any] = None, + processed_data_loaders_kwargs: None | dict[str, Any] = None, data_processors: None | list[str | DataProcessor | type[DataProcessor]] = None, - data_processors_kwargs: None | dict[str, dict[str, Any]] = None, + data_processors_kwargs: None | dict[str, Any] = None, ): # Init data connectors if isinstance(data_connector, str) or isinstance(data_connector, type): data_connector = DataConnectorManager().init_data_connector( - data_connector, **(data_connectors_kwargs or {}) + data_connector, **(data_connector_kwargs or {}) ) if data_connector: self.dataloader = DataLoader( @@ -149,7 +149,7 @@ def __init__( raise SynthesizerInitError("model or model_path must be specified") # Other arguments - self.processored_data_loaders_kwargs = processored_data_loaders_kwargs or {} + self.processed_data_loaders_kwargs = processed_data_loaders_kwargs or {} def save(self, save_dir: str | Path) -> Path: """ @@ -163,6 +163,7 @@ def save(self, save_dir: str | Path) -> Path: """ save_dir = Path(save_dir).expanduser().resolve() save_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Saving synthesizer to {save_dir}") if self.metadata: self.metadata.save(save_dir / self.METADATA_SAVE_NAME) @@ -179,15 +180,17 @@ def load( model: str | type[SynthesizerModel], metadata: None | Metadata = None, data_connector: None | str | DataConnector | type[DataConnector] = None, - data_connectors_kwargs: None | dict[str, Any] = None, + data_connector_kwargs: None | dict[str, Any] = None, raw_data_loaders_kwargs: None | dict[str, Any] = None, - processored_data_loaders_kwargs: None | dict[str, Any] = None, + processed_data_loaders_kwargs: None | dict[str, Any] = None, data_processors: None | list[str | DataProcessor | type[DataProcessor]] = None, data_processors_kwargs: None | dict[str, dict[str, Any]] = None, ) -> "Synthesizer": """ Load metadata and model, allow rebuilding Synthesizer for finetuning or other use cases. + We need ``model`` as not every model support *pickle* way to save and load. + Args: load_dir (str | Path): The directory to load the model. model (str | type[SynthesizerModel]): The name of the model or the model itself. Type of model must be :class:`~sdgx.models.base.SynthesizerModel`. @@ -195,9 +198,9 @@ def load( metadata (Metadata, optional): The metadata to use. Defaults to None. data_connector (DataConnector | type[DataConnector] | str, optional): The data connector to use. Defaults to None. When data_connector is a string, it must be registered in :class:`~sdgx.data_connectors.manager.DataConnectorManager`. - data_connectors_kwargs (dict[str, Any], optional): The keyword arguments for data connectors. Defaults to None. + data_connector_kwargs (dict[str, Any], optional): The keyword arguments for data connectors. Defaults to None. raw_data_loaders_kwargs (dict[str, Any], optional): The keyword arguments for raw data loaders. Defaults to None. - processored_data_loaders_kwargs (dict[str, Any], optional): The keyword arguments for processed data loaders. Defaults to None. + processed_data_loaders_kwargs (dict[str, Any], optional): The keyword arguments for processed data loaders. Defaults to None. data_processors (list[str | DataProcessor | type[DataProcessor]], optional): The data processors to use. Defaults to None. When data_processor is a string, it must be registered in :class:`~sdgx.data_processors.manager.DataProcessorManager`. data_processors_kwargs (dict[str, dict[str, Any]], optional): The keyword arguments for data processors. Defaults to None. @@ -207,6 +210,7 @@ def load( """ load_dir = Path(load_dir).expanduser().resolve() + logger.info(f"Loading synthesizer from {load_dir}") if not load_dir.exists(): raise SynthesizerInitError(f"{load_dir.as_posix()} does not exist") @@ -226,9 +230,9 @@ def load( metadata=metadata, metadata_path=metadata_path, data_connector=data_connector, - data_connectors_kwargs=data_connectors_kwargs, + data_connector_kwargs=data_connector_kwargs, raw_data_loaders_kwargs=raw_data_loaders_kwargs, - processored_data_loaders_kwargs=processored_data_loaders_kwargs, + processed_data_loaders_kwargs=processed_data_loaders_kwargs, data_processors=data_processors, data_processors_kwargs=data_processors_kwargs, ) @@ -295,7 +299,7 @@ def chunk_generator() -> Generator[pd.DataFrame, None, None]: start_time = time.time() processed_dataloader = DataLoader( GeneratorConnector(chunk_generator), - **self.processored_data_loaders_kwargs, + **self.processed_data_loaders_kwargs, ) logger.info(f"Initialized processed data loader in {time.time() - start_time}s") try: @@ -309,7 +313,7 @@ def sample( count: int, chunksize: None | int = None, metadata: None | Metadata = None, - model_fit_kwargs: None | dict[str, Any] = None, + model_sample_args: None | dict[str, Any] = None, ) -> pd.DataFrame | Generator[pd.DataFrame, None, None]: """ Sample data from the synthesizer. @@ -319,7 +323,7 @@ def sample( chunksize (int, optional): The chunksize to use. Defaults to None. If is not None, the data will be sampled in chunks. And will return a generator that yields chunks of samples. metadata (Metadata, optional): The metadata to use. Defaults to None. If None, will use the metadata in fit first. - model_fit_kwargs (dict[str, Any], optional): The keyword arguments for model.fit. Defaults to None. + model_sample_args (dict[str, Any], optional): The keyword arguments for model.sample. Defaults to None. Returns: pd.DataFrame | typing.Generator[pd.DataFrame, None, None]: The sampled data. When chunksize is not None, it will be a generator. @@ -330,22 +334,25 @@ def sample( if metadata: for d in self.data_processors: d.fit(metadata) - if not model_fit_kwargs: - model_fit_kwargs = {} + if not model_sample_args: + model_sample_args = {} if chunksize is None: - return self._sample_once(count, model_fit_kwargs) + return self._sample_once(count, model_sample_args) + + if chunksize > count: + raise SynthesizerSampleError("chunksize must be less than or equal to count") def generator_sample_caller(): sample_times = count // chunksize for _ in range(sample_times): - sample_data = self._sample_once(chunksize, model_fit_kwargs) + sample_data = self._sample_once(chunksize, model_sample_args) for d in self.data_processors: sample_data = d.reverse_convert(sample_data) yield sample_data if count % chunksize > 0: - sample_data = self._sample_once(count % chunksize, model_fit_kwargs) + sample_data = self._sample_once(count % chunksize, model_sample_args) for d in self.data_processors: sample_data = d.reverse_convert(sample_data) yield sample_data @@ -353,7 +360,7 @@ def generator_sample_caller(): return generator_sample_caller() def _sample_once( - self, count: int, model_fit_kwargs: None | dict[str, Any] = None + self, count: int, model_sample_args: None | dict[str, Any] = None ) -> pd.DataFrame: """ Sample data once. @@ -370,7 +377,7 @@ def _sample_once( max_trails = 5 sample_data_list = [] while missing_count > 0 and max_trails > 0: - sample_data = self.model.sample(int(missing_count * 1.2), **model_fit_kwargs) + sample_data = self.model.sample(int(missing_count * 1.2), **model_sample_args) for d in self.data_processors: sample_data = d.reverse_convert(sample_data) sample_data_list.append(sample_data) @@ -389,7 +396,8 @@ def cleanup(self): if self.dataloader: self.dataloader.finalize(clear_cache=True) # Release resources - del self.model + if hasattr(self, "model"): + del self.model def __del__(self): self.cleanup() diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py new file mode 100644 index 00000000..1fe938d0 --- /dev/null +++ b/tests/cli/test_cli.py @@ -0,0 +1,97 @@ +import json + +import pytest +from click.testing import CliRunner + +from sdgx.cli.main import ( + fit, + list_cachers, + list_data_connectors, + list_data_exporters, + list_data_processors, + list_models, + sample, +) +from sdgx.cli.message import NormalMessage + + +@pytest.mark.parametrize("json_output", [True, False]) +@pytest.mark.parametrize( + "command", + [ + list_cachers, + list_data_connectors, + list_data_processors, + list_data_exporters, + list_models, + ], +) +def test_list_extension_api(command, json_output): + runner = CliRunner() + result = runner.invoke(command, ["--json_output", json_output]) + + assert result.exit_code == 0 + if json_output: + assert NormalMessage()._dumo_json() in result.output + assert NormalMessage()._dumo_json() == result.output.strip().split("\n")[-1] + else: + assert NormalMessage()._dumo_json() not in result.output + + +@pytest.mark.parametrize("model", ["CTGAN"]) +@pytest.mark.parametrize("json_output", [True, False]) +def test_fit_save_load_sample(model, demo_single_table_path, cacher_kwargs, json_output, tmp_path): + runner = CliRunner() + save_dir = tmp_path / f"unittest-{model}" + result = runner.invoke( + fit, + [ + "--save_dir", + save_dir, + "--model", + model, + "--model_kwargs", + json.dumps({"epochs": 1}), + "--data_connector", + "csvconnector", + "--data_connector_kwargs", + json.dumps({"path": demo_single_table_path}), + "--raw_data_loaders_kwargs", + json.dumps({"cacher_kwargs": cacher_kwargs}), + "--processed_data_loaders_kwargs", + json.dumps({"cacher_kwargs": cacher_kwargs}), + "--json_output", + json_output, + ], + ) + + assert result.exit_code == 0 + assert save_dir.exists() + assert len(list(save_dir.iterdir())) > 0 + + if json_output: + assert json.loads(result.output.strip().split("\n")[-1]) + + export_dst = tmp_path / "exported.csv" + result = runner.invoke( + sample, + [ + "--load_dir", + save_dir, + "--model", + model, + "--json_output", + json_output, + "--export_dst", + export_dst.as_posix(), + ], + ) + + assert result.exit_code == 0 + assert export_dst.exists() + if json_output: + assert json.loads(result.output.strip().split("\n")[-1]) + + +if __name__ == "__main__": + pytest.main(["-vv", "-s", __file__]) diff --git a/tests/cli/test_message.py b/tests/cli/test_message.py new file mode 100644 index 00000000..bdb5b46c --- /dev/null +++ b/tests/cli/test_message.py @@ -0,0 +1,41 @@ +import json + +import pytest + +from sdgx.cli.message import ExceptionMessage, NormalMessage +from sdgx.exceptions import SdgxError + + +@pytest.mark.parametrize("return_val", [0, "123", [1, 2, 3], {"a": 1, "b": 2}]) +def test_normal_message(return_val): + NormalMessage.from_return_val(return_val)._dumo_json == json.dumps( + { + "code": 0, + "msg": "Success", + "payload": return_val if isinstance(return_val, dict) else {"return_val": return_val}, + } + ) + + +def unknown_exception(): + raise Exception + + +def sdgx_exception(): + raise SdgxError + + +@pytest.mark.parametrize("exception_caller", [unknown_exception, sdgx_exception]) +def test_exception_message(exception_caller): + try: + exception_caller() + except Exception as e: + msg = ExceptionMessage.from_exception(e) + assert msg._dumo_json() + assert msg.code != 0 + assert msg.payload + assert "details" in msg.payload + + +if __name__ == "__main__": + pytest.main(["-vv", "-s", __file__]) diff --git a/tests/conftest.py b/tests/conftest.py index 734fa5a7..78321ebb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,9 @@ import os -from functools import partial os.environ["SDG_NDARRAY_CACHE_ROOT"] = "/tmp/sdgx/ndarray_cache" + import shutil +from functools import partial import pytest @@ -19,7 +20,7 @@ @pytest.fixture def demo_single_table_path(): - yield download_demo_data(DATA_DIR) + yield download_demo_data(DATA_DIR).as_posix() @pytest.fixture diff --git a/tests/dataloader/conftest.py b/tests/dataloader/conftest.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/dataloader/test_cacher.py b/tests/dataloader/test_cacher.py index 18511622..4d3ded57 100644 --- a/tests/dataloader/test_cacher.py +++ b/tests/dataloader/test_cacher.py @@ -1,7 +1,5 @@ from __future__ import annotations -import shutil -from pathlib import Path from typing import Generator import pandas as pd diff --git a/tests/manager/test_exporter.py b/tests/manager/test_exporter.py new file mode 100644 index 00000000..15dd6301 --- /dev/null +++ b/tests/manager/test_exporter.py @@ -0,0 +1,22 @@ +import pytest + +from sdgx.data_exporters.manager import DataExporterManager + + +@pytest.fixture +def manager(): + yield DataExporterManager() + + +@pytest.mark.parametrize( + "supportd_exporter", + [ + "CsvExporter", + ], +) +def test_manager(supportd_exporter, manager: DataExporterManager): + assert manager._normalize_name(supportd_exporter) in manager.registed_exporters + + +if __name__ == "__main__": + pytest.main(["-vv", "-s", __file__]) diff --git a/tests/models/test_copula.py b/tests/models/test_copula.py index da665354..4b313ffb 100644 --- a/tests/models/test_copula.py +++ b/tests/models/test_copula.py @@ -1,9 +1,11 @@ +from pathlib import Path + from sdgx.models.statistics.single_table.copula import GaussianCopulaSynthesizer from sdgx.utils import get_demo_single_table def test_gaussian_copula(demo_single_table_path): - demo_data, discrete_cols = get_demo_single_table(demo_single_table_path.parent) + demo_data, discrete_cols = get_demo_single_table(Path(demo_single_table_path).parent) model = GaussianCopulaSynthesizer(discrete_cols) model.fit(demo_data) diff --git a/tests/test_csv_exporter.py b/tests/test_csv_exporter.py new file mode 100644 index 00000000..bf7e74eb --- /dev/null +++ b/tests/test_csv_exporter.py @@ -0,0 +1,37 @@ +import pandas as pd +import pytest + +from sdgx.data_exporters.csv_exporter import CsvExporter + + +@pytest.fixture +def csv_exporter(): + yield CsvExporter() + + +@pytest.fixture +def export_dst(tmp_path): + filename = tmp_path / "csv-exported.csv" + filename.unlink(missing_ok=True) + yield filename + # filename.unlink(missing_ok=True) + + +def test_csv_exporter_df(csv_exporter: CsvExporter, export_dst): + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + csv_exporter.write(export_dst, df) + pd.testing.assert_frame_equal(df, pd.read_csv(export_dst)) + + +def test_csv_exporter_generator(csv_exporter: CsvExporter, export_dst): + def generator(): + yield pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + yield pd.DataFrame({"a": [7, 8, 9], "b": [10, 11, 12]}) + + df_all = pd.concat(generator(), ignore_index=True) + csv_exporter.write(export_dst, generator()) + pd.testing.assert_frame_equal(df_all, pd.read_csv(export_dst)) + + +if __name__ == "__main__": + pytest.main(["-vv", "-s", __file__]) diff --git a/tests/test_synthesizer.py b/tests/test_synthesizer.py index b35ae9db..c6f1e5ab 100644 --- a/tests/test_synthesizer.py +++ b/tests/test_synthesizer.py @@ -61,7 +61,7 @@ def synthesizer(cacher_kwargs): data_connector=MockDataConnector(), raw_data_loaders_kwargs={"cacher_kwargs": cacher_kwargs}, data_processors=[MockDataProcessor()], - processored_data_loaders_kwargs={"cacher_kwargs": cacher_kwargs}, + processed_data_loaders_kwargs={"cacher_kwargs": cacher_kwargs}, metadata=Metadata(), )