Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/weighted #2922

Merged
merged 60 commits into from
Mar 19, 2020
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
0f2da8e
weighted for DataArray
mathause Apr 26, 2019
5f64492
remove some commented code
mathause Apr 26, 2019
685e5c4
pep8 and faulty import tests
mathause Apr 26, 2019
c9d612d
add weighted sum, replace 0s in sum_of_wgt
mathause Apr 30, 2019
a20a4cf
weighted: overhaul tests
mathause Apr 30, 2019
26c24b6
weighted: pep8
mathause Apr 30, 2019
f3c6758
weighted: pep8 lines
mathause Apr 30, 2019
25c3c29
weighted update docs
mathause May 2, 2019
5d37d11
weighted: fix typo
mathause May 2, 2019
b1c572b
weighted: pep8
mathause May 8, 2019
d1d1f2c
undo changes to avoid merge conflict
mathause Oct 17, 2019
6be1414
Merge branch 'master' into feature/weighted
mathause Oct 17, 2019
059263c
add weighted to dataarray again
mathause Oct 17, 2019
8b1904b
remove super
mathause Oct 17, 2019
8cad145
overhaul core/weighted.py
mathause Oct 17, 2019
49d4e43
add DatasetWeighted class
mathause Oct 17, 2019
527256e
_maybe_get_all_dims return sorted tuple
mathause Oct 17, 2019
739568f
work on: test_weighted
mathause Oct 17, 2019
f01305d
black and flake8
mathause Oct 17, 2019
2e3880d
Apply suggestions from code review (docs)
mathause Oct 17, 2019
ae8d048
restructure interim
mathause Oct 18, 2019
dc7f605
restructure classes
mathause Oct 18, 2019
c646568
Merge branch 'master' into feature/weighted
mathause Dec 4, 2019
e2ad69e
update weighted.py
mathause Dec 4, 2019
bd4f048
black
mathause Dec 4, 2019
3c7695a
use map; add keep_attrs
mathause Dec 4, 2019
ef07edd
implement expected_weighted; update tests
mathause Dec 4, 2019
064b5a9
add whats new
mathause Dec 4, 2019
fec1a35
Merge branch 'master' into feature/weighted
mathause Dec 4, 2019
72c7942
undo changes to whats-new
mathause Dec 4, 2019
0e91411
F811: noqa where?
mathause Dec 4, 2019
1eb2913
api.rst
mathause Dec 5, 2019
118dfed
add to computation
mathause Dec 5, 2019
e08c921
small updates
mathause Dec 5, 2019
0fafe0b
add example to gallery
mathause Dec 5, 2019
a8d330d
typo
mathause Dec 5, 2019
ae0012f
another typo
mathause Dec 5, 2019
111259b
correct docstring in core/common.py
mathause Dec 5, 2019
5afc6f3
Merge branch 'master' into feature/weighted
mathause Jan 14, 2020
668b54b
typos
mathause Jan 14, 2020
d877022
adjust review
mathause Jan 14, 2020
ead681e
clean tests
mathause Jan 14, 2020
c4598ba
add test nonequal coords
mathause Jan 14, 2020
866fba5
comment on use of dot
mathause Jan 14, 2020
3cc00c1
fix erroneous merge
mathause Jan 14, 2020
8f34167
Merge branch 'master' into feature/weighted
mathause Jan 21, 2020
9f0a8cd
update tests
mathause Jan 21, 2020
98929f1
Merge branch 'master' into feature/weighted
mathause Mar 5, 2020
62c43e6
move example to notebook
mathause Mar 5, 2020
2e8aba2
move whats-new entry to 15.1
mathause Mar 5, 2020
d14f668
some doc updates
mathause Mar 5, 2020
7fa78ae
dot to own function
mathause Mar 5, 2020
3ebb9d4
simplify some tests
mathause Mar 5, 2020
f01d47a
Doc updates
dcherian Mar 17, 2020
4b184f6
very minor changes.
dcherian Mar 17, 2020
1e06adc
fix & add references
dcherian Mar 17, 2020
706579a
doc: return 0/NaN on 0 weights
mathause Mar 17, 2020
b2718db
Merge branch 'feature/weighted' of https://github.com/mathause/xarray…
mathause Mar 17, 2020
4c17108
Merge branch 'master' into feature/weighted
mathause Mar 17, 2020
8acc78e
Update xarray/core/common.py
dcherian Mar 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,25 @@ def groupby_bins(self, group, bins, right: bool = True, labels=None,
'precision': precision,
'include_lowest': include_lowest})

def weighted(self, weights):
"""
Weighted operations.

Parameters
mathause marked this conversation as resolved.
Show resolved Hide resolved
----------
weights : DataArray
An array of weights associated with the values in this Dataset.
Each value in the data contributes to the reduction operation
according to its associated weight.
mathause marked this conversation as resolved.
Show resolved Hide resolved

Note
----
Missing values in the weights are treated as 0 (i.e. no weight).

"""

return self._weighted_cls(self, weights)

