From 8a2fbb89ea6880ef43362281792dce2005b6d08a Mon Sep 17 00:00:00 2001 From: Christian Jauvin Date: Sun, 27 Mar 2022 16:36:22 -0400 Subject: [PATCH] Weighted quantile (#6059) * Add weighted quantile * Add weighted quantile to documentation * Apply suggestions from code review Co-authored-by: Mathias Hauser * Apply suggestions from code review Co-authored-by: Mathias Hauser * Improve _weighted_quantile_type7_1d ufunc with suggestions * Expand scope of invalid q value test * Fix weighted quantile with zero weights * Replace np.ones by xr.ones_like in weighted quantile test * Process weighted quantile data with all nans * Fix operator precedence bug * Used effective sample size. Generalize to different quantile types supporting weighted quantiles (4-9, but only 7 is exposed and tested). Fixed unit tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Mathias Hauser Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added missing Typing hints * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update what's new and pep8 fixes * add docstring paragraph discussing weight interpretation * recognize numpy names for quantile interpolation methods * tweak to avoid warning with all nans data. simplify test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove integers from quantile interpolation available methods * remove merge artifacts * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [skip-ci] fix bad merge in whats-new * Add references * renamed htype argument to method in private functions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray/core/weighted.py Co-authored-by: Abel Aoun * Add skipped test to verify equal weights quantile with methods * Apply suggestions from code review Co-authored-by: Mathias Hauser * Update xarray/core/weighted.py Co-authored-by: Mathias Hauser * modifications suggested by review: comments, remove align, clarify test logic * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Mathias Hauser * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use broadcast * move whatsnew entry * Apply suggestions from code review * switch skipna determination * use align and broadcast Co-authored-by: Mathias Hauser Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: David Huard Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian Co-authored-by: Abel Aoun Co-authored-by: Mathias Hauser --- doc/api.rst | 2 + doc/user-guide/computation.rst | 8 +- doc/whats-new.rst | 3 + xarray/core/weighted.py | 223 ++++++++++++++++++++++++++++++- xarray/tests/test_weighted.py | 237 ++++++++++++++++++++++++++++++--- 5 files changed, 453 insertions(+), 20 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index d2c222da4db..7fdd775e168 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -944,6 +944,7 @@ Dataset DatasetWeighted DatasetWeighted.mean + DatasetWeighted.quantile DatasetWeighted.sum DatasetWeighted.std DatasetWeighted.var @@ -958,6 +959,7 @@ DataArray DataArrayWeighted DataArrayWeighted.mean + DataArrayWeighted.quantile DataArrayWeighted.sum DataArrayWeighted.std DataArrayWeighted.var diff --git a/doc/user-guide/computation.rst b/doc/user-guide/computation.rst index de2afa9060c..dc9748af80b 100644 --- a/doc/user-guide/computation.rst +++ b/doc/user-guide/computation.rst @@ -265,7 +265,7 @@ Weighted array reductions :py:class:`DataArray` and :py:class:`Dataset` objects include :py:meth:`DataArray.weighted` and :py:meth:`Dataset.weighted` array reduction methods. They currently -support weighted ``sum``, ``mean``, ``std`` and ``var``. +support weighted ``sum``, ``mean``, ``std``, ``var`` and ``quantile``. .. ipython:: python @@ -293,6 +293,12 @@ Calculate the weighted mean: weighted_prec.mean(dim="month") +Calculate the weighted quantile: + +.. ipython:: python + + weighted_prec.quantile(q=0.5, dim="month") + The weighted sum corresponds to: .. ipython:: python diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 37cf3af85b9..a15618e9d1f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,9 @@ v2022.03.1 (unreleased) New Features ~~~~~~~~~~~~ +- Add a weighted ``quantile`` method to :py:class:`~core.weighted.DatasetWeighted` and + :py:class:`~core.weighted.DataArrayWeighted` (:pull:`6059`). By + `Christian Jauvin `_ and `David Huard `_. - Add a ``create_index=True`` parameter to :py:meth:`Dataset.stack` and :py:meth:`DataArray.stack` so that the creation of multi-indexes is optional (:pull:`5692`). By `Benoît Bovy `_. diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 83ce36bcb35..2e944eab1e0 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,14 +1,26 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, Hashable, Iterable, cast +from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Literal, Sequence, cast import numpy as np -from . import duck_array_ops -from .computation import dot +from . import duck_array_ops, utils +from .alignment import align, broadcast +from .computation import apply_ufunc, dot +from .npcompat import ArrayLike from .pycompat import is_duck_dask_array from .types import T_Xarray +# Weighted quantile methods are a subset of the numpy supported quantile methods. +QUANTILE_METHODS = Literal[ + "linear", + "interpolated_inverted_cdf", + "hazen", + "weibull", + "median_unbiased", + "normal_unbiased", +] + _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """ Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s). @@ -56,6 +68,61 @@ New {cls} object with the sum of the weights over the given dimension. """ +_WEIGHTED_QUANTILE_DOCSTRING_TEMPLATE = """ + Apply a weighted ``quantile`` to this {cls}'s data along some dimension(s). + + Weights are interpreted as *sampling weights* (or probability weights) and + describe how a sample is scaled to the whole population [1]_. There are + other possible interpretations for weights, *precision weights* describing the + precision of observations, or *frequency weights* counting the number of identical + observations, however, they are not implemented here. + + For compatibility with NumPy's non-weighted ``quantile`` (which is used by + ``DataArray.quantile`` and ``Dataset.quantile``), the only interpolation + method supported by this weighted version corresponds to the default "linear" + option of ``numpy.quantile``. This is "Type 7" option, described in Hyndman + and Fan (1996) [2]_. The implementation is largely inspired by a blog post + from A. Akinshin's [3]_. + + Parameters + ---------- + q : float or sequence of float + Quantile to compute, which must be between 0 and 1 inclusive. + dim : str or sequence of str, optional + Dimension(s) over which to apply the weighted ``quantile``. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + + Returns + ------- + quantiles : {cls} + New {cls} object with weighted ``quantile`` applied to its data and + the indicated dimension(s) removed. + + See Also + -------- + numpy.nanquantile, pandas.Series.quantile, Dataset.quantile, DataArray.quantile + + Notes + ----- + Returns NaN if the ``weights`` sum to 0.0 along the reduced + dimension(s). + + References + ---------- + .. [1] https://notstatschat.rbind.io/2020/08/04/weights-in-statistics/ + .. [2] Hyndman, R. J. & Fan, Y. (1996). Sample Quantiles in Statistical Packages. + The American Statistician, 50(4), 361–365. https://doi.org/10.2307/2684934 + .. [3] https://aakinshin.net/posts/weighted-quantiles + """ + if TYPE_CHECKING: from .dataarray import DataArray @@ -241,6 +308,141 @@ def _weighted_std( return cast("DataArray", np.sqrt(self._weighted_var(da, dim, skipna))) + def _weighted_quantile( + self, + da: DataArray, + q: ArrayLike, + dim: Hashable | Iterable[Hashable] | None = None, + skipna: bool = None, + ) -> DataArray: + """Apply a weighted ``quantile`` to a DataArray along some dimension(s).""" + + def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray: + """Return the interpolation parameter.""" + # Note that options are not yet exposed in the public API. + if method == "linear": + h = (n - 1) * q + 1 + elif method == "interpolated_inverted_cdf": + h = n * q + elif method == "hazen": + h = n * q + 0.5 + elif method == "weibull": + h = (n + 1) * q + elif method == "median_unbiased": + h = (n + 1 / 3) * q + 1 / 3 + elif method == "normal_unbiased": + h = (n + 1 / 4) * q + 3 / 8 + else: + raise ValueError(f"Invalid method: {method}.") + return h.clip(1, n) + + def _weighted_quantile_1d( + data: np.ndarray, + weights: np.ndarray, + q: np.ndarray, + skipna: bool, + method: QUANTILE_METHODS = "linear", + ) -> np.ndarray: + + # This algorithm has been adapted from: + # https://aakinshin.net/posts/weighted-quantiles/#reference-implementation + is_nan = np.isnan(data) + if skipna: + # Remove nans from data and weights + not_nan = ~is_nan + data = data[not_nan] + weights = weights[not_nan] + elif is_nan.any(): + # Return nan if data contains any nan + return np.full(q.size, np.nan) + + # Filter out data (and weights) associated with zero weights, which also flattens them + nonzero_weights = weights != 0 + data = data[nonzero_weights] + weights = weights[nonzero_weights] + n = data.size + + if n == 0: + # Possibly empty after nan or zero weight filtering above + return np.full(q.size, np.nan) + + # Kish's effective sample size + nw = weights.sum() ** 2 / (weights**2).sum() + + # Sort data and weights + sorter = np.argsort(data) + data = data[sorter] + weights = weights[sorter] + + # Normalize and sum the weights + weights = weights / weights.sum() + weights_cum = np.append(0, weights.cumsum()) + + # Vectorize the computation by transposing q with respect to weights + q = np.atleast_2d(q).T + + # Get the interpolation parameter for each q + h = _get_h(nw, q, method) + + # Find the samples contributing to the quantile computation (at *positions* between (h-1)/nw and h/nw) + u = np.maximum((h - 1) / nw, np.minimum(h / nw, weights_cum)) + + # Compute their relative weight + v = u * nw - h + 1 + w = np.diff(v) + + # Apply the weights + return (data * w).sum(axis=1) + + if skipna is None and da.dtype.kind in "cfO": + skipna = True + + q = np.atleast_1d(np.asarray(q, dtype=np.float64)) + + if q.ndim > 1: + raise ValueError("q must be a scalar or 1d") + + if np.any((q < 0) | (q > 1)): + raise ValueError("q values must be between 0 and 1") + + if dim is None: + dim = da.dims + + if utils.is_scalar(dim): + dim = [dim] + + # To satisfy mypy + dim = cast(Sequence, dim) + + # need to align *and* broadcast + # - `_weighted_quantile_1d` requires arrays with the same shape + # - broadcast does an outer join, which can introduce NaN to weights + # - therefore we first need to do align(..., join="inner") + + # TODO: use broadcast(..., join="inner") once available + # see https://github.com/pydata/xarray/issues/6304 + + da, weights = align(da, self.weights, join="inner") + da, weights = broadcast(da, weights) + + result = apply_ufunc( + _weighted_quantile_1d, + da, + weights, + input_core_dims=[dim, dim], + output_core_dims=[["quantile"]], + output_dtypes=[np.float64], + dask_gufunc_kwargs=dict(output_sizes={"quantile": len(q)}), + dask="parallelized", + vectorize=True, + kwargs={"q": q, "skipna": skipna}, + ) + + result = result.transpose("quantile", ...) + result = result.assign_coords(quantile=q).squeeze() + + return result + def _implementation(self, func, dim, **kwargs): raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`") @@ -310,6 +512,19 @@ def std( self._weighted_std, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + def quantile( + self, + q: ArrayLike, + *, + dim: Hashable | Sequence[Hashable] | None = None, + keep_attrs: bool = None, + skipna: bool = True, + ) -> T_Xarray: + + return self._implementation( + self._weighted_quantile, q=q, dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) + def __repr__(self): """provide a nice str repr of our Weighted object""" @@ -360,6 +575,8 @@ def _inject_docstring(cls, cls_name): cls=cls_name, fcn="std", on_zero="NaN" ) + cls.quantile.__doc__ = _WEIGHTED_QUANTILE_DOCSTRING_TEMPLATE.format(cls=cls_name) + _inject_docstring(DataArrayWeighted, "DataArray") _inject_docstring(DatasetWeighted, "Dataset") diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 1f065228bc4..63dd1ec0c94 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -194,6 +194,160 @@ def test_weighted_mean_no_nan(weights, expected): assert_equal(expected, result) +@pytest.mark.parametrize( + ("weights", "expected"), + ( + ( + [0.25, 0.05, 0.15, 0.25, 0.15, 0.1, 0.05], + [1.554595, 2.463784, 3.000000, 3.518378], + ), + ( + [0.05, 0.05, 0.1, 0.15, 0.15, 0.25, 0.25], + [2.840000, 3.632973, 4.076216, 4.523243], + ), + ), +) +def test_weighted_quantile_no_nan(weights, expected): + # Expected values were calculated by running the reference implementation + # proposed in https://aakinshin.net/posts/weighted-quantiles/ + + da = DataArray([1, 1.9, 2.2, 3, 3.7, 4.1, 5]) + q = [0.2, 0.4, 0.6, 0.8] + weights = DataArray(weights) + + expected = DataArray(expected, coords={"quantile": q}) + result = da.weighted(weights).quantile(q) + + assert_allclose(expected, result) + + +def test_weighted_quantile_zero_weights(): + + da = DataArray([0, 1, 2, 3]) + weights = DataArray([1, 0, 1, 0]) + q = 0.75 + + result = da.weighted(weights).quantile(q) + expected = DataArray([0, 2]).quantile(0.75) + + assert_allclose(expected, result) + + +def test_weighted_quantile_simple(): + # Check that weighted quantiles return the same value as numpy quantiles + da = DataArray([0, 1, 2, 3]) + w = DataArray([1, 0, 1, 0]) + + w_eps = DataArray([1, 0.0001, 1, 0.0001]) + q = 0.75 + + expected = DataArray(np.quantile([0, 2], q), coords={"quantile": q}) # 1.5 + + assert_equal(expected, da.weighted(w).quantile(q)) + assert_allclose(expected, da.weighted(w_eps).quantile(q), rtol=0.001) + + +@pytest.mark.parametrize("skipna", (True, False)) +def test_weighted_quantile_nan(skipna): + # Check skipna behavior + da = DataArray([0, 1, 2, 3, np.nan]) + w = DataArray([1, 0, 1, 0, 1]) + q = [0.5, 0.75] + + result = da.weighted(w).quantile(q, skipna=skipna) + + if skipna: + expected = DataArray(np.quantile([0, 2], q), coords={"quantile": q}) + else: + expected = DataArray(np.full(len(q), np.nan), coords={"quantile": q}) + + assert_allclose(expected, result) + + +@pytest.mark.parametrize( + "da", + ( + [1, 1.9, 2.2, 3, 3.7, 4.1, 5], + [1, 1.9, 2.2, 3, 3.7, 4.1, np.nan], + [np.nan, np.nan, np.nan], + ), +) +@pytest.mark.parametrize("q", (0.5, (0.2, 0.8))) +@pytest.mark.parametrize("skipna", (True, False)) +@pytest.mark.parametrize("factor", [1, 3.14]) +def test_weighted_quantile_equal_weights(da, q, skipna, factor): + # if all weights are equal (!= 0), should yield the same result as quantile + + da = DataArray(da) + weights = xr.full_like(da, factor) + + expected = da.quantile(q, skipna=skipna) + result = da.weighted(weights).quantile(q, skipna=skipna) + + assert_allclose(expected, result) + + +@pytest.mark.skip(reason="`method` argument is not currently exposed") +@pytest.mark.parametrize( + "da", + ( + [1, 1.9, 2.2, 3, 3.7, 4.1, 5], + [1, 1.9, 2.2, 3, 3.7, 4.1, np.nan], + [np.nan, np.nan, np.nan], + ), +) +@pytest.mark.parametrize("q", (0.5, (0.2, 0.8))) +@pytest.mark.parametrize("skipna", (True, False)) +@pytest.mark.parametrize( + "method", + [ + "linear", + "interpolated_inverted_cdf", + "hazen", + "weibull", + "median_unbiased", + "normal_unbiased2", + ], +) +def test_weighted_quantile_equal_weights_all_methods(da, q, skipna, factor, method): + # If all weights are equal (!= 0), should yield the same result as numpy quantile + + da = DataArray(da) + weights = xr.full_like(da, 3.14) + + expected = da.quantile(q, skipna=skipna, method=method) + result = da.weighted(weights).quantile(q, skipna=skipna, method=method) + + assert_allclose(expected, result) + + +def test_weighted_quantile_bool(): + # https://github.com/pydata/xarray/issues/4074 + da = DataArray([1, 1]) + weights = DataArray([True, True]) + q = 0.5 + + expected = DataArray([1], coords={"quantile": [q]}).squeeze() + result = da.weighted(weights).quantile(q) + + assert_equal(expected, result) + + +@pytest.mark.parametrize("q", (-1, 1.1, (0.5, 1.1), ((0.2, 0.4), (0.6, 0.8)))) +def test_weighted_quantile_with_invalid_q(q): + + da = DataArray([1, 1.9, 2.2, 3, 3.7, 4.1, 5]) + q = np.asarray(q) + weights = xr.ones_like(da) + + if q.ndim <= 1: + with pytest.raises(ValueError, match="q values must be between 0 and 1"): + da.weighted(weights).quantile(q) + else: + with pytest.raises(ValueError, match="q must be a scalar or 1d"): + da.weighted(weights).quantile(q) + + @pytest.mark.parametrize( ("weights", "expected"), (([4, 6], 2.0), ([1, 0], np.nan), ([0, 0], np.nan)) ) @@ -466,16 +620,56 @@ def test_weighted_operations_3D(dim, add_nans, skipna): check_weighted_operations(data, weights, dim, skipna) -def test_weighted_operations_nonequal_coords(): +@pytest.mark.parametrize("dim", ("a", "b", "c", ("a", "b"), ("a", "b", "c"), None)) +@pytest.mark.parametrize("q", (0.5, (0.1, 0.9), (0.2, 0.4, 0.6, 0.8))) +@pytest.mark.parametrize("add_nans", (True, False)) +@pytest.mark.parametrize("skipna", (None, True, False)) +def test_weighted_quantile_3D(dim, q, add_nans, skipna): + + dims = ("a", "b", "c") + coords = dict(a=[0, 1, 2], b=[0, 1, 2, 3], c=[0, 1, 2, 3, 4]) + data = np.arange(60).reshape(3, 4, 5).astype(float) + + # add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700) + if add_nans: + c = int(data.size * 0.25) + data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN + + da = DataArray(data, dims=dims, coords=coords) + + # Weights are all ones, because we will compare against DataArray.quantile (non-weighted) + weights = xr.ones_like(da) + + result = da.weighted(weights).quantile(q, dim=dim, skipna=skipna) + expected = da.quantile(q, dim=dim, skipna=skipna) + + assert_allclose(expected, result) + + ds = da.to_dataset(name="data") + result2 = ds.weighted(weights).quantile(q, dim=dim, skipna=skipna) + + assert_allclose(expected, result2.data) + + +def test_weighted_operations_nonequal_coords(): + # There are no weights for a == 4, so that data point is ignored. weights = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[0, 1, 2, 3])) data = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[1, 2, 3, 4])) - check_weighted_operations(data, weights, dim="a", skipna=None) + q = 0.5 + result = data.weighted(weights).quantile(q, dim="a") + # Expected value computed using code from https://aakinshin.net/posts/weighted-quantiles/ with values at a=1,2,3 + expected = DataArray([0.9308707], coords={"quantile": [q]}).squeeze() + assert_allclose(result, expected) + data = data.to_dataset(name="data") check_weighted_operations(data, weights, dim="a", skipna=None) + result = data.weighted(weights).quantile(q, dim="a") + assert_allclose(result, expected.to_dataset(name="data")) + @pytest.mark.parametrize("shape_data", ((4,), (4, 4), (4, 4, 4))) @pytest.mark.parametrize("shape_weights", ((4,), (4, 4), (4, 4, 4))) @@ -506,7 +700,8 @@ def test_weighted_operations_different_shapes( @pytest.mark.parametrize( - "operation", ("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std") + "operation", + ("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std", "quantile"), ) @pytest.mark.parametrize("as_dataset", (True, False)) @pytest.mark.parametrize("keep_attrs", (True, False, None)) @@ -520,22 +715,23 @@ def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs): data.attrs = dict(attr="weights") - result = getattr(data.weighted(weights), operation)(keep_attrs=True) + kwargs = {"keep_attrs": keep_attrs} + if operation == "quantile": + kwargs["q"] = 0.5 + + result = getattr(data.weighted(weights), operation)(**kwargs) if operation == "sum_of_weights": - assert weights.attrs == result.attrs + assert result.attrs == (weights.attrs if keep_attrs else {}) + assert result.attrs == (weights.attrs if keep_attrs else {}) else: - assert data.attrs == result.attrs - - result = getattr(data.weighted(weights), operation)(keep_attrs=None) - assert not result.attrs - - result = getattr(data.weighted(weights), operation)(keep_attrs=False) - assert not result.attrs + assert result.attrs == (weights.attrs if keep_attrs else {}) + assert result.attrs == (data.attrs if keep_attrs else {}) @pytest.mark.parametrize( - "operation", ("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std") + "operation", + ("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std", "quantile"), ) def test_weighted_operations_keep_attr_da_in_ds(operation): # GH #3595 @@ -544,22 +740,31 @@ def test_weighted_operations_keep_attr_da_in_ds(operation): data = DataArray(np.random.randn(2, 2), attrs=dict(attr="data")) data = data.to_dataset(name="a") - result = getattr(data.weighted(weights), operation)(keep_attrs=True) + kwargs = {"keep_attrs": True} + if operation == "quantile": + kwargs["q"] = 0.5 + + result = getattr(data.weighted(weights), operation)(**kwargs) assert data.a.attrs == result.a.attrs +@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean", "quantile")) @pytest.mark.parametrize("as_dataset", (True, False)) -def test_weighted_bad_dim(as_dataset): +def test_weighted_bad_dim(operation, as_dataset): data = DataArray(np.random.randn(2, 2)) weights = xr.ones_like(data) if as_dataset: data = data.to_dataset(name="data") + kwargs = {"dim": "bad_dim"} + if operation == "quantile": + kwargs["q"] = 0.5 + error_msg = ( f"{data.__class__.__name__}Weighted" " does not contain the dimensions: {'bad_dim'}" ) with pytest.raises(ValueError, match=error_msg): - data.weighted(weights).mean("bad_dim") + getattr(data.weighted(weights), operation)(**kwargs)