Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG/API: make setitem-inplace preserve dtype when possible with PandasArray, IntegerArray, FloatingArray #39044

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
6fffb02
BUG/API: make setitem-inplace preserve dtype when possible with Panda…
jbrockmendel Jan 8, 2021
160f3f7
Merge branch 'master' of https://github.com/pandas-dev/pandas into bu…
jbrockmendel Jan 12, 2021
de10708
do patching inside test_numpy
jbrockmendel Jan 13, 2021
ba98a99
Merge branch 'master' of https://github.com/pandas-dev/pandas into bu…
jbrockmendel Jan 15, 2021
84261a7
Merge branch 'master' of https://github.com/pandas-dev/pandas into bu…
jbrockmendel Jan 15, 2021
284f36a
move kludge to test_numpy
jbrockmendel Jan 15, 2021
2639b5c
Merge branch 'master' of https://github.com/pandas-dev/pandas into bu…
jbrockmendel Jan 16, 2021
b2aa366
staticmethod -> function
jbrockmendel Jan 16, 2021
7aeb2b5
typo fixup+test
jbrockmendel Jan 20, 2021
715a602
Merge branch 'master' into bug-38896
jbrockmendel Jan 25, 2021
b55155b
Merge branch 'master' of https://github.com/pandas-dev/pandas into bu…
jbrockmendel Jan 25, 2021
c72e566
Merge branch 'master' of https://github.com/pandas-dev/pandas into bu…
jbrockmendel Jan 25, 2021
0b9f343
Merge branch 'master' of https://github.com/pandas-dev/pandas into bu…
jbrockmendel Jan 27, 2021
cd6adbe
Merge branch 'master' of https://github.com/pandas-dev/pandas into bu…
jbrockmendel Jan 28, 2021
071ab1b
Merge branch 'master' of https://github.com/pandas-dev/pandas into bu…
jbrockmendel Jan 28, 2021
6ab44af
docstring
jbrockmendel Jan 28, 2021
daacff8
Merge branch 'master' of https://github.com/pandas-dev/pandas into bu…
jbrockmendel Jan 29, 2021
450bf73
comment, revert floating
jbrockmendel Jan 29, 2021
dba7c11
Merge branch 'master' into bug-38896
jbrockmendel Feb 2, 2021
9258cbb
Merge branch 'master' into bug-38896
jbrockmendel Feb 3, 2021
88309ab
Merge branch 'master' of https://github.com/pandas-dev/pandas into bu…
jbrockmendel Feb 3, 2021
1847209
Merge branch 'master' into bug-38896
jbrockmendel Feb 7, 2021
6b8cc31
Merge branch 'master' into bug-38896
jbrockmendel Feb 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,14 +506,28 @@ def array_equals(left: ArrayLike, right: ArrayLike) -> bool:
return array_equivalent(left, right, dtype_equal=True)


def infer_fill_value(val):
def infer_fill_value(val, length: int):
"""
infer the fill value for the nan/NaT from the provided
scalar/ndarray/list-like if we are a NaT, return the correct dtyped
element to provide proper block construction
`val` is going to be inserted as (part of) a new column in a DataFrame
with the given length. If val cannot be made to fit exactly,
find an appropriately-dtyped NA value to construct a complete column from,
which we will later set `val` into.
"""
if not is_list_like(val):
val = [val]

if is_extension_array_dtype(val):
# We cannot use dtype._na_value bc pd.NA/pd.NaT do not preserve dtype
if len(val) == length:
jreback marked this conversation as resolved.
Show resolved Hide resolved
# TODO: in this case see if we can avoid making a copy later on
return val
if length == 0:
return val[:0].copy()

dtype = val.dtype
cls = dtype.construct_array_type()
return cls._from_sequence([dtype.na_value], dtype=dtype).repeat(length)

val = np.array(val, copy=False)
if needs_i8_conversion(val.dtype):
jreback marked this conversation as resolved.
Show resolved Hide resolved
return np.array("NaT", dtype=val.dtype)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,7 +1605,7 @@ def _setitem_with_indexer(self, indexer, value, name="iloc"):
# We are setting an entire column
self.obj[key] = value
else:
self.obj[key] = infer_fill_value(value)
self.obj[key] = infer_fill_value(value, len(self.obj))

new_indexer = convert_from_missing_indexer_tuple(
indexer, self.obj.axes
Expand Down
65 changes: 45 additions & 20 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
Categorical,
DatetimeArray,
ExtensionArray,
FloatingArray,
IntegerArray,
PandasArray,
TimedeltaArray,
)
Expand Down Expand Up @@ -620,10 +622,17 @@ def astype(self, dtype, copy: bool = False, errors: str = "raise"):
)
raise TypeError(msg)

