From 0f2da8e2fd768a3061afaae05c96effa72353ef5 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Fri, 26 Apr 2019 18:39:40 +0200 Subject: [PATCH 01/52] weighted for DataArray --- xarray/core/common.py | 18 ++++++ xarray/core/dataarray.py | 4 +- xarray/core/weighted.py | 110 ++++++++++++++++++++++++++++++++++ xarray/tests/test_weighted.py | 82 +++++++++++++++++++++++++ 4 files changed, 213 insertions(+), 1 deletion(-) create mode 100644 xarray/core/weighted.py create mode 100644 xarray/tests/test_weighted.py diff --git a/xarray/core/common.py b/xarray/core/common.py index b518e8431fd..ea9f782c2cb 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -539,6 +539,24 @@ def groupby_bins(self, group, bins, right: bool = True, labels=None, cut_kwargs={'right': right, 'labels': labels, 'precision': precision, 'include_lowest': include_lowest}) + def weighted(self, weights): + """ + Weighted operations. + + Parameters + ---------- + weights : DataArray + An array of weights associated with the values in this Dataset. + Each value in a contributes to the average according to its + associated weight. + + Note + ---- + Missing values in the weights are treated as 0 (i.e. no weight). + + """ + + return self._weighted_cls(self, weights) def rolling(self, dim: Optional[Mapping[Hashable, int]] = None, min_periods: Optional[int] = None, center: bool = False, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 39e9fc048e3..3c1c5f64419 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -8,7 +8,8 @@ from ..plot.plot import _PlotMethods from . import ( - computation, dtypes, groupby, indexing, ops, resample, rolling, utils) + computation, dtypes, groupby, indexing, ops, resample, rolling, utils, + weighted) from .accessors import DatetimeAccessor from .alignment import align, reindex_like_indexers from .common import AbstractArray, DataWithCoords @@ -160,6 +161,7 @@ class DataArray(AbstractArray, DataWithCoords): _rolling_cls = rolling.DataArrayRolling _coarsen_cls = rolling.DataArrayCoarsen _resample_cls = resample.DataArrayResample + _weighted_cls = weighted.DataArrayWeighted dt = property(DatetimeAccessor) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py new file mode 100644 index 00000000000..231efe23b68 --- /dev/null +++ b/xarray/core/weighted.py @@ -0,0 +1,110 @@ + +class DataArrayWeighted(object): + def __init__(self, obj, weights): + """ + Weighted operations for DataArray. + + Parameters + ---------- + obj : DataArray + Object to window. + weights : DataArray + An array of weights associated with the values in this Dataset. + Each value in a contributes to the average according to its + associated weight. + + Note + ---- + Missing values in the weights are treated as 0 (i.e. no weight). + + """ + + super(DataArrayWeighted, self).__init__() + + from .dataarray import DataArray + + msg = "'weights' must be a DataArray" + assert isinstance(weights, DataArray) + + self.obj = obj + self.weights = weights + + def sum_of_weights(self, dim=None, axis=None): + """ + Calcualte the sum of weights accounting for missing values + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to sum the weights. + axis : int or sequence of int, optional + Axis(es) over which to sum the weights. Only one of the 'dim' and + 'axis' arguments can be supplied. If neither are supplied, then + the weights are summed over all axes. + + """ + + # we need to mask values that are nan; else the weights are wrong + notnull = self.obj.notnull() + + return self.weights.where(notnull).sum(dim=dim, axis=axis, skipna=True) + + + def mean(self, dim=None, axis=None, skipna=None, **kwargs): + """ + Reduce this DataArray's data by a weighted `mean` along some dimension(s). + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to apply the weighted `mean`. + axis : int or sequence of int, optional + Axis(es) over which to apply the weighted `mean`. Only one of the + 'dim'and 'axis' arguments can be supplied. If neither are supplied, + then the weighted `mean` is calculated over all axes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + Note: Missing values in the weights are always skipped. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `mean` on this object's data. + + Returns + ------- + reduced : DataArray + New DataArray object with weighted `mean` applied to its data and + the indicated dimension(s) removed. + """ + + # get the sum of weights of the dims + sum_of_weights = self.sum_of_weights(dim=dim, axis=axis) + + # normalize weights to 1 + w = self.weights / sum_of_weights + + obj = self.obj + + # check if invalid values are masked by weights that are 0 + # e.g. values = [1 NaN]; weights = [1, 0], should return 1 + # if not skipna: + # # w = w.fillna(0) + # sel = ((w.isnull()) & (obj.isnull())) + # if sel.any(): + # obj = obj.where(sel, 0) + + + w = w.fillna(0) + + # calculate weighted mean + weighted = (obj * w).sum(dim, axis=axis, skipna=skipna, **kwargs) + + # set to NaN if sum_of_weights is zero + invalid_weights = sum_of_weights == 0 + return weighted.where(~ invalid_weights) diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py new file mode 100644 index 00000000000..975e7e2b541 --- /dev/null +++ b/xarray/tests/test_weighted.py @@ -0,0 +1,82 @@ +import pytest + +import numpy as np + +import xarray as xr +from xarray import ( + DataArray,) + +from xarray.tests import assert_allclose + +def test_weigted_non_DataArray_weights(): + + da = DataArray([1, 2]) + with pytest.raises(AssertionError): + da.weighted([1, 2]) + +@pytest.mark.parametrize('da', ([1, 2], [1, np.nan])) +@pytest.mark.parametrize('weights', ([1, 2], [np.nan, 2])) +def test_weigted_sum_of_weights(da, weights): + + da = DataArray(da) + weights = DataArray(weights) + + expected = weights.where(~ da.isnull()).sum() + result = da.weighted(weights).sum_of_weights() + + assert_equal(expected, result) + + +@pytest.mark.parametrize('da', ([1, 2], [1, np.nan])) +@pytest.mark.parametrize('skipna', (True, False)) +def test_weigted_mean_equal_weights(da, skipna): + # if all weights are equal, should yield the same result as mean + + da = DataArray(da) + + weights = xr.zeros_like(da) + 1 + + expected = da.mean(skipna=skipna) + result = da.weighted(weights).mean(skipna=skipna) + + assert_equal(expected, result) + + + +def expected_weighted(da, weights, skipna): + np.warnings.filterwarnings('ignore') + + # all NaN's in weights are replaced + weights = np.nan_to_num(weights) + + # + if np.all(np.isnan(da)): + expected = np.nan + elif skipna: + da = np.ma.masked_invalid(da) + expected = np.ma.average(da, weights=weights) + else: + expected = np.ma.average(da, weights=weights) + + expected = np.asarray(expected) + + expected[np.isinf(expected)] = np.nan + + return DataArray(expected) + + +@pytest.mark.parametrize('da', ([1, 2], [1, np.nan], [np.nan, np.nan])) +@pytest.mark.parametrize('weights', ([4, 6], [-1, 1], [1, 0], [0, 1], [1, np.nan], [np.nan, np.nan])) +@pytest.mark.parametrize('skipna', (True, False)) +def test_weigted_mean(da, weights, skipna): + + expected = expected_weighted(da, weights, skipna) + + da = DataArray(da) + weights = DataArray(weights) + + result = da.weighted(weights).mean(skipna=skipna) + + assert_equal(expected, result) + + From 5f64492b47474ee4f555e2b1ecc48d01c7444e43 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Fri, 26 Apr 2019 19:08:02 +0200 Subject: [PATCH 02/52] remove some commented code --- xarray/core/weighted.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 231efe23b68..2e782d9fc7c 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -91,15 +91,6 @@ def mean(self, dim=None, axis=None, skipna=None, **kwargs): obj = self.obj - # check if invalid values are masked by weights that are 0 - # e.g. values = [1 NaN]; weights = [1, 0], should return 1 - # if not skipna: - # # w = w.fillna(0) - # sel = ((w.isnull()) & (obj.isnull())) - # if sel.any(): - # obj = obj.where(sel, 0) - - w = w.fillna(0) # calculate weighted mean From 685e5c45a0a38336d51f9073cfda7e170aec6454 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Fri, 26 Apr 2019 19:15:03 +0200 Subject: [PATCH 03/52] pep8 and faulty import tests --- xarray/core/common.py | 1 + xarray/core/weighted.py | 14 +++++++------- xarray/tests/test_weighted.py | 18 ++++++++---------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index ea9f782c2cb..400ef6965a7 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -539,6 +539,7 @@ def groupby_bins(self, group, bins, right: bool = True, labels=None, cut_kwargs={'right': right, 'labels': labels, 'precision': precision, 'include_lowest': include_lowest}) + def weighted(self, weights): """ Weighted operations. diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 2e782d9fc7c..e9ac1d75827 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,5 +1,5 @@ -class DataArrayWeighted(object): +class DataArrayWeighted(object): def __init__(self, obj, weights): """ Weighted operations for DataArray. @@ -18,9 +18,9 @@ def __init__(self, obj, weights): Missing values in the weights are treated as 0 (i.e. no weight). """ - + super(DataArrayWeighted, self).__init__() - + from .dataarray import DataArray msg = "'weights' must be a DataArray" @@ -46,9 +46,8 @@ def sum_of_weights(self, dim=None, axis=None): # we need to mask values that are nan; else the weights are wrong notnull = self.obj.notnull() - + return self.weights.where(notnull).sum(dim=dim, axis=axis, skipna=True) - def mean(self, dim=None, axis=None, skipna=None, **kwargs): """ @@ -88,13 +87,14 @@ def mean(self, dim=None, axis=None, skipna=None, **kwargs): # normalize weights to 1 w = self.weights / sum_of_weights - + obj = self.obj w = w.fillna(0) # calculate weighted mean - weighted = (obj * w).sum(dim, axis=axis, skipna=skipna, **kwargs) + weighted = (obj * w).sum(dim, axis=axis, skipna=skipna, + **kwargs) # set to NaN if sum_of_weights is zero invalid_weights = sum_of_weights == 0 diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 975e7e2b541..3a94765888f 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -6,7 +6,8 @@ from xarray import ( DataArray,) -from xarray.tests import assert_allclose +from xarray.tests import assert_equal + def test_weigted_non_DataArray_weights(): @@ -14,6 +15,7 @@ def test_weigted_non_DataArray_weights(): with pytest.raises(AssertionError): da.weighted([1, 2]) + @pytest.mark.parametrize('da', ([1, 2], [1, np.nan])) @pytest.mark.parametrize('weights', ([1, 2], [np.nan, 2])) def test_weigted_sum_of_weights(da, weights): @@ -41,15 +43,12 @@ def test_weigted_mean_equal_weights(da, skipna): assert_equal(expected, result) - - def expected_weighted(da, weights, skipna): np.warnings.filterwarnings('ignore') - + # all NaN's in weights are replaced weights = np.nan_to_num(weights) - - # + if np.all(np.isnan(da)): expected = np.nan elif skipna: @@ -57,7 +56,7 @@ def expected_weighted(da, weights, skipna): expected = np.ma.average(da, weights=weights) else: expected = np.ma.average(da, weights=weights) - + expected = np.asarray(expected) expected[np.isinf(expected)] = np.nan @@ -66,7 +65,8 @@ def expected_weighted(da, weights, skipna): @pytest.mark.parametrize('da', ([1, 2], [1, np.nan], [np.nan, np.nan])) -@pytest.mark.parametrize('weights', ([4, 6], [-1, 1], [1, 0], [0, 1], [1, np.nan], [np.nan, np.nan])) +@pytest.mark.parametrize('weights', ([4, 6], [-1, 1], [1, 0], [0, 1], + [1, np.nan], [np.nan, np.nan])) @pytest.mark.parametrize('skipna', (True, False)) def test_weigted_mean(da, weights, skipna): @@ -78,5 +78,3 @@ def test_weigted_mean(da, weights, skipna): result = da.weighted(weights).mean(skipna=skipna) assert_equal(expected, result) - - From c9d612dd38a805f6a3b8a54e2a3ba8d87ee862be Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 30 Apr 2019 18:28:18 +0200 Subject: [PATCH 04/52] add weighted sum, replace 0s in sum_of_wgt --- xarray/core/weighted.py | 115 +++++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 49 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index e9ac1d75827..26fcfe84a0c 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,4 +1,37 @@ + +_doc_ = """ + Reduce this DataArray's data by a weighted `{fcn}` along some dimension(s). + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to apply the weighted `{fcn}`. + axis : int or sequence of int, optional + Axis(es) over which to apply the weighted `{fcn}`. Only one of the + 'dim' and 'axis' arguments can be supplied. If neither are supplied, + then the weighted `{fcn}` is calculated over all axes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + Note: Missing values in the weights are always skipped. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `{fcn}` on this object's data. + + Returns + ------- + reduced : DataArray + New DataArray object with weighted `{fcn}` applied to its data and + the indicated dimension(s) removed. + """ + class DataArrayWeighted(object): def __init__(self, obj, weights): """ @@ -15,7 +48,7 @@ def __init__(self, obj, weights): Note ---- - Missing values in the weights are treated as 0 (i.e. no weight). + Missing values in the weights are replaced with 0 (i.e. no weight). """ @@ -24,14 +57,14 @@ def __init__(self, obj, weights): from .dataarray import DataArray msg = "'weights' must be a DataArray" - assert isinstance(weights, DataArray) + assert isinstance(weights, DataArray), msg self.obj = obj - self.weights = weights + self.weights = weights.fillna(0) def sum_of_weights(self, dim=None, axis=None): """ - Calcualte the sum of weights accounting for missing values + Calcualte the sum of weights, accounting for missing values Parameters ---------- @@ -44,58 +77,42 @@ def sum_of_weights(self, dim=None, axis=None): """ - # we need to mask values that are nan; else the weights are wrong - notnull = self.obj.notnull() + # we need to mask DATA values that are nan; else the weights are wrong + masked_weights = self.weights.where(self.obj.notnull()) - return self.weights.where(notnull).sum(dim=dim, axis=axis, skipna=True) + sum_of_weights = masked_weights.sum(dim=dim, axis=axis, skipna=True) + + # find all weights that are valid (not 0) + valid_weights = sum_of_weights != 0. - def mean(self, dim=None, axis=None, skipna=None, **kwargs): - """ - Reduce this DataArray's data by a weighted `mean` along some dimension(s). + # set invalid weights to nan + return sum_of_weights.where(valid_weights) - Parameters - ---------- - dim : str or sequence of str, optional - Dimension(s) over which to apply the weighted `mean`. - axis : int or sequence of int, optional - Axis(es) over which to apply the weighted `mean`. Only one of the - 'dim'and 'axis' arguments can be supplied. If neither are supplied, - then the weighted `mean` is calculated over all axes. - skipna : bool, optional - If True, skip missing values (as marked by NaN). By default, only - skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been - implemented (object, datetime64 or timedelta64). - Note: Missing values in the weights are always skipped. - keep_attrs : bool, optional - If True, the attributes (`attrs`) will be copied from the original - object to the new one. If False (default), the new object will be - returned without attributes. - **kwargs : dict - Additional keyword arguments passed on to the appropriate array - function for calculating `mean` on this object's data. - - Returns - ------- - reduced : DataArray - New DataArray object with weighted `mean` applied to its data and - the indicated dimension(s) removed. - """ + def sum(self, dim=None, axis=None, skipna=None, **kwargs): + + # calculate weighted sum + return (self.obj * self.weights).sum(dim, axis=axis, skipna=skipna, + **kwargs) + + def mean(self, dim=None, axis=None, skipna=None, **kwargs): - # get the sum of weights of the dims + # get the sum of weights sum_of_weights = self.sum_of_weights(dim=dim, axis=axis) - # normalize weights to 1 - w = self.weights / sum_of_weights + # get weighted sum + weighted_sum = self.sum(dim=dim, axis=axis, skipna=skipna, **kwargs) - obj = self.obj + # calculate weighted mean + return weighted_sum / sum_of_weights - w = w.fillna(0) - # calculate weighted mean - weighted = (obj * w).sum(dim, axis=axis, skipna=skipna, - **kwargs) + def __repr__(self): + """provide a nice str repr of our weighted object""" + + msg = "{klass} with weights along dimensions: {weight_dims}" + return msg.format(klass=self.__class__.__name__, + weight_dims=", ".join(self.weights.dims)) - # set to NaN if sum_of_weights is zero - invalid_weights = sum_of_weights == 0 - return weighted.where(~ invalid_weights) +# add docstrings +DataArrayWeighted.mean.__doc__ = _doc_.format(fcn='mean') +DataArrayWeighted.sum.__doc__ = _doc_.format(fcn='sum') From a20a4cf6501be54e70b0402553dc666f93f386f8 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 30 Apr 2019 18:28:44 +0200 Subject: [PATCH 05/52] weighted: overhaul tests --- xarray/tests/test_weighted.py | 134 ++++++++++++++++++++++++++-------- 1 file changed, 102 insertions(+), 32 deletions(-) diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 3a94765888f..67abe55b1ba 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -6,36 +6,108 @@ from xarray import ( DataArray,) -from xarray.tests import assert_equal - +from xarray.tests import assert_equal, raises_regex def test_weigted_non_DataArray_weights(): da = DataArray([1, 2]) - with pytest.raises(AssertionError): + with raises_regex(AssertionError, "'weights' must be a DataArray"): da.weighted([1, 2]) -@pytest.mark.parametrize('da', ([1, 2], [1, np.nan])) -@pytest.mark.parametrize('weights', ([1, 2], [np.nan, 2])) -def test_weigted_sum_of_weights(da, weights): +@pytest.mark.parametrize('weights', ([1, 2], [np.nan, 2], [np.nan, np.nan])) +def test_weighted_weights_nan_replaced(weights): + # make sure nans are removed from weights - da = DataArray(da) + da = DataArray([1, 2]) + + expected = DataArray(weights).fillna(0.) + result = da.weighted(DataArray(weights)).weights + + assert_equal(expected, result) + + +@pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 3), + ([0, 2], 2), + ([0, 0], np.nan), + ([-1, 1], np.nan))) +def test_weigted_sum_of_weights_no_nan(weights, expected): + + da = DataArray([1, 2]) weights = DataArray(weights) + result = da.weighted(weights).sum_of_weights() + + expected = DataArray(expected) - expected = weights.where(~ da.isnull()).sum() + assert_equal(expected, result) + +@pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 2), + ([0, 2], 2), + ([0, 0], np.nan), + ([-1, 1], 1))) +def test_weigted_sum_of_weights_nan(weights, expected): + + da = DataArray([np.nan, 2]) + weights = DataArray(weights) result = da.weighted(weights).sum_of_weights() + expected = DataArray(expected) + + assert_equal(expected, result) + +@pytest.mark.parametrize('da', ([1, 2], [1, np.nan], [np.nan, np.nan])) +@pytest.mark.parametrize('factor', [0, 1, 2, 3.14]) +@pytest.mark.parametrize('skipna', (True, False)) +def test_weighted_sum_equal_weights(da, factor, skipna): + # if all weights are 'f'; weighted sum is f times the ordinary sum + + da = DataArray(da) + weights = xr.zeros_like(da) + factor + + expected = da.sum(skipna=skipna) * factor + result = da.weighted(weights).sum(skipna=skipna) + assert_equal(expected, result) +@pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 5), + ([0, 2], 4), + ([0, 0], 0))) +def test_weighted_sum_no_nan(weights, expected): + da = DataArray([1, 2]) + + weights = DataArray(weights) + result = da.weighted(weights).sum() + expected = DataArray(expected) -@pytest.mark.parametrize('da', ([1, 2], [1, np.nan])) + assert_equal(expected, result) + +@pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 4), + ([0, 2], 4), + ([1, 0], 0), + ([0, 0], 0))) +@pytest.mark.parametrize('skipna', (True, False)) +def test_weighted_sum_nan(weights, expected, skipna): + da = DataArray([np.nan, 2]) + + weights = DataArray(weights) + result = da.weighted(weights).sum(skipna=skipna) + + if skipna: + expected = DataArray(expected) + else: + expected = DataArray(np.nan) + + assert_equal(expected, result) + +@pytest.mark.filterwarnings("ignore:Mean of empty slice") +@pytest.mark.parametrize('da', ([1, 2], [1, np.nan], [np.nan, np.nan])) @pytest.mark.parametrize('skipna', (True, False)) def test_weigted_mean_equal_weights(da, skipna): # if all weights are equal, should yield the same result as mean da = DataArray(da) + # all weights as 1. weights = xr.zeros_like(da) + 1 expected = da.mean(skipna=skipna) @@ -43,38 +115,36 @@ def test_weigted_mean_equal_weights(da, skipna): assert_equal(expected, result) -def expected_weighted(da, weights, skipna): - np.warnings.filterwarnings('ignore') - - # all NaN's in weights are replaced - weights = np.nan_to_num(weights) - - if np.all(np.isnan(da)): - expected = np.nan - elif skipna: - da = np.ma.masked_invalid(da) - expected = np.ma.average(da, weights=weights) - else: - expected = np.ma.average(da, weights=weights) +@pytest.mark.parametrize(('weights', 'expected'), (([4, 6], 1.6), + ([0, 1], 2.0), + ([0, 2], 2.0), + ([0, 0], np.nan))) +def test_weigted_mean_no_nan(weights, expected): - expected = np.asarray(expected) + da = DataArray([1, 2]) + weights = DataArray(weights) + expected = DataArray(expected) - expected[np.isinf(expected)] = np.nan + result = da.weighted(weights).mean() - return DataArray(expected) + assert_equal(expected, result) +@pytest.mark.parametrize(('weights', 'expected'), (([4, 6], 2.0), + ([0, 1], 2.0), + ([0, 2], 2.0), + ([0, 0], np.nan))) -@pytest.mark.parametrize('da', ([1, 2], [1, np.nan], [np.nan, np.nan])) -@pytest.mark.parametrize('weights', ([4, 6], [-1, 1], [1, 0], [0, 1], - [1, np.nan], [np.nan, np.nan])) @pytest.mark.parametrize('skipna', (True, False)) -def test_weigted_mean(da, weights, skipna): - - expected = expected_weighted(da, weights, skipna) +def test_weigted_mean_nan(weights, expected, skipna): - da = DataArray(da) + da = DataArray([np.nan, 2]) weights = DataArray(weights) + if skipna: + expected = DataArray(expected) + else: + expected = DataArray(np.nan) + result = da.weighted(weights).mean(skipna=skipna) assert_equal(expected, result) From 26c24b60b2a687d80e36b1e582811509428e3483 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 30 Apr 2019 18:34:37 +0200 Subject: [PATCH 06/52] weighted: pep8 --- xarray/core/common.py | 2 +- xarray/core/weighted.py | 9 +++++---- xarray/tests/test_weighted.py | 13 ++++++++++--- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 400ef6965a7..cb5f04ce59a 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -539,7 +539,7 @@ def groupby_bins(self, group, bins, right: bool = True, labels=None, cut_kwargs={'right': right, 'labels': labels, 'precision': precision, 'include_lowest': include_lowest}) - + def weighted(self, weights): """ Weighted operations. diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 26fcfe84a0c..49a62b8ef74 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -32,6 +32,7 @@ the indicated dimension(s) removed. """ + class DataArrayWeighted(object): def __init__(self, obj, weights): """ @@ -81,7 +82,7 @@ def sum_of_weights(self, dim=None, axis=None): masked_weights = self.weights.where(self.obj.notnull()) sum_of_weights = masked_weights.sum(dim=dim, axis=axis, skipna=True) - + # find all weights that are valid (not 0) valid_weights = sum_of_weights != 0. @@ -89,7 +90,7 @@ def sum_of_weights(self, dim=None, axis=None): return sum_of_weights.where(valid_weights) def sum(self, dim=None, axis=None, skipna=None, **kwargs): - + # calculate weighted sum return (self.obj * self.weights).sum(dim, axis=axis, skipna=skipna, **kwargs) @@ -105,14 +106,14 @@ def mean(self, dim=None, axis=None, skipna=None, **kwargs): # calculate weighted mean return weighted_sum / sum_of_weights - def __repr__(self): """provide a nice str repr of our weighted object""" msg = "{klass} with weights along dimensions: {weight_dims}" - return msg.format(klass=self.__class__.__name__, + return msg.format(klass=self.__class__.__name__, weight_dims=", ".join(self.weights.dims)) # add docstrings DataArrayWeighted.mean.__doc__ = _doc_.format(fcn='mean') DataArrayWeighted.sum.__doc__ = _doc_.format(fcn='sum') + diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 67abe55b1ba..ca403914599 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -8,6 +8,7 @@ from xarray.tests import assert_equal, raises_regex + def test_weigted_non_DataArray_weights(): da = DataArray([1, 2]) @@ -41,6 +42,7 @@ def test_weigted_sum_of_weights_no_nan(weights, expected): assert_equal(expected, result) + @pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 2), ([0, 2], 2), ([0, 0], np.nan), @@ -55,6 +57,7 @@ def test_weigted_sum_of_weights_nan(weights, expected): assert_equal(expected, result) + @pytest.mark.parametrize('da', ([1, 2], [1, np.nan], [np.nan, np.nan])) @pytest.mark.parametrize('factor', [0, 1, 2, 3.14]) @pytest.mark.parametrize('skipna', (True, False)) @@ -69,8 +72,9 @@ def test_weighted_sum_equal_weights(da, factor, skipna): assert_equal(expected, result) -@pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 5), - ([0, 2], 4), + +@pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 5), + ([0, 2], 4), ([0, 0], 0))) def test_weighted_sum_no_nan(weights, expected): da = DataArray([1, 2]) @@ -81,6 +85,7 @@ def test_weighted_sum_no_nan(weights, expected): assert_equal(expected, result) + @pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 4), ([0, 2], 4), ([1, 0], 0), @@ -99,6 +104,7 @@ def test_weighted_sum_nan(weights, expected, skipna): assert_equal(expected, result) + @pytest.mark.filterwarnings("ignore:Mean of empty slice") @pytest.mark.parametrize('da', ([1, 2], [1, np.nan], [np.nan, np.nan])) @pytest.mark.parametrize('skipna', (True, False)) @@ -115,6 +121,7 @@ def test_weigted_mean_equal_weights(da, skipna): assert_equal(expected, result) + @pytest.mark.parametrize(('weights', 'expected'), (([4, 6], 1.6), ([0, 1], 2.0), ([0, 2], 2.0), @@ -129,11 +136,11 @@ def test_weigted_mean_no_nan(weights, expected): assert_equal(expected, result) + @pytest.mark.parametrize(('weights', 'expected'), (([4, 6], 2.0), ([0, 1], 2.0), ([0, 2], 2.0), ([0, 0], np.nan))) - @pytest.mark.parametrize('skipna', (True, False)) def test_weigted_mean_nan(weights, expected, skipna): From f3c6758de7a5eabee2a3cbf7205fb40275d4ff05 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 30 Apr 2019 18:36:01 +0200 Subject: [PATCH 07/52] weighted: pep8 lines --- xarray/core/weighted.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 49a62b8ef74..9cf644aef68 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -113,7 +113,7 @@ def __repr__(self): return msg.format(klass=self.__class__.__name__, weight_dims=", ".join(self.weights.dims)) + # add docstrings DataArrayWeighted.mean.__doc__ = _doc_.format(fcn='mean') DataArrayWeighted.sum.__doc__ = _doc_.format(fcn='sum') - From 25c3c29c1ce22daf57cfce428eac63b1ec1d115a Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 2 May 2019 20:01:45 +0200 Subject: [PATCH 08/52] weighted update docs --- xarray/core/common.py | 4 ++-- xarray/core/weighted.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index cb5f04ce59a..5db09e4b595 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -548,8 +548,8 @@ def weighted(self, weights): ---------- weights : DataArray An array of weights associated with the values in this Dataset. - Each value in a contributes to the average according to its - associated weight. + Each value in the data contributes to the reduction operation + according to its associated weight. Note ---- diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 9cf644aef68..df00f70241f 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -16,7 +16,8 @@ skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). - Note: Missing values in the weights are always skipped. + Note: Missing values in the weights are replaced with 0 (i.e. no + weight). keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be @@ -41,11 +42,11 @@ def __init__(self, obj, weights): Parameters ---------- obj : DataArray - Object to window. + Object over which the weighted reduction operation is applied. weights : DataArray An array of weights associated with the values in this Dataset. - Each value in a contributes to the average according to its - associated weight. + Each value in the DataArray contributes to the reduction operation + according to its associated weight. Note ---- From 5d37d11d0e391c83bb28618e4e5f82fddb43cdd6 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 2 May 2019 20:58:26 +0200 Subject: [PATCH 09/52] weighted: fix typo --- xarray/tests/test_weighted.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index ca403914599..65357ee2440 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -126,7 +126,7 @@ def test_weigted_mean_equal_weights(da, skipna): ([0, 1], 2.0), ([0, 2], 2.0), ([0, 0], np.nan))) -def test_weigted_mean_no_nan(weights, expected): +def test_weighted_mean_no_nan(weights, expected): da = DataArray([1, 2]) weights = DataArray(weights) From b1c572b0aeff1e24735cbd52a4f7168df7bcad9c Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 8 May 2019 10:30:08 +0200 Subject: [PATCH 10/52] weighted: pep8 --- xarray/core/weighted.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index df00f70241f..d64506dd479 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -16,7 +16,7 @@ skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). - Note: Missing values in the weights are replaced with 0 (i.e. no + Note: Missing values in the weights are replaced with 0 (i.e. no weight). keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original From d1d1f2c6dc4f97d6d37194c702b8611bf54a64c9 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 17 Oct 2019 16:36:18 +0200 Subject: [PATCH 11/52] undo changes to avoid merge conflict --- xarray/core/common.py | 19 ------------------- xarray/core/dataarray.py | 4 +--- 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 5db09e4b595..b518e8431fd 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -540,25 +540,6 @@ def groupby_bins(self, group, bins, right: bool = True, labels=None, 'precision': precision, 'include_lowest': include_lowest}) - def weighted(self, weights): - """ - Weighted operations. - - Parameters - ---------- - weights : DataArray - An array of weights associated with the values in this Dataset. - Each value in the data contributes to the reduction operation - according to its associated weight. - - Note - ---- - Missing values in the weights are treated as 0 (i.e. no weight). - - """ - - return self._weighted_cls(self, weights) - def rolling(self, dim: Optional[Mapping[Hashable, int]] = None, min_periods: Optional[int] = None, center: bool = False, **dim_kwargs: int): diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 3c1c5f64419..39e9fc048e3 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -8,8 +8,7 @@ from ..plot.plot import _PlotMethods from . import ( - computation, dtypes, groupby, indexing, ops, resample, rolling, utils, - weighted) + computation, dtypes, groupby, indexing, ops, resample, rolling, utils) from .accessors import DatetimeAccessor from .alignment import align, reindex_like_indexers from .common import AbstractArray, DataWithCoords @@ -161,7 +160,6 @@ class DataArray(AbstractArray, DataWithCoords): _rolling_cls = rolling.DataArrayRolling _coarsen_cls = rolling.DataArrayCoarsen _resample_cls = resample.DataArrayResample - _weighted_cls = weighted.DataArrayWeighted dt = property(DatetimeAccessor) From 059263c8a36e4bde5ed98bf8b0da1bd6488bf5ca Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 17 Oct 2019 17:06:25 +0200 Subject: [PATCH 12/52] add weighted to dataarray again --- xarray/core/common.py | 17 +++++++++++++++++ xarray/core/dataarray.py | 4 +++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index a762f7fbed9..3e423988a03 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -730,6 +730,23 @@ def groupby_bins( }, ) + def weighted(self, weights): + """ + Weighted operations. + Parameters + ---------- + weights : DataArray + An array of weights associated with the values in this Dataset. + Each value in the data contributes to the reduction operation + according to its associated weight. + Note + ---- + Missing values in the weights are treated as 0 (i.e. no weight). + """ + + return self._weighted_cls(self, weights) + + def rolling( self, dim: Mapping[Hashable, int] = None, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 4a48f13b86d..deb0a9c0e1d 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -33,6 +33,7 @@ resample, rolling, utils, + weighted, ) from .accessor_dt import DatetimeAccessor from .accessor_str import StringAccessor @@ -269,7 +270,8 @@ class DataArray(AbstractArray, DataWithCoords): _rolling_cls = rolling.DataArrayRolling _coarsen_cls = rolling.DataArrayCoarsen _resample_cls = resample.DataArrayResample - + _weighted_cls = weighted.DataArrayWeighted + __default = ReprObject("") dt = property(DatetimeAccessor) From 8b1904bca3dbd8a9c5146fbfe2bfc6c8cd573322 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 17 Oct 2019 17:06:41 +0200 Subject: [PATCH 13/52] remove super --- xarray/core/weighted.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index d64506dd479..d8ebb11da3f 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -54,8 +54,6 @@ def __init__(self, obj, weights): """ - super(DataArrayWeighted, self).__init__() - from .dataarray import DataArray msg = "'weights' must be a DataArray" From 8cad145f75f23fa609c4d82fd7e6f41e0e29b90a Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 17 Oct 2019 22:03:24 +0200 Subject: [PATCH 14/52] overhaul core/weighted.py --- xarray/core/weighted.py | 281 +++++++++++++++++++++++++++++++++------- 1 file changed, 234 insertions(+), 47 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index d8ebb11da3f..79f9c88b763 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,7 +1,11 @@ +from .computation import where, dot +from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Tuple, Union, overload +if TYPE_CHECKING: + from .dataarray import DataArray, Dataset -_doc_ = """ - Reduce this DataArray's data by a weighted `{fcn}` along some dimension(s). +_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """ + Reduce this {cls}'s data by a weighted `{fcn}` along some dimension(s). Parameters ---------- @@ -28,24 +32,154 @@ Returns ------- - reduced : DataArray - New DataArray object with weighted `{fcn}` applied to its data and + reduced : {cls} + New {cls} object with weighted `{fcn}` applied to its data and the indicated dimension(s) removed. """ +_SUM_OF_WEIGHTS_DOCSTRING = """ + Calcualte the sum of weights, accounting for missing values -class DataArrayWeighted(object): - def __init__(self, obj, weights): + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to sum the weights. + axis : int or sequence of int, optional + Axis(es) over which to sum the weights. Only one of the 'dim' and + 'axis' arguments can be supplied. If neither are supplied, then + the weights are summed over all axes. + + Returns + ------- + reduced : {cls} + New {cls} object with the sum of the weights over the given dimension. + + + """ + + +# functions for weighted operations for one DataArray +# NOTE: weights must not contain missing values (this is taken care of in the +# DataArrayWeighted and DatasetWeighted cls) + + +def _maybe_get_all_dims( + dims: Optional[Union[Hashable, Iterable[Hashable]]], dims1: Tuple[Hashable, ...], dims2: Tuple[Hashable, ...] +): + """ the union of all dimensions + + `dims=None` behaves differently in `dot` and `sum`, so we have to apply + `dot` over the union of the dimensions + + """ + + if dims is None: + dims = set(dims1) | set(dims2) + + return dims + + +def _sum_of_weights( + da: "DataArray", + weights: "DataArray", + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + axis=None, +) -> "DataArray": + """ Calcualte the sum of weights, accounting for missing values """ + + # we need to mask DATA values that are nan; else the weights are wrong + mask = where(da.notnull(), 1, 0) # binary mask + + # need to infer dims as we use `dot` + dims = _maybe_get_all_dims(dim, da.dims, weights.dims) + + # use `dot` to avoid creating large da's + sum_of_weights = dot(mask, weights, dims=dims) + + # find all weights that are valid (not 0) + valid_weights = sum_of_weights != 0.0 + + # set invalid weights to nan + return sum_of_weights.where(valid_weights) + + +def _weighted_sum( + da: "DataArray", + weights: "DataArray", + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + axis=None, + skipna: Optional[bool] = None, + **kwargs +) -> "DataArray": + """Reduce a DataArray by a by a weighted `sum` along some dimension(s).""" + + # need to infer dims as we use `dot` + dims = _maybe_get_all_dims(dim, da.dims, weights.dims) + + # use `dot` to avoid creating large da's + + # need to mask invalid DATA as dot does not implement skipna + if skipna or skipna is None: + return where(da.isnull(), 0.0, da).dot(weights, dims=dims) + + return dot(da, weights, dims=dims) + + +def _weighted_mean( + da: "DataArray", + weights: "DataArray", + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + axis=None, + skipna: Optional[bool] = None, + **kwargs +) -> "DataArray": + """Reduce a DataArray by a weighted `mean` along some dimension(s).""" + + # get weighted sum + weighted_sum = _weighted_sum( + da, weights, dim=dim, axis=axis, skipna=skipna, **kwargs + ) + + # get the sum of weights + sum_of_weights = _sum_of_weights(da, weights, dim=dim, axis=axis) + + # calculate weighted mean + return weighted_sum / sum_of_weights + + +class Weighted: + """A object that implements weighted operations. + + You should create a Weighted object by using the `DataArray.weighted` or + `Dataset.weighted` methods. + + See Also + -------- + Dataset.weighted + DataArray.weighted + """ + + __slots__ = ("obj", "weights") + + @overload + def __init__(self, obj: "DataArray", weights: "DataArray") -> None: + ... + + @overload + def __init__(self, obj: "Dataset", weights: "DataArray") -> None: + ... + + def __init__(self, obj, weights) -> None: """ Weighted operations for DataArray. Parameters ---------- - obj : DataArray + obj : DataArray or Dataset Object over which the weighted reduction operation is applied. weights : DataArray - An array of weights associated with the values in this Dataset. - Each value in the DataArray contributes to the reduction operation + An array of weights associated with the values in this obj. + Each value in the obj contributes to the reduction operation according to its associated weight. Note @@ -62,57 +196,110 @@ def __init__(self, obj, weights): self.obj = obj self.weights = weights.fillna(0) - def sum_of_weights(self, dim=None, axis=None): - """ - Calcualte the sum of weights, accounting for missing values + def __repr__(self): + """provide a nice str repr of our weighted object""" - Parameters - ---------- - dim : str or sequence of str, optional - Dimension(s) over which to sum the weights. - axis : int or sequence of int, optional - Axis(es) over which to sum the weights. Only one of the 'dim' and - 'axis' arguments can be supplied. If neither are supplied, then - the weights are summed over all axes. + msg = "{klass} with weights along dimensions: {weight_dims}" + return msg.format( + klass=self.__class__.__name__, weight_dims=", ".join(self.weights.dims) + ) - """ - # we need to mask DATA values that are nan; else the weights are wrong - masked_weights = self.weights.where(self.obj.notnull()) +class DataArrayWeighted(Weighted): + def sum_of_weights( + self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, axis=None + ) -> "DataArray": - sum_of_weights = masked_weights.sum(dim=dim, axis=axis, skipna=True) + return _sum_of_weights(self.obj, self.weights, dim=dim, axis=axis) - # find all weights that are valid (not 0) - valid_weights = sum_of_weights != 0. + def sum( + self, + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + axis=None, + skipna: Optional[bool] = None, + **kwargs + ) -> "DataArray": - # set invalid weights to nan - return sum_of_weights.where(valid_weights) + return _weighted_sum( + self.obj, self.weights, dim=dim, axis=axis, skipna=skipna, **kwargs + ) - def sum(self, dim=None, axis=None, skipna=None, **kwargs): + def mean( + self, + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + axis=None, + skipna: Optional[bool] = None, + **kwargs + ) -> "DataArray": - # calculate weighted sum - return (self.obj * self.weights).sum(dim, axis=axis, skipna=skipna, - **kwargs) + return _weighted_mean( + self.obj, self.weights, dim=dim, axis=axis, skipna=skipna, **kwargs + ) - def mean(self, dim=None, axis=None, skipna=None, **kwargs): - # get the sum of weights - sum_of_weights = self.sum_of_weights(dim=dim, axis=axis) +# add docstrings +DataArrayWeighted.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format( + cls="DataArray" +) +DataArrayWeighted.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls="DataArray", fcn="mean" +) +DataArrayWeighted.sum.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls="DataArray", fcn="sum" +) - # get weighted sum - weighted_sum = self.sum(dim=dim, axis=axis, skipna=skipna, **kwargs) - # calculate weighted mean - return weighted_sum / sum_of_weights +class DatasetWeighted(Weighted): + def _dataset_implementation(self, func, **kwargs) -> "Dataset": + + from .dataset import Dataset - def __repr__(self): - """provide a nice str repr of our weighted object""" + weighted = {} + for key, da in self.obj.data_vars.items(): - msg = "{klass} with weights along dimensions: {weight_dims}" - return msg.format(klass=self.__class__.__name__, - weight_dims=", ".join(self.weights.dims)) + weighted[key] = func(da, self.weights, **kwargs) + return Dataset(weighted, coords=self.obj.coords) -# add docstrings -DataArrayWeighted.mean.__doc__ = _doc_.format(fcn='mean') -DataArrayWeighted.sum.__doc__ = _doc_.format(fcn='sum') + def sum_of_weights( + self, + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + axis=None, + skipna: Optional[bool] = None, + ) -> "Dataset": + + return self._dataset_implementation(_sum_of_weights, dim=dim, axis=axis) + + def sum( + self, + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + axis=None, + skipna: Optional[bool] = None, + **kwargs + ) -> "Dataset": + + return self._dataset_implementation( + _weighted_sum, dim=dim, axis=axis, skipna=skipna, **kwargs + ) + + def mean( + self, + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + axis=None, + skipna: Optional[bool] = None, + **kwargs + ) -> "Dataset": + + return self._dataset_implementation( + _weighted_mean, dim=dim, axis=axis, skipna=skipna, **kwargs + ) + + +# add docstring +DatasetWeighted.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(cls="Dataset") +DatasetWeighted.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls="Dataset", fcn="mean" +) +DatasetWeighted.sum.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls="Dataset", fcn="sum" +) From 49d4e43b1cc5a1e1469fd935b5f3f34e0ce172d5 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 17 Oct 2019 22:03:49 +0200 Subject: [PATCH 15/52] add DatasetWeighted class --- xarray/core/dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6123b42b77e..d58ac8bfcb3 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -44,6 +44,7 @@ resample, rolling, utils, + weighted, ) from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align from .common import ( @@ -432,6 +433,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): _rolling_cls = rolling.DatasetRolling _coarsen_cls = rolling.DatasetCoarsen _resample_cls = resample.DatasetResample + _weighted_cls = weighted.DatasetWeighted def __init__( self, From 527256eb10e0e834178e60a45d922b40b346afee Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 17 Oct 2019 22:36:20 +0200 Subject: [PATCH 16/52] _maybe_get_all_dims return sorted tuple --- xarray/core/weighted.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 79f9c88b763..0132b60c897 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -74,7 +74,7 @@ def _maybe_get_all_dims( """ if dims is None: - dims = set(dims1) | set(dims2) + dims = tuple(sorted(set(dims1) | set(dims2))) return dims From 739568f4e1a7f798c7cd4f3c8ce017ac81f12558 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 17 Oct 2019 22:37:16 +0200 Subject: [PATCH 17/52] work on: test_weighted --- xarray/tests/test_weighted.py | 133 +++++++++++++++++++++++++--------- 1 file changed, 98 insertions(+), 35 deletions(-) diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 65357ee2440..0d75cd09432 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -3,10 +3,61 @@ import numpy as np import xarray as xr -from xarray import ( - DataArray,) +from xarray import DataArray, Dataset +from xarray.core import weighted -from xarray.tests import assert_equal, raises_regex +from xarray.tests import assert_equal, assert_allclose, raises_regex + + +@pytest.mark.parametrize("dims", (None, "a", ("a", "b"))) +def test_weighted_maybe_get_all_dims(dims): + + d1 = ("x", "y") + d2 = ("y", "z") + + expected = ("x", "y", "z") if dims is None else dims + + result = weighted._maybe_get_all_dims(dims, d1, d2) + + assert result == expected + + +@pytest.mark.parametrize("size", (1, 5, 100)) +def test_weighted__sum_of_weights_1D(size): + + data = np.zeros(size) + # make sure weights is not 0 + weights = np.arange(1, size + 1) + + da = DataArray(data) + weights = DataArray(weights) + + expected = weights.sum() + + result = da.weighted(weights).sum_of_weights() + + assert_equal(expected, result) + + +@pytest.mark.parametrize("shape", ((2, 2), (2, 5), (10, 10))) +@pytest.mark.parametrize("dim", (None, "dim_0", "dim_1", ("dim_0", "dim_1"))) +def test_weighted__sum_of_weights_2D(shape, dim): + + np.random.seed(0) + + data = np.zeros(shape) + # make sure all weights are positive to avoid summing to 0 + weights = np.abs(np.random.randn(*shape)) + + da = DataArray(data) + weights = DataArray(weights) + + weighted = da.weighted(weights) + + expected = weights.sum(dim=dim) + result = weighted.sum_of_weights(dim=dim) + + assert_allclose(expected, result) def test_weigted_non_DataArray_weights(): @@ -15,23 +66,27 @@ def test_weigted_non_DataArray_weights(): with raises_regex(AssertionError, "'weights' must be a DataArray"): da.weighted([1, 2]) + ds = Dataset(dict(data=[1, 2])) + with raises_regex(AssertionError, "'weights' must be a DataArray"): + da.weighted([1, 2]) -@pytest.mark.parametrize('weights', ([1, 2], [np.nan, 2], [np.nan, np.nan])) + +@pytest.mark.parametrize("weights", ([1, 2], [np.nan, 2], [np.nan, np.nan])) def test_weighted_weights_nan_replaced(weights): # make sure nans are removed from weights da = DataArray([1, 2]) - expected = DataArray(weights).fillna(0.) + expected = DataArray(weights).fillna(0.0) result = da.weighted(DataArray(weights)).weights assert_equal(expected, result) -@pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 3), - ([0, 2], 2), - ([0, 0], np.nan), - ([-1, 1], np.nan))) +@pytest.mark.parametrize( + ("weights", "expected"), + (([1, 2], 3), ([0, 2], 2), ([0, 0], np.nan), ([-1, 1], np.nan)), +) def test_weigted_sum_of_weights_no_nan(weights, expected): da = DataArray([1, 2]) @@ -43,10 +98,9 @@ def test_weigted_sum_of_weights_no_nan(weights, expected): assert_equal(expected, result) -@pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 2), - ([0, 2], 2), - ([0, 0], np.nan), - ([-1, 1], 1))) +@pytest.mark.parametrize( + ("weights", "expected"), (([1, 2], 2), ([0, 2], 2), ([0, 0], np.nan), ([-1, 1], 1)) +) def test_weigted_sum_of_weights_nan(weights, expected): da = DataArray([np.nan, 2]) @@ -58,9 +112,9 @@ def test_weigted_sum_of_weights_nan(weights, expected): assert_equal(expected, result) -@pytest.mark.parametrize('da', ([1, 2], [1, np.nan], [np.nan, np.nan])) -@pytest.mark.parametrize('factor', [0, 1, 2, 3.14]) -@pytest.mark.parametrize('skipna', (True, False)) +@pytest.mark.parametrize("da", ([1, 2], [1, np.nan], [np.nan, np.nan])) +@pytest.mark.parametrize("factor", [0, 1, 2, 3.14]) +@pytest.mark.parametrize("skipna", (True, False)) def test_weighted_sum_equal_weights(da, factor, skipna): # if all weights are 'f'; weighted sum is f times the ordinary sum @@ -73,9 +127,9 @@ def test_weighted_sum_equal_weights(da, factor, skipna): assert_equal(expected, result) -@pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 5), - ([0, 2], 4), - ([0, 0], 0))) +@pytest.mark.parametrize( + ("weights", "expected"), (([1, 2], 5), ([0, 2], 4), ([0, 0], 0)) +) def test_weighted_sum_no_nan(weights, expected): da = DataArray([1, 2]) @@ -86,11 +140,10 @@ def test_weighted_sum_no_nan(weights, expected): assert_equal(expected, result) -@pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 4), - ([0, 2], 4), - ([1, 0], 0), - ([0, 0], 0))) -@pytest.mark.parametrize('skipna', (True, False)) +@pytest.mark.parametrize( + ("weights", "expected"), (([1, 2], 4), ([0, 2], 4), ([1, 0], 0), ([0, 0], 0)) +) +@pytest.mark.parametrize("skipna", (True, False)) def test_weighted_sum_nan(weights, expected, skipna): da = DataArray([np.nan, 2]) @@ -106,8 +159,8 @@ def test_weighted_sum_nan(weights, expected, skipna): @pytest.mark.filterwarnings("ignore:Mean of empty slice") -@pytest.mark.parametrize('da', ([1, 2], [1, np.nan], [np.nan, np.nan])) -@pytest.mark.parametrize('skipna', (True, False)) +@pytest.mark.parametrize("da", ([1, 2], [1, np.nan], [np.nan, np.nan])) +@pytest.mark.parametrize("skipna", (True, False)) def test_weigted_mean_equal_weights(da, skipna): # if all weights are equal, should yield the same result as mean @@ -122,10 +175,10 @@ def test_weigted_mean_equal_weights(da, skipna): assert_equal(expected, result) -@pytest.mark.parametrize(('weights', 'expected'), (([4, 6], 1.6), - ([0, 1], 2.0), - ([0, 2], 2.0), - ([0, 0], np.nan))) +@pytest.mark.parametrize( + ("weights", "expected"), + (([4, 6], 1.6), ([0, 1], 2.0), ([0, 2], 2.0), ([0, 0], np.nan)), +) def test_weighted_mean_no_nan(weights, expected): da = DataArray([1, 2]) @@ -137,11 +190,11 @@ def test_weighted_mean_no_nan(weights, expected): assert_equal(expected, result) -@pytest.mark.parametrize(('weights', 'expected'), (([4, 6], 2.0), - ([0, 1], 2.0), - ([0, 2], 2.0), - ([0, 0], np.nan))) -@pytest.mark.parametrize('skipna', (True, False)) +@pytest.mark.parametrize( + ("weights", "expected"), + (([4, 6], 2.0), ([0, 1], 2.0), ([0, 2], 2.0), ([0, 0], np.nan)), +) +@pytest.mark.parametrize("skipna", (True, False)) def test_weigted_mean_nan(weights, expected, skipna): da = DataArray([np.nan, 2]) @@ -155,3 +208,13 @@ def test_weigted_mean_nan(weights, expected, skipna): result = da.weighted(weights).mean(skipna=skipna) assert_equal(expected, result) + + +# def expected_weighted(da, weights, operation, ): + + +# weighted_sum = (da * weights).sum(dim, axis=axis, skipna=skipna, +# **kwargs) + +# if operation == "sum": +# return weighted_sum From f01305d213af272e1859bb1f128dcb96be4c0f98 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 17 Oct 2019 22:43:47 +0200 Subject: [PATCH 18/52] black and flake8 --- xarray/core/common.py | 1 - xarray/core/dataarray.py | 2 +- xarray/core/weighted.py | 8 +++++--- xarray/tests/test_weighted.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 3e423988a03..21ba1c2bae5 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -746,7 +746,6 @@ def weighted(self, weights): return self._weighted_cls(self, weights) - def rolling( self, dim: Mapping[Hashable, int] = None, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index deb0a9c0e1d..eaf649b9a73 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -271,7 +271,7 @@ class DataArray(AbstractArray, DataWithCoords): _coarsen_cls = rolling.DataArrayCoarsen _resample_cls = resample.DataArrayResample _weighted_cls = weighted.DataArrayWeighted - + __default = ReprObject("") dt = property(DatetimeAccessor) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 0132b60c897..104417fb11e 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -48,7 +48,7 @@ Axis(es) over which to sum the weights. Only one of the 'dim' and 'axis' arguments can be supplied. If neither are supplied, then the weights are summed over all axes. - + Returns ------- reduced : {cls} @@ -64,7 +64,9 @@ def _maybe_get_all_dims( - dims: Optional[Union[Hashable, Iterable[Hashable]]], dims1: Tuple[Hashable, ...], dims2: Tuple[Hashable, ...] + dims: Optional[Union[Hashable, Iterable[Hashable]]], + dims1: Tuple[Hashable, ...], + dims2: Tuple[Hashable, ...], ): """ the union of all dimensions @@ -251,7 +253,7 @@ def mean( class DatasetWeighted(Weighted): def _dataset_implementation(self, func, **kwargs) -> "Dataset": - + from .dataset import Dataset weighted = {} diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 0d75cd09432..0a62db8bb7a 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -68,7 +68,7 @@ def test_weigted_non_DataArray_weights(): ds = Dataset(dict(data=[1, 2])) with raises_regex(AssertionError, "'weights' must be a DataArray"): - da.weighted([1, 2]) + ds.weighted([1, 2]) @pytest.mark.parametrize("weights", ([1, 2], [np.nan, 2], [np.nan, np.nan])) From 2e3880da3b2448e35b371ea425ceeccd50c3bc85 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 17 Oct 2019 23:23:22 +0200 Subject: [PATCH 19/52] Apply suggestions from code review (docs) --- xarray/core/weighted.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 104417fb11e..e1a4a154331 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -38,7 +38,7 @@ """ _SUM_OF_WEIGHTS_DOCSTRING = """ - Calcualte the sum of weights, accounting for missing values + Calculate the sum of weights, accounting for missing values Parameters ---------- @@ -68,7 +68,7 @@ def _maybe_get_all_dims( dims1: Tuple[Hashable, ...], dims2: Tuple[Hashable, ...], ): - """ the union of all dimensions + """ the union of dims1 and dims2 if dims is None `dims=None` behaves differently in `dot` and `sum`, so we have to apply `dot` over the union of the dimensions @@ -95,7 +95,7 @@ def _sum_of_weights( # need to infer dims as we use `dot` dims = _maybe_get_all_dims(dim, da.dims, weights.dims) - # use `dot` to avoid creating large da's + # use `dot` to avoid creating large DataArrays (if da and weights do not share all dims) sum_of_weights = dot(mask, weights, dims=dims) # find all weights that are valid (not 0) @@ -173,7 +173,7 @@ def __init__(self, obj: "Dataset", weights: "DataArray") -> None: def __init__(self, obj, weights) -> None: """ - Weighted operations for DataArray. + Create a Weighted object Parameters ---------- @@ -199,7 +199,7 @@ def __init__(self, obj, weights) -> None: self.weights = weights.fillna(0) def __repr__(self): - """provide a nice str repr of our weighted object""" + """provide a nice str repr of our Weighted object""" msg = "{klass} with weights along dimensions: {weight_dims}" return msg.format( From ae8d0483bc8fda24937bf734003c524272fd0876 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Fri, 18 Oct 2019 12:40:01 +0200 Subject: [PATCH 20/52] restructure interim --- xarray/core/weighted.py | 235 ++++++++++++++++++++-------------------- 1 file changed, 120 insertions(+), 115 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index e1a4a154331..f3dd0d1b789 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -11,10 +11,6 @@ ---------- dim : str or sequence of str, optional Dimension(s) over which to apply the weighted `{fcn}`. - axis : int or sequence of int, optional - Axis(es) over which to apply the weighted `{fcn}`. Only one of the - 'dim' and 'axis' arguments can be supplied. If neither are supplied, - then the weighted `{fcn}` is calculated over all axes. skipna : bool, optional If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not @@ -44,10 +40,6 @@ ---------- dim : str or sequence of str, optional Dimension(s) over which to sum the weights. - axis : int or sequence of int, optional - Axis(es) over which to sum the weights. Only one of the 'dim' and - 'axis' arguments can be supplied. If neither are supplied, then - the weights are summed over all axes. Returns ------- @@ -58,11 +50,6 @@ """ -# functions for weighted operations for one DataArray -# NOTE: weights must not contain missing values (this is taken care of in the -# DataArrayWeighted and DatasetWeighted cls) - - def _maybe_get_all_dims( dims: Optional[Union[Hashable, Iterable[Hashable]]], dims1: Tuple[Hashable, ...], @@ -80,75 +67,6 @@ def _maybe_get_all_dims( return dims - -def _sum_of_weights( - da: "DataArray", - weights: "DataArray", - dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, - axis=None, -) -> "DataArray": - """ Calcualte the sum of weights, accounting for missing values """ - - # we need to mask DATA values that are nan; else the weights are wrong - mask = where(da.notnull(), 1, 0) # binary mask - - # need to infer dims as we use `dot` - dims = _maybe_get_all_dims(dim, da.dims, weights.dims) - - # use `dot` to avoid creating large DataArrays (if da and weights do not share all dims) - sum_of_weights = dot(mask, weights, dims=dims) - - # find all weights that are valid (not 0) - valid_weights = sum_of_weights != 0.0 - - # set invalid weights to nan - return sum_of_weights.where(valid_weights) - - -def _weighted_sum( - da: "DataArray", - weights: "DataArray", - dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, - axis=None, - skipna: Optional[bool] = None, - **kwargs -) -> "DataArray": - """Reduce a DataArray by a by a weighted `sum` along some dimension(s).""" - - # need to infer dims as we use `dot` - dims = _maybe_get_all_dims(dim, da.dims, weights.dims) - - # use `dot` to avoid creating large da's - - # need to mask invalid DATA as dot does not implement skipna - if skipna or skipna is None: - return where(da.isnull(), 0.0, da).dot(weights, dims=dims) - - return dot(da, weights, dims=dims) - - -def _weighted_mean( - da: "DataArray", - weights: "DataArray", - dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, - axis=None, - skipna: Optional[bool] = None, - **kwargs -) -> "DataArray": - """Reduce a DataArray by a weighted `mean` along some dimension(s).""" - - # get weighted sum - weighted_sum = _weighted_sum( - da, weights, dim=dim, axis=axis, skipna=skipna, **kwargs - ) - - # get the sum of weights - sum_of_weights = _sum_of_weights(da, weights, dim=dim, axis=axis) - - # calculate weighted mean - return weighted_sum / sum_of_weights - - class Weighted: """A object that implements weighted operations. @@ -168,10 +86,14 @@ def __init__(self, obj: "DataArray", weights: "DataArray") -> None: ... @overload - def __init__(self, obj: "Dataset", weights: "DataArray") -> None: + def __init__( + self, obj: "Dataset", weights: "DataArray" + ) -> None: # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updated ... - def __init__(self, obj, weights) -> None: + def __init__( + self, obj, weights + ) -> None: # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updated """ Create a Weighted object @@ -198,6 +120,76 @@ def __init__(self, obj, weights) -> None: self.obj = obj self.weights = weights.fillna(0) + def _sum_of_weights( + self, da: "DataArray", dim: Optional[Union[Hashable, Iterable[Hashable]]] = None + ) -> "DataArray": + """ Calculate the sum of weights, accounting for missing values """ + + # we need to mask DATA values that are nan; else the weights are wrong + mask = da.isnull() + + # need to infer dims as we use `dot` + dims = _maybe_get_all_dims(dim, da.dims, self.weights.dims) + + # use `dot` to avoid creating large DataArrays (if da and weights do not share all dims) + sum_of_weights = dot(mask, self.weights, dims=dims) + + # find all weights that are valid (not 0) + valid_weights = sum_of_weights != 0.0 + + # set invalid weights to nan + return sum_of_weights.where(valid_weights) + + def _weighted_sum( + self, + da: "DataArray", + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + **kwargs + ) -> "DataArray": + """Reduce a DataArray by a by a weighted `sum` along some dimension(s).""" + + # need to infer dims as we use `dot` + dims = _maybe_get_all_dims(dim, da.dims, self.weights.dims) + + # use `dot` to avoid creating large da's + + # need to mask invalid DATA as dot does not implement skipna + if skipna or skipna is None: + return da.fillna(0.).dot(self.weights, dims=dims) + + return dot(da, self.weights, dims=dims) + + def _weighted_mean( + self, + da: "DataArray", + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + **kwargs + ) -> "DataArray": + """Reduce a DataArray by a weighted `mean` along some dimension(s).""" + + # get weighted sum + weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna, **kwargs) + + # get the sum of weights + sum_of_weights = self._sum_of_weights(da, dim=dim) + + # calculate weighted mean + return weighted_sum / sum_of_weights + + # def _implementation(self, func, **kwargs): + + # msg = "Use 'Dataset.weighted' or 'DataArray.weighted'" + + # raise NotImplementedError(msg) + + + def sum(self, dim=None, skipna=None): + + return self._implementation(self._weighted_sum, dim=None, skipna=None) + + def __repr__(self): """provide a nice str repr of our Weighted object""" @@ -208,35 +200,35 @@ def __repr__(self): class DataArrayWeighted(Weighted): + + + def _implementation(self, func, **kwargs): + + return func(self.obj, **kwargs) + def sum_of_weights( - self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, axis=None + self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None ) -> "DataArray": - return _sum_of_weights(self.obj, self.weights, dim=dim, axis=axis) + return self._sum_of_weights(self.obj, dim=dim) - def sum( - self, - dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, - axis=None, - skipna: Optional[bool] = None, - **kwargs - ) -> "DataArray": + # def sum( + # self, + # dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + # skipna: Optional[bool] = None, + # **kwargs + # ) -> "DataArray": - return _weighted_sum( - self.obj, self.weights, dim=dim, axis=axis, skipna=skipna, **kwargs - ) + # return self._weighted_sum(self.obj, dim=dim, skipna=skipna, **kwargs) def mean( self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, - axis=None, skipna: Optional[bool] = None, **kwargs ) -> "DataArray": - return _weighted_mean( - self.obj, self.weights, dim=dim, axis=axis, skipna=skipna, **kwargs - ) + return self._weighted_mean(self.obj, dim=dim, skipna=skipna, **kwargs) # add docstrings @@ -259,41 +251,54 @@ def _dataset_implementation(self, func, **kwargs) -> "Dataset": weighted = {} for key, da in self.obj.data_vars.items(): - weighted[key] = func(da, self.weights, **kwargs) + weighted[key] = func(da, **kwargs) return Dataset(weighted, coords=self.obj.coords) + + def _implementation(self, func, **kwargs) -> "Dataset": + + from .dataset import Dataset + + weighted = {} + for key, da in self.obj.data_vars.items(): + + weighted[key] = func(da, **kwargs) + + return Dataset(weighted, coords=self.obj.coords) + + + + + def sum_of_weights( self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, - axis=None, skipna: Optional[bool] = None, ) -> "Dataset": - return self._dataset_implementation(_sum_of_weights, dim=dim, axis=axis) + return self._dataset_implementation(self._sum_of_weights, dim=dim) - def sum( - self, - dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, - axis=None, - skipna: Optional[bool] = None, - **kwargs - ) -> "Dataset": + # def sum( + # self, + # dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + # skipna: Optional[bool] = None, + # **kwargs + # ) -> "Dataset": - return self._dataset_implementation( - _weighted_sum, dim=dim, axis=axis, skipna=skipna, **kwargs - ) + # return self._dataset_implementation( + # self._weighted_sum, dim=dim, skipna=skipna, **kwargs + # ) def mean( self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, - axis=None, skipna: Optional[bool] = None, **kwargs ) -> "Dataset": return self._dataset_implementation( - _weighted_mean, dim=dim, axis=axis, skipna=skipna, **kwargs + self._weighted_mean, dim=dim, skipna=skipna, **kwargs ) From dc7f6057c1a2533569ff6b5995d9a3c3e4c1fd85 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Fri, 18 Oct 2019 16:02:53 +0200 Subject: [PATCH 21/52] restructure classes --- xarray/core/weighted.py | 124 +++++++++++----------------------------- 1 file changed, 32 insertions(+), 92 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index f3dd0d1b789..8b2857bb44b 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,4 +1,4 @@ -from .computation import where, dot +from .computation import dot from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Tuple, Union, overload if TYPE_CHECKING: @@ -67,6 +67,7 @@ def _maybe_get_all_dims( return dims + class Weighted: """A object that implements weighted operations. @@ -85,15 +86,13 @@ class Weighted: def __init__(self, obj: "DataArray", weights: "DataArray") -> None: ... - @overload - def __init__( - self, obj: "Dataset", weights: "DataArray" - ) -> None: # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updated + @overload # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updated + def __init__(self, obj: "Dataset", weights: "DataArray") -> None: ... - def __init__( + def __init__( # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updated self, obj, weights - ) -> None: # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updated + ): """ Create a Weighted object @@ -102,13 +101,13 @@ def __init__( obj : DataArray or Dataset Object over which the weighted reduction operation is applied. weights : DataArray - An array of weights associated with the values in this obj. + An array of weights associated with the values in the obj. Each value in the obj contributes to the reduction operation according to its associated weight. Note ---- - Missing values in the weights are replaced with 0 (i.e. no weight). + Missing values in the weights are replaced with 0. (i.e. no weight). """ @@ -126,7 +125,7 @@ def _sum_of_weights( """ Calculate the sum of weights, accounting for missing values """ # we need to mask DATA values that are nan; else the weights are wrong - mask = da.isnull() + mask = ~da.isnull() # need to infer dims as we use `dot` dims = _maybe_get_all_dims(dim, da.dims, self.weights.dims) @@ -145,18 +144,17 @@ def _weighted_sum( da: "DataArray", dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, - **kwargs ) -> "DataArray": """Reduce a DataArray by a by a weighted `sum` along some dimension(s).""" # need to infer dims as we use `dot` dims = _maybe_get_all_dims(dim, da.dims, self.weights.dims) - # use `dot` to avoid creating large da's + # use `dot` to avoid creating large DataArrays # need to mask invalid DATA as dot does not implement skipna if skipna or skipna is None: - return da.fillna(0.).dot(self.weights, dims=dims) + return dot(da.fillna(0.0), self.weights, dims=dims) return dot(da, self.weights, dims=dims) @@ -165,12 +163,11 @@ def _weighted_mean( da: "DataArray", dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, - **kwargs ) -> "DataArray": """Reduce a DataArray by a weighted `mean` along some dimension(s).""" # get weighted sum - weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna, **kwargs) + weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna) # get the sum of weights sum_of_weights = self._sum_of_weights(da, dim=dim) @@ -178,17 +175,32 @@ def _weighted_mean( # calculate weighted mean return weighted_sum / sum_of_weights - # def _implementation(self, func, **kwargs): + def _implementation(self, func, **kwargs): - # msg = "Use 'Dataset.weighted' or 'DataArray.weighted'" + msg = "Use 'Dataset.weighted' or 'DataArray.weighted'" + raise NotImplementedError(msg) - # raise NotImplementedError(msg) + def sum_of_weights( + self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None + ) -> Union["DataArray", "Dataset"]: + return self._implementation(self._sum_of_weights, dim=dim) - def sum(self, dim=None, skipna=None): + def sum( + self, + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + ) -> Union["DataArray", "Dataset"]: - return self._implementation(self._weighted_sum, dim=None, skipna=None) + return self._implementation(self._weighted_sum, dim=dim, skipna=skipna) + def mean( + self, + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + ) -> Union["DataArray", "Dataset"]: + + return self._implementation(self._weighted_mean, dim=dim, skipna=skipna) def __repr__(self): """provide a nice str repr of our Weighted object""" @@ -200,36 +212,10 @@ def __repr__(self): class DataArrayWeighted(Weighted): - - def _implementation(self, func, **kwargs): return func(self.obj, **kwargs) - def sum_of_weights( - self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None - ) -> "DataArray": - - return self._sum_of_weights(self.obj, dim=dim) - - # def sum( - # self, - # dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, - # skipna: Optional[bool] = None, - # **kwargs - # ) -> "DataArray": - - # return self._weighted_sum(self.obj, dim=dim, skipna=skipna, **kwargs) - - def mean( - self, - dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, - skipna: Optional[bool] = None, - **kwargs - ) -> "DataArray": - - return self._weighted_mean(self.obj, dim=dim, skipna=skipna, **kwargs) - # add docstrings DataArrayWeighted.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format( @@ -244,18 +230,6 @@ def mean( class DatasetWeighted(Weighted): - def _dataset_implementation(self, func, **kwargs) -> "Dataset": - - from .dataset import Dataset - - weighted = {} - for key, da in self.obj.data_vars.items(): - - weighted[key] = func(da, **kwargs) - - return Dataset(weighted, coords=self.obj.coords) - - def _implementation(self, func, **kwargs) -> "Dataset": from .dataset import Dataset @@ -268,40 +242,6 @@ def _implementation(self, func, **kwargs) -> "Dataset": return Dataset(weighted, coords=self.obj.coords) - - - - def sum_of_weights( - self, - dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, - skipna: Optional[bool] = None, - ) -> "Dataset": - - return self._dataset_implementation(self._sum_of_weights, dim=dim) - - # def sum( - # self, - # dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, - # skipna: Optional[bool] = None, - # **kwargs - # ) -> "Dataset": - - # return self._dataset_implementation( - # self._weighted_sum, dim=dim, skipna=skipna, **kwargs - # ) - - def mean( - self, - dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, - skipna: Optional[bool] = None, - **kwargs - ) -> "Dataset": - - return self._dataset_implementation( - self._weighted_mean, dim=dim, skipna=skipna, **kwargs - ) - - # add docstring DatasetWeighted.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(cls="Dataset") DatasetWeighted.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( From e2ad69ebb03f613922dac769874be88e21bc1278 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 4 Dec 2019 16:14:40 +0100 Subject: [PATCH 22/52] update weighted.py --- xarray/core/weighted.py | 91 ++++++++++++++++------------------------- 1 file changed, 36 insertions(+), 55 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 8b2857bb44b..023c6610da4 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,5 +1,6 @@ +from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload + from .computation import dot -from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Tuple, Union, overload if TYPE_CHECKING: from .dataarray import DataArray, Dataset @@ -50,24 +51,6 @@ """ -def _maybe_get_all_dims( - dims: Optional[Union[Hashable, Iterable[Hashable]]], - dims1: Tuple[Hashable, ...], - dims2: Tuple[Hashable, ...], -): - """ the union of dims1 and dims2 if dims is None - - `dims=None` behaves differently in `dot` and `sum`, so we have to apply - `dot` over the union of the dimensions - - """ - - if dims is None: - dims = tuple(sorted(set(dims1) | set(dims2))) - - return dims - - class Weighted: """A object that implements weighted operations. @@ -86,13 +69,11 @@ class Weighted: def __init__(self, obj: "DataArray", weights: "DataArray") -> None: ... - @overload # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updated + @overload # noqa: F811 def __init__(self, obj: "Dataset", weights: "DataArray") -> None: ... - def __init__( # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updated - self, obj, weights - ): + def __init__(self, obj, weights): # noqa: F811 """ Create a Weighted object @@ -107,7 +88,7 @@ def __init__( # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updat Note ---- - Missing values in the weights are replaced with 0. (i.e. no weight). + Weights can not contain missing values. """ @@ -117,21 +98,28 @@ def __init__( # noqa: F811 TODO: remove once pyflakes/ flake8 on azure is updat assert isinstance(weights, DataArray), msg self.obj = obj - self.weights = weights.fillna(0) + + if weights.isnull().any(): + raise ValueError("`weights` cannot contain missing values.") + + self.weights = weights def _sum_of_weights( - self, da: "DataArray", dim: Optional[Union[Hashable, Iterable[Hashable]]] = None + self, + da: "DataArray", + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, ) -> "DataArray": """ Calculate the sum of weights, accounting for missing values """ - # we need to mask DATA values that are nan; else the weights are wrong - mask = ~da.isnull() + # we need to mask data values that are nan; else the weights are wrong + mask = da.notnull() # need to infer dims as we use `dot` - dims = _maybe_get_all_dims(dim, da.dims, self.weights.dims) + if dim is None: + dim = ... # use `dot` to avoid creating large DataArrays (if da and weights do not share all dims) - sum_of_weights = dot(mask, self.weights, dims=dims) + sum_of_weights = dot(mask, self.weights, dims=dim) # find all weights that are valid (not 0) valid_weights = sum_of_weights != 0.0 @@ -148,15 +136,16 @@ def _weighted_sum( """Reduce a DataArray by a by a weighted `sum` along some dimension(s).""" # need to infer dims as we use `dot` - dims = _maybe_get_all_dims(dim, da.dims, self.weights.dims) + if dim is None: + dim = ... # use `dot` to avoid creating large DataArrays # need to mask invalid DATA as dot does not implement skipna - if skipna or skipna is None: - return dot(da.fillna(0.0), self.weights, dims=dims) + if skipna or (skipna is None and da.dtype.kind in "cfO"): + return dot(da.fillna(0.0), self.weights, dims=dim) - return dot(da, self.weights, dims=dims) + return dot(da, self.weights, dims=dim) def _weighted_mean( self, @@ -207,7 +196,7 @@ def __repr__(self): msg = "{klass} with weights along dimensions: {weight_dims}" return msg.format( - klass=self.__class__.__name__, weight_dims=", ".join(self.weights.dims) + klass=self.__class__.__name__, weight_dims=", ".join(self.weights.dims), ) @@ -217,18 +206,6 @@ def _implementation(self, func, **kwargs): return func(self.obj, **kwargs) -# add docstrings -DataArrayWeighted.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format( - cls="DataArray" -) -DataArrayWeighted.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( - cls="DataArray", fcn="mean" -) -DataArrayWeighted.sum.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( - cls="DataArray", fcn="sum" -) - - class DatasetWeighted(Weighted): def _implementation(self, func, **kwargs) -> "Dataset": @@ -242,11 +219,15 @@ def _implementation(self, func, **kwargs) -> "Dataset": return Dataset(weighted, coords=self.obj.coords) -# add docstring -DatasetWeighted.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(cls="Dataset") -DatasetWeighted.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( - cls="Dataset", fcn="mean" -) -DatasetWeighted.sum.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( - cls="Dataset", fcn="sum" -) +def _inject_docstring(cls, cls_name): + + cls.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(cls=cls_name) + + for operator in ["sum", "mean"]: + getattr(cls, operator).__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls=cls_name, fcn=operator + ) + + +_inject_docstring(DataArrayWeighted, "DataArray") +_inject_docstring(DatasetWeighted, "Dataset") From bd4f048bdb5a5a356a5603904d96a676037d1b6e Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 4 Dec 2019 16:46:33 +0100 Subject: [PATCH 23/52] black --- xarray/core/weighted.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 023c6610da4..d2dc309ad10 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -105,9 +105,7 @@ def __init__(self, obj, weights): # noqa: F811 self.weights = weights def _sum_of_weights( - self, - da: "DataArray", - dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + self, da: "DataArray", dim: Optional[Union[Hashable, Iterable[Hashable]]] = None ) -> "DataArray": """ Calculate the sum of weights, accounting for missing values """ @@ -196,7 +194,7 @@ def __repr__(self): msg = "{klass} with weights along dimensions: {weight_dims}" return msg.format( - klass=self.__class__.__name__, weight_dims=", ".join(self.weights.dims), + klass=self.__class__.__name__, weight_dims=", ".join(self.weights.dims) ) From 3c7695a5345b59dcc6096ceb97c7f53dd087faa4 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 4 Dec 2019 21:07:46 +0100 Subject: [PATCH 24/52] use map; add keep_attrs --- xarray/core/weighted.py | 53 +++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index d2dc309ad10..bca1e067d18 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload from .computation import dot +from .options import _get_keep_attrs if TYPE_CHECKING: from .dataarray import DataArray, Dataset @@ -23,9 +24,6 @@ If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. - **kwargs : dict - Additional keyword arguments passed on to the appropriate array - function for calculating `{fcn}` on this object's data. Returns ------- @@ -41,13 +39,15 @@ ---------- dim : str or sequence of str, optional Dimension(s) over which to sum the weights. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. Returns ------- reduced : {cls} New {cls} object with the sum of the weights over the given dimension. - - """ @@ -94,8 +94,7 @@ def __init__(self, obj, weights): # noqa: F811 from .dataarray import DataArray - msg = "'weights' must be a DataArray" - assert isinstance(weights, DataArray), msg + assert isinstance(weights, DataArray), "'weights' must be a DataArray" self.obj = obj @@ -162,32 +161,42 @@ def _weighted_mean( # calculate weighted mean return weighted_sum / sum_of_weights - def _implementation(self, func, **kwargs): + def _implementation(self, func, dim, **kwargs): msg = "Use 'Dataset.weighted' or 'DataArray.weighted'" raise NotImplementedError(msg) def sum_of_weights( - self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None + self, + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + keep_attrs: Optional[bool] = None, ) -> Union["DataArray", "Dataset"]: - return self._implementation(self._sum_of_weights, dim=dim) + return self._implementation( + self._sum_of_weights, dim=dim, keep_attrs=keep_attrs + ) def sum( self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, + keep_attrs: Optional[bool] = None, ) -> Union["DataArray", "Dataset"]: - return self._implementation(self._weighted_sum, dim=dim, skipna=skipna) + return self._implementation( + self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) def mean( self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, + keep_attrs: Optional[bool] = None, ) -> Union["DataArray", "Dataset"]: - return self._implementation(self._weighted_mean, dim=dim, skipna=skipna) + return self._implementation( + self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) def __repr__(self): """provide a nice str repr of our Weighted object""" @@ -199,22 +208,24 @@ def __repr__(self): class DataArrayWeighted(Weighted): - def _implementation(self, func, **kwargs): + def _implementation(self, func, dim, **kwargs): - return func(self.obj, **kwargs) + keep_attrs = kwargs.pop("keep_attrs") + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + weighted = func(self.obj, dim=dim, **kwargs) -class DatasetWeighted(Weighted): - def _implementation(self, func, **kwargs) -> "Dataset": + if keep_attrs: + weighted.attrs = self.obj.attrs - from .dataset import Dataset + return weighted - weighted = {} - for key, da in self.obj.data_vars.items(): - weighted[key] = func(da, **kwargs) +class DatasetWeighted(Weighted): + def _implementation(self, func, dim, **kwargs) -> "Dataset": - return Dataset(weighted, coords=self.obj.coords) + return self.obj.map(func, dim=dim, **kwargs) def _inject_docstring(cls, cls_name): From ef07edd43491d23f0e1da106dc1bb640c7863c48 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 4 Dec 2019 21:19:19 +0100 Subject: [PATCH 25/52] implement expected_weighted; update tests --- xarray/tests/test_weighted.py | 176 +++++++++++++++++++++++++++------- 1 file changed, 140 insertions(+), 36 deletions(-) diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 0a62db8bb7a..1b240738efc 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -1,25 +1,9 @@ -import pytest - import numpy as np +import pytest import xarray as xr from xarray import DataArray, Dataset -from xarray.core import weighted - -from xarray.tests import assert_equal, assert_allclose, raises_regex - - -@pytest.mark.parametrize("dims", (None, "a", ("a", "b"))) -def test_weighted_maybe_get_all_dims(dims): - - d1 = ("x", "y") - d2 = ("y", "z") - - expected = ("x", "y", "z") if dims is None else dims - - result = weighted._maybe_get_all_dims(dims, d1, d2) - - assert result == expected +from xarray.tests import assert_allclose, assert_equal, raises_regex @pytest.mark.parametrize("size", (1, 5, 100)) @@ -71,16 +55,12 @@ def test_weigted_non_DataArray_weights(): ds.weighted([1, 2]) -@pytest.mark.parametrize("weights", ([1, 2], [np.nan, 2], [np.nan, np.nan])) -def test_weighted_weights_nan_replaced(weights): - # make sure nans are removed from weights +@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan])) +def test_weighted_weights_nan_raises(weights): + # make sure NaNs in weights raise - da = DataArray([1, 2]) - - expected = DataArray(weights).fillna(0.0) - result = da.weighted(DataArray(weights)).weights - - assert_equal(expected, result) + with pytest.raises(ValueError, match="`weights` cannot contain missing values."): + DataArray([1, 2]).weighted(DataArray(weights)) @pytest.mark.parametrize( @@ -99,7 +79,8 @@ def test_weigted_sum_of_weights_no_nan(weights, expected): @pytest.mark.parametrize( - ("weights", "expected"), (([1, 2], 2), ([0, 2], 2), ([0, 0], np.nan), ([-1, 1], 1)) + ("weights", "expected"), + (([1, 2], 2), ([0, 2], 2), ([0, 0], np.nan), ([-1, 1], 1), ([2, 0], np.nan)), ) def test_weigted_sum_of_weights_nan(weights, expected): @@ -161,13 +142,14 @@ def test_weighted_sum_nan(weights, expected, skipna): @pytest.mark.filterwarnings("ignore:Mean of empty slice") @pytest.mark.parametrize("da", ([1, 2], [1, np.nan], [np.nan, np.nan])) @pytest.mark.parametrize("skipna", (True, False)) -def test_weigted_mean_equal_weights(da, skipna): - # if all weights are equal, should yield the same result as mean +@pytest.mark.parametrize("factor", [1, 2, 3.14]) +def test_weigted_mean_equal_weights(da, skipna, factor): + # if all weights are equal (!= 0), should yield the same result as mean da = DataArray(da) # all weights as 1. - weights = xr.zeros_like(da) + 1 + weights = xr.zeros_like(da) + factor expected = da.mean(skipna=skipna) result = da.weighted(weights).mean(skipna=skipna) @@ -210,11 +192,133 @@ def test_weigted_mean_nan(weights, expected, skipna): assert_equal(expected, result) -# def expected_weighted(da, weights, operation, ): +def expected_weighted(da, weights, dim, skipna, operation): + """ operations implemented via `*` and `sum`; da.Weighted uses `dot` + """ + + weighted_sum = (da * weights).sum(dim=dim, skipna=skipna) + + if operation == "sum": + return weighted_sum + + masked_weights = weights.where(da.notnull()) + sum_of_weights = masked_weights.sum(dim=dim, skipna=True) + valid_weights = sum_of_weights != 0 + sum_of_weights = sum_of_weights.where(valid_weights) + + if operation == "sum_of_weights": + return sum_of_weights + + weighted_mean = weighted_sum / sum_of_weights + + if operation == "mean": + return weighted_mean + + +@pytest.mark.parametrize("dim", ("a", "b", "c", ("a", "b"), ("a", "b", "c"), None)) +@pytest.mark.parametrize("operator", ("sum_of_weights", "sum", "mean")) +@pytest.mark.parametrize("add_nans", (True, False)) +@pytest.mark.parametrize("skipna", (None, True, False)) +@pytest.mark.parametrize("as_dataset", (True, False)) +def test_weighted_operations_3D(dim, operator, add_nans, skipna, as_dataset): + + dims = ("a", "b", "c") + coords = dict(a=[0, 1, 2, 3], b=[0, 1, 2, 3], c=[0, 1, 2, 3]) + + weights = DataArray(np.random.randn(4, 4, 4), dims=dims, coords=coords) + + data = np.random.randn(4, 4, 4) + + # add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700) + if add_nans: + c = int(data.size * 0.25) + data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN + + data = DataArray(data, dims=dims, coords=coords) + + if as_dataset: + data = data.to_dataset(name="data") + + if operator == "sum_of_weights": + result = getattr(data.weighted(weights), operator)(dim) + else: + result = getattr(data.weighted(weights), operator)(dim, skipna=skipna) + + expected = expected_weighted(data, weights, dim, skipna, operator) + + assert_allclose(expected, result) + + +@pytest.mark.parametrize("dim", ("dim_0", None)) +@pytest.mark.parametrize("shape_data", ((4,), (4, 4), (4, 4, 4))) +@pytest.mark.parametrize("shape_weights", ((4,), (4, 4), (4, 4, 4))) +@pytest.mark.parametrize("operator", ("sum_of_weights", "sum", "mean")) +@pytest.mark.parametrize("add_nans", (True, False)) +@pytest.mark.parametrize("skipna", (None, True, False)) +@pytest.mark.parametrize("as_dataset", (True, False)) +def test_weighted_operations_different_shapes( + dim, shape_data, shape_weights, operator, add_nans, skipna, as_dataset +): + + weights = DataArray(np.random.randn(*shape_weights)) + + data = np.random.randn(*shape_data) + + # add approximately 25 % NaNs + if add_nans: + c = int(data.size * 0.25) + data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN + + data = DataArray(data) + + if as_dataset: + data = data.to_dataset(name="data") + + if operator == "sum_of_weights": + result = getattr(data.weighted(weights), operator)(dim) + else: + result = getattr(data.weighted(weights), operator)(dim, skipna=skipna) + + expected = expected_weighted(data, weights, dim, skipna, operator) + + assert_allclose(expected, result) + + +@pytest.mark.parametrize("operator", ("sum_of_weights", "sum", "mean")) +@pytest.mark.parametrize("as_dataset", (True, False)) +@pytest.mark.parametrize("keep_attrs", (True, False, None)) +def test_weighted_operations_keep_attr(operator, as_dataset, keep_attrs): + + weights = DataArray(np.random.randn(2, 2), attrs=dict(attr="weights")) + data = DataArray(np.random.randn(2, 2)) + + if as_dataset: + data = data.to_dataset(name="data") + + data.attrs = dict(attr="weights") + + result = getattr(data.weighted(weights), operator)(keep_attrs=True) + + if operator == "sum_of_weights": + assert weights.attrs == result.attrs + else: + assert data.attrs == result.attrs + + result = getattr(data.weighted(weights), operator)(keep_attrs=None) + assert not result.attrs + + result = getattr(data.weighted(weights), operator)(keep_attrs=False) + assert not result.attrs + + +@pytest.mark.xfail(reason="xr.Dataset.map does not copy attrs of DataArrays GH: 3595") +@pytest.mark.parametrize("operator", ("sum", "mean")) +def test_weighted_operations_keep_attr_da_in_ds(operator): + # GH #3595 + weights = DataArray(np.random.randn(2, 2), attrs=dict(attr="weights")) + data = DataArray(np.random.randn(4, 4, 4), attrs=dict(attr="data")) -# weighted_sum = (da * weights).sum(dim, axis=axis, skipna=skipna, -# **kwargs) + result = getattr(data.weighted(weights), operator)(keep_attrs=True) -# if operation == "sum": -# return weighted_sum + assert data.data.attrs == result.attrs From 064b5a9a43be4514a6d16641a3d8670951f539b3 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 4 Dec 2019 21:19:37 +0100 Subject: [PATCH 26/52] add whats new --- doc/whats-new.rst | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 884c3cef91c..a5d03120803 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -28,7 +28,9 @@ New Features - :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` and ``GroupBy.quantile`` now work with dask Variables. By `Deepak Cherian `_. - +- Weighted array reductions are now supported via the new :py:meth:`DataArray.weighted` + and :py:meth:`Dataset.weighted` methods. By `Mathias Hauser `_ + (:issue:`422`). Bug fixes ~~~~~~~~~ @@ -56,7 +58,7 @@ Internal Changes ~~~~~~~~~~~~~~~~ -- Removed internal method ``Dataset._from_vars_and_coord_names``, +- Removed internal method ``Dataset._from_vars_and_coord_names``, which was dominated by ``Dataset._construct_direct``. (:pull:`3565`) By `Maximilian Roos `_ @@ -83,8 +85,8 @@ Breaking changes New Features ~~~~~~~~~~~~ -- Added the ``sparse`` option to :py:meth:`~xarray.DataArray.unstack`, - :py:meth:`~xarray.Dataset.unstack`, :py:meth:`~xarray.DataArray.reindex`, +- Added the ``sparse`` option to :py:meth:`~xarray.DataArray.unstack`, + :py:meth:`~xarray.Dataset.unstack`, :py:meth:`~xarray.DataArray.reindex`, :py:meth:`~xarray.Dataset.reindex` (:issue:`3518`). By `Keisuke Fujii `_. - Added the ``fill_value`` option to :py:meth:`DataArray.unstack` and @@ -94,13 +96,13 @@ New Features :py:meth:`~xarray.Dataset.interpolate_na`. This controls the maximum size of the data gap that will be filled by interpolation. By `Deepak Cherian `_. - Added :py:meth:`Dataset.drop_sel` & :py:meth:`DataArray.drop_sel` for dropping labels. - :py:meth:`Dataset.drop_vars` & :py:meth:`DataArray.drop_vars` have been added for + :py:meth:`Dataset.drop_vars` & :py:meth:`DataArray.drop_vars` have been added for dropping variables (including coordinates). The existing :py:meth:`Dataset.drop` & :py:meth:`DataArray.drop` methods remain as a backward compatible option for dropping either labels or variables, but using the more specific methods is encouraged. (:pull:`3475`) By `Maximilian Roos `_ -- Added :py:meth:`Dataset.map` & :py:meth:`GroupBy.map` & :py:meth:`Resample.map` for +- Added :py:meth:`Dataset.map` & :py:meth:`GroupBy.map` & :py:meth:`Resample.map` for mapping / applying a function over each item in the collection, reflecting the widely used and least surprising name for this operation. The existing ``apply`` methods remain for backward compatibility, though using the ``map`` @@ -137,13 +139,13 @@ New Features Bug fixes ~~~~~~~~~ -- Ensure an index of type ``CFTimeIndex`` is not converted to a ``DatetimeIndex`` when +- Ensure an index of type ``CFTimeIndex`` is not converted to a ``DatetimeIndex`` when calling :py:meth:`Dataset.rename`, :py:meth:`Dataset.rename_dims` and :py:meth:`Dataset.rename_vars`. By `Mathias Hauser `_. (:issue:`3522`). - Fix a bug in :py:meth:`DataArray.set_index` in case that an existing dimension becomes a level variable of MultiIndex. (:pull:`3520`). By `Keisuke Fujii `_. - Harmonize ``_FillValue``, ``missing_value`` during encoding and decoding steps. (:pull:`3502`) - By `Anderson Banihirwe `_. + By `Anderson Banihirwe `_. - Fix regression introduced in v0.14.0 that would cause a crash if dask is installed but cloudpickle isn't (:issue:`3401`) by `Rhys Doyle `_ - Fix grouping over variables with NaNs. (:issue:`2383`, :pull:`3406`). From 72c79429b482b8177fb4f7c7ec32bf8e2ccc7fa9 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 4 Dec 2019 21:31:09 +0100 Subject: [PATCH 27/52] undo changes to whats-new --- doc/whats-new.rst | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d55226c0eb8..f49ddd0b682 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -63,7 +63,7 @@ Internal Changes ~~~~~~~~~~~~~~~~ -- Removed internal method ``Dataset._from_vars_and_coord_names``, +- Removed internal method ``Dataset._from_vars_and_coord_names``, which was dominated by ``Dataset._construct_direct``. (:pull:`3565`) By `Maximilian Roos `_ @@ -90,8 +90,8 @@ Breaking changes New Features ~~~~~~~~~~~~ -- Added the ``sparse`` option to :py:meth:`~xarray.DataArray.unstack`, - :py:meth:`~xarray.Dataset.unstack`, :py:meth:`~xarray.DataArray.reindex`, +- Added the ``sparse`` option to :py:meth:`~xarray.DataArray.unstack`, + :py:meth:`~xarray.Dataset.unstack`, :py:meth:`~xarray.DataArray.reindex`, :py:meth:`~xarray.Dataset.reindex` (:issue:`3518`). By `Keisuke Fujii `_. - Added the ``fill_value`` option to :py:meth:`DataArray.unstack` and @@ -101,13 +101,13 @@ New Features :py:meth:`~xarray.Dataset.interpolate_na`. This controls the maximum size of the data gap that will be filled by interpolation. By `Deepak Cherian `_. - Added :py:meth:`Dataset.drop_sel` & :py:meth:`DataArray.drop_sel` for dropping labels. - :py:meth:`Dataset.drop_vars` & :py:meth:`DataArray.drop_vars` have been added for + :py:meth:`Dataset.drop_vars` & :py:meth:`DataArray.drop_vars` have been added for dropping variables (including coordinates). The existing :py:meth:`Dataset.drop` & :py:meth:`DataArray.drop` methods remain as a backward compatible option for dropping either labels or variables, but using the more specific methods is encouraged. (:pull:`3475`) By `Maximilian Roos `_ -- Added :py:meth:`Dataset.map` & :py:meth:`GroupBy.map` & :py:meth:`Resample.map` for +- Added :py:meth:`Dataset.map` & :py:meth:`GroupBy.map` & :py:meth:`Resample.map` for mapping / applying a function over each item in the collection, reflecting the widely used and least surprising name for this operation. The existing ``apply`` methods remain for backward compatibility, though using the ``map`` @@ -144,13 +144,13 @@ New Features Bug fixes ~~~~~~~~~ -- Ensure an index of type ``CFTimeIndex`` is not converted to a ``DatetimeIndex`` when +- Ensure an index of type ``CFTimeIndex`` is not converted to a ``DatetimeIndex`` when calling :py:meth:`Dataset.rename`, :py:meth:`Dataset.rename_dims` and :py:meth:`Dataset.rename_vars`. By `Mathias Hauser `_. (:issue:`3522`). - Fix a bug in :py:meth:`DataArray.set_index` in case that an existing dimension becomes a level variable of MultiIndex. (:pull:`3520`). By `Keisuke Fujii `_. - Harmonize ``_FillValue``, ``missing_value`` during encoding and decoding steps. (:pull:`3502`) - By `Anderson Banihirwe `_. + By `Anderson Banihirwe `_. - Fix regression introduced in v0.14.0 that would cause a crash if dask is installed but cloudpickle isn't (:issue:`3401`) by `Rhys Doyle `_ - Fix grouping over variables with NaNs. (:issue:`2383`, :pull:`3406`). From 0e914118cb797fd873258fe3f86a5449a60b4d74 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 4 Dec 2019 21:34:35 +0100 Subject: [PATCH 28/52] F811: noqa where? --- xarray/core/weighted.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index bca1e067d18..718de2a2406 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -70,7 +70,7 @@ def __init__(self, obj: "DataArray", weights: "DataArray") -> None: ... @overload # noqa: F811 - def __init__(self, obj: "Dataset", weights: "DataArray") -> None: + def __init__(self, obj: "Dataset", weights: "DataArray") -> None: # noqa: F811 ... def __init__(self, obj, weights): # noqa: F811 From 1eb2913f20438bdaebcbae8eff624da0e37793c9 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 5 Dec 2019 16:59:49 +0100 Subject: [PATCH 29/52] api.rst --- doc/api.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index a1fae3deb03..771151c875e 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -161,6 +161,7 @@ Computation Dataset.groupby_bins Dataset.rolling Dataset.rolling_exp + Dataset.weighted Dataset.coarsen Dataset.resample Dataset.diff @@ -336,6 +337,7 @@ Computation DataArray.groupby_bins DataArray.rolling DataArray.rolling_exp + DataArray.weighted DataArray.coarsen DataArray.dt DataArray.resample From 118dfed3e6cd076d0fca9e6ba279d5fb5292fd14 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 5 Dec 2019 17:00:31 +0100 Subject: [PATCH 30/52] add to computation --- doc/computation.rst | 63 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/doc/computation.rst b/doc/computation.rst index 1ac30f55ee7..4667bb2b377 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -243,6 +243,69 @@ You can also use ``construct`` to compute a weighted rolling sum: .. _comput.coarsen: +Weighted array reductions +========================= + +``DataArray`` and ``Dataset`` objects include :py:meth:`~xarray.DataArray.weighted` +and :py:meth:`~xarray.Dataset.weighted` array reduction methods. They currently +support weighted ``sum`` and weighted ``mean``. + +.. ipython:: python + + coords = dict(month=('month', [1, 2, 3])) + + prec = xr.DataArray([1.1, 1.0, 0.9], dims=('month', ), coords=coords) + weights = xr.DataArray([31, 28, 31], dims=('month', ), coords=coords) + +Create a weighted object: + +.. ipython:: python + + weighted_prec = prec.weighted(weights) + weighted_prec + +Calculate the weighted sum: + +.. ipython:: python + + weighted_prec.sum() + +Calculate the weighted mean: + +.. ipython:: python + + weighted_prec.mean(dim="month") + +The weighted sum corresponds to: + +.. ipython:: python + + weighted_sum = (prec * weights).sum() + weighted_sum + +and the weighted mean to: + +.. ipython:: python + + weighted_mean = weighted_sum / weights.sum() + weighted_mean + +However, the functions also take missing values in the data into account: + +.. ipython:: python + + data = xr.DataArray([np.NaN, 2, 4]) + weights = xr.DataArray([8, 1, 1]) + + data.weighted(weights).mean() + +Using ``(data * weights).sum() / weights.sum()`` would (incorrectly) result +in 0.6. + +.. note:: + ``weights`` must be a ``DataArray`` and cannot contain missing values. + Missing values can be replaced manually by `weights.fillna(0)`. + Coarsen large arrays ==================== From e08c9213cc82103f25417f6f6adf9939cc2e1ddf Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 5 Dec 2019 17:03:12 +0100 Subject: [PATCH 31/52] small updates --- xarray/core/weighted.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 718de2a2406..1c0df0aa928 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -18,8 +18,6 @@ skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). - Note: Missing values in the weights are replaced with 0 (i.e. no - weight). keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be @@ -33,7 +31,7 @@ """ _SUM_OF_WEIGHTS_DOCSTRING = """ - Calculate the sum of weights, accounting for missing values + Calculate the sum of weights, accounting for missing values in the data Parameters ---------- @@ -88,8 +86,8 @@ def __init__(self, obj, weights): # noqa: F811 Note ---- - Weights can not contain missing values. - + ``weights`` must be a ``DataArray`` and cannot contain missing values. + Missing values can be replaced by `weights.fillna(0)`. """ from .dataarray import DataArray @@ -136,12 +134,11 @@ def _weighted_sum( if dim is None: dim = ... - # use `dot` to avoid creating large DataArrays - - # need to mask invalid DATA as dot does not implement skipna + # need to mask invalid DATA as `dot` does not implement skipna if skipna or (skipna is None and da.dtype.kind in "cfO"): - return dot(da.fillna(0.0), self.weights, dims=dim) + da = da.fillna(0.0) + # use `dot` to avoid creating large DataArrays return dot(da, self.weights, dims=dim) def _weighted_mean( @@ -163,7 +160,7 @@ def _weighted_mean( def _implementation(self, func, dim, **kwargs): - msg = "Use 'Dataset.weighted' or 'DataArray.weighted'" + msg = "Use `Dataset.weighted` or `DataArray.weighted`" raise NotImplementedError(msg) def sum_of_weights( From 0fafe0b71cbd52149c774279d642dbf02cfbb565 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 5 Dec 2019 17:04:16 +0100 Subject: [PATCH 32/52] add example to gallery --- doc/gallery/area_weighted_temperature.py | 43 ++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 doc/gallery/area_weighted_temperature.py diff --git a/doc/gallery/area_weighted_temperature.py b/doc/gallery/area_weighted_temperature.py new file mode 100644 index 00000000000..d0de6d0aa45 --- /dev/null +++ b/doc/gallery/area_weighted_temperature.py @@ -0,0 +1,43 @@ +""" +================================================ +Compare weighted and unweighted mean temperature +================================================ + + +Use ``air.weighted(weights).mean()`` to calculate the area weighted temperature +for the air_temperature example dataset. This dataset has a regular latitude/ longitude +grid, thus the gridcell area decreases towards the pole. For this grid we can use the +cosine of the latitude as proxy for the grid cell area. Note how the weighted mean +temperature is higher than the unweighted, because high. + + +""" + +import matplotlib.pyplot as plt +import numpy as np + +import xarray as xr + +# Load the data +ds = xr.tutorial.load_dataset("air_temperature") +air = ds.air - 273.15 # to celsius + +# resample from 6-hourly to daily values +air = air.resample(time="D").mean() + +# the cosine of the latitude is proportional to the grid cell area (for a rectangular grid) +weights = np.cos(np.deg2rad(air.lat)) + +mean_air = air.weighted(weights).mean(("lat", "lon")) + +# Prepare the figure +f, ax = plt.subplots(1, 1) + +mean_air.plot(label="Area weighted mean") +air.mean(("lat", "lon")).plot(label="Unweighted mean") + +ax.legend() + +# Show +plt.tight_layout() +plt.show() From a8d330da0ddc69c96b3fc1aaaf8d373959645054 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 5 Dec 2019 17:08:14 +0100 Subject: [PATCH 33/52] typo --- doc/gallery/area_weighted_temperature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/gallery/area_weighted_temperature.py b/doc/gallery/area_weighted_temperature.py index d0de6d0aa45..b9a24c93e1b 100644 --- a/doc/gallery/area_weighted_temperature.py +++ b/doc/gallery/area_weighted_temperature.py @@ -8,7 +8,7 @@ for the air_temperature example dataset. This dataset has a regular latitude/ longitude grid, thus the gridcell area decreases towards the pole. For this grid we can use the cosine of the latitude as proxy for the grid cell area. Note how the weighted mean -temperature is higher than the unweighted, because high. +temperature is higher than the unweighted mean. """ From ae0012fad6e54e7af521bc8ccdbb00b6e2e5cdaa Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 5 Dec 2019 17:09:18 +0100 Subject: [PATCH 34/52] another typo --- doc/gallery/area_weighted_temperature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/gallery/area_weighted_temperature.py b/doc/gallery/area_weighted_temperature.py index b9a24c93e1b..d659574034c 100644 --- a/doc/gallery/area_weighted_temperature.py +++ b/doc/gallery/area_weighted_temperature.py @@ -8,7 +8,7 @@ for the air_temperature example dataset. This dataset has a regular latitude/ longitude grid, thus the gridcell area decreases towards the pole. For this grid we can use the cosine of the latitude as proxy for the grid cell area. Note how the weighted mean -temperature is higher than the unweighted mean. +temperature is higher than the unweighted. """ From 111259b32e626d2ef06199262e9f77f66b1f8603 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 5 Dec 2019 17:12:07 +0100 Subject: [PATCH 35/52] correct docstring in core/common.py --- xarray/core/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index b68f861134e..9f256f59210 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -748,7 +748,8 @@ def weighted(self, weights): according to its associated weight. Note ---- - Missing values in the weights are treated as 0 (i.e. no weight). + ``weights`` must be a ``DataArray`` and cannot contain missing values. + Missing values can be replaced by `weights.fillna(0)`. """ return self._weighted_cls(self, weights) From 668b54b43bc67cdbf38463914db09d25bb47e62c Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 14 Jan 2020 18:54:42 +0100 Subject: [PATCH 36/52] typos --- doc/gallery/area_weighted_temperature.py | 2 +- doc/whats-new.rst | 2 +- xarray/core/common.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/doc/gallery/area_weighted_temperature.py b/doc/gallery/area_weighted_temperature.py index d659574034c..3400b9203fb 100644 --- a/doc/gallery/area_weighted_temperature.py +++ b/doc/gallery/area_weighted_temperature.py @@ -4,7 +4,7 @@ ================================================ -Use ``air.weighted(weights).mean()`` to calculate the area weighted temperature +Use ``air.weighted(weights).mean()`` to calculate the area-weighted temperature for the air_temperature example dataset. This dataset has a regular latitude/ longitude grid, thus the gridcell area decreases towards the pole. For this grid we can use the cosine of the latitude as proxy for the grid cell area. Note how the weighted mean diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3e7daf66e88..2adfdab8a3e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,7 +34,7 @@ New Features - :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` and ``GroupBy.quantile`` now work with dask Variables. By `Deepak Cherian `_. -- Weighted array reductions are now supported via the new :py:meth:`DataArray.weighted` +- Weighted array reductions are now supported via the new :py:meth:`DataArray.weighted` and :py:meth:`Dataset.weighted` methods. By `Mathias Hauser `_ (:issue:`422`). - Added the :py:meth:`count` reduction method to both :py:class:`DatasetCoarsen` diff --git a/xarray/core/common.py b/xarray/core/common.py index 9df0e4dd585..17831fee046 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -740,16 +740,18 @@ def groupby_bins( def weighted(self, weights): """ Weighted operations. + Parameters ---------- weights : DataArray An array of weights associated with the values in this Dataset. Each value in the data contributes to the reduction operation according to its associated weight. + Note ---- ``weights`` must be a ``DataArray`` and cannot contain missing values. - Missing values can be replaced by `weights.fillna(0)`. + Missing values can be replaced by ``weights.fillna(0)``. """ return self._weighted_cls(self, weights) From d8770226b3f143d81bfdd71a82b9554205eda3ec Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 14 Jan 2020 18:57:13 +0100 Subject: [PATCH 37/52] adjust review --- xarray/core/weighted.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 1c0df0aa928..80721ffcafb 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -7,26 +7,26 @@ from .dataarray import DataArray, Dataset _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """ - Reduce this {cls}'s data by a weighted `{fcn}` along some dimension(s). + Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s). Parameters ---------- dim : str or sequence of str, optional - Dimension(s) over which to apply the weighted `{fcn}`. + Dimension(s) over which to apply the weighted ``{fcn}``. skipna : bool, optional If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional - If True, the attributes (`attrs`) will be copied from the original + If True, the attributes (``attrs``) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. Returns ------- reduced : {cls} - New {cls} object with weighted `{fcn}` applied to its data and + New {cls} object with weighted ``{fcn}`` applied to its data and the indicated dimension(s) removed. """ @@ -38,7 +38,7 @@ dim : str or sequence of str, optional Dimension(s) over which to sum the weights. keep_attrs : bool, optional - If True, the attributes (`attrs`) will be copied from the original + If True, the attributes (``attrs``) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. @@ -52,8 +52,8 @@ class Weighted: """A object that implements weighted operations. - You should create a Weighted object by using the `DataArray.weighted` or - `Dataset.weighted` methods. + You should create a Weighted object by using the ``DataArray.weighted`` or + ``Dataset.weighted`` methods. See Also -------- @@ -87,18 +87,21 @@ def __init__(self, obj, weights): # noqa: F811 Note ---- ``weights`` must be a ``DataArray`` and cannot contain missing values. - Missing values can be replaced by `weights.fillna(0)`. + Missing values can be replaced by ``weights.fillna(0)``. """ from .dataarray import DataArray - assert isinstance(weights, DataArray), "'weights' must be a DataArray" - - self.obj = obj + if not isinstance(weights, DataArray): + raise ValueError("`weights` must be a DataArray") if weights.isnull().any(): - raise ValueError("`weights` cannot contain missing values.") + raise ValueError( + "`weights` cannot contain missing values. " + "Missing values can be replaced by `weights.fillna(0)`." + ) + self.obj = obj self.weights = weights def _sum_of_weights( @@ -128,7 +131,7 @@ def _weighted_sum( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, ) -> "DataArray": - """Reduce a DataArray by a by a weighted `sum` along some dimension(s).""" + """Reduce a DataArray by a by a weighted ``sum`` along some dimension(s).""" # need to infer dims as we use `dot` if dim is None: @@ -147,7 +150,7 @@ def _weighted_mean( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, ) -> "DataArray": - """Reduce a DataArray by a weighted `mean` along some dimension(s).""" + """Reduce a DataArray by a weighted ``mean`` along some dimension(s).""" # get weighted sum weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna) @@ -160,8 +163,7 @@ def _weighted_mean( def _implementation(self, func, dim, **kwargs): - msg = "Use `Dataset.weighted` or `DataArray.weighted`" - raise NotImplementedError(msg) + raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`") def sum_of_weights( self, From ead681e12e2b415f2a8e21b37c030cd3f09295cc Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 14 Jan 2020 19:08:06 +0100 Subject: [PATCH 38/52] clean tests --- xarray/tests/test_weighted.py | 87 +++++++++-------------------------- 1 file changed, 23 insertions(+), 64 deletions(-) diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 1b240738efc..885160d609c 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -6,68 +6,26 @@ from xarray.tests import assert_allclose, assert_equal, raises_regex -@pytest.mark.parametrize("size", (1, 5, 100)) -def test_weighted__sum_of_weights_1D(size): +@pytest.mark.parametrize("data", (DataArray([1, 2]), Dataset(dict(data=[1, 2])))) +def test_weighted_non_DataArray_weights(data): - data = np.zeros(size) - # make sure weights is not 0 - weights = np.arange(1, size + 1) - - da = DataArray(data) - weights = DataArray(weights) - - expected = weights.sum() - - result = da.weighted(weights).sum_of_weights() - - assert_equal(expected, result) - - -@pytest.mark.parametrize("shape", ((2, 2), (2, 5), (10, 10))) -@pytest.mark.parametrize("dim", (None, "dim_0", "dim_1", ("dim_0", "dim_1"))) -def test_weighted__sum_of_weights_2D(shape, dim): - - np.random.seed(0) - - data = np.zeros(shape) - # make sure all weights are positive to avoid summing to 0 - weights = np.abs(np.random.randn(*shape)) - - da = DataArray(data) - weights = DataArray(weights) - - weighted = da.weighted(weights) - - expected = weights.sum(dim=dim) - result = weighted.sum_of_weights(dim=dim) - - assert_allclose(expected, result) - - -def test_weigted_non_DataArray_weights(): - - da = DataArray([1, 2]) - with raises_regex(AssertionError, "'weights' must be a DataArray"): - da.weighted([1, 2]) - - ds = Dataset(dict(data=[1, 2])) - with raises_regex(AssertionError, "'weights' must be a DataArray"): - ds.weighted([1, 2]) + with raises_regex(ValueError, "`weights` must be a DataArray"): + data.weighted([1, 2]) +@pytest.mark.parametrize("data", (DataArray([1, 2]), Dataset(dict(data=[1, 2])))) @pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan])) -def test_weighted_weights_nan_raises(weights): - # make sure NaNs in weights raise +def test_weighted_weights_nan_raises(data, weights): with pytest.raises(ValueError, match="`weights` cannot contain missing values."): - DataArray([1, 2]).weighted(DataArray(weights)) + data.weighted(DataArray(weights)) @pytest.mark.parametrize( ("weights", "expected"), - (([1, 2], 3), ([0, 2], 2), ([0, 0], np.nan), ([-1, 1], np.nan)), + (([1, 2], 3), ([2, 0], 2), ([0, 0], np.nan), ([-1, 1], np.nan)), ) -def test_weigted_sum_of_weights_no_nan(weights, expected): +def test_weighted_sum_of_weights_no_nan(weights, expected): da = DataArray([1, 2]) weights = DataArray(weights) @@ -80,9 +38,9 @@ def test_weigted_sum_of_weights_no_nan(weights, expected): @pytest.mark.parametrize( ("weights", "expected"), - (([1, 2], 2), ([0, 2], 2), ([0, 0], np.nan), ([-1, 1], 1), ([2, 0], np.nan)), + (([1, 2], 2), ([2, 0], np.nan), ([0, 0], np.nan), ([-1, 1], 1)), ) -def test_weigted_sum_of_weights_nan(weights, expected): +def test_weighted_sum_of_weights_nan(weights, expected): da = DataArray([np.nan, 2]) weights = DataArray(weights) @@ -93,14 +51,14 @@ def test_weigted_sum_of_weights_nan(weights, expected): assert_equal(expected, result) -@pytest.mark.parametrize("da", ([1, 2], [1, np.nan], [np.nan, np.nan])) -@pytest.mark.parametrize("factor", [0, 1, 2, 3.14]) +@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan])) +@pytest.mark.parametrize("factor", [0, 1, 3.14]) @pytest.mark.parametrize("skipna", (True, False)) def test_weighted_sum_equal_weights(da, factor, skipna): # if all weights are 'f'; weighted sum is f times the ordinary sum da = DataArray(da) - weights = xr.zeros_like(da) + factor + weights = xr.full_like(da, factor) expected = da.sum(skipna=skipna) * factor result = da.weighted(weights).sum(skipna=skipna) @@ -140,16 +98,16 @@ def test_weighted_sum_nan(weights, expected, skipna): @pytest.mark.filterwarnings("ignore:Mean of empty slice") -@pytest.mark.parametrize("da", ([1, 2], [1, np.nan], [np.nan, np.nan])) +@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan])) @pytest.mark.parametrize("skipna", (True, False)) @pytest.mark.parametrize("factor", [1, 2, 3.14]) -def test_weigted_mean_equal_weights(da, skipna, factor): +def test_weighted_mean_equal_weights(da, skipna, factor): # if all weights are equal (!= 0), should yield the same result as mean da = DataArray(da) # all weights as 1. - weights = xr.zeros_like(da) + factor + weights = xr.full_like(da, factor) expected = da.mean(skipna=skipna) result = da.weighted(weights).mean(skipna=skipna) @@ -177,7 +135,7 @@ def test_weighted_mean_no_nan(weights, expected): (([4, 6], 2.0), ([0, 1], 2.0), ([0, 2], 2.0), ([0, 0], np.nan)), ) @pytest.mark.parametrize("skipna", (True, False)) -def test_weigted_mean_nan(weights, expected, skipna): +def test_weighted_mean_nan(weights, expected, skipna): da = DataArray([np.nan, 2]) weights = DataArray(weights) @@ -193,7 +151,7 @@ def test_weigted_mean_nan(weights, expected, skipna): def expected_weighted(da, weights, dim, skipna, operation): - """ operations implemented via `*` and `sum`; da.Weighted uses `dot` + """ operations implemented via ``*`` and ``sum``; da.Weighted uses ``dot`` """ weighted_sum = (da * weights).sum(dim=dim, skipna=skipna) @@ -316,9 +274,10 @@ def test_weighted_operations_keep_attr(operator, as_dataset, keep_attrs): def test_weighted_operations_keep_attr_da_in_ds(operator): # GH #3595 - weights = DataArray(np.random.randn(2, 2), attrs=dict(attr="weights")) - data = DataArray(np.random.randn(4, 4, 4), attrs=dict(attr="data")) + weights = DataArray(np.random.randn(2, 2)) + data = DataArray(np.random.randn(2, 2), attrs=dict(attr="data")) + data = data.to_dataset(name="a") result = getattr(data.weighted(weights), operator)(keep_attrs=True) - assert data.data.attrs == result.attrs + assert data.a.attrs == result.a.attrs From c4598bafe092c9225b58ac450e7edbc14c036778 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 14 Jan 2020 19:11:16 +0100 Subject: [PATCH 39/52] add test nonequal coords --- xarray/tests/test_weighted.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 885160d609c..90f96adb402 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -207,6 +207,23 @@ def test_weighted_operations_3D(dim, operator, add_nans, skipna, as_dataset): assert_allclose(expected, result) +@pytest.mark.xfail(reason="GH: 3694") +@pytest.mark.parametrize("operator", ("sum_of_weights", "sum", "mean")) +@pytest.mark.parametrize("as_dataset", (True, False)) +def test_weighted_operations_nonequal_coords(operator, as_dataset): + + weights = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[0, 1, 2, 3])) + data = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[1, 2, 3, 4])) + + if as_dataset: + data = data.to_dataset(name="data") + + expected = expected_weighted(data, weights, dim="a", skipna=None, operator=operator) + result = getattr(data.weighted(weights), operator)(dim="a") + + assert_allclose(expected, result) + + @pytest.mark.parametrize("dim", ("dim_0", None)) @pytest.mark.parametrize("shape_data", ((4,), (4, 4), (4, 4, 4))) @pytest.mark.parametrize("shape_weights", ((4,), (4, 4), (4, 4, 4))) From 866fba54f8ca621265c1ef82339814a5ffd7b7c7 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 14 Jan 2020 19:31:09 +0100 Subject: [PATCH 40/52] comment on use of dot --- xarray/core/weighted.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 80721ffcafb..d28b25f9978 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -116,7 +116,9 @@ def _sum_of_weights( if dim is None: dim = ... - # use `dot` to avoid creating large DataArrays (if da and weights do not share all dims) + # `dot` does not broadcast arrays, so this avoids creating a large + # DataArray (if `weights` has additional dimensions) + # TODO: maybe add fasttrack (`(mask * weights).sum(dims=dim, skipna=skipna)`) sum_of_weights = dot(mask, self.weights, dims=dim) # find all weights that are valid (not 0) @@ -141,7 +143,9 @@ def _weighted_sum( if skipna or (skipna is None and da.dtype.kind in "cfO"): da = da.fillna(0.0) - # use `dot` to avoid creating large DataArrays + # `dot` does not broadcast arrays, so this avoids creating a large + # DataArray (if `weights` has additional dimensions) + # TODO: maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`) return dot(da, self.weights, dims=dim) def _weighted_mean( From 3cc00c121931b5dc0247b4ea5809c4051a4dacdd Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 14 Jan 2020 19:55:20 +0100 Subject: [PATCH 41/52] fix erroneous merge --- doc/whats-new.rst | 3 --- 1 file changed, 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 2adfdab8a3e..7a6ce4f00b0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,9 +37,6 @@ New Features - Weighted array reductions are now supported via the new :py:meth:`DataArray.weighted` and :py:meth:`Dataset.weighted` methods. By `Mathias Hauser `_ (:issue:`422`). -- Added the :py:meth:`count` reduction method to both :py:class:`DatasetCoarsen` - and :py:class:`DataArrayCoarsen` objects. (:pull:`3500`) - By `Deepak Cherian `_ - Added the ``count`` reduction method to both :py:class:`~core.rolling.DatasetCoarsen` and :py:class:`~core.rolling.DataArrayCoarsen` objects. (:pull:`3500`) By `Deepak Cherian `_ From 9f0a8cd2239086ecc6d237daeac7acd139e1b964 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 21 Jan 2020 15:47:10 +0100 Subject: [PATCH 42/52] update tests --- xarray/tests/test_weighted.py | 53 ++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 90f96adb402..c279be43d14 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -174,11 +174,11 @@ def expected_weighted(da, weights, dim, skipna, operation): @pytest.mark.parametrize("dim", ("a", "b", "c", ("a", "b"), ("a", "b", "c"), None)) -@pytest.mark.parametrize("operator", ("sum_of_weights", "sum", "mean")) +@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean")) @pytest.mark.parametrize("add_nans", (True, False)) @pytest.mark.parametrize("skipna", (None, True, False)) @pytest.mark.parametrize("as_dataset", (True, False)) -def test_weighted_operations_3D(dim, operator, add_nans, skipna, as_dataset): +def test_weighted_operations_3D(dim, operation, add_nans, skipna, as_dataset): dims = ("a", "b", "c") coords = dict(a=[0, 1, 2, 3], b=[0, 1, 2, 3], c=[0, 1, 2, 3]) @@ -197,20 +197,19 @@ def test_weighted_operations_3D(dim, operator, add_nans, skipna, as_dataset): if as_dataset: data = data.to_dataset(name="data") - if operator == "sum_of_weights": - result = getattr(data.weighted(weights), operator)(dim) + if operation == "sum_of_weights": + result = getattr(data.weighted(weights), operation)(dim) else: - result = getattr(data.weighted(weights), operator)(dim, skipna=skipna) + result = getattr(data.weighted(weights), operation)(dim, skipna=skipna) - expected = expected_weighted(data, weights, dim, skipna, operator) + expected = expected_weighted(data, weights, dim, skipna, operation) assert_allclose(expected, result) -@pytest.mark.xfail(reason="GH: 3694") -@pytest.mark.parametrize("operator", ("sum_of_weights", "sum", "mean")) +@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean")) @pytest.mark.parametrize("as_dataset", (True, False)) -def test_weighted_operations_nonequal_coords(operator, as_dataset): +def test_weighted_operations_nonequal_coords(operation, as_dataset): weights = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[0, 1, 2, 3])) data = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[1, 2, 3, 4])) @@ -218,8 +217,10 @@ def test_weighted_operations_nonequal_coords(operator, as_dataset): if as_dataset: data = data.to_dataset(name="data") - expected = expected_weighted(data, weights, dim="a", skipna=None, operator=operator) - result = getattr(data.weighted(weights), operator)(dim="a") + expected = expected_weighted( + data, weights, dim="a", skipna=None, operation=operation + ) + result = getattr(data.weighted(weights), operation)(dim="a") assert_allclose(expected, result) @@ -227,12 +228,12 @@ def test_weighted_operations_nonequal_coords(operator, as_dataset): @pytest.mark.parametrize("dim", ("dim_0", None)) @pytest.mark.parametrize("shape_data", ((4,), (4, 4), (4, 4, 4))) @pytest.mark.parametrize("shape_weights", ((4,), (4, 4), (4, 4, 4))) -@pytest.mark.parametrize("operator", ("sum_of_weights", "sum", "mean")) +@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean")) @pytest.mark.parametrize("add_nans", (True, False)) @pytest.mark.parametrize("skipna", (None, True, False)) @pytest.mark.parametrize("as_dataset", (True, False)) def test_weighted_operations_different_shapes( - dim, shape_data, shape_weights, operator, add_nans, skipna, as_dataset + dim, shape_data, shape_weights, operation, add_nans, skipna, as_dataset ): weights = DataArray(np.random.randn(*shape_weights)) @@ -249,20 +250,20 @@ def test_weighted_operations_different_shapes( if as_dataset: data = data.to_dataset(name="data") - if operator == "sum_of_weights": - result = getattr(data.weighted(weights), operator)(dim) + if operation == "sum_of_weights": + result = getattr(data.weighted(weights), operation)(dim) else: - result = getattr(data.weighted(weights), operator)(dim, skipna=skipna) + result = getattr(data.weighted(weights), operation)(dim, skipna=skipna) - expected = expected_weighted(data, weights, dim, skipna, operator) + expected = expected_weighted(data, weights, dim, skipna, operation) assert_allclose(expected, result) -@pytest.mark.parametrize("operator", ("sum_of_weights", "sum", "mean")) +@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean")) @pytest.mark.parametrize("as_dataset", (True, False)) @pytest.mark.parametrize("keep_attrs", (True, False, None)) -def test_weighted_operations_keep_attr(operator, as_dataset, keep_attrs): +def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs): weights = DataArray(np.random.randn(2, 2), attrs=dict(attr="weights")) data = DataArray(np.random.randn(2, 2)) @@ -272,29 +273,29 @@ def test_weighted_operations_keep_attr(operator, as_dataset, keep_attrs): data.attrs = dict(attr="weights") - result = getattr(data.weighted(weights), operator)(keep_attrs=True) + result = getattr(data.weighted(weights), operation)(keep_attrs=True) - if operator == "sum_of_weights": + if operation == "sum_of_weights": assert weights.attrs == result.attrs else: assert data.attrs == result.attrs - result = getattr(data.weighted(weights), operator)(keep_attrs=None) + result = getattr(data.weighted(weights), operation)(keep_attrs=None) assert not result.attrs - result = getattr(data.weighted(weights), operator)(keep_attrs=False) + result = getattr(data.weighted(weights), operation)(keep_attrs=False) assert not result.attrs @pytest.mark.xfail(reason="xr.Dataset.map does not copy attrs of DataArrays GH: 3595") -@pytest.mark.parametrize("operator", ("sum", "mean")) -def test_weighted_operations_keep_attr_da_in_ds(operator): +@pytest.mark.parametrize("operation", ("sum", "mean")) +def test_weighted_operations_keep_attr_da_in_ds(operation): # GH #3595 weights = DataArray(np.random.randn(2, 2)) data = DataArray(np.random.randn(2, 2), attrs=dict(attr="data")) data = data.to_dataset(name="a") - result = getattr(data.weighted(weights), operator)(keep_attrs=True) + result = getattr(data.weighted(weights), operation)(keep_attrs=True) assert data.a.attrs == result.a.attrs From 62c43e6bf3e8785e98f3d77330f8427d34b66a3a Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 5 Mar 2020 16:35:56 +0100 Subject: [PATCH 43/52] move example to notebook --- doc/examples.rst | 1 + doc/examples/area_weighted_temperature.ipynb | 163 +++++++++++++++++++ doc/gallery/area_weighted_temperature.py | 43 ----- 3 files changed, 164 insertions(+), 43 deletions(-) create mode 100644 doc/examples/area_weighted_temperature.ipynb delete mode 100644 doc/gallery/area_weighted_temperature.py diff --git a/doc/examples.rst b/doc/examples.rst index 805395808e0..1d48d29bcc5 100644 --- a/doc/examples.rst +++ b/doc/examples.rst @@ -6,6 +6,7 @@ Examples examples/weather-data examples/monthly-means + examples/area_weighted_temperature examples/multidimensional-coords examples/visualization_gallery examples/ROMS_ocean_model diff --git a/doc/examples/area_weighted_temperature.ipynb b/doc/examples/area_weighted_temperature.ipynb new file mode 100644 index 00000000000..656e9ed7dbc --- /dev/null +++ b/doc/examples/area_weighted_temperature.ipynb @@ -0,0 +1,163 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Compare weighted and unweighted mean temperature\n", + "\n", + "\n", + "Author: [Mathias Hauser](https://github.com/mathause/)\n", + "\n", + "The data used for this example can be found in the [xarray-data](https://github.com/pydata/xarray-data) repository. You may need to change the path to `air_temperature` below.\n", + "\n", + "We use the air_temperature example dataset to calculate the area-weighted temperature over its domain. This dataset has a regular latitude/ longitude grid, thus the gridcell area decreases towards the pole. For this grid we can use the cosine of the latitude as proxy for the grid cell area.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import cartopy.crs as ccrs\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "import xarray as xr" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data\n", + "\n", + "Load the data, convert to celsius, and resample to daily values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = xr.tutorial.load_dataset(\"air_temperature\")\n", + "\n", + "# to celsius\n", + "air = ds.air - 273.15\n", + "\n", + "# resample from 6-hourly to daily values\n", + "air = air.resample(time=\"D\").mean()\n", + "\n", + "air" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot the first timestep:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "projection = ccrs.LambertConformal(central_longitude=-95, central_latitude=45)\n", + "\n", + "f, ax = plt.subplots(subplot_kw=dict(projection=projection))\n", + "\n", + "air.isel(time=0).plot(transform=ccrs.PlateCarree(), cbar_kwargs=dict(shrink=0.7))\n", + "ax.coastlines()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Creating weights\n", + "\n", + "For a for a rectangular grid the cosine of the latitude is proportional to the grid cell area." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "weights = np.cos(np.deg2rad(air.lat))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Weighted mean" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "air_weighted = air.weighted(weights).mean((\"lon\", \"lat\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plot: comparison with unweighted mean\n", + "\n", + "Note how the weighted mean temperature is higher than the unweighted." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "air_weighted.plot(label=\"weighted\")\n", + "air.mean((\"lon\", \"lat\")).plot(label=\"unweighted\")\n", + "\n", + "plt.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/gallery/area_weighted_temperature.py b/doc/gallery/area_weighted_temperature.py deleted file mode 100644 index 3400b9203fb..00000000000 --- a/doc/gallery/area_weighted_temperature.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -================================================ -Compare weighted and unweighted mean temperature -================================================ - - -Use ``air.weighted(weights).mean()`` to calculate the area-weighted temperature -for the air_temperature example dataset. This dataset has a regular latitude/ longitude -grid, thus the gridcell area decreases towards the pole. For this grid we can use the -cosine of the latitude as proxy for the grid cell area. Note how the weighted mean -temperature is higher than the unweighted. - - -""" - -import matplotlib.pyplot as plt -import numpy as np - -import xarray as xr - -# Load the data -ds = xr.tutorial.load_dataset("air_temperature") -air = ds.air - 273.15 # to celsius - -# resample from 6-hourly to daily values -air = air.resample(time="D").mean() - -# the cosine of the latitude is proportional to the grid cell area (for a rectangular grid) -weights = np.cos(np.deg2rad(air.lat)) - -mean_air = air.weighted(weights).mean(("lat", "lon")) - -# Prepare the figure -f, ax = plt.subplots(1, 1) - -mean_air.plot(label="Area weighted mean") -air.mean(("lat", "lon")).plot(label="Unweighted mean") - -ax.legend() - -# Show -plt.tight_layout() -plt.show() From 2e8aba2712ad7d146f2d907b158aee244b07369c Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 5 Mar 2020 16:36:23 +0100 Subject: [PATCH 44/52] move whats-new entry to 15.1 --- doc/whats-new.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6e5203c4a90..a9f69824122 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,6 +25,9 @@ Breaking changes New Features ~~~~~~~~~~~~ +- Weighted array reductions are now supported via the new :py:meth:`DataArray.weighted` + and :py:meth:`Dataset.weighted` methods. By `Mathias Hauser `_ + (:issue:`422`). - Added support for :py:class:`pandas.DatetimeIndex`-style rounding of ``cftime.datetime`` objects directly via a :py:class:`CFTimeIndex` or via the :py:class:`~core.accessor_dt.DatetimeAccessor`. @@ -137,9 +140,6 @@ New Features - :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` and ``GroupBy.quantile`` now work with dask Variables. By `Deepak Cherian `_. -- Weighted array reductions are now supported via the new :py:meth:`DataArray.weighted` - and :py:meth:`Dataset.weighted` methods. By `Mathias Hauser `_ - (:issue:`422`). - Added the ``count`` reduction method to both :py:class:`~core.rolling.DatasetCoarsen` and :py:class:`~core.rolling.DataArrayCoarsen` objects. (:pull:`3500`) By `Deepak Cherian `_ From d14f668c6017efb96a73af7bfcad103513088987 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 5 Mar 2020 16:59:23 +0100 Subject: [PATCH 45/52] some doc updates --- xarray/core/weighted.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index d28b25f9978..60d7ce38aec 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -50,7 +50,7 @@ class Weighted: - """A object that implements weighted operations. + """An object that implements weighted operations. You should create a Weighted object by using the ``DataArray.weighted`` or ``Dataset.weighted`` methods. @@ -118,13 +118,12 @@ def _sum_of_weights( # `dot` does not broadcast arrays, so this avoids creating a large # DataArray (if `weights` has additional dimensions) - # TODO: maybe add fasttrack (`(mask * weights).sum(dims=dim, skipna=skipna)`) + # TODO: add fasttrack (`(mask * weights).sum(dims=dim, skipna=skipna)`) sum_of_weights = dot(mask, self.weights, dims=dim) - # find all weights that are valid (not 0) + # 0-weights are not valid valid_weights = sum_of_weights != 0.0 - # set invalid weights to nan return sum_of_weights.where(valid_weights) def _weighted_sum( @@ -145,7 +144,7 @@ def _weighted_sum( # `dot` does not broadcast arrays, so this avoids creating a large # DataArray (if `weights` has additional dimensions) - # TODO: maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`) + # maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`) return dot(da, self.weights, dims=dim) def _weighted_mean( @@ -156,13 +155,10 @@ def _weighted_mean( ) -> "DataArray": """Reduce a DataArray by a weighted ``mean`` along some dimension(s).""" - # get weighted sum weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna) - # get the sum of weights sum_of_weights = self._sum_of_weights(da, dim=dim) - # calculate weighted mean return weighted_sum / sum_of_weights def _implementation(self, func, dim, **kwargs): From 7fa78ae218160532a5c03205757bdf47255510c7 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 5 Mar 2020 17:34:01 +0100 Subject: [PATCH 46/52] dot to own function --- xarray/core/weighted.py | 45 ++++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 60d7ce38aec..af0b17ed753 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -104,22 +104,40 @@ def __init__(self, obj, weights): # noqa: F811 self.obj = obj self.weights = weights - def _sum_of_weights( - self, da: "DataArray", dim: Optional[Union[Hashable, Iterable[Hashable]]] = None + @staticmethod + def _reduce( + da: "DataArray", + weights: "DataArray", + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, ) -> "DataArray": - """ Calculate the sum of weights, accounting for missing values """ + """reduce using dot; equivalent to (da * weights).sum(dim, skipna) - # we need to mask data values that are nan; else the weights are wrong - mask = da.notnull() + for internal use only + """ # need to infer dims as we use `dot` if dim is None: dim = ... + # need to mask invalid values in da, as `dot` does not implement skipna + if skipna or (skipna is None and da.dtype.kind in "cfO"): + da = da.fillna(0.0) + # `dot` does not broadcast arrays, so this avoids creating a large # DataArray (if `weights` has additional dimensions) - # TODO: add fasttrack (`(mask * weights).sum(dims=dim, skipna=skipna)`) - sum_of_weights = dot(mask, self.weights, dims=dim) + # maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`) + return dot(da, weights, dims=dim) + + def _sum_of_weights( + self, da: "DataArray", dim: Optional[Union[Hashable, Iterable[Hashable]]] = None + ) -> "DataArray": + """ Calculate the sum of weights, accounting for missing values """ + + # we need to mask data values that are nan; else the weights are wrong + mask = da.notnull() + + sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False) # 0-weights are not valid valid_weights = sum_of_weights != 0.0 @@ -134,18 +152,7 @@ def _weighted_sum( ) -> "DataArray": """Reduce a DataArray by a by a weighted ``sum`` along some dimension(s).""" - # need to infer dims as we use `dot` - if dim is None: - dim = ... - - # need to mask invalid DATA as `dot` does not implement skipna - if skipna or (skipna is None and da.dtype.kind in "cfO"): - da = da.fillna(0.0) - - # `dot` does not broadcast arrays, so this avoids creating a large - # DataArray (if `weights` has additional dimensions) - # maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`) - return dot(da, self.weights, dims=dim) + return self._reduce(da, self.weights, dim=dim, skipna=skipna) def _weighted_mean( self, From 3ebb9d4337fe7d226bcbdb75c010e7f4f6a6fa9e Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 5 Mar 2020 19:15:05 +0100 Subject: [PATCH 47/52] simplify some tests --- xarray/tests/test_weighted.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index c279be43d14..f622a88d31b 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -2,20 +2,28 @@ import pytest import xarray as xr -from xarray import DataArray, Dataset +from xarray import DataArray from xarray.tests import assert_allclose, assert_equal, raises_regex -@pytest.mark.parametrize("data", (DataArray([1, 2]), Dataset(dict(data=[1, 2])))) -def test_weighted_non_DataArray_weights(data): +@pytest.mark.parametrize("as_dataset", (True, False)) +def test_weighted_non_DataArray_weights(as_dataset): + + data = DataArray([1, 2]) + if as_dataset: + data = data.to_dataset(name="data") with raises_regex(ValueError, "`weights` must be a DataArray"): data.weighted([1, 2]) -@pytest.mark.parametrize("data", (DataArray([1, 2]), Dataset(dict(data=[1, 2])))) +@pytest.mark.parametrize("as_dataset", (True, False)) @pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan])) -def test_weighted_weights_nan_raises(data, weights): +def test_weighted_weights_nan_raises(as_dataset, weights): + + data = DataArray([1, 2]) + if as_dataset: + data = data.to_dataset(name="data") with pytest.raises(ValueError, match="`weights` cannot contain missing values."): data.weighted(DataArray(weights)) @@ -70,6 +78,7 @@ def test_weighted_sum_equal_weights(da, factor, skipna): ("weights", "expected"), (([1, 2], 5), ([0, 2], 4), ([0, 0], 0)) ) def test_weighted_sum_no_nan(weights, expected): + da = DataArray([1, 2]) weights = DataArray(weights) @@ -84,6 +93,7 @@ def test_weighted_sum_no_nan(weights, expected): ) @pytest.mark.parametrize("skipna", (True, False)) def test_weighted_sum_nan(weights, expected, skipna): + da = DataArray([np.nan, 2]) weights = DataArray(weights) @@ -116,8 +126,7 @@ def test_weighted_mean_equal_weights(da, skipna, factor): @pytest.mark.parametrize( - ("weights", "expected"), - (([4, 6], 1.6), ([0, 1], 2.0), ([0, 2], 2.0), ([0, 0], np.nan)), + ("weights", "expected"), (([4, 6], 1.6), ([1, 0], 1.0), ([0, 0], np.nan)), ) def test_weighted_mean_no_nan(weights, expected): @@ -131,8 +140,7 @@ def test_weighted_mean_no_nan(weights, expected): @pytest.mark.parametrize( - ("weights", "expected"), - (([4, 6], 2.0), ([0, 1], 2.0), ([0, 2], 2.0), ([0, 0], np.nan)), + ("weights", "expected"), (([4, 6], 2.0), ([1, 0], np.nan), ([0, 0], np.nan)), ) @pytest.mark.parametrize("skipna", (True, False)) def test_weighted_mean_nan(weights, expected, skipna): @@ -198,7 +206,7 @@ def test_weighted_operations_3D(dim, operation, add_nans, skipna, as_dataset): data = data.to_dataset(name="data") if operation == "sum_of_weights": - result = getattr(data.weighted(weights), operation)(dim) + result = data.weighted(weights).sum_of_weights(dim) else: result = getattr(data.weighted(weights), operation)(dim, skipna=skipna) From f01d47a90f83585485633101dcf14f59d6271793 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 17 Mar 2020 08:45:30 -0600 Subject: [PATCH 48/52] Doc updates --- doc/api.rst | 16 ++++ doc/computation.rst | 12 ++- doc/examples/area_weighted_temperature.ipynb | 99 ++++++++++++++++---- xarray/core/common.py | 6 +- xarray/core/weighted.py | 4 +- 5 files changed, 109 insertions(+), 28 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 56743420f47..43a9cf53ead 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -579,6 +579,22 @@ Rolling objects core.rolling.DatasetRolling.reduce core.rolling_exp.RollingExp +Weighted objects +================ + +.. autosummary:: + :toctree: generated/ + + core.weighted.DataArrayWeighted + core.weighted.DataArrayWeighted.mean + core.weighted.DataArrayWeighted.sum + core.weighted.DataArrayWeighted.sum_of_weights + core.weighted.DatasetWeighted + core.weighted.DatasetWeighted.mean + core.weighted.DatasetWeighted.sum + core.weighted.DatasetWeighted.sum_of_weights + + Coarsen objects =============== diff --git a/doc/computation.rst b/doc/computation.rst index 4667bb2b377..b2b625c29c2 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -1,3 +1,5 @@ +.. currentmodule:: xarray + .. _comput: ########### @@ -246,8 +248,8 @@ You can also use ``construct`` to compute a weighted rolling sum: Weighted array reductions ========================= -``DataArray`` and ``Dataset`` objects include :py:meth:`~xarray.DataArray.weighted` -and :py:meth:`~xarray.Dataset.weighted` array reduction methods. They currently +:py:class:`DataArray` and :py:class:`Dataset` objects include :py:meth:`DataArray.weighted` +and :py:meth:`Dataset.weighted` array reduction methods. They currently support weighted ``sum`` and weighted ``mean``. .. ipython:: python @@ -303,13 +305,13 @@ Using ``(data * weights).sum() / weights.sum()`` would (incorrectly) result in 0.6. .. note:: - ``weights`` must be a ``DataArray`` and cannot contain missing values. - Missing values can be replaced manually by `weights.fillna(0)`. + ``weights`` must be a :py:class:`DataArray` and cannot contain missing values. + Missing values can be replaced manually by ``weights.fillna(0)``. Coarsen large arrays ==================== -``DataArray`` and ``Dataset`` objects include a +:py:class:`DataArray` and :py:class:`Dataset` objects include a :py:meth:`~xarray.DataArray.coarsen` and :py:meth:`~xarray.Dataset.coarsen` methods. This supports the block aggregation along multiple dimensions, diff --git a/doc/examples/area_weighted_temperature.ipynb b/doc/examples/area_weighted_temperature.ipynb index 656e9ed7dbc..72876e3fc29 100644 --- a/doc/examples/area_weighted_temperature.ipynb +++ b/doc/examples/area_weighted_temperature.ipynb @@ -1,5 +1,15 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": { + "toc": true + }, + "source": [ + "

