Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[0.1.0] Intro NDArryLoader #75

Merged
merged 6 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,4 @@ cython_debug/
*.log

.sdgx_cache
.ndarry_cache
101 changes: 60 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,61 +70,80 @@ pip install git+https://github.com/hitsz-ids/synthetic-data-generator.git
pip install sdgx
```

### Quick Demo of Single Table Data Generation
### Quick Demo of Single Table Data Generation and Metric

#### Demo code

```python
# Import modules
from sdgx.models.single_table.ctgan import CTGAN
from sdgx.utils.io.csv_utils import *
from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
from sdgx.utils import download_demo_data

# Read data from demo
demo_data, discrete_cols = get_demo_single_table()
```
dataset_csv = download_demo_data()
data_connector = CsvConnector(path=dataset_csv)
synthesizer = Synthesizer(
model=CTGANSynthesizerModel(epochs=1), # For quick demo
data_connector=data_connector,
)

Real data are as follows:
synthesizer.fit()
sampled_data = synthesizer.sample(1000)
synthesizer.cleanup() # Clean all cache

```
age workclass fnlwgt ... hours-per-week native-country class
0 27 Private 177119 ... 44 United-States <=50K
1 27 Private 216481 ... 40 United-States <=50K
2 25 Private 256263 ... 40 United-States <=50K
3 46 Private 147640 ... 40 United-States <=50K
4 45 Private 172822 ... 76 United-States >50K
... ... ... ... ... ... ... ...
32556 43 Local-gov 33331 ... 40 United-States >50K
32557 44 Private 98466 ... 35 United-States <=50K
32558 23 Private 45317 ... 40 United-States <=50K
32559 45 Local-gov 215862 ... 45 United-States >50K
32560 25 Private 186925 ... 48 United-States <=50K

[32561 rows x 15 columns]
# Optional, use JSD for mectics
from sdgx.metrics.column.jsd import JSD

JSD = JSD()

selected_columns = ["workclass"]
isDiscrete = True
metrics = JSD.calculate(data_connector.read(), sampled_data, selected_columns, isDiscrete)

print("JSD metric of column %s: %g" % (selected_columns[0], metrics))
```

#### Comparison

Real data are as follows:

```python
# Define model
model = CTGAN(epochs=10)
# Model training
model.fit(demo_data, discrete_cols)
>>> data_connector.read()
age workclass fnlwgt education ... capitalloss hoursperweek native-country class
0 2 State-gov 77516 Bachelors ... 0 2 United-States <=50K
1 3 Self-emp-not-inc 83311 Bachelors ... 0 0 United-States <=50K
2 2 Private 215646 HS-grad ... 0 2 United-States <=50K
3 3 Private 234721 11th ... 0 2 United-States <=50K
4 1 Private 338409 Bachelors ... 0 2 Cuba <=50K
... ... ... ... ... ... ... ... ... ...
48837 2 Private 215419 Bachelors ... 0 2 United-States <=50K
48838 4 NaN 321403 HS-grad ... 0 2 United-States <=50K
48839 2 Private 374983 Bachelors ... 0 3 United-States <=50K
48840 2 Private 83891 Bachelors ... 0 2 United-States <=50K
48841 1 Self-emp-inc 182148 Bachelors ... 0 3 United-States >50K

[48842 rows x 15 columns]

# Generate synthetic data
sampled_data = model.generate(1000)
```

Synthetic data are as follows:

```
age workclass fnlwgt ... hours-per-week native-country class
0 33 Private 276389 ... 41 United-States >50K
1 33 Self-emp-not-inc 296948 ... 54 United-States <=50K
2 67 Without-pay 266913 ... 51 Columbia <=50K
3 49 Private 423018 ... 41 United-States >50K
4 22 Private 295325 ... 39 United-States >50K
5 63 Private 234140 ... 65 United-States <=50K
6 42 Private 243623 ... 52 United-States <=50K
7 75 Private 247679 ... 41 United-States <=50K
8 79 Private 332237 ... 41 United-States >50K
9 28 State-gov 837932 ... 99 United-States <=50K
```python
>>> sampled_data
age workclass fnlwgt education ... capitalloss hoursperweek native-country class
0 1 NaN 28219 Some-college ... 0 2 Puerto-Rico <=50K
1 2 Private 250166 HS-grad ... 0 2 United-States >50K
2 2 Private 50304 HS-grad ... 0 2 United-States <=50K
3 4 Private 89318 Bachelors ... 0 2 Puerto-Rico >50K
4 1 Private 172149 Bachelors ... 0 3 United-States <=50K
.. ... ... ... ... ... ... ... ... ...
995 2 NaN 208938 Bachelors ... 0 1 United-States <=50K
996 2 Private 166416 Bachelors ... 2 2 United-States <=50K
997 2 NaN 336022 HS-grad ... 0 1 United-States <=50K
998 3 Private 198051 Masters ... 0 2 United-States >50K
999 1 NaN 41973 HS-grad ... 0 2 United-States <=50K

[1000 rows x 15 columns]
```

