Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[0.1.0] Refactoring CTGAN for DataLoader #72

Merged
merged 13 commits into from
Dec 18, 2023
4 changes: 4 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[run]
omit =
*/tests/*
*/sdgx/models/components/sdv_*
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
python -m pip install -e .[test]
- name: Test with pytest
run: |
pytest -vv --cov=sdgx tests
pytest -vv --cov-config=.coveragerc --cov=sdgx/ tests
- name: Install dependencies for building
run: |
pip install build twine hatch
Expand Down
3 changes: 1 addition & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ pip install -e .[test]
We use pytest to write unit tests, and use pytest-cov to generate coverage reports

```bash
pytest -v
pytest --cov=sdgx # Generate coverage reports
pytest -vv --cov-config=.coveragerc --cov=sdgx/ tests
```

Run unit-test before PR, **ensure that new features are covered by unit tests**
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ High-quality synthetic data can also be used in various fields such as data open
- Provide distributed training support for deep learning models with frameworks such as torch.
- Privacy enhancements:
- SDG supports differential privacy, anonymization and other methods to enhance the security of synthetic data.
- Easy to Extend
- Easy to extend
- Supports expansion of models, data processing, data connectors, etc. in the form of plug-in packages

Read [the latest API docs](https://synthetic-data-generator.readthedocs.io/en/latest/) for more details.
Expand Down
3 changes: 1 addition & 2 deletions docs/source/developer_guides/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ coverage reports

.. code:: bash

pytest -v
pytest --cov=sdgx # Generate coverage reports
pytest -vv --cov-config=.coveragerc --cov=sdgx/ tests # Generate coverage reports

Run unit-test before PR, **ensure that new features are covered by unit
tests**
Expand Down
6 changes: 4 additions & 2 deletions sdgx/cachers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ def load_all(self, data_connector: DataConnector) -> pd.DataFrame:
"""
Load all data from data_connector or cache
"""

raise NotImplementedError
return pd.concat(
self.iter(chunksize=self.blocksize, data_connector=data_connector),
ignore_index=True,
)

def clear_cache(self):
"""
Expand Down
11 changes: 2 additions & 9 deletions sdgx/cachers/disk_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import lru_cache
from pathlib import Path
from typing import Generator

Expand Down Expand Up @@ -81,6 +82,7 @@ def _refresh(self, offset: int, data: pd.DataFrame) -> None:
else:
data.to_parquet(self._get_cache_filename(offset))

@lru_cache(maxsize=64)
def load(self, offset: int, chunksize: int, data_connector: DataConnector) -> pd.DataFrame:
"""
Load data from data_connector or cache
Expand All @@ -106,15 +108,6 @@ def load(self, offset: int, chunksize: int, data_connector: DataConnector) -> pd
return data
return data[:chunksize]

def load_all(self, data_connector: DataConnector) -> pd.DataFrame:
"""
Load all data from data_connector or cache
"""
return pd.concat(
self.iter(chunksize=self.blocksize, data_connector=data_connector),
ignore_index=True,
)

def iter(
self, chunksize: int, data_connector: DataConnector
) -> Generator[pd.DataFrame, None, None]:
Expand Down
7 changes: 0 additions & 7 deletions sdgx/cachers/memory_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,6 @@ def load(self, offset: int, chunksize: int, data_connector: DataConnector) -> pd
return data
return data[:chunksize]

def load_all(self, data_connector: DataConnector) -> pd.DataFrame:
# Concat all dataframe
return pd.concat(
self.iter(chunksize=self.blocksize, data_connector=data_connector),
ignore_index=True,
)

def iter(
self, chunksize: int, data_connector: DataConnector
) -> Generator[pd.DataFrame, None, None]:
Expand Down
8 changes: 4 additions & 4 deletions sdgx/data_connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DataConnector:
Identity of data source, e.g. table name, hash of content
"""

def _read(self, offset=0, limit=None) -> pd.DataFrame:
def _read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame:
"""
Subclass must implement this for reading data.

Expand All @@ -33,15 +33,15 @@ def _columns(self) -> list[str]:
"""
raise NotImplementedError

def _iter(self, offset=0, chunksize=0) -> Generator[pd.DataFrame, None, None]:
def _iter(self, offset: int = 0, chunksize: int = 0) -> Generator[pd.DataFrame, None, None]:
"""
Subclass should implement this for reading data in chunk.

See ``iter`` for more details.
"""
raise NotImplementedError

def iter(self, offset=0, chunksize=0) -> Generator[pd.DataFrame, None, None]:
def iter(self, offset: int = 0, chunksize: int = 0) -> Generator[pd.DataFrame, None, None]:
"""
Interface for reading data in chunk.

Expand All @@ -54,7 +54,7 @@ def iter(self, offset=0, chunksize=0) -> Generator[pd.DataFrame, None, None]:
"""
return self._iter(offset, chunksize)

def read(self, offset=0, limit=None) -> pd.DataFrame:
def read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame:
"""
Interface for reading data.

Expand Down
9 changes: 5 additions & 4 deletions sdgx/data_connectors/csv_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def __init__(
self.header = header
self.read_csv_kwargs = read_csv_kwargs

def _read(self, offset=0, limit=None) -> pd.DataFrame:
def _read(self, offset: int = 0, limit: int | None = None) -> pd.DataFrame:
""" """
return pd.read_csv(
self.path,
sep=self.sep,
header=self.header,
skiprows=offset,
skiprows=range(1, offset),
nrows=limit,
**self.read_csv_kwargs,
)
Expand All @@ -69,10 +69,11 @@ def _columns(self) -> list[str]:
sep=self.sep,
header=self.header,
nrows=0,
**self.read_csv_kwargs,
).columns.tolist()
return d

def _iter(self, offset=0, chunksize=1000) -> Generator[pd.DataFrame, None, None]:
def _iter(self, offset: int = 0, chunksize: int = 1000) -> Generator[pd.DataFrame, None, None]:
if chunksize is None:
yield self._read(offset=offset)
return
Expand All @@ -81,7 +82,7 @@ def _iter(self, offset=0, chunksize=1000) -> Generator[pd.DataFrame, None, None]
self.path,
sep=self.sep,
header=self.header,
skiprows=offset,
skiprows=range(1, offset),
chunksize=chunksize,
**self.read_csv_kwargs,
):
Expand Down
52 changes: 52 additions & 0 deletions sdgx/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from functools import cached_property
from typing import Any, Generator

