Skip to content

Commit

Permalink
Merge pull request #47 from hitsz-ids/feature-cli-and-plugin-system
Browse files Browse the repository at this point in the history
Adding CLI and Plugin system
  • Loading branch information
MooooCat committed Dec 4, 2023
2 parents 8c63460 + b0e69d6 commit d6a109e
Show file tree
Hide file tree
Showing 14 changed files with 317 additions and 1 deletion.
1 change: 1 addition & 0 deletions example/extension/dummymodel/dummymodel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.1.0"
15 changes: 15 additions & 0 deletions example/extension/dummymodel/dummymodel/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations

from sdgx.models.base import BaseSynthesizerModel


class MyOwnModel(BaseSynthesizerModel):
...


from sdgx.models.extension import hookimpl


@hookimpl
def register(manager):
manager.register("DummyModel", MyOwnModel)
17 changes: 17 additions & 0 deletions example/extension/dummymodel/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Build with hatch, you can use any build tool you like.
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "sdgx-dummymodel"

dependencies = ["sdgx"]
dynamic = ["version"]

# This is the entry point for the FilterManager to find the Filter.
[project.entry-points."sdgx.model"]
dummymodel = "dummymodel.model"

[tool.hatch.version]
path = "dummymodel/__init__.py"
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ dependencies = [
"seaborn",
"table-evaluator",
"copulas",
"click",
"pluggy",
"loguru",
]
dynamic = ["version"]
classifiers = [
Expand All @@ -36,6 +39,8 @@ classifiers = [
test = ["pytest", "pytest-cov"]
docs = ["Sphinx<=7.2.4", "sphinx-rtd-theme", "sphinx-click", "autodoc_pydantic"]

[project.scripts]
sdgx = "sdgx.cli.main:cli"

[[project.authors]]
name = "hitsz-ids"
Expand Down
Empty file added sdgx/cli/__init__.py
Empty file.
97 changes: 97 additions & 0 deletions sdgx/cli/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import json
from pathlib import Path

import click
import pandas

from sdgx.models.manager import ModelManager


@click.command()
@click.option(
"--model",
help="Name of model, use `sdgx list-models` to list all available models.",
required=True,
)
@click.option(
"--model_params",
default="{}",
help="[Json-string] Parameters for model.",
)
@click.option(
"--input_path",
help="Path of input data.",
required=True,
)
@click.option(
"--input_type",
default="csv",
help="Type of input data, will be used as `pandas.read_{input_type}`.",
)
@click.option(
"--read_params",
default="{}",
help="[Json-string] Parameters for `pandas.read_{input_type}`.",
)
@click.option(
"--fit_params",
default="{}",
help="[Json-string] Parameters for `model.fit`.",
)
@click.option(
"--output_path",
help="Path to save the model.",
required=True,
)
def fit(
model,
model_params,
input_path,
input_type,
read_params,
fit_params,
output_path,
):
model_params = json.loads(model_params)
read_params = json.loads(read_params)
fit_params = json.loads(fit_params)

model = ModelManager().init_model(model, **model_params)
input_method = getattr(pandas, f"read_{input_type}")
if not input_method:
raise NotImplementedError(f"Pandas not support read_{input_type}")
df = input_method(input_path, **read_params)
model.fit(df, **fit_params)

Path(output_path).mkdir(parents=True, exist_ok=True)
model.save(output_path)


@click.command()
def sample(
model_path,
output_path,
write_type,
write_param,
):
model = ModelManager.load(model_path)
# TODO: Model not have `sample` in Base Class yet
# sampled_data = model.sample()
# write_method = getattr(sampled_data, f"to_{write_type}")
# write_method(output_path, write_param)


@click.command()
def list_models():
for model_name, model_cls in ModelManager().registed_model.items():
print(f"{model_name} is registed as class: {model_cls}.")


@click.group()
def cli():
pass


cli.add_command(fit)
cli.add_command(sample)
cli.add_command(list_models)
12 changes: 12 additions & 0 deletions sdgx/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,15 @@ class SdgxError(Exception):

class NonParametricError(Exception):
"""Exception to indicate that a model is not parametric."""


class ModelNotFoundError(SdgxError):
pass


class ModelRegisterError(SdgxError):
pass


class ModelInitializationError(SdgxError):
pass
7 changes: 7 additions & 0 deletions sdgx/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os

USER_DEFINED_LOG_LEVEL = os.getenv("SDGX_LOG_LEVEL", "INFO")

os.environ["LOGURU_LEVEL"] = USER_DEFINED_LOG_LEVEL

from loguru import logger
1 change: 1 addition & 0 deletions sdgx/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import multi_tables, single_table
40 changes: 40 additions & 0 deletions sdgx/models/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

from typing import Any

import pluggy

project_name = "sdgx.model"
hookimpl = pluggy.HookimplMarker(project_name)
hookspec = pluggy.HookspecMarker(project_name)


@hookspec
def register(manager):
"""
For more information about this function, please check the ``ModelManager``
We provided an example package for you in {project_root}/example/extension/dummymodel.
Example:
.. code-block:: python
class MyOwnModel(BaseSynthesizerModel):
...
from sdgx.models.extension import hookimpl
@hookimpl
def register(manager):
manager.register("DummyModel", MyOwnModel)
Config ``project.entry-points`` so that we can find it
.. code-block:: toml
[project.entry-points."sdgx.model"]
< whatever-name > = "<package>.<path>.<to>.<model-file>"
You can verify it by `sdgx list-models`.
"""
78 changes: 78 additions & 0 deletions sdgx/models/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

import glob
import importlib
from os.path import basename, dirname, isfile, join
from typing import Any

import pluggy

from sdgx import models
from sdgx.errors import ModelInitializationError, ModelNotFoundError, ModelRegisterError
from sdgx.log import logger
from sdgx.models import extension
from sdgx.models.base import BaseSynthesizerModel
from sdgx.models.extension import project_name as PROJECT_NAME
from sdgx.utils.utils import Singleton


class ModelManager(metaclass=Singleton):
def __init__(self):
self.pm = pluggy.PluginManager(PROJECT_NAME)
self.pm.add_hookspecs(extension)
self._registed_model: dict[str, type[BaseSynthesizerModel]] = {}

# Load all
self.pm.load_setuptools_entrypoints(PROJECT_NAME)

# Load all local model
self.load_all_local_model()

@property
def registed_model(self):
# Lazy load when query registed_model
if self._registed_model:
return self._registed_model
for f in self.pm.hook.register(manager=self):
try:
f()
except Exception as e:
logger.exception(ModelRegisterError(e))
continue
return self._registed_model

def load_all_local_model(self):
# self.pm.register(sdgx/models/single_table/*)
self._load_dir(models.single_table)
# self.pm.register(sdgx/models/multi_tables/*)
self._load_dir(models.multi_tables)

def _load_dir(self, module):
modules = glob.glob(join(dirname(module.__file__), "*.py"))
sub_packages = (
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")
)
packages = (str(module.__package__) + "." + i for i in sub_packages)
for p in packages:
self.pm.register(importlib.import_module(p))

def _normalize_name(self, model_name: str) -> str:
return model_name.strip().lower()

def register(self, model_name, model_cls: type[BaseSynthesizerModel]):
model_name = self._normalize_name(model_name)
logger.info(f"Register for new model: {model_name}")
self._registed_model[model_name] = model_cls

def init_model(self, model_name, **kwargs: dict[str, Any]) -> BaseSynthesizerModel:
model_name = self._normalize_name(model_name)
if not model_name in self.registed_model:
raise ModelNotFoundError
try:
return self.registed_model[model_name](**kwargs)
except Exception as e:
raise ModelInitializationError(e)

@staticmethod
def load(model_path) -> BaseSynthesizerModel:
return BaseSynthesizerModel.load(model_path)
8 changes: 7 additions & 1 deletion sdgx/models/single_table/ctgan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import List, Optional
from typing import Any, List, Optional

import numpy as np
import pandas as pd
Expand All @@ -18,6 +18,7 @@

# base 类已拆分,挪到 base.py
from sdgx.models.base import BaseSynthesizerModel
from sdgx.models.extension import hookimpl

# transformer 以及 sampler 已经拆分,挪到 transform/ 目录中
from sdgx.transform.sampler import DataSamplerCTGAN
Expand Down Expand Up @@ -577,3 +578,8 @@ def sample(self, n, condition_column=None, condition_value=None):
data = data[:n]

return self._transformer.inverse_transform(data)


@hookimpl
def register(manager):
manager.register("CTGAN", CTGAN)
13 changes: 13 additions & 0 deletions sdgx/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import logging
import threading

import numpy as np
import torch
Expand Down Expand Up @@ -241,3 +242,15 @@ def validate_numerical_distributions(numerical_distributions, metadata_columns):
# f'{invalid_columns}. The column names you provide must be present '
# 'in the metadata.'
# )


class Singleton(type):
_instances = {}
_lock = threading.Lock()

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
with cls._lock:
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
24 changes: 24 additions & 0 deletions tests/test_model_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

from sdgx.models.manager import ModelManager
from sdgx.utils.io.csv_utils import get_demo_single_table


@pytest.fixture
def model_manager():
yield ModelManager()


def test_model_manager(model_manager: ModelManager):
assert "ctgan" in model_manager.registed_model

model = model_manager.init_model("ctgan", epochs=1)
demo_data, discrete_cols = get_demo_single_table()
model.fit(demo_data, discrete_cols)

# 生成合成数据
sampled_data = model.sample(1000)


if __name__ == "__main__":
pytest.main(["-vv", "-s", __file__])

0 comments on commit d6a109e

Please sign in to comment.