Skip to content

Commit

Permalink
REF: de-duplicate groupby_helper code (#28934)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and WillAyd committed Oct 16, 2019
1 parent b63f829 commit bff90a3
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 88 deletions.
3 changes: 2 additions & 1 deletion pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ def group_any_all(uint8_t[:] out,
const uint8_t[:] mask,
object val_test,
bint skipna):
"""Aggregated boolean values to show truthfulness of group elements
"""
Aggregated boolean values to show truthfulness of group elements.
Parameters
----------
Expand Down
139 changes: 52 additions & 87 deletions pandas/_libs/groupby_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ ctypedef fused rank_t:
object


cdef inline bint _treat_as_na(rank_t val, bint is_datetimelike) nogil:
if rank_t is object:
# Should never be used, but we need to avoid the `val != val` below
# or else cython will raise about gil acquisition.
raise NotImplementedError

elif rank_t is int64_t:
return is_datetimelike and val == NPY_NAT
else:
return val != val


@cython.wraparound(False)
@cython.boundscheck(False)
def group_last(rank_t[:, :] out,
Expand Down Expand Up @@ -61,24 +73,16 @@ def group_last(rank_t[:, :] out,
for j in range(K):
val = values[i, j]

# not nan
if rank_t is int64_t:
# need a special notna check
if val != NPY_NAT:
nobs[lab, j] += 1
resx[lab, j] = val
else:
if val == val:
nobs[lab, j] += 1
resx[lab, j] = val
if val == val:
# NB: use _treat_as_na here once
# conditional-nogil is available.
nobs[lab, j] += 1
resx[lab, j] = val

for i in range(ncounts):
for j in range(K):
if nobs[i, j] == 0:
if rank_t is int64_t:
out[i, j] = NPY_NAT
else:
out[i, j] = NAN
out[i, j] = NAN
else:
out[i, j] = resx[i, j]
else:
Expand All @@ -92,16 +96,10 @@ def group_last(rank_t[:, :] out,
for j in range(K):
val = values[i, j]

# not nan
if rank_t is int64_t:
# need a special notna check
if val != NPY_NAT:
nobs[lab, j] += 1
resx[lab, j] = val
else:
if val == val:
nobs[lab, j] += 1
resx[lab, j] = val
if not _treat_as_na(val, True):
# TODO: Sure we always want is_datetimelike=True?
nobs[lab, j] += 1
resx[lab, j] = val

for i in range(ncounts):
for j in range(K):
Expand All @@ -113,6 +111,7 @@ def group_last(rank_t[:, :] out,
break
else:
out[i, j] = NAN

else:
out[i, j] = resx[i, j]

Expand All @@ -121,7 +120,6 @@ def group_last(rank_t[:, :] out,
# block.
raise RuntimeError("empty group with uint64_t")


group_last_float64 = group_last["float64_t"]
group_last_float32 = group_last["float32_t"]
group_last_int64 = group_last["int64_t"]
Expand Down Expand Up @@ -169,8 +167,9 @@ def group_nth(rank_t[:, :] out,
for j in range(K):
val = values[i, j]

# not nan
if val == val:
# NB: use _treat_as_na here once
# conditional-nogil is available.
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val
Expand All @@ -193,18 +192,11 @@ def group_nth(rank_t[:, :] out,
for j in range(K):
val = values[i, j]

# not nan
if rank_t is int64_t:
# need a special notna check
if val != NPY_NAT:
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val
else:
if val == val:
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val
if not _treat_as_na(val, True):
# TODO: Sure we always want is_datetimelike=True?
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val

for i in range(ncounts):
for j in range(K):
Expand Down Expand Up @@ -487,17 +479,11 @@ def group_max(groupby_t[:, :] out,
for j in range(K):
val = values[i, j]

# not nan
if groupby_t is int64_t:
if val != nan_val:
nobs[lab, j] += 1
if val > maxx[lab, j]:
maxx[lab, j] = val
else:
if val == val:
nobs[lab, j] += 1
if val > maxx[lab, j]:
maxx[lab, j] = val
if not _treat_as_na(val, True):
# TODO: Sure we always want is_datetimelike=True?
nobs[lab, j] += 1
if val > maxx[lab, j]:
maxx[lab, j] = val

for i in range(ncounts):
for j in range(K):
Expand Down Expand Up @@ -563,17 +549,11 @@ def group_min(groupby_t[:, :] out,
for j in range(K):
val = values[i, j]

# not nan
if groupby_t is int64_t:
if val != nan_val:
nobs[lab, j] += 1
if val < minx[lab, j]:
minx[lab, j] = val
else:
if val == val:
nobs[lab, j] += 1
if val < minx[lab, j]:
minx[lab, j] = val
if not _treat_as_na(val, True):
# TODO: Sure we always want is_datetimelike=True?
nobs[lab, j] += 1
if val < minx[lab, j]:
minx[lab, j] = val

for i in range(ncounts):
for j in range(K):
Expand Down Expand Up @@ -643,21 +623,13 @@ def group_cummin(groupby_t[:, :] out,
for j in range(K):
val = values[i, j]

# val = nan
if groupby_t is int64_t:
if is_datetimelike and val == NPY_NAT:
out[i, j] = NPY_NAT
else:
mval = accum[lab, j]
if val < mval:
accum[lab, j] = mval = val
out[i, j] = mval
if _treat_as_na(val, is_datetimelike):
out[i, j] = val
else:
if val == val:
mval = accum[lab, j]
if val < mval:
accum[lab, j] = mval = val
out[i, j] = mval
mval = accum[lab, j]
if val < mval:
accum[lab, j] = mval = val
out[i, j] = mval


@cython.boundscheck(False)
Expand Down Expand Up @@ -712,17 +684,10 @@ def group_cummax(groupby_t[:, :] out,
for j in range(K):
val = values[i, j]

if groupby_t is int64_t:
if is_datetimelike and val == NPY_NAT:
out[i, j] = NPY_NAT
else:
mval = accum[lab, j]
if val > mval:
accum[lab, j] = mval = val
out[i, j] = mval
if _treat_as_na(val, is_datetimelike):
out[i, j] = val
else:
if val == val:
mval = accum[lab, j]
if val > mval:
accum[lab, j] = mval = val
out[i, j] = mval
mval = accum[lab, j]
if val > mval:
accum[lab, j] = mval = val
out[i, j] = mval

0 comments on commit bff90a3

Please sign in to comment.