Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge model base class #11

Merged
merged 3 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions example/1_ctgan_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# ipython -i example/1_ctgan_example.py
# 并查看 sampled_data 变量

from sdgx.models.single_table.ctgan import GeneratorCTGAN
from sdgx.models.single_table.ctgan import CTGAN
from sdgx.transform.sampler import DataSamplerCTGAN
from sdgx.transform.transformer import DataTransformerCTGAN
from sdgx.utils.io.csv_utils import *
Expand All @@ -12,8 +12,8 @@
demo_data, discrete_cols = get_demo_single_table()


model = GeneratorCTGAN(epochs=10, transformer=DataTransformerCTGAN, sampler=DataSamplerCTGAN)
model = CTGAN(epochs=10, transformer=DataTransformerCTGAN, sampler=DataSamplerCTGAN)
model.fit(demo_data, discrete_cols)

# sampled
sampled_data = model.generate(1000)
sampled_data = model.sample(1000)
85 changes: 65 additions & 20 deletions sdgx/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,80 @@
import torch


class BaseGeneratorModel:
class BaseSynthesizerModel:
random_states = None

def __init__(self, transformer=None, sampler=None) -> None:
# 以下几个变量都需要在初始化 model 时进行更改
self.model = None # 存放模型
self.status = "UNFINED"
self.model_type = "MODEL_TYPE_UNDEFINED"
# self.epochs = epochs
self._device = "CPU"

# 目前使用CPU计算,后续扩展使用 GPU 及其 Proxy
self.device = "CPU"

# fit 模型
def fit(self):
# 需要覆写该方法
raise NotImplementedError
def set_device(self, device):
"""Set the `device` to be used ('GPU' or 'CPU')."""
self._device = device
if self._generator is not None:
self._generator.to(self._device)

def generate(self, n_rows=100):
# 需要覆写该方法
raise NotImplementedError
def __getstate__(self):
device_backup = self._device
self.set_device(torch.device("cpu"))
state = self.__dict__.copy()
self.set_device(device_backup)
if (
isinstance(self.random_states, tuple)
and isinstance(self.random_states[0], np.random.RandomState)
and isinstance(self.random_states[1], torch.Generator)
):
state["_numpy_random_state"] = self.random_states[0].get_state()
state["_torch_random_state"] = self.random_states[1].get_state()
state.pop("random_states")
return state

def fit(self):
# 需要覆写该方法
raise NotImplementedError
def __setstate__(self, state):
if "_numpy_random_state" in state and "_torch_random_state" in state:
np_state = state.pop("_numpy_random_state")
torch_state = state.pop("_torch_random_state")
current_torch_state = torch.Generator()
current_torch_state.set_state(torch_state)
current_numpy_state = np.random.RandomState()
current_numpy_state.set_state(np_state)
state["random_states"] = (current_numpy_state, current_torch_state)
self.__dict__ = state
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.set_device(device)

def load_from_disk(self, model_path=""):
# 需要覆写该方法
raise NotImplementedError
def save(self, path):
device_backup = self._device
self.set_device(torch.device("cpu"))
torch.save(self, path)
self.set_device(device_backup)

def dump_to_disk(self, output_path=""):
raise NotImplementedError
@classmethod
def load(cls, path):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load(path)
model.set_device(device)
return model

pass
def set_random_state(self, random_state):
if random_state is None:
self.random_states = random_state
elif isinstance(random_state, int):
self.random_states = (
np.random.RandomState(seed=random_state),
torch.Generator().manual_seed(random_state),
)
elif (
isinstance(random_state, tuple)
and isinstance(random_state[0], np.random.RandomState)
and isinstance(random_state[1], torch.Generator)
):
self.random_states = random_state
else:
raise TypeError(
f"`random_state` {random_state} expected to be an int or a tuple of "
"(`np.random.RandomState`, `torch.Generator`)"
)
144 changes: 6 additions & 138 deletions sdgx/models/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,47 +15,17 @@
functional,
)

# base
from sdgx.models.base import BaseGeneratorModel
# base 类已拆分,挪到 base.py
from sdgx.models.base import BaseSynthesizerModel

# transformer 以及 sampler 已经拆分,单独挪到了 transform/ 目录中
# transformer 以及 sampler 已经拆分,挪到 transform/ 目录中
from sdgx.transform.sampler import DataSamplerCTGAN
from sdgx.transform.transformer import DataTransformerCTGAN

