diff --git a/doc/api.rst b/doc/api.rst index 4492d882355..43a9cf53ead 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -165,6 +165,7 @@ Computation Dataset.groupby_bins Dataset.rolling Dataset.rolling_exp + Dataset.weighted Dataset.coarsen Dataset.resample Dataset.diff @@ -340,6 +341,7 @@ Computation DataArray.groupby_bins DataArray.rolling DataArray.rolling_exp + DataArray.weighted DataArray.coarsen DataArray.dt DataArray.resample @@ -577,6 +579,22 @@ Rolling objects core.rolling.DatasetRolling.reduce core.rolling_exp.RollingExp +Weighted objects +================ + +.. autosummary:: + :toctree: generated/ + + core.weighted.DataArrayWeighted + core.weighted.DataArrayWeighted.mean + core.weighted.DataArrayWeighted.sum + core.weighted.DataArrayWeighted.sum_of_weights + core.weighted.DatasetWeighted + core.weighted.DatasetWeighted.mean + core.weighted.DatasetWeighted.sum + core.weighted.DatasetWeighted.sum_of_weights + + Coarsen objects =============== diff --git a/doc/computation.rst b/doc/computation.rst index 1ac30f55ee7..5309f27e9b6 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -1,3 +1,5 @@ +.. currentmodule:: xarray + .. _comput: ########### @@ -241,12 +243,94 @@ You can also use ``construct`` to compute a weighted rolling sum: To avoid this, use ``skipna=False`` as the above example. +.. _comput.weighted: + +Weighted array reductions +========================= + +:py:class:`DataArray` and :py:class:`Dataset` objects include :py:meth:`DataArray.weighted` +and :py:meth:`Dataset.weighted` array reduction methods. They currently +support weighted ``sum`` and weighted ``mean``. + +.. ipython:: python + + coords = dict(month=('month', [1, 2, 3])) + + prec = xr.DataArray([1.1, 1.0, 0.9], dims=('month', ), coords=coords) + weights = xr.DataArray([31, 28, 31], dims=('month', ), coords=coords) + +Create a weighted object: + +.. ipython:: python + + weighted_prec = prec.weighted(weights) + weighted_prec + +Calculate the weighted sum: + +.. ipython:: python + + weighted_prec.sum() + +Calculate the weighted mean: + +.. ipython:: python + + weighted_prec.mean(dim="month") + +The weighted sum corresponds to: + +.. ipython:: python + + weighted_sum = (prec * weights).sum() + weighted_sum + +and the weighted mean to: + +.. ipython:: python + + weighted_mean = weighted_sum / weights.sum() + weighted_mean + +However, the functions also take missing values in the data into account: + +.. ipython:: python + + data = xr.DataArray([np.NaN, 2, 4]) + weights = xr.DataArray([8, 1, 1]) + + data.weighted(weights).mean() + +Using ``(data * weights).sum() / weights.sum()`` would (incorrectly) result +in 0.6. + + +If the weights add up to to 0, ``sum`` returns 0: + +.. ipython:: python + + data = xr.DataArray([1.0, 1.0]) + weights = xr.DataArray([-1.0, 1.0]) + + data.weighted(weights).sum() + +and ``mean`` returns ``NaN``: + +.. ipython:: python + + data.weighted(weights).mean() + + +.. note:: + ``weights`` must be a :py:class:`DataArray` and cannot contain missing values. + Missing values can be replaced manually by ``weights.fillna(0)``. + .. _comput.coarsen: Coarsen large arrays ==================== -``DataArray`` and ``Dataset`` objects include a +:py:class:`DataArray` and :py:class:`Dataset` objects include a :py:meth:`~xarray.DataArray.coarsen` and :py:meth:`~xarray.Dataset.coarsen` methods. This supports the block aggregation along multiple dimensions, diff --git a/doc/examples.rst b/doc/examples.rst index 805395808e0..1d48d29bcc5 100644 --- a/doc/examples.rst +++ b/doc/examples.rst @@ -6,6 +6,7 @@ Examples examples/weather-data examples/monthly-means + examples/area_weighted_temperature examples/multidimensional-coords examples/visualization_gallery examples/ROMS_ocean_model diff --git a/doc/examples/area_weighted_temperature.ipynb b/doc/examples/area_weighted_temperature.ipynb new file mode 100644 index 00000000000..72876e3fc29 --- /dev/null +++ b/doc/examples/area_weighted_temperature.ipynb @@ -0,0 +1,226 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "toc": true + }, + "source": [ + "

