diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index a8ef5e174..f9c3b6893 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -513,14 +513,6 @@ def _detect_relationships(self): sdtype=original_foreign_key_sdtype) continue - try: - self._validate_all_tables_connected(self._get_parent_map(), self._get_child_map()) - except InvalidMetadataError as invalid_error: - warning_msg = ( - f'Could not automatically add relationships for all tables. {str(invalid_error)}' - ) - warnings.warn(warning_msg) - def detect_table_from_dataframe(self, table_name, data): """Detect the metadata for a table from a dataframe. @@ -739,14 +731,10 @@ def validate(self): for relation in self.relationships: self._append_relationships_errors(errors, self._validate_relationship, **relation) - parent_map = self._get_parent_map() child_map = self._get_child_map() self._append_relationships_errors( errors, self._validate_child_map_circular_relationship, child_map) - self._append_relationships_errors( - errors, self._validate_all_tables_connected, parent_map, child_map) - if errors: raise InvalidMetadataError( 'The metadata is not valid' + '\n'.join(str(e) for e in errors) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 78f9c8c21..4be8bd53f 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1818,3 +1818,24 @@ def test_table_name_logging(caplog): # Assert for msg in caplog.messages: assert 'table parent_data' in msg or 'table child_data' in msg + + +def test_disjointed_tables(): + """Test to see if synthesizer works with disjointed tables.""" + # Setup + real_data, metadata = download_demo('multi_table', 'Bupa_v1') + + # Delete Some Relationships to make it disjointed + remove_some_dict = metadata.to_dict() + half_list = remove_some_dict['relationships'][1::2] + remove_some_dict['relationships'] = half_list + disjoined_metadata = MultiTableMetadata.load_from_dict(remove_some_dict) + + # Run + disjoin_synthesizer = HMASynthesizer(disjoined_metadata) + disjoin_synthesizer.fit(real_data) + disjoin_synthetic_data = disjoin_synthesizer.sample(1.0) + + # Assert + for table in real_data: + assert list(real_data[table].columns) == list(disjoin_synthetic_data[table].columns) diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index c32ed7185..de30ae4a8 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1203,14 +1203,42 @@ def test_validate_raises_errors(self): "Please use 'set_primary_key' in order to set one." "\nRelationship between tables ('sessions', 'transactions') is invalid. " 'The primary and foreign key columns are not the same type.' - "\nThe relationships in the dataset are disjointed. Table ['payments'] " - 'is not connected to any of the other tables.' ) # Run and Assert with pytest.raises(InvalidMetadataError, match=error_msg): instance.validate() + def test__validate_all_tables_connected_raises_errors(self): + """Test the method ``_validate_all_tables_connected``. + + Test that when a disjointed table is validated with `_validate_all_tables_connected` + + Setup: + - Instance of ``MultiTableMetadata`` with all valid tables and + missing relationships. + """ + # Setup + instance = self.get_metadata() + instance.tables['users'].primary_key = None + instance.tables['transactions'].columns['session_id']['sdtype'] = 'datetime' + instance.tables['payments'].columns['date']['sdtype'] = 'id' + instance.tables['payments'].columns['date']['regex_format'] = '[A-z{' + instance.relationships.pop(-1) + + # Run + error_msg = re.escape( + 'The relationships in the dataset are disjointed. ' + "Table ['payments'] is not connected to any of the other tables." + ) + + # Run and Assert + with pytest.raises(InvalidMetadataError, match=error_msg): + instance._validate_all_tables_connected( + instance._get_parent_map(), + instance._get_child_map() + ) + def test_validate_child_key_is_primary_key(self): """Test it crashes if the child key is a primary key.""" # Setup @@ -2324,44 +2352,6 @@ def test__detect_relationships(self): assert instance.relationships == expected_relationships assert instance.tables['sessions'].columns['user_id']['sdtype'] == 'id' - @patch('sdv.metadata.multi_table.warnings') - def test__detect_relationships_disconnected_warning(self, warnings_mock): - """Test that ``_detect_relationships`` warns about tables it could not connect.""" - # Setup - parent_table = Mock() - parent_table.primary_key = 'id' - parent_table.columns = { - 'id': {'sdtype': 'id'}, - 'user_name': {'sdtype': 'categorical'}, - 'transactions': {'sdtype': 'numerical'}, - } - - child_table = SingleTableMetadata() - child_table.primary_key = 'session_id' - child_table.columns = { - 'user_id': {'sdtype': 'categorical'}, - 'session_id': {'sdtype': 'numerical'}, - 'timestamp': {'sdtype': 'datetime'}, - } - - instance = MultiTableMetadata() - instance.tables = { - 'users': parent_table, - 'sessions': child_table, - } - - # Run - instance._detect_relationships() - - # Assert - expected_warning = ( - 'Could not automatically add relationships for all tables. The relationships in ' - "the dataset are disjointed. Tables ['users', 'sessions'] are not connected to " - 'any of the other tables.' - ) - warnings_mock.warn.assert_called_once_with(expected_warning) - assert instance.relationships == [] - def test__detect_relationships_circular(self): """Test that relationships that invalidate the metadata are not added.""" # Setup