Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[0.1.0] Intro NDArryLoader #75

Merged
merged 6 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,4 @@ cython_debug/
*.log

.sdgx_cache
.ndarry_cache
101 changes: 60 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,61 +70,80 @@ pip install git+https://github.com/hitsz-ids/synthetic-data-generator.git
pip install sdgx
```

### Quick Demo of Single Table Data Generation
### Quick Demo of Single Table Data Generation and Metric

#### Demo code

```python
# Import modules
from sdgx.models.single_table.ctgan import CTGAN
from sdgx.utils.io.csv_utils import *
from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
from sdgx.utils import download_demo_data

# Read data from demo
demo_data, discrete_cols = get_demo_single_table()
```
dataset_csv = download_demo_data()
data_connector = CsvConnector(path=dataset_csv)
synthesizer = Synthesizer(
model=CTGANSynthesizerModel(epochs=1), # For quick demo
data_connector=data_connector,
)

Real data are as follows:
synthesizer.fit()
sampled_data = synthesizer.sample(1000)
synthesizer.cleanup() # Clean all cache

```
age workclass fnlwgt ... hours-per-week native-country class
0 27 Private 177119 ... 44 United-States <=50K
1 27 Private 216481 ... 40 United-States <=50K
2 25 Private 256263 ... 40 United-States <=50K
3 46 Private 147640 ... 40 United-States <=50K
4 45 Private 172822 ... 76 United-States >50K
... ... ... ... ... ... ... ...
32556 43 Local-gov 33331 ... 40 United-States >50K
32557 44 Private 98466 ... 35 United-States <=50K
32558 23 Private 45317 ... 40 United-States <=50K
32559 45 Local-gov 215862 ... 45 United-States >50K
32560 25 Private 186925 ... 48 United-States <=50K

[32561 rows x 15 columns]
# Optional, use JSD for mectics
from sdgx.metrics.column.jsd import JSD

JSD = JSD()

selected_columns = ["workclass"]
isDiscrete = True
metrics = JSD.calculate(data_connector.read(), sampled_data, selected_columns, isDiscrete)

print("JSD metric of column %s: %g" % (selected_columns[0], metrics))
```

#### Comparison

Real data are as follows:

```python
# Define model
model = CTGAN(epochs=10)
# Model training
model.fit(demo_data, discrete_cols)
>>> data_connector.read()
age workclass fnlwgt education ... capitalloss hoursperweek native-country class
0 2 State-gov 77516 Bachelors ... 0 2 United-States <=50K
1 3 Self-emp-not-inc 83311 Bachelors ... 0 0 United-States <=50K
2 2 Private 215646 HS-grad ... 0 2 United-States <=50K
3 3 Private 234721 11th ... 0 2 United-States <=50K
4 1 Private 338409 Bachelors ... 0 2 Cuba <=50K
... ... ... ... ... ... ... ... ... ...
48837 2 Private 215419 Bachelors ... 0 2 United-States <=50K
48838 4 NaN 321403 HS-grad ... 0 2 United-States <=50K
48839 2 Private 374983 Bachelors ... 0 3 United-States <=50K
48840 2 Private 83891 Bachelors ... 0 2 United-States <=50K
48841 1 Self-emp-inc 182148 Bachelors ... 0 3 United-States >50K

[48842 rows x 15 columns]

# Generate synthetic data
sampled_data = model.generate(1000)
```

Synthetic data are as follows:

```
age workclass fnlwgt ... hours-per-week native-country class
0 33 Private 276389 ... 41 United-States >50K
1 33 Self-emp-not-inc 296948 ... 54 United-States <=50K
2 67 Without-pay 266913 ... 51 Columbia <=50K
3 49 Private 423018 ... 41 United-States >50K
4 22 Private 295325 ... 39 United-States >50K
5 63 Private 234140 ... 65 United-States <=50K
6 42 Private 243623 ... 52 United-States <=50K
7 75 Private 247679 ... 41 United-States <=50K
8 79 Private 332237 ... 41 United-States >50K
9 28 State-gov 837932 ... 99 United-States <=50K
```python
>>> sampled_data
age workclass fnlwgt education ... capitalloss hoursperweek native-country class
0 1 NaN 28219 Some-college ... 0 2 Puerto-Rico <=50K
1 2 Private 250166 HS-grad ... 0 2 United-States >50K
2 2 Private 50304 HS-grad ... 0 2 United-States <=50K
3 4 Private 89318 Bachelors ... 0 2 Puerto-Rico >50K
4 1 Private 172149 Bachelors ... 0 3 United-States <=50K
.. ... ... ... ... ... ... ... ... ...
995 2 NaN 208938 Bachelors ... 0 1 United-States <=50K
996 2 Private 166416 Bachelors ... 2 2 United-States <=50K
997 2 NaN 336022 HS-grad ... 0 1 United-States <=50K
998 3 Private 198051 Masters ... 0 2 United-States >50K
999 1 NaN 41973 HS-grad ... 0 2 United-States <=50K

[1000 rows x 15 columns]
```

