Skip to content

Commit

Permalink
REF: Refactor assert_index_equal (#41980)
Browse files Browse the repository at this point in the history
  • Loading branch information
topper-123 authored Jun 18, 2021
1 parent fce7f9e commit 50e55e6
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
35 changes: 17 additions & 18 deletions pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,18 +314,16 @@ def _check_types(left, right, obj="Index") -> None:
return

assert_class_equal(left, right, exact=exact, obj=obj)
assert_attr_equal("inferred_type", left, right, obj=obj)

# Skip exact dtype checking when `check_categorical` is False
if check_categorical:
assert_attr_equal("dtype", left, right, obj=obj)
if is_categorical_dtype(left.dtype) and is_categorical_dtype(right.dtype):
if is_categorical_dtype(left.dtype) and is_categorical_dtype(right.dtype):
if check_categorical:
assert_attr_equal("dtype", left, right, obj=obj)
assert_index_equal(left.categories, right.categories, exact=exact)
return

# allow string-like to have different inferred_types
if left.inferred_type in ("string"):
assert right.inferred_type in ("string")
else:
assert_attr_equal("inferred_type", left, right, obj=obj)
assert_attr_equal("dtype", left, right, obj=obj)

def _get_ilevel_values(index, level):
# accept level number only
Expand Down Expand Up @@ -437,6 +435,8 @@ def assert_class_equal(left, right, exact: bool | str = True, obj="Input"):
"""
Checks classes are equal.
"""
from pandas.core.indexes.numeric import NumericIndex

__tracebackhide__ = True

def repr_class(x):
Expand All @@ -446,17 +446,16 @@ def repr_class(x):

return type(x).__name__

if type(left) == type(right):
return

if exact == "equiv":
if type(left) != type(right):
# allow equivalence of Int64Index/RangeIndex
types = {type(left).__name__, type(right).__name__}
if len(types - {"Int64Index", "RangeIndex"}):
msg = f"{obj} classes are not equivalent"
raise_assert_detail(obj, msg, repr_class(left), repr_class(right))
elif exact:
if type(left) != type(right):
msg = f"{obj} classes are different"
raise_assert_detail(obj, msg, repr_class(left), repr_class(right))
# accept equivalence of NumericIndex (sub-)classes
if isinstance(left, NumericIndex) and isinstance(right, NumericIndex):
return

msg = f"{obj} classes are different"
raise_assert_detail(obj, msg, repr_class(left), repr_class(right))


def assert_attr_equal(attr: str, left, right, obj: str = "Attributes"):
Expand Down
23 changes: 19 additions & 4 deletions pandas/tests/util/test_assert_index_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,30 @@ def test_index_equal_length_mismatch(check_exact):
tm.assert_index_equal(idx1, idx2, check_exact=check_exact)


def test_index_equal_class_mismatch(check_exact):
msg = """Index are different
@pytest.mark.parametrize("exact", [False, "equiv"])
def test_index_equal_class(exact):
idx1 = Index([0, 1, 2])
idx2 = RangeIndex(3)

tm.assert_index_equal(idx1, idx2, exact=exact)


@pytest.mark.parametrize(
"idx_values, msg_str",
[
[[1, 2, 3.0], "Float64Index\\(\\[1\\.0, 2\\.0, 3\\.0\\], dtype='float64'\\)"],
[range(3), "RangeIndex\\(start=0, stop=3, step=1\\)"],
],
)
def test_index_equal_class_mismatch(check_exact, idx_values, msg_str):
msg = f"""Index are different
Index classes are different
\\[left\\]: Int64Index\\(\\[1, 2, 3\\], dtype='int64'\\)
\\[right\\]: Float64Index\\(\\[1\\.0, 2\\.0, 3\\.0\\], dtype='float64'\\)"""
\\[right\\]: {msg_str}"""

idx1 = Index([1, 2, 3])
idx2 = Index([1, 2, 3.0])
idx2 = Index(idx_values)

with pytest.raises(AssertionError, match=msg):
tm.assert_index_equal(idx1, idx2, exact=True, check_exact=check_exact)
Expand Down

0 comments on commit 50e55e6

Please sign in to comment.