Skip to content

Commit

Permalink
REF: simplify _cython_operation return (#38253)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Dec 3, 2020
1 parent 8ac84fa commit 0cfb5e9
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,7 @@ def py_fallback(bvalues: ArrayLike) -> ArrayLike:
def blk_func(bvalues: ArrayLike) -> ArrayLike:

try:
result, _ = self.grouper._cython_operation(
result = self.grouper._cython_operation(
"aggregate", bvalues, how, axis=1, min_count=min_count
)
except NotImplementedError:
Expand Down
7 changes: 4 additions & 3 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ def _cython_transform(
continue

try:
result, _ = self.grouper._cython_operation(
result = self.grouper._cython_operation(
"transform", obj._values, how, axis, **kwargs
)
except NotImplementedError:
Expand Down Expand Up @@ -1069,12 +1069,13 @@ def _cython_agg_general(
if numeric_only and not is_numeric:
continue

result, agg_names = self.grouper._cython_operation(
result = self.grouper._cython_operation(
"aggregate", obj._values, how, axis=0, min_count=min_count
)

if agg_names:
if how == "ohlc":
# e.g. ohlc
agg_names = ["open", "high", "low", "close"]
assert len(agg_names) == result.shape[1]
for result_column, result_name in zip(result.T, agg_names):
key = base.OutputKey(label=result_name, position=idx)
Expand Down
23 changes: 8 additions & 15 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,6 @@ def get_group_levels(self) -> List[Index]:

_cython_arity = {"ohlc": 4} # OHLC

_name_functions = {"ohlc": ["open", "high", "low", "close"]}

def _is_builtin_func(self, arg):
"""
if we define a builtin function for this argument, return it,
Expand Down Expand Up @@ -492,36 +490,33 @@ def _ea_wrap_cython_operation(
# All of the functions implemented here are ordinal, so we can
# operate on the tz-naive equivalents
values = values.view("M8[ns]")
res_values, names = self._cython_operation(
res_values = self._cython_operation(
kind, values, how, axis, min_count, **kwargs
)
if how in ["rank"]:
# preserve float64 dtype
return res_values, names
return res_values

res_values = res_values.astype("i8", copy=False)
result = type(orig_values)._simple_new(res_values, dtype=orig_values.dtype)
return result, names
return result

elif is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype):
# IntegerArray or BooleanArray
values = ensure_int_or_float(values)
res_values, names = self._cython_operation(
res_values = self._cython_operation(
kind, values, how, axis, min_count, **kwargs
)
result = maybe_cast_result(result=res_values, obj=orig_values, how=how)
return result, names
return result

raise NotImplementedError(values.dtype)

def _cython_operation(
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
) -> Tuple[np.ndarray, Optional[List[str]]]:
) -> np.ndarray:
"""
Returns the values of a cython operation as a Tuple of [data, names].
Names is only useful when dealing with 2D results, like ohlc
(see self._name_functions).
Returns the values of a cython operation.
"""
orig_values = values
assert kind in ["transform", "aggregate"]
Expand Down Expand Up @@ -619,8 +614,6 @@ def _cython_operation(
if vdim == 1 and arity == 1:
result = result[:, 0]

names: Optional[List[str]] = self._name_functions.get(how, None)

if swapped:
result = result.swapaxes(0, axis)

Expand All @@ -630,7 +623,7 @@ def _cython_operation(
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
result = maybe_downcast_to_dtype(result, dtype)

return result, names
return result

def _aggregate(
self, result, counts, values, comp_ids, agg_func, min_count: int = -1
Expand Down

0 comments on commit 0cfb5e9

Please sign in to comment.