diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index a06e0c74ec03b..d943fe3df88c5 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -1,5 +1,6 @@ from operator import le, lt import textwrap +from typing import TYPE_CHECKING, Optional, Tuple, Union, cast import numpy as np @@ -11,6 +12,7 @@ IntervalMixin, intervals_to_interval_bounds, ) +from pandas._typing import ArrayLike, Dtype from pandas.compat.numpy import function as nv from pandas.util._decorators import Appender @@ -18,7 +20,9 @@ from pandas.core.dtypes.common import ( is_categorical_dtype, is_datetime64_any_dtype, + is_dtype_equal, is_float_dtype, + is_integer, is_integer_dtype, is_interval_dtype, is_list_like, @@ -45,6 +49,10 @@ from pandas.core.indexers import check_array_indexer from pandas.core.indexes.base import ensure_index +if TYPE_CHECKING: + from pandas import Index + from pandas.core.arrays import DatetimeArray, TimedeltaArray + _interval_shared_docs = {} _shared_docs_kwargs = dict( @@ -169,6 +177,17 @@ def __new__( left = data._left right = data._right closed = closed or data.closed + + if dtype is None or data.dtype == dtype: + # This path will preserve id(result._combined) + # TODO: could also validate dtype before going to simple_new + combined = data._combined + if copy: + combined = combined.copy() + result = cls._simple_new(combined, closed=closed) + if verify_integrity: + result._validate() + return result else: # don't allow scalars @@ -186,83 +205,22 @@ def __new__( ) closed = closed or infer_closed - return cls._simple_new( - left, - right, - closed, - copy=copy, - dtype=dtype, - verify_integrity=verify_integrity, - ) + closed = closed or "right" + left, right = _maybe_cast_inputs(left, right, copy, dtype) + combined = _get_combined_data(left, right) + result = cls._simple_new(combined, closed=closed) + if verify_integrity: + result._validate() + return result @classmethod - def _simple_new( - cls, left, right, closed=None, copy=False, dtype=None, verify_integrity=True - ): + def _simple_new(cls, data, closed="right"): result = IntervalMixin.__new__(cls) - closed = closed or "right" - left = ensure_index(left, copy=copy) - right = ensure_index(right, copy=copy) - - if dtype is not None: - # GH 19262: dtype must be an IntervalDtype to override inferred - dtype = pandas_dtype(dtype) - if not is_interval_dtype(dtype): - msg = f"dtype must be an IntervalDtype, got {dtype}" - raise TypeError(msg) - elif dtype.subtype is not None: - left = left.astype(dtype.subtype) - right = right.astype(dtype.subtype) - - # coerce dtypes to match if needed - if is_float_dtype(left) and is_integer_dtype(right): - right = right.astype(left.dtype) - elif is_float_dtype(right) and is_integer_dtype(left): - left = left.astype(right.dtype) - - if type(left) != type(right): - msg = ( - f"must not have differing left [{type(left).__name__}] and " - f"right [{type(right).__name__}] types" - ) - raise ValueError(msg) - elif is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype): - # GH 19016 - msg = ( - "category, object, and string subtypes are not supported " - "for IntervalArray" - ) - raise TypeError(msg) - elif isinstance(left, ABCPeriodIndex): - msg = "Period dtypes are not supported, use a PeriodIndex instead" - raise ValueError(msg) - elif isinstance(left, ABCDatetimeIndex) and str(left.tz) != str(right.tz): - msg = ( - "left and right must have the same time zone, got " - f"'{left.tz}' and '{right.tz}'" - ) - raise ValueError(msg) - - # For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray - from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array - - left = maybe_upcast_datetimelike_array(left) - left = extract_array(left, extract_numpy=True) - right = maybe_upcast_datetimelike_array(right) - right = extract_array(right, extract_numpy=True) - - lbase = getattr(left, "_ndarray", left).base - rbase = getattr(right, "_ndarray", right).base - if lbase is not None and lbase is rbase: - # If these share data, then setitem could corrupt our IA - right = right.copy() - - result._left = left - result._right = right + result._combined = data + result._left = data[:, 0] + result._right = data[:, 1] result._closed = closed - if verify_integrity: - result._validate() return result @classmethod @@ -397,10 +355,16 @@ def from_breaks(cls, breaks, closed="right", copy=False, dtype=None): def from_arrays(cls, left, right, closed="right", copy=False, dtype=None): left = maybe_convert_platform_interval(left) right = maybe_convert_platform_interval(right) + if len(left) != len(right): + raise ValueError("left and right must have the same length") - return cls._simple_new( - left, right, closed, copy=copy, dtype=dtype, verify_integrity=True - ) + closed = closed or "right" + left, right = _maybe_cast_inputs(left, right, copy, dtype) + combined = _get_combined_data(left, right) + + result = cls._simple_new(combined, closed) + result._validate() + return result _interval_shared_docs["from_tuples"] = textwrap.dedent( """ @@ -506,19 +470,6 @@ def _validate(self): msg = "left side of interval must be <= right side" raise ValueError(msg) - def _shallow_copy(self, left, right): - """ - Return a new IntervalArray with the replacement attributes - - Parameters - ---------- - left : Index - Values to be used for the left-side of the intervals. - right : Index - Values to be used for the right-side of the intervals. - """ - return self._simple_new(left, right, closed=self.closed, verify_integrity=False) - # --------------------------------------------------------------------- # Descriptive @@ -546,18 +497,20 @@ def __len__(self) -> int: def __getitem__(self, key): key = check_array_indexer(self, key) - left = self._left[key] - right = self._right[key] - if not isinstance(left, (np.ndarray, ExtensionArray)): - # scalar - if is_scalar(left) and isna(left): + result = self._combined[key] + + if is_integer(key): + left, right = result[0], result[1] + if isna(left): return self._fill_value return Interval(left, right, self.closed) - if np.ndim(left) > 1: + + # TODO: need to watch out for incorrectly-reducing getitem + if np.ndim(result) > 2: # GH#30588 multi-dimensional indexer disallowed raise ValueError("multi-dimensional indexing not allowed") - return self._shallow_copy(left, right) + return type(self)._simple_new(result, closed=self.closed) def __setitem__(self, key, value): value_left, value_right = self._validate_setitem_value(value) @@ -651,7 +604,8 @@ def fillna(self, value=None, method=None, limit=None): left = self.left.fillna(value=value_left) right = self.right.fillna(value=value_right) - return self._shallow_copy(left, right) + combined = _get_combined_data(left, right) + return type(self)._simple_new(combined, closed=self.closed) def astype(self, dtype, copy=True): """ @@ -693,7 +647,9 @@ def astype(self, dtype, copy=True): f"Cannot convert {self.dtype} to {dtype}; subtypes are incompatible" ) raise TypeError(msg) from err - return self._shallow_copy(new_left, new_right) + # TODO: do astype directly on self._combined + combined = _get_combined_data(new_left, new_right) + return type(self)._simple_new(combined, closed=self.closed) elif is_categorical_dtype(dtype): return Categorical(np.asarray(self)) elif isinstance(dtype, StringDtype): @@ -734,9 +690,11 @@ def _concat_same_type(cls, to_concat): raise ValueError("Intervals must all be closed on the same side.") closed = closed.pop() + # TODO: will this mess up on dt64tz? left = np.concatenate([interval.left for interval in to_concat]) right = np.concatenate([interval.right for interval in to_concat]) - return cls._simple_new(left, right, closed=closed, copy=False) + combined = _get_combined_data(left, right) # TODO: 1-stage concat + return cls._simple_new(combined, closed=closed) def copy(self): """ @@ -746,11 +704,8 @@ def copy(self): ------- IntervalArray """ - left = self._left.copy() - right = self._right.copy() - closed = self.closed - # TODO: Could skip verify_integrity here. - return type(self).from_arrays(left, right, closed=closed) + combined = self._combined.copy() + return type(self)._simple_new(combined, closed=self.closed) def isna(self) -> np.ndarray: return isna(self._left) @@ -843,7 +798,8 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs): self._right, indices, allow_fill=allow_fill, fill_value=fill_right ) - return self._shallow_copy(left_take, right_take) + combined = _get_combined_data(left_take, right_take) + return type(self)._simple_new(combined, closed=self.closed) def _validate_listlike(self, value): # list-like of intervals @@ -1170,10 +1126,7 @@ def set_closed(self, closed): if closed not in VALID_CLOSED: msg = f"invalid option for 'closed': {closed}" raise ValueError(msg) - - return type(self)._simple_new( - left=self._left, right=self._right, closed=closed, verify_integrity=False - ) + return type(self)._simple_new(self._combined, closed=closed) _interval_shared_docs[ "is_non_overlapping_monotonic" @@ -1314,9 +1267,8 @@ def to_tuples(self, na_tuple=True): @Appender(_extension_array_shared_docs["repeat"] % _shared_docs_kwargs) def repeat(self, repeats, axis=None): nv.validate_repeat(tuple(), dict(axis=axis)) - left_repeat = self.left.repeat(repeats) - right_repeat = self.right.repeat(repeats) - return self._shallow_copy(left=left_repeat, right=right_repeat) + combined = self._combined.repeat(repeats, 0) + return type(self)._simple_new(combined, closed=self.closed) _interval_shared_docs["contains"] = textwrap.dedent( """ @@ -1399,3 +1351,92 @@ def maybe_convert_platform_interval(values): values = np.asarray(values) return maybe_convert_platform(values) + + +def _maybe_cast_inputs( + left_orig: Union["Index", ArrayLike], + right_orig: Union["Index", ArrayLike], + copy: bool, + dtype: Optional[Dtype], +) -> Tuple["Index", "Index"]: + left = ensure_index(left_orig, copy=copy) + right = ensure_index(right_orig, copy=copy) + + if dtype is not None: + # GH#19262: dtype must be an IntervalDtype to override inferred + dtype = pandas_dtype(dtype) + if not is_interval_dtype(dtype): + msg = f"dtype must be an IntervalDtype, got {dtype}" + raise TypeError(msg) + dtype = cast(IntervalDtype, dtype) + if dtype.subtype is not None: + left = left.astype(dtype.subtype) + right = right.astype(dtype.subtype) + + # coerce dtypes to match if needed + if is_float_dtype(left) and is_integer_dtype(right): + right = right.astype(left.dtype) + elif is_float_dtype(right) and is_integer_dtype(left): + left = left.astype(right.dtype) + + if type(left) != type(right): + msg = ( + f"must not have differing left [{type(left).__name__}] and " + f"right [{type(right).__name__}] types" + ) + raise ValueError(msg) + elif is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype): + # GH#19016 + msg = ( + "category, object, and string subtypes are not supported " + "for IntervalArray" + ) + raise TypeError(msg) + elif isinstance(left, ABCPeriodIndex): + msg = "Period dtypes are not supported, use a PeriodIndex instead" + raise ValueError(msg) + elif isinstance(left, ABCDatetimeIndex) and not is_dtype_equal( + left.dtype, right.dtype + ): + left_arr = cast("DatetimeArray", left._data) + right_arr = cast("DatetimeArray", right._data) + msg = ( + "left and right must have the same time zone, got " + f"'{left_arr.tz}' and '{right_arr.tz}'" + ) + raise ValueError(msg) + + return left, right + + +def _get_combined_data( + left: Union["Index", ArrayLike], right: Union["Index", ArrayLike] +) -> Union[np.ndarray, "DatetimeArray", "TimedeltaArray"]: + # For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray + from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array + + left = maybe_upcast_datetimelike_array(left) + left = extract_array(left, extract_numpy=True) + right = maybe_upcast_datetimelike_array(right) + right = extract_array(right, extract_numpy=True) + + lbase = getattr(left, "_ndarray", left).base + rbase = getattr(right, "_ndarray", right).base + if lbase is not None and lbase is rbase: + # If these share data, then setitem could corrupt our IA + right = right.copy() + + if isinstance(left, np.ndarray): + assert isinstance(right, np.ndarray) # for mypy + combined = np.concatenate( + [left.reshape(-1, 1), right.reshape(-1, 1)], + axis=1, + ) + else: + left = cast(Union["DatetimeArray", "TimedeltaArray"], left) + right = cast(Union["DatetimeArray", "TimedeltaArray"], right) + combined = type(left)._concat_same_type( + [left.reshape(-1, 1), right.reshape(-1, 1)], + axis=1, + ) + return combined diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index cc47740dba5f2..cb25ef1241ce0 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -896,7 +896,7 @@ def delete(self, loc): """ new_left = self.left.delete(loc) new_right = self.right.delete(loc) - result = self._data._shallow_copy(new_left, new_right) + result = IntervalArray.from_arrays(new_left, new_right, closed=self.closed) return self._shallow_copy(result) def insert(self, loc, item): @@ -918,7 +918,7 @@ def insert(self, loc, item): new_left = self.left.insert(loc, left_insert) new_right = self.right.insert(loc, right_insert) - result = self._data._shallow_copy(new_left, new_right) + result = IntervalArray.from_arrays(new_left, new_right, closed=self.closed) return self._shallow_copy(result) @Appender(_index_shared_docs["take"] % _index_doc_kwargs) diff --git a/pandas/tests/base/test_conversion.py b/pandas/tests/base/test_conversion.py index b5595ba220a15..26ad6fc1c6572 100644 --- a/pandas/tests/base/test_conversion.py +++ b/pandas/tests/base/test_conversion.py @@ -241,7 +241,7 @@ def test_numpy_array_all_dtypes(any_numpy_dtype): (pd.Categorical(["a", "b"]), "_codes"), (pd.core.arrays.period_array(["2000", "2001"], freq="D"), "_data"), (pd.core.arrays.integer_array([0, np.nan]), "_data"), - (IntervalArray.from_breaks([0, 1]), "_left"), + (IntervalArray.from_breaks([0, 1]), "_combined"), (SparseArray([0, 1]), "_sparse_values"), (DatetimeArray(np.array([1, 2], dtype="datetime64[ns]")), "_data"), # tz-aware Datetime diff --git a/pandas/tests/indexes/interval/test_constructors.py b/pandas/tests/indexes/interval/test_constructors.py index aec7de549744f..c0ca0b415ba8e 100644 --- a/pandas/tests/indexes/interval/test_constructors.py +++ b/pandas/tests/indexes/interval/test_constructors.py @@ -266,7 +266,11 @@ def test_left_right_dont_share_data(self): # GH#36310 breaks = np.arange(5) result = IntervalIndex.from_breaks(breaks)._data - assert result._left.base is None or result._left.base is not result._right.base + left = result._left + right = result._right + + left[:] = 10000 + assert not (right == 10000).any() class TestFromTuples(Base):