Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: move Block.astype implementation to dtypes/cast.py #40141

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions pandas/core/array_algos/cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import inspect

import numpy as np

from pandas._typing import (
ArrayLike,
DtypeObj,
)

from pandas.core.dtypes.cast import (
astype_dt64_to_dt64tz,
astype_nansafe,
)
from pandas.core.dtypes.common import (
is_datetime64_dtype,
is_datetime64tz_dtype,
is_dtype_equal,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import ExtensionDtype

from pandas.core.arrays import ExtensionArray


def astype_array(values: ArrayLike, dtype: DtypeObj, copy: bool = False):
jorisvandenbossche marked this conversation as resolved.
Show resolved Hide resolved
"""
Cast array to the new dtype.

Parameters
----------
values : ndarray or ExtensionArray
dtype : dtype object
copy : bool, default False
copy if indicated

Returns
-------
ndarray or ExtensionArray
"""
if (
values.dtype.kind in ["m", "M"]
and dtype.kind in ["i", "u"]
and isinstance(dtype, np.dtype)
and dtype.itemsize != 8
):
# TODO(2.0) remove special case once deprecation on DTA/TDA is enforced
msg = rf"cannot astype a datetimelike from [{values.dtype}] to [{dtype}]"
raise TypeError(msg)

if is_datetime64tz_dtype(dtype) and is_datetime64_dtype(values.dtype):
return astype_dt64_to_dt64tz(values, dtype, copy, via_utc=True)

if is_dtype_equal(values.dtype, dtype):
if copy:
return values.copy()
return values

if isinstance(values, ExtensionArray):
values = values.astype(dtype, copy=copy)

else:
values = astype_nansafe(values, dtype, copy=copy)

# now in ObjectBlock._maybe_coerce_values(cls, values):
if isinstance(dtype, np.dtype) and issubclass(values.dtype.type, str):
values = np.array(values, dtype=object)

return values


def astype_array_safe(values, dtype, copy: bool = False, errors: str = "raise"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

values: ArrayLike; dtype: DtypeObj?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy isn't smart enough for that (we also didn't annotate dtype in the code from where I copied this). But will add values: ArrayLike

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense. if we move the pandas_dtype call up into the caller we can do it once instead of (block|array)-wise

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we move the pandas_dtype call up into the caller we can do it once instead of (block|array)-wise

Then also the dtype validation needs to be moved up (requiring some more changes to avoid duplication of that part). At the moment, it's a rather straightfoward cut and paste from blocks.py to cast.py, so I would maybe prefer to keep it that way for this PR.

"""
Cast array to the new dtype.

Parameters
----------
values : ndarray or ExtensionArray
dtype : str, dtype convertible
copy : bool, default False
copy if indicated
errors : str, {'raise', 'ignore'}, default 'raise'
- ``raise`` : allow exceptions to be raised
- ``ignore`` : suppress exceptions. On error return original object

Returns
-------
ndarray or ExtensionArray
"""
errors_legal_values = ("raise", "ignore")

if errors not in errors_legal_values:
invalid_arg = (
"Expected value of kwarg 'errors' to be one of "
f"{list(errors_legal_values)}. Supplied value is '{errors}'"
)
raise ValueError(invalid_arg)

if inspect.isclass(dtype) and issubclass(dtype, ExtensionDtype):
msg = (
f"Expected an instance of {dtype.__name__}, "
"but got the class instead. Try instantiating 'dtype'."
)
raise TypeError(msg)

dtype = pandas_dtype(dtype)

try:
new_values = astype_array(values, dtype, copy=copy)
except (ValueError, TypeError):
# e.g. astype_nansafe can fail on object-dtype of strings
# trying to convert to float
if errors == "ignore":
new_values = values
else:
raise

return new_values
3 changes: 2 additions & 1 deletion pandas/core/internals/array_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)

