Skip to content

Commit

Permalink
Add error message if sequence key present in context columns (#2108)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsheni authored Jul 8, 2024
1 parent b68ddd3 commit b34b088
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 10 deletions.
16 changes: 16 additions & 0 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(

self._sequence_index = self.metadata.sequence_index
self.context_columns = context_columns or []
self._validate_sequence_key_and_context_columns()
self._extra_context_columns = {}
self.extended_columns = {}
self.segment_size = segment_size
Expand Down Expand Up @@ -194,6 +195,21 @@ def add_custom_constraint_class(self, class_object, class_name):
"""Error that tells the user custom constraints can't be used in the ``PARSynthesizer``."""
raise SynthesizerInputError('The PARSynthesizer cannot accommodate custom constraints.')

def _validate_sequence_key_and_context_columns(self):
"""Check that the sequence key is not present in the context colums.
Args:
sequence_key (list[str]):
A list of column that identify which row(s) belong to which sequences.
context_columns (list[str]):
A list of strings, representing the columns that do not vary in a sequence.
"""
if set(self._sequence_key).intersection(set(self.context_columns)):
raise SynthesizerInputError(
f'The sequence key {self._sequence_key} cannot be a context column. '
'To proceed, please remove the sequence key from the context_columns parameter.'
)

def _validate_context_columns(self, data):
errors = []
if self.context_columns:
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,22 @@ def test_par_sequence_index_is_numerical():
s1.fit(data)
sample = s1.sample(2, 5)
assert sample.columns.to_list() == data.columns.to_list()


def test_init_error_sequence_key_in_context():
# Setup
metadata_dict = {
'columns': {
'A': {'sdtype': 'id'},
'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
},
'sequence_key': 'A',
}
metadata = SingleTableMetadata.load_from_dict(metadata_dict)
sequence_key_context_column_error_msg = re.escape(
"The sequence key ['A'] cannot be a context column. "
'To proceed, please remove the sequence key from the context_columns parameter.'
)
# Run and Assert
with pytest.raises(SynthesizerInputError, match=sequence_key_context_column_error_msg):
PARSynthesizer(metadata, context_columns=['A'])
31 changes: 21 additions & 10 deletions tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,8 @@ def test___init___no_sequence_key(self):
def test_add_constraints(self):
"""Test that that only simple constraints can be added to PARSynthesizer."""
# Setup
metadata = self.get_metadata()
synthesizer = PARSynthesizer(metadata=metadata, context_columns=['name', 'measurement'])
name_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {'column_name': 'name'},
}
metadata = self.get_metadata(add_sequence_key=True)
synthesizer = PARSynthesizer(metadata=metadata, context_columns=['gender', 'measurement'])
measurement_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {'column_name': 'measurement'},
Expand All @@ -130,7 +126,7 @@ def test_add_constraints(self):
}
multi_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {'column_names': ['name', 'time']},
'constraint_parameters': {'column_names': ['gender', 'time']},
}
overlapping_error_msg = re.escape(
'The PARSynthesizer cannot accommodate multiple constraints '
Expand All @@ -143,7 +139,7 @@ def test_add_constraints(self):

# Run and Assert
with pytest.raises(SynthesizerInputError, match=mixed_constraint_error_msg):
synthesizer.add_constraints([name_constraint, gender_constraint])
synthesizer.add_constraints([time_constraint, gender_constraint])

with pytest.raises(SynthesizerInputError, match=mixed_constraint_error_msg):
synthesizer.add_constraints([time_constraint, measurement_constraint])
Expand All @@ -152,10 +148,10 @@ def test_add_constraints(self):
synthesizer.add_constraints([multi_constraint])

with pytest.raises(SynthesizerInputError, match=overlapping_error_msg):
synthesizer.add_constraints([multi_constraint, name_constraint])
synthesizer.add_constraints([multi_constraint, gender_constraint])

with pytest.raises(SynthesizerInputError, match=overlapping_error_msg):
synthesizer.add_constraints([name_constraint, name_constraint])
synthesizer.add_constraints([gender_constraint, gender_constraint])

with pytest.raises(SynthesizerInputError, match=overlapping_error_msg):
synthesizer.add_constraints([gender_constraint, gender_constraint])
Expand Down Expand Up @@ -935,3 +931,18 @@ def test_load(self, mock_file, cloudpickle_mock):
mock_file.assert_called_once_with('synth.pkl', 'rb')
cloudpickle_mock.load.assert_called_once_with(mock_file.return_value)
assert loaded_instance == synthesizer_mock

def test___init___error_sequence_key_in_context(self):
"""Test that the sequence_key is not a context column"""
# Setup
metadata = self.get_metadata(add_sequence_key=True)
sequence_key_context_column_error_msg = re.escape(
"The sequence key ['name'] cannot be a context column. "
'To proceed, please remove the sequence key from the context_columns parameter.'
)
# Run and Assert
with pytest.raises(SynthesizerInputError, match=sequence_key_context_column_error_msg):
PARSynthesizer(
metadata=metadata,
context_columns=['name'],
)

0 comments on commit b34b088

Please sign in to comment.