diff --git a/sdv/sampling/independent_sampler.py b/sdv/sampling/independent_sampler.py index 46bdc6d41..a86b0ddc0 100644 --- a/sdv/sampling/independent_sampler.py +++ b/sdv/sampling/independent_sampler.py @@ -1,6 +1,7 @@ """Independent Samplers.""" import logging +import warnings LOGGER = logging.getLogger(__name__) @@ -144,8 +145,12 @@ def _sample(self, scale=1.0): sampled data tables as ``pandas.DataFrame``. """ sampled_data = {} + send_min_sample_warning = False for table in self.metadata.tables: num_rows = int(self._table_sizes[table] * scale) + if num_rows <= 0: + send_min_sample_warning = True + num_rows = 1 synthesizer = self._table_synthesizers[table] self._sample_table( synthesizer=synthesizer, @@ -154,5 +159,12 @@ def _sample(self, scale=1.0): sampled_data=sampled_data, ) + if send_min_sample_warning: + warn_msg = ( + "The 'scale' parameter is too small. Some tables may have 1 row." + ' For better quality data, please choose a larger scale.' + ) + warnings.warn(warn_msg) + self._connect_tables(sampled_data) return self._finalize(sampled_data) diff --git a/tests/unit/sampling/test_independent_sampler.py b/tests/unit/sampling/test_independent_sampler.py index f45215ba2..2cda624a6 100644 --- a/tests/unit/sampling/test_independent_sampler.py +++ b/tests/unit/sampling/test_independent_sampler.py @@ -1,3 +1,4 @@ +import re from unittest.mock import Mock, call, patch import numpy as np @@ -404,3 +405,36 @@ def _connect_tables(sampled_data): 'transactions': DataFrameMatcher(connected_transactions), }) assert result == instance._finalize.return_value + + def test__sample_too_small(self): + instance = Mock() + metadata = Mock() + metadata.tables = { + 'guests': { + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', + 'primary_key': 'key_1', + 'columns': { + 'key_1': { + 'sdtype': 'numerical', + } + }, + }, + 'hotels': { + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', + 'primary_key': 'key_2', + 'columns': { + 'key_2': { + 'sdtype': 'numerical', + } + }, + }, + } + instance._table_sizes = {'hotels': 10, 'guests': 658} + instance._table_synthesizers = {'hotels': Mock(), 'guests': Mock()} + instance.metadata = metadata + warning_msg = re.escape( + "The 'scale' parameter is too small. Some tables may have 1 row." + ' For better quality data, please choose a larger scale.' + ) + with pytest.warns(Warning, match=warning_msg): + BaseIndependentSampler._sample(instance, 0.01)