Skip to content

Commit

Permalink
Merge branch 'main' into feature/metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Aug 5, 2024
2 parents 9adc130 + a904634 commit fb2c2d5
Show file tree
Hide file tree
Showing 16 changed files with 477 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/auto_assign.yml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Set to true to add assignees to pull requests
addAssignees: true
addAssignees: author
4 changes: 2 additions & 2 deletions latest_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ numpy==1.26.4
pandas==2.2.2
platformdirs==4.2.2
rdt==1.12.2
sdmetrics==0.14.1
tqdm==4.66.4
sdmetrics==0.15.0
tqdm==4.66.5
41 changes: 41 additions & 0 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,47 @@ def get_table_metadata(self, table_name):
self._validate_table_exists(table_name)
return deepcopy(self.tables[table_name])

def anonymize(self):
"""Anonymize metadata by obfuscating column names.
Returns:
MultiTableMetadata:
An anonymized MultiTableMetadata instance.
"""
anonymized_metadata = {'tables': {}, 'relationships': []}
anonymized_table_map = {}
counter = 1
for table, table_metadata in self.tables.items():
anonymized_table_name = f'table{counter}'
anonymized_table_map[table] = anonymized_table_name

anonymized_metadata['tables'][anonymized_table_name] = (
table_metadata.anonymize().to_dict()
)
counter += 1

for relationship in self.relationships:
parent_table = relationship['parent_table_name']
anonymized_parent_table = anonymized_table_map[parent_table]

child_table = relationship['child_table_name']
anonymized_child_table = anonymized_table_map[child_table]

foreign_key = relationship['child_foreign_key']
anonymized_foreign_key = self.tables[child_table]._anonymized_column_map[foreign_key]

primary_key = relationship['parent_primary_key']
anonymized_primary_key = self.tables[parent_table]._anonymized_column_map[primary_key]

anonymized_metadata['relationships'].append({
'parent_table_name': anonymized_parent_table,
'child_table_name': anonymized_child_table,
'child_foreign_key': anonymized_foreign_key,
'parent_primary_key': anonymized_primary_key,
})

return MultiTableMetadata.load_from_dict(anonymized_metadata)

def visualize(
self, show_table_details='full', show_relationship_labels=True, output_filepath=None
):
Expand Down
111 changes: 94 additions & 17 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,20 +538,72 @@ def _determine_sdtype_for_objects(self, data):

return sdtype

def _detect_primary_key(self, data):
"""Detect the table's primary key.
This method will loop through the columns and select the first column that was detected as
an id. If there are none of those, it will pick the first unique pii column. If there is
still no primary key, it will return None. All other id columns after the first will be
reassigned to the 'unknown' sdtype.
Args:
data (pandas.DataFrame):
The data to be analyzed.
Returns:
str:
The column name of the selected primary key.
Raises:
RuntimeError:
If the sdtypes for all columns haven't been detected or set yet.
"""
original_columns = data.columns
stringified_columns = data.columns.astype(str)
data.columns = stringified_columns
for column in data.columns:
if not self.columns.get(column, {}).get('sdtype'):
raise RuntimeError(
'All columns must have sdtypes detected or set manually to detect the primary '
'key.'
)

candidates = []
first_pii_field = None
for column, column_meta in self.columns.items():
sdtype = column_meta['sdtype']
column_data = data[column]
has_nan = column_data.isna().any()
valid_potential_primary_key = column_data.is_unique and not has_nan
sdtype_in_reference = sdtype in self._REFERENCE_TO_SDTYPE.values()
if sdtype == 'id':
candidates.append(column)
if len(candidates) > 1:
self.columns[column]['sdtype'] = 'unknown'
self.columns[column]['pii'] = True

elif sdtype_in_reference and first_pii_field is None and valid_potential_primary_key:
first_pii_field = column

data.columns = original_columns
if candidates:
return candidates[0]
if first_pii_field:
return first_pii_field

return None

