From 565557160682e6851b807148293ba8ea22894ef1 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sat, 30 Jan 2021 12:07:39 +0100 Subject: [PATCH 01/26] add a apply_to_dataset method --- xarray/core/common.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/xarray/core/common.py b/xarray/core/common.py index a69ba03a7a4..053f6b9d509 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -635,6 +635,21 @@ def pipe( else: return func(self, *args, **kwargs) + def apply_to_dataset(self, f): + from .dataarray import DataArray + + if isinstance(self, DataArray): + ds = self._to_temp_dataset() + else: + ds = self + + result = f(ds) + + if isinstance(self, DataArray): + return self._from_temp_dataset(result, name=self.name) + else: + return result + def groupby(self, group, squeeze: bool = True, restore_coord_dims: bool = None): """Returns a GroupBy object for performing grouped operations. From 90f8d55288fd5251cfb4b72195da1c4e57ec3805 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sat, 30 Jan 2021 12:24:07 +0100 Subject: [PATCH 02/26] write a test for apply_to_dataset on a DataArray --- xarray/tests/test_dataarray.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index fc84687511e..f09556c9b71 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2617,6 +2617,17 @@ def test_fillna(self): actual = a.groupby("b").fillna(DataArray([0, 2], dims="b")) assert_identical(expected, actual) + def test_apply_to_dataset(self): + def func(ds): + return Dataset(ds.data_vars, coords=ds.coords) + + da = DataArray( + [[0, 1], [2, 3], [4, 5]], + dims=("x", "y"), + name="abc", + ) + assert_identical(da, da.apply_to_dataset(func)) + def test_groupby_iter(self): for ((act_x, act_dv), (exp_x, exp_ds)) in zip( self.dv.groupby("y"), self.ds.groupby("y") From fd2c897f4ffc217a7eb5f594194af22aba924a34 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sat, 30 Jan 2021 12:28:03 +0100 Subject: [PATCH 03/26] also add a test for dataset --- xarray/tests/test_dataset.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index fed9098701b..e62e0ee5732 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5028,6 +5028,13 @@ def test_count(self): actual = ds.count() assert_identical(expected, actual) + def test_apply_to_dataset(self): + def func(ds): + return Dataset(ds.data_vars, coords=ds.coords) + + ds = create_test_data() + assert_identical(ds, ds.apply_to_dataset(func)) + def test_map(self): data = create_test_data() data.attrs["foo"] = "bar" From 857c783d0ed3debbff0bd7570654574bfa58c91e Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 5 Feb 2021 02:04:44 +0100 Subject: [PATCH 04/26] convert apply_to_dataset to a top-level function --- xarray/__init__.py | 11 +++++++- xarray/core/common.py | 15 ----------- xarray/core/computation.py | 36 ++++++++++++++++++++++++++ xarray/tests/test_computation.py | 43 ++++++++++++++++++++++++++++++++ xarray/tests/test_dataarray.py | 11 -------- xarray/tests/test_dataset.py | 7 ------ 6 files changed, 89 insertions(+), 34 deletions(-) diff --git a/xarray/__init__.py b/xarray/__init__.py index 3886edc60e6..2d7fc8a9be4 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -18,7 +18,15 @@ from .core.alignment import align, broadcast from .core.combine import combine_by_coords, combine_nested from .core.common import ALL_DIMS, full_like, ones_like, zeros_like -from .core.computation import apply_ufunc, corr, cov, dot, polyval, where +from .core.computation import ( + apply_to_dataset, + apply_ufunc, + corr, + cov, + dot, + polyval, + where, +) from .core.concat import concat from .core.dataarray import DataArray from .core.dataset import Dataset @@ -46,6 +54,7 @@ # Top-level functions "align", "apply_ufunc", + "apply_to_dataset", "as_variable", "broadcast", "cftime_range", diff --git a/xarray/core/common.py b/xarray/core/common.py index 053f6b9d509..a69ba03a7a4 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -635,21 +635,6 @@ def pipe( else: return func(self, *args, **kwargs) - def apply_to_dataset(self, f): - from .dataarray import DataArray - - if isinstance(self, DataArray): - ds = self._to_temp_dataset() - else: - ds = self - - result = f(ds) - - if isinstance(self, DataArray): - return self._from_temp_dataset(result, name=self.name) - else: - return result - def groupby(self, group, squeeze: bool = True, restore_coord_dims: bool = None): """Returns a GroupBy object for performing grouped operations. diff --git a/xarray/core/computation.py b/xarray/core/computation.py index e0d9ff4b218..239ead3f286 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1142,6 +1142,42 @@ def earth_mover_distance(first_samples, return apply_array_ufunc(func, *args, dask=dask) +def apply_to_dataset(func, obj, *args, **kwargs): + """apply a function expecting a Dataset to a xarray object + + Parameters + ---------- + func : callable + A function expecting a Dataset as its first parameter. + obj : DataArray or Dataset + The dataset to apply ``func`` to. If a ``DataArray``, convert it to a single + variable ``Dataset`` first. + *args, **kwargs + Additional arguments to ``func`` + + Returns + ------- + DataArray or Dataset + The result of ``func(obj, *args, **kwargs)`` with the same type as ``obj``. + + Notes + ----- + If a ``DataArray``, result will have the same name as ``obj`` but the single data + variable in the temporary ``Dataset`` will always have a generic name. + """ + from .dataarray import DataArray + + ds = obj._to_temp_dataset() if isinstance(obj, DataArray) else obj + + result = func(ds, *args, **kwargs) + + return ( + obj._from_temp_dataset(result, name=obj.name) + if isinstance(obj, DataArray) + else result + ) + + def cov(da_a, da_b, dim=None, ddof=1): """ Compute covariance between two DataArray objects along a shared dimension. diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 4890536a5d7..e362cf07589 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -468,6 +468,49 @@ def test_apply_groupby_add(): add(data_array.groupby("y"), data_array.groupby("x")) +@pytest.mark.parametrize( + ["obj", "expected"], + ( + pytest.param( + xr.DataArray( + [0, 1], + coords={ + "x": ("x", [-1, 1], {"a": 1, "b": 2}), + "u": ("x", [2, 3], {"c": 3}), + }, + dims="x", + attrs={"d": 4, "e": 5}, + ), + xr.DataArray([0, 1], coords={"x": [-1, 1], "u": ("x", [2, 3])}, dims="x"), + id="DataArray", + ), + pytest.param( + xr.Dataset( + {"a": ("x", [1, 2], {"a": 1, "b": 2}), "b": ("x", [0, 1], {"c": 3})}, + coords={ + "x": ("x", [-1, 1], {"d": 4, "e": 5}), + "u": ("x", [2, 3], {"f": 6}), + }, + ), + xr.Dataset( + {"a": ("x", [1, 2]), "b": ("x", [0, 1])}, + coords={"x": [-1, 1], "u": ("x", [2, 3])}, + ), + id="Dataset", + ), + ), +) +def test_apply_to_dataset(obj, expected): + def clear_all_attrs(ds): + new_ds = ds.copy() + for var in new_ds.variables.values(): + var.attrs.clear() + new_ds.attrs.clear() + return new_ds + + assert_identical(expected, xr.apply_to_dataset(clear_all_attrs, obj)) + + def test_unified_dim_sizes(): assert unified_dim_sizes([xr.Variable((), 0)]) == {} assert unified_dim_sizes([xr.Variable("x", [1]), xr.Variable("x", [1])]) == {"x": 1} diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index f09556c9b71..fc84687511e 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2617,17 +2617,6 @@ def test_fillna(self): actual = a.groupby("b").fillna(DataArray([0, 2], dims="b")) assert_identical(expected, actual) - def test_apply_to_dataset(self): - def func(ds): - return Dataset(ds.data_vars, coords=ds.coords) - - da = DataArray( - [[0, 1], [2, 3], [4, 5]], - dims=("x", "y"), - name="abc", - ) - assert_identical(da, da.apply_to_dataset(func)) - def test_groupby_iter(self): for ((act_x, act_dv), (exp_x, exp_ds)) in zip( self.dv.groupby("y"), self.ds.groupby("y") diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index e62e0ee5732..fed9098701b 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5028,13 +5028,6 @@ def test_count(self): actual = ds.count() assert_identical(expected, actual) - def test_apply_to_dataset(self): - def func(ds): - return Dataset(ds.data_vars, coords=ds.coords) - - ds = create_test_data() - assert_identical(ds, ds.apply_to_dataset(func)) - def test_map(self): data = create_test_data() data.attrs["foo"] = "bar" From 57e94b64b3d6e3a5cffd714837cae3c3fb58d8a9 Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 5 Feb 2021 22:22:00 +0100 Subject: [PATCH 05/26] update whats-new.rst --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 488d8baa650..403d7abff3a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -55,6 +55,9 @@ New Features - :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims in the form of kwargs as well as a dict, like most similar methods. By `Maximilian Roos `_. +- Add :py:func:`apply_to_dataset` as a way to apply functions expecting + :py:class:`Dataset` objects to :py:class:`DataArray` objects (:issue:`4837`, :pull:`4863`). + By `Justus Magin `_. Bug fixes ~~~~~~~~~ From cdb0f3d89d3776371083ff8353781b3333bb2492 Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 5 Feb 2021 22:22:30 +0100 Subject: [PATCH 06/26] add the new function to api.rst [skip-ci] --- doc/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/api.rst b/doc/api.rst index 9cb02441d37..ff2954a2d9c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -17,6 +17,7 @@ Top-level functions :toctree: generated/ apply_ufunc + apply_to_dataset align broadcast concat From 1d81a49c39fa5f21772787bbf315ac497e88431c Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 5 Feb 2021 23:52:09 +0100 Subject: [PATCH 07/26] rephrase the note [skip-ci] --- xarray/core/computation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 239ead3f286..8a50f35852e 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1162,8 +1162,8 @@ def apply_to_dataset(func, obj, *args, **kwargs): Notes ----- - If a ``DataArray``, result will have the same name as ``obj`` but the single data - variable in the temporary ``Dataset`` will always have a generic name. + If a ``DataArray``, the data variable of the temporary ``Dataset`` will have a + generic name. The original name will be restored for the result of the call. """ from .dataarray import DataArray From 88fe863b942165c1aab6550c7301e711808a1d4d Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 5 Feb 2021 23:55:26 +0100 Subject: [PATCH 08/26] add a see also section [skip-ci] --- xarray/core/computation.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 8a50f35852e..98883fcd6b5 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1164,6 +1164,12 @@ def apply_to_dataset(func, obj, *args, **kwargs): ----- If a ``DataArray``, the data variable of the temporary ``Dataset`` will have a generic name. The original name will be restored for the result of the call. + + See Also + -------- + Dataset.map + Dataset.pipe + DataArray.pipe """ from .dataarray import DataArray From 0daf42dd0267025b58d11670bf7c75a3482d2b74 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sat, 6 Feb 2021 00:03:26 +0100 Subject: [PATCH 09/26] add examples [skip-ci] --- xarray/core/computation.py | 43 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 98883fcd6b5..08c55a1a775 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1170,6 +1170,49 @@ def apply_to_dataset(func, obj, *args, **kwargs): Dataset.map Dataset.pipe DataArray.pipe + + Examples + -------- + >>> def f(ds): + ... return xr.Dataset( + ... { + ... name: var * var.attrs.get("scale", 1) + ... for name, var in ds.data_vars.items() + ... }, + ... coords=ds.coords, + ... attrs=ds.attrs, + ... ) + ... + >>> ds = xr.Dataset( + ... {"a": ("x", [3, 4], {"scale": 0.5}), "b": ("x", [-1, 1], {"scale": 1.5})}, + ... coords={"x": [0, 1]}, + ... attrs={"attr": "value"}, + ... ) + >>> ds + + Dimensions: (x: 2) + Coordinates: + * x (x) int64 0 1 + Data variables: + a (x) int64 3 4 + b (x) int64 -1 1 + Attributes: + attr: value + >>> xr.apply_to_dataset(f, ds) + + Dimensions: (x: 2) + Coordinates: + * x (x) int64 0 1 + Data variables: + a (x) float64 1.5 2.0 + b (x) float64 -1.5 1.5 + Attributes: + attr: value + >>> xr.apply_to_dataset(f, ds.a) + + array([1.5, 2. ]) + Coordinates: + * x (x) int64 0 1 """ from .dataarray import DataArray From 559d8ef6f943f66a2b768e7eb4b838a638d64a25 Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 15 Mar 2021 17:21:26 +0100 Subject: [PATCH 10/26] rename to call_on_dataset --- xarray/__init__.py | 4 ++-- xarray/core/computation.py | 6 +++--- xarray/tests/test_computation.py | 5 +++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/xarray/__init__.py b/xarray/__init__.py index 2d7fc8a9be4..1ef962eb479 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -19,8 +19,8 @@ from .core.combine import combine_by_coords, combine_nested from .core.common import ALL_DIMS, full_like, ones_like, zeros_like from .core.computation import ( - apply_to_dataset, apply_ufunc, + call_on_dataset, corr, cov, dot, @@ -54,7 +54,7 @@ # Top-level functions "align", "apply_ufunc", - "apply_to_dataset", + "call_on_dataset", "as_variable", "broadcast", "cftime_range", diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 4a4945b44ad..eb405b585c7 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1141,7 +1141,7 @@ def earth_mover_distance(first_samples, return apply_array_ufunc(func, *args, dask=dask) -def apply_to_dataset(func, obj, *args, **kwargs): +def call_on_dataset(func, obj, *args, **kwargs): """apply a function expecting a Dataset to a xarray object Parameters @@ -1197,7 +1197,7 @@ def apply_to_dataset(func, obj, *args, **kwargs): b (x) int64 -1 1 Attributes: attr: value - >>> xr.apply_to_dataset(f, ds) + >>> xr.call_on_dataset(f, ds) Dimensions: (x: 2) Coordinates: @@ -1207,7 +1207,7 @@ def apply_to_dataset(func, obj, *args, **kwargs): b (x) float64 -1.5 1.5 Attributes: attr: value - >>> xr.apply_to_dataset(f, ds.a) + >>> xr.call_on_dataset(f, ds.a) array([1.5, 2. ]) Coordinates: diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index e362cf07589..8d8a186e3d1 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -500,7 +500,7 @@ def test_apply_groupby_add(): ), ), ) -def test_apply_to_dataset(obj, expected): +def test_call_on_dataset(obj, expected): def clear_all_attrs(ds): new_ds = ds.copy() for var in new_ds.variables.values(): @@ -508,7 +508,8 @@ def clear_all_attrs(ds): new_ds.attrs.clear() return new_ds - assert_identical(expected, xr.apply_to_dataset(clear_all_attrs, obj)) + actual = xr.call_on_dataset(clear_all_attrs, obj) + assert_identical(actual, expected) def test_unified_dim_sizes(): From 0c424bf97ebeca1690b23014d51ce000babecc79 Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 15 Mar 2021 17:21:59 +0100 Subject: [PATCH 11/26] preserve the name as much as possible --- xarray/core/computation.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index eb405b585c7..7701b899ed6 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1213,17 +1213,24 @@ def call_on_dataset(func, obj, *args, **kwargs): Coordinates: * x (x) int64 0 1 """ - from .dataarray import DataArray + from .dataarray import _THIS_ARRAY, DataArray + from .parallel import dataarray_to_dataset, dataset_to_dataarray - ds = obj._to_temp_dataset() if isinstance(obj, DataArray) else obj + if isinstance(obj, DataArray): + ds = dataarray_to_dataset(obj) + if obj.name is None: + ds = ds.rename({_THIS_ARRAY: None}) + else: + ds = obj result = func(ds, *args, **kwargs) - return ( - obj._from_temp_dataset(result, name=obj.name) - if isinstance(obj, DataArray) - else result - ) + if isinstance(obj, DataArray): + result = dataset_to_dataarray(result) + if obj.name is None: + result = result.rename(None) + + return result def cov(da_a, da_b, dim=None, ddof=1): From 8db9e7e50cbbe12caef9c584cc34790ef8ac19af Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 15 Mar 2021 17:22:58 +0100 Subject: [PATCH 12/26] update api.rst --- doc/api.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/api.rst b/doc/api.rst index ff2954a2d9c..564202c3273 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -17,9 +17,9 @@ Top-level functions :toctree: generated/ apply_ufunc - apply_to_dataset align broadcast + call_on_dataset concat merge combine_by_coords From 43bf70dbe550087754ddec50ff1c2e582eb600bb Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 15 Mar 2021 17:25:15 +0100 Subject: [PATCH 13/26] update whats-new.rst --- doc/whats-new.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 5832af2bba8..183e31a70ff 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -55,6 +55,9 @@ New Features :py:class:`~core.groupby.DataArrayGroupBy`, inspired by pandas' :py:meth:`~pandas.core.groupby.GroupBy.get_group`. By `Deepak Cherian `_. +- Add :py:func:`call_on_dataset` as a way to apply functions expecting + :py:class:`Dataset` objects to :py:class:`DataArray` objects (:issue:`4837`, :pull:`4863`). + By `Justus Magin `_. Breaking changes @@ -201,9 +204,6 @@ New Features - :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims in the form of kwargs as well as a dict, like most similar methods. By `Maximilian Roos `_. -- Add :py:func:`apply_to_dataset` as a way to apply functions expecting - :py:class:`Dataset` objects to :py:class:`DataArray` objects (:issue:`4837`, :pull:`4863`). - By `Justus Magin `_. Bug fixes ~~~~~~~~~ From 31645e55013647380c8e4e56a6baa6029abd065c Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 15 Mar 2021 17:25:22 +0100 Subject: [PATCH 14/26] remove the notes --- xarray/core/computation.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 81692d1b5ff..ec3f6617e95 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1169,11 +1169,6 @@ def call_on_dataset(func, obj, *args, **kwargs): DataArray or Dataset The result of ``func(obj, *args, **kwargs)`` with the same type as ``obj``. - Notes - ----- - If a ``DataArray``, the data variable of the temporary ``Dataset`` will have a - generic name. The original name will be restored for the result of the call. - See Also -------- Dataset.map From 293d9c16f16c2178c8ba85bd325aa0c983e49f84 Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 15 Mar 2021 17:27:22 +0100 Subject: [PATCH 15/26] remove the no-op --- xarray/core/computation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index ec3f6617e95..29f384c19b1 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1232,8 +1232,6 @@ def call_on_dataset(func, obj, *args, **kwargs): if isinstance(obj, DataArray): result = dataset_to_dataarray(result) - if obj.name is None: - result = result.rename(None) return result From d0de1ca73f8dac79f2f06dfa10e881dd97860577 Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 15 Mar 2021 17:30:17 +0100 Subject: [PATCH 16/26] don't rename to None --- xarray/core/computation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 29f384c19b1..495ccb8f42d 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1218,13 +1218,11 @@ def call_on_dataset(func, obj, *args, **kwargs): Coordinates: * x (x) int64 0 1 """ - from .dataarray import _THIS_ARRAY, DataArray + from .dataarray import DataArray from .parallel import dataarray_to_dataset, dataset_to_dataarray if isinstance(obj, DataArray): ds = dataarray_to_dataset(obj) - if obj.name is None: - ds = ds.rename({_THIS_ARRAY: None}) else: ds = obj From a82223269d36fdf98f3c8864f9f3b93b4d7a1fad Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 15 Mar 2021 17:36:03 +0100 Subject: [PATCH 17/26] rename to "" --- xarray/core/computation.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 495ccb8f42d..9850877c9bb 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1169,6 +1169,12 @@ def call_on_dataset(func, obj, *args, **kwargs): DataArray or Dataset The result of ``func(obj, *args, **kwargs)`` with the same type as ``obj``. + Notes + ----- + DataArray objects without a name (or named ``None``) will be renamed to + ``""`` before being passed to ``func``. The name will be restored for + the result of the call. + See Also -------- Dataset.map @@ -1218,17 +1224,21 @@ def call_on_dataset(func, obj, *args, **kwargs): Coordinates: * x (x) int64 0 1 """ - from .dataarray import DataArray + from .dataarray import _THIS_ARRAY, DataArray from .parallel import dataarray_to_dataset, dataset_to_dataarray if isinstance(obj, DataArray): ds = dataarray_to_dataset(obj) + if obj.name is None: + ds = ds.rename({_THIS_ARRAY: ""}) else: ds = obj result = func(ds, *args, **kwargs) if isinstance(obj, DataArray): + if obj.name is None: + result = result.rename({"": _THIS_ARRAY}) result = dataset_to_dataarray(result) return result From d278919eb1ce3d17a7224ad8c897e08c06ff3201 Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 15 Mar 2021 17:36:37 +0100 Subject: [PATCH 18/26] rewrite [skip-ci] --- xarray/core/computation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 9850877c9bb..fe65a162181 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1172,8 +1172,8 @@ def call_on_dataset(func, obj, *args, **kwargs): Notes ----- DataArray objects without a name (or named ``None``) will be renamed to - ``""`` before being passed to ``func``. The name will be restored for - the result of the call. + ``""`` before being passed to ``func``. The empty name will be restored + for the result of the call. See Also -------- From 97d43387ddad15bb8a9f6650d39bb9bb0f40fbac Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 15 Mar 2021 19:42:19 +0100 Subject: [PATCH 19/26] rename back to None --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index fe65a162181..0642c1ef915 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1238,7 +1238,7 @@ def call_on_dataset(func, obj, *args, **kwargs): if isinstance(obj, DataArray): if obj.name is None: - result = result.rename({"": _THIS_ARRAY}) + result = result.rename({"": None}) result = dataset_to_dataarray(result) return result From 371f5094fa04349b226cf906a37e734e2e0dae6a Mon Sep 17 00:00:00 2001 From: Keewis Date: Tue, 11 May 2021 01:04:04 +0200 Subject: [PATCH 20/26] introduce a mandatory name parameter to use as a name for the data variable --- xarray/core/computation.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 0642c1ef915..90564ed6a59 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1151,7 +1151,7 @@ def earth_mover_distance(first_samples, return apply_array_ufunc(func, *args, dask=dask) -def call_on_dataset(func, obj, *args, **kwargs): +def call_on_dataset(func, obj, name, *args, **kwargs): """apply a function expecting a Dataset to a xarray object Parameters @@ -1161,6 +1161,9 @@ def call_on_dataset(func, obj, *args, **kwargs): obj : DataArray or Dataset The dataset to apply ``func`` to. If a ``DataArray``, convert it to a single variable ``Dataset`` first. + name : hashable + A intermediate name to use as the name of the data variable. If the DataArray + already had a name, it will be restored after converting back. *args, **kwargs Additional arguments to ``func`` @@ -1169,12 +1172,6 @@ def call_on_dataset(func, obj, *args, **kwargs): DataArray or Dataset The result of ``func(obj, *args, **kwargs)`` with the same type as ``obj``. - Notes - ----- - DataArray objects without a name (or named ``None``) will be renamed to - ``""`` before being passed to ``func``. The empty name will be restored - for the result of the call. - See Also -------- Dataset.map @@ -1218,28 +1215,24 @@ def call_on_dataset(func, obj, *args, **kwargs): b (x) float64 -1.5 1.5 Attributes: attr: value - >>> xr.call_on_dataset(f, ds.a) + >>> xr.call_on_dataset(f, ds.a, name="") array([1.5, 2. ]) Coordinates: * x (x) int64 0 1 """ - from .dataarray import _THIS_ARRAY, DataArray + from .dataarray import DataArray from .parallel import dataarray_to_dataset, dataset_to_dataarray if isinstance(obj, DataArray): - ds = dataarray_to_dataset(obj) - if obj.name is None: - ds = ds.rename({_THIS_ARRAY: ""}) + ds = dataarray_to_dataset(obj.rename(name)) else: ds = obj result = func(ds, *args, **kwargs) if isinstance(obj, DataArray): - if obj.name is None: - result = result.rename({"": None}) - result = dataset_to_dataarray(result) + result = dataset_to_dataarray(result).rename({name: obj.name}) return result From c9459f7972253f2d2f2ccf476566a15a9233ed69 Mon Sep 17 00:00:00 2001 From: Keewis Date: Tue, 11 May 2021 01:06:03 +0200 Subject: [PATCH 21/26] move to the new section in whats-new.rst --- doc/whats-new.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c618b47d584..4765efca969 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,9 @@ v0.18.1 (unreleased) New Features ~~~~~~~~~~~~ +- Add :py:func:`call_on_dataset` as a way to apply functions expecting + :py:class:`Dataset` objects to :py:class:`DataArray` objects (:issue:`4837`, :pull:`4863`). + By `Justus Magin `_. Breaking changes @@ -124,9 +127,6 @@ New Features :py:class:`~core.groupby.DataArrayGroupBy`, inspired by pandas' :py:meth:`~pandas.core.groupby.GroupBy.get_group`. By `Deepak Cherian `_. -- Add :py:func:`call_on_dataset` as a way to apply functions expecting - :py:class:`Dataset` objects to :py:class:`DataArray` objects (:issue:`4837`, :pull:`4863`). - By `Justus Magin `_. - Switch the tutorial functions to use `pooch `_ (which is now a optional dependency) and add :py:func:`tutorial.open_rasterio` as a way to open example rasterio files (:issue:`3986`, :pull:`4102`, :pull:`5074`). From 021ad36e24251652650635bc6d855713b4437cb8 Mon Sep 17 00:00:00 2001 From: Keewis Date: Tue, 11 May 2021 14:33:03 +0200 Subject: [PATCH 22/26] fix the tests --- xarray/core/computation.py | 2 +- xarray/tests/test_computation.py | 47 ++++++++++++++++++++++---------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index ce5c0589731..0784bd1bd10 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1239,7 +1239,7 @@ def call_on_dataset(func, obj, name, *args, **kwargs): result = func(ds, *args, **kwargs) if isinstance(obj, DataArray): - result = dataset_to_dataarray(result).rename({name: obj.name}) + result = dataset_to_dataarray(result.rename({name: obj.name})) return result diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index eebc7868ab7..afa4cb207c2 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -469,29 +469,36 @@ def test_apply_groupby_add(): @pytest.mark.parametrize( - ["obj", "expected"], + ["obj", "attrs", "expected"], ( pytest.param( xr.DataArray( [0, 1], - coords={ - "x": ("x", [-1, 1], {"a": 1, "b": 2}), - "u": ("x", [2, 3], {"c": 3}), - }, dims="x", - attrs={"d": 4, "e": 5}, ), - xr.DataArray([0, 1], coords={"x": [-1, 1], "u": ("x", [2, 3])}, dims="x"), - id="DataArray", + {None: {"a": 1}}, + xr.DataArray([0, 1], dims="x", attrs={"a": 1}), + id="unnamed DataArray", + ), + pytest.param( + xr.DataArray( + [0, 1], + dims="x", + name="b", + ), + {None: {"a": 1}}, + xr.DataArray([0, 1], dims="x", attrs={"a": 1}), + id="named DataArray", ), pytest.param( xr.Dataset( - {"a": ("x", [1, 2], {"a": 1, "b": 2}), "b": ("x", [0, 1], {"c": 3})}, + {"a": ("x", [1, 2]), "b": ("x", [0, 1])}, coords={ "x": ("x", [-1, 1], {"d": 4, "e": 5}), "u": ("x", [2, 3], {"f": 6}), }, ), + {"a": {"a": 1}}, xr.Dataset( {"a": ("x", [1, 2]), "b": ("x", [0, 1])}, coords={"x": [-1, 1], "u": ("x", [2, 3])}, @@ -500,15 +507,25 @@ def test_apply_groupby_add(): ), ), ) -def test_call_on_dataset(obj, expected): - def clear_all_attrs(ds): +def test_call_on_dataset(obj, attrs, expected): + temporary_name = "" + + def attach_attrs(ds, attrs): new_ds = ds.copy() - for var in new_ds.variables.values(): - var.attrs.clear() - new_ds.attrs.clear() + for n, v in new_ds.variables.items(): + if n == temporary_name: + n = None + + if n not in attrs: + continue + + v.attrs.update(attrs[n]) + return new_ds - actual = xr.call_on_dataset(clear_all_attrs, obj) + actual = xr.call_on_dataset( + lambda ds: attach_attrs(ds, attrs), obj, name=temporary_name + ) assert_identical(actual, expected) From 7081e150cc3e8425d48f32a3a32167822e448118 Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 31 May 2021 17:41:04 +0200 Subject: [PATCH 23/26] update the input and expected values --- xarray/tests/test_computation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 73c28ac7f4a..7631e82f393 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -488,21 +488,21 @@ def test_apply_groupby_add(): name="b", ), {None: {"a": 1}}, - xr.DataArray([0, 1], dims="x", attrs={"a": 1}), + xr.DataArray([0, 1], dims="x", attrs={"a": 1}, name="b"), id="named DataArray", ), pytest.param( xr.Dataset( {"a": ("x", [1, 2]), "b": ("x", [0, 1])}, coords={ - "x": ("x", [-1, 1], {"d": 4, "e": 5}), - "u": ("x", [2, 3], {"f": 6}), + "x": ("x", [-1, 1]), + "u": ("x", [2, 3]), }, ), - {"a": {"a": 1}}, + {"a": {"a": 1}, "u": {"b": 2}}, xr.Dataset( - {"a": ("x", [1, 2]), "b": ("x", [0, 1])}, - coords={"x": [-1, 1], "u": ("x", [2, 3])}, + {"a": ("x", [1, 2], {"a": 1}), "b": ("x", [0, 1])}, + coords={"x": [-1, 1], "u": ("x", [2, 3], {"b": 2})}, ), id="Dataset", ), From 52a39f37fa1db704abc384c9afdb50b6aad8c056 Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 31 May 2021 17:53:26 +0200 Subject: [PATCH 24/26] add the missing name for the dataset call --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index a0a939f2452..a0d503dbd2b 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1238,7 +1238,7 @@ def call_on_dataset(func, obj, name, *args, **kwargs): b (x) int64 -1 1 Attributes: attr: value - >>> xr.call_on_dataset(f, ds) + >>> xr.call_on_dataset(f, ds, name="") Dimensions: (x: 2) Coordinates: From fcfaaa5d6aaae1d1772efdb8ed11fc3d65290315 Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 31 May 2021 18:47:10 +0200 Subject: [PATCH 25/26] use DataArray.to_dataset instead --- xarray/core/computation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index a0d503dbd2b..5e21b2388b1 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1255,10 +1255,10 @@ def call_on_dataset(func, obj, name, *args, **kwargs): * x (x) int64 0 1 """ from .dataarray import DataArray - from .parallel import dataarray_to_dataset, dataset_to_dataarray + from .parallel import dataset_to_dataarray if isinstance(obj, DataArray): - ds = dataarray_to_dataset(obj.rename(name)) + ds = obj.to_dataset(name=name) else: ds = obj From f2d2880785181f66ea7062d20b8a81fae13047b3 Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 31 May 2021 18:47:42 +0200 Subject: [PATCH 26/26] only convert if the result is a Dataset --- xarray/core/computation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 5e21b2388b1..e12c5e23e2e 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1253,8 +1253,10 @@ def call_on_dataset(func, obj, name, *args, **kwargs): array([1.5, 2. ]) Coordinates: * x (x) int64 0 1 + >>> xr.call_on_dataset(lambda ds: list(ds.variables.keys()), ds.a, name="data") + ['x', 'data'] """ - from .dataarray import DataArray + from .dataarray import DataArray, Dataset from .parallel import dataset_to_dataarray if isinstance(obj, DataArray): @@ -1264,8 +1266,8 @@ def call_on_dataset(func, obj, name, *args, **kwargs): result = func(ds, *args, **kwargs) - if isinstance(obj, DataArray): - result = dataset_to_dataarray(result.rename({name: obj.name})) + if isinstance(obj, DataArray) and isinstance(result, Dataset): + result = dataset_to_dataarray(result).rename(obj.name) return result