Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: replace coerces incorrect dtype #12780

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.20.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ Bug Fixes

- Bug in ``pd.read_msgpack()`` in which ``Series`` categoricals were being improperly processed (:issue:`14901`)
- Bug in ``Series.ffill()`` with mixed dtypes containing tz-aware datetimes. (:issue:`14956`)
- Bug in ``.replace()`` may result in incorrect dtypes. (:issue:`12747`)



Expand Down
20 changes: 17 additions & 3 deletions pandas/core/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,8 +1890,11 @@ def convert(self, *args, **kwargs):
blocks.append(newb)

else:
values = fn(
self.values.ravel(), **fn_kwargs).reshape(self.values.shape)
values = fn(self.values.ravel(), **fn_kwargs)
try:
values = values.reshape(self.values.shape)
except NotImplementedError:
pass
blocks.append(make_block(values, ndim=self.ndim,
placement=self.mgr_locs))

Expand Down Expand Up @@ -3233,6 +3236,16 @@ def comp(s):
return _possibly_compare(values, getattr(s, 'asm8', s),
operator.eq)

def _cast_scalar(block, scalar):
dtype, val = _infer_dtype_from_scalar(scalar, pandas_dtype=True)
if not is_dtype_equal(block.dtype, dtype):
dtype = _find_common_type([block.dtype, dtype])
block = block.astype(dtype)
# use original value
val = scalar

return block, val

masks = [comp(s) for i, s in enumerate(src_list)]

result_blocks = []
Expand All @@ -3255,7 +3268,8 @@ def comp(s):
# particular block
m = masks[i][b.mgr_locs.indexer]
if m.any():
new_rb.extend(b.putmask(m, d, inplace=True))
b, val = _cast_scalar(b, d)
new_rb.extend(b.putmask(m, val, inplace=True))
else:
new_rb.append(b)
rb = new_rb
Expand Down
50 changes: 38 additions & 12 deletions pandas/tests/indexing/test_coercion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,12 +1155,27 @@ def setUp(self):
self.rep['float64'] = [1.1, 2.2]
self.rep['complex128'] = [1 + 1j, 2 + 2j]
self.rep['bool'] = [True, False]
self.rep['datetime64[ns]'] = [pd.Timestamp('2011-01-01'),
pd.Timestamp('2011-01-03')]

for tz in ['UTC', 'US/Eastern']:
# to test tz => different tz replacement
key = 'datetime64[ns, {0}]'.format(tz)
self.rep[key] = [pd.Timestamp('2011-01-01', tz=tz),
pd.Timestamp('2011-01-03', tz=tz)]

self.rep['timedelta64[ns]'] = [pd.Timedelta('1 day'),
pd.Timedelta('2 day')]

def _assert_replace_conversion(self, from_key, to_key, how):
index = pd.Index([3, 4], name='xxx')
obj = pd.Series(self.rep[from_key], index=index, name='yyy')
self.assertEqual(obj.dtype, from_key)

if (from_key.startswith('datetime') and to_key.startswith('datetime')):
# different tz, currently mask_missing raises SystemError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this for another issue?

Copy link
Member Author

@sinhrks sinhrks Jan 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, issued #15183. Let me fix it in a separate pr.

return

if how == 'dict':
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
elif how == 'series':
Expand All @@ -1177,17 +1192,10 @@ def _assert_replace_conversion(self, from_key, to_key, how):
raise nose.SkipTest("windows platform buggy: {0} -> {1}".format
(from_key, to_key))

if ((from_key == 'float64' and
to_key in ('bool', 'int64')) or

if ((from_key == 'float64' and to_key in ('bool', 'int64')) or
(from_key == 'complex128' and
to_key in ('bool', 'int64', 'float64')) or

(from_key == 'int64' and
to_key in ('bool')) or

# TODO_GH12747 The result must be int?
(from_key == 'bool' and to_key == 'int64')):
(from_key == 'int64' and to_key in ('bool'))):

# buggy on 32-bit
if tm.is_platform_32bit():
Expand Down Expand Up @@ -1250,13 +1258,31 @@ def test_replace_series_bool(self):
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_datetime64(self):
pass
from_key = 'datetime64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='dict')

from_key = 'datetime64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_datetime64tz(self):
pass
from_key = 'datetime64[ns, US/Eastern]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='dict')

from_key = 'datetime64[ns, US/Eastern]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_timedelta64(self):
pass
from_key = 'timedelta64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='dict')

from_key = 'timedelta64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_period(self):
pass
4 changes: 2 additions & 2 deletions pandas/tests/series/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def check_replace(to_rep, val, expected):
tm.assert_series_equal(expected, r)
tm.assert_series_equal(expected, sc)

# should NOT upcast to float
e = pd.Series([0, 1, 2, 3, 4])
# MUST upcast to float
e = pd.Series([0., 1., 2., 3., 4.])
tr, v = [3], [3.0]
check_replace(tr, v, e)

Expand Down
37 changes: 29 additions & 8 deletions pandas/types/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
_ensure_int32, _ensure_int64,
_NS_DTYPE, _TD_DTYPE, _INT64_DTYPE,
_POSSIBLY_CAST_DTYPES)
from .dtypes import ExtensionDtype
from .dtypes import ExtensionDtype, DatetimeTZDtype, PeriodDtype
from .generic import ABCDatetimeIndex, ABCPeriodIndex, ABCSeries
from .missing import isnull, notnull
from .inference import is_list_like
Expand Down Expand Up @@ -310,8 +310,17 @@ def _maybe_promote(dtype, fill_value=np.nan):
return dtype, fill_value


def _infer_dtype_from_scalar(val):
""" interpret the dtype from a scalar """
def _infer_dtype_from_scalar(val, pandas_dtype=False):
"""
interpret the dtype from a scalar

Parameters
----------
pandas_dtype : bool, default False
whether to infer dtype including pandas extension types.
If False, scalar belongs to pandas extension types is inferred as
object
"""

dtype = np.object_

Expand All @@ -334,13 +343,20 @@ def _infer_dtype_from_scalar(val):

dtype = np.object_

elif isinstance(val, (np.datetime64,
datetime)) and getattr(val, 'tzinfo', None) is None:
val = lib.Timestamp(val).value
dtype = np.dtype('M8[ns]')
elif isinstance(val, (np.datetime64, datetime)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess category should raise NotImplmentedError?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As value must be scalar, it can't be category dtype.

val = tslib.Timestamp(val)
if val is tslib.NaT or val.tz is None:
dtype = np.dtype('M8[ns]')
else:
if pandas_dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. here, is this a back-compat issue?

dtype = DatetimeTZDtype(unit='ns', tz=val.tz)
else:
# return datetimetz as object
return np.object_, val
val = val.value

elif isinstance(val, (np.timedelta64, timedelta)):
val = lib.Timedelta(val).value
val = tslib.Timedelta(val).value
dtype = np.dtype('m8[ns]')

elif is_bool(val):
Expand All @@ -361,6 +377,11 @@ def _infer_dtype_from_scalar(val):
elif is_complex(val):
dtype = np.complex_

elif pandas_dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this parameter necessary, why wouldn't we always want to infer a pandas dtype if its there (e.g. Period)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed for the PR itself. The same func is required for #14145.

if lib.is_period(val):
dtype = PeriodDtype(freq=val.freq)
val = val.ordinal

return dtype, val


Expand Down