From 4072ed537beb17373b2fde6a885818a8a8bea943 Mon Sep 17 00:00:00 2001 From: Pulkit Maloo Date: Mon, 19 Nov 2018 20:13:56 -0500 Subject: [PATCH] BUG: fixed .str.contains(..., na=False) for categorical series (#22170) --- doc/source/whatsnew/v0.24.0.rst | 2 +- pandas/core/strings.py | 9 +++++---- pandas/tests/test_strings.py | 30 ++++++++++++++++++++++++------ 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/doc/source/whatsnew/v0.24.0.rst b/doc/source/whatsnew/v0.24.0.rst index 3ed5c91141b16a..7d8ee975ba02c8 100644 --- a/doc/source/whatsnew/v0.24.0.rst +++ b/doc/source/whatsnew/v0.24.0.rst @@ -1280,7 +1280,7 @@ Strings - Bug in :meth:`Index.str.partition` was not nan-safe (:issue:`23558`). - Bug in :meth:`Index.str.split` was not nan-safe (:issue:`23677`). -- +- Bug :func:`Series.str.contains` not respecting the ``na`` argument for a ``Categorical`` dtype ``Series`` (:issue:`22158`) Interval ^^^^^^^^ diff --git a/pandas/core/strings.py b/pandas/core/strings.py index 1c4317d56f82bb..6c21318c935978 100644 --- a/pandas/core/strings.py +++ b/pandas/core/strings.py @@ -1857,7 +1857,7 @@ def __iter__(self): g = self.get(i) def _wrap_result(self, result, use_codes=True, - name=None, expand=None): + name=None, expand=None, fill_value=np.nan): from pandas.core.index import Index, MultiIndex @@ -1867,7 +1867,8 @@ def _wrap_result(self, result, use_codes=True, # so make it possible to skip this step as the method already did this # before the transformation... if use_codes and self._is_categorical: - result = take_1d(result, self._orig.cat.codes) + result = take_1d(result, self._orig.cat.codes, + fill_value=fill_value) if not hasattr(result, 'ndim') or not hasattr(result, 'dtype'): return result @@ -2520,12 +2521,12 @@ def join(self, sep): def contains(self, pat, case=True, flags=0, na=np.nan, regex=True): result = str_contains(self._parent, pat, case=case, flags=flags, na=na, regex=regex) - return self._wrap_result(result) + return self._wrap_result(result, fill_value=na) @copy(str_match) def match(self, pat, case=True, flags=0, na=np.nan): result = str_match(self._parent, pat, case=case, flags=flags, na=na) - return self._wrap_result(result) + return self._wrap_result(result, fill_value=na) @copy(str_replace) def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): diff --git a/pandas/tests/test_strings.py b/pandas/tests/test_strings.py index 7b4e330ca6e3d6..c0aab5d25e3fef 100644 --- a/pandas/tests/test_strings.py +++ b/pandas/tests/test_strings.py @@ -512,10 +512,28 @@ def test_contains(self): assert result.dtype == np.bool_ tm.assert_numpy_array_equal(result, expected) - # na - values = Series(['om', 'foo', np.nan]) - res = values.str.contains('foo', na="foo") - assert res.loc[2] == "foo" + def test_contains_for_object_category(self): + # gh 22158 + + # na for category + values = Series(["a", "b", "c", "a", np.nan], dtype="category") + result = values.str.contains('a', na=True) + expected = Series([True, False, False, True, True]) + tm.assert_series_equal(result, expected) + + result = values.str.contains('a', na=False) + expected = Series([True, False, False, True, False]) + tm.assert_series_equal(result, expected) + + # na for objects + values = Series(["a", "b", "c", "a", np.nan]) + result = values.str.contains('a', na=True) + expected = Series([True, False, False, True, True]) + tm.assert_series_equal(result, expected) + + result = values.str.contains('a', na=False) + expected = Series([True, False, False, True, False]) + tm.assert_series_equal(result, expected) def test_startswith(self): values = Series(['om', NA, 'foo_nom', 'nom', 'bar_foo', NA, 'foo']) @@ -2893,7 +2911,7 @@ def test_get_complex_nested(self, to_type): expected = Series([np.nan]) tm.assert_series_equal(result, expected) - def test_more_contains(self): + def test_contains_moar(self): # PR #1179 s = Series(['A', 'B', 'C', 'Aaba', 'Baca', '', NA, 'CABA', 'dog', 'cat']) @@ -2943,7 +2961,7 @@ def test_contains_nan(self): expected = Series([np.nan, np.nan, np.nan], dtype=np.object_) assert_series_equal(result, expected) - def test_more_replace(self): + def test_replace_moar(self): # PR #1179 s = Series(['A', 'B', 'C', 'Aaba', 'Baca', '', NA, 'CABA', 'dog', 'cat'])