Skip to content

Commit

Permalink
Merge pull request #38 from hitsz-ids/bugfix-requirements
Browse files Browse the repository at this point in the history
Update `requirements.txt`
  • Loading branch information
MooooCat committed Oct 31, 2023
2 parents ad066b3 + 1edddab commit 6e3abb1
Show file tree
Hide file tree
Showing 17 changed files with 263 additions and 251 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

Synthetic Data Generator (SDG) is a framework focused on quickly generating high-quality structured tabular data. It supports more than 10 single-table and multi-table data synthesis algorithms, achieving up to 120 times performance improvement, and supports differential privacy and other methods to enhance the security of synthesized data.

Synthetic data is generated by machines based on real data and algorithms, it does not contain sensitive information, but can retain the characteristics of real data.
There is no correspondence between synthetic data and real data, and it is not subject to privacy regulations such as GDPR and ADPPA.
In practical applications, there is no need to worry about the risk of privacy leakage.
Synthetic data is generated by machines based on real data and algorithms, it does not contain sensitive information, but can retain the characteristics of real data.
There is no correspondence between synthetic data and real data, and it is not subject to privacy regulations such as GDPR and ADPPA.
In practical applications, there is no need to worry about the risk of privacy leakage.
High-quality synthetic data can also be used in various fields such as data opening, model training and debugging, system development and testing, etc.


Expand Down Expand Up @@ -138,4 +138,3 @@ The SDG project was initiated by **Institute of Data Security, Harbin Institute
## 📄 License

The SDG open source project uses Apache-2.0 license, please refer to the [LICENSE].(https://github.com/hitsz-ids/synthetic-data-generator/blob/main/LICENSE)。

2 changes: 1 addition & 1 deletion docs/develop/single_table_GaussianCopula.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
# SDG 统计学模型开发文档
# SDG 统计学模型开发文档
1 change: 0 additions & 1 deletion docs/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,3 @@ sampled = model.generate(num_rows=10)
[10 rows x 10 columns]}}
```

1 change: 1 addition & 0 deletions example/1_ctgan_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# 并查看 sampled_data 变量

from sdgx.models.single_table.ctgan import CTGAN

# from sdgx.transform.sampler import DataSamplerCTGAN
# from sdgx.transform.transformer import DataTransformerCTGAN
from sdgx.utils.io.csv_utils import *
Expand Down
23 changes: 12 additions & 11 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# using python 3.9
setproctitle==1.2.3
PyMySQL==1.0.2
pandas>=1.2.4
copulas==0.9.1
dython==0.5.1
joblib>=1.2.0
numpy>=1.21.5
scikit-learn==1.2.2
torch==1.12.0
torchvision==0.13.0
pandas>=1.2.4
PyMySQL>=1.0.2
pytest>=7.4.3
rdt==1.6.0
joblib==1.2.0
dython==0.5.1
scikit-learn==1.2.2
seaborn==0.11.1
table-evaluator==1.4.2
copulas==0.9.1
# using python 3.8 3.9, 3.10
setproctitle>=1.2.3
table-evaluator>=1.4.2
torch>=2.1.0
torchvision
9 changes: 6 additions & 3 deletions sdgx/errors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# TBD
# TBD
# 主要用于存放 sdg 中特有的的报错信息


class SdgxError(Exception):
"""Base class for exceptions in this module."""
pass

pass


class NonParametricError(Exception):
"""Exception to indicate that a model is not parametric."""
"""Exception to indicate that a model is not parametric."""
36 changes: 17 additions & 19 deletions sdgx/models/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ class CTGAN(BaseSynthesizerModel):
设计上这个参数需要是 str 类型或者 bool 类型 可以是如下数值:
- True:根据实际情况进行自动化内存优化,保证在能跑完的情况下,尽量多地利用内存,同时不报错
- False:不采用任何内存优化策略
- None:同 False
该参数默认设为 False
- None:同 False
该参数默认设为 False
"""

def __init__(
Expand All @@ -180,7 +180,7 @@ def __init__(
epochs=300,
pac=10,
cuda=True,
memory_optimize = False
memory_optimize=False,
):
assert batch_size % 2 == 0

Expand Down Expand Up @@ -216,7 +216,6 @@ def __init__(
# - False:不采用任何内存优化策略
self.memory_optimize = memory_optimize


@staticmethod
def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
"""Deals with the instability of the gumbel_softmax for older versions of torch.
Expand Down Expand Up @@ -315,14 +314,11 @@ def _validate_discrete_columns(self, train_data, discrete_columns):

if invalid_columns:
raise ValueError(f"Invalid columns found: {invalid_columns}")



# OPTIMIZE 新增方法,完成在内存受限情况下的 CTGAN 训练
@random_state
def fit_optimze(self, train_data_iterator,\
discrete_columns: Optional[List] = [],\
epoches = 10):
''' OPTIMIZE 新增方法,在内存受限情况下完成 CTGAN 训练
@random_state
def fit_optimze(self, train_data_iterator, discrete_columns: Optional[List] = [], epoches=10):
"""OPTIMIZE 新增方法,在内存受限情况下完成 CTGAN 训练
参数列表:
train_data_iterator (iterator):
Expand All @@ -337,18 +333,20 @@ def fit_optimze(self, train_data_iterator,\
epoches(int-like):
CTGAN 模型的迭代次数,
选填参数,默认为10,但严肃应用中应必须填写。
'''
"""
if epochs is None:
epochs = self._epochs


pass

@random_state
def fit(self, train_data,
discrete_columns: Optional[List] = None,\
train_data_iterator = None,
epochs=None):
def fit(
self,
train_data,
discrete_columns: Optional[List] = None,
train_data_iterator=None,
epochs=None,
):
"""Fit the CTGAN Synthesizer models to the training data.
Args:
Expand All @@ -367,13 +365,13 @@ def fit(self, train_data,
if not discrete_columns:
discrete_columns = []

# OPTIMIZE 检查 optimize 以及 train_data_iterator
# OPTIMIZE 检查 optimize 以及 train_data_iterator
if self.memory_optimize and train_data_iterator is None:
raise ValueError("train_data_iterator should not be None.")
# 如果符合 optimize 的需求,则转到新实现的 optimize 方法,这样也不干扰老方法的顺利执行
if self.memory_optimize and train_data_iterator is not None:
self.fit_optimze(train_data_iterator, discrete_columns, epochs)
return
return
# OPTIMIZE 改动结束

# 以下为原始的 fit 方法
Expand Down
Loading

0 comments on commit 6e3abb1

Please sign in to comment.