-
Notifications
You must be signed in to change notification settings - Fork 541
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #47 from hitsz-ids/feature-cli-and-plugin-system
Adding CLI and Plugin system
- Loading branch information
Showing
14 changed files
with
317 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__version__ = "0.1.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from __future__ import annotations | ||
|
||
from sdgx.models.base import BaseSynthesizerModel | ||
|
||
|
||
class MyOwnModel(BaseSynthesizerModel): | ||
... | ||
|
||
|
||
from sdgx.models.extension import hookimpl | ||
|
||
|
||
@hookimpl | ||
def register(manager): | ||
manager.register("DummyModel", MyOwnModel) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Build with hatch, you can use any build tool you like. | ||
[build-system] | ||
requires = ["hatchling"] | ||
build-backend = "hatchling.build" | ||
|
||
[project] | ||
name = "sdgx-dummymodel" | ||
|
||
dependencies = ["sdgx"] | ||
dynamic = ["version"] | ||
|
||
# This is the entry point for the FilterManager to find the Filter. | ||
[project.entry-points."sdgx.model"] | ||
dummymodel = "dummymodel.model" | ||
|
||
[tool.hatch.version] | ||
path = "dummymodel/__init__.py" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import json | ||
from pathlib import Path | ||
|
||
import click | ||
import pandas | ||
|
||
from sdgx.models.manager import ModelManager | ||
|
||
|
||
@click.command() | ||
@click.option( | ||
"--model", | ||
help="Name of model, use `sdgx list-models` to list all available models.", | ||
required=True, | ||
) | ||
@click.option( | ||
"--model_params", | ||
default="{}", | ||
help="[Json-string] Parameters for model.", | ||
) | ||
@click.option( | ||
"--input_path", | ||
help="Path of input data.", | ||
required=True, | ||
) | ||
@click.option( | ||
"--input_type", | ||
default="csv", | ||
help="Type of input data, will be used as `pandas.read_{input_type}`.", | ||
) | ||
@click.option( | ||
"--read_params", | ||
default="{}", | ||
help="[Json-string] Parameters for `pandas.read_{input_type}`.", | ||
) | ||
@click.option( | ||
"--fit_params", | ||
default="{}", | ||
help="[Json-string] Parameters for `model.fit`.", | ||
) | ||
@click.option( | ||
"--output_path", | ||
help="Path to save the model.", | ||
required=True, | ||
) | ||
def fit( | ||
model, | ||
model_params, | ||
input_path, | ||
input_type, | ||
read_params, | ||
fit_params, | ||
output_path, | ||
): | ||
model_params = json.loads(model_params) | ||
read_params = json.loads(read_params) | ||
fit_params = json.loads(fit_params) | ||
|
||
model = ModelManager().init_model(model, **model_params) | ||
input_method = getattr(pandas, f"read_{input_type}") | ||
if not input_method: | ||
raise NotImplementedError(f"Pandas not support read_{input_type}") | ||
df = input_method(input_path, **read_params) | ||
model.fit(df, **fit_params) | ||
|
||
Path(output_path).mkdir(parents=True, exist_ok=True) | ||
model.save(output_path) | ||
|
||
|
||
@click.command() | ||
def sample( | ||
model_path, | ||
output_path, | ||
write_type, | ||
write_param, | ||
): | ||
model = ModelManager.load(model_path) | ||
# TODO: Model not have `sample` in Base Class yet | ||
# sampled_data = model.sample() | ||
# write_method = getattr(sampled_data, f"to_{write_type}") | ||
# write_method(output_path, write_param) | ||
|
||
|
||
@click.command() | ||
def list_models(): | ||
for model_name, model_cls in ModelManager().registed_model.items(): | ||
print(f"{model_name} is registed as class: {model_cls}.") | ||
|
||
|
||
@click.group() | ||
def cli(): | ||
pass | ||
|
||
|
||
cli.add_command(fit) | ||
cli.add_command(sample) | ||
cli.add_command(list_models) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import os | ||
|
||
USER_DEFINED_LOG_LEVEL = os.getenv("SDGX_LOG_LEVEL", "INFO") | ||
|
||
os.environ["LOGURU_LEVEL"] = USER_DEFINED_LOG_LEVEL | ||
|
||
from loguru import logger |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from . import multi_tables, single_table |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
import pluggy | ||
|
||
project_name = "sdgx.model" | ||
hookimpl = pluggy.HookimplMarker(project_name) | ||
hookspec = pluggy.HookspecMarker(project_name) | ||
|
||
|
||
@hookspec | ||
def register(manager): | ||
""" | ||
For more information about this function, please check the ``ModelManager`` | ||
We provided an example package for you in {project_root}/example/extension/dummymodel. | ||
Example: | ||
.. code-block:: python | ||
class MyOwnModel(BaseSynthesizerModel): | ||
... | ||
from sdgx.models.extension import hookimpl | ||
@hookimpl | ||
def register(manager): | ||
manager.register("DummyModel", MyOwnModel) | ||
Config ``project.entry-points`` so that we can find it | ||
.. code-block:: toml | ||
[project.entry-points."sdgx.model"] | ||
< whatever-name > = "<package>.<path>.<to>.<model-file>" | ||
You can verify it by `sdgx list-models`. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from __future__ import annotations | ||
|
||
import glob | ||
import importlib | ||
from os.path import basename, dirname, isfile, join | ||
from typing import Any | ||
|
||
import pluggy | ||
|
||
from sdgx import models | ||
from sdgx.errors import ModelInitializationError, ModelNotFoundError, ModelRegisterError | ||
from sdgx.log import logger | ||
from sdgx.models import extension | ||
from sdgx.models.base import BaseSynthesizerModel | ||
from sdgx.models.extension import project_name as PROJECT_NAME | ||
from sdgx.utils.utils import Singleton | ||
|
||
|
||
class ModelManager(metaclass=Singleton): | ||
def __init__(self): | ||
self.pm = pluggy.PluginManager(PROJECT_NAME) | ||
self.pm.add_hookspecs(extension) | ||
self._registed_model: dict[str, type[BaseSynthesizerModel]] = {} | ||
|
||
# Load all | ||
self.pm.load_setuptools_entrypoints(PROJECT_NAME) | ||
|
||
# Load all local model | ||
self.load_all_local_model() | ||
|
||
@property | ||
def registed_model(self): | ||
# Lazy load when query registed_model | ||
if self._registed_model: | ||
return self._registed_model | ||
for f in self.pm.hook.register(manager=self): | ||
try: | ||
f() | ||
except Exception as e: | ||
logger.exception(ModelRegisterError(e)) | ||
continue | ||
return self._registed_model | ||
|
||
def load_all_local_model(self): | ||
# self.pm.register(sdgx/models/single_table/*) | ||
self._load_dir(models.single_table) | ||
# self.pm.register(sdgx/models/multi_tables/*) | ||
self._load_dir(models.multi_tables) | ||
|
||
def _load_dir(self, module): | ||
modules = glob.glob(join(dirname(module.__file__), "*.py")) | ||
sub_packages = ( | ||
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") | ||
) | ||
packages = (str(module.__package__) + "." + i for i in sub_packages) | ||
for p in packages: | ||
self.pm.register(importlib.import_module(p)) | ||
|
||
def _normalize_name(self, model_name: str) -> str: | ||
return model_name.strip().lower() | ||
|
||
def register(self, model_name, model_cls: type[BaseSynthesizerModel]): | ||
model_name = self._normalize_name(model_name) | ||
logger.info(f"Register for new model: {model_name}") | ||
self._registed_model[model_name] = model_cls | ||
|
||
def init_model(self, model_name, **kwargs: dict[str, Any]) -> BaseSynthesizerModel: | ||
model_name = self._normalize_name(model_name) | ||
if not model_name in self.registed_model: | ||
raise ModelNotFoundError | ||
try: | ||
return self.registed_model[model_name](**kwargs) | ||
except Exception as e: | ||
raise ModelInitializationError(e) | ||
|
||
@staticmethod | ||
def load(model_path) -> BaseSynthesizerModel: | ||
return BaseSynthesizerModel.load(model_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import pytest | ||
|
||
from sdgx.models.manager import ModelManager | ||
from sdgx.utils.io.csv_utils import get_demo_single_table | ||
|
||
|
||
@pytest.fixture | ||
def model_manager(): | ||
yield ModelManager() | ||
|
||
|
||
def test_model_manager(model_manager: ModelManager): | ||
assert "ctgan" in model_manager.registed_model | ||
|
||
model = model_manager.init_model("ctgan", epochs=1) | ||
demo_data, discrete_cols = get_demo_single_table() | ||
model.fit(demo_data, discrete_cols) | ||
|
||
# 生成合成数据 | ||
sampled_data = model.sample(1000) | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main(["-vv", "-s", __file__]) |