Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge GaussianCopula Model #20

Merged
merged 3 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/develop/single_table_GAN.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SDG 模型开发文档
# SDG 神经网络模型开发文档

## 为 SDG 开发可运行的单表模型模块

Expand Down
1 change: 1 addition & 0 deletions docs/develop/single_table_GaussianCopula.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SDG 统计学模型开发文档
19 changes: 19 additions & 0 deletions example/2_guassian_copula_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# 运行该例子,可使用:
# ipython -i example/2_guassian_copula_example.py
# 并查看 sampled_data 变量

from sdgx.statistics.single_table.copula import GaussianCopulaSynthesizer
from sdgx.utils.io.csv_utils import *

# 针对 csv 格式的小规模数据
# 目前我们以 df 作为输入的数据的格式
demo_data, discrete_cols = get_demo_single_table()
# print(demo_data)
# print(discrete_cols)

model = GaussianCopulaSynthesizer(discrete_cols)
model.fit(demo_data)

# sampled
sampled_data = model.sample(10)
print(sampled_data)
9 changes: 9 additions & 0 deletions sdgx/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# TBD
# 主要用于存放 sdg 中特有的的报错信息

class SdgxError(Exception):
"""Base class for exceptions in this module."""
pass

class NonParametricError(Exception):
"""Exception to indicate that a model is not parametric."""
86 changes: 86 additions & 0 deletions sdgx/statistics/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import List, Optional

import numpy as np
import torch


class BaseSynthesizerModel:
random_states = None

def __init__(self, transformer=None, sampler=None) -> None:
# 以下几个变量都需要在初始化 model 时进行更改
self.model = None # 存放模型
self.status = "UNFINED"
self.model_type = "MODEL_TYPE_UNDEFINED"
# self.epochs = epochs
self._device = "CPU"

def fit(self, input_df, discrete_cols: Optional[List] = None):
raise NotImplementedError

def set_device(self, device):
"""Set the `device` to be used ('GPU' or 'CPU')."""
self._device = device
if self._generator is not None:
self._generator.to(self._device)

def __getstate__(self):
device_backup = self._device
self.set_device(torch.device("cpu"))
state = self.__dict__.copy()
self.set_device(device_backup)
if (
isinstance(self.random_states, tuple)
and isinstance(self.random_states[0], np.random.RandomState)
and isinstance(self.random_states[1], torch.Generator)
):
state["_numpy_random_state"] = self.random_states[0].get_state()
state["_torch_random_state"] = self.random_states[1].get_state()
state.pop("random_states")
return state

def __setstate__(self, state):
if "_numpy_random_state" in state and "_torch_random_state" in state:
np_state = state.pop("_numpy_random_state")
torch_state = state.pop("_torch_random_state")
current_torch_state = torch.Generator()
current_torch_state.set_state(torch_state)
current_numpy_state = np.random.RandomState()
current_numpy_state.set_state(np_state)
state["random_states"] = (current_numpy_state, current_torch_state)
self.__dict__ = state
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.set_device(device)

def save(self, path):
device_backup = self._device
self.set_device(torch.device("cpu"))
torch.save(self, path)
self.set_device(device_backup)

@classmethod
def load(cls, path):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load(path)
model.set_device(device)
return model

def set_random_state(self, random_state):
if random_state is None:
self.random_states = random_state
elif isinstance(random_state, int):
self.random_states = (
np.random.RandomState(seed=random_state),
torch.Generator().manual_seed(random_state),
)
elif (
isinstance(random_state, tuple)
and isinstance(random_state[0], np.random.RandomState)
and isinstance(random_state[1], torch.Generator)
):
self.random_states = random_state
else:
raise TypeError(
f"`random_state` {random_state} expected to be an int or a tuple of "
"(`np.random.RandomState`, `torch.Generator`)"
)
Loading
Loading