values = self.values
dtype = pandas_dtype(dtype)
if isinstance(dtype, ExtensionDtype) and self.values.ndim == 2:
# TODO(EA2D): kludge not needed with 2D EAs (astype_nansafe would raise)
# note DataFrame.astype has special handling to avoid getting here
if self.shape[0] != 1:
raise NotImplementedError("Need 2D EAs!")
values = values[0]

try:
new_values = self._astype(dtype, copy=copy)
new_values = astype_block_compat(values, dtype, copy=copy)
except (ValueError, TypeError):
# e.g. astype_nansafe can fail on object-dtype of strings
# trying to convert to float
Expand All @@ -642,25 +651,6 @@ def astype(self, dtype, copy: bool = False, errors: str = "raise"):
)
return newb

def _astype(self, dtype: DtypeObj, copy: bool) -> ArrayLike:
values = self.values

if is_datetime64tz_dtype(dtype) and is_datetime64_dtype(values.dtype):
return astype_dt64_to_dt64tz(values, dtype, copy, via_utc=True)

if is_dtype_equal(values.dtype, dtype):
if copy:
return values.copy()
return values

if isinstance(values, ExtensionArray):
values = values.astype(dtype, copy=copy)

else:
values = astype_nansafe(values, dtype, copy=copy)

return values