def _detect_columns(self, data):
"""Detect the columns' sdtype and the primary key from the data.
"""Detect the columns' sdtypes from the data.
Args:
data (pandas.DataFrame):
The data to be analyzed.
"""
old_columns = data.columns
data.columns = data.columns.astype(str)
first_pii_field = None
for field in data:
column_data = data[field]
has_nan = column_data.isna().any()
valid_potential_primary_key = column_data.is_unique and not has_nan
clean_data = column_data.dropna()
dtype = clean_data.infer_objects().dtype.kind

Expand All @@ -571,30 +623,19 @@ def _detect_columns(self, data):
"The valid data types are: 'object', 'int', 'float', 'datetime', 'bool'."
)

# Set the first ID column we detect to be the primary key
if sdtype == 'id':
if self.primary_key is None and valid_potential_primary_key:
self.primary_key = field
else:
sdtype = 'unknown'

column_dict = {'sdtype': sdtype}
sdtype_in_reference = sdtype in self._REFERENCE_TO_SDTYPE.values()

if sdtype_in_reference or sdtype == 'unknown':
column_dict['pii'] = True
if sdtype_in_reference and first_pii_field is None and not has_nan:
first_pii_field = field

if sdtype == 'datetime' and dtype == 'O':
datetime_format = _get_datetime_format(column_data.iloc[:100])
column_dict['datetime_format'] = datetime_format

self.columns[field] = deepcopy(column_dict)

# When no primary key column was set, choose the first pii field
if self.primary_key is None and first_pii_field and valid_potential_primary_key:
self.primary_key = first_pii_field

self.primary_key = self._detect_primary_key(data)
self._updated = True
data.columns = old_columns

Expand Down Expand Up @@ -1177,6 +1218,41 @@ def validate_data(self, data, sdtype_warnings=None):
if errors:
raise InvalidDataError(errors)

def anonymize(self):
"""Anonymize metadata by obfuscating column names.
Returns:
SingleTableMetadata:
An anonymized SingleTableMetadata instance.
"""
anonymized_metadata = {'columns': {}}

self._anonymized_column_map = {}
counter = 1
for column, column_metadata in self.columns.items():
anonymized_column = f'col{counter}'
self._anonymized_column_map[column] = anonymized_column
anonymized_metadata['columns'][anonymized_column] = column_metadata
counter += 1

if self.primary_key:
anonymized_metadata['primary_key'] = self._anonymized_column_map[self.primary_key]

if self.alternate_keys:
anonymized_alternate_keys = []
for alternate_key in self.alternate_keys:
anonymized_alternate_keys.append(self._anonymized_column_map[alternate_key])

anonymized_metadata['alternate_keys'] = anonymized_alternate_keys

if self.sequence_key:
anonymized_metadata['sequence_key'] = self._anonymized_column_map[self.sequence_key]

if self.sequence_index:
anonymized_metadata['sequence_index'] = self._anonymized_column_map[self.sequence_index]

return SingleTableMetadata.load_from_dict(anonymized_metadata)

def visualize(self, show_table_details='full', output_filepath=None):
"""Create a visualization of the single-table dataset.
Expand Down Expand Up @@ -1273,6 +1349,7 @@ def load_from_dict(cls, metadata_dict):
}
setattr(instance, f'{key}', value)

instance._primary_key_candidates = None
return instance

@classmethod
Expand Down
10 changes: 10 additions & 0 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,13 +667,23 @@ def get_info(self):

return info

def _validate_fit_before_save(self):
"""Validate that the synthesizer has been fitted before saving."""
if not self._fitted:
warnings.warn(
'You are saving a synthesizer that has not yet been fitted. You will not be able '
'to sample synthetic data without fitting. We recommend fitting the synthesizer '
'first and then saving.'
)

