Skip to content

Commit

Permalink
BUG: Fix duplicates in intersection of multiindexes (#36927)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Nov 29, 2020
1 parent 7b400b3 commit e99e5ab
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 10 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Fixed regressions
- Fixed regression in :meth:`DataFrame.groupby` aggregation with out-of-bounds datetime objects in an object-dtype column (:issue:`36003`)
- Fixed regression in ``df.groupby(..).rolling(..)`` with the resulting :class:`MultiIndex` when grouping by a label that is in the index (:issue:`37641`)
- Fixed regression in :meth:`DataFrame.fillna` not filling ``NaN`` after other operations such as :meth:`DataFrame.pivot` (:issue:`36495`).
- Fixed regression in :meth:`MultiIndex.intersection` returning duplicates when at least one of the indexes had duplicates (:issue:`36915`)

.. ---------------------------------------------------------------------------
Expand Down
9 changes: 6 additions & 3 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2822,7 +2822,7 @@ def intersection(self, other, sort=False):
self._assert_can_do_setop(other)
other = ensure_index(other)

if self.equals(other):
if self.equals(other) and not self.has_duplicates:
return self._get_reconciled_name_object(other)

if not is_dtype_equal(self.dtype, other.dtype):
Expand All @@ -2847,7 +2847,7 @@ def _intersection(self, other, sort=False):
except TypeError:
pass
else:
return result
return algos.unique1d(result)

try:
indexer = Index(rvals).get_indexer(lvals)
Expand All @@ -2858,11 +2858,14 @@ def _intersection(self, other, sort=False):
indexer = algos.unique1d(Index(rvals).get_indexer_non_unique(lvals)[0])
indexer = indexer[indexer != -1]

result = other.take(indexer)._values
result = other.take(indexer).unique()._values

if sort is None:
result = algos.safe_sort(result)

# Intersection has to be unique
assert algos.unique(result).shape == result.shape

return result

def difference(self, other, sort=None):
Expand Down
8 changes: 6 additions & 2 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3601,6 +3601,8 @@ def intersection(self, other, sort=False):
other, result_names = self._convert_can_do_setop(other)

if self.equals(other):
if self.has_duplicates:
return self.unique().rename(result_names)
return self.rename(result_names)

if not is_object_dtype(other.dtype):
Expand All @@ -3619,10 +3621,12 @@ def intersection(self, other, sort=False):
uniq_tuples = None # flag whether _inner_indexer was successful
if self.is_monotonic and other.is_monotonic:
try:
uniq_tuples = self._inner_indexer(lvals, rvals)[0]
sort = False # uniq_tuples is already sorted
inner_tuples = self._inner_indexer(lvals, rvals)[0]
sort = False # inner_tuples is already sorted
except TypeError:
pass
else:
uniq_tuples = algos.unique(inner_tuples)

if uniq_tuples is None:
other_uniq = set(rvals)
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,10 @@ def should_reindex_frame_op(
# TODO: any other cases we should handle here?
cols = left.columns.intersection(right.columns)

if len(cols) and not (cols.equals(left.columns) and cols.equals(right.columns)):
# Intersection is always unique so we have to check the unique columns
left_uniques = left.columns.unique()
right_uniques = right.columns.unique()
if len(cols) and not (cols.equals(left_uniques) and cols.equals(right_uniques)):
# TODO: is there a shortcut available when len(cols) == 0?
return True

Expand Down
9 changes: 7 additions & 2 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,7 +1271,9 @@ def _validate_specification(self):
raise MergeError("Must pass left_on or left_index=True")
else:
# use the common columns
common_cols = self.left.columns.intersection(self.right.columns)
left_cols = self.left.columns
right_cols = self.right.columns
common_cols = left_cols.intersection(right_cols)
if len(common_cols) == 0:
raise MergeError(
"No common columns to perform merge on. "
Expand All @@ -1280,7 +1282,10 @@ def _validate_specification(self):
f"left_index={self.left_index}, "
f"right_index={self.right_index}"
)
if not common_cols.is_unique:
if (
not left_cols.join(common_cols, how="inner").is_unique
or not right_cols.join(common_cols, how="inner").is_unique
):
raise MergeError(f"Data columns not unique: {repr(common_cols)}")
self.left_on = self.right_on = common_cols
elif self.on is not None:
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/indexes/base_class/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_intersection_str_dates(self, sort):

@pytest.mark.parametrize(
"index2,expected_arr",
[(Index(["B", "D"]), ["B"]), (Index(["B", "D", "A"]), ["A", "B", "A"])],
[(Index(["B", "D"]), ["B"]), (Index(["B", "D", "A"]), ["A", "B"])],
)
def test_intersection_non_monotonic_non_unique(self, index2, expected_arr, sort):
# non-monotonic non-unique
Expand Down
23 changes: 23 additions & 0 deletions pandas/tests/indexes/multi/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,26 @@ def test_setops_disallow_true(method):

with pytest.raises(ValueError, match="The 'sort' keyword only takes"):
getattr(idx1, method)(idx2, sort=True)


@pytest.mark.parametrize(
("tuples", "exp_tuples"),
[
([("val1", "test1")], [("val1", "test1")]),
([("val1", "test1"), ("val1", "test1")], [("val1", "test1")]),
(
[("val2", "test2"), ("val1", "test1")],
[("val2", "test2"), ("val1", "test1")],
),
],
)
def test_intersect_with_duplicates(tuples, exp_tuples):
# GH#36915
left = MultiIndex.from_tuples(tuples, names=["first", "second"])
right = MultiIndex.from_tuples(
[("val1", "test1"), ("val1", "test1"), ("val2", "test2")],
names=["first", "second"],
)
result = left.intersection(right)
expected = MultiIndex.from_tuples(exp_tuples, names=["first", "second"])
tm.assert_index_equal(result, expected)
10 changes: 10 additions & 0 deletions pandas/tests/indexes/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ def test_dunder_inplace_setops_deprecated(index):
index ^= index


@pytest.mark.parametrize("values", [[1, 2, 2, 3], [3, 3]])
def test_intersection_duplicates(values):
# GH#31326
a = pd.Index(values)
b = pd.Index([3, 3])
result = a.intersection(b)
expected = pd.Index([3])
tm.assert_index_equal(result, expected)


class TestSetOps:
# Set operation tests shared by all indexes in the `index` fixture
@pytest.mark.parametrize("case", [0.5, "xxx"])
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/reshape/merge/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def test_overlapping_columns_error_message(self):

# #2649, #10639
df2.columns = ["key1", "foo", "foo"]
msg = r"Data columns not unique: Index\(\['foo', 'foo'\], dtype='object'\)"
msg = r"Data columns not unique: Index\(\['foo'\], dtype='object'\)"
with pytest.raises(MergeError, match=msg):
merge(df, df2)

Expand Down

0 comments on commit e99e5ab

Please sign in to comment.