import pandas.core.algorithms as algos
from pandas.core.array_algos.cast import astype_array_safe
from pandas.core.arrays import ExtensionArray
from pandas.core.arrays.sparse import SparseDtype
from pandas.core.construction import (
Expand Down Expand Up @@ -499,7 +500,7 @@ def downcast(self) -> ArrayManager:
return self.apply_with_block("downcast")

def astype(self, dtype, copy: bool = False, errors: str = "raise") -> ArrayManager:
return self.apply("astype", dtype=dtype, copy=copy) # , errors=errors)
return self.apply(astype_array_safe, dtype=dtype, copy=copy, errors=errors)

def convert(
self,
Expand Down
66 changes: 5 additions & 61 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import inspect
import re
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -36,8 +35,6 @@
from pandas.util._validators import validate_bool_kwarg

from pandas.core.dtypes.cast import (
astype_dt64_to_dt64tz,
astype_nansafe,
can_hold_element,
find_common_type,
infer_dtype_from,
Expand All @@ -49,7 +46,6 @@
)
from pandas.core.dtypes.common import (
is_categorical_dtype,
is_datetime64_dtype,
is_datetime64tz_dtype,
is_dtype_equal,
is_extension_array_dtype,
Expand All @@ -76,6 +72,7 @@
)

import pandas.core.algorithms as algos
from pandas.core.array_algos.cast import astype_array_safe
from pandas.core.array_algos.putmask import (
extract_bool_array,
putmask_inplace,
Expand Down Expand Up @@ -652,33 +649,11 @@ def astype(self, dtype, copy: bool = False, errors: str = "raise"):
-------
Block
"""
errors_legal_values = ("raise", "ignore")

if errors not in errors_legal_values:
invalid_arg = (
"Expected value of kwarg 'errors' to be one of "
f"{list(errors_legal_values)}. Supplied value is '{errors}'"
)
raise ValueError(invalid_arg)

if inspect.isclass(dtype) and issubclass(dtype, ExtensionDtype):
msg = (
f"Expected an instance of {dtype.__name__}, "
"but got the class instead. Try instantiating 'dtype'."
)
raise TypeError(msg)

dtype = pandas_dtype(dtype)
values = self.values
if values.dtype.kind in ["m", "M"]:
values = self.array_values()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could move this into astype_array_safe and use ensure_wrapped_if_datetimelike; would make it robust to AM/BM (though i think both AM and BM now have PRs to make the arrays EAs to begin with)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since ArrayManager already stores it as EAs (after this array), I would prefer to leave it here (then your PR changing to store EAs in BlockManager as well can remove those two lines)


try:
new_values = self._astype(dtype, copy=copy)
except (ValueError, TypeError):
# e.g. astype_nansafe can fail on object-dtype of strings
# trying to convert to float
if errors == "ignore":
new_values = self.values
else:
raise
new_values = astype_array_safe(values, dtype, copy=copy, errors=errors)

newb = self.make_block(new_values)
if newb.shape != self.shape:
Expand All @@ -689,37 +664,6 @@ def astype(self, dtype, copy: bool = False, errors: str = "raise"):
)
return newb

def _astype(self, dtype: DtypeObj, copy: bool) -> ArrayLike:
values = self.values
if values.dtype.kind in ["m", "M"]:
values = self.array_values()

if (
values.dtype.kind in ["m", "M"]
and dtype.kind in ["i", "u"]
and isinstance(dtype, np.dtype)
and dtype.itemsize != 8
):
# TODO(2.0) remove special case once deprecation on DTA/TDA is enforced
msg = rf"cannot astype a datetimelike from [{values.dtype}] to [{dtype}]"
raise TypeError(msg)

if is_datetime64tz_dtype(dtype) and is_datetime64_dtype(values.dtype):
return astype_dt64_to_dt64tz(values, dtype, copy, via_utc=True)

if is_dtype_equal(values.dtype, dtype):
if copy:
return values.copy()
return values

if isinstance(values, ExtensionArray):
values = values.astype(dtype, copy=copy)

else:
values = astype_nansafe(values, dtype, copy=copy)

return values

def convert(
self,
copy: bool = True,
Expand Down
8 changes: 0 additions & 8 deletions pandas/tests/frame/methods/test_astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

import pandas as pd
from pandas import (
Categorical,
Expand Down Expand Up @@ -92,7 +90,6 @@ def test_astype_mixed_type(self, mixed_type_frame):
casted = mn.astype("O")
_check_cast(casted, "object")

@td.skip_array_manager_not_yet_implemented
def test_astype_with_exclude_string(self, float_frame):
df = float_frame.copy()
expected = float_frame.astype(int)
Expand Down Expand Up @@ -127,7 +124,6 @@ def test_astype_with_view_mixed_float(self, mixed_float_frame):
casted = tf.astype(np.int64)
casted = tf.astype(np.float32) # noqa

@td.skip_array_manager_not_yet_implemented
@pytest.mark.parametrize("dtype", [np.int32, np.int64])
@pytest.mark.parametrize("val", [np.nan, np.inf])
def test_astype_cast_nan_inf_int(self, val, dtype):
Expand Down Expand Up @@ -386,7 +382,6 @@ def test_astype_to_datetimelike_unit(self, arr_dtype, dtype, unit):

tm.assert_frame_equal(result, expected)

@td.skip_array_manager_not_yet_implemented
@pytest.mark.parametrize("unit", ["ns", "us", "ms", "s", "h", "m", "D"])
def test_astype_to_datetime_unit(self, unit):
# tests all units from datetime origination
Expand All @@ -411,7 +406,6 @@ def test_astype_to_timedelta_unit_ns(self, unit):

tm.assert_frame_equal(result, expected)

@td.skip_array_manager_not_yet_implemented
@pytest.mark.parametrize("unit", ["us", "ms", "s", "h", "m", "D"])
def test_astype_to_timedelta_unit(self, unit):
# coerce to float
Expand Down Expand Up @@ -441,7 +435,6 @@ def test_astype_to_incorrect_datetimelike(self, unit):
with pytest.raises(TypeError, match=msg):
df.astype(dtype)

@td.skip_array_manager_not_yet_implemented
def test_astype_arg_for_errors(self):
# GH#14878

Expand Down Expand Up @@ -570,7 +563,6 @@ def test_astype_empty_dtype_dict(self):
tm.assert_frame_equal(result, df)
assert result is not df

@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) ignore keyword
@pytest.mark.parametrize(
"df",
[
Expand Down
2 changes: 1 addition & 1 deletion pandas/util/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def find_stack_level() -> int:
if stack[n].function == "astype":
break

while stack[n].function in ["astype", "apply", "_astype"]:
while stack[n].function in ["astype", "apply", "astype_array_safe", "astype_array"]:
# e.g.
# bump up Block.astype -> BlockManager.astype -> NDFrame.astype
# bump up Datetime.Array.astype -> DatetimeIndex.astype
Expand Down