Skip to content

Commit

Permalink
Ensure TDA.__init__ validates freq (pandas-dev#24666)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and jreback committed Jan 9, 2019
1 parent 46a31c9 commit decc8ce
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 64 deletions.
91 changes: 30 additions & 61 deletions pandas/core/arrays/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from pandas.util._decorators import Appender

from pandas.core.dtypes.common import (
_NS_DTYPE, _TD_DTYPE, ensure_int64, is_datetime64_dtype, is_float_dtype,
is_integer_dtype, is_list_like, is_object_dtype, is_scalar,
_NS_DTYPE, _TD_DTYPE, ensure_int64, is_datetime64_dtype, is_dtype_equal,
is_float_dtype, is_integer_dtype, is_list_like, is_object_dtype, is_scalar,
is_string_dtype, is_timedelta64_dtype, is_timedelta64_ns_dtype,
pandas_dtype)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
Expand Down Expand Up @@ -134,55 +134,39 @@ def dtype(self):
_attributes = ["freq"]

def __init__(self, values, dtype=_TD_DTYPE, freq=None, copy=False):
if isinstance(values, (ABCSeries, ABCIndexClass)):
values = values._values

if isinstance(values, type(self)):
values, freq, freq_infer = extract_values_freq(values, freq)

if not isinstance(values, np.ndarray):
msg = (
if not hasattr(values, "dtype"):
raise ValueError(
"Unexpected type '{}'. 'values' must be a TimedeltaArray "
"ndarray, or Series or Index containing one of those."
)
raise ValueError(msg.format(type(values).__name__))

if values.dtype == 'i8':
# for compat with datetime/timedelta/period shared methods,
# we can sometimes get here with int64 values. These represent
# nanosecond UTC (or tz-naive) unix timestamps
values = values.view(_TD_DTYPE)

if values.dtype != _TD_DTYPE:
raise TypeError(_BAD_DTYPE.format(dtype=values.dtype))

try:
dtype_mismatch = dtype != _TD_DTYPE
except TypeError:
raise TypeError(_BAD_DTYPE.format(dtype=dtype))
else:
if dtype_mismatch:
raise TypeError(_BAD_DTYPE.format(dtype=dtype))

.format(type(values).__name__))
if freq == "infer":
msg = (
raise ValueError(
"Frequency inference not allowed in TimedeltaArray.__init__. "
"Use 'pd.array()' instead."
)
raise ValueError(msg)
"Use 'pd.array()' instead.")

if copy:
values = values.copy()
if freq:
freq = to_offset(freq)
if dtype is not None and not is_dtype_equal(dtype, _TD_DTYPE):
raise TypeError("dtype {dtype} cannot be converted to "
"timedelta64[ns]".format(dtype=dtype))

if values.dtype == 'i8':
values = values.view('timedelta64[ns]')

self._data = values
self._dtype = dtype
self._freq = freq
result = type(self)._from_sequence(values, dtype=dtype,
copy=copy, freq=freq)
self._data = result._data
self._freq = result._freq
self._dtype = result._dtype

@classmethod
def _simple_new(cls, values, freq=None, dtype=_TD_DTYPE):
return cls(values, dtype=dtype, freq=freq)
assert dtype == _TD_DTYPE, dtype
assert isinstance(values, np.ndarray), type(values)

result = object.__new__(cls)
result._data = values.view(_TD_DTYPE)
result._freq = to_offset(freq)
result._dtype = _TD_DTYPE
return result

@classmethod
def _from_sequence(cls, data, dtype=_TD_DTYPE, copy=False,
Expand Down Expand Up @@ -860,17 +844,17 @@ def sequence_to_td64ns(data, copy=False, unit="ns", errors="raise"):
data = data._data

# Convert whatever we have into timedelta64[ns] dtype
if is_object_dtype(data) or is_string_dtype(data):
if is_object_dtype(data.dtype) or is_string_dtype(data.dtype):
# no need to make a copy, need to convert if string-dtyped
data = objects_to_td64ns(data, unit=unit, errors=errors)
copy = False

elif is_integer_dtype(data):
elif is_integer_dtype(data.dtype):
# treat as multiples of the given unit
data, copy_made = ints_to_td64ns(data, unit=unit)
copy = copy and not copy_made

elif is_float_dtype(data):
elif is_float_dtype(data.dtype):
# treat as multiples of the given unit. If after converting to nanos,
# there are fractional components left, these are truncated
# (i.e. NOT rounded)
Expand All @@ -880,7 +864,7 @@ def sequence_to_td64ns(data, copy=False, unit="ns", errors="raise"):
data[mask] = iNaT
copy = False

elif is_timedelta64_dtype(data):
elif is_timedelta64_dtype(data.dtype):
if data.dtype != _TD_DTYPE:
# non-nano unit
# TODO: watch out for overflows
Expand Down Expand Up @@ -998,18 +982,3 @@ def _generate_regular_range(start, end, periods, offset):

data = np.arange(b, e, stride, dtype=np.int64)
return data


def extract_values_freq(arr, freq):
# type: (TimedeltaArray, Offset) -> Tuple[ndarray, Offset, bool]
freq_infer = False
if freq is None:
freq = arr.freq
elif freq and arr.freq:
freq = to_offset(freq)
freq, freq_infer = dtl.validate_inferred_freq(
freq, arr.freq,
freq_infer=False
)
values = arr._data
return values, freq, freq_infer
6 changes: 4 additions & 2 deletions pandas/core/indexes/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,14 @@ def _simple_new(cls, values, name=None, freq=None, dtype=_TD_DTYPE):
if not isinstance(values, TimedeltaArray):
values = TimedeltaArray._simple_new(values, dtype=dtype,
freq=freq)
else:
if freq is None:
freq = values.freq
assert isinstance(values, TimedeltaArray), type(values)
assert dtype == _TD_DTYPE, dtype
assert values.dtype == 'm8[ns]', values.dtype

freq = to_offset(freq)
tdarr = TimedeltaArray._simple_new(values, freq=freq)
tdarr = TimedeltaArray._simple_new(values._data, freq=freq)
result = object.__new__(cls)
result._data = tdarr
result.name = name
Expand Down
11 changes: 10 additions & 1 deletion pandas/tests/arrays/test_timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@


class TestTimedeltaArrayConstructor(object):
def test_freq_validation(self):
# ensure that the public constructor cannot create an invalid instance
arr = np.array([0, 0, 1], dtype=np.int64) * 3600 * 10**9

msg = ("Inferred frequency None from passed values does not "
"conform to passed frequency D")
with pytest.raises(ValueError, match=msg):
TimedeltaArray(arr.view('timedelta64[ns]'), freq="D")

def test_non_array_raises(self):
with pytest.raises(ValueError, match='list'):
TimedeltaArray([1, 2, 3])
Expand All @@ -34,7 +43,7 @@ def test_incorrect_dtype_raises(self):
def test_copy(self):
data = np.array([1, 2, 3], dtype='m8[ns]')
arr = TimedeltaArray(data, copy=False)
assert arr._data is data
assert arr._data.base is data

arr = TimedeltaArray(data, copy=True)
assert arr._data is not data
Expand Down

0 comments on commit decc8ce

Please sign in to comment.