Skip to content

Commit

Permalink
Merge pull request #20 from hitsz-ids/feature-Gaussian_Copula
Browse files Browse the repository at this point in the history
Merge GaussianCopula Model
  • Loading branch information
joeyscave committed Oct 24, 2023
2 parents 99e2d37 + c8108b9 commit 2c711ee
Show file tree
Hide file tree
Showing 9 changed files with 1,021 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/develop/single_table_GAN.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SDG 模型开发文档
# SDG 神经网络模型开发文档

## 为 SDG 开发可运行的单表模型模块

Expand Down
1 change: 1 addition & 0 deletions docs/develop/single_table_GaussianCopula.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SDG 统计学模型开发文档
19 changes: 19 additions & 0 deletions example/2_guassian_copula_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# 运行该例子,可使用:
# ipython -i example/2_guassian_copula_example.py
# 并查看 sampled_data 变量

from sdgx.statistics.single_table.copula import GaussianCopulaSynthesizer
from sdgx.utils.io.csv_utils import *

# 针对 csv 格式的小规模数据
# 目前我们以 df 作为输入的数据的格式
demo_data, discrete_cols = get_demo_single_table()
# print(demo_data)
# print(discrete_cols)

model = GaussianCopulaSynthesizer(discrete_cols)
model.fit(demo_data)

# sampled
sampled_data = model.sample(10)
print(sampled_data)
9 changes: 9 additions & 0 deletions sdgx/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# TBD
# 主要用于存放 sdg 中特有的的报错信息

class SdgxError(Exception):
"""Base class for exceptions in this module."""
pass

class NonParametricError(Exception):
"""Exception to indicate that a model is not parametric."""
86 changes: 86 additions & 0 deletions sdgx/statistics/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import List, Optional

import numpy as np
import torch


class BaseSynthesizerModel:
random_states = None

def __init__(self, transformer=None, sampler=None) -> None:
# 以下几个变量都需要在初始化 model 时进行更改
self.model = None # 存放模型
self.status = "UNFINED"
self.model_type = "MODEL_TYPE_UNDEFINED"
# self.epochs = epochs
self._device = "CPU"

def fit(self, input_df, discrete_cols: Optional[List] = None):
raise NotImplementedError

def set_device(self, device):
"""Set the `device` to be used ('GPU' or 'CPU')."""
self._device = device
if self._generator is not None:
self._generator.to(self._device)

def __getstate__(self):
device_backup = self._device
self.set_device(torch.device("cpu"))
state = self.__dict__.copy()
self.set_device(device_backup)
if (
isinstance(self.random_states, tuple)
and isinstance(self.random_states[0], np.random.RandomState)
and isinstance(self.random_states[1], torch.Generator)
):
state["_numpy_random_state"] = self.random_states[0].get_state()
state["_torch_random_state"] = self.random_states[1].get_state()
state.pop("random_states")
return state

def __setstate__(self, state):
if "_numpy_random_state" in state and "_torch_random_state" in state:
np_state = state.pop("_numpy_random_state")
torch_state = state.pop("_torch_random_state")
current_torch_state = torch.Generator()
current_torch_state.set_state(torch_state)
current_numpy_state = np.random.RandomState()
current_numpy_state.set_state(np_state)
state["random_states"] = (current_numpy_state, current_torch_state)
self.__dict__ = state
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.set_device(device)

def save(self, path):
device_backup = self._device
self.set_device(torch.device("cpu"))
torch.save(self, path)
self.set_device(device_backup)

@classmethod
def load(cls, path):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load(path)
model.set_device(device)
return model

def set_random_state(self, random_state):
if random_state is None:
self.random_states = random_state
elif isinstance(random_state, int):
self.random_states = (
np.random.RandomState(seed=random_state),
torch.Generator().manual_seed(random_state),
)
elif (
isinstance(random_state, tuple)
and isinstance(random_state[0], np.random.RandomState)
and isinstance(random_state[1], torch.Generator)
):
self.random_states = random_state
else:
raise TypeError(
f"`random_state` {random_state} expected to be an int or a tuple of "
"(`np.random.RandomState`, `torch.Generator`)"
)
Loading

0 comments on commit 2c711ee

Please sign in to comment.