Skip to content

Commit

Permalink
Separate primary key detection functionality (#2132)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontanez24 committed Jul 17, 2024
1 parent fa01804 commit eac39a8
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 20 deletions.
76 changes: 59 additions & 17 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,20 +538,72 @@ def _determine_sdtype_for_objects(self, data):

return sdtype

def _detect_primary_key(self, data):
"""Detect the table's primary key.
This method will loop through the columns and select the first column that was detected as
an id. If there are none of those, it will pick the first unique pii column. If there is
still no primary key, it will return None. All other id columns after the first will be
reassigned to the 'unknown' sdtype.
Args:
data (pandas.DataFrame):
The data to be analyzed.
Returns:
str:
The column name of the selected primary key.
Raises:
RuntimeError:
If the sdtypes for all columns haven't been detected or set yet.
"""
original_columns = data.columns
stringified_columns = data.columns.astype(str)
data.columns = stringified_columns
for column in data.columns:
if not self.columns.get(column, {}).get('sdtype'):
raise RuntimeError(
'All columns must have sdtypes detected or set manually to detect the primary '
'key.'
)

candidates = []
first_pii_field = None
for column, column_meta in self.columns.items():
sdtype = column_meta['sdtype']
column_data = data[column]
has_nan = column_data.isna().any()
valid_potential_primary_key = column_data.is_unique and not has_nan
sdtype_in_reference = sdtype in self._REFERENCE_TO_SDTYPE.values()
if sdtype == 'id':
candidates.append(column)
if len(candidates) > 1:
self.columns[column]['sdtype'] = 'unknown'
self.columns[column]['pii'] = True

elif sdtype_in_reference and first_pii_field is None and valid_potential_primary_key:
first_pii_field = column

data.columns = original_columns
if candidates:
return candidates[0]
if first_pii_field:
return first_pii_field

return None

def _detect_columns(self, data):
"""Detect the columns' sdtype and the primary key from the data.
"""Detect the columns' sdtypes from the data.
Args:
data (pandas.DataFrame):
The data to be analyzed.
"""
old_columns = data.columns
data.columns = data.columns.astype(str)
first_pii_field = None
for field in data:
column_data = data[field]
has_nan = column_data.isna().any()
valid_potential_primary_key = column_data.is_unique and not has_nan
clean_data = column_data.dropna()
dtype = clean_data.infer_objects().dtype.kind

Expand All @@ -571,30 +623,19 @@ def _detect_columns(self, data):
"The valid data types are: 'object', 'int', 'float', 'datetime', 'bool'."
)

# Set the first ID column we detect to be the primary key
if sdtype == 'id':
if self.primary_key is None and valid_potential_primary_key:
self.primary_key = field
else:
sdtype = 'unknown'

column_dict = {'sdtype': sdtype}
sdtype_in_reference = sdtype in self._REFERENCE_TO_SDTYPE.values()

if sdtype_in_reference or sdtype == 'unknown':
column_dict['pii'] = True
if sdtype_in_reference and first_pii_field is None and not has_nan:
first_pii_field = field

if sdtype == 'datetime' and dtype == 'O':
datetime_format = _get_datetime_format(column_data.iloc[:100])
column_dict['datetime_format'] = datetime_format

self.columns[field] = deepcopy(column_dict)

# When no primary key column was set, choose the first pii field
if self.primary_key is None and first_pii_field and valid_potential_primary_key:
self.primary_key = first_pii_field

self.primary_key = self._detect_primary_key(data)
self._updated = True
data.columns = old_columns

Expand Down Expand Up @@ -1273,6 +1314,7 @@ def load_from_dict(cls, metadata_dict):
}
setattr(instance, f'{key}', value)

instance._primary_key_candidates = None
return instance

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,8 @@ def test__validate_constraint_dict_columns_in_relationships(self):
}

metadata = SingleTableMetadata()
metadata.add_column('country_column', sdtype='country_code')
metadata.add_column('city_column', sdtype='city')
metadata.columns['country_column'] = {'sdtype': 'country_code', 'pii': True}
metadata.columns['city_column'] = {'sdtype': 'city', 'pii': True}
custom_constraint = Mock()

dp = DataProcessor(metadata)
Expand Down
19 changes: 18 additions & 1 deletion tests/unit/metadata/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def test_update_columns_sdtype_in_kwargs_error(self):
with pytest.raises(InvalidMetadataError, match=error_msg):
instance.update_columns(['col_1', 'col_2'], sdtype='numerical', pii=True)

def test_update_columns_multiple_erros(self):
def test_update_columns_multiple_errors(self):
"""Test the ``update_columns`` method.
Test that ``update_columns`` with multiple errors.
Expand Down Expand Up @@ -1249,6 +1249,23 @@ def test__detect_columns_with_error(self, mock__get_datetime_format):
instance._determine_sdtype_for_objects.assert_called_once()
mock__get_datetime_format.assert_called_once()

def test__detect_primary_key_missing_sdtypes(self):
"""The method should raise an error if not all sdtypes were detected."""
# Setup
data = pd.DataFrame({
'string_id': ['1', '2', '3', '4', '5', '6'],
'num_id': [1, 2, 3, 4, 5, 6],
})
metadata = SingleTableMetadata()
metadata.columns = {'string_id': {'sdtype': 'id'}}

# Run and Assert
message = (
'All columns must have sdtypes detected or set manually to detect the primary key.'
)
with pytest.raises(RuntimeError, match=message):
metadata._detect_primary_key(data)

def test_detect_from_dataframe_raises_error(self):
"""Test the ``detect_from_dataframe`` method.
Expand Down

0 comments on commit eac39a8

Please sign in to comment.