diff --git a/xarray/core/common.py b/xarray/core/common.py index a69ba03a7a4..c5836c68759 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -3,6 +3,7 @@ from html import escape from textwrap import dedent from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -32,6 +33,12 @@ ALL_DIMS = ... +if TYPE_CHECKING: + from .dataarray import DataArray + from .weighted import Weighted + +T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") + C = TypeVar("C") T = TypeVar("T") @@ -772,7 +779,9 @@ def groupby_bins( }, ) - def weighted(self, weights): + def weighted( + self: T_DataWithCoords, weights: "DataArray" + ) -> "Weighted[T_DataWithCoords]": """ Weighted operations. diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index dbd4e1ad103..449a7200ee7 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,13 +1,16 @@ -from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload +from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, TypeVar, Union from . import duck_array_ops from .computation import dot -from .options import _get_keep_attrs from .pycompat import is_duck_dask_array if TYPE_CHECKING: + from .common import DataWithCoords # noqa: F401 from .dataarray import DataArray, Dataset +T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") + + _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """ Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s). @@ -56,7 +59,7 @@ """ -class Weighted: +class Weighted(Generic[T_DataWithCoords]): """An object that implements weighted operations. You should create a Weighted object by using the ``DataArray.weighted`` or @@ -70,15 +73,7 @@ class Weighted: __slots__ = ("obj", "weights") - @overload - def __init__(self, obj: "DataArray", weights: "DataArray") -> None: - ... - - @overload - def __init__(self, obj: "Dataset", weights: "DataArray") -> None: - ... - - def __init__(self, obj, weights): + def __init__(self, obj: T_DataWithCoords, weights: "DataArray"): """ Create a Weighted object @@ -121,8 +116,8 @@ def _weight_check(w): else: _weight_check(weights.data) - self.obj = obj - self.weights = weights + self.obj: T_DataWithCoords = obj + self.weights: "DataArray" = weights @staticmethod def _reduce( @@ -146,7 +141,6 @@ def _reduce( # `dot` does not broadcast arrays, so this avoids creating a large # DataArray (if `weights` has additional dimensions) - # maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`) return dot(da, weights, dims=dim) def _sum_of_weights( @@ -203,7 +197,7 @@ def sum_of_weights( self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, keep_attrs: Optional[bool] = None, - ) -> Union["DataArray", "Dataset"]: + ) -> T_DataWithCoords: return self._implementation( self._sum_of_weights, dim=dim, keep_attrs=keep_attrs @@ -214,7 +208,7 @@ def sum( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, keep_attrs: Optional[bool] = None, - ) -> Union["DataArray", "Dataset"]: + ) -> T_DataWithCoords: return self._implementation( self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs @@ -225,7 +219,7 @@ def mean( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, keep_attrs: Optional[bool] = None, - ) -> Union["DataArray", "Dataset"]: + ) -> T_DataWithCoords: return self._implementation( self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs @@ -239,22 +233,15 @@ def __repr__(self): return f"{klass} with weights along dimensions: {weight_dims}" -class DataArrayWeighted(Weighted): - def _implementation(self, func, dim, **kwargs): - - keep_attrs = kwargs.pop("keep_attrs") - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) - - weighted = func(self.obj, dim=dim, **kwargs) - - if keep_attrs: - weighted.attrs = self.obj.attrs +class DataArrayWeighted(Weighted["DataArray"]): + def _implementation(self, func, dim, **kwargs) -> "DataArray": - return weighted + dataset = self.obj._to_temp_dataset() + dataset = dataset.map(func, dim=dim, **kwargs) + return self.obj._from_temp_dataset(dataset) -class DatasetWeighted(Weighted): +class DatasetWeighted(Weighted["Dataset"]): def _implementation(self, func, dim, **kwargs) -> "Dataset": return self.obj.map(func, dim=dim, **kwargs)