From e99e5ab32c4e831e7bbac0346189f4d6d86a6225 Mon Sep 17 00:00:00 2001 From: patrick <61934744+phofl@users.noreply.github.com> Date: Sun, 29 Nov 2020 18:21:52 +0100 Subject: [PATCH] BUG: Fix duplicates in intersection of multiindexes (#36927) --- doc/source/whatsnew/v1.1.5.rst | 1 + pandas/core/indexes/base.py | 9 +++++--- pandas/core/indexes/multi.py | 8 +++++-- pandas/core/ops/__init__.py | 5 +++- pandas/core/reshape/merge.py | 9 ++++++-- .../tests/indexes/base_class/test_setops.py | 2 +- pandas/tests/indexes/multi/test_setops.py | 23 +++++++++++++++++++ pandas/tests/indexes/test_setops.py | 10 ++++++++ pandas/tests/reshape/merge/test_merge.py | 2 +- 9 files changed, 59 insertions(+), 10 deletions(-) diff --git a/doc/source/whatsnew/v1.1.5.rst b/doc/source/whatsnew/v1.1.5.rst index 46c4ad4f35fe4..edc2f7327abfc 100644 --- a/doc/source/whatsnew/v1.1.5.rst +++ b/doc/source/whatsnew/v1.1.5.rst @@ -23,6 +23,7 @@ Fixed regressions - Fixed regression in :meth:`DataFrame.groupby` aggregation with out-of-bounds datetime objects in an object-dtype column (:issue:`36003`) - Fixed regression in ``df.groupby(..).rolling(..)`` with the resulting :class:`MultiIndex` when grouping by a label that is in the index (:issue:`37641`) - Fixed regression in :meth:`DataFrame.fillna` not filling ``NaN`` after other operations such as :meth:`DataFrame.pivot` (:issue:`36495`). +- Fixed regression in :meth:`MultiIndex.intersection` returning duplicates when at least one of the indexes had duplicates (:issue:`36915`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index c86652acbcd0f..3f89b0619e600 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2822,7 +2822,7 @@ def intersection(self, other, sort=False): self._assert_can_do_setop(other) other = ensure_index(other) - if self.equals(other): + if self.equals(other) and not self.has_duplicates: return self._get_reconciled_name_object(other) if not is_dtype_equal(self.dtype, other.dtype): @@ -2847,7 +2847,7 @@ def _intersection(self, other, sort=False): except TypeError: pass else: - return result + return algos.unique1d(result) try: indexer = Index(rvals).get_indexer(lvals) @@ -2858,11 +2858,14 @@ def _intersection(self, other, sort=False): indexer = algos.unique1d(Index(rvals).get_indexer_non_unique(lvals)[0]) indexer = indexer[indexer != -1] - result = other.take(indexer)._values + result = other.take(indexer).unique()._values if sort is None: result = algos.safe_sort(result) + # Intersection has to be unique + assert algos.unique(result).shape == result.shape + return result def difference(self, other, sort=None): diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 46846209f315b..589da4a6c4ceb 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3601,6 +3601,8 @@ def intersection(self, other, sort=False): other, result_names = self._convert_can_do_setop(other) if self.equals(other): + if self.has_duplicates: + return self.unique().rename(result_names) return self.rename(result_names) if not is_object_dtype(other.dtype): @@ -3619,10 +3621,12 @@ def intersection(self, other, sort=False): uniq_tuples = None # flag whether _inner_indexer was successful if self.is_monotonic and other.is_monotonic: try: - uniq_tuples = self._inner_indexer(lvals, rvals)[0] - sort = False # uniq_tuples is already sorted + inner_tuples = self._inner_indexer(lvals, rvals)[0] + sort = False # inner_tuples is already sorted except TypeError: pass + else: + uniq_tuples = algos.unique(inner_tuples) if uniq_tuples is None: other_uniq = set(rvals) diff --git a/pandas/core/ops/__init__.py b/pandas/core/ops/__init__.py index 2b159c607b0a0..d8b5dba424cbf 100644 --- a/pandas/core/ops/__init__.py +++ b/pandas/core/ops/__init__.py @@ -311,7 +311,10 @@ def should_reindex_frame_op( # TODO: any other cases we should handle here? cols = left.columns.intersection(right.columns) - if len(cols) and not (cols.equals(left.columns) and cols.equals(right.columns)): + # Intersection is always unique so we have to check the unique columns + left_uniques = left.columns.unique() + right_uniques = right.columns.unique() + if len(cols) and not (cols.equals(left_uniques) and cols.equals(right_uniques)): # TODO: is there a shortcut available when len(cols) == 0? return True diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index 3b755c40721fb..9bb1add309407 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -1271,7 +1271,9 @@ def _validate_specification(self): raise MergeError("Must pass left_on or left_index=True") else: # use the common columns - common_cols = self.left.columns.intersection(self.right.columns) + left_cols = self.left.columns + right_cols = self.right.columns + common_cols = left_cols.intersection(right_cols) if len(common_cols) == 0: raise MergeError( "No common columns to perform merge on. " @@ -1280,7 +1282,10 @@ def _validate_specification(self): f"left_index={self.left_index}, " f"right_index={self.right_index}" ) - if not common_cols.is_unique: + if ( + not left_cols.join(common_cols, how="inner").is_unique + or not right_cols.join(common_cols, how="inner").is_unique + ): raise MergeError(f"Data columns not unique: {repr(common_cols)}") self.left_on = self.right_on = common_cols elif self.on is not None: diff --git a/pandas/tests/indexes/base_class/test_setops.py b/pandas/tests/indexes/base_class/test_setops.py index 6413b110dff2e..ddcb3c5b87ebc 100644 --- a/pandas/tests/indexes/base_class/test_setops.py +++ b/pandas/tests/indexes/base_class/test_setops.py @@ -141,7 +141,7 @@ def test_intersection_str_dates(self, sort): @pytest.mark.parametrize( "index2,expected_arr", - [(Index(["B", "D"]), ["B"]), (Index(["B", "D", "A"]), ["A", "B", "A"])], + [(Index(["B", "D"]), ["B"]), (Index(["B", "D", "A"]), ["A", "B"])], ) def test_intersection_non_monotonic_non_unique(self, index2, expected_arr, sort): # non-monotonic non-unique diff --git a/pandas/tests/indexes/multi/test_setops.py b/pandas/tests/indexes/multi/test_setops.py index 4ac9a27069a3f..2ac57f1befd57 100644 --- a/pandas/tests/indexes/multi/test_setops.py +++ b/pandas/tests/indexes/multi/test_setops.py @@ -378,3 +378,26 @@ def test_setops_disallow_true(method): with pytest.raises(ValueError, match="The 'sort' keyword only takes"): getattr(idx1, method)(idx2, sort=True) + + +@pytest.mark.parametrize( + ("tuples", "exp_tuples"), + [ + ([("val1", "test1")], [("val1", "test1")]), + ([("val1", "test1"), ("val1", "test1")], [("val1", "test1")]), + ( + [("val2", "test2"), ("val1", "test1")], + [("val2", "test2"), ("val1", "test1")], + ), + ], +) +def test_intersect_with_duplicates(tuples, exp_tuples): + # GH#36915 + left = MultiIndex.from_tuples(tuples, names=["first", "second"]) + right = MultiIndex.from_tuples( + [("val1", "test1"), ("val1", "test1"), ("val2", "test2")], + names=["first", "second"], + ) + result = left.intersection(right) + expected = MultiIndex.from_tuples(exp_tuples, names=["first", "second"]) + tm.assert_index_equal(result, expected) diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index 0973cef7cfdc1..2675c4569a8e9 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -120,6 +120,16 @@ def test_dunder_inplace_setops_deprecated(index): index ^= index +@pytest.mark.parametrize("values", [[1, 2, 2, 3], [3, 3]]) +def test_intersection_duplicates(values): + # GH#31326 + a = pd.Index(values) + b = pd.Index([3, 3]) + result = a.intersection(b) + expected = pd.Index([3]) + tm.assert_index_equal(result, expected) + + class TestSetOps: # Set operation tests shared by all indexes in the `index` fixture @pytest.mark.parametrize("case", [0.5, "xxx"]) diff --git a/pandas/tests/reshape/merge/test_merge.py b/pandas/tests/reshape/merge/test_merge.py index f44909b61ff7a..40ba62a27aa68 100644 --- a/pandas/tests/reshape/merge/test_merge.py +++ b/pandas/tests/reshape/merge/test_merge.py @@ -753,7 +753,7 @@ def test_overlapping_columns_error_message(self): # #2649, #10639 df2.columns = ["key1", "foo", "foo"] - msg = r"Data columns not unique: Index\(\['foo', 'foo'\], dtype='object'\)" + msg = r"Data columns not unique: Index\(\['foo'\], dtype='object'\)" with pytest.raises(MergeError, match=msg): merge(df, df2)