Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable visualize just the real data (or just the synthetic data) in a multi-table setting #2169

Merged
merged 3 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions sdv/evaluation/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def get_column_plot(real_data, synthetic_data, metadata, table_name, column_name
1D marginal distribution plot (i.e. a histogram) of the columns.
"""
metadata = metadata.tables[table_name]
real_data = real_data[table_name]
synthetic_data = synthetic_data[table_name]
real_data = real_data[table_name] if real_data else None
synthetic_data = synthetic_data[table_name] if synthetic_data else None
return single_table_visualization.get_column_plot(
real_data,
synthetic_data,
Expand Down Expand Up @@ -118,8 +118,8 @@ def get_column_pair_plot(
2D bivariate distribution plot (i.e. a scatterplot) of the columns.
"""
metadata = metadata.tables[table_name]
real_data = real_data[table_name]
synthetic_data = synthetic_data[table_name]
real_data = real_data[table_name] if real_data else None
synthetic_data = synthetic_data[table_name] if synthetic_data else None
return single_table_visualization.get_column_pair_plot(
real_data, synthetic_data, metadata, column_names, sample_size, plot_type
)
Expand Down
46 changes: 45 additions & 1 deletion tests/unit/evaluation/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,31 @@ def test_get_column_plot(mock_plot):
assert plot == 'plot'


@patch('sdv.evaluation.single_table.get_column_plot')
def test_get_column_plot_only_real_or_synthetic(mock_plot):
"""Test that ``get_column_plot`` works when only real or synthetic data is provided."""
# Setup
table1 = pd.DataFrame({'col': [1, 2, 3]})
data1 = {'table': table1}
metadata = MultiTableMetadata()
metadata.detect_table_from_dataframe('table', table1)
mock_plot.return_value = 'plot'

# Run
get_column_plot(data1, None, metadata, 'table', 'col')
get_column_plot(None, data1, metadata, 'table', 'col')

# Assert
call_metadata = metadata.tables['table']
mock_plot.assert_has_calls([
((table1, None, call_metadata, 'col', None), {}),
((None, table1, call_metadata, 'col', None), {}),
])


@patch('sdv.evaluation.single_table.get_column_pair_plot')
def test_get_column_pair_plot(mock_plot):
"""Test that ``get_column_pair`` plot is being called with the expected objects."""
"""Test that ``get_column_pair_plot`` is being called with the expected objects."""
# Setup
table1 = pd.DataFrame({'col1': [1, 2, 3], 'col2': [3, 2, 1]})
table2 = pd.DataFrame({'col1': [2, 1, 3], 'col2': [1, 2, 3]})
Expand All @@ -94,6 +116,28 @@ def test_get_column_pair_plot(mock_plot):
assert plot == 'plot'


@patch('sdv.evaluation.single_table.get_column_pair_plot')
def test_get_column_pair_plot_only_real_or_synthetic(mock_plot):
"""Test that ``get_column_pair_plot`` works when only real or synthetic data is provided."""
# Setup
table1 = pd.DataFrame({'col1': [1, 2, 3], 'col2': [3, 2, 1]})
data1 = {'table': table1}
metadata = MultiTableMetadata()
metadata.detect_table_from_dataframe('table', table1)
mock_plot.return_value = 'plot'

# Run
get_column_pair_plot(data1, None, metadata, 'table', ['col1', 'col2'], 2)
get_column_pair_plot(None, data1, metadata, 'table', ['col1', 'col2'], 2)

# Assert
call_metadata = metadata.tables['table']
mock_plot.assert_has_calls([
((table1, None, call_metadata, ['col1', 'col2'], None, 2), {}),
((None, table1, call_metadata, ['col1', 'col2'], None, 2), {}),
])


@patch('sdmetrics.visualization.get_cardinality_plot')
def test_get_cardinality_plot(mock_plot):
"""Test it calls ``get_column_cardinality_plot`` in sdmetrics with the parent primary key."""
Expand Down
Loading