## 🤝 Join Community
Expand Down
101 changes: 60 additions & 41 deletions README_ZH_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,59 +71,78 @@ pip install sdgx

### 单表数据快速合成示例

#### 演示代码

```python
# 导入相关模块
from sdgx.models.single_table.ctgan import CTGAN
from sdgx.utils.io.csv_utils import *
from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
from sdgx.utils import download_demo_data

# 读取数据
demo_data, discrete_cols = get_demo_single_table()
```
dataset_csv = download_demo_data()
data_connector = CsvConnector(path=dataset_csv)
synthesizer = Synthesizer(
model=CTGANSynthesizerModel(epochs=1), # For quick demo
data_connector=data_connector,
)

真实数据示例如下:
synthesizer.fit()
sampled_data = synthesizer.sample(1000)
synthesizer.cleanup() # Clean all cache

```
age workclass fnlwgt ... hours-per-week native-country class
0 27 Private 177119 ... 44 United-States <=50K
1 27 Private 216481 ... 40 United-States <=50K
2 25 Private 256263 ... 40 United-States <=50K
3 46 Private 147640 ... 40 United-States <=50K
4 45 Private 172822 ... 76 United-States >50K
... ... ... ... ... ... ... ...
32556 43 Local-gov 33331 ... 40 United-States >50K
32557 44 Private 98466 ... 35 United-States <=50K
32558 23 Private 45317 ... 40 United-States <=50K
32559 45 Local-gov 215862 ... 45 United-States >50K
32560 25 Private 186925 ... 48 United-States <=50K

[32561 rows x 15 columns]
# Optional, use JSD for mectics
from sdgx.metrics.column.jsd import JSD

JSD = JSD()

selected_columns = ["workclass"]
isDiscrete = True
metrics = JSD.calculate(data_connector.read(), sampled_data, selected_columns, isDiscrete)

print("JSD metric of column %s: %g" % (selected_columns[0], metrics))
```

#### 对比

真实数据:

```python
#定义模型
model = CTGAN(epochs=10)
# 训练模型
model.fit(demo_data, discrete_cols)
>>> data_connector.read()
age workclass fnlwgt education ... capitalloss hoursperweek native-country class
0 2 State-gov 77516 Bachelors ... 0 2 United-States <=50K
1 3 Self-emp-not-inc 83311 Bachelors ... 0 0 United-States <=50K
2 2 Private 215646 HS-grad ... 0 2 United-States <=50K
3 3 Private 234721 11th ... 0 2 United-States <=50K
4 1 Private 338409 Bachelors ... 0 2 Cuba <=50K
... ... ... ... ... ... ... ... ... ...
48837 2 Private 215419 Bachelors ... 0 2 United-States <=50K
48838 4 NaN 321403 HS-grad ... 0 2 United-States <=50K
48839 2 Private 374983 Bachelors ... 0 3 United-States <=50K
48840 2 Private 83891 Bachelors ... 0 2 United-States <=50K
48841 1 Self-emp-inc 182148 Bachelors ... 0 3 United-States >50K

[48842 rows x 15 columns]

# 生成合成数据
sampled_data = model.generate(1000)
```

合成数据如下
仿真数据

```
age workclass fnlwgt ... hours-per-week native-country class
0 33 Private 276389 ... 41 United-States >50K
1 33 Self-emp-not-inc 296948 ... 54 United-States <=50K
2 67 Without-pay 266913 ... 51 Columbia <=50K
3 49 Private 423018 ... 41 United-States >50K
4 22 Private 295325 ... 39 United-States >50K
5 63 Private 234140 ... 65 United-States <=50K
6 42 Private 243623 ... 52 United-States <=50K
7 75 Private 247679 ... 41 United-States <=50K
8 79 Private 332237 ... 41 United-States >50K
9 28 State-gov 837932 ... 99 United-States <=50K
```python
>>> sampled_data
age workclass fnlwgt education ... capitalloss hoursperweek native-country class
0 1 NaN 28219 Some-college ... 0 2 Puerto-Rico <=50K
1 2 Private 250166 HS-grad ... 0 2 United-States >50K
2 2 Private 50304 HS-grad ... 0 2 United-States <=50K
3 4 Private 89318 Bachelors ... 0 2 Puerto-Rico >50K
4 1 Private 172149 Bachelors ... 0 3 United-States <=50K
.. ... ... ... ... ... ... ... ... ...
995 2 NaN 208938 Bachelors ... 0 1 United-States <=50K
996 2 Private 166416 Bachelors ... 2 2 United-States <=50K
997 2 NaN 336022 HS-grad ... 0 1 United-States <=50K
998 3 Private 198051 Masters ... 0 2 United-States >50K
999 1 NaN 41973 HS-grad ... 0 2 United-States <=50K

