Skip to content

Commit

Permalink
Deprecate MultiTableMetadata in favor of using Metadata, add compat t…
Browse files Browse the repository at this point in the history
…ests
  • Loading branch information
lajohn4747 committed Aug 5, 2024
1 parent fb2c2d5 commit f22d815
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 8 deletions.
11 changes: 9 additions & 2 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,14 @@
)
from sdv.logging import disable_single_table_logger, get_sdv_logger
from sdv.metadata.metadata import Metadata
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.single_table.copulas import GaussianCopulaSynthesizer

SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer')
DEPRECATION_MSG = (
"The 'MultiTableMetadata' is deprecated. Please use the new "
"'Metadata' class for synthesizers."
)


class BaseMultiTableSynthesizer:
Expand Down Expand Up @@ -99,6 +104,8 @@ def _check_metadata_updated(self):

def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
self.metadata = metadata
if type(metadata) is MultiTableMetadata:
warnings.warn(DEPRECATION_MSG, FutureWarning)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message=r'.*column relationship.*')
self.metadata.validate()
Expand Down Expand Up @@ -206,8 +213,8 @@ def set_table_parameters(self, table_name, table_parameters):
self._table_parameters[table_name].update(deepcopy(table_parameters))

def get_metadata(self):
"""Return the ``MultiTableMetadata`` for this synthesizer."""
return self.metadata
"""Return the ``Metadata`` for this synthesizer."""
return Metadata.load_from_dict(self.metadata.to_dict())

def _validate_all_tables(self, data):
"""Validate every table of the data has a valid table/metadata pair."""
Expand Down
154 changes: 154 additions & 0 deletions tests/integration/metadata/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import os
import re

import pytest

from sdv.datasets.demo import download_demo
from sdv.metadata.metadata import Metadata
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.multi_table.hma import HMASynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer

DEFAULT_TABLE_NAME = 'default_table_name'


def test_metadata():
Expand Down Expand Up @@ -216,3 +227,146 @@ def test_detect_table_from_csv(tmp_path):
}

assert metadata.to_dict() == expected_metadata


def test_single_table_compatibility(tmp_path):
"""Test if SingleMetadataTable still has compatibility with single table synthesizers."""
# Setup
data, _ = download_demo('single_table', 'fake_hotel_guests')
warn_msg = (
"The 'SingleTableMetadata' is deprecated. Please use the new "
"'Metadata' class for synthesizers."
)

single_table_metadata_dict = {
'primary_key': 'guest_email',
'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1',
'columns': {
'guest_email': {'sdtype': 'email', 'pii': True},
'has_rewards': {'sdtype': 'boolean'},
'room_type': {'sdtype': 'categorical'},
'amenities_fee': {'sdtype': 'numerical', 'computer_representation': 'Float'},
'checkin_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'},
'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'},
'room_rate': {'sdtype': 'numerical', 'computer_representation': 'Float'},
'billing_address': {'sdtype': 'address', 'pii': True},
'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True},
},
}
metadata = SingleTableMetadata.load_from_dict(single_table_metadata_dict)
assert isinstance(metadata, SingleTableMetadata)

# Run
with pytest.warns(FutureWarning, match=warn_msg):
synthesizer = GaussianCopulaSynthesizer(metadata)
synthesizer.fit(data)
model_path = os.path.join(tmp_path, 'synthesizer.pkl')
synthesizer.save(model_path)

# Assert
assert os.path.exists(model_path)
assert os.path.isfile(model_path)
loaded_synthesizer = GaussianCopulaSynthesizer.load(model_path)
assert isinstance(synthesizer, GaussianCopulaSynthesizer)
assert loaded_synthesizer.get_info() == synthesizer.get_info()
assert loaded_synthesizer.metadata.to_dict() == metadata.to_dict()
loaded_sample = loaded_synthesizer.sample(10)
synthesizer.validate(loaded_sample)

# Run against Metadata
synthesizer_2 = GaussianCopulaSynthesizer(Metadata.load_from_dict(metadata.to_dict()))
synthesizer_2.fit(data)
metadata_sample = synthesizer.sample(10)
assert loaded_synthesizer.metadata.to_dict() == synthesizer_2.metadata.to_dict()
assert metadata_sample.columns.to_list() == loaded_sample.columns.to_list()


