From f8ab40c5fc1424f9c66206ba9f00dc21735890af Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 30 Sep 2023 11:50:33 -0700 Subject: [PATCH] Accept `lambda` for `other` param (#8256) * Accept `lambda` for `other` param --- doc/whats-new.rst | 4 ++++ xarray/core/common.py | 28 ++++++++++++++++------------ xarray/tests/test_dataarray.py | 8 ++++++++ 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 17744288aef..e485b24bf3e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,10 @@ v2023.09.1 (unreleased) New Features ~~~~~~~~~~~~ +- :py:meth:`DataArray.where` & :py:meth:`Dataset.where` accept a callable for + the ``other`` parameter, passing the object as the first argument. Previously, + this was only valid for the ``cond`` parameter. (:issue:`8255`) + By `Maximilian Roos `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/common.py b/xarray/core/common.py index db9b2aead23..2a4c4c200d4 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1074,9 +1074,10 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: cond : DataArray, Dataset, or callable Locations at which to preserve this object's values. dtype must be `bool`. If a callable, it must expect this object as its only parameter. - other : scalar, DataArray or Dataset, optional + other : scalar, DataArray, Dataset, or callable, optional Value to use for locations in this object where ``cond`` is False. - By default, these locations filled with NA. + By default, these locations are filled with NA. If a callable, it must + expect this object as its only parameter. drop : bool, default: False If True, coordinate labels that only correspond to False values of the condition are dropped from the result. @@ -1124,7 +1125,16 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: [15., nan, nan, nan]]) Dimensions without coordinates: x, y - >>> a.where(lambda x: x.x + x.y < 4, drop=True) + >>> a.where(lambda x: x.x + x.y < 4, lambda x: -x) + + array([[ 0, 1, 2, 3, -4], + [ 5, 6, 7, -8, -9], + [ 10, 11, -12, -13, -14], + [ 15, -16, -17, -18, -19], + [-20, -21, -22, -23, -24]]) + Dimensions without coordinates: x, y + + >>> a.where(a.x + a.y < 4, drop=True) array([[ 0., 1., 2., 3.], [ 5., 6., 7., nan], @@ -1132,14 +1142,6 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: [15., nan, nan, nan]]) Dimensions without coordinates: x, y - >>> a.where(a.x + a.y < 4, -1, drop=True) - - array([[ 0, 1, 2, 3], - [ 5, 6, 7, -1], - [10, 11, -1, -1], - [15, -1, -1, -1]]) - Dimensions without coordinates: x, y - See Also -------- numpy.where : corresponding numpy function @@ -1151,11 +1153,13 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: if callable(cond): cond = cond(self) + if callable(other): + other = other(self) if drop: if not isinstance(cond, (Dataset, DataArray)): raise TypeError( - f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r}" + f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r} (or a callable than returns one)." ) self, cond = align(self, cond) # type: ignore[assignment] diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 11ebc4da347..63175f2be40 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2717,6 +2717,14 @@ def test_where_lambda(self) -> None: actual = arr.where(lambda x: x.y < 2, drop=True) assert_identical(actual, expected) + def test_where_other_lambda(self) -> None: + arr = DataArray(np.arange(4), dims="y") + expected = xr.concat( + [arr.sel(y=slice(2)), arr.sel(y=slice(2, None)) + 1], dim="y" + ) + actual = arr.where(lambda x: x.y < 2, lambda x: x + 1) + assert_identical(actual, expected) + def test_where_string(self) -> None: array = DataArray(["a", "b"]) expected = DataArray(np.array(["a", np.nan], dtype=object))