Skip to content

Commit

Permalink
Fixes #79
Browse files Browse the repository at this point in the history
  • Loading branch information
ml31415 committed Oct 3, 2023
1 parent c292eee commit 162fdb9
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 4 deletions.
3 changes: 3 additions & 0 deletions numpy_groupies/aggregate_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def _nancumsum(group_idx, a, size, fill_value=None, dtype=None):
generic=_generic_callable,
)
_impl_dict.update(("nan" + k, v) for k, v in list(_impl_dict.items()) if k not in funcs_no_separate_nan)
_impl_dict["nancumsum"] = _nancumsum


def _aggregate_base(
Expand Down Expand Up @@ -308,6 +309,8 @@ def _aggregate_base(
if "nan" in func:
if "arg" in func:
kwargs["_nansqueeze"] = True
elif "cum" in func:
pass
else:
good = ~np.isnan(a)
if "len" not in func or is_pandas:
Expand Down
1 change: 0 additions & 1 deletion numpy_groupies/tests/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
may throw NotImplementedError in order to show missing functionality without throwing
test errors.
"""
import sys
from itertools import product

import numpy as np
Expand Down
17 changes: 17 additions & 0 deletions numpy_groupies/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def _deselect_purepy(aggregate_all, *args, **kwargs):
return aggregate_all.__name__.endswith("purepy")


def _deselect_purepy_and_pandas(aggregate_all, *args, **kwargs):
# purepy and pandas implementation handle some nan cases differently.
# So they need to be excluded from several tests."""
return aggregate_all.__name__.endswith(("pandas", "purepy"))


def _deselect_purepy_and_invalid_axis(aggregate_all, size, axis, *args, **kwargs):
if axis >= len(size):
return True
Expand Down Expand Up @@ -358,6 +364,17 @@ def test_cumsum(aggregate_all):
np.testing.assert_array_equal(res, ref)


@pytest.mark.deselect_if(func=_deselect_purepy_and_pandas)
def test_nancumsum(aggregate_all):
# https://github.com/ml31415/numpy-groupies/issues/79
group_idx = [0, 0, 0, 1, 1, 0, 0]
a = [2, 2, np.nan, 2, 2, 2, 2]
ref = [2., 4., 4., 2., 4., 6., 8.]

res = aggregate_all(group_idx, a, func="nancumsum")
np.testing.assert_array_equal(res, ref)


def test_cummax(aggregate_all):
group_idx = np.array([4, 3, 3, 4, 4, 1, 1, 1, 7, 8, 7, 4, 3, 3, 1, 1])
a = np.array([3, 4, 1, 3, 9, 9, 6, 7, 7, 0, 8, 2, 1, 8, 9, 8])
Expand Down
7 changes: 4 additions & 3 deletions numpy_groupies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@
np.array: "array",
np.asarray: "array",
np.sort: "sort",
np.cumsum: "cumsum",
np.cumprod: "cumprod",
np.nansum: "nansum",
np.nanprod: "nanprod",
np.nanmean: "nanmean",
Expand All @@ -126,8 +128,7 @@
np.nanstd: "nanstd",
np.nanargmax: "nanargmax",
np.nanargmin: "nanargmin",
np.cumsum: "cumsum",
np.cumprod: "cumprod",
np.nancumsum: "nancumsum",
}


Expand All @@ -150,7 +151,7 @@ def get_aliasing(*extra):
alias.update((k, k) for k in set(alias.values()))
# Treat nan-functions as firstclass member and add them directly
for key in set(alias.values()):
if key not in funcs_no_separate_nan:
if key not in funcs_no_separate_nan and not key.startswith("nan"):
key = "nan" + key
alias[key] = key
return alias
Expand Down

0 comments on commit 162fdb9

Please sign in to comment.