def rolling(self, dim: Optional[Mapping[Hashable, int]] = None,
min_periods: Optional[int] = None, center: bool = False,
**dim_kwargs: int):
Expand Down
4 changes: 3 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from ..plot.plot import _PlotMethods
from . import (
computation, dtypes, groupby, indexing, ops, resample, rolling, utils)
computation, dtypes, groupby, indexing, ops, resample, rolling, utils,
weighted)
from .accessors import DatetimeAccessor
from .alignment import align, reindex_like_indexers
from .common import AbstractArray, DataWithCoords
Expand Down Expand Up @@ -160,6 +161,7 @@ class DataArray(AbstractArray, DataWithCoords):
_rolling_cls = rolling.DataArrayRolling
_coarsen_cls = rolling.DataArrayCoarsen
_resample_cls = resample.DataArrayResample
_weighted_cls = weighted.DataArrayWeighted

dt = property(DatetimeAccessor)

Expand Down
120 changes: 120 additions & 0 deletions xarray/core/weighted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@


_doc_ = """
Reduce this DataArray's data by a weighted `{fcn}` along some dimension(s).

Parameters
----------
dim : str or sequence of str, optional
Dimension(s) over which to apply the weighted `{fcn}`.
dcherian marked this conversation as resolved.
Show resolved Hide resolved
axis : int or sequence of int, optional
Axis(es) over which to apply the weighted `{fcn}`. Only one of the
'dim' and 'axis' arguments can be supplied. If neither are supplied,
then the weighted `{fcn}` is calculated over all axes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe remove as xr.dot does not provide this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree re removing (or am I missing something—why was this here?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency it would be nice to offer axis here. Originally sum was implemented as (weights * da).sum(...) and we got axis for free. With dot it is not as straightforward any more. Honestly, I never use axis with xarray, so I suppose it is fine to only implement it if anyone would ever request it...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Unless I'm missing something, let's remove axis

skipna : bool, optional
mathause marked this conversation as resolved.
Show resolved Hide resolved
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
Note: Missing values in the weights are replaced with 0 (i.e. no
weight).
keep_attrs : bool, optional
If True, the attributes (`attrs`) will be copied from the original
mathause marked this conversation as resolved.
Show resolved Hide resolved
object to the new one. If False (default), the new object will be
returned without attributes.
**kwargs : dict
Additional keyword arguments passed on to the appropriate array
function for calculating `{fcn}` on this object's data.

Returns
-------
reduced : DataArray
New DataArray object with weighted `{fcn}` applied to its data and
the indicated dimension(s) removed.
"""


class DataArrayWeighted(object):
def __init__(self, obj, weights):
"""
Weighted operations for DataArray.
mathause marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
obj : DataArray
Object over which the weighted reduction operation is applied.
weights : DataArray
An array of weights associated with the values in this Dataset.
Each value in the DataArray contributes to the reduction operation
according to its associated weight.

Note
----
Missing values in the weights are replaced with 0 (i.e. no weight).

"""

super(DataArrayWeighted, self).__init__()
mathause marked this conversation as resolved.
Show resolved Hide resolved

from .dataarray import DataArray

msg = "'weights' must be a DataArray"
assert isinstance(weights, DataArray), msg

self.obj = obj
self.weights = weights.fillna(0)

def sum_of_weights(self, dim=None, axis=None):
"""
Calcualte the sum of weights, accounting for missing values

Parameters
----------
dim : str or sequence of str, optional
Dimension(s) over which to sum the weights.
axis : int or sequence of int, optional
Axis(es) over which to sum the weights. Only one of the 'dim' and
'axis' arguments can be supplied. If neither are supplied, then
the weights are summed over all axes.

"""

# we need to mask DATA values that are nan; else the weights are wrong
masked_weights = self.weights.where(self.obj.notnull())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If weights has an additional dimension to those of self.obj, masked_weight may consume large memory.
Can we avoid this for example by separating mask from weights?

mask = xr.where(self.obj.notnull(), 1, 0)  # binary mask
sum_of_weights = xr.dot(mask, weights)

Do you think if the above is worth doing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done - I do not have any performance tests (memory or speed)


sum_of_weights = masked_weights.sum(dim=dim, axis=axis, skipna=True)

# find all weights that are valid (not 0)
valid_weights = sum_of_weights != 0.

# set invalid weights to nan
return sum_of_weights.where(valid_weights)

def sum(self, dim=None, axis=None, skipna=None, **kwargs):

