Skip to content

Commit

Permalink
Merge pull request #11 from hitsz-ids/model_base_class
Browse files Browse the repository at this point in the history
Merge model base class
  • Loading branch information
MooooCat committed Aug 17, 2023
2 parents 32459a0 + c1da311 commit df0fe2e
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 161 deletions.
6 changes: 3 additions & 3 deletions example/1_ctgan_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# ipython -i example/1_ctgan_example.py
# 并查看 sampled_data 变量

from sdgx.models.single_table.ctgan import GeneratorCTGAN
from sdgx.models.single_table.ctgan import CTGAN
from sdgx.transform.sampler import DataSamplerCTGAN
from sdgx.transform.transformer import DataTransformerCTGAN
from sdgx.utils.io.csv_utils import *
Expand All @@ -12,8 +12,8 @@
demo_data, discrete_cols = get_demo_single_table()


model = GeneratorCTGAN(epochs=10, transformer=DataTransformerCTGAN, sampler=DataSamplerCTGAN)
model = CTGAN(epochs=10, transformer=DataTransformerCTGAN, sampler=DataSamplerCTGAN)
model.fit(demo_data, discrete_cols)

# sampled
sampled_data = model.generate(1000)
sampled_data = model.sample(1000)
85 changes: 65 additions & 20 deletions sdgx/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,80 @@
import torch


class BaseGeneratorModel:
class BaseSynthesizerModel:
random_states = None

def __init__(self, transformer=None, sampler=None) -> None:
# 以下几个变量都需要在初始化 model 时进行更改
self.model = None # 存放模型
self.status = "UNFINED"
self.model_type = "MODEL_TYPE_UNDEFINED"
# self.epochs = epochs
self._device = "CPU"

# 目前使用CPU计算,后续扩展使用 GPU 及其 Proxy
self.device = "CPU"

# fit 模型
def fit(self):
# 需要覆写该方法
raise NotImplementedError
def set_device(self, device):
"""Set the `device` to be used ('GPU' or 'CPU')."""
self._device = device
if self._generator is not None:
self._generator.to(self._device)

def generate(self, n_rows=100):
# 需要覆写该方法
raise NotImplementedError
def __getstate__(self):
device_backup = self._device
self.set_device(torch.device("cpu"))
state = self.__dict__.copy()
self.set_device(device_backup)
if (
isinstance(self.random_states, tuple)
and isinstance(self.random_states[0], np.random.RandomState)
and isinstance(self.random_states[1], torch.Generator)
):
state["_numpy_random_state"] = self.random_states[0].get_state()
state["_torch_random_state"] = self.random_states[1].get_state()
state.pop("random_states")
return state

def fit(self):
# 需要覆写该方法
raise NotImplementedError
def __setstate__(self, state):
if "_numpy_random_state" in state and "_torch_random_state" in state:
np_state = state.pop("_numpy_random_state")
torch_state = state.pop("_torch_random_state")
current_torch_state = torch.Generator()
current_torch_state.set_state(torch_state)
current_numpy_state = np.random.RandomState()
current_numpy_state.set_state(np_state)
state["random_states"] = (current_numpy_state, current_torch_state)
self.__dict__ = state
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.set_device(device)

def load_from_disk(self, model_path=""):
# 需要覆写该方法
raise NotImplementedError
def save(self, path):
device_backup = self._device
self.set_device(torch.device("cpu"))
torch.save(self, path)
self.set_device(device_backup)

def dump_to_disk(self, output_path=""):
raise NotImplementedError
@classmethod
def load(cls, path):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load(path)
model.set_device(device)
return model

pass
def set_random_state(self, random_state):
if random_state is None:
self.random_states = random_state
elif isinstance(random_state, int):
self.random_states = (
np.random.RandomState(seed=random_state),
torch.Generator().manual_seed(random_state),
)
elif (
isinstance(random_state, tuple)
and isinstance(random_state[0], np.random.RandomState)
and isinstance(random_state[1], torch.Generator)
):
self.random_states = random_state
else:
raise TypeError(
f"`random_state` {random_state} expected to be an int or a tuple of "
"(`np.random.RandomState`, `torch.Generator`)"
)
144 changes: 6 additions & 138 deletions sdgx/models/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,47 +15,17 @@
functional,
)

# base
from sdgx.models.base import BaseGeneratorModel
# base 类已拆分,挪到 base.py
from sdgx.models.base import BaseSynthesizerModel

