From 4948fbfa2e0b4365d991f134966c2414e013fe6c Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 2 Dec 2020 14:21:37 -0800 Subject: [PATCH 1/2] ENH: implement _should_compare/_is_comparable_dtype for all Index subclasses --- pandas/core/indexes/base.py | 31 +++++++++++++++++++++++-------- pandas/core/indexes/category.py | 3 +++ pandas/core/indexes/interval.py | 20 +++++++++++++++++--- pandas/core/indexes/multi.py | 15 ++++++--------- pandas/core/indexes/numeric.py | 6 +++++- pandas/tests/indexes/test_base.py | 8 ++++---- 6 files changed, 58 insertions(+), 25 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 40fcc824992b7..874ece8742451 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -4904,16 +4904,31 @@ def get_indexer_non_unique(self, target): # Treat boolean labels passed to a numeric index as not found. Without # this fix False and True would be treated as 0 and 1 respectively. # (GH #16877) - no_matches = -1 * np.ones(self.shape, dtype=np.intp) - return no_matches, no_matches + return self._get_indexer_non_comparable(target, method=None, unique=False) pself, ptarget = self._maybe_promote(target) if pself is not self or ptarget is not target: return pself.get_indexer_non_unique(ptarget) - if not self._is_comparable_dtype(target.dtype): - no_matches = -1 * np.ones(self.shape, dtype=np.intp) - return no_matches, no_matches + if not self._should_compare(target): + return self._get_indexer_non_comparable(target, method=None, unique=False) + + if not is_dtype_equal(self.dtype, target.dtype): + # TODO: if object, could use infer_dtype to pre-empt costly + # conversion if still non-comparable? + dtype = find_common_type([self.dtype, target.dtype]) + if ( + dtype.kind in ["i", "u"] + and is_categorical_dtype(target.dtype) + and target.hasnans + ): + # FIXME: find_common_type incorrect with Categorical GH#38240 + # FIXME: some cases where float64 cast can be lossy? + dtype = np.dtype(np.float64) + + this = self.astype(dtype, copy=False) + that = target.astype(dtype, copy=False) + return this.get_indexer_non_unique(that) if is_categorical_dtype(target.dtype): tgt_values = np.asarray(target) @@ -4966,7 +4981,7 @@ def _get_indexer_non_comparable(self, target: "Index", method, unique: bool = Tr If doing an inequality check, i.e. method is not None. """ if method is not None: - other = _unpack_nested_dtype(target) + other = unpack_nested_dtype(target) raise TypeError(f"Cannot compare dtypes {self.dtype} and {other.dtype}") no_matches = -1 * np.ones(target.shape, dtype=np.intp) @@ -5017,7 +5032,7 @@ def _should_compare(self, other: "Index") -> bool: """ Check if `self == other` can ever have non-False entries. """ - other = _unpack_nested_dtype(other) + other = unpack_nested_dtype(other) dtype = other.dtype return self._is_comparable_dtype(dtype) or is_object_dtype(dtype) @@ -6170,7 +6185,7 @@ def get_unanimous_names(*indexes: Index) -> Tuple[Label, ...]: return names -def _unpack_nested_dtype(other: Index) -> Index: +def unpack_nested_dtype(other: Index) -> Index: """ When checking if our dtype is comparable with another, we need to unpack CategoricalDtype to look at its categories.dtype. diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index abf70fd150345..89f1647a299dc 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -554,6 +554,9 @@ def _maybe_cast_slice_bound(self, label, side: str, kind): # -------------------------------------------------------------------- + def _is_comparable_dtype(self, dtype): + return self._categories._is_comparable_dtype(dtype) + def take_nd(self, *args, **kwargs): """Alias for `take`""" warnings.warn( diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 18be4bf225da5..049391ed1ad57 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -11,7 +11,7 @@ from pandas._libs import lib from pandas._libs.interval import Interval, IntervalMixin, IntervalTree from pandas._libs.tslibs import BaseOffset, Timedelta, Timestamp, to_offset -from pandas._typing import AnyArrayLike, Label +from pandas._typing import AnyArrayLike, DtypeObj, Label from pandas.errors import InvalidIndexError from pandas.util._decorators import Appender, Substitution, cache_readonly from pandas.util._exceptions import rewrite_exception @@ -38,6 +38,7 @@ is_object_dtype, is_scalar, ) +from pandas.core.dtypes.dtypes import IntervalDtype from pandas.core.algorithms import take_1d from pandas.core.arrays.interval import IntervalArray, _interval_shared_docs @@ -50,6 +51,7 @@ default_pprint, ensure_index, maybe_extract_name, + unpack_nested_dtype, ) from pandas.core.indexes.datetimes import DatetimeIndex, date_range from pandas.core.indexes.extension import ExtensionIndex, inherit_names @@ -807,6 +809,19 @@ def _convert_list_indexer(self, keyarr): return locs + def _is_comparable_dtype(self, dtype: DtypeObj) -> bool: + if not isinstance(dtype, IntervalDtype): + return False + common_subtype = find_common_type([self.dtype.subtype, dtype.subtype]) + return not is_object_dtype(common_subtype) + + def _should_compare(self, other) -> bool: + if not super()._should_compare(other): + return False + other = unpack_nested_dtype(other) + return other.closed == self.closed + + # TODO: use should_compare and get rid of _is_non_comparable_own_type def _is_non_comparable_own_type(self, other: "IntervalIndex") -> bool: # different closed or incompatible subtype -> no matches @@ -814,8 +829,7 @@ def _is_non_comparable_own_type(self, other: "IntervalIndex") -> bool: # is_comparable_dtype GH#19371 if self.closed != other.closed: return True - common_subtype = find_common_type([self.dtype.subtype, other.dtype.subtype]) - return is_object_dtype(common_subtype) + return not self._is_comparable_dtype(other.dtype) # -------------------------------------------------------------------- diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 6af6555007c2f..0287f6f282685 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -20,7 +20,7 @@ from pandas._libs import algos as libalgos, index as libindex, lib from pandas._libs.hashtable import duplicated_int64 -from pandas._typing import AnyArrayLike, Label, Scalar, Shape +from pandas._typing import AnyArrayLike, DtypeObj, Label, Scalar, Shape from pandas.compat.numpy import function as nv from pandas.errors import InvalidIndexError, PerformanceWarning, UnsortedIndexError from pandas.util._decorators import Appender, cache_readonly, doc @@ -3582,6 +3582,9 @@ def union(self, other, sort=None): zip(*uniq_tuples), sortorder=0, names=result_names ) + def _is_comparable_dtype(self, dtype: DtypeObj) -> bool: + return is_object_dtype(dtype) + def intersection(self, other, sort=False): """ Form the intersection of two MultiIndex objects. @@ -3617,15 +3620,9 @@ def intersection(self, other, sort=False): def _intersection(self, other, sort=False): other, result_names = self._convert_can_do_setop(other) - if not is_object_dtype(other.dtype): + if not self._is_comparable_dtype(other.dtype): # The intersection is empty - # TODO: we have no tests that get here - return MultiIndex( - levels=self.levels, - codes=[[]] * self.nlevels, - names=result_names, - verify_integrity=False, - ) + return self[:0].rename(result_names) lvals = self._values rvals = other._values diff --git a/pandas/core/indexes/numeric.py b/pandas/core/indexes/numeric.py index 12f61fc44582d..d4562162f7c10 100644 --- a/pandas/core/indexes/numeric.py +++ b/pandas/core/indexes/numeric.py @@ -4,7 +4,7 @@ import numpy as np from pandas._libs import index as libindex, lib -from pandas._typing import Dtype, Label +from pandas._typing import Dtype, DtypeObj, Label from pandas.util._decorators import doc from pandas.core.dtypes.cast import astype_nansafe @@ -148,6 +148,10 @@ def _convert_tolerance(self, tolerance, target): ) return tolerance + def _is_comparable_dtype(self, dtype: DtypeObj) -> bool: + # If we ever have BoolIndex or ComplexIndex, this may need to be tightened + return is_numeric_dtype(dtype) + @classmethod def _assert_safe_casting(cls, data, subarr): """ diff --git a/pandas/tests/indexes/test_base.py b/pandas/tests/indexes/test_base.py index ba49c51c9db8e..27ee370d1c036 100644 --- a/pandas/tests/indexes/test_base.py +++ b/pandas/tests/indexes/test_base.py @@ -1249,10 +1249,9 @@ def test_get_indexer_numeric_index_boolean_target(self, method, idx_class): if method == "get_indexer": tm.assert_numpy_array_equal(result, expected) else: - expected = np.array([-1, -1, -1, -1], dtype=np.intp) - + missing = np.arange(3, dtype=np.intp) tm.assert_numpy_array_equal(result[0], expected) - tm.assert_numpy_array_equal(result[1], expected) + tm.assert_numpy_array_equal(result[1], missing) def test_get_indexer_with_NA_values( self, unique_nulls_fixture, unique_nulls_fixture2 @@ -2346,5 +2345,6 @@ def construct(dtype): else: no_matches = np.array([-1] * 6, dtype=np.intp) + missing = np.arange(6, dtype=np.intp) tm.assert_numpy_array_equal(result[0], no_matches) - tm.assert_numpy_array_equal(result[1], no_matches) + tm.assert_numpy_array_equal(result[1], missing) From 1f53cd87a9fe8bed3bfa19309e9380c5bf0ee11f Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 2 Dec 2020 15:07:07 -0800 Subject: [PATCH 2/2] mypy fixup --- pandas/core/indexes/category.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index 89f1647a299dc..377fff5f85e92 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -555,7 +555,7 @@ def _maybe_cast_slice_bound(self, label, side: str, kind): # -------------------------------------------------------------------- def _is_comparable_dtype(self, dtype): - return self._categories._is_comparable_dtype(dtype) + return self.categories._is_comparable_dtype(dtype) def take_nd(self, *args, **kwargs): """Alias for `take`"""