Skip to content

Commit

Permalink
Cythonized GroupBy Quantile (#20405)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd authored and jreback committed Feb 28, 2019
1 parent e52f063 commit 64e5612
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 19 deletions.
7 changes: 4 additions & 3 deletions asv_bench/benchmarks/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
method_blacklist = {
'object': {'median', 'prod', 'sem', 'cumsum', 'sum', 'cummin', 'mean',
'max', 'skew', 'cumprod', 'cummax', 'rank', 'pct_change', 'min',
'var', 'mad', 'describe', 'std'},
'var', 'mad', 'describe', 'std', 'quantile'},
'datetime': {'median', 'prod', 'sem', 'cumsum', 'sum', 'mean', 'skew',
'cumprod', 'cummax', 'pct_change', 'var', 'mad', 'describe',
'std'}
Expand Down Expand Up @@ -316,8 +316,9 @@ class GroupByMethods(object):
['all', 'any', 'bfill', 'count', 'cumcount', 'cummax', 'cummin',
'cumprod', 'cumsum', 'describe', 'ffill', 'first', 'head',
'last', 'mad', 'max', 'min', 'median', 'mean', 'nunique',
'pct_change', 'prod', 'rank', 'sem', 'shift', 'size', 'skew',
'std', 'sum', 'tail', 'unique', 'value_counts', 'var'],
'pct_change', 'prod', 'quantile', 'rank', 'sem', 'shift',
'size', 'skew', 'std', 'sum', 'tail', 'unique', 'value_counts',
'var'],
['direct', 'transformation']]

def setup(self, dtype, method, application):
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.25.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ Performance Improvements
- `DataFrame.to_stata()` is now faster when outputting data with any string or non-native endian columns (:issue:`25045`)
- Improved performance of :meth:`Series.searchsorted`. The speedup is especially large when the dtype is
int8/int16/int32 and the searched key is within the integer bounds for the dtype (:issue:`22034`)
- Improved performance of :meth:`pandas.core.groupby.GroupBy.quantile` (:issue:`20405`)


.. _whatsnew_0250.bug_fixes:
Expand Down
6 changes: 6 additions & 0 deletions pandas/_libs/groupby.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
cdef enum InterpolationEnumType:
INTERPOLATION_LINEAR,
INTERPOLATION_LOWER,
INTERPOLATION_HIGHER,
INTERPOLATION_NEAREST,
INTERPOLATION_MIDPOINT
101 changes: 101 additions & 0 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -644,5 +644,106 @@ def _group_ohlc(floating[:, :] out,
group_ohlc_float32 = _group_ohlc['float']
group_ohlc_float64 = _group_ohlc['double']


@cython.boundscheck(False)
@cython.wraparound(False)
def group_quantile(ndarray[float64_t] out,
ndarray[int64_t] labels,
numeric[:] values,
ndarray[uint8_t] mask,
float64_t q,
object interpolation):
"""
Calculate the quantile per group.
Parameters
----------
out : ndarray
Array of aggregated values that will be written to.
labels : ndarray
Array containing the unique group labels.
values : ndarray
Array containing the values to apply the function against.
q : float
The quantile value to search for.
Notes
-----
Rather than explicitly returning a value, this function modifies the
provided `out` parameter.
"""
cdef:
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz
Py_ssize_t grp_start=0, idx=0
int64_t lab
uint8_t interp
float64_t q_idx, frac, val, next_val
ndarray[int64_t] counts, non_na_counts, sort_arr

assert values.shape[0] == N
inter_methods = {
'linear': INTERPOLATION_LINEAR,
'lower': INTERPOLATION_LOWER,
'higher': INTERPOLATION_HIGHER,
'nearest': INTERPOLATION_NEAREST,
'midpoint': INTERPOLATION_MIDPOINT,
}
interp = inter_methods[interpolation]

counts = np.zeros_like(out, dtype=np.int64)
non_na_counts = np.zeros_like(out, dtype=np.int64)
ngroups = len(counts)

# First figure out the size of every group
with nogil:
for i in range(N):
lab = labels[i]
counts[lab] += 1
if not mask[i]:
non_na_counts[lab] += 1

# Get an index of values sorted by labels and then values
order = (values, labels)
sort_arr = np.lexsort(order).astype(np.int64, copy=False)

with nogil:
for i in range(ngroups):
# Figure out how many group elements there are
grp_sz = counts[i]
non_na_sz = non_na_counts[i]

if non_na_sz == 0:
out[i] = NaN
else:
# Calculate where to retrieve the desired value
# Casting to int will intentionaly truncate result
idx = grp_start + <int64_t>(q * <float64_t>(non_na_sz - 1))

val = values[sort_arr[idx]]
# If requested quantile falls evenly on a particular index
# then write that index's value out. Otherwise interpolate
q_idx = q * (non_na_sz - 1)
frac = q_idx % 1

if frac == 0.0 or interp == INTERPOLATION_LOWER:
out[i] = val
else:
next_val = values[sort_arr[idx + 1]]
if interp == INTERPOLATION_LINEAR:
out[i] = val + (next_val - val) * frac
elif interp == INTERPOLATION_HIGHER:
out[i] = next_val
elif interp == INTERPOLATION_MIDPOINT:
out[i] = (val + next_val) / 2.0
elif interp == INTERPOLATION_NEAREST:
if frac > .5 or (frac == .5 and q > .5): # Always OK?
out[i] = next_val
else:
out[i] = val

# Increment the index reference in sorted_arr for the next group
grp_start += grp_sz


# generated from template
include "groupby_helper.pxi"
103 changes: 92 additions & 11 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class providing the base-class of operations.
ensure_float, is_extension_array_dtype, is_numeric_dtype, is_scalar)
from pandas.core.dtypes.missing import isna, notna

