Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: Infer dtype scalar #16408

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions doc/source/whatsnew/v0.21.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@ Backwards incompatible API changes
- Accessing a non-existent attribute on a closed :class:`HDFStore` will now
raise an ``AttributeError`` rather than a ``ClosedFileError`` (:issue:`16301`)

.. _whatsnew_0210.dtype_conversions:

Dtype Conversions
^^^^^^^^^^^^^^^^^

Example about setitem / where with bools.



- Inconsistent behavior in ``.where()`` with datetimelikes which would raise rather than coerce to ``object`` (:issue:`16402`)
- Bug in assignment against datetime-like data with ``int`` may incorrectly convert to datetime-like (:issue:`14145`)
- Bug in assignment against ``int64`` data with ``np.ndarray`` with ``float64`` dtype may keep ``int64`` dtype (:issue:`14001`)


.. _whatsnew_0210.api:

Other API Changes
Expand Down
26 changes: 20 additions & 6 deletions pandas/_libs/index.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ cimport tslib
from hashtable cimport *
from pandas._libs import tslib, algos, hashtable as _hash
from pandas._libs.tslib import Timestamp, Timedelta
from datetime import datetime, timedelta

from datetime cimport (get_datetime64_value, _pydatetime_to_dts,
pandas_datetimestruct)
Expand Down Expand Up @@ -507,24 +508,37 @@ cdef class TimedeltaEngine(DatetimeEngine):
return 'm8[ns]'

cpdef convert_scalar(ndarray arr, object value):
# we don't turn integers
# into datetimes/timedeltas

# we don't turn bools into int/float/complex

if arr.descr.type_num == NPY_DATETIME:
if isinstance(value, np.ndarray):
pass
elif isinstance(value, Timestamp):
return value.value
elif isinstance(value, datetime):
return Timestamp(value).value
elif value is None or value != value:
return iNaT
else:
elif util.is_string_object(value):
return Timestamp(value).value
raise ValueError("cannot set a Timestamp with a non-timestamp")

elif arr.descr.type_num == NPY_TIMEDELTA:
if isinstance(value, np.ndarray):
pass
elif isinstance(value, Timedelta):
return value.value
elif isinstance(value, timedelta):
return Timedelta(value).value
elif value is None or value != value:
return iNaT
else:
elif util.is_string_object(value):
return Timedelta(value).value
raise ValueError("cannot set a Timedelta with a non-timedelta")

if (issubclass(arr.dtype.type, (np.integer, np.floating, np.complex)) and not
issubclass(arr.dtype.type, np.bool_)):
if util.is_bool_object(value):
raise ValueError('Cannot assign bool to float/integer series')

if issubclass(arr.dtype.type, (np.integer, np.bool_)):
if util.is_float_object(value) and value != value:
Expand Down
75 changes: 66 additions & 9 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings

from pandas._libs import tslib, lib
from pandas._libs.tslib import iNaT
from pandas._libs.tslib import iNaT, Timestamp
from pandas.compat import string_types, text_type, PY3
from .common import (_ensure_object, is_bool, is_integer, is_float,
is_complex, is_datetimetz, is_categorical_dtype,
Expand Down Expand Up @@ -272,7 +272,7 @@ def maybe_promote(dtype, fill_value=np.nan):
else:
if issubclass(dtype.type, np.datetime64):
try:
fill_value = lib.Timestamp(fill_value).value
fill_value = Timestamp(fill_value).value
except:
# the proper thing to do here would probably be to upcast
# to object (but numpy 1.6.1 doesn't do this properly)
Expand Down Expand Up @@ -333,6 +333,23 @@ def maybe_promote(dtype, fill_value=np.nan):
return dtype, fill_value


def infer_dtype_from(val, pandas_dtype=False):
"""
interpret the dtype from a scalar or array. This is a convenience
routines to infer dtype from a scalar or an array

Parameters
----------
pandas_dtype : bool, default False
whether to infer dtype including pandas extension types.
If False, scalar/array belongs to pandas extension types is inferred as
object
"""
if is_scalar(val):
return infer_dtype_from_scalar(val, pandas_dtype=pandas_dtype)
return infer_dtype_from_array(val, pandas_dtype=pandas_dtype)


def infer_dtype_from_scalar(val, pandas_dtype=False):
"""
interpret the dtype from a scalar
Expand All @@ -349,9 +366,9 @@ def infer_dtype_from_scalar(val, pandas_dtype=False):

# a 1-element ndarray
if isinstance(val, np.ndarray):
msg = "invalid ndarray passed to _infer_dtype_from_scalar"
if val.ndim != 0:
raise ValueError(
"invalid ndarray passed to _infer_dtype_from_scalar")
raise ValueError(msg)

dtype = val.dtype
val = val.item()
Expand Down Expand Up @@ -408,24 +425,32 @@ def infer_dtype_from_scalar(val, pandas_dtype=False):
return dtype, val


def infer_dtype_from_array(arr):
def infer_dtype_from_array(arr, pandas_dtype=False):
"""
infer the dtype from a scalar or array

Parameters
----------
arr : scalar or array
pandas_dtype : bool, default False
whether to infer dtype including pandas extension types.
If False, array belongs to pandas extension types
is inferred as object

Returns
-------
tuple (numpy-compat dtype, array)
tuple (numpy-compat/pandas-compat dtype, array)

Notes
-----
These infer to numpy dtypes exactly
with the exception that mixed / object dtypes

if pandas_dtype=False. these infer to numpy dtypes
exactly with the exception that mixed / object dtypes
are not coerced by stringifying or conversion

if pandas_dtype=True. datetime64tz-aware/categorical
types will retain there character.

Examples
--------
>>> np.asarray([1, '1'])
Expand All @@ -442,6 +467,10 @@ def infer_dtype_from_array(arr):
if not is_list_like(arr):
arr = [arr]

if pandas_dtype and (is_categorical_dtype(arr) or
is_datetime64tz_dtype(arr)):
return arr.dtype, arr

# don't force numpy coerce with nan's
inferred = lib.infer_dtype(arr)
if inferred in ['string', 'bytes', 'unicode',
Expand Down Expand Up @@ -552,7 +581,7 @@ def conv(r, dtype):
if isnull(r):
pass
elif dtype == _NS_DTYPE:
r = lib.Timestamp(r)
r = Timestamp(r)
elif dtype == _TD_DTYPE:
r = _coerce_scalar_to_timedelta_type(r)
elif dtype == np.bool_:
Expand Down Expand Up @@ -1026,3 +1055,31 @@ def find_common_type(types):
return np.object

return np.find_common_type(types, [])


def cast_scalar_to_array(shape, value, dtype=None):
"""
create np.ndarray of specified shape and dtype, filled with values

Parameters
----------
shape : tuple
value : scalar value
dtype : np.dtype, optional
dtype to coerce

Returns
-------
ndarray of shape, filled with value, of specified / inferred dtype

"""

if dtype is None:
dtype, fill_value = infer_dtype_from_scalar(value)
else:
fill_value = value

values = np.empty(shape, dtype=dtype)
values.fill(fill_value)

return values
14 changes: 13 additions & 1 deletion pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
ExtensionDtype)
from .generic import (ABCCategorical, ABCPeriodIndex,
ABCDatetimeIndex, ABCSeries,
ABCSparseArray, ABCSparseSeries)
ABCSparseArray, ABCSparseSeries,
ABCIndexClass)
from .inference import is_string_like
from .inference import * # noqa

Expand Down Expand Up @@ -1535,11 +1536,22 @@ def is_bool_dtype(arr_or_dtype):

if arr_or_dtype is None:
return False

try:
tipo = _get_dtype_type(arr_or_dtype)
except ValueError:
# this isn't even a dtype
return False

if isinstance(arr_or_dtype, ABCIndexClass):

# TODO(jreback)
# we don't have a boolean Index class
# so its object, we need to infer to
# guess this
return (arr_or_dtype.is_object and
arr_or_dtype.inferred_type == 'boolean')

return issubclass(tipo, np.bool_)