# 一些辅助的函数
# 其他函数
from sdgx.utils.utils import random_state


class GeneratorCTGAN(BaseGeneratorModel):
def __init__(self, epochs, transformer=None, sampler=None) -> None:
# super().__init__()

# ctgan 参数,需要预先定义
self.epochs = epochs

# 模型相关的其他参数
self.model = CTGAN(
epochs=self.epochs, # 以下为拆分的 transformer 与 sampler
# 本组件可以自定义这两个内容
transformer=transformer,
sampler=sampler,
)
self.model_type = "CTGAN"
self.status = "ready"

def fit(self, input_df, discrete_cols=[]):
# 模型训练
self.model.fit(input_df, discrete_cols)
return

def generate(self, n_rows=100):
# 使用模型 generate 数据
generated_data = self.model.sample(n_rows)
return generated_data

pass


class Discriminator(Module):
"""Discriminator for the CTGAN."""

Expand Down Expand Up @@ -138,104 +108,8 @@ def forward(self, input_):
return data


# 从 ctgan中引入,后续根据时
class BaseSynthesizer:
"""Base class for all default synthesizers of ``CTGAN``."""

random_states = None

def __getstate__(self):
"""Improve pickling state for ``BaseSynthesizer``.

Convert to ``cpu`` device before starting the pickling process in order to be able to
load the model even when used from an external tool such as ``SDV``. Also, if
``random_states`` are set, store their states as dictionaries rather than generators.

Returns:
dict:
Python dict representing the object.
"""
device_backup = self._device
self.set_device(torch.device("cpu"))
state = self.__dict__.copy()
self.set_device(device_backup)
if (
isinstance(self.random_states, tuple)
and isinstance(self.random_states[0], np.random.RandomState)
and isinstance(self.random_states[1], torch.Generator)
):
state["_numpy_random_state"] = self.random_states[0].get_state()
state["_torch_random_state"] = self.random_states[1].get_state()
state.pop("random_states")

return state

def __setstate__(self, state):
"""Restore the state of a ``BaseSynthesizer``.

Restore the ``random_states`` from the state dict if those are present and then
set the device according to the current hardware.
"""
if "_numpy_random_state" in state and "_torch_random_state" in state:
np_state = state.pop("_numpy_random_state")
torch_state = state.pop("_torch_random_state")

current_torch_state = torch.Generator()
current_torch_state.set_state(torch_state)

current_numpy_state = np.random.RandomState()
current_numpy_state.set_state(np_state)
state["random_states"] = (current_numpy_state, current_torch_state)

self.__dict__ = state
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.set_device(device)

def save(self, path):
"""Save the model in the passed `path`."""
device_backup = self._device
self.set_device(torch.device("cpu"))
torch.save(self, path)
self.set_device(device_backup)

@classmethod
def load(cls, path):
"""Load the model stored in the passed `path`."""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load(path)
model.set_device(device)
return model

def set_random_state(self, random_state):
"""Set the random state.

Args:
random_state (int, tuple, or None):
Either a tuple containing the (numpy.random.RandomState, torch.Generator)
or an int representing the random seed to use for both random states.
"""
if random_state is None:
self.random_states = random_state
elif isinstance(random_state, int):
self.random_states = (
np.random.RandomState(seed=random_state),
torch.Generator().manual_seed(random_state),
)
elif (
isinstance(random_state, tuple)
and isinstance(random_state[0], np.random.RandomState)
and isinstance(random_state[1], torch.Generator)
):
self.random_states = random_state
else:
raise TypeError(
f"`random_state` {random_state} expected to be an int or a tuple of "
"(`np.random.RandomState`, `torch.Generator`)"
)


# CTGAN model
class CTGAN(BaseSynthesizer):
# 后续需要根据实际情况做性能优化
class CTGAN(BaseSynthesizerModel):
"""Conditional Table GAN Synthesizer.

This is the core class of the CTGAN project, where the different components
Expand Down Expand Up @@ -643,9 +517,3 @@ def sample(self, n, condition_column=None, condition_value=None):
data = data[:n]

return self._transformer.inverse_transform(data)

def set_device(self, device):
"""Set the `device` to be used ('GPU' or 'CPU)."""
self._device = device
if self._generator is not None:
self._generator.to(self._device)