## 🤝 Join Community
Expand Down
101 changes: 60 additions & 41 deletions README_ZH_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,59 +71,78 @@ pip install sdgx

### 单表数据快速合成示例

#### 演示代码

```python
# 导入相关模块
from sdgx.models.single_table.ctgan import CTGAN
from sdgx.utils.io.csv_utils import *
from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
from sdgx.utils import download_demo_data

# 读取数据
demo_data, discrete_cols = get_demo_single_table()
```
dataset_csv = download_demo_data()
data_connector = CsvConnector(path=dataset_csv)
synthesizer = Synthesizer(
model=CTGANSynthesizerModel(epochs=1), # For quick demo
data_connector=data_connector,
)

真实数据示例如下:
synthesizer.fit()
sampled_data = synthesizer.sample(1000)
synthesizer.cleanup() # Clean all cache

```
age workclass fnlwgt ... hours-per-week native-country class
0 27 Private 177119 ... 44 United-States <=50K
1 27 Private 216481 ... 40 United-States <=50K
2 25 Private 256263 ... 40 United-States <=50K
3 46 Private 147640 ... 40 United-States <=50K
4 45 Private 172822 ... 76 United-States >50K
... ... ... ... ... ... ... ...
32556 43 Local-gov 33331 ... 40 United-States >50K
32557 44 Private 98466 ... 35 United-States <=50K
32558 23 Private 45317 ... 40 United-States <=50K
32559 45 Local-gov 215862 ... 45 United-States >50K
32560 25 Private 186925 ... 48 United-States <=50K

[32561 rows x 15 columns]
# Optional, use JSD for mectics
from sdgx.metrics.column.jsd import JSD

JSD = JSD()

selected_columns = ["workclass"]
isDiscrete = True
metrics = JSD.calculate(data_connector.read(), sampled_data, selected_columns, isDiscrete)

print("JSD metric of column %s: %g" % (selected_columns[0], metrics))
```

#### 对比

真实数据:

```python
#定义模型
model = CTGAN(epochs=10)
# 训练模型
model.fit(demo_data, discrete_cols)
>>> data_connector.read()
age workclass fnlwgt education ... capitalloss hoursperweek native-country class
0 2 State-gov 77516 Bachelors ... 0 2 United-States <=50K
1 3 Self-emp-not-inc 83311 Bachelors ... 0 0 United-States <=50K
2 2 Private 215646 HS-grad ... 0 2 United-States <=50K
3 3 Private 234721 11th ... 0 2 United-States <=50K
4 1 Private 338409 Bachelors ... 0 2 Cuba <=50K
... ... ... ... ... ... ... ... ... ...
48837 2 Private 215419 Bachelors ... 0 2 United-States <=50K
48838 4 NaN 321403 HS-grad ... 0 2 United-States <=50K
48839 2 Private 374983 Bachelors ... 0 3 United-States <=50K
48840 2 Private 83891 Bachelors ... 0 2 United-States <=50K
48841 1 Self-emp-inc 182148 Bachelors ... 0 3 United-States >50K

[48842 rows x 15 columns]

# 生成合成数据
sampled_data = model.generate(1000)
```

合成数据如下
仿真数据

