Skip to content

Commit

Permalink
Reinstate get_table_parameters for the multi-table synthesizers (#1830
Browse files Browse the repository at this point in the history
)
  • Loading branch information
fealho committed Mar 6, 2024
1 parent 0bd8eaa commit 6222922
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 17 deletions.
13 changes: 8 additions & 5 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,19 @@ def get_table_parameters(self, table_name):
Returns:
parameters (dict):
A dictionary representing the parameters that will be used to instantiate the
table's synthesizer.
A dictionary with the following structure:
{
'synthesizer_name': the string name of the synthesizer for that table,
'synthesizer_parameters': the parameters used to instantiate the synthesizer
}
"""
table_synthesizer = self._table_synthesizers.get(table_name)
if not table_synthesizer:
table_params = {'table_synthesizer': None, 'table_parameters': {}}
table_params = {'synthesizer_name': None, 'synthesizer_parameters': {}}
else:
table_params = {
'table_synthesizer': type(table_synthesizer).__name__,
'table_parameters': table_synthesizer.get_parameters()
'synthesizer_name': type(table_synthesizer).__name__,
'synthesizer_parameters': table_synthesizer.get_parameters()
}

return table_params
Expand Down
2 changes: 1 addition & 1 deletion sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _get_num_extended_columns(self, table_name, parent_table, columns_per_table)
if num_data_columns == 0:
return num_rows_columns

table_parameters = self.get_table_parameters(table_name)['table_parameters']
table_parameters = self.get_table_parameters(table_name)['synthesizer_parameters']
distribution = table_parameters['default_distribution']
num_parameters_columns = num_rows_columns * num_data_columns
if distribution in {'beta', 'truncnorm'}:
Expand Down
12 changes: 6 additions & 6 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,26 +140,26 @@ def test_hma_set_table_parameters(self):

# Assert
character_params = hmasynthesizer.get_table_parameters('characters')
assert character_params['table_synthesizer'] == 'GaussianCopulaSynthesizer'
assert character_params['table_parameters'] == {
assert character_params['synthesizer_name'] == 'GaussianCopulaSynthesizer'
assert character_params['synthesizer_parameters'] == {
'default_distribution': 'gamma',
'enforce_min_max_values': True,
'enforce_rounding': True,
'locales': ['en_US'],
'numerical_distributions': {}
}
families_params = hmasynthesizer.get_table_parameters('families')
assert families_params['table_synthesizer'] == 'GaussianCopulaSynthesizer'
assert families_params['table_parameters'] == {
assert families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer'
assert families_params['synthesizer_parameters'] == {
'default_distribution': 'uniform',
'enforce_min_max_values': True,
'enforce_rounding': True,
'locales': ['en_US'],
'numerical_distributions': {}
}
char_families_params = hmasynthesizer.get_table_parameters('character_families')
assert char_families_params['table_synthesizer'] == 'GaussianCopulaSynthesizer'
assert char_families_params['table_parameters'] == {
assert char_families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer'
assert char_families_params['synthesizer_parameters'] == {
'default_distribution': 'norm',
'enforce_min_max_values': True,
'enforce_rounding': True,
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ def test_get_table_parameters_empty(self):

# Assert
assert result == {
'table_synthesizer': 'GaussianCopulaSynthesizer',
'table_parameters': {
'synthesizer_name': 'GaussianCopulaSynthesizer',
'synthesizer_parameters': {
'default_distribution': 'beta',
'enforce_min_max_values': True,
'enforce_rounding': True,
Expand All @@ -277,7 +277,7 @@ def test_get_table_parameters_has_parameters(self):
result = instance.get_table_parameters('oseba')

# Assert
assert result['table_parameters'] == {
assert result['synthesizer_parameters'] == {
'default_distribution': 'gamma',
'enforce_min_max_values': True,
'enforce_rounding': True,
Expand Down Expand Up @@ -314,8 +314,8 @@ def test_set_table_parameters(self):
# Assert
table_parameters = instance.get_table_parameters('oseba')
assert instance._table_parameters['oseba'] == {'default_distribution': 'gamma'}
assert table_parameters['table_synthesizer'] == 'GaussianCopulaSynthesizer'
assert table_parameters['table_parameters'] == {
assert table_parameters['synthesizer_name'] == 'GaussianCopulaSynthesizer'
assert table_parameters['synthesizer_parameters'] == {
'default_distribution': 'gamma',
'enforce_min_max_values': True,
'locales': ['en_US'],
Expand Down

0 comments on commit 6222922

Please sign in to comment.