Skip to content

Commit

Permalink
[0.1.0] Intro NDArryLoader (#75)
Browse files Browse the repository at this point in the history
* Intro NDArryLoader
* Update docstring and REAME
  • Loading branch information
Wh1isper committed Dec 19, 2023
1 parent 33cd199 commit 482f7d5
Show file tree
Hide file tree
Showing 27 changed files with 604 additions and 171 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,4 @@ cython_debug/
*.log

.sdgx_cache
.ndarry_cache
101 changes: 60 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,61 +70,80 @@ pip install git+https://github.com/hitsz-ids/synthetic-data-generator.git
pip install sdgx
```

### Quick Demo of Single Table Data Generation
### Quick Demo of Single Table Data Generation and Metric

#### Demo code

```python
# Import modules
from sdgx.models.single_table.ctgan import CTGAN
from sdgx.utils.io.csv_utils import *
from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
from sdgx.utils import download_demo_data

# Read data from demo
demo_data, discrete_cols = get_demo_single_table()
```
dataset_csv = download_demo_data()
data_connector = CsvConnector(path=dataset_csv)
synthesizer = Synthesizer(
model=CTGANSynthesizerModel(epochs=1), # For quick demo
data_connector=data_connector,
)

Real data are as follows:
synthesizer.fit()
sampled_data = synthesizer.sample(1000)
synthesizer.cleanup() # Clean all cache

```
age workclass fnlwgt ... hours-per-week native-country class
0 27 Private 177119 ... 44 United-States <=50K
1 27 Private 216481 ... 40 United-States <=50K
2 25 Private 256263 ... 40 United-States <=50K
3 46 Private 147640 ... 40 United-States <=50K
4 45 Private 172822 ... 76 United-States >50K
... ... ... ... ... ... ... ...
32556 43 Local-gov 33331 ... 40 United-States >50K
32557 44 Private 98466 ... 35 United-States <=50K
32558 23 Private 45317 ... 40 United-States <=50K
32559 45 Local-gov 215862 ... 45 United-States >50K
32560 25 Private 186925 ... 48 United-States <=50K
[32561 rows x 15 columns]
# Optional, use JSD for mectics
from sdgx.metrics.column.jsd import JSD

JSD = JSD()

selected_columns = ["workclass"]
isDiscrete = True
metrics = JSD.calculate(data_connector.read(), sampled_data, selected_columns, isDiscrete)

print("JSD metric of column %s: %g" % (selected_columns[0], metrics))
```

#### Comparison

Real data are as follows:

```python
# Define model
model = CTGAN(epochs=10)
# Model training
model.fit(demo_data, discrete_cols)
>>> data_connector.read()
age workclass fnlwgt education ... capitalloss hoursperweek native-country class
0 2 State-gov 77516 Bachelors ... 0 2 United-States <=50K
1 3 Self-emp-not-inc 83311 Bachelors ... 0 0 United-States <=50K
2 2 Private 215646 HS-grad ... 0 2 United-States <=50K
3 3 Private 234721 11th ... 0 2 United-States <=50K
4 1 Private 338409 Bachelors ... 0 2 Cuba <=50K
... ... ... ... ... ... ... ... ... ...
48837 2 Private 215419 Bachelors ... 0 2 United-States <=50K
48838 4 NaN 321403 HS-grad ... 0 2 United-States <=50K
48839 2 Private 374983 Bachelors ... 0 3 United-States <=50K
48840 2 Private 83891 Bachelors ... 0 2 United-States <=50K
48841 1 Self-emp-inc 182148 Bachelors ... 0 3 United-States >50K

[48842 rows x 15 columns]

# Generate synthetic data
sampled_data = model.generate(1000)
```

Synthetic data are as follows:

```
age workclass fnlwgt ... hours-per-week native-country class
0 33 Private 276389 ... 41 United-States >50K
1 33 Self-emp-not-inc 296948 ... 54 United-States <=50K
2 67 Without-pay 266913 ... 51 Columbia <=50K
3 49 Private 423018 ... 41 United-States >50K
4 22 Private 295325 ... 39 United-States >50K
5 63 Private 234140 ... 65 United-States <=50K
6 42 Private 243623 ... 52 United-States <=50K
7 75 Private 247679 ... 41 United-States <=50K
8 79 Private 332237 ... 41 United-States >50K
9 28 State-gov 837932 ... 99 United-States <=50K
```python
>>> sampled_data
age workclass fnlwgt education ... capitalloss hoursperweek native-country class
0 1 NaN 28219 Some-college ... 0 2 Puerto-Rico <=50K
1 2 Private 250166 HS-grad ... 0 2 United-States >50K
2 2 Private 50304 HS-grad ... 0 2 United-States <=50K
3 4 Private 89318 Bachelors ... 0 2 Puerto-Rico >50K
4 1 Private 172149 Bachelors ... 0 3 United-States <=50K
.. ... ... ... ... ... ... ... ... ...
995 2 NaN 208938 Bachelors ... 0 1 United-States <=50K
996 2 Private 166416 Bachelors ... 2 2 United-States <=50K
997 2 NaN 336022 HS-grad ... 0 1 United-States <=50K
998 3 Private 198051 Masters ... 0 2 United-States >50K
999 1 NaN 41973 HS-grad ... 0 2 United-States <=50K

[1000 rows x 15 columns]
```

## 🤝 Join Community
Expand Down
101 changes: 60 additions & 41 deletions README_ZH_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,59 +71,78 @@ pip install sdgx

### 单表数据快速合成示例

#### 演示代码

```python
# 导入相关模块
from sdgx.models.single_table.ctgan import CTGAN
from sdgx.utils.io.csv_utils import *
from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
from sdgx.utils import download_demo_data

# 读取数据
demo_data, discrete_cols = get_demo_single_table()
```
dataset_csv = download_demo_data()
data_connector = CsvConnector(path=dataset_csv)
synthesizer = Synthesizer(
model=CTGANSynthesizerModel(epochs=1), # For quick demo
data_connector=data_connector,
)

真实数据示例如下:
synthesizer.fit()
sampled_data = synthesizer.sample(1000)
synthesizer.cleanup() # Clean all cache

```
age workclass fnlwgt ... hours-per-week native-country class
0 27 Private 177119 ... 44 United-States <=50K
1 27 Private 216481 ... 40 United-States <=50K
2 25 Private 256263 ... 40 United-States <=50K
3 46 Private 147640 ... 40 United-States <=50K
4 45 Private 172822 ... 76 United-States >50K
... ... ... ... ... ... ... ...
32556 43 Local-gov 33331 ... 40 United-States >50K
32557 44 Private 98466 ... 35 United-States <=50K
32558 23 Private 45317 ... 40 United-States <=50K
32559 45 Local-gov 215862 ... 45 United-States >50K
32560 25 Private 186925 ... 48 United-States <=50K
[32561 rows x 15 columns]
# Optional, use JSD for mectics
from sdgx.metrics.column.jsd import JSD

JSD = JSD()

selected_columns = ["workclass"]
isDiscrete = True
metrics = JSD.calculate(data_connector.read(), sampled_data, selected_columns, isDiscrete)

print("JSD metric of column %s: %g" % (selected_columns[0], metrics))
```

#### 对比

真实数据:

```python
#定义模型
model = CTGAN(epochs=10)
# 训练模型
model.fit(demo_data, discrete_cols)
>>> data_connector.read()
age workclass fnlwgt education ... capitalloss hoursperweek native-country class
0 2 State-gov 77516 Bachelors ... 0 2 United-States <=50K
1 3 Self-emp-not-inc 83311 Bachelors ... 0 0 United-States <=50K
2 2 Private 215646 HS-grad ... 0 2 United-States <=50K
3 3 Private 234721 11th ... 0 2 United-States <=50K
4 1 Private 338409 Bachelors ... 0 2 Cuba <=50K
... ... ... ... ... ... ... ... ... ...
48837 2 Private 215419 Bachelors ... 0 2 United-States <=50K
48838 4 NaN 321403 HS-grad ... 0 2 United-States <=50K
48839 2 Private 374983 Bachelors ... 0 3 United-States <=50K
48840 2 Private 83891 Bachelors ... 0 2 United-States <=50K
48841 1 Self-emp-inc 182148 Bachelors ... 0 3 United-States >50K

[48842 rows x 15 columns]

# 生成合成数据
sampled_data = model.generate(1000)
```

合成数据如下
仿真数据

```
age workclass fnlwgt ... hours-per-week native-country class
0 33 Private 276389 ... 41 United-States >50K
1 33 Self-emp-not-inc 296948 ... 54 United-States <=50K
2 67 Without-pay 266913 ... 51 Columbia <=50K
3 49 Private 423018 ... 41 United-States >50K
4 22 Private 295325 ... 39 United-States >50K
5 63 Private 234140 ... 65 United-States <=50K
6 42 Private 243623 ... 52 United-States <=50K
7 75 Private 247679 ... 41 United-States <=50K
8 79 Private 332237 ... 41 United-States >50K
9 28 State-gov 837932 ... 99 United-States <=50K
```python
>>> sampled_data
age workclass fnlwgt education ... capitalloss hoursperweek native-country class
0 1 NaN 28219 Some-college ... 0 2 Puerto-Rico <=50K
1 2 Private 250166 HS-grad ... 0 2 United-States >50K
2 2 Private 50304 HS-grad ... 0 2 United-States <=50K
3 4 Private 89318 Bachelors ... 0 2 Puerto-Rico >50K
4 1 Private 172149 Bachelors ... 0 3 United-States <=50K
.. ... ... ... ... ... ... ... ... ...
995 2 NaN 208938 Bachelors ... 0 1 United-States <=50K
996 2 Private 166416 Bachelors ... 2 2 United-States <=50K
997 2 NaN 336022 HS-grad ... 0 1 United-States <=50K
998 3 Private 198051 Masters ... 0 2 United-States >50K
999 1 NaN 41973 HS-grad ... 0 2 United-States <=50K

[1000 rows x 15 columns]
```

## 🤝 如何贡献
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ API Reference
Metadata and Inspectors <data_models/index>
Manager <manager>
Exceptions <exceptions>
Utils <utils>
4 changes: 4 additions & 0 deletions docs/source/api_reference/models/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@ Built-in ML Models
.. toctree::
:maxdepth: 2

ML Models <ml/index>


Build-in Statistical Models
-----------------------------

.. toctree::
:maxdepth: 2

Statistics Models <statistics/index>


Custom Models Relevant
-----------------------------
Expand Down
8 changes: 8 additions & 0 deletions docs/source/api_reference/models/ml/ctgan.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
CTGANSynthesizerModel
============================

.. automodule:: sdgx.models.ml.single_table.ctgan
:members:
:undoc-members:
:private-members:
:show-inheritance:
18 changes: 18 additions & 0 deletions docs/source/api_reference/models/ml/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Built-in ML Models
============================

Model for single table
----------------------------

.. toctree::
:maxdepth: 2

CTGAN <ctgan>



Model for multi table
----------------------------

.. toctree::
:maxdepth: 2
16 changes: 16 additions & 0 deletions docs/source/api_reference/models/statistics/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Built-in Statistics Models
============================

Model for single table
----------------------------

.. toctree::
:maxdepth: 2



Model for multi table
----------------------------

.. toctree::
:maxdepth: 2
7 changes: 7 additions & 0 deletions docs/source/api_reference/utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Utils
============================

.. automodule:: sdgx.utils
:members:
:undoc-members:
:inherited-members:
29 changes: 29 additions & 0 deletions docs/source/user_guides/library.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,31 @@
Use Synthetic Data Generator as a library
==================================================

.. code-block:: python
from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
from sdgx.utils import download_demo_data
dataset_csv = download_demo_data()
data_connector = CsvConnector(path=dataset_csv)
synthesizer = Synthesizer(
model=CTGANSynthesizerModel(epochs=1), # For quick demo
data_connector=data_connector,
)
synthesizer.fit()
sampled_data = synthesizer.sample(1000)
synthesizer.cleanup() # Clean all cache
from sdgx.metrics.column.jsd import JSD
JSD = JSD()
selected_columns = ["workclass"]
isDiscrete = True
metrics = JSD.calculate(data_connector.read(), sampled_data, selected_columns, isDiscrete)
print("JSD metric of column %s: %g" % (selected_columns[0], metrics))
30 changes: 30 additions & 0 deletions docs/source/user_guides/single_table.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,32 @@
Synthetic single-table data
==========================================


.. code-block:: python
from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
from sdgx.utils import download_demo_data
dataset_csv = download_demo_data()
data_connector = CsvConnector(path=dataset_csv)
synthesizer = Synthesizer(
model=CTGANSynthesizerModel(epochs=1), # For quick demo
data_connector=data_connector,
)
synthesizer.fit()
sampled_data = synthesizer.sample(1000)
synthesizer.cleanup() # Clean all cache
from sdgx.metrics.column.jsd import JSD
JSD = JSD()
selected_columns = ["workclass"]
isDiscrete = True
metrics = JSD.calculate(data_connector.read(), sampled_data, selected_columns, isDiscrete)
print("JSD metric of column %s: %g" % (selected_columns[0], metrics))
Loading

0 comments on commit 482f7d5

Please sign in to comment.