def test_multi_table_compatibility(tmp_path):
"""Test if MultiTableMetadata still has compatibility with multi table synthesizers."""
# Setup
data, _ = download_demo('multi_table', 'fake_hotels')
warn_msg = re.escape(
"The 'MultiTableMetadata' is deprecated. Please use the new "
"'Metadata' class for synthesizers."
)

multi_dict = {
'tables': {
'guests': {
'primary_key': 'guest_email',
'columns': {
'guest_email': {'sdtype': 'email', 'pii': True},
'hotel_id': {'sdtype': 'id', 'regex_format': '[A-Za-z]{5}'},
'has_rewards': {'sdtype': 'boolean'},
'room_type': {'sdtype': 'categorical'},
'amenities_fee': {'sdtype': 'numerical', 'computer_representation': 'Float'},
'checkin_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'},
'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'},
'room_rate': {'sdtype': 'numerical', 'computer_representation': 'Float'},
'billing_address': {'sdtype': 'address', 'pii': True},
'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True},
},
},
'hotels': {
'primary_key': 'hotel_id',
'columns': {
'hotel_id': {'sdtype': 'id', 'regex_format': 'HID_[0-9]{3}'},
'city': {'sdtype': 'categorical'},
'state': {'sdtype': 'categorical'},
'rating': {'sdtype': 'numerical', 'computer_representation': 'Float'},
'classification': {'sdtype': 'categorical'},
},
},
},
'relationships': [
{
'parent_table_name': 'hotels',
'parent_primary_key': 'hotel_id',
'child_table_name': 'guests',
'child_foreign_key': 'hotel_id',
}
],
'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1',
}
metadata = MultiTableMetadata.load_from_dict(multi_dict)
assert type(metadata) is MultiTableMetadata

# Run
with pytest.warns(FutureWarning, match=warn_msg):
synthesizer = HMASynthesizer(metadata)

synthesizer.fit(data)
model_path = os.path.join(tmp_path, 'synthesizer.pkl')
synthesizer.save(model_path)

# Assert
assert os.path.exists(model_path)
assert os.path.isfile(model_path)

# Load HMASynthesizer
loaded_synthesizer = HMASynthesizer.load(model_path)

# Asserts
assert isinstance(synthesizer, HMASynthesizer)
assert loaded_synthesizer.get_info() == synthesizer.get_info()

# Load Metadata
expected_metadata = metadata.to_dict()

# Asserts

assert loaded_synthesizer.metadata.to_dict() == expected_metadata

# Sample from loaded synthesizer
loaded_sample = loaded_synthesizer.sample(10)
synthesizer.validate(loaded_sample)

# Run against Metadata
synthesizer_2 = HMASynthesizer(Metadata.load_from_dict(metadata.to_dict()))
synthesizer_2.fit(data)
metadata_sample = synthesizer.sample(10)
expected_metadata = loaded_synthesizer.metadata.to_dict()
expected_metadata['METADATA_SPEC_VERSION'] = 'V1'
assert expected_metadata == synthesizer_2.metadata.to_dict()
for table in metadata_sample:
assert metadata_sample[table].columns.to_list() == loaded_sample[table].columns.to_list()
34 changes: 31 additions & 3 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sdv.datasets.local import load_csvs
from sdv.errors import SamplingError, SynthesizerInputError, VersionError
from sdv.evaluation.multi_table import evaluate_quality, get_column_pair_plot, get_column_plot
from sdv.metadata.metadata import Metadata
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.multi_table import HMASynthesizer
from tests.integration.single_table.custom_constraints import MyConstraint
Expand Down Expand Up @@ -50,6 +51,32 @@ def test_hma(self):
for normal_table, increased_table in zip(normal_sample.values(), increased_sample.values()):
assert increased_table.size > normal_table.size

