Skip to content

Commit

Permalink
Array API fixes for astype (#7847)
Browse files Browse the repository at this point in the history
* array API fixes for astype

* whatsnew
  • Loading branch information
TomNicholas authored May 19, 2023
1 parent 97a2032 commit 05c7888
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 33 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ Documentation
Internal Changes
~~~~~~~~~~~~~~~~

- Minor improvements to support of the python `array api standard <https://data-apis.org/array-api/latest/>`_,
internally using the function ``xp.astype()`` instead of the method ``arr.astype()``, as the latter is not in the standard.
(:pull:`7847`) By `Tom Nicholas <https://github.com/TomNicholas>`_.

.. _whats-new.2023.05.0:

Expand Down
73 changes: 43 additions & 30 deletions xarray/core/accessor_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

import numpy as np

from xarray.core import duck_array_ops
from xarray.core.computation import apply_ufunc
from xarray.core.types import T_DataArray

Expand Down Expand Up @@ -2085,13 +2086,16 @@ def _get_res_multi(val, pat):
else:
# dtype MUST be object or strings can be truncated
# See: https://github.com/numpy/numpy/issues/8352
return self._apply(
func=_get_res_multi,
func_args=(pat,),
dtype=np.object_,
output_core_dims=[[dim]],
output_sizes={dim: maxgroups},
).astype(self._obj.dtype.kind)
return duck_array_ops.astype(
self._apply(
func=_get_res_multi,
func_args=(pat,),
dtype=np.object_,
output_core_dims=[[dim]],
output_sizes={dim: maxgroups},
),
self._obj.dtype.kind,
)

def extractall(
self,
Expand Down Expand Up @@ -2258,15 +2262,18 @@ def _get_res(val, ipat, imaxcount=maxcount, dtype=self._obj.dtype):

return res

return self._apply(
# dtype MUST be object or strings can be truncated
# See: https://github.com/numpy/numpy/issues/8352
func=_get_res,
func_args=(pat,),
dtype=np.object_,
output_core_dims=[[group_dim, match_dim]],
output_sizes={group_dim: maxgroups, match_dim: maxcount},
).astype(self._obj.dtype.kind)
return duck_array_ops.astype(
self._apply(
# dtype MUST be object or strings can be truncated
# See: https://github.com/numpy/numpy/issues/8352
func=_get_res,
func_args=(pat,),
dtype=np.object_,
output_core_dims=[[group_dim, match_dim]],
output_sizes={group_dim: maxgroups, match_dim: maxcount},
),
self._obj.dtype.kind,
)

def findall(
self,
Expand Down Expand Up @@ -2385,13 +2392,16 @@ def _partitioner(

# dtype MUST be object or strings can be truncated
# See: https://github.com/numpy/numpy/issues/8352
return self._apply(
func=arrfunc,
func_args=(sep,),
dtype=np.object_,
output_core_dims=[[dim]],
output_sizes={dim: 3},
).astype(self._obj.dtype.kind)
return duck_array_ops.astype(
self._apply(
func=arrfunc,
func_args=(sep,),
dtype=np.object_,
output_core_dims=[[dim]],
output_sizes={dim: 3},
),
self._obj.dtype.kind,
)

def partition(
self,
Expand Down Expand Up @@ -2510,13 +2520,16 @@ def _dosplit(mystr, sep, maxsplit=maxsplit, dtype=self._obj.dtype):

# dtype MUST be object or strings can be truncated
# See: https://github.com/numpy/numpy/issues/8352
return self._apply(
func=_dosplit,
func_args=(sep,),
dtype=np.object_,
output_core_dims=[[dim]],
output_sizes={dim: maxsplit},
).astype(self._obj.dtype.kind)
return duck_array_ops.astype(
self._apply(
func=_dosplit,
func_args=(sep,),
dtype=np.object_,
output_core_dims=[[dim]],
output_sizes={dim: maxsplit},
),
self._obj.dtype.kind,
)

def split(
self,
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA):
pads = [(0, 0) if d != dim else dim_pad for d in self.dims]

data = np.pad(
trimmed_data.astype(dtype),
duck_array_ops.astype(trimmed_data, dtype),
pads,
mode="constant",
constant_values=fill_value,
Expand Down Expand Up @@ -1570,7 +1570,7 @@ def pad(
pad_option_kwargs["reflect_type"] = reflect_type

array = np.pad(
self.data.astype(dtype, copy=False),
duck_array_ops.astype(self.data, dtype, copy=False),
pad_width_by_index,
mode=mode,
**pad_option_kwargs,
Expand Down Expand Up @@ -2438,7 +2438,7 @@ def rolling_window(
"""
if fill_value is dtypes.NA: # np.nan is passed
dtype, fill_value = dtypes.maybe_promote(self.dtype)
var = self.astype(dtype, copy=False)
var = duck_array_ops.astype(self, dtype, copy=False)
else:
dtype = self.dtype
var = self
Expand Down

0 comments on commit 05c7888

Please sign in to comment.