Skip to content

Commit

Permalink
✅ tests: update tsdownsample tests to support nan downsamplers
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsPraet committed Aug 3, 2023
1 parent acf1ff3 commit 585a889
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions tests/test_tsdownsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

import numpy as np
import pytest
from test_config import supported_dtypes_x, supported_dtypes_y
from test_config import supported_dtypes_x, supported_dtypes_y, supported_dtypes_y_nan

from tsdownsample import ( # MeanDownsampler,; MedianDownsampler,
EveryNthDownsampler,
LTTBDownsampler,
M4Downsampler,
MinMaxDownsampler,
MinMaxLTTBDownsampler,
NaNM4Downsampler,
NanMinMaxDownsampler,
NaNMinMaxLTTBDownsampler,
)
from tsdownsample.downsampling_interface import AbstractDownsampler

Expand All @@ -24,18 +27,32 @@
MinMaxLTTBDownsampler(),
]

OTHER_DOWNSAMPLERS = [EveryNthDownsampler()]
RUST_NAN_DOWNSAMPLERS = [
NanMinMaxDownsampler(),
NaNM4Downsampler(),
NaNMinMaxLTTBDownsampler()
]

OTHER_DOWNSAMPLERS = [EveryNthDownsampler()]

def generate_rust_downsamplers() -> Iterable[AbstractDownsampler]:
for downsampler in RUST_DOWNSAMPLERS:
for downsampler in RUST_DOWNSAMPLERS + RUST_NAN_DOWNSAMPLERS:
yield downsampler


def generate_all_downsamplers() -> Iterable[AbstractDownsampler]:
for downsampler in RUST_DOWNSAMPLERS + OTHER_DOWNSAMPLERS:
for downsampler in RUST_DOWNSAMPLERS + RUST_NAN_DOWNSAMPLERS + OTHER_DOWNSAMPLERS:
yield downsampler

def is_nan_downsampler(obj):
return obj.__class__.__name__ in [x.__class__.__name__ for x in RUST_NAN_DOWNSAMPLERS]

def generate_datapoints(obj):
N_DATAPOINTS = 10_000
if is_nan_downsampler(obj):
return np.arange(N_DATAPOINTS, dtype=np.float64)
else:
return np.arange(N_DATAPOINTS)

@pytest.mark.parametrize("downsampler", generate_all_downsamplers())
def test_serialization_copy(downsampler: AbstractDownsampler):
Expand All @@ -45,7 +62,8 @@ def test_serialization_copy(downsampler: AbstractDownsampler):
dc = copy(downsampler)
ddc = deepcopy(downsampler)

arr = np.arange(10_000)
arr = generate_datapoints(downsampler)

orig_downsampled = downsampler.downsample(arr, n_out=100)
dc_downsampled = dc.downsample(arr, n_out=100)
ddc_downsampled = ddc.downsample(arr, n_out=100)
Expand All @@ -60,7 +78,7 @@ def test_serialization_pickle(downsampler: AbstractDownsampler):

dc = pickle.loads(pickle.dumps(downsampler))

arr = np.arange(10_000)
arr = generate_datapoints(downsampler)
orig_downsampled = downsampler.downsample(arr, n_out=100)
dc_downsampled = dc.downsample(arr, n_out=100)
assert np.all(orig_downsampled == dc_downsampled)
Expand All @@ -69,7 +87,7 @@ def test_serialization_pickle(downsampler: AbstractDownsampler):
@pytest.mark.parametrize("downsampler", generate_rust_downsamplers())
def test_rust_downsampler(downsampler: AbstractDownsampler):
"""Test the Rust downsamplers."""
arr = np.arange(10_000)
arr = generate_datapoints(downsampler)
s_downsampled = downsampler.downsample(arr, n_out=100)
assert s_downsampled[0] == 0
assert s_downsampled[-1] == len(arr) - 1
Expand Down Expand Up @@ -134,7 +152,8 @@ def test_downsampling_different_dtypes(downsampler: AbstractDownsampler):
"""Test downsampling with different data types."""
arr_orig = np.random.randint(0, 100, size=10_000)
res = []
for dtype_y in supported_dtypes_y:
y_dtypes = supported_dtypes_y_nan if is_nan_downsampler(downsampler) else supported_dtypes_y
for dtype_y in y_dtypes:
arr = arr_orig.astype(dtype_y)
s_downsampled = downsampler.downsample(arr, n_out=100)
if dtype_y is not np.bool_:
Expand All @@ -148,10 +167,11 @@ def test_downsampling_different_dtypes_with_x(downsampler: AbstractDownsampler):
"""Test downsampling with x with different data types."""
arr_orig = np.random.randint(0, 100, size=10_000)
idx_orig = np.arange(len(arr_orig))
y_dtypes = supported_dtypes_y_nan if is_nan_downsampler(downsampler) else supported_dtypes_y
for dtype_x in supported_dtypes_x:
res = []
idx = idx_orig.astype(dtype_x)
for dtype_y in supported_dtypes_y:
for dtype_y in y_dtypes:
arr = arr_orig.astype(dtype_y)
s_downsampled = downsampler.downsample(idx, arr, n_out=100)
if dtype_y is not np.bool_:
Expand All @@ -167,7 +187,8 @@ def test_downsampling_no_out_of_bounds_different_dtypes(
"""Test no out of bounds issues when downsampling with different data types."""
arr_orig = np.random.randint(0, 100, size=100)
res = []
for dtype in supported_dtypes_y:
y_dtypes = supported_dtypes_y_nan if is_nan_downsampler(downsampler) else supported_dtypes_y
for dtype in y_dtypes:
arr = arr_orig.astype(dtype)
s_downsampled = downsampler.downsample(arr, n_out=76)
s_downsampled_p = downsampler.downsample(arr, n_out=76, n_threads=2)
Expand All @@ -185,10 +206,11 @@ def test_downsampling_no_out_of_bounds_different_dtypes_with_x(
"""Test no out of bounds issues when downsampling with different data types."""
arr_orig = np.random.randint(0, 100, size=100)
idx_orig = np.arange(len(arr_orig))
y_dtypes = supported_dtypes_y_nan if is_nan_downsampler(downsampler) else supported_dtypes_y
for dtype_x in supported_dtypes_x:
res = []
idx = idx_orig.astype(dtype_x)
for dtype_y in supported_dtypes_y:
for dtype_y in y_dtypes:
arr = arr_orig.astype(dtype_y)
s_downsampled = downsampler.downsample(idx, arr, n_out=76)
s_downsampled_p = downsampler.downsample(idx, arr, n_out=76, n_threads=2)
Expand Down

0 comments on commit 585a889

Please sign in to comment.