```
age workclass fnlwgt ... hours-per-week native-country class
0 33 Private 276389 ... 41 United-States >50K
1 33 Self-emp-not-inc 296948 ... 54 United-States <=50K
2 67 Without-pay 266913 ... 51 Columbia <=50K
3 49 Private 423018 ... 41 United-States >50K
4 22 Private 295325 ... 39 United-States >50K
5 63 Private 234140 ... 65 United-States <=50K
6 42 Private 243623 ... 52 United-States <=50K
7 75 Private 247679 ... 41 United-States <=50K
8 79 Private 332237 ... 41 United-States >50K
9 28 State-gov 837932 ... 99 United-States <=50K
```python
>>> sampled_data
age workclass fnlwgt education ... capitalloss hoursperweek native-country class
0 1 NaN 28219 Some-college ... 0 2 Puerto-Rico <=50K
1 2 Private 250166 HS-grad ... 0 2 United-States >50K
2 2 Private 50304 HS-grad ... 0 2 United-States <=50K
3 4 Private 89318 Bachelors ... 0 2 Puerto-Rico >50K
4 1 Private 172149 Bachelors ... 0 3 United-States <=50K
.. ... ... ... ... ... ... ... ... ...
995 2 NaN 208938 Bachelors ... 0 1 United-States <=50K
996 2 Private 166416 Bachelors ... 2 2 United-States <=50K
997 2 NaN 336022 HS-grad ... 0 1 United-States <=50K
998 3 Private 198051 Masters ... 0 2 United-States >50K
999 1 NaN 41973 HS-grad ... 0 2 United-States <=50K

[1000 rows x 15 columns]
```

## 🤝 如何贡献
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ API Reference
Metadata and Inspectors <data_models/index>
Manager <manager>
Exceptions <exceptions>
Utils <utils>
4 changes: 4 additions & 0 deletions docs/source/api_reference/models/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@ Built-in ML Models
.. toctree::
:maxdepth: 2

ML Models <ml/index>


Build-in Statistical Models
-----------------------------

.. toctree::
:maxdepth: 2

Statistics Models <statistics/index>


Custom Models Relevant
-----------------------------
Expand Down
8 changes: 8 additions & 0 deletions docs/source/api_reference/models/ml/ctgan.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
CTGANSynthesizerModel
============================

.. automodule:: sdgx.models.ml.single_table.ctgan
:members:
:undoc-members:
:private-members:
:show-inheritance:
18 changes: 18 additions & 0 deletions docs/source/api_reference/models/ml/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Built-in ML Models
============================

Model for single table
----------------------------

.. toctree::
:maxdepth: 2

CTGAN <ctgan>



Model for multi table
----------------------------

.. toctree::
:maxdepth: 2
16 changes: 16 additions & 0 deletions docs/source/api_reference/models/statistics/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Built-in Statistics Models
============================

Model for single table
----------------------------

.. toctree::
:maxdepth: 2



Model for multi table
----------------------------

.. toctree::
:maxdepth: 2
7 changes: 7 additions & 0 deletions docs/source/api_reference/utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Utils
============================

.. automodule:: sdgx.utils
:members:
:undoc-members:
:inherited-members:
29 changes: 29 additions & 0 deletions docs/source/user_guides/library.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,31 @@
Use Synthetic Data Generator as a library
==================================================

.. code-block:: python

from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
from sdgx.utils import download_demo_data

dataset_csv = download_demo_data()
data_connector = CsvConnector(path=dataset_csv)
synthesizer = Synthesizer(
model=CTGANSynthesizerModel(epochs=1), # For quick demo
data_connector=data_connector,
)
synthesizer.fit()
sampled_data = synthesizer.sample(1000)
synthesizer.cleanup() # Clean all cache


from sdgx.metrics.column.jsd import JSD

JSD = JSD()


selected_columns = ["workclass"]
isDiscrete = True
metrics = JSD.calculate(data_connector.read(), sampled_data, selected_columns, isDiscrete)

print("JSD metric of column %s: %g" % (selected_columns[0], metrics))
30 changes: 30 additions & 0 deletions docs/source/user_guides/single_table.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,32 @@
Synthetic single-table data
==========================================


.. code-block:: python

from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
from sdgx.utils import download_demo_data

dataset_csv = download_demo_data()
data_connector = CsvConnector(path=dataset_csv)
synthesizer = Synthesizer(
model=CTGANSynthesizerModel(epochs=1), # For quick demo
data_connector=data_connector,
)
synthesizer.fit()
sampled_data = synthesizer.sample(1000)
synthesizer.cleanup() # Clean all cache


from sdgx.metrics.column.jsd import JSD

JSD = JSD()


selected_columns = ["workclass"]
isDiscrete = True
metrics = JSD.calculate(data_connector.read(), sampled_data, selected_columns, isDiscrete)

print("JSD metric of column %s: %g" % (selected_columns[0], metrics))
Loading