Skip to content

Commit

Permalink
[0.1.0] Refactoring CTGAN for DataLoader (#72)
Browse files Browse the repository at this point in the history
* Rewrite ctgan based on MIT Licnesed code and imp Synthesizer

* Add cov setting

* Fixing covrc

* Fixing cov

* Change cov command

* Dropping TorchSynthesizerModel

* Support ramdom access in dataloader

* Switch to optimized ctgan

* Fix annotations

* Fix testing

* Improve col slice performace

* Fix missing read_csv_kwargs
  • Loading branch information
Wh1isper committed Dec 18, 2023
1 parent 0814523 commit 2dab1e1
Show file tree
Hide file tree
Showing 30 changed files with 530 additions and 496 deletions.
4 changes: 4 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[run]
omit =
*/tests/*
*/sdgx/models/components/sdv_*
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
python -m pip install -e .[test]
- name: Test with pytest
run: |
pytest -vv --cov=sdgx tests
pytest -vv --cov-config=.coveragerc --cov=sdgx/ tests
- name: Install dependencies for building
run: |
pip install build twine hatch
Expand Down
3 changes: 1 addition & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ pip install -e .[test]
We use pytest to write unit tests, and use pytest-cov to generate coverage reports

```bash
pytest -v
pytest --cov=sdgx # Generate coverage reports
pytest -vv --cov-config=.coveragerc --cov=sdgx/ tests
```

Run unit-test before PR, **ensure that new features are covered by unit tests**
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ High-quality synthetic data can also be used in various fields such as data open
- Provide distributed training support for deep learning models with frameworks such as torch.
- Privacy enhancements:
- SDG supports differential privacy, anonymization and other methods to enhance the security of synthetic data.
- Easy to Extend
- Easy to extend
- Supports expansion of models, data processing, data connectors, etc. in the form of plug-in packages

Read [the latest API docs](https://synthetic-data-generator.readthedocs.io/en/latest/) for more details.
Expand Down
3 changes: 1 addition & 2 deletions docs/source/developer_guides/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ coverage reports

.. code:: bash
pytest -v
pytest --cov=sdgx # Generate coverage reports
pytest -vv --cov-config=.coveragerc --cov=sdgx/ tests # Generate coverage reports
Run unit-test before PR, **ensure that new features are covered by unit
tests**
Expand Down
6 changes: 4 additions & 2 deletions sdgx/cachers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ def load_all(self, data_connector: DataConnector) -> pd.DataFrame:
"""
Load all data from data_connector or cache
"""

raise NotImplementedError
return pd.concat(
self.iter(chunksize=self.blocksize, data_connector=data_connector),
ignore_index=True,
)

def clear_cache(self):
"""
Expand Down
11 changes: 2 additions & 9 deletions sdgx/cachers/disk_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import lru_cache
from pathlib import Path
from typing import Generator

Expand Down Expand Up @@ -81,6 +82,7 @@ def _refresh(self, offset: int, data: pd.DataFrame) -> None:
else:
data.to_parquet(self._get_cache_filename(offset))

@lru_cache(maxsize=64)
def load(self, offset: int, chunksize: int, data_connector: DataConnector) -> pd.DataFrame:
"""
Load data from data_connector or cache
Expand All @@ -106,15 +108,6 @@ def load(self, offset: int, chunksize: int, data_connector: DataConnector) -> pd
return data
return data[:chunksize]

def load_all(self, data_connector: DataConnector) -> pd.DataFrame:
"""
Load all data from data_connector or cache
"""
return pd.concat(
self.iter(chunksize=self.blocksize, data_connector=data_connector),
ignore_index=True,
)

def iter(
self, chunksize: int, data_connector: DataConnector
) -> Generator[pd.DataFrame, None, None]:
Expand Down
7 changes: 0 additions & 7 deletions sdgx/cachers/memory_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,6 @@ def load(self, offset: int, chunksize: int, data_connector: DataConnector) -> pd
return data
return data[:chunksize]

def load_all(self, data_connector: DataConnector) -> pd.DataFrame:
# Concat all dataframe
return pd.concat(
self.iter(chunksize=self.blocksize, data_connector=data_connector),
ignore_index=True,
)

def iter(
self, chunksize: int, data_connector: DataConnector
) -> Generator[pd.DataFrame, None, None]:
Expand Down
8 changes: 4 additions & 4 deletions sdgx/data_connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DataConnector:
Identity of data source, e.g. table name, hash of content
"""

def _read(self, offset=0, limit=None) -> pd.DataFrame:
def _read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame:
"""
Subclass must implement this for reading data.
Expand All @@ -33,15 +33,15 @@ def _columns(self) -> list[str]:
"""
raise NotImplementedError

def _iter(self, offset=0, chunksize=0) -> Generator[pd.DataFrame, None, None]:
def _iter(self, offset: int = 0, chunksize: int = 0) -> Generator[pd.DataFrame, None, None]:
"""
Subclass should implement this for reading data in chunk.
See ``iter`` for more details.
"""
raise NotImplementedError

def iter(self, offset=0, chunksize=0) -> Generator[pd.DataFrame, None, None]:
def iter(self, offset: int = 0, chunksize: int = 0) -> Generator[pd.DataFrame, None, None]:
"""
Interface for reading data in chunk.
Expand All @@ -54,7 +54,7 @@ def iter(self, offset=0, chunksize=0) -> Generator[pd.DataFrame, None, None]:
"""
return self._iter(offset, chunksize)

def read(self, offset=0, limit=None) -> pd.DataFrame:
def read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame:
"""
Interface for reading data.
Expand Down
9 changes: 5 additions & 4 deletions sdgx/data_connectors/csv_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def __init__(
self.header = header
self.read_csv_kwargs = read_csv_kwargs

def _read(self, offset=0, limit=None) -> pd.DataFrame:
def _read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame:
""" """
return pd.read_csv(
self.path,
sep=self.sep,
header=self.header,
skiprows=offset,
skiprows=range(1, offset),
nrows=limit,
**self.read_csv_kwargs,
)
Expand All @@ -69,10 +69,11 @@ def _columns(self) -> list[str]:
sep=self.sep,
header=self.header,
nrows=0,
**self.read_csv_kwargs,
).columns.tolist()
return d

def _iter(self, offset=0, chunksize=1000) -> Generator[pd.DataFrame, None, None]:
def _iter(self, offset: int = 0, chunksize: int = 1000) -> Generator[pd.DataFrame, None, None]:
if chunksize is None:
yield self._read(offset=offset)
return
Expand All @@ -81,7 +82,7 @@ def _iter(self, offset=0, chunksize=1000) -> Generator[pd.DataFrame, None, None]
self.path,
sep=self.sep,
header=self.header,
skiprows=offset,
skiprows=range(1, offset),
chunksize=chunksize,
**self.read_csv_kwargs,
):
Expand Down
52 changes: 52 additions & 0 deletions sdgx/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from functools import cached_property
from typing import Any, Generator

import pandas as pd

from sdgx.cachers.base import Cacher
from sdgx.cachers.manager import CacherManager
from sdgx.data_connectors.base import DataConnector
from sdgx.utils import cache


class DataLoader:
Expand Down Expand Up @@ -78,3 +80,53 @@ def finalize(self, clear_cache=False) -> None:
self.data_connector.finalize()
if clear_cache:
self.cacher.clear_cache()

def __getitem__(self, key: list | slice | tuple) -> pd.DataFrame:
"""
Support get data by index and slice
Warning:
This is very tricky when using :ref:`GeneratorConnector` with a :ref:`Cacher`.
When calling ``len``, will iterate and store all data in cache.
Then we can ``load`` the data from cache. This makes accessing data in correct index.
If using :ref:`GeneratorConnector` with :ref:`NoCache`, the index will be wrong
and this may totally broken.
"""
if isinstance(key, list):
sli = None
rows = key
else:
sli = key
rows = None

if not sli:
return pd.concat((d[rows] for d in self.iter()), ignore_index=True)

start = sli.start or 0
stop = sli.stop or len(self)
step = sli.step or 1

offset = (start // self.chunksize) * self.chunksize
n_iter = ((stop - start) // self.chunksize) + 1

tables = (
self.cacher.load(
offset=offset + i * self.chunksize,
chunksize=self.chunksize,
data_connector=self.data_connector,
)
for i in range(n_iter)
)

return pd.concat(tables, ignore_index=True)[start - offset : stop - offset : step]

@cache
def __len__(self):
return sum(len(l) for l in self.iter())

@cached_property
def shape(self):
return (len(self), len(self.columns()))
12 changes: 12 additions & 0 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import json
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List

import pandas as pd
Expand Down Expand Up @@ -88,3 +90,13 @@ def from_dataframe(
metadata.update(inspector.inspect())

return metadata

def save(self, path: str | Path):
with path.open("w") as f:
f.write(self.model_dump_json())

@classmethod
def load(cls, path: str | Path) -> "Metadata":
path = Path(path).expanduser().resolve()
attributes = json.load(path.open("r"))
return Metadata().update(attributes)
13 changes: 9 additions & 4 deletions sdgx/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,16 @@ def init(self, cls_name, **kwargs: dict[str, Any]):
NotFoundError: if cls_name is not registered
InitializationError: if failed to initialize
"""
cls_name = self._normalize_name(cls_name)
if not cls_name in self.registed_cls:
raise NotFoundError
if isinstance(cls_name, type):
cls_type = cls_name
else:
cls_name = self._normalize_name(cls_name)

if not cls_name in self.registed_cls:
raise NotFoundError
cls_type = self.registed_cls[cls_name]
try:
instance = self.registed_cls[cls_name](**kwargs)
instance = cls_type(**kwargs)
if not isinstance(instance, self.register_type):
raise InitializationError(f"{cls_name} is not a subclass of {self.register_type}.")
return instance
Expand Down
2 changes: 1 addition & 1 deletion sdgx/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class SynthesizerModel:
def fit(metadata: Metadata, dataloader: DataLoader, *args, **kwargs):
def fit(self, metadata: Metadata, dataloader: DataLoader, *args, **kwargs):
raise NotImplementedError

def sample(self, count: int, *args, **kwargs) -> pd.DataFrame:
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""
Refer CTGAN Version 0.6.0: https://github.com/sdv-dev/CTGAN@a40570e321cb46d798a823f350e1010a0270d804
Which is Lincensed by MIT License
"""
"""DataSampler module."""
from __future__ import annotations

import numpy as np

from sdgx.data_loader import DataLoader

class DataSamplerCTGAN:

class DataSampler(object):
"""DataSampler samples the conditional vector and corresponding data for CTGAN."""

def __init__(self, data, output_info, log_frequency):
self._data = data
def __init__(self, dataloader: DataLoader | np.ndarray, output_info, log_frequency):
self._data: DataLoader | np.ndarray = dataloader

def is_discrete_column(column_info):
return len(column_info) == 1 and column_info[0].activation_fn == "softmax"
Expand All @@ -35,12 +35,12 @@ def is_discrete_column(column_info):

rid_by_cat = []
for j in range(span_info.dim):
rid_by_cat.append(np.nonzero(data[:, st + j])[0])
rid_by_cat.append(np.nonzero(dataloader[:, st + j])[0])
self._rid_by_cat_cols.append(rid_by_cat)
st = ed
else:
st += sum([span_info.dim for span_info in column_info])
assert st == data.shape[1]
assert st == dataloader.shape[1]

# Prepare an interval matrix for efficiently sample conditional vector
max_category = max(
Expand All @@ -63,7 +63,7 @@ def is_discrete_column(column_info):
if is_discrete_column(column_info):
span_info = column_info[0]
ed = st + span_info.dim
category_freq = np.sum(data[:, st:ed], axis=0)
category_freq = np.sum(dataloader[:, st:ed], axis=0)
if log_frequency:
category_freq = np.log(category_freq + 1)
category_prob = category_freq / np.sum(category_freq)
Expand Down
Loading

0 comments on commit 2dab1e1

Please sign in to comment.