Skip to content

Commit

Permalink
Independent Sampler always create a sample size of 1 no matter how sm…
Browse files Browse the repository at this point in the history
…all the scale (#2102)
  • Loading branch information
lajohn4747 authored Jul 3, 2024
1 parent 462812d commit 6a57b3f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
12 changes: 12 additions & 0 deletions sdv/sampling/independent_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Independent Samplers."""

import logging
import warnings

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -144,8 +145,12 @@ def _sample(self, scale=1.0):
sampled data tables as ``pandas.DataFrame``.
"""
sampled_data = {}
send_min_sample_warning = False
for table in self.metadata.tables:
num_rows = int(self._table_sizes[table] * scale)
if num_rows <= 0:
send_min_sample_warning = True
num_rows = 1
synthesizer = self._table_synthesizers[table]
self._sample_table(
synthesizer=synthesizer,
Expand All @@ -154,5 +159,12 @@ def _sample(self, scale=1.0):
sampled_data=sampled_data,
)

if send_min_sample_warning:
warn_msg = (
"The 'scale' parameter is too small. Some tables may have 1 row."
' For better quality data, please choose a larger scale.'
)
warnings.warn(warn_msg)

self._connect_tables(sampled_data)
return self._finalize(sampled_data)
34 changes: 34 additions & 0 deletions tests/unit/sampling/test_independent_sampler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from unittest.mock import Mock, call, patch

import numpy as np
Expand Down Expand Up @@ -404,3 +405,36 @@ def _connect_tables(sampled_data):
'transactions': DataFrameMatcher(connected_transactions),
})
assert result == instance._finalize.return_value

def test__sample_too_small(self):
instance = Mock()
metadata = Mock()
metadata.tables = {
'guests': {
'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1',
'primary_key': 'key_1',
'columns': {
'key_1': {
'sdtype': 'numerical',
}
},
},
'hotels': {
'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1',
'primary_key': 'key_2',
'columns': {
'key_2': {
'sdtype': 'numerical',
}
},
},
}
instance._table_sizes = {'hotels': 10, 'guests': 658}
instance._table_synthesizers = {'hotels': Mock(), 'guests': Mock()}
instance.metadata = metadata
warning_msg = re.escape(
"The 'scale' parameter is too small. Some tables may have 1 row."
' For better quality data, please choose a larger scale.'
)
with pytest.warns(Warning, match=warning_msg):
BaseIndependentSampler._sample(instance, 0.01)

0 comments on commit 6a57b3f

Please sign in to comment.