Skip to content

Commit

Permalink
BUG: replace coerces incorrect dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
sinhrks committed Nov 21, 2016
1 parent f26b049 commit 081d2e9
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 25 deletions.
20 changes: 17 additions & 3 deletions pandas/core/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,8 +1873,11 @@ def convert(self, *args, **kwargs):
blocks.append(newb)

else:
values = fn(
self.values.ravel(), **fn_kwargs).reshape(self.values.shape)
values = fn(self.values.ravel(), **fn_kwargs)
try:
values = values.reshape(self.values.shape)
except NotImplementedError:
pass
blocks.append(make_block(values, ndim=self.ndim,
placement=self.mgr_locs))

Expand Down Expand Up @@ -3211,6 +3214,16 @@ def comp(s):
return _possibly_compare(values, getattr(s, 'asm8', s),
operator.eq)

def _cast(block, scalar):
dtype, val = _infer_dtype_from_scalar(scalar, pandas_dtype=True)
if not is_dtype_equal(block.dtype, dtype):
dtype = _find_common_type([block.dtype, dtype])
block = block.astype(dtype)
# use original value
val = scalar

return block, val

masks = [comp(s) for i, s in enumerate(src_list)]

result_blocks = []
Expand All @@ -3231,7 +3244,8 @@ def comp(s):
# particular block
m = masks[i][b.mgr_locs.indexer]
if m.any():
new_rb.extend(b.putmask(m, d, inplace=True))
b, val = _cast(b, d)
new_rb.extend(b.putmask(m, val, inplace=True))
else:
new_rb.append(b)
rb = new_rb
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def mask_missing(arr, values_to_mask):
# numpy elementwise comparison warning
if is_numeric_v_string_like(arr, x):
mask = False
# elif is_object_dtype(arr):
# mask = lib.scalar_compare(arr, x, operator.eq)
else:
mask = arr == x

Expand All @@ -51,6 +53,8 @@ def mask_missing(arr, values_to_mask):
# numpy elementwise comparison warning
if is_numeric_v_string_like(arr, x):
mask |= False
# elif is_object_dtype(arr):
# mask |= lib.scalar_compare(arr, x, operator.eq)
else:
mask |= arr == x

Expand Down
50 changes: 38 additions & 12 deletions pandas/tests/indexing/test_coercion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,12 +1155,27 @@ def setUp(self):
self.rep['float64'] = [1.1, 2.2]
self.rep['complex128'] = [1 + 1j, 2 + 2j]
self.rep['bool'] = [True, False]
self.rep['datetime64[ns]'] = [pd.Timestamp('2011-01-01'),
pd.Timestamp('2011-01-03')]

for tz in ['UTC', 'US/Eastern']:
# to test tz => different tz replacement
key = 'datetime64[ns, {0}]'.format(tz)
self.rep[key] = [pd.Timestamp('2011-01-01', tz=tz),
pd.Timestamp('2011-01-03', tz=tz)]

self.rep['timedelta64[ns]'] = [pd.Timedelta('1 day'),
pd.Timedelta('2 day')]

def _assert_replace_conversion(self, from_key, to_key, how):
index = pd.Index([3, 4], name='xxx')
obj = pd.Series(self.rep[from_key], index=index, name='yyy')
self.assertEqual(obj.dtype, from_key)

if (from_key.startswith('datetime') and to_key.startswith('datetime')):
# different tz, currently mask_missing raises SystemError
return

if how == 'dict':
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
elif how == 'series':
Expand All @@ -1177,17 +1192,10 @@ def _assert_replace_conversion(self, from_key, to_key, how):
raise nose.SkipTest("windows platform buggy: {0} -> {1}".format
(from_key, to_key))

if ((from_key == 'float64' and
to_key in ('bool', 'int64')) or

if ((from_key == 'float64' and to_key in ('bool', 'int64')) or
(from_key == 'complex128' and
to_key in ('bool', 'int64', 'float64')) or

(from_key == 'int64' and
to_key in ('bool')) or

# TODO_GH12747 The result must be int?
(from_key == 'bool' and to_key == 'int64')):
(from_key == 'int64' and to_key in ('bool'))):

# buggy on 32-bit
if tm.is_platform_32bit():
Expand Down Expand Up @@ -1250,13 +1258,31 @@ def test_replace_series_bool(self):
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_datetime64(self):
pass
from_key = 'datetime64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='dict')

from_key = 'datetime64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_datetime64tz(self):
pass
from_key = 'datetime64[ns, US/Eastern]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='dict')

from_key = 'datetime64[ns, US/Eastern]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_timedelta64(self):
pass
from_key = 'timedelta64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='dict')

from_key = 'timedelta64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_period(self):
pass
4 changes: 2 additions & 2 deletions pandas/tests/series/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def check_replace(to_rep, val, expected):
tm.assert_series_equal(expected, r)
tm.assert_series_equal(expected, sc)

# should NOT upcast to float
e = pd.Series([0, 1, 2, 3, 4])
# MUST upcast to float
e = pd.Series([0., 1., 2., 3., 4.])
tr, v = [3], [3.0]
check_replace(tr, v, e)

Expand Down
42 changes: 34 additions & 8 deletions pandas/types/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
_ensure_int32, _ensure_int64,
_NS_DTYPE, _TD_DTYPE, _INT64_DTYPE,
_DATELIKE_DTYPES, _POSSIBLY_CAST_DTYPES)
from .dtypes import ExtensionDtype
from .dtypes import ExtensionDtype, DatetimeTZDtype, PeriodDtype
from .generic import ABCDatetimeIndex, ABCPeriodIndex, ABCSeries
from .missing import isnull, notnull
from .inference import is_list_like
Expand Down Expand Up @@ -309,8 +309,17 @@ def _maybe_promote(dtype, fill_value=np.nan):
return dtype, fill_value


def _infer_dtype_from_scalar(val):
""" interpret the dtype from a scalar """
def _infer_dtype_from_scalar(val, pandas_dtype=False):
"""
interpret the dtype from a scalar
Parameters
----------
pandas_dtype : bool, default False
whether to infer dtype including pandas extension types.
If False, scalar belongs to pandas extension types is inferred as
object
"""

dtype = np.object_

Expand All @@ -333,13 +342,23 @@ def _infer_dtype_from_scalar(val):

dtype = np.object_

elif isinstance(val, (np.datetime64,
datetime)) and getattr(val, 'tzinfo', None) is None:
val = lib.Timestamp(val).value
dtype = np.dtype('M8[ns]')
elif isinstance(val, (np.datetime64, datetime)):
val = tslib.Timestamp(val)
if val is tslib.NaT or val.tz is None:
dtype = np.dtype('M8[ns]')
else:
if pandas_dtype:
dtype = DatetimeTZDtype(unit='ns', tz=val.tz)
# ToDo: This localization is not needed if
# DatetimeTZBlock doesn't localize internal values
val = val.tz_localize(None)
else:
# return datetimetz as object
return np.object_, val
val = val.value

elif isinstance(val, (np.timedelta64, timedelta)):
val = lib.Timedelta(val).value
val = tslib.Timedelta(val).value
dtype = np.dtype('m8[ns]')

elif is_bool(val):
Expand All @@ -360,6 +379,13 @@ def _infer_dtype_from_scalar(val):
elif is_complex(val):
dtype = np.complex_

elif pandas_dtype:
# to do use util
from pandas.tseries.period import Period
if isinstance(val, Period):
dtype = PeriodDtype(freq=val.freq)
val = val.ordinal

return dtype, val


Expand Down

0 comments on commit 081d2e9

Please sign in to comment.