-
Notifications
You must be signed in to change notification settings - Fork 541
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
222 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
dataset/ | ||
.ndarry_cache | ||
.sdgx_cache |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# WIP | ||
|
||
Please help us to improve our benchmark: https://github.com/hitsz-ids/synthetic-data-generator/issues/82 | ||
|
||
## Benchmarks | ||
|
||
Benchmarks aim to measure the performance of the library. | ||
|
||
- Performance: Processing time, Training time of model, Simpling rate... | ||
- Memory Consumption | ||
- Others, like cache hit rate... | ||
|
||
Now we provide a simple benchmark for our CTGAN implementation against the original one. Fit them with a big ramdom dataset, and compare their memory consumptions. | ||
|
||
### Setup | ||
|
||
```bash | ||
# Clone and install latest version | ||
# You can also use our latest image: docker pull idsteam/sdgx:latest | ||
git clone https://github.com/hitsz-ids/synthetic-data-generator.git | ||
cd synthetic-data-generator && pip install -e ./ | ||
# Setup benchmark | ||
cd benchmarks | ||
pip install -r requirements.txt | ||
``` | ||
|
||
Generate a dataset with `python generate_dataset.py`, you can use `python generate_dataset.py --help` to see the usage. | ||
|
||
### Benchmark our implementation | ||
|
||
We use [memory_profiler](https://github.com/pythonprofilers/memory_profiler) to benchmark our implementation. | ||
|
||
```bash | ||
mprof run python ./sdgx_ctgan.py | ||
``` | ||
|
||
Plot the results with `mprof plot` or `mprof plot --output=sdgx_ctgan.png` to save the plot. | ||
|
||
### Benchmark original implementation | ||
|
||
```bash | ||
pip install ctgan | ||
mprof run python ./sdv_ctgan.py | ||
``` | ||
|
||
Plot the results with `mprof plot` or `mprof plot --output=sdv_ctgan.png` to save the plot. | ||
|
||
## Results | ||
|
||
In default settings, our implementation can fit 1,000,000 x 50 size dataset in 32GB(usable nearly 20GB) memory mechine. And the original implementation need more than 20GB memory and crashed during training. | ||
|
||
![succeed-memory-sdgx-ctgan](./img/succeed-memory-sdgx-ctgan.png) | ||
|
||
![failed-memory-sdv_ctgan](./img/failed-memory-sdv_ctgan.png) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import datetime | ||
import itertools | ||
import random | ||
import string | ||
from pathlib import Path | ||
|
||
import click | ||
|
||
_HERE = Path(__file__).parent | ||
|
||
|
||
def random_string(length): | ||
return "".join(random.choice(string.ascii_lowercase) for i in range(length)) | ||
|
||
|
||
def random_float(): | ||
return random.random() * 1000 | ||
|
||
|
||
def random_int(): | ||
return random.randint(0, 1000) | ||
|
||
|
||
def random_timestamp(): | ||
current_timestamp = datetime.datetime.now().timestamp() | ||
return current_timestamp + random_int() | ||
|
||
|
||
def random_datetime(): | ||
return datetime.datetime.fromtimestamp(random_timestamp()) | ||
|
||
|
||
@click.option( | ||
"--output_file", | ||
default=(_HERE / "dataset/benchmark.csv").as_posix(), | ||
) | ||
@click.option("--num_rows", default=1_000_000) | ||
@click.option("--int_cols", default=15) | ||
@click.option("--float_cols", default=15) | ||
@click.option("--string_cols", default=10) | ||
@click.option("--string_discrete_nums", default=50) | ||
@click.option("--timestamp_cols", default=10) | ||
@click.option("--datetime_cols", default=0) | ||
@click.command() | ||
def generate_dateset( | ||
output_file, | ||
num_rows, | ||
int_cols, | ||
float_cols, | ||
string_cols, | ||
string_discrete_nums, | ||
timestamp_cols, | ||
datetime_cols, | ||
): | ||
headers = itertools.chain.from_iterable( | ||
[ | ||
(f"int_col{i}" for i in range(int_cols)), | ||
(f"float_col{i}" for i in range(float_cols)), | ||
(f"string_col{i}" for i in range(string_cols)), | ||
(f"timestamp_col{i}" for i in range(timestamp_cols)), | ||
(f"datetime_col{i}" for i in range(datetime_cols)), | ||
] | ||
) | ||
output_file = Path(output_file).expanduser().resolve() | ||
output_file.parent.mkdir(parents=True, exist_ok=True) | ||
|
||
random_str_list = [random_string(25) for i in range(string_discrete_nums)] | ||
|
||
def _generate_one_line(): | ||
return ",".join( | ||
map( | ||
str, | ||
itertools.chain( | ||
(random_int() for _ in range(int_cols)), | ||
(random_float() for _ in range(float_cols)), | ||
(random.choice(random_str_list) for _ in range(string_cols)), | ||
(random_timestamp() for _ in range(timestamp_cols)), | ||
(random_datetime() for _ in range(datetime_cols)), | ||
), | ||
) | ||
) | ||
|
||
with output_file.open("w") as f: | ||
f.write(",".join(headers) + "\n") | ||
|
||
chunk_size = 1000 | ||
for i in range(0, num_rows, chunk_size): | ||
f.write("\n".join(_generate_one_line() for _ in range(chunk_size))) | ||
f.write("\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
generate_dateset() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
click | ||
memory_profiler | ||
psutil |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os | ||
|
||
os.environ["SDGX_LOG_LEVEL"] = "DEBUG" | ||
|
||
|
||
from pathlib import Path | ||
|
||
from sdgx.data_connectors.csv_connector import CsvConnector | ||
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel | ||
from sdgx.synthesizer import Synthesizer | ||
|
||
_HERE = Path(__file__).parent | ||
|
||
dataset_csv = (_HERE / "dataset/benchmark.csv").expanduser().resolve() | ||
data_connector = CsvConnector(path=dataset_csv) | ||
synthesizer = Synthesizer( | ||
model=CTGANSynthesizerModel, | ||
data_connector=data_connector, | ||
model_kwargs={"epochs": 1, "device": "cpu"}, | ||
) | ||
synthesizer.fit() | ||
# sampled_data = synthesizer.sample(1000) | ||
# synthesizer.cleanup() # Clean all cache |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from pathlib import Path | ||
|
||
import pandas as pd | ||
|
||
_HERE = Path(__file__).parent | ||
|
||
dataset_csv = (_HERE / "dataset/benchmark.csv").expanduser().resolve() | ||
df = pd.read_csv(dataset_csv) | ||
|
||
discrete_columns = [s for s in df.columns if s.startswith("string")] | ||
|
||
|
||
from ctgan import CTGAN | ||
|
||
ctgan = CTGAN(epochs=1, cuda=False) | ||
ctgan.fit(df, discrete_columns) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters