Skip to content

Commit

Permalink
Accept lambda for other param (#8256)
Browse files Browse the repository at this point in the history
* Accept `lambda` for `other` param
  • Loading branch information
max-sixty authored Sep 30, 2023
1 parent d8c166b commit f8ab40c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 12 deletions.
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ v2023.09.1 (unreleased)
New Features
~~~~~~~~~~~~

- :py:meth:`DataArray.where` & :py:meth:`Dataset.where` accept a callable for
the ``other`` parameter, passing the object as the first argument. Previously,
this was only valid for the ``cond`` parameter. (:issue:`8255`)
By `Maximilian Roos <https://github.com/max-sixty>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
28 changes: 16 additions & 12 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,9 +1074,10 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self:
cond : DataArray, Dataset, or callable
Locations at which to preserve this object's values. dtype must be `bool`.
If a callable, it must expect this object as its only parameter.
other : scalar, DataArray or Dataset, optional
other : scalar, DataArray, Dataset, or callable, optional
Value to use for locations in this object where ``cond`` is False.
By default, these locations filled with NA.
By default, these locations are filled with NA. If a callable, it must
expect this object as its only parameter.
drop : bool, default: False
If True, coordinate labels that only correspond to False values of
the condition are dropped from the result.
Expand Down Expand Up @@ -1124,22 +1125,23 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self:
[15., nan, nan, nan]])
Dimensions without coordinates: x, y
>>> a.where(lambda x: x.x + x.y < 4, drop=True)
>>> a.where(lambda x: x.x + x.y < 4, lambda x: -x)
<xarray.DataArray (x: 5, y: 5)>
array([[ 0, 1, 2, 3, -4],
[ 5, 6, 7, -8, -9],
[ 10, 11, -12, -13, -14],
[ 15, -16, -17, -18, -19],
[-20, -21, -22, -23, -24]])
Dimensions without coordinates: x, y
>>> a.where(a.x + a.y < 4, drop=True)
<xarray.DataArray (x: 4, y: 4)>
array([[ 0., 1., 2., 3.],
[ 5., 6., 7., nan],
[10., 11., nan, nan],
[15., nan, nan, nan]])
Dimensions without coordinates: x, y
>>> a.where(a.x + a.y < 4, -1, drop=True)
<xarray.DataArray (x: 4, y: 4)>
array([[ 0, 1, 2, 3],
[ 5, 6, 7, -1],
[10, 11, -1, -1],
[15, -1, -1, -1]])
Dimensions without coordinates: x, y
See Also
--------
numpy.where : corresponding numpy function
Expand All @@ -1151,11 +1153,13 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self:

if callable(cond):
cond = cond(self)
if callable(other):
other = other(self)

if drop:
if not isinstance(cond, (Dataset, DataArray)):
raise TypeError(
f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r}"
f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r} (or a callable than returns one)."
)

self, cond = align(self, cond) # type: ignore[assignment]
Expand Down
8 changes: 8 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2717,6 +2717,14 @@ def test_where_lambda(self) -> None:
actual = arr.where(lambda x: x.y < 2, drop=True)
assert_identical(actual, expected)

def test_where_other_lambda(self) -> None:
arr = DataArray(np.arange(4), dims="y")
expected = xr.concat(
[arr.sel(y=slice(2)), arr.sel(y=slice(2, None)) + 1], dim="y"
)
actual = arr.where(lambda x: x.y < 2, lambda x: x + 1)
assert_identical(actual, expected)

def test_where_string(self) -> None:
array = DataArray(["a", "b"])
expected = DataArray(np.array(["a", np.nan], dtype=object))
Expand Down

0 comments on commit f8ab40c

Please sign in to comment.