Table of Contents

\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Compare weighted and unweighted mean temperature\n", + "\n", + "\n", + "Author: [Mathias Hauser](https://github.com/mathause/)\n", + "\n", + "\n", + "We use the `air_temperature` example dataset to calculate the area-weighted temperature over its domain. This dataset has a regular latitude/ longitude grid, thus the gridcell area decreases towards the pole. For this grid we can use the cosine of the latitude as proxy for the grid cell area.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:43:57.222351Z", + "start_time": "2020-03-17T14:43:56.147541Z" + } + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import cartopy.crs as ccrs\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "import xarray as xr" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data\n", + "\n", + "Load the data, convert to celsius, and resample to daily values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:43:57.831734Z", + "start_time": "2020-03-17T14:43:57.651845Z" + } + }, + "outputs": [], + "source": [ + "ds = xr.tutorial.load_dataset(\"air_temperature\")\n", + "\n", + "# to celsius\n", + "air = ds.air - 273.15\n", + "\n", + "# resample from 6-hourly to daily values\n", + "air = air.resample(time=\"D\").mean()\n", + "\n", + "air" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot the first timestep:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:43:59.887120Z", + "start_time": "2020-03-17T14:43:59.582894Z" + } + }, + "outputs": [], + "source": [ + "projection = ccrs.LambertConformal(central_longitude=-95, central_latitude=45)\n", + "\n", + "f, ax = plt.subplots(subplot_kw=dict(projection=projection))\n", + "\n", + "air.isel(time=0).plot(transform=ccrs.PlateCarree(), cbar_kwargs=dict(shrink=0.7))\n", + "ax.coastlines()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Creating weights\n", + "\n", + "For a for a rectangular grid the cosine of the latitude is proportional to the grid cell area." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:44:18.777092Z", + "start_time": "2020-03-17T14:44:18.736587Z" + } + }, + "outputs": [], + "source": [ + "weights = np.cos(np.deg2rad(air.lat))\n", + "weights.name = \"weights\"\n", + "weights" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Weighted mean" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:44:52.607120Z", + "start_time": "2020-03-17T14:44:52.564674Z" + } + }, + "outputs": [], + "source": [ + "air_weighted = air.weighted(weights)\n", + "air_weighted" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:44:54.334279Z", + "start_time": "2020-03-17T14:44:54.280022Z" + } + }, + "outputs": [], + "source": [ + "weighted_mean = air_weighted.mean((\"lon\", \"lat\"))\n", + "weighted_mean" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plot: comparison with unweighted mean\n", + "\n", + "Note how the weighted mean temperature is higher than the unweighted." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:45:08.877307Z", + "start_time": "2020-03-17T14:45:08.673383Z" + } + }, + "outputs": [], + "source": [ + "weighted_mean.plot(label=\"weighted\")\n", + "air.mean((\"lon\", \"lat\")).plot(label=\"unweighted\")\n", + "\n", + "plt.legend()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": true, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/whats-new.rst b/doc/whats-new.rst index aad0e083a8c..5640e872bea 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,6 +25,9 @@ Breaking changes New Features ~~~~~~~~~~~~ +- Weighted array reductions are now supported via the new :py:meth:`DataArray.weighted` + and :py:meth:`Dataset.weighted` methods. See :ref:`comput.weighted`. (:issue:`422`, :pull:`2922`). + By `Mathias Hauser `_ - Added support for :py:class:`pandas.DatetimeIndex`-style rounding of ``cftime.datetime`` objects directly via a :py:class:`CFTimeIndex` or via the :py:class:`~core.accessor_dt.DatetimeAccessor`. diff --git a/xarray/core/common.py b/xarray/core/common.py index 39aa7982091..a003642076f 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -745,6 +745,25 @@ def groupby_bins( }, ) + def weighted(self, weights): + """ + Weighted operations. + + Parameters + ---------- + weights : DataArray + An array of weights associated with the values in this Dataset. + Each value in the data contributes to the reduction operation + according to its associated weight. + + Notes + ----- + ``weights`` must be a DataArray and cannot contain missing values. + Missing values can be replaced by ``weights.fillna(0)``. + """ + + return self._weighted_cls(self, weights) + def rolling( self, dim: Mapping[Hashable, int] = None, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b335eeb293b..4b3ecb2744c 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -33,6 +33,7 @@ resample, rolling, utils, + weighted, ) from .accessor_dt import CombinedDatetimelikeAccessor from .accessor_str import StringAccessor @@ -258,6 +259,7 @@ class DataArray(AbstractArray, DataWithCoords): _rolling_cls = rolling.DataArrayRolling _coarsen_cls = rolling.DataArrayCoarsen _resample_cls = resample.DataArrayResample + _weighted_cls = weighted.DataArrayWeighted dt = property(CombinedDatetimelikeAccessor) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d5ad1123a54..c10447f6d11 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -46,6 +46,7 @@ resample, rolling, utils, + weighted, ) from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align from .common import ( @@ -457,6 +458,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): _rolling_cls = rolling.DatasetRolling _coarsen_cls = rolling.DatasetCoarsen _resample_cls = resample.DatasetResample + _weighted_cls = weighted.DatasetWeighted def __init__( self, diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py new file mode 100644 index 00000000000..996d2e4c43e --- /dev/null +++ b/xarray/core/weighted.py @@ -0,0 +1,255 @@ +from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload + +from .computation import dot +from .options import _get_keep_attrs + +if TYPE_CHECKING: + from .dataarray import DataArray, Dataset + +_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """ + Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s). + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to apply the weighted ``{fcn}``. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + + Returns + ------- + reduced : {cls} + New {cls} object with weighted ``{fcn}`` applied to its data and + the indicated dimension(s) removed. + + Notes + ----- + Returns {on_zero} if the ``weights`` sum to 0.0 along the reduced + dimension(s). + """ + +_SUM_OF_WEIGHTS_DOCSTRING = """ + Calculate the sum of weights, accounting for missing values in the data + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to sum the weights. + keep_attrs : bool, optional + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + + Returns + ------- + reduced : {cls} + New {cls} object with the sum of the weights over the given dimension. + """ + + +class Weighted: + """An object that implements weighted operations. + + You should create a Weighted object by using the ``DataArray.weighted`` or + ``Dataset.weighted`` methods. + + See Also + -------- + Dataset.weighted + DataArray.weighted + """ + + __slots__ = ("obj", "weights") + + @overload + def __init__(self, obj: "DataArray", weights: "DataArray") -> None: + ... + + @overload # noqa: F811 + def __init__(self, obj: "Dataset", weights: "DataArray") -> None: # noqa: F811 + ... + + def __init__(self, obj, weights): # noqa: F811 + """ + Create a Weighted object + + Parameters + ---------- + obj : DataArray or Dataset + Object over which the weighted reduction operation is applied. + weights : DataArray + An array of weights associated with the values in the obj. + Each value in the obj contributes to the reduction operation + according to its associated weight. + + Notes + ----- + ``weights`` must be a ``DataArray`` and cannot contain missing values. + Missing values can be replaced by ``weights.fillna(0)``. + """ + + from .dataarray import DataArray + + if not isinstance(weights, DataArray): + raise ValueError("`weights` must be a DataArray") + + if weights.isnull().any(): + raise ValueError( + "`weights` cannot contain missing values. " + "Missing values can be replaced by `weights.fillna(0)`." + ) + + self.obj = obj + self.weights = weights + + @staticmethod + def _reduce( + da: "DataArray", + weights: "DataArray", + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + ) -> "DataArray": + """reduce using dot; equivalent to (da * weights).sum(dim, skipna) + + for internal use only + """ + + # need to infer dims as we use `dot` + if dim is None: + dim = ... + + # need to mask invalid values in da, as `dot` does not implement skipna + if skipna or (skipna is None and da.dtype.kind in "cfO"): + da = da.fillna(0.0) + + # `dot` does not broadcast arrays, so this avoids creating a large + # DataArray (if `weights` has additional dimensions) + # maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`) + return dot(da, weights, dims=dim) + + def _sum_of_weights( + self, da: "DataArray", dim: Optional[Union[Hashable, Iterable[Hashable]]] = None + ) -> "DataArray": + """ Calculate the sum of weights, accounting for missing values """ + + # we need to mask data values that are nan; else the weights are wrong + mask = da.notnull() + + sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False) + + # 0-weights are not valid + valid_weights = sum_of_weights != 0.0 + + return sum_of_weights.where(valid_weights) + + def _weighted_sum( + self, + da: "DataArray", + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + ) -> "DataArray": + """Reduce a DataArray by a by a weighted ``sum`` along some dimension(s).""" + + return self._reduce(da, self.weights, dim=dim, skipna=skipna) + + def _weighted_mean( + self, + da: "DataArray", + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + ) -> "DataArray": + """Reduce a DataArray by a weighted ``mean`` along some dimension(s).""" + + weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna) + + sum_of_weights = self._sum_of_weights(da, dim=dim) + + return weighted_sum / sum_of_weights + + def _implementation(self, func, dim, **kwargs): + + raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`") + + def sum_of_weights( + self, + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + keep_attrs: Optional[bool] = None, + ) -> Union["DataArray", "Dataset"]: + + return self._implementation( + self._sum_of_weights, dim=dim, keep_attrs=keep_attrs + ) + + def sum( + self, + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + keep_attrs: Optional[bool] = None, + ) -> Union["DataArray", "Dataset"]: + + return self._implementation( + self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) + + def mean( + self, + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + keep_attrs: Optional[bool] = None, + ) -> Union["DataArray", "Dataset"]: + + return self._implementation( + self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) + + def __repr__(self): + """provide a nice str repr of our Weighted object""" + + klass = self.__class__.__name__ + weight_dims = ", ".join(self.weights.dims) + return f"{klass} with weights along dimensions: {weight_dims}" + + +class DataArrayWeighted(Weighted): + def _implementation(self, func, dim, **kwargs): + + keep_attrs = kwargs.pop("keep_attrs") + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + weighted = func(self.obj, dim=dim, **kwargs) + + if keep_attrs: + weighted.attrs = self.obj.attrs + + return weighted + + +class DatasetWeighted(Weighted): + def _implementation(self, func, dim, **kwargs) -> "Dataset": + + return self.obj.map(func, dim=dim, **kwargs) + + +def _inject_docstring(cls, cls_name): + + cls.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(cls=cls_name) + + cls.sum.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls=cls_name, fcn="sum", on_zero="0" + ) + + cls.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls=cls_name, fcn="mean", on_zero="NaN" + ) + + +_inject_docstring(DataArrayWeighted, "DataArray") +_inject_docstring(DatasetWeighted, "Dataset") diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py new file mode 100644 index 00000000000..24531215dfb --- /dev/null +++ b/xarray/tests/test_weighted.py @@ -0,0 +1,311 @@ +import numpy as np +import pytest + +import xarray as xr +from xarray import DataArray +from xarray.tests import assert_allclose, assert_equal, raises_regex + + +@pytest.mark.parametrize("as_dataset", (True, False)) +def test_weighted_non_DataArray_weights(as_dataset): + + data = DataArray([1, 2]) + if as_dataset: + data = data.to_dataset(name="data") + + with raises_regex(ValueError, "`weights` must be a DataArray"): + data.weighted([1, 2]) + + +@pytest.mark.parametrize("as_dataset", (True, False)) +@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan])) +def test_weighted_weights_nan_raises(as_dataset, weights): + + data = DataArray([1, 2]) + if as_dataset: + data = data.to_dataset(name="data") + + with pytest.raises(ValueError, match="`weights` cannot contain missing values."): + data.weighted(DataArray(weights)) + + +@pytest.mark.parametrize( + ("weights", "expected"), + (([1, 2], 3), ([2, 0], 2), ([0, 0], np.nan), ([-1, 1], np.nan)), +) +def test_weighted_sum_of_weights_no_nan(weights, expected): + + da = DataArray([1, 2]) + weights = DataArray(weights) + result = da.weighted(weights).sum_of_weights() + + expected = DataArray(expected) + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), + (([1, 2], 2), ([2, 0], np.nan), ([0, 0], np.nan), ([-1, 1], 1)), +) +def test_weighted_sum_of_weights_nan(weights, expected): + + da = DataArray([np.nan, 2]) + weights = DataArray(weights) + result = da.weighted(weights).sum_of_weights() + + expected = DataArray(expected) + + assert_equal(expected, result) + + +@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan])) +@pytest.mark.parametrize("factor", [0, 1, 3.14]) +@pytest.mark.parametrize("skipna", (True, False)) +def test_weighted_sum_equal_weights(da, factor, skipna): + # if all weights are 'f'; weighted sum is f times the ordinary sum + + da = DataArray(da) + weights = xr.full_like(da, factor) + + expected = da.sum(skipna=skipna) * factor + result = da.weighted(weights).sum(skipna=skipna) + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([1, 2], 5), ([0, 2], 4), ([0, 0], 0)) +) +def test_weighted_sum_no_nan(weights, expected): + + da = DataArray([1, 2]) + + weights = DataArray(weights) + result = da.weighted(weights).sum() + expected = DataArray(expected) + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([1, 2], 4), ([0, 2], 4), ([1, 0], 0), ([0, 0], 0)) +) +@pytest.mark.parametrize("skipna", (True, False)) +def test_weighted_sum_nan(weights, expected, skipna): + + da = DataArray([np.nan, 2]) + + weights = DataArray(weights) + result = da.weighted(weights).sum(skipna=skipna) + + if skipna: + expected = DataArray(expected) + else: + expected = DataArray(np.nan) + + assert_equal(expected, result) + + +@pytest.mark.filterwarnings("ignore:Mean of empty slice") +@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan])) +@pytest.mark.parametrize("skipna", (True, False)) +@pytest.mark.parametrize("factor", [1, 2, 3.14]) +def test_weighted_mean_equal_weights(da, skipna, factor): + # if all weights are equal (!= 0), should yield the same result as mean + + da = DataArray(da) + + # all weights as 1. + weights = xr.full_like(da, factor) + + expected = da.mean(skipna=skipna) + result = da.weighted(weights).mean(skipna=skipna) + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([4, 6], 1.6), ([1, 0], 1.0), ([0, 0], np.nan)) +) +def test_weighted_mean_no_nan(weights, expected): + + da = DataArray([1, 2]) + weights = DataArray(weights) + expected = DataArray(expected) + + result = da.weighted(weights).mean() + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([4, 6], 2.0), ([1, 0], np.nan), ([0, 0], np.nan)) +) +@pytest.mark.parametrize("skipna", (True, False)) +def test_weighted_mean_nan(weights, expected, skipna): + + da = DataArray([np.nan, 2]) + weights = DataArray(weights) + + if skipna: + expected = DataArray(expected) + else: + expected = DataArray(np.nan) + + result = da.weighted(weights).mean(skipna=skipna) + + assert_equal(expected, result) + + +def expected_weighted(da, weights, dim, skipna, operation): + """ + Generate expected result using ``*`` and ``sum``. This is checked against + the result of da.weighted which uses ``dot`` + """ + + weighted_sum = (da * weights).sum(dim=dim, skipna=skipna) + + if operation == "sum": + return weighted_sum + + masked_weights = weights.where(da.notnull()) + sum_of_weights = masked_weights.sum(dim=dim, skipna=True) + valid_weights = sum_of_weights != 0 + sum_of_weights = sum_of_weights.where(valid_weights) + + if operation == "sum_of_weights": + return sum_of_weights + + weighted_mean = weighted_sum / sum_of_weights + + if operation == "mean": + return weighted_mean + + +@pytest.mark.parametrize("dim", ("a", "b", "c", ("a", "b"), ("a", "b", "c"), None)) +@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean")) +@pytest.mark.parametrize("add_nans", (True, False)) +@pytest.mark.parametrize("skipna", (None, True, False)) +@pytest.mark.parametrize("as_dataset", (True, False)) +def test_weighted_operations_3D(dim, operation, add_nans, skipna, as_dataset): + + dims = ("a", "b", "c") + coords = dict(a=[0, 1, 2, 3], b=[0, 1, 2, 3], c=[0, 1, 2, 3]) + + weights = DataArray(np.random.randn(4, 4, 4), dims=dims, coords=coords) + + data = np.random.randn(4, 4, 4) + + # add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700) + if add_nans: + c = int(data.size * 0.25) + data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN + + data = DataArray(data, dims=dims, coords=coords) + + if as_dataset: + data = data.to_dataset(name="data") + + if operation == "sum_of_weights": + result = data.weighted(weights).sum_of_weights(dim) + else: + result = getattr(data.weighted(weights), operation)(dim, skipna=skipna) + + expected = expected_weighted(data, weights, dim, skipna, operation) + + assert_allclose(expected, result) + + +@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean")) +@pytest.mark.parametrize("as_dataset", (True, False)) +def test_weighted_operations_nonequal_coords(operation, as_dataset): + + weights = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[0, 1, 2, 3])) + data = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[1, 2, 3, 4])) + + if as_dataset: + data = data.to_dataset(name="data") + + expected = expected_weighted( + data, weights, dim="a", skipna=None, operation=operation + ) + result = getattr(data.weighted(weights), operation)(dim="a") + + assert_allclose(expected, result) + + +@pytest.mark.parametrize("dim", ("dim_0", None)) +@pytest.mark.parametrize("shape_data", ((4,), (4, 4), (4, 4, 4))) +@pytest.mark.parametrize("shape_weights", ((4,), (4, 4), (4, 4, 4))) +@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean")) +@pytest.mark.parametrize("add_nans", (True, False)) +@pytest.mark.parametrize("skipna", (None, True, False)) +@pytest.mark.parametrize("as_dataset", (True, False)) +def test_weighted_operations_different_shapes( + dim, shape_data, shape_weights, operation, add_nans, skipna, as_dataset +): + + weights = DataArray(np.random.randn(*shape_weights)) + + data = np.random.randn(*shape_data) + + # add approximately 25 % NaNs + if add_nans: + c = int(data.size * 0.25) + data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN + + data = DataArray(data) + + if as_dataset: + data = data.to_dataset(name="data") + + if operation == "sum_of_weights": + result = getattr(data.weighted(weights), operation)(dim) + else: + result = getattr(data.weighted(weights), operation)(dim, skipna=skipna) + + expected = expected_weighted(data, weights, dim, skipna, operation) + + assert_allclose(expected, result) + + +@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean")) +@pytest.mark.parametrize("as_dataset", (True, False)) +@pytest.mark.parametrize("keep_attrs", (True, False, None)) +def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs): + + weights = DataArray(np.random.randn(2, 2), attrs=dict(attr="weights")) + data = DataArray(np.random.randn(2, 2)) + + if as_dataset: + data = data.to_dataset(name="data") + + data.attrs = dict(attr="weights") + + result = getattr(data.weighted(weights), operation)(keep_attrs=True) + + if operation == "sum_of_weights": + assert weights.attrs == result.attrs + else: + assert data.attrs == result.attrs + + result = getattr(data.weighted(weights), operation)(keep_attrs=None) + assert not result.attrs + + result = getattr(data.weighted(weights), operation)(keep_attrs=False) + assert not result.attrs + + +@pytest.mark.xfail(reason="xr.Dataset.map does not copy attrs of DataArrays GH: 3595") +@pytest.mark.parametrize("operation", ("sum", "mean")) +def test_weighted_operations_keep_attr_da_in_ds(operation): + # GH #3595 + + weights = DataArray(np.random.randn(2, 2)) + data = DataArray(np.random.randn(2, 2), attrs=dict(attr="data")) + data = data.to_dataset(name="a") + + result = getattr(data.weighted(weights), operation)(keep_attrs=True) + + assert data.a.attrs == result.a.attrs