Skip to content

Commit

Permalink
Merge pull request #35 from hitsz-ids/optimize-CTGAN
Browse files Browse the repository at this point in the history
Optimize the Transformer module
  • Loading branch information
MooooCat committed Oct 30, 2023
2 parents 2c711ee + fbaca68 commit a1620d1
Show file tree
Hide file tree
Showing 9 changed files with 568 additions and 26 deletions.
7 changes: 3 additions & 4 deletions example/1_ctgan_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
# 并查看 sampled_data 变量

from sdgx.models.single_table.ctgan import CTGAN
from sdgx.transform.sampler import DataSamplerCTGAN
from sdgx.transform.transformer import DataTransformerCTGAN
# from sdgx.transform.sampler import DataSamplerCTGAN
# from sdgx.transform.transformer import DataTransformerCTGAN
from sdgx.utils.io.csv_utils import *

# 针对 csv 格式的小规模数据
# 目前我们以 df 作为输入的数据的格式
demo_data, discrete_cols = get_demo_single_table()


model = CTGAN(epochs=10, transformer=DataTransformerCTGAN, sampler=DataSamplerCTGAN)
model = CTGAN(epochs=10)
model.fit(demo_data, discrete_cols)

# sampled
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt.old
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# using python 3.9
setproctitle==1.2.3
PyMySQL==1.0.2
pandas==1.2.4
numpy==1.21.5
pandas>=1.2.4
numpy>=1.21.5
scikit-learn==1.2.2
torch==1.12.0
torchvision==0.13.0
Expand Down
72 changes: 63 additions & 9 deletions sdgx/models/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def forward(self, input_):
return data


# 后续需要根据实际情况做性能优化
# 目前针对了较大的数据进行内存占用上的优化
# 总体思路是采用分批 load 进内存的方法,优化其占用空间的大小
class CTGAN(BaseSynthesizerModel):
"""Conditional Table GAN Synthesizer.
Expand Down Expand Up @@ -155,6 +156,12 @@ class CTGAN(BaseSynthesizerModel):
Whether to attempt to use cuda for GPU computation.
If this is False or CUDA is not available, CPU will be used.
Defaults to ``True``.
memory_optimize(bool):
设计上这个参数需要是 str 类型或者 bool 类型 可以是如下数值:
- True:根据实际情况进行自动化内存优化,保证在能跑完的情况下,尽量多地利用内存,同时不报错
- False:不采用任何内存优化策略
- None:同 False
该参数默认设为 False
"""

def __init__(
Expand All @@ -173,8 +180,7 @@ def __init__(
epochs=300,
pac=10,
cuda=True,
transformer=None,
sampler=None,
memory_optimize = False
):
assert batch_size % 2 == 0

Expand Down Expand Up @@ -202,12 +208,15 @@ def __init__(
device = "cuda"
self._device = torch.device(device)

# self._transformer = None
# self._data_sampler = None
self._transformer = transformer
self._data_sampler = sampler
self._generator = None

# 是否启用内存优化,这里初步打算采用内存限制进行
# 设计上这个参数需要是 str 类型或者 bool 类型 可以是如下数值:
# - True:根据实际情况进行自动化内存优化,保证在能跑完的情况下,尽量多地利用内存,同时不报错
# - False:不采用任何内存优化策略
self.memory_optimize = memory_optimize


@staticmethod
def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
"""Deals with the instability of the gumbel_softmax for older versions of torch.
Expand Down Expand Up @@ -306,23 +315,68 @@ def _validate_discrete_columns(self, train_data, discrete_columns):

if invalid_columns:
raise ValueError(f"Invalid columns found: {invalid_columns}")


# OPTIMIZE 新增方法,完成在内存受限情况下的 CTGAN 训练
@random_state
def fit_optimze(self, train_data_iterator,\
discrete_columns: Optional[List] = [],\
epoches = 10):
''' OPTIMIZE 新增方法,在内存受限情况下完成 CTGAN 训练
参数列表:
train_data_iterator (iterator):
train_data_iterator 是 带有训练数据的迭代器,
需要根据 csv 文件,迭代返回训练数据,
属于必填参数。
discrete_columns (list-like):
描述离散特征的一个 python 列表。
如果 训练数据的迭代器 中不含有列名,则返回列的编号,
否则 ,返回列的名称(可以乱序),
属于选填参数,默认为 python 空列表。
epoches(int-like):
CTGAN 模型的迭代次数,
选填参数,默认为10,但严肃应用中应必须填写。
'''
if epochs is None:
epochs = self._epochs


pass

@random_state
def fit(self, train_data, discrete_columns: Optional[List] = None, epochs=None):
def fit(self, train_data,
discrete_columns: Optional[List] = None,\
train_data_iterator = None,
epochs=None):
"""Fit the CTGAN Synthesizer models to the training data.
Args:
train_data (numpy.ndarray or pandas.DataFrame):
Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
train_data_iterator (iterator):
train_data_iterator 是 带有训练数据的迭代器
需要根据 csv 文件,迭代返回训练数据
discrete_columns (list-like):
List of discrete columns to be used to generate the Conditional
Vector. If ``train_data`` is a Numpy array, this list should
contain the integer indices of the columns. Otherwise, if it is
a ``pandas.DataFrame``, this list should contain the column names.
"""
# set discrete_columns
if not discrete_columns:
discrete_columns = []
# 离散列检查

# OPTIMIZE 检查 optimize 以及 train_data_iterator
if self.memory_optimize and train_data_iterator is None:
raise ValueError("train_data_iterator should not be None.")
# 如果符合 optimize 的需求,则转到新实现的 optimize 方法,这样也不干扰老方法的顺利执行
if self.memory_optimize and train_data_iterator is not None:
self.fit_optimze(train_data_iterator, discrete_columns, epochs)
return
# OPTIMIZE 改动结束

# 以下为原始的 fit 方法
self._validate_discrete_columns(train_data, discrete_columns)

# 参数检查
Expand Down
2 changes: 2 additions & 0 deletions sdgx/transform/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,5 @@ def convert_column_name_value_to_id(self, column_name, value):
"column_id": column_id,
"value_id": np.argmax(one_hot),
}


Loading

0 comments on commit a1620d1

Please sign in to comment.