From 96b71d1626b02ef8d5236d07590a8e168f77f25c Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 18 Dec 2019 02:30:26 -0600 Subject: [PATCH] StringArray comparisions return BooleanArray (#30231) xref https://github.com/pandas-dev/pandas/issues/29556 --- doc/source/user_guide/text.rst | 6 +++- pandas/core/arrays/string_.py | 31 +++++++++++++++----- pandas/core/ops/__init__.py | 33 +++++++++++++++++++++- pandas/tests/arrays/string_/test_string.py | 31 ++++++++++++++++++++ pandas/tests/extension/test_string.py | 2 +- 5 files changed, 93 insertions(+), 10 deletions(-) diff --git a/doc/source/user_guide/text.rst b/doc/source/user_guide/text.rst index 072871f89bdae0..ff0474dbecbb45 100644 --- a/doc/source/user_guide/text.rst +++ b/doc/source/user_guide/text.rst @@ -94,7 +94,11 @@ l. For ``StringDtype``, :ref:`string accessor methods` 2. Some string methods, like :meth:`Series.str.decode` are not available on ``StringArray`` because ``StringArray`` only holds strings, not bytes. - +3. In comparision operations, :class:`arrays.StringArray` and ``Series`` backed + by a ``StringArray`` will return an object with :class:`BooleanDtype`, + rather than a ``bool`` dtype object. Missing values in a ``StringArray`` + will propagate in comparision operations, rather than always comparing + unequal like :attr:`numpy.nan`. Everything else that follows in the rest of this document applies equally to ``string`` and ``object`` dtype. diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 2de19a3319cc57..3bad7f0162f445 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -134,6 +134,10 @@ class StringArray(PandasArray): The string methods are available on Series backed by a StringArray. + Notes + ----- + StringArray returns a BooleanArray for comparison methods. + Examples -------- >>> pd.array(['This is', 'some text', None, 'data.'], dtype="string") @@ -148,6 +152,13 @@ class StringArray(PandasArray): Traceback (most recent call last): ... ValueError: StringArray requires an object-dtype ndarray of strings. + + For comparision methods, this returns a :class:`pandas.BooleanArray` + + >>> pd.array(["a", None, "c"], dtype="string") == "a" + + [True, NA, False] + Length: 3, dtype: boolean """ # undo the PandasArray hack @@ -255,7 +266,12 @@ def value_counts(self, dropna=False): # Overrride parent because we have different return types. @classmethod def _create_arithmetic_method(cls, op): + # Note: this handles both arithmetic and comparison methods. def method(self, other): + from pandas.arrays import BooleanArray + + assert op.__name__ in ops.ARITHMETIC_BINOPS | ops.COMPARISON_BINOPS + if isinstance(other, (ABCIndexClass, ABCSeries, ABCDataFrame)): return NotImplemented @@ -275,15 +291,16 @@ def method(self, other): other = np.asarray(other) other = other[valid] - result = np.empty_like(self._ndarray, dtype="object") - result[mask] = StringDtype.na_value - result[valid] = op(self._ndarray[valid], other) - - if op.__name__ in {"add", "radd", "mul", "rmul"}: + if op.__name__ in ops.ARITHMETIC_BINOPS: + result = np.empty_like(self._ndarray, dtype="object") + result[mask] = StringDtype.na_value + result[valid] = op(self._ndarray[valid], other) return StringArray(result) else: - dtype = "object" if mask.any() else "bool" - return np.asarray(result, dtype=dtype) + # logical + result = np.zeros(len(self._ndarray), dtype="bool") + result[valid] = op(self._ndarray[valid], other) + return BooleanArray(result, mask) return compat.set_function_name(method, f"__{op.__name__}__", cls) diff --git a/pandas/core/ops/__init__.py b/pandas/core/ops/__init__.py index ffa38cbc3d6581..14705f4d22e9b8 100644 --- a/pandas/core/ops/__init__.py +++ b/pandas/core/ops/__init__.py @@ -5,7 +5,7 @@ """ import datetime import operator -from typing import Tuple, Union +from typing import Set, Tuple, Union import numpy as np @@ -59,6 +59,37 @@ rxor, ) +# ----------------------------------------------------------------------------- +# constants +ARITHMETIC_BINOPS: Set[str] = { + "add", + "sub", + "mul", + "pow", + "mod", + "floordiv", + "truediv", + "divmod", + "radd", + "rsub", + "rmul", + "rpow", + "rmod", + "rfloordiv", + "rtruediv", + "rdivmod", +} + + +COMPARISON_BINOPS: Set[str] = { + "eq", + "ne", + "lt", + "gt", + "le", + "ge", +} + # ----------------------------------------------------------------------------- # Ops Wrapping Utilities diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index 0dfd75a2042b06..0544ee40028909 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -154,6 +154,37 @@ def test_add_frame(): tm.assert_frame_equal(result, expected) +def test_comparison_methods_scalar(all_compare_operators): + op_name = all_compare_operators + + a = pd.array(["a", None, "c"], dtype="string") + other = "a" + result = getattr(a, op_name)(other) + expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object) + expected = pd.array(expected, dtype="boolean") + tm.assert_extension_array_equal(result, expected) + + result = getattr(a, op_name)(pd.NA) + expected = pd.array([None, None, None], dtype="boolean") + tm.assert_extension_array_equal(result, expected) + + +def test_comparison_methods_array(all_compare_operators): + op_name = all_compare_operators + + a = pd.array(["a", None, "c"], dtype="string") + other = [None, None, "c"] + result = getattr(a, op_name)(other) + expected = np.empty_like(a, dtype="object") + expected[-1] = getattr(other[-1], op_name)(a[-1]) + expected = pd.array(expected, dtype="boolean") + tm.assert_extension_array_equal(result, expected) + + result = getattr(a, op_name)(pd.NA) + expected = pd.array([None, None, None], dtype="boolean") + tm.assert_extension_array_equal(result, expected) + + def test_constructor_raises(): with pytest.raises(ValueError, match="sequence of strings"): pd.arrays.StringArray(np.array(["a", "b"], dtype="S1")) diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 471a1b79d23bc1..8519c2999ade34 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -91,7 +91,7 @@ class TestCasting(base.BaseCastingTests): class TestComparisonOps(base.BaseComparisonOpsTests): def _compare_other(self, s, data, op_name, other): result = getattr(s, op_name)(other) - expected = getattr(s.astype(object), op_name)(other) + expected = getattr(s.astype(object), op_name)(other).astype("boolean") self.assert_series_equal(result, expected) def test_compare_scalar(self, data, all_compare_operators):