From 0cfb5e978fccc6e8a79e2586867b196749facdb7 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Thu, 3 Dec 2020 06:29:28 -0800 Subject: [PATCH] REF: simplify _cython_operation return (#38253) --- pandas/core/groupby/generic.py | 2 +- pandas/core/groupby/groupby.py | 7 ++++--- pandas/core/groupby/ops.py | 23 ++++++++--------------- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 00561e5441e00..5e78fe8fea00c 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1087,7 +1087,7 @@ def py_fallback(bvalues: ArrayLike) -> ArrayLike: def blk_func(bvalues: ArrayLike) -> ArrayLike: try: - result, _ = self.grouper._cython_operation( + result = self.grouper._cython_operation( "aggregate", bvalues, how, axis=1, min_count=min_count ) except NotImplementedError: diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 583fc6bf8ddb7..798c0742f03e5 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -994,7 +994,7 @@ def _cython_transform( continue try: - result, _ = self.grouper._cython_operation( + result = self.grouper._cython_operation( "transform", obj._values, how, axis, **kwargs ) except NotImplementedError: @@ -1069,12 +1069,13 @@ def _cython_agg_general( if numeric_only and not is_numeric: continue - result, agg_names = self.grouper._cython_operation( + result = self.grouper._cython_operation( "aggregate", obj._values, how, axis=0, min_count=min_count ) - if agg_names: + if how == "ohlc": # e.g. ohlc + agg_names = ["open", "high", "low", "close"] assert len(agg_names) == result.shape[1] for result_column, result_name in zip(result.T, agg_names): key = base.OutputKey(label=result_name, position=idx) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 8046be669ea51..d98c55755042e 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -370,8 +370,6 @@ def get_group_levels(self) -> List[Index]: _cython_arity = {"ohlc": 4} # OHLC - _name_functions = {"ohlc": ["open", "high", "low", "close"]} - def _is_builtin_func(self, arg): """ if we define a builtin function for this argument, return it, @@ -492,36 +490,33 @@ def _ea_wrap_cython_operation( # All of the functions implemented here are ordinal, so we can # operate on the tz-naive equivalents values = values.view("M8[ns]") - res_values, names = self._cython_operation( + res_values = self._cython_operation( kind, values, how, axis, min_count, **kwargs ) if how in ["rank"]: # preserve float64 dtype - return res_values, names + return res_values res_values = res_values.astype("i8", copy=False) result = type(orig_values)._simple_new(res_values, dtype=orig_values.dtype) - return result, names + return result elif is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype): # IntegerArray or BooleanArray values = ensure_int_or_float(values) - res_values, names = self._cython_operation( + res_values = self._cython_operation( kind, values, how, axis, min_count, **kwargs ) result = maybe_cast_result(result=res_values, obj=orig_values, how=how) - return result, names + return result raise NotImplementedError(values.dtype) def _cython_operation( self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs - ) -> Tuple[np.ndarray, Optional[List[str]]]: + ) -> np.ndarray: """ - Returns the values of a cython operation as a Tuple of [data, names]. - - Names is only useful when dealing with 2D results, like ohlc - (see self._name_functions). + Returns the values of a cython operation. """ orig_values = values assert kind in ["transform", "aggregate"] @@ -619,8 +614,6 @@ def _cython_operation( if vdim == 1 and arity == 1: result = result[:, 0] - names: Optional[List[str]] = self._name_functions.get(how, None) - if swapped: result = result.swapaxes(0, axis) @@ -630,7 +623,7 @@ def _cython_operation( dtype = maybe_cast_result_dtype(orig_values.dtype, how) result = maybe_downcast_to_dtype(result, dtype) - return result, names + return result def _aggregate( self, result, counts, values, comp_ids, agg_func, min_count: int = -1