import pandas as pd

from sdgx.cachers.base import Cacher
from sdgx.cachers.manager import CacherManager
from sdgx.data_connectors.base import DataConnector
from sdgx.utils import cache


class DataLoader:
Expand Down Expand Up @@ -78,3 +80,53 @@ def finalize(self, clear_cache=False) -> None:
self.data_connector.finalize()
if clear_cache:
self.cacher.clear_cache()

def __getitem__(self, key: list | slice | tuple) -> pd.DataFrame:
"""
Support get data by index and slice

Warning:

This is very tricky when using :ref:`GeneratorConnector` with a :ref:`Cacher`.
When calling ``len``, will iterate and store all data in cache.
Then we can ``load`` the data from cache. This makes accessing data in correct index.

If using :ref:`GeneratorConnector` with :ref:`NoCache`, the index will be wrong
and this may totally broken.

"""
if isinstance(key, list):
sli = None
rows = key
else:
sli = key
rows = None

if not sli:
return pd.concat((d[rows] for d in self.iter()), ignore_index=True)

start = sli.start or 0
stop = sli.stop or len(self)
step = sli.step or 1

offset = (start // self.chunksize) * self.chunksize
n_iter = ((stop - start) // self.chunksize) + 1

tables = (
self.cacher.load(
offset=offset + i * self.chunksize,
chunksize=self.chunksize,
data_connector=self.data_connector,
)
for i in range(n_iter)
)

return pd.concat(tables, ignore_index=True)[start - offset : stop - offset : step]

@cache
def __len__(self):
return sum(len(l) for l in self.iter())

@cached_property
def shape(self):
return (len(self), len(self.columns()))
12 changes: 12 additions & 0 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import json
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List

import pandas as pd
Expand Down Expand Up @@ -88,3 +90,13 @@ def from_dataframe(
metadata.update(inspector.inspect())

return metadata

def save(self, path: str | Path):
with path.open("w") as f:
f.write(self.model_dump_json())

@classmethod
def load(cls, path: str | Path) -> "Metadata":
path = Path(path).expanduser().resolve()
attributes = json.load(path.open("r"))
return Metadata().update(attributes)
13 changes: 9 additions & 4 deletions sdgx/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,16 @@ def init(self, cls_name, **kwargs: dict[str, Any]):
NotFoundError: if cls_name is not registered
InitializationError: if failed to initialize
"""
cls_name = self._normalize_name(cls_name)
if not cls_name in self.registed_cls:
raise NotFoundError
if isinstance(cls_name, type):
cls_type = cls_name
else:
cls_name = self._normalize_name(cls_name)

if not cls_name in self.registed_cls:
raise NotFoundError
cls_type = self.registed_cls[cls_name]
try:
instance = self.registed_cls[cls_name](**kwargs)
instance = cls_type(**kwargs)
if not isinstance(instance, self.register_type):
raise InitializationError(f"{cls_name} is not a subclass of {self.register_type}.")
return instance
Expand Down
2 changes: 1 addition & 1 deletion sdgx/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class SynthesizerModel:
def fit(metadata: Metadata, dataloader: DataLoader, *args, **kwargs):
def fit(self, metadata: Metadata, dataloader: DataLoader, *args, **kwargs):
raise NotImplementedError

def sample(self, count: int, *args, **kwargs) -> pd.DataFrame:
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""
Refer CTGAN Version 0.6.0: https://github.com/sdv-dev/CTGAN@a40570e321cb46d798a823f350e1010a0270d804
Which is Lincensed by MIT License
"""
"""DataSampler module."""
from __future__ import annotations

import numpy as np

from sdgx.data_loader import DataLoader

class DataSamplerCTGAN:

class DataSampler(object):
"""DataSampler samples the conditional vector and corresponding data for CTGAN."""

def __init__(self, data, output_info, log_frequency):
self._data = data
def __init__(self, dataloader: DataLoader | np.ndarray, output_info, log_frequency):
self._data: DataLoader | np.ndarray = dataloader

def is_discrete_column(column_info):
return len(column_info) == 1 and column_info[0].activation_fn == "softmax"
Expand All @@ -35,12 +35,12 @@ def is_discrete_column(column_info):

rid_by_cat = []
for j in range(span_info.dim):
rid_by_cat.append(np.nonzero(data[:, st + j])[0])
rid_by_cat.append(np.nonzero(dataloader[:, st + j])[0])
self._rid_by_cat_cols.append(rid_by_cat)
st = ed
else:
st += sum([span_info.dim for span_info in column_info])
assert st == data.shape[1]
assert st == dataloader.shape[1]

# Prepare an interval matrix for efficiently sample conditional vector
max_category = max(
Expand All @@ -63,7 +63,7 @@ def is_discrete_column(column_info):
if is_discrete_column(column_info):
span_info = column_info[0]
ed = st + span_info.dim
category_freq = np.sum(data[:, st:ed], axis=0)
category_freq = np.sum(dataloader[:, st:ed], axis=0)
if log_frequency:
category_freq = np.log(category_freq + 1)
category_prob = category_freq / np.sum(category_freq)
Expand Down
Loading