Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: back IntervalArray by a single ndarray #37047

Merged
merged 9 commits into from
Oct 12, 2020
263 changes: 152 additions & 111 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from operator import le, lt
import textwrap
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast

import numpy as np

Expand All @@ -11,14 +12,17 @@
IntervalMixin,
intervals_to_interval_bounds,
)
from pandas._typing import ArrayLike, Dtype
from pandas.compat.numpy import function as nv
from pandas.util._decorators import Appender

from pandas.core.dtypes.cast import maybe_convert_platform
from pandas.core.dtypes.common import (
is_categorical_dtype,
is_datetime64_any_dtype,
is_dtype_equal,
is_float_dtype,
is_integer,
is_integer_dtype,
is_interval_dtype,
is_list_like,
Expand All @@ -45,6 +49,10 @@
from pandas.core.indexers import check_array_indexer
from pandas.core.indexes.base import ensure_index

if TYPE_CHECKING:
from pandas import Index
from pandas.core.arrays import DatetimeArray, TimedeltaArray

_interval_shared_docs = {}

_shared_docs_kwargs = dict(
Expand Down Expand Up @@ -169,6 +177,17 @@ def __new__(
left = data._left
right = data._right
closed = closed or data.closed

if dtype is None or data.dtype == dtype:
# This path will preserve id(result._combined)
# TODO: could also validate dtype before going to simple_new
combined = data._combined
if copy:
combined = combined.copy()
result = cls._simple_new(combined, closed=closed)
if verify_integrity:
result._validate()
return result
else:

# don't allow scalars
Expand All @@ -186,83 +205,22 @@ def __new__(
)
closed = closed or infer_closed

return cls._simple_new(
left,
right,
closed,
copy=copy,
dtype=dtype,
verify_integrity=verify_integrity,
)
closed = closed or "right"
left, right = _maybe_cast_inputs(left, right, copy, dtype)
combined = _get_combined_data(left, right)
result = cls._simple_new(combined, closed=closed)
if verify_integrity:
result._validate()
return result

@classmethod
def _simple_new(
cls, left, right, closed=None, copy=False, dtype=None, verify_integrity=True
):
def _simple_new(cls, data, closed="right"):
result = IntervalMixin.__new__(cls)

closed = closed or "right"
left = ensure_index(left, copy=copy)
right = ensure_index(right, copy=copy)

if dtype is not None:
# GH 19262: dtype must be an IntervalDtype to override inferred
dtype = pandas_dtype(dtype)
if not is_interval_dtype(dtype):
msg = f"dtype must be an IntervalDtype, got {dtype}"
raise TypeError(msg)
elif dtype.subtype is not None:
left = left.astype(dtype.subtype)
right = right.astype(dtype.subtype)

# coerce dtypes to match if needed
if is_float_dtype(left) and is_integer_dtype(right):
right = right.astype(left.dtype)
elif is_float_dtype(right) and is_integer_dtype(left):
left = left.astype(right.dtype)

if type(left) != type(right):
msg = (
f"must not have differing left [{type(left).__name__}] and "
f"right [{type(right).__name__}] types"
)
raise ValueError(msg)
elif is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype):
# GH 19016
msg = (
"category, object, and string subtypes are not supported "
"for IntervalArray"
)
raise TypeError(msg)
elif isinstance(left, ABCPeriodIndex):
msg = "Period dtypes are not supported, use a PeriodIndex instead"
raise ValueError(msg)
elif isinstance(left, ABCDatetimeIndex) and str(left.tz) != str(right.tz):
msg = (
"left and right must have the same time zone, got "
f"'{left.tz}' and '{right.tz}'"
)
raise ValueError(msg)

# For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array

left = maybe_upcast_datetimelike_array(left)
left = extract_array(left, extract_numpy=True)
right = maybe_upcast_datetimelike_array(right)
right = extract_array(right, extract_numpy=True)

lbase = getattr(left, "_ndarray", left).base
rbase = getattr(right, "_ndarray", right).base
if lbase is not None and lbase is rbase:
# If these share data, then setitem could corrupt our IA
right = right.copy()

result._left = left
result._right = right
result._combined = data
result._left = data[:, 0]
result._right = data[:, 1]
result._closed = closed
if verify_integrity:
result._validate()
return result

@classmethod
Expand Down Expand Up @@ -397,10 +355,16 @@ def from_breaks(cls, breaks, closed="right", copy=False, dtype=None):
def from_arrays(cls, left, right, closed="right", copy=False, dtype=None):
left = maybe_convert_platform_interval(left)
right = maybe_convert_platform_interval(right)
if len(left) != len(right):
raise ValueError("left and right must have the same length")

return cls._simple_new(
left, right, closed, copy=copy, dtype=dtype, verify_integrity=True
)
closed = closed or "right"
left, right = _maybe_cast_inputs(left, right, copy, dtype)
combined = _get_combined_data(left, right)

result = cls._simple_new(combined, closed)
result._validate()
jreback marked this conversation as resolved.
Show resolved Hide resolved
return result

_interval_shared_docs["from_tuples"] = textwrap.dedent(
"""
Expand Down Expand Up @@ -506,19 +470,6 @@ def _validate(self):
msg = "left side of interval must be <= right side"
raise ValueError(msg)

def _shallow_copy(self, left, right):
"""
Return a new IntervalArray with the replacement attributes

Parameters
----------
left : Index
Values to be used for the left-side of the intervals.
right : Index
Values to be used for the right-side of the intervals.
"""
return self._simple_new(left, right, closed=self.closed, verify_integrity=False)

# ---------------------------------------------------------------------
# Descriptive

Expand Down Expand Up @@ -546,18 +497,20 @@ def __len__(self) -> int:

def __getitem__(self, key):
key = check_array_indexer(self, key)
left = self._left[key]
right = self._right[key]

if not isinstance(left, (np.ndarray, ExtensionArray)):
# scalar
if is_scalar(left) and isna(left):
result = self._combined[key]

if is_integer(key):
left, right = result[0], result[1]
if isna(left):
return self._fill_value
return Interval(left, right, self.closed)
if np.ndim(left) > 1:

# TODO: need to watch out for incorrectly-reducing getitem
if np.ndim(result) > 2:
# GH#30588 multi-dimensional indexer disallowed
raise ValueError("multi-dimensional indexing not allowed")
return self._shallow_copy(left, right)
return type(self)._simple_new(result, closed=self.closed)

def __setitem__(self, key, value):
value_left, value_right = self._validate_setitem_value(value)
Expand Down Expand Up @@ -651,7 +604,8 @@ def fillna(self, value=None, method=None, limit=None):

left = self.left.fillna(value=value_left)
right = self.right.fillna(value=value_right)
return self._shallow_copy(left, right)
combined = _get_combined_data(left, right)
return type(self)._simple_new(combined, closed=self.closed)

def astype(self, dtype, copy=True):
"""
Expand Down Expand Up @@ -693,7 +647,9 @@ def astype(self, dtype, copy=True):
f"Cannot convert {self.dtype} to {dtype}; subtypes are incompatible"
)
raise TypeError(msg) from err
return self._shallow_copy(new_left, new_right)
# TODO: do astype directly on self._combined
combined = _get_combined_data(new_left, new_right)
return type(self)._simple_new(combined, closed=self.closed)
elif is_categorical_dtype(dtype):
return Categorical(np.asarray(self))
elif isinstance(dtype, StringDtype):
Expand Down Expand Up @@ -734,9 +690,11 @@ def _concat_same_type(cls, to_concat):
raise ValueError("Intervals must all be closed on the same side.")
closed = closed.pop()

# TODO: will this mess up on dt64tz?
left = np.concatenate([interval.left for interval in to_concat])
right = np.concatenate([interval.right for interval in to_concat])
return cls._simple_new(left, right, closed=closed, copy=False)
combined = _get_combined_data(left, right) # TODO: 1-stage concat
return cls._simple_new(combined, closed=closed)

def copy(self):
"""
Expand All @@ -746,11 +704,8 @@ def copy(self):
-------
IntervalArray
"""
left = self._left.copy()
right = self._right.copy()
closed = self.closed
# TODO: Could skip verify_integrity here.
return type(self).from_arrays(left, right, closed=closed)
combined = self._combined.copy()
return type(self)._simple_new(combined, closed=self.closed)

def isna(self) -> np.ndarray:
return isna(self._left)
Expand Down Expand Up @@ -843,7 +798,8 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs):
self._right, indices, allow_fill=allow_fill, fill_value=fill_right
)

