Skip to content

Commit

Permalink
Allow for disconnected tables (#1979)
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Jun 11, 2024
1 parent fbbc65a commit 1286211
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 52 deletions.
12 changes: 0 additions & 12 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,14 +513,6 @@ def _detect_relationships(self):
sdtype=original_foreign_key_sdtype)
continue

try:
self._validate_all_tables_connected(self._get_parent_map(), self._get_child_map())
except InvalidMetadataError as invalid_error:
warning_msg = (
f'Could not automatically add relationships for all tables. {str(invalid_error)}'
)
warnings.warn(warning_msg)

def detect_table_from_dataframe(self, table_name, data):
"""Detect the metadata for a table from a dataframe.
Expand Down Expand Up @@ -739,14 +731,10 @@ def validate(self):
for relation in self.relationships:
self._append_relationships_errors(errors, self._validate_relationship, **relation)

parent_map = self._get_parent_map()
child_map = self._get_child_map()

self._append_relationships_errors(
errors, self._validate_child_map_circular_relationship, child_map)
self._append_relationships_errors(
errors, self._validate_all_tables_connected, parent_map, child_map)

if errors:
raise InvalidMetadataError(
'The metadata is not valid' + '\n'.join(str(e) for e in errors)
Expand Down
21 changes: 21 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,3 +1818,24 @@ def test_table_name_logging(caplog):
# Assert
for msg in caplog.messages:
assert 'table parent_data' in msg or 'table child_data' in msg


def test_disjointed_tables():
"""Test to see if synthesizer works with disjointed tables."""
# Setup
real_data, metadata = download_demo('multi_table', 'Bupa_v1')

# Delete Some Relationships to make it disjointed
remove_some_dict = metadata.to_dict()
half_list = remove_some_dict['relationships'][1::2]
remove_some_dict['relationships'] = half_list
disjoined_metadata = MultiTableMetadata.load_from_dict(remove_some_dict)

# Run
disjoin_synthesizer = HMASynthesizer(disjoined_metadata)
disjoin_synthesizer.fit(real_data)
disjoin_synthetic_data = disjoin_synthesizer.sample(1.0)

# Assert
for table in real_data:
assert list(real_data[table].columns) == list(disjoin_synthetic_data[table].columns)
70 changes: 30 additions & 40 deletions tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,14 +1203,42 @@ def test_validate_raises_errors(self):
"Please use 'set_primary_key' in order to set one."
"\nRelationship between tables ('sessions', 'transactions') is invalid. "
'The primary and foreign key columns are not the same type.'
"\nThe relationships in the dataset are disjointed. Table ['payments'] "
'is not connected to any of the other tables.'
)

# Run and Assert
with pytest.raises(InvalidMetadataError, match=error_msg):
instance.validate()

def test__validate_all_tables_connected_raises_errors(self):
"""Test the method ``_validate_all_tables_connected``.
Test that when a disjointed table is validated with `_validate_all_tables_connected`
Setup:
- Instance of ``MultiTableMetadata`` with all valid tables and
missing relationships.
"""
# Setup
instance = self.get_metadata()
instance.tables['users'].primary_key = None
instance.tables['transactions'].columns['session_id']['sdtype'] = 'datetime'
instance.tables['payments'].columns['date']['sdtype'] = 'id'
instance.tables['payments'].columns['date']['regex_format'] = '[A-z{'
instance.relationships.pop(-1)

# Run
error_msg = re.escape(
'The relationships in the dataset are disjointed. '
"Table ['payments'] is not connected to any of the other tables."
)

# Run and Assert
with pytest.raises(InvalidMetadataError, match=error_msg):
instance._validate_all_tables_connected(
instance._get_parent_map(),
instance._get_child_map()
)

def test_validate_child_key_is_primary_key(self):
"""Test it crashes if the child key is a primary key."""
# Setup
Expand Down Expand Up @@ -2324,44 +2352,6 @@ def test__detect_relationships(self):
assert instance.relationships == expected_relationships
assert instance.tables['sessions'].columns['user_id']['sdtype'] == 'id'

@patch('sdv.metadata.multi_table.warnings')
def test__detect_relationships_disconnected_warning(self, warnings_mock):
"""Test that ``_detect_relationships`` warns about tables it could not connect."""
# Setup
parent_table = Mock()
parent_table.primary_key = 'id'
parent_table.columns = {
'id': {'sdtype': 'id'},
'user_name': {'sdtype': 'categorical'},
'transactions': {'sdtype': 'numerical'},
}

child_table = SingleTableMetadata()
child_table.primary_key = 'session_id'
child_table.columns = {
'user_id': {'sdtype': 'categorical'},
'session_id': {'sdtype': 'numerical'},
'timestamp': {'sdtype': 'datetime'},
}

instance = MultiTableMetadata()
instance.tables = {
'users': parent_table,
'sessions': child_table,
}

# Run
instance._detect_relationships()

# Assert
expected_warning = (
'Could not automatically add relationships for all tables. The relationships in '
"the dataset are disjointed. Tables ['users', 'sessions'] are not connected to "
'any of the other tables.'
)
warnings_mock.warn.assert_called_once_with(expected_warning)
assert instance.relationships == []

def test__detect_relationships_circular(self):
"""Test that relationships that invalidate the metadata are not added."""
# Setup
Expand Down

0 comments on commit 1286211

Please sign in to comment.