def save(self, filepath):
"""Save this instance to the given path using cloudpickle.
Args:
filepath (str):
Path where the instance will be serialized.
"""
self._validate_fit_before_save()
synthesizer_id = getattr(self, '_synthesizer_id', None)
SYNTHESIZER_LOGGER.info({
'EVENT': 'Save',
Expand Down
6 changes: 6 additions & 0 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,12 @@ def _fit_context_model(self, transformed):
context[constant_column] = 0
context_metadata.add_column(constant_column, sdtype='numerical')

for column in self.context_columns:
# Context datetime SDTypes for PAR have already been converted to float timestamp
if context_metadata.columns[column]['sdtype'] == 'datetime':
if pd.api.types.is_numeric_dtype(context[column]):
context_metadata.update_column(column, sdtype='numerical')

self._context_synthesizer = GaussianCopulaSynthesizer(
context_metadata,
enforce_min_max_values=self._context_synthesizer.enforce_min_max_values,
Expand Down
10 changes: 10 additions & 0 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,13 +485,23 @@ def fit(self, data):
processed_data = self.preprocess(data)
self.fit_processed_data(processed_data)

def _validate_fit_before_save(self):
"""Validate that the synthesizer has been fitted before saving."""
if not self._fitted:
warnings.warn(
'You are saving a synthesizer that has not yet been fitted. You will not be able '
'to sample synthetic data without fitting. We recommend fitting the synthesizer '
'first and then saving.'
)

def save(self, filepath):
"""Save this model instance to the given path using cloudpickle.
Args:
filepath (str):
Path where the synthesizer instance will be serialized.
"""
self._validate_fit_before_save()
synthesizer_id = getattr(self, '_synthesizer_id', None)
SYNTHESIZER_LOGGER.info({
'EVENT': 'Save',
Expand Down
57 changes: 57 additions & 0 deletions tests/integration/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,60 @@ def test_get_table_metadata():
'column_relationships': [{'type': 'gps', 'column_names': ['latitude', 'longitude']}],
}
assert table_metadata.to_dict() == expected_metadata


def test_anonymize():
"""Test the ``anonymize`` method."""
# Setup
metadata_dict = {
'tables': {
'real_table1': {
'columns': {
'table1_primary_key': {'sdtype': 'id', 'regex_format': 'ID_[0-9]{3}'},
'table1_column2': {'sdtype': 'categorical'},
},
'primary_key': 'table1_primary_key',
},
'real_table2': {
'columns': {
'table2_primary_key': {'sdtype': 'email'},
'table2_foreign_key': {'sdtype': 'id', 'regex_format': 'ID_[0-9]{3}'},
},
'primary_key': 'table2_primary_key',
},
},
'relationships': [
{
'parent_table_name': 'real_table1',
'parent_primary_key': 'table1_primary_key',
'child_table_name': 'real_table2',
'child_foreign_key': 'table2_foreign_key',
}
],
}
metadata = MultiTableMetadata.load_from_dict(metadata_dict)
table1_metadata = metadata.tables['real_table1']
table2_metadata = metadata.tables['real_table2']
metadata.validate()

# Run
anonymized = metadata.anonymize()

# Assert
anonymized.validate()

assert anonymized.tables.keys() == {'table1', 'table2'}
assert len(anonymized.relationships) == len(metadata.relationships)
assert anonymized.relationships[0]['parent_table_name'] == 'table1'
assert anonymized.relationships[0]['child_table_name'] == 'table2'
assert anonymized.relationships[0]['parent_primary_key'] == 'col1'
assert anonymized.relationships[0]['child_foreign_key'] == 'col2'

anon_primary_key_metadata = anonymized.tables['table1'].columns['col1']
assert anon_primary_key_metadata == table1_metadata.columns['table1_primary_key']

anon_foreign_key_metadata = anonymized.tables['table2'].columns['col2']
assert anon_foreign_key_metadata == table2_metadata.columns['table2_foreign_key']

assert anonymized.tables['table1'].to_dict() == table1_metadata.anonymize().to_dict()
assert anonymized.tables['table2'].to_dict() == table2_metadata.anonymize().to_dict()
46 changes: 46 additions & 0 deletions tests/integration/metadata/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,3 +551,49 @@ def test_metadata_set_same_sequence_primary():
)
with pytest.raises(InvalidMetadataError, match=error_msg_sequence):
metadata_primary.set_sequence_key('A')


def test_anonymize():
"""Test the ``anonymize`` method."""
# Setup
metadata_dict = {
'columns': {
'primary_key': {'sdtype': 'id', 'regex_format': 'ID_[0-9]{3}'},
'sequence_index': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
'sequence_key': {'sdtype': 'id'},
'alternate_id1': {'sdtype': 'email', 'pii': True},
'alternate_id2': {'sdtype': 'name', 'pii': True},
'numerical': {'sdtype': 'numerical', 'computer_representation': 'Float'},
'categorical': {'sdtype': 'categorical'},
},
'primary_key': 'primary_key',
'sequence_index': 'sequence_index',
'sequence_key': 'sequence_key',
'alternate_keys': ['alternate_id1', 'alternate_id2'],
}
metadata = SingleTableMetadata.load_from_dict(metadata_dict)
metadata.validate()

# Run
anonymized = metadata.anonymize()

# Assert
anonymized.validate()

assert all(original_col not in anonymized.columns for original_col in metadata.columns)
for original_col, anonymized_col in metadata._anonymized_column_map.items():
assert metadata.columns[original_col] == anonymized.columns[anonymized_col]

anon_primary_key = anonymized.primary_key
assert anonymized.columns[anon_primary_key] == metadata.columns['primary_key']

anon_alternate_keys = anonymized.alternate_keys
assert len(anon_alternate_keys) == len(metadata.alternate_keys)
assert anonymized.columns[anon_alternate_keys[0]] == metadata.columns['alternate_id1']
assert anonymized.columns[anon_alternate_keys[1]] == metadata.columns['alternate_id2']

anon_sequence_index = anonymized.sequence_index
assert anonymized.columns[anon_sequence_index] == metadata.columns['sequence_index']

anon_sequence_key = anonymized.sequence_key
assert anonymized.columns[anon_sequence_key] == metadata.columns['sequence_key']
Loading

0 comments on commit fb2c2d5

Please sign in to comment.