Table of Contents

\n", + "" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -9,15 +19,19 @@ "\n", "Author: [Mathias Hauser](https://github.com/mathause/)\n", "\n", - "The data used for this example can be found in the [xarray-data](https://github.com/pydata/xarray-data) repository. You may need to change the path to `air_temperature` below.\n", "\n", - "We use the air_temperature example dataset to calculate the area-weighted temperature over its domain. This dataset has a regular latitude/ longitude grid, thus the gridcell area decreases towards the pole. For this grid we can use the cosine of the latitude as proxy for the grid cell area.\n" + "We use the `air_temperature` example dataset to calculate the area-weighted temperature over its domain. This dataset has a regular latitude/ longitude grid, thus the gridcell area decreases towards the pole. For this grid we can use the cosine of the latitude as proxy for the grid cell area.\n" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:43:57.222351Z", + "start_time": "2020-03-17T14:43:56.147541Z" + } + }, "outputs": [], "source": [ "%matplotlib inline\n", @@ -41,7 +55,12 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:43:57.831734Z", + "start_time": "2020-03-17T14:43:57.651845Z" + } + }, "outputs": [], "source": [ "ds = xr.tutorial.load_dataset(\"air_temperature\")\n", @@ -65,7 +84,12 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:43:59.887120Z", + "start_time": "2020-03-17T14:43:59.582894Z" + } + }, "outputs": [], "source": [ "projection = ccrs.LambertConformal(central_longitude=-95, central_latitude=45)\n", @@ -88,10 +112,17 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:44:18.777092Z", + "start_time": "2020-03-17T14:44:18.736587Z" + } + }, "outputs": [], "source": [ - "weights = np.cos(np.deg2rad(air.lat))" + "weights = np.cos(np.deg2rad(air.lat))\n", + "weights.name = \"weights\"\n", + "weights" ] }, { @@ -104,10 +135,31 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:44:52.607120Z", + "start_time": "2020-03-17T14:44:52.564674Z" + } + }, "outputs": [], "source": [ - "air_weighted = air.weighted(weights).mean((\"lon\", \"lat\"))" + "air_weighted = air.weighted(weights)\n", + "air_weighted" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:44:54.334279Z", + "start_time": "2020-03-17T14:44:54.280022Z" + } + }, + "outputs": [], + "source": [ + "weighted_mean = air_weighted.mean((\"lon\", \"lat\"))\n", + "weighted_mean" ] }, { @@ -122,21 +174,19 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:45:08.877307Z", + "start_time": "2020-03-17T14:45:08.673383Z" + } + }, "outputs": [], "source": [ - "air_weighted.plot(label=\"weighted\")\n", + "weighted_mean.plot(label=\"weighted\")\n", "air.mean((\"lon\", \"lat\")).plot(label=\"unweighted\")\n", "\n", "plt.legend()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -156,6 +206,19 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": true, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true } }, "nbformat": 4, diff --git a/xarray/core/common.py b/xarray/core/common.py index cee1719c32d..be59167e93f 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -759,9 +759,9 @@ def weighted(self, weights): Each value in the data contributes to the reduction operation according to its associated weight. - Note - ---- - ``weights`` must be a ``DataArray`` and cannot contain missing values. + Notes + ----- + ``weights`` must be a :py:class:`DataArray` and cannot contain missing values. Missing values can be replaced by ``weights.fillna(0)``. """ diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index af0b17ed753..33f6f20b2dc 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -84,8 +84,8 @@ def __init__(self, obj, weights): # noqa: F811 Each value in the obj contributes to the reduction operation according to its associated weight. - Note - ---- + Notes + ----- ``weights`` must be a ``DataArray`` and cannot contain missing values. Missing values can be replaced by ``weights.fillna(0)``. """ From 4b184f6b9a8f798454c6133e9e941d9206b48442 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 17 Mar 2020 10:30:39 -0600 Subject: [PATCH 49/52] very minor changes. --- xarray/core/weighted.py | 7 +++---- xarray/tests/test_weighted.py | 8 +++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 33f6f20b2dc..ee900e8a38b 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -207,10 +207,9 @@ def mean( def __repr__(self): """provide a nice str repr of our Weighted object""" - msg = "{klass} with weights along dimensions: {weight_dims}" - return msg.format( - klass=self.__class__.__name__, weight_dims=", ".join(self.weights.dims) - ) + klass = self.__class__.__name__ + weight_dims = ", ".join(self.weights.dims) + return f"{klass} with weights along dimensions: {weight_dims}" class DataArrayWeighted(Weighted): diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index f622a88d31b..24531215dfb 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -126,7 +126,7 @@ def test_weighted_mean_equal_weights(da, skipna, factor): @pytest.mark.parametrize( - ("weights", "expected"), (([4, 6], 1.6), ([1, 0], 1.0), ([0, 0], np.nan)), + ("weights", "expected"), (([4, 6], 1.6), ([1, 0], 1.0), ([0, 0], np.nan)) ) def test_weighted_mean_no_nan(weights, expected): @@ -140,7 +140,7 @@ def test_weighted_mean_no_nan(weights, expected): @pytest.mark.parametrize( - ("weights", "expected"), (([4, 6], 2.0), ([1, 0], np.nan), ([0, 0], np.nan)), + ("weights", "expected"), (([4, 6], 2.0), ([1, 0], np.nan), ([0, 0], np.nan)) ) @pytest.mark.parametrize("skipna", (True, False)) def test_weighted_mean_nan(weights, expected, skipna): @@ -159,7 +159,9 @@ def test_weighted_mean_nan(weights, expected, skipna): def expected_weighted(da, weights, dim, skipna, operation): - """ operations implemented via ``*`` and ``sum``; da.Weighted uses ``dot`` + """ + Generate expected result using ``*`` and ``sum``. This is checked against + the result of da.weighted which uses ``dot`` """ weighted_sum = (da * weights).sum(dim=dim, skipna=skipna) From 1e06adc6c069746c25f3bfc3ec7a5397255fd263 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 17 Mar 2020 10:37:17 -0600 Subject: [PATCH 50/52] fix & add references --- doc/computation.rst | 4 +++- doc/whats-new.rst | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/doc/computation.rst b/doc/computation.rst index b2b625c29c2..9c0dee91e6d 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -243,7 +243,7 @@ You can also use ``construct`` to compute a weighted rolling sum: To avoid this, use ``skipna=False`` as the above example. -.. _comput.coarsen: +.. _comput.weighted: Weighted array reductions ========================= @@ -308,6 +308,8 @@ in 0.6. ``weights`` must be a :py:class:`DataArray` and cannot contain missing values. Missing values can be replaced manually by ``weights.fillna(0)``. +.. _comput.coarsen: + Coarsen large arrays ==================== diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a9f69824122..6acc7b1eab1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,8 +26,8 @@ New Features ~~~~~~~~~~~~ - Weighted array reductions are now supported via the new :py:meth:`DataArray.weighted` - and :py:meth:`Dataset.weighted` methods. By `Mathias Hauser `_ - (:issue:`422`). + and :py:meth:`Dataset.weighted` methods. See :ref:`comput.weighted`. (:issue:`422`, :pull:`2922`). + By `Mathias Hauser `_ - Added support for :py:class:`pandas.DatetimeIndex`-style rounding of ``cftime.datetime`` objects directly via a :py:class:`CFTimeIndex` or via the :py:class:`~core.accessor_dt.DatetimeAccessor`. From 706579ab748cc71498c4f69a1434d209ba25c1b0 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 17 Mar 2020 18:49:32 +0100 Subject: [PATCH 51/52] doc: return 0/NaN on 0 weights --- doc/computation.rst | 17 +++++++++++++++++ xarray/core/weighted.py | 16 ++++++++++++---- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/doc/computation.rst b/doc/computation.rst index 4667bb2b377..c5a99651fc8 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -302,6 +302,23 @@ However, the functions also take missing values in the data into account: Using ``(data * weights).sum() / weights.sum()`` would (incorrectly) result in 0.6. + +If the weights add up to to 0, ``sum`` returns 0: + +.. ipython:: python + + data = xr.DataArray([1.0, 1.0]) + weights = xr.DataArray([-1.0, 1.0]) + + data.weighted(weights).sum() + +and ``mean`` returns ``NaN``: + +.. ipython:: python + + data.weighted(weights).mean() + + .. note:: ``weights`` must be a ``DataArray`` and cannot contain missing values. Missing values can be replaced manually by `weights.fillna(0)`. diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index af0b17ed753..04fcc9d497c 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -28,6 +28,11 @@ reduced : {cls} New {cls} object with weighted ``{fcn}`` applied to its data and the indicated dimension(s) removed. + + Notes + ----- + Returns {on_zero} if the ``weights`` sum to 0.0 along the reduced + dimension(s). """ _SUM_OF_WEIGHTS_DOCSTRING = """ @@ -238,10 +243,13 @@ def _inject_docstring(cls, cls_name): cls.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(cls=cls_name) - for operator in ["sum", "mean"]: - getattr(cls, operator).__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( - cls=cls_name, fcn=operator - ) + cls.sum.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls=cls_name, fcn="sum", on_zero="0" + ) + + cls.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls=cls_name, fcn="mean", on_zero="NaN" + ) _inject_docstring(DataArrayWeighted, "DataArray") From 8acc78ef91ec2e808eac852a526dd90fe859f39b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 17 Mar 2020 19:41:56 -0600 Subject: [PATCH 52/52] Update xarray/core/common.py --- xarray/core/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index d850536e20d..a003642076f 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -758,7 +758,7 @@ def weighted(self, weights): Notes ----- - ``weights`` must be a :py:class:`DataArray` and cannot contain missing values. + ``weights`` must be a DataArray and cannot contain missing values. Missing values can be replaced by ``weights.fillna(0)``. """