def test_hma_metadata(self):
"""End to end integration tests with ``HMASynthesizer``.
The test consist on loading the demo data, convert the old metadata to the new format
and then fit a ``HMASynthesizer``. After fitting two samples are being generated, one with
a 0.5 scale and one with 1.5 scale.
"""
# Setup
data, multi_metadata = download_demo('multi_table', 'got_families')
metadata = Metadata.load_from_dict(multi_metadata.to_dict())
hmasynthesizer = HMASynthesizer(metadata)

# Run
hmasynthesizer.fit(data)
normal_sample = hmasynthesizer.sample(0.5)
increased_sample = hmasynthesizer.sample(1.5)

# Assert
assert set(normal_sample) == {'characters', 'character_families', 'families'}
assert set(increased_sample) == {'characters', 'character_families', 'families'}
for table_name, table in normal_sample.items():
assert all(table.columns == data[table_name].columns)

for normal_table, increased_table in zip(normal_sample.values(), increased_sample.values()):
assert increased_table.size > normal_table.size

def test_hma_reset_sampling(self):
"""End to end integration test that uses ``reset_sampling``.
Expand Down Expand Up @@ -1241,7 +1268,8 @@ def test_metadata_updated_no_warning(self, tmp_path):
initialization, but is saved to a file before fitting.
"""
# Setup
data, metadata = download_demo('multi_table', 'got_families')
data, multi_metadata = download_demo('multi_table', 'got_families')
metadata = Metadata.load_from_dict(multi_metadata.to_dict())

# Run 1
with warnings.catch_warnings(record=True) as captured_warnings:
Expand All @@ -1258,7 +1286,7 @@ def test_metadata_updated_no_warning(self, tmp_path):
assert len(captured_warnings) == 0

# Run 2
metadata_detect = MultiTableMetadata()
metadata_detect = Metadata()
metadata_detect.detect_from_dataframes(data)

metadata_detect.relationships = metadata.relationships
Expand Down Expand Up @@ -1298,7 +1326,7 @@ def test_metadata_updated_warning_detect(self):
"""
# Setup
data, metadata = download_demo('multi_table', 'got_families')
metadata_detect = MultiTableMetadata()
metadata_detect = Metadata()
metadata_detect.detect_from_dataframes(data)

metadata_detect.relationships = metadata.relationships
Expand Down
23 changes: 20 additions & 3 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test___init__(
mock_generate_synthesizer_id.return_value = synthesizer_id
mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183'
metadata = get_multi_table_metadata()
metadata.validate = Mock()
metadata.validate = Mock(spec=Metadata)

# Run
with catch_sdv_logs(caplog, logging.INFO, 'MultiTableSynthesizer'):
Expand All @@ -144,6 +144,21 @@ def test___init__(
'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
})

def test___init___deprecated(self):
"""Test that init with old MultiTableMetadata gives a future warnging."""
# Setup
metadata = get_multi_table_metadata()
metadata.validate = Mock()

deprecation_msg = re.escape(
"The 'MultiTableMetadata' is deprecated. Please use the new "
"'Metadata' class for synthesizers."
)

# Run
with pytest.warns(FutureWarning, match=deprecation_msg):
BaseMultiTableSynthesizer(metadata)

def test__init__column_relationship_warning(self):
"""Test that a warning is raised only once when the metadata has column relationships."""
# Setup
Expand Down Expand Up @@ -382,7 +397,9 @@ def test_get_metadata(self):
result = instance.get_metadata()

# Assert
assert metadata == result
expected_metadata = Metadata.load_from_dict(metadata.to_dict())
assert type(result) is Metadata
assert expected_metadata.to_dict() == result.to_dict()

def test_validate(self):
"""Test that no error is being raised when the data is valid."""
Expand Down Expand Up @@ -868,7 +885,7 @@ def test_preprocess_int_columns(self):
def test_preprocess_warning(self, mock_warnings):
"""Test that ``preprocess`` warns the user if the model has already been fitted."""
# Setup
metadata = get_multi_table_metadata()
metadata = Metadata.load_from_dict(get_multi_table_metadata().to_dict())
instance = BaseMultiTableSynthesizer(metadata)
instance.validate = Mock()
data = {
Expand Down

0 comments on commit f22d815

Please sign in to comment.