Skip to content

Commit

Permalink
Numerical unknowns should not be converted to sdv-pii-???? (#2089)
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 authored Jul 3, 2024
1 parent d8962df commit 462812d
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 6 deletions.
18 changes: 14 additions & 4 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,13 +587,23 @@ def _create_config(self, data, columns_created_by_constraints):

elif sdtype == 'unknown':
sdtypes[column] = 'pii'
transformers[column] = AnonymizedFaker(
function_name='bothify',
)
transformers[column].function_kwargs = {
function_name = 'bothify'
function_kwargs = {
'text': 'sdv-pii-?????',
'letters': '0123456789abcdefghijklmnopqrstuvwxyz',
}
if pd.api.types.is_numeric_dtype(data[column]):
max_digits = len(str(abs(max(data[column]))))
min_digits = len(str(abs(min(data[column]))))
text = ('!' * (max_digits - min_digits)) + '%' + ('#' * (min_digits - 1))
function_name = 'numerify'
function_kwargs = {
'text': text,
}
transformers[column] = AnonymizedFaker(
function_name=function_name,
)
transformers[column].function_kwargs = function_kwargs

elif pii:
sdtypes[column] = 'pii'
Expand Down
8 changes: 7 additions & 1 deletion sdv/sampling/hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,13 @@ def _finalize(self, sampled_data):
synthesizer = self._table_synthesizers.get(table_name)
dtypes = synthesizer._data_processor._dtypes
for name, dtype in dtypes.items():
table_rows[name] = table_rows[name].dropna().astype(dtype)
try:
table_rows[name] = table_rows[name].dropna().astype(dtype)
except Exception:
LOGGER.info(
"Could not cast back to column's original dtype, keeping original typing."
)
table_rows[name] = table_rows[name].dropna()

final_data[table_name] = table_rows[list(dtypes.keys())]

Expand Down
50 changes: 50 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import warnings

import faker
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -1381,6 +1382,55 @@ def test_null_foreign_keys(self):
with pytest.raises(SynthesizerInputError, match=err_msg):
synthesizer.fit(data)

def test_sampling_with_unknown_sdtype_numerical_column(self):
"""Test that if a numerical column is detected as unknown in the metadata,
it does not fail and is handled as original detected value
"""
# Setup
fake = faker.Faker()

table1 = pd.DataFrame({
'name': [fake.name() for i in range(20)],
'salary': np.random.randint(20_000, 250_000, 20),
'age': np.random.randint(18, 70, 20),
'address': [fake.address() for i in range(20)],
})
table2 = pd.DataFrame({
'company': [fake.company() for i in range(20)],
'employee_count': np.random.randint(15, 4000, 20),
'revenue': np.random.randint(100_000, 1_000_000_000),
})

tables_dict = {'people': table1, 'company': table2}

metadata = MultiTableMetadata()
metadata.detect_from_dataframes(tables_dict)

# Run
synth = HMASynthesizer(metadata)
synth.fit(tables_dict)
sample_data = synth.sample(1)

# Assert
people_sample = sample_data['people']
company_sample = sample_data['company']

# Since these values are inferred, windows and mac may have different int types
# so check if it is numeric
numeric_data = [
people_sample['salary'],
people_sample['age'],
company_sample['employee_count'],
company_sample['revenue'],
]
object_data = [
people_sample['name'].dtype,
people_sample['address'].dtype,
company_sample['company'].dtype,
]
assert all(pd.api.types.is_numeric_dtype(dtype) for dtype in numeric_data)
assert all(dtype == 'object' for dtype in object_data)


@pytest.mark.parametrize('num_rows', [(10), (1000)])
def test_hma_0_1_child(num_rows):
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,24 @@ class CustomTransformer:
assert isinstance(config['transformers']['phone_number'], AnonymizedFaker)
assert isinstance(config['transformers']['email'], CustomTransformer)

def test__create_config_with_unknown_numerical_data(self):
"""Test the ``_create_config`` method with unknown numerical columns."""
# Setup
data = pd.DataFrame({
'numerical_column': [12321, 198, 1958],
})
metadata = SingleTableMetadata().load_from_dict({
'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1',
'columns': {'numerical_column': {'sdtype': 'unknown', 'pii': True}},
})
dp = DataProcessor(metadata)

# Run
config = dp._create_config(data, set())

# Assert
assert config['transformers']['numerical_column'].function_kwargs['text'] == '!!%##'

def test_update_transformers_not_fitted(self):
"""Test when ``self._hyper_transformer`` is ``None`` raises a ``NotFittedError``."""
# Setup
Expand Down
80 changes: 79 additions & 1 deletion tests/unit/sampling/test_hierarchical_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from unittest.mock import MagicMock, Mock, call
from unittest.mock import MagicMock, Mock, call, patch

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -378,6 +378,84 @@ def test__finalize(self):
for result_frame, expected_frame in zip(result.values(), expected_result.values()):
pd.testing.assert_frame_equal(result_frame, expected_frame)

@patch('sdv.sampling.hierarchical_sampler.LOGGER')
def test__finalize_no_matching_dtype(self, mock_logging):
"""Test that finalize removes extra columns from the sampled data."""
# Setup
instance = Mock()
metadata = Mock()
metadata._get_parent_map.return_value = {
'sessions': ['users'],
'transactions': ['sessions'],
}
instance.metadata = metadata

sampled_data = {
'users': pd.DataFrame({
'user_id': pd.Series([0, 1, 2], dtype=np.int64),
'name': pd.Series(['John', 'Doe', 'Johanna'], dtype=object),
'additional_column': pd.Series([0.1, 0.2, 0.3], dtype=float),
'another_additional_column': pd.Series([0.1, 0.2, 0.5], dtype=float),
}),
'sessions': pd.DataFrame({
'user_id': pd.Series([1, 2, 1], dtype=np.int64),
'session_id': pd.Series(['a', 'b', 'c'], dtype=object),
'os': pd.Series(['linux', 'mac', 'win'], dtype=object),
'country': pd.Series(['us', 'us', 'es'], dtype=object),
}),
'transactions': pd.DataFrame({
'transaction_id': pd.Series([1, 2, 3], dtype=np.int64),
'session_id': pd.Series(['a', 'a', 'b'], dtype=object),
}),
}

users_synth = Mock()
users_synth._data_processor._dtypes = {'user_id': np.int64, 'name': str}
sessions_synth = Mock()
# Incorrectly label data_processor type
sessions_synth._data_processor._dtypes = {
'user_id': np.int64,
'session_id': np.int64, # Should be str
'os': str,
'country': str,
}
transactions_synth = Mock()
transactions_synth._data_processor._dtypes = {'transaction_id': np.int64, 'session_id': str}

instance._table_synthesizers = {
'users': users_synth,
'sessions': sessions_synth,
'transactions': transactions_synth,
}

# Run
result = BaseHierarchicalSampler._finalize(instance, sampled_data)

# Assert
expected_result = {
'users': pd.DataFrame({
'user_id': pd.Series([0, 1, 2], dtype=np.int64),
'name': pd.Series(['John', 'Doe', 'Johanna'], dtype=object),
}),
'sessions': pd.DataFrame({
'user_id': pd.Series([1, 2, 1], dtype=np.int64),
'session_id': pd.Series(['a', 'b', 'c'], dtype=object),
'os': pd.Series(['linux', 'mac', 'win'], dtype=object),
'country': pd.Series(['us', 'us', 'es'], dtype=object),
}),
'transactions': pd.DataFrame({
'transaction_id': pd.Series([1, 2, 3], dtype=np.int64),
'session_id': pd.Series(['a', 'a', 'b'], dtype=object),
}),
}
for result_frame, expected_frame in zip(result.values(), expected_result.values()):
pd.testing.assert_frame_equal(result_frame, expected_frame)

# Confirm log was called
mock_logging.info.assert_called_once_with(
"Could not cast back to column's original dtype, keeping original typing."
)

def test__sample(self):
"""Test that the whole dataset is sampled.
Expand Down

0 comments on commit 462812d

Please sign in to comment.