Skip to content

Commit

Permalink
Merge pull request #48 from hitsz-ids/feature-Data_Processor
Browse files Browse the repository at this point in the history
Update SDG's New Data Processor Structure
  • Loading branch information
MooooCat committed Dec 5, 2023
2 parents c212879 + c7ab510 commit 86f7710
Show file tree
Hide file tree
Showing 19 changed files with 66 additions and 27 deletions.
4 changes: 2 additions & 2 deletions docs/develop/single_table_GAN.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
functional,
)
from sdgx.models.base import BaseSynthesizerModel
from sdgx.transform.sampler import DataSamplerCTGAN
from sdgx.transform.transformer import DataTransformerCTGAN
from sdgx.data_process.sampling.sampler import DataSamplerCTGAN
from sdgx.data_process.transform.transform import DataTransformer
```

- 完成您的模块中的 `__init__` 函数,并定义相应的类变量,以CTGAN为例:
Expand Down
6 changes: 3 additions & 3 deletions docs/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
```python
# 导入相关模块
from sdgx.models.single_table.ctgan import CTGAN
from sdgx.transform.sampler import DataSamplerCTGAN
from sdgx.transform.transformer import DataTransformerCTGAN
from sdgx.data_process.sampling.sampler import DataSamplerCTGAN
from sdgx.data_process.transform.transform import DataTransformer
from sdgx.utils.io.csv_utils import *