Expand Down
24 changes: 10 additions & 14 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
import numpy.ma as ma

from pandas.core.dtypes.cast import (
maybe_upcast, infer_dtype_from_scalar,
maybe_upcast,
maybe_cast_to_datetime,
maybe_infer_to_datetimelike,
maybe_convert_platform,
maybe_downcast_to_dtype,
invalidate_string_dtypes,
coerce_to_dtypes,
maybe_upcast_putmask,
cast_scalar_to_array,
find_common_type)
from pandas.core.dtypes.common import (
is_categorical_dtype,
Expand All @@ -59,6 +60,7 @@
is_named_tuple)
from pandas.core.dtypes.missing import isnull, notnull


from pandas.core.common import (_try_sort,
_default_index,
_values_from_object,
Expand Down Expand Up @@ -355,15 +357,10 @@ def __init__(self, data=None, index=None, columns=None, dtype=None,
raise_with_traceback(exc)

if arr.ndim == 0 and index is not None and columns is not None:
if isinstance(data, compat.string_types) and dtype is None:
dtype = np.object_
if dtype is None:
dtype, data = infer_dtype_from_scalar(data)

values = np.empty((len(index), len(columns)), dtype=dtype)
values.fill(data)
mgr = self._init_ndarray(values, index, columns, dtype=dtype,
copy=False)
values = cast_scalar_to_array((len(index), len(columns)),
data, dtype=dtype)
mgr = self._init_ndarray(values, index, columns,
dtype=values.dtype, copy=False)
else:
raise ValueError('DataFrame constructor not properly called!')

Expand Down Expand Up @@ -477,7 +474,7 @@ def _get_axes(N, K, index=index, columns=columns):
values = _prep_ndarray(values, copy=copy)

if dtype is not None:
if values.dtype != dtype:
if not is_dtype_equal(values.dtype, dtype):
try:
values = values.astype(dtype)
except Exception as orig:
Expand Down Expand Up @@ -2653,9 +2650,8 @@ def reindexer(value):

else:
# upcast the scalar
dtype, value = infer_dtype_from_scalar(value)
value = np.repeat(value, len(self.index)).astype(dtype)
value = maybe_cast_to_datetime(value, dtype)
value = cast_scalar_to_array(len(self.index), value)
value = maybe_cast_to_datetime(value, value.dtype)

# return internal types directly
if is_extension_type(value):
Expand Down
45 changes: 1 addition & 44 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from pandas.core.dtypes.common import (
_ensure_int64,
_ensure_object,
needs_i8_conversion,
is_scalar,
is_number,
is_integer, is_bool,
Expand Down Expand Up @@ -5301,48 +5300,6 @@ def _where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
raise NotImplemented("cannot align with a higher dimensional "
"NDFrame")

elif is_list_like(other):

if self.ndim == 1:

# try to set the same dtype as ourselves
try:
new_other = np.array(other, dtype=self.dtype)
except ValueError:
new_other = np.array(other)
except TypeError:
new_other = other

# we can end up comparing integers and m8[ns]
# which is a numpy no no
is_i8 = needs_i8_conversion(self.dtype)
if is_i8:
matches = False
else:
matches = (new_other == np.array(other))

if matches is False or not matches.all():

# coerce other to a common dtype if we can
if needs_i8_conversion(self.dtype):
try:
other = np.array(other, dtype=self.dtype)
except:
other = np.array(other)
else:
other = np.asarray(other)
other = np.asarray(other,
dtype=np.common_type(other,
new_other))

# we need to use the new dtype
try_quick = False
else:
other = new_other
else:

other = np.array(other)

if isinstance(other, np.ndarray):

if other.shape != self.shape:
Expand Down Expand Up @@ -5407,7 +5364,7 @@ def _where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
# reconstruct the block manager

self._check_inplace_setting(other)
new_data = self._data.putmask(mask=cond, new=other, align=align,
new_data = self._data.putmask(mask=cond, other=other, align=align,
inplace=True, axis=block_axis,
transpose=self._AXIS_REVERSED)
self._update_inplace(new_data)
Expand Down
Loading