return self._shallow_copy(left_take, right_take)
combined = _get_combined_data(left_take, right_take)
return type(self)._simple_new(combined, closed=self.closed)

def _validate_listlike(self, value):
# list-like of intervals
Expand Down Expand Up @@ -1170,10 +1126,7 @@ def set_closed(self, closed):
if closed not in VALID_CLOSED:
msg = f"invalid option for 'closed': {closed}"
raise ValueError(msg)

return type(self)._simple_new(
left=self._left, right=self._right, closed=closed, verify_integrity=False
)
return type(self)._simple_new(self._combined, closed=closed)

_interval_shared_docs[
"is_non_overlapping_monotonic"
Expand Down Expand Up @@ -1314,9 +1267,8 @@ def to_tuples(self, na_tuple=True):
@Appender(_extension_array_shared_docs["repeat"] % _shared_docs_kwargs)
def repeat(self, repeats, axis=None):
nv.validate_repeat(tuple(), dict(axis=axis))
left_repeat = self.left.repeat(repeats)
right_repeat = self.right.repeat(repeats)
return self._shallow_copy(left=left_repeat, right=right_repeat)
combined = self._combined.repeat(repeats, 0)
return type(self)._simple_new(combined, closed=self.closed)

_interval_shared_docs["contains"] = textwrap.dedent(
"""
Expand Down Expand Up @@ -1399,3 +1351,92 @@ def maybe_convert_platform_interval(values):
values = np.asarray(values)

return maybe_convert_platform(values)


def _maybe_cast_inputs(
left_orig: Union["Index", ArrayLike],
right_orig: Union["Index", ArrayLike],
copy: bool,
dtype: Optional[Dtype],
) -> Tuple["Index", "Index"]:
left = ensure_index(left_orig, copy=copy)
right = ensure_index(right_orig, copy=copy)

if dtype is not None:
# GH#19262: dtype must be an IntervalDtype to override inferred
dtype = pandas_dtype(dtype)
if not is_interval_dtype(dtype):
msg = f"dtype must be an IntervalDtype, got {dtype}"
raise TypeError(msg)
dtype = cast(IntervalDtype, dtype)
if dtype.subtype is not None:
left = left.astype(dtype.subtype)
right = right.astype(dtype.subtype)

# coerce dtypes to match if needed
if is_float_dtype(left) and is_integer_dtype(right):
right = right.astype(left.dtype)
elif is_float_dtype(right) and is_integer_dtype(left):
left = left.astype(right.dtype)

if type(left) != type(right):
msg = (
f"must not have differing left [{type(left).__name__}] and "
f"right [{type(right).__name__}] types"
)
raise ValueError(msg)
elif is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype):
# GH#19016
msg = (
"category, object, and string subtypes are not supported "
"for IntervalArray"
)
raise TypeError(msg)
elif isinstance(left, ABCPeriodIndex):
msg = "Period dtypes are not supported, use a PeriodIndex instead"
raise ValueError(msg)
elif isinstance(left, ABCDatetimeIndex) and not is_dtype_equal(
left.dtype, right.dtype
):
left_arr = cast("DatetimeArray", left._data)
right_arr = cast("DatetimeArray", right._data)
msg = (
"left and right must have the same time zone, got "
f"'{left_arr.tz}' and '{right_arr.tz}'"
)
raise ValueError(msg)

return left, right


def _get_combined_data(
left: Union["Index", ArrayLike], right: Union["Index", ArrayLike]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this method be more strict? e.g. only accept Union[np.ndarray, "DatetimeArray", "TimedeltaArray"]

e.g. things that have already been casted?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we can do quite a bit less back-and-forth casting eventually, yes

) -> Union[np.ndarray, "DatetimeArray", "TimedeltaArray"]:
# For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array

left = maybe_upcast_datetimelike_array(left)
left = extract_array(left, extract_numpy=True)
right = maybe_upcast_datetimelike_array(right)
right = extract_array(right, extract_numpy=True)

lbase = getattr(left, "_ndarray", left).base
rbase = getattr(right, "_ndarray", right).base
if lbase is not None and lbase is rbase:
# If these share data, then setitem could corrupt our IA
right = right.copy()

if isinstance(left, np.ndarray):
assert isinstance(right, np.ndarray) # for mypy
combined = np.concatenate(
[left.reshape(-1, 1), right.reshape(-1, 1)],
axis=1,
)
else:
left = cast(Union["DatetimeArray", "TimedeltaArray"], left)
right = cast(Union["DatetimeArray", "TimedeltaArray"], right)
combined = type(left)._concat_same_type(
[left.reshape(-1, 1), right.reshape(-1, 1)],
axis=1,
)
return combined
Loading