# 读取数据
Expand Down Expand Up @@ -41,7 +41,7 @@ demo_data, discrete_cols = get_demo_single_table()
```python
#定义模型
model = GeneratorCTGAN(epochs=10,\
transformer= DataTransformerCTGAN,\
transformer= DataTransformer,\
sampler=DataSamplerCTGAN)

#训练模型
Expand Down
4 changes: 2 additions & 2 deletions example/1_ctgan_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from sdgx.models.single_table.ctgan import CTGAN

# from sdgx.transform.sampler import DataSamplerCTGAN
# from sdgx.transform.transformer import DataTransformerCTGAN
# from sdgx.data_process.sampling.sampler import DataSamplerCTGAN
# from sdgx.data_process.transform.transform import DataTransformer
from sdgx.utils.io.csv_utils import *

# 针对 csv 格式的小规模数据
Expand Down
File renamed without changes.
5 changes: 5 additions & 0 deletions sdgx/data_process/custom_logic/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# 设计上有参考sdv,但我们想做的更进一步
# 主要描述以下几类信息:
# - 列与列之间的限制关系
# - 列与数值之间的关系
# - 列与其他规则之间的关系,例如:我们将首先支持正则表达式
18 changes: 18 additions & 0 deletions sdgx/data_process/formatter/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Formatter: 列格式转换工具,基本描述如下:
# - 针对不同类型的列,实现解析能力,例如:DataTime 搞成时间戳形式;
# - 针对不同类型的列,提供格式上的转换能力
# - 输入和输出均为【列】数据
#
# 同时也在此说明与 transform 的区别:
# - 涉及到【单列】作为输入的,涉及【格式转换】问题,使用 formatter
# - 涉及到【整张表】作为输入的进行转换,使用 data transformer
# - 通常,在 Data Transformer 的实现中,针对列的情况,调用不同的 formatter
# - 提供 extract 方法
#


class BaseFormatter(object):
# def extract_xxx(self):
# pass

pass
7 changes: 7 additions & 0 deletions sdgx/data_process/metadata/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# 顾名思义,Metadata 用于记录表的元数据信息,在第一阶段,主要描述如下:
# - 我们会参考 sdv 中的 metadata 管理方法,但不会照搬;
# - 我们则主要提供表元数据的描述;
# - 我们提供表于表之间元数据的描述;
# - 我们提供一些必要的问题检测,例如:DAG检测;
# - 我们提供足够的人工接口,用于人工修改和配置元数据;
# - 未来会提供更多实用功能
File renamed without changes.
Empty file.
Empty file.
6 changes: 6 additions & 0 deletions sdgx/data_process/pii_generator/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# PII Generator 模块专门用于针对 PII 类型的列进行 【生成】
# 随机生成是一种简单粗暴且有效的方法,但我们会更进一步
# 该模块主要负责:
# - 针对不同类型的 PII 对象(列),提供针对列的批量生成方法
# - 针对不同类型的 PII 对象,提供随机化的生成方法
# - 以地域、归属地等限制条件为输入,生成 PII 对象
4 changes: 4 additions & 0 deletions sdgx/data_process/sampling/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Sampler 模块主要用于针对下列情况:
# - 大规模数据库时候的情况;
# - 其他必要的情况,包括:csv, xls 等;
# - 个别模型模型所需的 sampler ;
File renamed without changes.
Empty file.
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
"""
DataTransformer 模块:
目前使用了CTGAN开源项目中的代码
后续还会根据实际业务需求进一步进行改写
以及进行一些性能优化
DataTransform 模块:
将把该模块列入 Data Process 中
"""

from collections import namedtuple
Expand All @@ -21,7 +18,7 @@
)


class DataTransformerCTGAN(object):
class DataTransformer(object):
"""Data Transformer.
Model continuous columns with a BayesianGMM and normalized to a scalar [0, 1] and a vector.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class DataTransformer(object):
"""OPTIMIZE SDG 重写的 Data Transformer
应对大数据(数据 > 内存)情况下的 Transformer 解决方案
应对大数据(数据 > 内存)情况下的 Transformer 解决方案(试行)
- 对于连续列:使用 BayesianGMM 对连续列建模并标准化为标量 [0, 1] 和向量。
- 对于离散列:使用 scikit-learn OneHotEncoder 进行编码。
Expand Down
10 changes: 5 additions & 5 deletions sdgx/models/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
functional,
)

# transformer 以及 sampler 已经拆分,挪到 transform/ 目录中
from sdgx.data_process.sampling.sampler import DataSamplerCTGAN
from sdgx.data_process.transform.transform import DataTransformer

# base 类已拆分,挪到 base.py
from sdgx.models.base import BaseSynthesizerModel
from sdgx.models.extension import hookimpl

# transformer 以及 sampler 已经拆分,挪到 transform/ 目录中
from sdgx.transform.sampler import DataSamplerCTGAN
from sdgx.transform.transformer import DataTransformerCTGAN

# 其他函数
from sdgx.utils.utils import random_state

Expand Down Expand Up @@ -391,7 +391,7 @@ def fit(
)

# 载入 transformer
self._transformer = DataTransformerCTGAN()
self._transformer = DataTransformer()
self._transformer.fit(train_data, discrete_columns)

# 使用 transformer 处理数据
Expand Down
9 changes: 4 additions & 5 deletions sdgx/statistics/single_table/copula.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
from copulas import multivariate
from rdt.transformers import OneHotEncoder

# transformer 以及 sampler 已经拆分,挪到 transform/ 目录中
# from sdgx.data_process.sampling.sampler import DataSamplerCTGAN
from sdgx.data_process.transform.transform import DataTransformer
from sdgx.errors import NonParametricError
from sdgx.statistics.base import BaseSynthesizerModel

# transformer 以及 sampler 已经拆分,挪到 transform/ 目录中
# from sdgx.transform.sampler import DataSamplerCTGAN
from sdgx.transform.transformer import DataTransformerCTGAN
from sdgx.utils.utils import (
flatten_dict,
log_numerical_distributions_error,
Expand Down Expand Up @@ -149,7 +148,7 @@ def __init__(
# 2. 增加
def fit(self, processed_data):
# 载入 transformer
self._transformer = DataTransformerCTGAN()
self._transformer = DataTransformer()
self._transformer.fit(processed_data, self.metadata[0])

# 使用 transformer 处理数据
Expand Down
9 changes: 6 additions & 3 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@
_HERE = os.path.dirname(__file__)
sys.path.append(os.getcwd())

from sdgx.transform.transformer import DataTransformerCTGAN
from sdgx.transform.transformer_opt import DataTransformer
from sdgx.data_process.transform.transform import DataTransformer

# from sdgx.data_process.transform.transformer_opt import DataTransformer
from sdgx.utils.io.csv_utils import *


def test_transformer_original():
demo_data, discrete_cols = get_demo_single_table()
ctgan_transformer = DataTransformerCTGAN()
ctgan_transformer = DataTransformer()
ctgan_transformer.fit(demo_data, discrete_cols)
transformed_data = ctgan_transformer.transform(demo_data)


"""
def test_transformer_opt():
# 测试经过内存优化之后的 transformer
demo_data_path = "./dataset/adult.csv"
Expand All @@ -40,6 +42,7 @@ def test_transformer_opt():
shutil.os.remove("inverse_tmp.csv")
shutil.os.remove("output_tmp.csv")
pass
"""


if __name__ == "__main__":
Expand Down

0 comments on commit 86f7710

Please sign in to comment.