Skip to content

Commit

Permalink
BUG: divmod return type
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger committed Oct 1, 2018
1 parent a277e4a commit 52538fa
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 7 deletions.
15 changes: 12 additions & 3 deletions doc/source/extending.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,18 @@ your ``MyExtensionArray`` class, as follows:
MyExtensionArray._add_arithmetic_ops()
MyExtensionArray._add_comparison_ops()
Note that since ``pandas`` automatically calls the underlying operator on each
element one-by-one, this might not be as performant as implementing your own
version of the associated operators directly on the ``ExtensionArray``.
.. note::

Since ``pandas`` automatically calls the underlying operator on each
element one-by-one, this might not be as performant as implementing your own
version of the associated operators directly on the ``ExtensionArray``.

This implementation will try to reconstruct a new ``ExtensionArray`` with the
result of the element-wise operation. Whether or not that succeeds depends on
whether the operation returns a result that's valid for the ``ExtensionArray``.
If an ``ExtensionArray`` cannot be reconstructed, a list containing the scalars
returned instead.

.. _extending.extension.testing:

Expand Down
16 changes: 12 additions & 4 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,10 +775,18 @@ def convert_values(param):
res = [op(a, b) for (a, b) in zip(lvalues, rvalues)]

if coerce_to_dtype:
try:
res = self._from_sequence(res)
except TypeError:
pass
if op.__name__ in {'divmod', 'rdivmod'}:
try:
a, b = zip(*res)
res = (self._from_sequence(a),
self._from_sequence(b))
except TypeError:
pass
else:
try:
res = self._from_sequence(res)
except TypeError:
pass

return res

Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/extension/decimal/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,5 +138,9 @@ def _concat_same_type(cls, to_concat):
return cls(np.concatenate([x._data for x in to_concat]))


def to_decimal(values, context=None):
return DecimalArray([decimal.Decimal(x) for x in values], context=context)


DecimalArray._add_arithmetic_ops()
DecimalArray._add_comparison_ops()
22 changes: 22 additions & 0 deletions pandas/tests/extension/test_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

from pandas.tests.extension.decimal.array import to_decimal
import pandas.util.testing as tm


@pytest.mark.parametrize("reverse, expected_div, expected_mod", [
(False, [0, 1, 1, 2], [1, 0, 1, 0]),
(True, [2, 1, 0, 0], [0, 0, 2, 2]),
])
def test_divmod(reverse, expected_div, expected_mod):
# https://github.com/pandas-dev/pandas/issues/22930
arr = to_decimal([1, 2, 3, 4])
if reverse:
div, mod = divmod(2, arr)
else:
div, mod = divmod(arr, 2)
expected_div = to_decimal(expected_div)
expected_mod = to_decimal(expected_mod)

tm.assert_extension_array_equal(div, expected_div)
tm.assert_extension_array_equal(mod, expected_mod)

0 comments on commit 52538fa

Please sign in to comment.