Skip to content

Commit

Permalink
REF: casting in _python_agg_general (#38235)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Dec 2, 2020
1 parent bda4bc3 commit daa5942
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
20 changes: 11 additions & 9 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class providing the base-class of operations.
from pandas.errors import AbstractMethodError
from pandas.util._decorators import Appender, Substitution, cache_readonly, doc

from pandas.core.dtypes.cast import maybe_cast_result
from pandas.core.dtypes.cast import maybe_cast_result, maybe_downcast_to_dtype
from pandas.core.dtypes.common import (
ensure_float,
is_bool_dtype,
Expand Down Expand Up @@ -1185,22 +1185,24 @@ def _python_agg_general(self, func, *args, **kwargs):

assert result is not None
key = base.OutputKey(label=name, position=idx)
output[key] = maybe_cast_result(result, obj, numeric_only=True)

if not output:
return self._python_apply_general(f, self._selected_obj)
if is_numeric_dtype(obj.dtype):
result = maybe_downcast_to_dtype(result, obj.dtype)

if self.grouper._filter_empty_groups:

mask = counts.ravel() > 0
for key, result in output.items():
if self.grouper._filter_empty_groups:
mask = counts.ravel() > 0

# since we are masking, make sure that we have a float object
values = result
if is_numeric_dtype(values.dtype):
values = ensure_float(values)

output[key] = maybe_cast_result(values[mask], result)
result = maybe_downcast_to_dtype(values[mask], result.dtype)

output[key] = result

if not output:
return self._python_apply_general(f, self._selected_obj)

return self._wrap_aggregated_output(output, index=self.grouper.result_index)

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def _aggregate_series_pure_python(self, obj: Series, func: F):
result[label] = res

result = lib.maybe_convert_objects(result, try_float=0)
# TODO: maybe_cast_to_extension_array?
result = maybe_cast_result(result, obj, numeric_only=True)

return result, counts

Expand Down

0 comments on commit daa5942

Please sign in to comment.