Skip to content

Commit

Permalink
Intro dummy table for speedup models case (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wh1isper committed Dec 28, 2023
1 parent f447d8d commit 76e3eb0
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 8 deletions.
46 changes: 46 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

os.environ["SDG_NDARRAY_CACHE_ROOT"] = "/tmp/sdgx/ndarray_cache"

import random
import shutil
import string
from functools import partial

import pandas as pd
import pytest

from sdgx.data_connectors.csv_connector import CsvConnector
Expand All @@ -18,6 +21,49 @@
DATA_DIR = os.path.join(_HERE, "dataset")


def ramdon_str():
return "".join(random.choice(string.ascii_letters) for _ in range(10))


@pytest.fixture
def dummy_single_table_path(tmp_path):
dummy_size = 10
role_set = ["admin", "user", "guest"]

df = pd.DataFrame(
{
"role": [random.choice(role_set) for _ in range(dummy_size)],
"name": [ramdon_str() for _ in range(dummy_size)],
"feature_x": [random.random() for _ in range(dummy_size)],
"feature_y": [random.random() for _ in range(dummy_size)],
"feature_z": [random.random() for _ in range(dummy_size)],
}
)
save_path = tmp_path / "dummy.csv"
df.to_csv(save_path, index=False, header=True)
yield save_path
save_path.unlink()


@pytest.fixture
def dummy_single_table_data_connector(dummy_single_table_path):
yield CsvConnector(
path=dummy_single_table_path,
)


@pytest.fixture
def dummy_single_table_data_loader(dummy_single_table_data_connector, cacher_kwargs):
d = DataLoader(dummy_single_table_data_connector, cacher_kwargs=cacher_kwargs)
yield d
d.finalize()


@pytest.fixture
def dummy_single_table_metadata(dummy_single_table_data_loader):
yield Metadata.from_dataloader(dummy_single_table_data_loader)


@pytest.fixture
def demo_single_table_path():
yield download_demo_data(DATA_DIR).as_posix()
Expand Down
19 changes: 16 additions & 3 deletions tests/models/test_copula.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
from pathlib import Path

import pandas as pd
import pytest

from sdgx.models.statistics.single_table.copula import GaussianCopulaSynthesizer
from sdgx.utils import get_demo_single_table


def test_gaussian_copula(demo_single_table_path):
demo_data, discrete_cols = get_demo_single_table(Path(demo_single_table_path).parent)
@pytest.fixture
def dummy_data(dummy_single_table_path):
yield pd.read_csv(dummy_single_table_path)


@pytest.fixture
def discrete_cols(dummy_data):
yield [col for col in dummy_data.columns if not col.startswith("feature")]


def test_gaussian_copula(dummy_data, discrete_cols):
model = GaussianCopulaSynthesizer(discrete_cols)
model.fit(demo_data)
model.fit(dummy_data)

sampled_data = model.sample(10)
assert len(sampled_data) == 10
assert sampled_data.columns.tolist() == dummy_data.columns.tolist()
20 changes: 15 additions & 5 deletions tests/models/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,27 @@ def save_model_dir(tmp_path):
shutil.rmtree(dirname, ignore_errors=True)


def assert_sampled_data(dummy_single_table_data_loader, sampled_data, count):
assert len(sampled_data) == count
assert sampled_data.columns.tolist() == dummy_single_table_data_loader.columns()


def test_ctgan(
ctgan: CTGANSynthesizerModel,
demo_single_table_metadata,
demo_single_table_data_loader,
dummy_single_table_metadata,
dummy_single_table_data_loader,
save_model_dir,
):
ctgan.fit(demo_single_table_metadata, demo_single_table_data_loader)
ctgan.sample(10)
ctgan.fit(dummy_single_table_metadata, dummy_single_table_data_loader)
sampled_data = ctgan.sample(10)
assert_sampled_data(dummy_single_table_data_loader, sampled_data, 10)

ctgan.save(save_model_dir)
assert save_model_dir.exists()

model = CTGANSynthesizerModel.load(save_model_dir)
model.sample(10)
sampled_data = model.sample(10)
assert_sampled_data(dummy_single_table_data_loader, sampled_data, 10)


if __name__ == "__main__":
Expand Down

0 comments on commit 76e3eb0

Please sign in to comment.