From 8a2f29b8c0548963c7d88398c081faea5b6388db Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 31 Oct 2023 18:43:57 +0100 Subject: [PATCH] Fix sparse typing (#8387) * Fix sparse typing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update variable.py * typos * Update variable.py * Update variable.py * Update variable.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update variable.py * Update variable.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- xarray/core/variable.py | 22 ++++++++++++++++++++++ xarray/namedarray/_typing.py | 9 +++++++-- xarray/namedarray/core.py | 18 ++++++++---------- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f18c4044f40..db109a40454 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2595,6 +2595,28 @@ def argmax( """ return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) + def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable: + """ + Use sparse-array as backend. + """ + from xarray.namedarray.utils import _default as _default_named + + if sparse_format is _default: + sparse_format = _default_named + + if fill_value is _default: + fill_value = _default_named + + out = super()._as_sparse(sparse_format, fill_value) + return cast("Variable", out) + + def _to_dense(self) -> Variable: + """ + Change backend from sparse to np.array. + """ + out = super()._to_dense() + return cast("Variable", out) + class IndexVariable(Variable): """Wrapper for accommodating a pandas.Index in an xarray.Variable. diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 8cfc6931431..7de44240530 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -248,7 +248,7 @@ class _sparsearrayfunction( Corresponds to np.ndarray. """ - def todense(self) -> NDArray[_ScalarType_co]: + def todense(self) -> np.ndarray[Any, _DType_co]: ... @@ -262,9 +262,14 @@ class _sparsearrayapi( Corresponds to np.ndarray. """ - def todense(self) -> NDArray[_ScalarType_co]: + def todense(self) -> np.ndarray[Any, _DType_co]: ... # NamedArray can most likely use both __array_function__ and __array_namespace__: _sparsearrayfunction_or_api = (_sparsearrayfunction, _sparsearrayapi) + +sparseduckarray = Union[ + _sparsearrayfunction[_ShapeType_co, _DType_co], + _sparsearrayapi[_ShapeType_co, _DType_co], +] diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index feff052101b..2fef1cad3db 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -29,6 +29,7 @@ _DType_co, _ScalarType_co, _ShapeType_co, + _sparsearrayfunction_or_api, _SupportsImag, _SupportsReal, ) @@ -810,9 +811,9 @@ def _as_sparse( self, sparse_format: Literal["coo"] | Default = _default, fill_value: ArrayLike | Default = _default, - ) -> Self: + ) -> NamedArray[Any, _DType_co]: """ - use sparse-array as backend. + Use sparse-array as backend. """ import sparse @@ -832,18 +833,15 @@ def _as_sparse( raise ValueError(f"{sparse_format} is not a valid sparse format") from exc data = as_sparse(astype(self, dtype).data, fill_value=fill_value) - return self._replace(data=data) + return self._new(data=data) - def _to_dense(self) -> Self: + def _to_dense(self) -> NamedArray[Any, _DType_co]: """ - Change backend from sparse to np.array + Change backend from sparse to np.array. """ - from xarray.namedarray._typing import _sparsearrayfunction_or_api - if isinstance(self._data, _sparsearrayfunction_or_api): - # return self._replace(data=self._data.todense()) - data_: np.ndarray[Any, Any] = self._data.todense() - return self._replace(data=data_) + data_dense: np.ndarray[Any, _DType_co] = self._data.todense() + return self._new(data=data_dense) else: raise TypeError("self.data is not a sparse array")