# transformer 以及 sampler 已经拆分,单独挪到了 transform/ 目录中
# transformer 以及 sampler 已经拆分,挪到 transform/ 目录中
from sdgx.transform.sampler import DataSamplerCTGAN
from sdgx.transform.transformer import DataTransformerCTGAN

# 一些辅助的函数
# 其他函数
from sdgx.utils.utils import random_state


class GeneratorCTGAN(BaseGeneratorModel):
def __init__(self, epochs, transformer=None, sampler=None) -> None:
# super().__init__()

# ctgan 参数,需要预先定义
self.epochs = epochs

# 模型相关的其他参数
self.model = CTGAN(
epochs=self.epochs, # 以下为拆分的 transformer 与 sampler
# 本组件可以自定义这两个内容
transformer=transformer,
sampler=sampler,
)
self.model_type = "CTGAN"
self.status = "ready"

def fit(self, input_df, discrete_cols=[]):
# 模型训练
self.model.fit(input_df, discrete_cols)
return

def generate(self, n_rows=100):
# 使用模型 generate 数据
generated_data = self.model.sample(n_rows)
return generated_data

pass


class Discriminator(Module):
"""Discriminator for the CTGAN."""

Expand Down Expand Up @@ -138,104 +108,8 @@ def forward(self, input_):
return data


# 从 ctgan中引入,后续根据时
class BaseSynthesizer:
"""Base class for all default synthesizers of ``CTGAN``."""

random_states = None

def __getstate__(self):
"""Improve pickling state for ``BaseSynthesizer``.
Convert to ``cpu`` device before starting the pickling process in order to be able to
load the model even when used from an external tool such as ``SDV``. Also, if
``random_states`` are set, store their states as dictionaries rather than generators.
Returns:
dict:
Python dict representing the object.
"""
device_backup = self._device
self.set_device(torch.device("cpu"))
state = self.__dict__.copy()
self.set_device(device_backup)
if (
isinstance(self.random_states, tuple)
and isinstance(self.random_states[0], np.random.RandomState)
and isinstance(self.random_states[1], torch.Generator)
):
state["_numpy_random_state"] = self.random_states[0].get_state()
state["_torch_random_state"] = self.random_states[1].get_state()
state.pop("random_states")

return state

def __setstate__(self, state):
"""Restore the state of a ``BaseSynthesizer``.
Restore the ``random_states`` from the state dict if those are present and then
set the device according to the current hardware.
"""
if "_numpy_random_state" in state and "_torch_random_state" in state:
np_state = state.pop("_numpy_random_state")
torch_state = state.pop("_torch_random_state")

current_torch_state = torch.Generator()
current_torch_state.set_state(torch_state)

current_numpy_state = np.random.RandomState()
current_numpy_state.set_state(np_state)
state["random_states"] = (current_numpy_state, current_torch_state)

self.__dict__ = state
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.set_device(device)

def save(self, path):
"""Save the model in the passed `path`."""
device_backup = self._device
self.set_device(torch.device("cpu"))
torch.save(self, path)
self.set_device(device_backup)

@classmethod
def load(cls, path):
"""Load the model stored in the passed `path`."""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load(path)
model.set_device(device)
return model

def set_random_state(self, random_state):
"""Set the random state.
Args:
random_state (int, tuple, or None):
Either a tuple containing the (numpy.random.RandomState, torch.Generator)
or an int representing the random seed to use for both random states.
"""
if random_state is None:
self.random_states = random_state
elif isinstance(random_state, int):
self.random_states = (
np.random.RandomState(seed=random_state),
torch.Generator().manual_seed(random_state),
)
elif (
isinstance(random_state, tuple)
and isinstance(random_state[0], np.random.RandomState)
and isinstance(random_state[1], torch.Generator)
):
self.random_states = random_state
else:
raise TypeError(
f"`random_state` {random_state} expected to be an int or a tuple of "
"(`np.random.RandomState`, `torch.Generator`)"
)


# CTGAN model
class CTGAN(BaseSynthesizer):
# 后续需要根据实际情况做性能优化
class CTGAN(BaseSynthesizerModel):
"""Conditional Table GAN Synthesizer.
This is the core class of the CTGAN project, where the different components
Expand Down Expand Up @@ -643,9 +517,3 @@ def sample(self, n, condition_column=None, condition_value=None):
data = data[:n]

return self._transformer.inverse_transform(data)

def set_device(self, device):
"""Set the `device` to be used ('GPU' or 'CPU)."""
self._device = device
if self._generator is not None:
self._generator.to(self._device)

0 comments on commit df0fe2e

Please sign in to comment.