def convert(
self,
copy: bool = True,
Expand Down Expand Up @@ -905,6 +895,15 @@ def setitem(self, indexer, value):
# current dtype cannot store value, coerce to common dtype
return self.coerce_to_target_dtype(value).setitem(indexer, value)

value = extract_array(value, extract_numpy=True)

if isinstance(value, (IntegerArray, FloatingArray)) and not value._mask.any():
# GH#38896
value = value.to_numpy(value.dtype.numpy_dtype)
if self.ndim == 2 and value.ndim == 1:
# TODO(EA2D): special case not needed with 2D EAs
value = np.atleast_2d(value).T

if self.dtype.kind in ["m", "M"]:
arr = self.array_values().T
arr[indexer] = value
Expand Down Expand Up @@ -1901,6 +1900,11 @@ class NumericBlock(Block):
is_numeric = True

def _can_hold_element(self, element: Any) -> bool:
if isinstance(element, (IntegerArray, FloatingArray)):
# GH#38896
if element._mask.any():
return False

return can_hold_element(self.dtype, element)

@property
Expand Down Expand Up @@ -2533,3 +2537,24 @@ def _extract_bool_array(mask: ArrayLike) -> np.ndarray:
assert isinstance(mask, np.ndarray), type(mask)
assert mask.dtype == bool, mask.dtype
return mask


def astype_block_compat(values: ArrayLike, dtype: DtypeObj, copy: bool) -> ArrayLike:
"""
Series/DataFrame implementation of .astype
"""
if is_datetime64tz_dtype(dtype) and is_datetime64_dtype(values.dtype):
return astype_dt64_to_dt64tz(values, dtype, copy, via_utc=True)

if is_dtype_equal(values.dtype, dtype):
if copy:
return values.copy()
return values

if isinstance(values, ExtensionArray):
values = values.astype(dtype, copy=copy)

else:
values = astype_nansafe(values, dtype, copy=copy)

return values
30 changes: 24 additions & 6 deletions pandas/tests/extension/test_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
from pandas.tests.extension import base


def make_data():
return list(range(1, 9)) + [pd.NA] + list(range(10, 98)) + [pd.NA] + [99, 100]
def make_data(with_nas: bool = True):
if with_nas:
return list(range(1, 9)) + [pd.NA] + list(range(10, 98)) + [pd.NA] + [99, 100]

return list(range(1, 101))


@pytest.fixture(
Expand All @@ -52,9 +55,10 @@ def dtype(request):
return request.param()


@pytest.fixture
def data(dtype):
return pd.array(make_data(), dtype=dtype)
@pytest.fixture(params=[True, False])
def data(dtype, request):
with_nas = request.param
return pd.array(make_data(with_nas), dtype=dtype)


@pytest.fixture
Expand Down Expand Up @@ -193,7 +197,21 @@ class TestGetitem(base.BaseGetitemTests):


class TestSetitem(base.BaseSetitemTests):
pass
def test_setitem_series(self, data, full_indexer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you indicate here why it is overriding the base class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment added

# https://github.com/pandas-dev/pandas/issues/32395
# overriden because we have a different `expected` in some cases
ser = expected = pd.Series(data, name="data")
result = pd.Series(index=ser.index, dtype=object, name="data")

key = full_indexer(ser)
result.loc[key] = ser

if not data._mask.any():
# GH#38896 like we do with ndarray, we set the values inplace
# but cast to the new numpy dtype
expected = pd.Series(data.to_numpy(data.dtype.numpy_dtype), name="data")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we converting to the numpy dtype here? That's also not the original dtype?


self.assert_series_equal(result, expected)


class TestMissing(base.BaseMissingTests):
Expand Down
67 changes: 66 additions & 1 deletion pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
import pytest

from pandas.core.dtypes.dtypes import ExtensionDtype, PandasDtype
from pandas.core.dtypes.missing import infer_fill_value as infer_fill_value_orig

import pandas as pd
import pandas._testing as tm
from pandas.core.arrays.numpy_ import PandasArray
from pandas.core.arrays import PandasArray, StringArray
from pandas.core.construction import extract_array
from pandas.tests.extension import base


Expand All @@ -29,6 +31,31 @@ def dtype(request):
return PandasDtype(np.dtype(request.param))


orig_setitem = pd.core.internals.Block.setitem
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use monkeypatch instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does use monkeypatch. the monkeypatched method calls the original method



def setitem(self, indexer, value):
# patch Block.setitem
value = extract_array(value, extract_numpy=True)
if isinstance(value, PandasArray) and not isinstance(value, StringArray):
value = value.to_numpy()
if self.ndim == 2 and value.ndim == 1:
# TODO(EA2D): special case not needed with 2D EAs
value = np.atleast_2d(value)

return orig_setitem(self, indexer, value)


def infer_fill_value(val, length: int):
# GH#39044 we have to patch core.dtypes.missing.infer_fill_value
# to unwrap PandasArray bc it won't recognize PandasArray with
# is_extension_dtype
if isinstance(val, PandasArray):
val = val.to_numpy()

return infer_fill_value_orig(val, length)


@pytest.fixture
def allow_in_pandas(monkeypatch):
"""
Expand All @@ -48,6 +75,8 @@ def allow_in_pandas(monkeypatch):
"""
with monkeypatch.context() as m:
m.setattr(PandasArray, "_typ", "extension")
m.setattr(pd.core.indexing, "infer_fill_value", infer_fill_value)
m.setattr(pd.core.internals.Block, "setitem", setitem)
yield


Expand Down Expand Up @@ -457,6 +486,42 @@ def test_setitem_slice(self, data, box_in_series):
def test_setitem_loc_iloc_slice(self, data):
super().test_setitem_loc_iloc_slice(data)

def test_setitem_with_expansion_dataframe_column(self, data, full_indexer, request):
# https://github.com/pandas-dev/pandas/issues/32395
df = pd.DataFrame({"data": pd.Series(data)})
result = pd.DataFrame(index=df.index)

key = full_indexer(df)
result.loc[key, "data"] = df["data"]._values

expected = pd.DataFrame({"data": data})
if data.dtype.numpy_dtype != object:
# For PandasArray we expect to get unboxed to numpy
expected = pd.DataFrame({"data": data.to_numpy()})

if isinstance(key, slice) and (
key == slice(None) and data.dtype.numpy_dtype != object
):
mark = pytest.mark.xfail(
reason="This case goes through a different code path"
)
# Other cases go through Block.setitem
request.node.add_marker(mark)

self.assert_frame_equal(result, expected)

def test_setitem_series(self, data, full_indexer):
# https://github.com/pandas-dev/pandas/issues/32395
ser = pd.Series(data, name="data")
result = pd.Series(index=ser.index, dtype=object, name="data")

key = full_indexer(ser)
result.loc[key] = ser

# For PandasArray we expect to get unboxed to numpy
expected = pd.Series(data.to_numpy(), name="data")
self.assert_series_equal(result, expected)


@skip_nested
class TestParsing(BaseNumPyTests, base.BaseParsingTests):
Expand Down
14 changes: 14 additions & 0 deletions pandas/tests/indexing/test_loc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Index,
IndexSlice,
MultiIndex,
NaT,
Series,
SparseDtype,
Timedelta,
Expand Down Expand Up @@ -1348,6 +1349,19 @@ def test_loc_setitem_categorical_column_retains_dtype(self, ordered):
expected = DataFrame({"A": [1], "B": Categorical(["b"], ordered=ordered)})
tm.assert_frame_equal(result, expected)

def test_loc_setitem_ea_not_full_column(self):
# GH#39163
df = DataFrame({"A": range(5)})

val = date_range("2016-01-01", periods=3, tz="US/Pacific")

df.loc[[0, 1, 2], "B"] = val

bex = val.append(DatetimeIndex([NaT, NaT], dtype=val.dtype))
expected = DataFrame({"A": range(5), "B": bex})
assert expected.dtypes["B"] == val.dtype
tm.assert_frame_equal(df, expected)


class TestLocCallable:
def test_frame_loc_getitem_callable(self):
Expand Down
2 changes: 1 addition & 1 deletion pandas/util/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def find_stack_level() -> int:
if stack[n].function == "astype":
break

while stack[n].function in ["astype", "apply", "_astype"]:
while stack[n].function in ["astype", "apply", "astype_block_compat"]:
# e.g.
# bump up Block.astype -> BlockManager.astype -> NDFrame.astype
# bump up Datetime.Array.astype -> DatetimeIndex.astype
Expand Down