# calculate weighted sum
return (self.obj * self.weights).sum(dim, axis=axis, skipna=skipna,
**kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the same reason above,

xr.where(self.obj, self.obj, 0).dot(self.weights)

may work if skipna is True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done - I do not have any performance tests (memory or speed)


def mean(self, dim=None, axis=None, skipna=None, **kwargs):

# get the sum of weights
sum_of_weights = self.sum_of_weights(dim=dim, axis=axis)

# get weighted sum
weighted_sum = self.sum(dim=dim, axis=axis, skipna=skipna, **kwargs)

# calculate weighted mean
return weighted_sum / sum_of_weights

def __repr__(self):
"""provide a nice str repr of our weighted object"""

msg = "{klass} with weights along dimensions: {weight_dims}"
return msg.format(klass=self.__class__.__name__,
weight_dims=", ".join(self.weights.dims))


# add docstrings
DataArrayWeighted.mean.__doc__ = _doc_.format(fcn='mean')
DataArrayWeighted.sum.__doc__ = _doc_.format(fcn='sum')
157 changes: 157 additions & 0 deletions xarray/tests/test_weighted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import pytest

import numpy as np

import xarray as xr
from xarray import (
DataArray,)

from xarray.tests import assert_equal, raises_regex


def test_weigted_non_DataArray_weights():

da = DataArray([1, 2])
with raises_regex(AssertionError, "'weights' must be a DataArray"):
da.weighted([1, 2])


@pytest.mark.parametrize('weights', ([1, 2], [np.nan, 2], [np.nan, np.nan]))
def test_weighted_weights_nan_replaced(weights):
# make sure nans are removed from weights

da = DataArray([1, 2])

expected = DataArray(weights).fillna(0.)
result = da.weighted(DataArray(weights)).weights

assert_equal(expected, result)


@pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 3),
([0, 2], 2),
([0, 0], np.nan),
([-1, 1], np.nan)))
def test_weigted_sum_of_weights_no_nan(weights, expected):

da = DataArray([1, 2])
weights = DataArray(weights)
result = da.weighted(weights).sum_of_weights()

expected = DataArray(expected)

assert_equal(expected, result)


@pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 2),
([0, 2], 2),
([0, 0], np.nan),
([-1, 1], 1)))
def test_weigted_sum_of_weights_nan(weights, expected):
mathause marked this conversation as resolved.
Show resolved Hide resolved

da = DataArray([np.nan, 2])
weights = DataArray(weights)
result = da.weighted(weights).sum_of_weights()

expected = DataArray(expected)

assert_equal(expected, result)


@pytest.mark.parametrize('da', ([1, 2], [1, np.nan], [np.nan, np.nan]))
@pytest.mark.parametrize('factor', [0, 1, 2, 3.14])
@pytest.mark.parametrize('skipna', (True, False))
def test_weighted_sum_equal_weights(da, factor, skipna):
# if all weights are 'f'; weighted sum is f times the ordinary sum

da = DataArray(da)
weights = xr.zeros_like(da) + factor
dcherian marked this conversation as resolved.
Show resolved Hide resolved

expected = da.sum(skipna=skipna) * factor
result = da.weighted(weights).sum(skipna=skipna)

assert_equal(expected, result)


@pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 5),
([0, 2], 4),
([0, 0], 0)))
def test_weighted_sum_no_nan(weights, expected):
da = DataArray([1, 2])

weights = DataArray(weights)
result = da.weighted(weights).sum()
expected = DataArray(expected)

assert_equal(expected, result)


@pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 4),
([0, 2], 4),
([1, 0], 0),
([0, 0], 0)))
@pytest.mark.parametrize('skipna', (True, False))
def test_weighted_sum_nan(weights, expected, skipna):
dcherian marked this conversation as resolved.
Show resolved Hide resolved
da = DataArray([np.nan, 2])

weights = DataArray(weights)
result = da.weighted(weights).sum(skipna=skipna)

if skipna:
expected = DataArray(expected)
else:
expected = DataArray(np.nan)

assert_equal(expected, result)


@pytest.mark.filterwarnings("ignore:Mean of empty slice")
@pytest.mark.parametrize('da', ([1, 2], [1, np.nan], [np.nan, np.nan]))
@pytest.mark.parametrize('skipna', (True, False))
def test_weigted_mean_equal_weights(da, skipna):
# if all weights are equal, should yield the same result as mean

da = DataArray(da)

# all weights as 1.
weights = xr.zeros_like(da) + 1

expected = da.mean(skipna=skipna)
result = da.weighted(weights).mean(skipna=skipna)

assert_equal(expected, result)


@pytest.mark.parametrize(('weights', 'expected'), (([4, 6], 1.6),
([0, 1], 2.0),
([0, 2], 2.0),
([0, 0], np.nan)))
def test_weigted_mean_no_nan(weights, expected):
mathause marked this conversation as resolved.
Show resolved Hide resolved

da = DataArray([1, 2])
weights = DataArray(weights)
expected = DataArray(expected)

result = da.weighted(weights).mean()

assert_equal(expected, result)


@pytest.mark.parametrize(('weights', 'expected'), (([4, 6], 2.0),
([0, 1], 2.0),
([0, 2], 2.0),
([0, 0], np.nan)))
@pytest.mark.parametrize('skipna', (True, False))
def test_weigted_mean_nan(weights, expected, skipna):

da = DataArray([np.nan, 2])
weights = DataArray(weights)

if skipna:
expected = DataArray(expected)
else:
expected = DataArray(np.nan)

result = da.weighted(weights).mean(skipna=skipna)

assert_equal(expected, result)