from pandas.api.types import (
is_datetime64_dtype, is_integer_dtype, is_object_dtype)
import pandas.core.algorithms as algorithms
from pandas.core.base import (
DataError, GroupByError, PandasObject, SelectionMixin, SpecificationError)
Expand Down Expand Up @@ -1024,15 +1026,17 @@ def _bool_agg(self, val_test, skipna):
"""

def objs_to_bool(vals):
try:
vals = vals.astype(np.bool)
except ValueError: # for objects
# type: np.ndarray -> (np.ndarray, typing.Type)
if is_object_dtype(vals):
vals = np.array([bool(x) for x in vals])
else:
vals = vals.astype(np.bool)

return vals.view(np.uint8)
return vals.view(np.uint8), np.bool

def result_to_bool(result):
return result.astype(np.bool, copy=False)
def result_to_bool(result, inference):
# type: (np.ndarray, typing.Type) -> np.ndarray
return result.astype(inference, copy=False)

return self._get_cythonized_result('group_any_all', self.grouper,
aggregate=True,
Expand Down Expand Up @@ -1688,6 +1692,75 @@ def nth(self, n, dropna=None):

return result

def quantile(self, q=0.5, interpolation='linear'):
"""
Return group values at the given quantile, a la numpy.percentile.
Parameters
----------
q : float or array-like, default 0.5 (50% quantile)
Value(s) between 0 and 1 providing the quantile(s) to compute.
interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'}
Method to use when the desired quantile falls between two points.
Returns
-------
Series or DataFrame
Return type determined by caller of GroupBy object.
See Also
--------
Series.quantile : Similar method for Series.
DataFrame.quantile : Similar method for DataFrame.
numpy.percentile : NumPy method to compute qth percentile.
Examples
--------
>>> df = pd.DataFrame([
... ['a', 1], ['a', 2], ['a', 3],
... ['b', 1], ['b', 3], ['b', 5]
... ], columns=['key', 'val'])
>>> df.groupby('key').quantile()
val
key
a 2.0
b 3.0
"""

def pre_processor(vals):
# type: np.ndarray -> (np.ndarray, Optional[typing.Type])
if is_object_dtype(vals):
raise TypeError("'quantile' cannot be performed against "
"'object' dtypes!")

inference = None
if is_integer_dtype(vals):
inference = np.int64
elif is_datetime64_dtype(vals):
inference = 'datetime64[ns]'
vals = vals.astype(np.float)

return vals, inference

def post_processor(vals, inference):
# type: (np.ndarray, Optional[typing.Type]) -> np.ndarray
if inference:
# Check for edge case
if not (is_integer_dtype(inference) and
interpolation in {'linear', 'midpoint'}):
vals = vals.astype(inference)

return vals

return self._get_cythonized_result('group_quantile', self.grouper,
aggregate=True,
needs_values=True,
needs_mask=True,
cython_dtype=np.float64,
pre_processing=pre_processor,
post_processing=post_processor,
q=q, interpolation=interpolation)

@Substitution(name='groupby')
def ngroup(self, ascending=True):
"""
Expand Down Expand Up @@ -1924,10 +1997,16 @@ def _get_cythonized_result(self, how, grouper, aggregate=False,
Whether the result of the Cython operation is an index of
values to be retrieved, instead of the actual values themselves
pre_processing : function, default None
Function to be applied to `values` prior to passing to Cython
Raises if `needs_values` is False
Function to be applied to `values` prior to passing to Cython.
Function should return a tuple where the first element is the
values to be passed to Cython and the second element is an optional
type which the values should be converted to after being returned
by the Cython operation. Raises if `needs_values` is False.
post_processing : function, default None
Function to be applied to result of Cython function
Function to be applied to result of Cython function. Should accept
an array of values as the first argument and type inferences as its
second argument, i.e. the signature should be
(ndarray, typing.Type).
**kwargs : dict
Extra arguments to be passed back to Cython funcs
Expand Down Expand Up @@ -1963,10 +2042,12 @@ def _get_cythonized_result(self, how, grouper, aggregate=False,

result = np.zeros(result_sz, dtype=cython_dtype)
func = partial(base_func, result, labels)
inferences = None

if needs_values:
vals = obj.values
if pre_processing:
vals = pre_processing(vals)
vals, inferences = pre_processing(vals)
func = partial(func, vals)

if needs_mask:
Expand All @@ -1982,7 +2063,7 @@ def _get_cythonized_result(self, how, grouper, aggregate=False,
result = algorithms.take_nd(obj.values, result)

if post_processing:
result = post_processing(result)
result = post_processing(result, inferences)

output[name] = result

Expand Down
49 changes: 49 additions & 0 deletions pandas/tests/groupby/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,55 @@ def test_size(df):
tm.assert_series_equal(df.groupby('A').size(), out)


# quantile
# --------------------------------
@pytest.mark.parametrize("interpolation", [
"linear", "lower", "higher", "nearest", "midpoint"])
@pytest.mark.parametrize("a_vals,b_vals", [
# Ints
([1, 2, 3, 4, 5], [5, 4, 3, 2, 1]),
([1, 2, 3, 4], [4, 3, 2, 1]),
([1, 2, 3, 4, 5], [4, 3, 2, 1]),
# Floats
([1., 2., 3., 4., 5.], [5., 4., 3., 2., 1.]),
# Missing data
([1., np.nan, 3., np.nan, 5.], [5., np.nan, 3., np.nan, 1.]),
([np.nan, 4., np.nan, 2., np.nan], [np.nan, 4., np.nan, 2., np.nan]),
# Timestamps
([x for x in pd.date_range('1/1/18', freq='D', periods=5)],
[x for x in pd.date_range('1/1/18', freq='D', periods=5)][::-1]),
# All NA
([np.nan] * 5, [np.nan] * 5),
])
@pytest.mark.parametrize('q', [0, .25, .5, .75, 1])
def test_quantile(interpolation, a_vals, b_vals, q):
if interpolation == 'nearest' and q == 0.5 and b_vals == [4, 3, 2, 1]:
pytest.skip("Unclear numpy expectation for nearest result with "
"equidistant data")

a_expected = pd.Series(a_vals).quantile(q, interpolation=interpolation)
b_expected = pd.Series(b_vals).quantile(q, interpolation=interpolation)

df = DataFrame({
'key': ['a'] * len(a_vals) + ['b'] * len(b_vals),
'val': a_vals + b_vals})

expected = DataFrame([a_expected, b_expected], columns=['val'],
index=Index(['a', 'b'], name='key'))
result = df.groupby('key').quantile(q, interpolation=interpolation)

tm.assert_frame_equal(result, expected)


def test_quantile_raises():
df = pd.DataFrame([
['foo', 'a'], ['foo', 'b'], ['foo', 'c']], columns=['key', 'val'])

with pytest.raises(TypeError, match="cannot be performed against "
"'object' dtypes"):
df.groupby('key').quantile()


# pipe
# --------------------------------

Expand Down
10 changes: 5 additions & 5 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def f(x, q=None, axis=0):
trans_expected = ts_grouped.transform(g)

assert_series_equal(apply_result, agg_expected)
assert_series_equal(agg_result, agg_expected, check_names=False)
assert_series_equal(agg_result, agg_expected)
assert_series_equal(trans_result, trans_expected)

agg_result = ts_grouped.agg(f, q=80)
Expand All @@ -223,13 +223,13 @@ def f(x, q=None, axis=0):
agg_result = df_grouped.agg(np.percentile, 80, axis=0)
apply_result = df_grouped.apply(DataFrame.quantile, .8)
expected = df_grouped.quantile(.8)
assert_frame_equal(apply_result, expected)
assert_frame_equal(agg_result, expected, check_names=False)
assert_frame_equal(apply_result, expected, check_names=False)
assert_frame_equal(agg_result, expected)

agg_result = df_grouped.agg(f, q=80)
apply_result = df_grouped.apply(DataFrame.quantile, q=.8)
assert_frame_equal(agg_result, expected, check_names=False)
assert_frame_equal(apply_result, expected)
assert_frame_equal(agg_result, expected)
assert_frame_equal(apply_result, expected, check_names=False)


def test_len():
Expand Down

0 comments on commit 64e5612

Please sign in to comment.