Skip to content

Commit

Permalink
Add a utility function get_random_sequence_subset (#2098)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontanez24 authored Jul 2, 2024
1 parent cb49b38 commit d8962df
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 5 deletions.
4 changes: 2 additions & 2 deletions sdv/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Utils module."""

from sdv.utils.utils import drop_unknown_references
from sdv.utils.utils import drop_unknown_references, get_random_sequence_subset

__all__ = ('drop_unknown_references',)
__all__ = ('drop_unknown_references', 'get_random_sequence_subset')
69 changes: 69 additions & 0 deletions sdv/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from copy import deepcopy

import numpy as np
import pandas as pd

from sdv._utils import _validate_foreign_keys_not_null
Expand Down Expand Up @@ -60,3 +61,71 @@ def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=Tr
sys.stdout.write('\n'.join([success_message, '', summary_table.to_string(index=False)]))

return result


def get_random_sequence_subset(
data,
metadata,
num_sequences,
max_sequence_length=None,
long_sequence_subsampling_method='first_rows',
):
"""Subsample sequential data based on a number of sequences.
Args:
data (pandas.DataFrame):
The sequential data.
metadata (SingleTableMetadata):
A SingleTableMetadata object describing the data.
num_sequences (int):
The number of sequences to subsample.
max_sequence_length (int):
The maximum length each subsampled sequence is allowed to be. Defaults to None. If
None, do not enforce any max length, meaning that entire sequences will be sampled.
If provided all subsampled sequences must be <= the provided length.
long_sequence_subsampling_method (str):
The method to use when a selected sequence is too long. Options are:
- first_rows (default): Keep the first n rows of the sequence, where n is the max
sequence length.
- last_rows: Keep the last n rows of the sequence, where n is the max sequence length.
- random: Randomly choose n rows to keep within the sequence. It is important to keep
the randomly chosen rows in the same order as they appear in the original data.
"""
if long_sequence_subsampling_method not in ['first_rows', 'last_rows', 'random']:
raise ValueError(
'long_sequence_subsampling_method must be one of "first_rows", "last_rows" or "random"'
)

sequence_key = metadata.sequence_key
if not sequence_key:
raise ValueError(
'Your metadata does not include a sequence key. A sequence key must be provided to '
'subset the sequential data.'
)

if sequence_key not in data.columns:
raise ValueError(
'Your provided sequence key is not in the data. This is required to get a subset.'
)

selected_sequences = np.random.permutation(data[sequence_key])[:num_sequences]
subset = data[data[sequence_key].isin(selected_sequences)].reset_index(drop=True)
if max_sequence_length:
grouped_sequences = subset.groupby(sequence_key)
if long_sequence_subsampling_method == 'first_rows':
return grouped_sequences.head(max_sequence_length).reset_index(drop=True)
elif long_sequence_subsampling_method == 'last_rows':
return grouped_sequences.tail(max_sequence_length).reset_index(drop=True)
else:
subsetted_sequences = []
for _, group in grouped_sequences:
if len(group) > max_sequence_length:
idx = np.random.permutation(len(group))[:max_sequence_length]
idx.sort()
subsetted_sequences.append(group.iloc[idx])
else:
subsetted_sequences.append(group)

return pd.concat(subsetted_sequences, ignore_index=True)

return subset
54 changes: 53 additions & 1 deletion tests/integration/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import pandas as pd
import pytest

from sdv.datasets.demo import download_demo
from sdv.errors import InvalidDataError
from sdv.metadata import MultiTableMetadata
from sdv.utils import drop_unknown_references
from sdv.utils import drop_unknown_references, get_random_sequence_subset


@pytest.fixture
Expand Down Expand Up @@ -140,3 +141,54 @@ def test_drop_unknown_references_not_drop_missing_values(metadata, data):
pd.testing.assert_frame_equal(cleaned_data['child'], data['child'].iloc[:4])
assert pd.isna(cleaned_data['child']['parent_id']).any()
assert len(cleaned_data['child']) == 4


def test_get_random_sequence_subset():
"""Test that the sequences are subsetted and properly clipped."""
# Setup
data, metadata = download_demo(modality='sequential', dataset_name='nasdaq100_2019')

# Run
subset = get_random_sequence_subset(data, metadata, num_sequences=3, max_sequence_length=5)

# Assert
selected_sequences = subset[metadata.sequence_key].unique()
assert len(selected_sequences) == 3
for sequence_key in selected_sequences:
pd.testing.assert_frame_equal(
subset[subset[metadata.sequence_key] == sequence_key].reset_index(drop=True),
data[data[metadata.sequence_key] == sequence_key].head(5).reset_index(drop=True),
)


def test_get_random_sequence_subset_random_clipping():
"""Test that the sequences are subsetted and properly clipped.
If the long_sequence_sampling_method is set to 'random', the selected sequences should be
subsampled randomly, but maintain the same order.
"""
# Setup
data, metadata = download_demo(modality='sequential', dataset_name='nasdaq100_2019')

# Run
subset = get_random_sequence_subset(
data,
metadata,
num_sequences=3,
max_sequence_length=5,
long_sequence_subsampling_method='random',
)

# Assert
selected_sequences = subset[metadata.sequence_key].unique()
assert len(selected_sequences) == 3
for sequence_key in selected_sequences:
selected_sequence = subset[subset[metadata.sequence_key] == sequence_key]
assert len(selected_sequence) <= 5
subset_data = data[
data['Date'].isin(selected_sequence['Date'])
& data['Symbol'].isin(selected_sequence['Symbol'])
]
pd.testing.assert_frame_equal(
subset_data.reset_index(drop=True), selected_sequence.reset_index(drop=True)
)
6 changes: 5 additions & 1 deletion tests/unit/utils/test_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from sdv.errors import InvalidDataError
from sdv.metadata import MultiTableMetadata
from sdv.metadata.errors import InvalidMetadataError
from sdv.utils.poc import drop_unknown_references, get_random_subset, simplify_schema
from sdv.utils.poc import (
drop_unknown_references,
get_random_subset,
simplify_schema,
)


@patch('sdv.utils.poc.utils_drop_unknown_references')
Expand Down
154 changes: 153 additions & 1 deletion tests/unit/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pytest

from sdv.errors import InvalidDataError
from sdv.utils.utils import drop_unknown_references
from sdv.metadata import SingleTableMetadata
from sdv.utils.utils import drop_unknown_references, get_random_sequence_subset


@patch('sdv.utils.utils._drop_rows')
Expand Down Expand Up @@ -355,3 +356,154 @@ def test_drop_unknown_references_drop_all_rows(mock_get_rows_to_drop):
)
with pytest.raises(InvalidDataError, match=expected_message):
drop_unknown_references(data, metadata)


def test_get_random_sequence_subset_no_sequence_key():
"""Test that an error is raised if no sequence_key is provided in the metadata."""
# Setup
metadata = Mock(spec=SingleTableMetadata)
metadata.sequence_key = None

# Run and Assert
error_message = (
'Your metadata does not include a sequence key. A sequence key must be provided to subset'
' the sequential data.'
)
with pytest.raises(ValueError, match=error_message):
get_random_sequence_subset(pd.DataFrame(), metadata, 3)


def test_get_random_sequence_subset_sequence_key_not_in_data():
"""Test that an error is raised if the data doesn't contain the sequence_key."""
# Setup
metadata = Mock(spec=SingleTableMetadata)
metadata.sequence_key = 'key'

# Run and Assert
error_message = (
'Your provided sequence key is not in the data. This is required to get a subset.'
)
with pytest.raises(ValueError, match=error_message):
get_random_sequence_subset(pd.DataFrame(), metadata, 3)


def test_get_random_sequence_subset_bad_long_sequence_subsampling_method():
"""Test that an error is raised if the long_sequence_subsampling_method is invalid."""
# Setup
metadata = Mock(spec=SingleTableMetadata)
metadata.sequence_key = 'key'

# Run and Assert
error_message = (
'long_sequence_subsampling_method must be one of "first_rows", "last_rows" or "random"'
)
with pytest.raises(ValueError, match=error_message):
get_random_sequence_subset(pd.DataFrame(), metadata, 3, 10, 'blah')


@patch('sdv.utils.utils.np')
def test_get_random_sequence_subset_no_max_sequence_length(mock_np):
"""Test that the sequences are subsetted but each sequence is full."""
# Setup
data = pd.DataFrame({'key': ['a'] * 10 + ['b'] * 7 + ['c'] * 9 + ['d'] * 4, 'value': range(30)})
metadata = Mock(spec=SingleTableMetadata)
metadata.sequence_key = 'key'
mock_np.random.permutation.return_value = np.array(['a', 'd'])

# Run
subset = get_random_sequence_subset(data, metadata, num_sequences=2)

# Assert
expected = pd.DataFrame({
'key': ['a'] * 10 + ['d'] * 4,
'value': list(range(10)) + [26, 27, 28, 29],
})
pd.testing.assert_frame_equal(expected, subset)


@patch('sdv.utils.utils.np')
def test_get_random_sequence_subset_use_first_rows(mock_np):
"""Test that the sequences are subsetted and subsampled properly.
If 'long_sequence_subsampling_method' isn't set, the sequences should be clipped using the
first 'max_sequence_length' rows.
"""
# Setup
data = pd.DataFrame({'key': ['a'] * 10 + ['b'] * 7 + ['c'] * 9 + ['d'] * 4, 'value': range(30)})
metadata = Mock(spec=SingleTableMetadata)
metadata.sequence_key = 'key'
mock_np.random.permutation.return_value = np.array(['a', 'b', 'd'])

# Run
subset = get_random_sequence_subset(data, metadata, num_sequences=3, max_sequence_length=6)

# Assert
expected = pd.DataFrame({
'key': ['a'] * 6 + ['b'] * 6 + ['d'] * 4,
'value': [0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 14, 15, 26, 27, 28, 29],
})
pd.testing.assert_frame_equal(expected, subset)


@patch('sdv.utils.utils.np')
def test_get_random_sequence_subset_use_last_rows(mock_np):
"""Test that the sequences are subsetted and subsampled properly.
If 'long_sequence_subsampling_method' isn't set, the sequences should be clipped using the
last 'max_sequence_length' rows.
"""
# Setup
data = pd.DataFrame({'key': ['a'] * 10 + ['b'] * 7 + ['c'] * 9 + ['d'] * 4, 'value': range(30)})
metadata = Mock(spec=SingleTableMetadata)
metadata.sequence_key = 'key'
mock_np.random.permutation.return_value = np.array(['a', 'b', 'd'])

# Run
subset = get_random_sequence_subset(
data,
metadata,
num_sequences=3,
max_sequence_length=6,
long_sequence_subsampling_method='last_rows',
)

# Assert
expected = pd.DataFrame({
'key': ['a'] * 6 + ['b'] * 6 + ['d'] * 4,
'value': [4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 26, 27, 28, 29],
})
pd.testing.assert_frame_equal(expected, subset)


@patch('sdv.utils.utils.np')
def test_get_random_sequence_subset_use_random_rows(mock_np):
"""Test that the sequences are subsetted and subsampled properly.
If 'long_sequence_subsampling_method' isn't set, the sequences should be clipped using random
'max_sequence_length' rows.
"""
# Setup
data = pd.DataFrame({'key': ['a'] * 10 + ['b'] * 7 + ['c'] * 9 + ['d'] * 4, 'value': range(30)})
metadata = Mock(spec=SingleTableMetadata)
metadata.sequence_key = 'key'
mock_np.random.permutation.side_effect = [
np.array(['a', 'b', 'd']),
np.array([0, 2, 4, 5, 7, 9]),
np.array([6, 5, 1, 2, 4, 0]),
]

# Run
subset = get_random_sequence_subset(
data,
metadata,
num_sequences=3,
max_sequence_length=6,
long_sequence_subsampling_method='random',
)

# Assert
expected = pd.DataFrame({
'key': ['a'] * 6 + ['b'] * 6 + ['d'] * 4,
'value': [0, 2, 4, 5, 7, 9, 10, 11, 12, 14, 15, 16, 26, 27, 28, 29],
})
pd.testing.assert_frame_equal(expected, subset)

0 comments on commit d8962df

Please sign in to comment.