[1000 rows x 15 columns]
```

## 🤝 如何贡献
Expand Down
31 changes: 15 additions & 16 deletions example/1_ctgan_example.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
# To run this example, you can use:
# ipython - i example/1_ctgan_example.py
# then view the sampled_data
from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
from sdgx.utils import download_demo_data

dataset_csv = download_demo_data()
data_connector = CsvConnector(path=dataset_csv)
synthesizer = Synthesizer(
model=CTGANSynthesizerModel(epochs=1), # For quick demo
data_connector=data_connector,
)
synthesizer.fit()
sampled_data = synthesizer.sample(1000)
synthesizer.cleanup() # Clean all cache

import numpy as np

from sdgx.metrics.column.jsd import JSD
from sdgx.models.single_table.ctgan import CTGAN
from sdgx.utils.io.csv_utils import *

# 针对 csv 格式的小规模数据
# 目前我们以 df 作为输入的数据的格式
demo_data, discrete_cols = get_demo_single_table()
JSD = JSD()

model = CTGAN(epochs=10)
model.fit(demo_data, discrete_cols)

sampled_data = model.sample(1000)

# selected_columns = ["education-num", "fnlwgt"]
# isDiscrete = False
selected_columns = ["workclass"]
isDiscrete = True
metrics = JSD.calculate(demo_data, sampled_data, selected_columns, isDiscrete)
metrics = JSD.calculate(data_connector.read(), sampled_data, selected_columns, isDiscrete)

print("JSD metric of column %s: %g" % (selected_columns[0], metrics))
9 changes: 7 additions & 2 deletions sdgx/cachers/disk_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import shutil
from functools import lru_cache
from pathlib import Path
from typing import Generator
Expand Down Expand Up @@ -43,12 +44,13 @@ def __init__(
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)

def clear_cache(self):
def clear_cache(self) -> None:
"""
Clear all cache in cache_dir.
"""
for f in self.cache_dir.glob("*.parquet"):
f.unlink()
shutil.rmtree(self.cache_dir, ignore_errors=True)

def clear_invalid_cache(self):
"""
Expand All @@ -74,6 +76,7 @@ def _refresh(self, offset: int, data: pd.DataFrame) -> None:
"""
Refresh cache, will write data to cache file in parquet format.
"""
self.cache_dir.mkdir(parents=True, exist_ok=True)
if len(data) < self.blocksize:
data.to_parquet(self._get_cache_filename(offset))
elif len(data) > self.blocksize:
Expand Down Expand Up @@ -103,6 +106,8 @@ def load(self, offset: int, chunksize: int, data_connector: DataConnector) -> pd
return cached_data[:chunksize]
return cached_data
data = data_connector.read(offset=offset, limit=max(self.blocksize, chunksize))
if data is None:
return data
self._refresh(offset, data)
if len(data) < chunksize:
return data
Expand All @@ -117,7 +122,7 @@ def iter(
offset = 0
while True:
data = self.load(offset, chunksize, data_connector)
if len(data) == 0:
if data is None or len(data) == 0:
break
yield data
offset += len(data)
Expand Down
4 changes: 3 additions & 1 deletion sdgx/cachers/memory_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def load(self, offset: int, chunksize: int, data_connector: DataConnector) -> pd
return cached_data

data = data_connector.read(offset=offset, limit=max(self.blocksize, chunksize))
if data is None:
return data
self._refresh(offset, data)
if len(data) < chunksize:
return data
Expand All @@ -70,7 +72,7 @@ def iter(
offset = 0
while True:
data = self.load(offset, chunksize, data_connector)
if len(data) == 0:
if data is not None and len(data) == 0:
break
yield data
offset += len(data)
Expand Down
4 changes: 2 additions & 2 deletions sdgx/data_connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DataConnector:
Identity of data source, e.g. table name, hash of content
"""

def _read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame:
def _read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame | None | None:
"""
Subclass must implement this for reading data.

Expand Down Expand Up @@ -54,7 +54,7 @@ def iter(self, offset: int = 0, chunksize: int = 0) -> Generator[pd.DataFrame, N
"""
return self._iter(offset, chunksize)

def read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame:
def read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame | None:
"""
Interface for reading data.

Expand Down
2 changes: 1 addition & 1 deletion sdgx/data_connectors/csv_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
self.header = header
self.read_csv_kwargs = read_csv_kwargs

def _read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame:
def _read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame | None:
""" """
return pd.read_csv(
self.path,
Expand Down
2 changes: 1 addition & 1 deletion sdgx/data_connectors/generator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self.generator_caller = generator_caller
self._generator = self.generator_caller()

def _read(self, offset=0, limit=None) -> pd.DataFrame:
def _read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame | None:
"""
Ingore limit and allow sequential reading.
"""
Expand Down
Loading