Skip to content

Commit

Permalink
Weighted quantile (#6059)
Browse files Browse the repository at this point in the history
* Add weighted quantile

* Add weighted quantile to documentation

* Apply suggestions from code review

Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>

* Improve _weighted_quantile_type7_1d ufunc with suggestions

* Expand scope of invalid q value test

* Fix weighted quantile with zero weights

* Replace np.ones by xr.ones_like in weighted quantile test

* Process weighted quantile data with all nans

* Fix operator precedence bug

* Used effective sample size. Generalize to different quantile types supporting weighted quantiles (4-9, but only 7 is exposed and tested). Fixed unit tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>
Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added missing Typing hints

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update what's new and pep8 fixes

* add docstring paragraph discussing weight interpretation

* recognize numpy names for quantile interpolation methods

* tweak to avoid warning with all nans data. simplify test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove integers from quantile interpolation available methods

* remove merge artifacts

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [skip-ci] fix bad merge in whats-new

* Add references

* renamed htype argument to method in private functions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update xarray/core/weighted.py

Co-authored-by: Abel Aoun <aoun.abel@gmail.com>

* Add skipped test to verify equal weights quantile with methods

* Apply suggestions from code review

Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>

* Update xarray/core/weighted.py

Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>

* modifications suggested by review: comments, remove align, clarify test logic

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use broadcast

* move whatsnew entry

* Apply suggestions from code review

* switch skipna determination

* use align and broadcast

Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>
Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>
Co-authored-by: David Huard <huard.david@ouranos.ca>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
Co-authored-by: Abel Aoun <aoun.abel@gmail.com>
Co-authored-by: Mathias Hauser <mathias.hauser@env.ethz.ch>
  • Loading branch information
8 people authored Mar 27, 2022
1 parent 8f42bfd commit 8a2fbb8
Show file tree
Hide file tree
Showing 5 changed files with 453 additions and 20 deletions.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,7 @@ Dataset

DatasetWeighted
DatasetWeighted.mean
DatasetWeighted.quantile
DatasetWeighted.sum
DatasetWeighted.std
DatasetWeighted.var
Expand All @@ -958,6 +959,7 @@ DataArray

DataArrayWeighted
DataArrayWeighted.mean
DataArrayWeighted.quantile
DataArrayWeighted.sum
DataArrayWeighted.std
DataArrayWeighted.var
Expand Down
8 changes: 7 additions & 1 deletion doc/user-guide/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ Weighted array reductions

:py:class:`DataArray` and :py:class:`Dataset` objects include :py:meth:`DataArray.weighted`
and :py:meth:`Dataset.weighted` array reduction methods. They currently
support weighted ``sum``, ``mean``, ``std`` and ``var``.
support weighted ``sum``, ``mean``, ``std``, ``var`` and ``quantile``.

.. ipython:: python
Expand Down Expand Up @@ -293,6 +293,12 @@ Calculate the weighted mean:
weighted_prec.mean(dim="month")
Calculate the weighted quantile:

.. ipython:: python
weighted_prec.quantile(q=0.5, dim="month")
The weighted sum corresponds to:

.. ipython:: python
Expand Down
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ v2022.03.1 (unreleased)
New Features
~~~~~~~~~~~~

- Add a weighted ``quantile`` method to :py:class:`~core.weighted.DatasetWeighted` and
:py:class:`~core.weighted.DataArrayWeighted` (:pull:`6059`). By
`Christian Jauvin <https://github.com/cjauvin>`_ and `David Huard <https://github.com/huard>`_.
- Add a ``create_index=True`` parameter to :py:meth:`Dataset.stack` and
:py:meth:`DataArray.stack` so that the creation of multi-indexes is optional
(:pull:`5692`). By `Benoît Bovy <https://github.com/benbovy>`_.
Expand Down
223 changes: 220 additions & 3 deletions xarray/core/weighted.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Generic, Hashable, Iterable, cast
from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Literal, Sequence, cast

import numpy as np

from . import duck_array_ops
from .computation import dot
from . import duck_array_ops, utils
from .alignment import align, broadcast
from .computation import apply_ufunc, dot
from .npcompat import ArrayLike
from .pycompat import is_duck_dask_array
from .types import T_Xarray

# Weighted quantile methods are a subset of the numpy supported quantile methods.
QUANTILE_METHODS = Literal[
"linear",
"interpolated_inverted_cdf",
"hazen",
"weibull",
"median_unbiased",
"normal_unbiased",
]

_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """
Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s).
Expand Down Expand Up @@ -56,6 +68,61 @@
New {cls} object with the sum of the weights over the given dimension.
"""

_WEIGHTED_QUANTILE_DOCSTRING_TEMPLATE = """
Apply a weighted ``quantile`` to this {cls}'s data along some dimension(s).
Weights are interpreted as *sampling weights* (or probability weights) and
describe how a sample is scaled to the whole population [1]_. There are
other possible interpretations for weights, *precision weights* describing the
precision of observations, or *frequency weights* counting the number of identical
observations, however, they are not implemented here.
For compatibility with NumPy's non-weighted ``quantile`` (which is used by
``DataArray.quantile`` and ``Dataset.quantile``), the only interpolation
method supported by this weighted version corresponds to the default "linear"
option of ``numpy.quantile``. This is "Type 7" option, described in Hyndman
and Fan (1996) [2]_. The implementation is largely inspired by a blog post
from A. Akinshin's [3]_.
Parameters
----------
q : float or sequence of float
Quantile to compute, which must be between 0 and 1 inclusive.
dim : str or sequence of str, optional
Dimension(s) over which to apply the weighted ``quantile``.
skipna : bool, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
keep_attrs : bool, optional
If True, the attributes (``attrs``) will be copied from the original
object to the new one. If False (default), the new object will be
returned without attributes.
Returns
-------
quantiles : {cls}
New {cls} object with weighted ``quantile`` applied to its data and
the indicated dimension(s) removed.
See Also
--------
numpy.nanquantile, pandas.Series.quantile, Dataset.quantile, DataArray.quantile
Notes
-----
Returns NaN if the ``weights`` sum to 0.0 along the reduced
dimension(s).
References
----------
.. [1] https://notstatschat.rbind.io/2020/08/04/weights-in-statistics/
.. [2] Hyndman, R. J. & Fan, Y. (1996). Sample Quantiles in Statistical Packages.
The American Statistician, 50(4), 361–365. https://doi.org/10.2307/2684934
.. [3] https://aakinshin.net/posts/weighted-quantiles
"""


if TYPE_CHECKING:
from .dataarray import DataArray
Expand Down Expand Up @@ -241,6 +308,141 @@ def _weighted_std(

return cast("DataArray", np.sqrt(self._weighted_var(da, dim, skipna)))

def _weighted_quantile(
self,
da: DataArray,
q: ArrayLike,
dim: Hashable | Iterable[Hashable] | None = None,
skipna: bool = None,
) -> DataArray:
"""Apply a weighted ``quantile`` to a DataArray along some dimension(s)."""

def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray:
"""Return the interpolation parameter."""
# Note that options are not yet exposed in the public API.
if method == "linear":
h = (n - 1) * q + 1
elif method == "interpolated_inverted_cdf":
h = n * q
elif method == "hazen":
h = n * q + 0.5
elif method == "weibull":
h = (n + 1) * q
elif method == "median_unbiased":
h = (n + 1 / 3) * q + 1 / 3
elif method == "normal_unbiased":
h = (n + 1 / 4) * q + 3 / 8
else:
raise ValueError(f"Invalid method: {method}.")
return h.clip(1, n)

def _weighted_quantile_1d(
data: np.ndarray,
weights: np.ndarray,
q: np.ndarray,
skipna: bool,
method: QUANTILE_METHODS = "linear",
) -> np.ndarray:

# This algorithm has been adapted from:
# https://aakinshin.net/posts/weighted-quantiles/#reference-implementation
is_nan = np.isnan(data)
if skipna:
# Remove nans from data and weights
not_nan = ~is_nan
data = data[not_nan]
weights = weights[not_nan]
elif is_nan.any():
# Return nan if data contains any nan
return np.full(q.size, np.nan)

# Filter out data (and weights) associated with zero weights, which also flattens them
nonzero_weights = weights != 0
data = data[nonzero_weights]
weights = weights[nonzero_weights]
n = data.size

if n == 0:
# Possibly empty after nan or zero weight filtering above
return np.full(q.size, np.nan)

# Kish's effective sample size
nw = weights.sum() ** 2 / (weights**2).sum()

# Sort data and weights
sorter = np.argsort(data)
data = data[sorter]
weights = weights[sorter]

# Normalize and sum the weights
weights = weights / weights.sum()
weights_cum = np.append(0, weights.cumsum())

# Vectorize the computation by transposing q with respect to weights
q = np.atleast_2d(q).T

# Get the interpolation parameter for each q
h = _get_h(nw, q, method)

# Find the samples contributing to the quantile computation (at *positions* between (h-1)/nw and h/nw)
u = np.maximum((h - 1) / nw, np.minimum(h / nw, weights_cum))

# Compute their relative weight
v = u * nw - h + 1
w = np.diff(v)

# Apply the weights
return (data * w).sum(axis=1)

if skipna is None and da.dtype.kind in "cfO":
skipna = True

q = np.atleast_1d(np.asarray(q, dtype=np.float64))

if q.ndim > 1:
raise ValueError("q must be a scalar or 1d")

if np.any((q < 0) | (q > 1)):
raise ValueError("q values must be between 0 and 1")

if dim is None:
dim = da.dims

if utils.is_scalar(dim):
dim = [dim]

# To satisfy mypy
dim = cast(Sequence, dim)

# need to align *and* broadcast
# - `_weighted_quantile_1d` requires arrays with the same shape
# - broadcast does an outer join, which can introduce NaN to weights
# - therefore we first need to do align(..., join="inner")

# TODO: use broadcast(..., join="inner") once available
# see https://github.com/pydata/xarray/issues/6304

da, weights = align(da, self.weights, join="inner")
da, weights = broadcast(da, weights)

result = apply_ufunc(
_weighted_quantile_1d,
da,
weights,
input_core_dims=[dim, dim],
output_core_dims=[["quantile"]],
output_dtypes=[np.float64],
dask_gufunc_kwargs=dict(output_sizes={"quantile": len(q)}),
dask="parallelized",
vectorize=True,
kwargs={"q": q, "skipna": skipna},
)

result = result.transpose("quantile", ...)
result = result.assign_coords(quantile=q).squeeze()

return result

def _implementation(self, func, dim, **kwargs):

raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`")
Expand Down Expand Up @@ -310,6 +512,19 @@ def std(
self._weighted_std, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)

def quantile(
self,
q: ArrayLike,
*,
dim: Hashable | Sequence[Hashable] | None = None,
keep_attrs: bool = None,
skipna: bool = True,
) -> T_Xarray:

return self._implementation(
self._weighted_quantile, q=q, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)

def __repr__(self):
"""provide a nice str repr of our Weighted object"""

Expand Down Expand Up @@ -360,6 +575,8 @@ def _inject_docstring(cls, cls_name):
cls=cls_name, fcn="std", on_zero="NaN"
)

cls.quantile.__doc__ = _WEIGHTED_QUANTILE_DOCSTRING_TEMPLATE.format(cls=cls_name)


_inject_docstring(DataArrayWeighted, "DataArray")
_inject_docstring(DatasetWeighted, "Dataset")
Loading

0 comments on commit 8a2fbb8

Please sign in to comment.