Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable clip broadcasting, with apply_ufunc #5184

Merged
merged 11 commits into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ New Features
By `Richard Kleijn <https://github.com/rhkleijn>`_ .
- Add a ``combine_attrs`` parameter to :py:func:`open_mfdataset` (:pull:`4971`).
By `Justus Magin <https://github.com/keewis>`_.
- Enable passing arrays with a subset of dimensions to
:py:meth:`DataArray.clip` & :py:meth:`Dataset.clip`; these methods now use
:py:func:`xarray.apply_ufunc`; (:pull:`5184`).
By `Maximilian Roos <https://github.com/max-sixty>`_.
- Disable the `cfgrib` backend if the `eccodes` library is not installed (:pull:`5083`). By `Baudouin Raoult <https://github.com/b8raoult>`_.
- Added :py:meth:`DataArray.curvefit` and :py:meth:`Dataset.curvefit` for general curve fitting applications. (:issue:`4300`, :pull:`4849`)
By `Sam Levang <https://github.com/slevang>`_.
Expand Down
12 changes: 0 additions & 12 deletions xarray/core/_typed_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,6 @@ def round(self, *args, **kwargs):
def argsort(self, *args, **kwargs):
return self._unary_op(ops.argsort, *args, **kwargs)

def clip(self, *args, **kwargs):
return self._unary_op(ops.clip, *args, **kwargs)

def conj(self, *args, **kwargs):
return self._unary_op(ops.conj, *args, **kwargs)

Expand Down Expand Up @@ -195,7 +192,6 @@ def conjugate(self, *args, **kwargs):
__invert__.__doc__ = operator.invert.__doc__
round.__doc__ = ops.round_.__doc__
argsort.__doc__ = ops.argsort.__doc__
clip.__doc__ = ops.clip.__doc__
conj.__doc__ = ops.conj.__doc__
conjugate.__doc__ = ops.conjugate.__doc__

Expand Down Expand Up @@ -338,9 +334,6 @@ def round(self, *args, **kwargs):
def argsort(self, *args, **kwargs):
return self._unary_op(ops.argsort, *args, **kwargs)

def clip(self, *args, **kwargs):
return self._unary_op(ops.clip, *args, **kwargs)

def conj(self, *args, **kwargs):
return self._unary_op(ops.conj, *args, **kwargs)

Expand Down Expand Up @@ -389,7 +382,6 @@ def conjugate(self, *args, **kwargs):
__invert__.__doc__ = operator.invert.__doc__
round.__doc__ = ops.round_.__doc__
argsort.__doc__ = ops.argsort.__doc__
clip.__doc__ = ops.clip.__doc__
conj.__doc__ = ops.conj.__doc__
conjugate.__doc__ = ops.conjugate.__doc__

Expand Down Expand Up @@ -532,9 +524,6 @@ def round(self, *args, **kwargs):
def argsort(self, *args, **kwargs):
return self._unary_op(ops.argsort, *args, **kwargs)

def clip(self, *args, **kwargs):
return self._unary_op(ops.clip, *args, **kwargs)

def conj(self, *args, **kwargs):
return self._unary_op(ops.conj, *args, **kwargs)

Expand Down Expand Up @@ -583,7 +572,6 @@ def conjugate(self, *args, **kwargs):
__invert__.__doc__ = operator.invert.__doc__
round.__doc__ = ops.round_.__doc__
argsort.__doc__ = ops.argsort.__doc__
clip.__doc__ = ops.clip.__doc__
conj.__doc__ = ops.conj.__doc__
conjugate.__doc__ = ops.conjugate.__doc__

Expand Down
3 changes: 0 additions & 3 deletions xarray/core/_typed_ops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ class DatasetOpsMixin:
def __invert__(self: T_Dataset) -> T_Dataset: ...
def round(self: T_Dataset, *args, **kwargs) -> T_Dataset: ...
def argsort(self: T_Dataset, *args, **kwargs) -> T_Dataset: ...
def clip(self: T_Dataset, *args, **kwargs) -> T_Dataset: ...
def conj(self: T_Dataset, *args, **kwargs) -> T_Dataset: ...
def conjugate(self: T_Dataset, *args, **kwargs) -> T_Dataset: ...

Expand Down Expand Up @@ -235,7 +234,6 @@ class DataArrayOpsMixin:
def __invert__(self: T_DataArray) -> T_DataArray: ...
def round(self: T_DataArray, *args, **kwargs) -> T_DataArray: ...
def argsort(self: T_DataArray, *args, **kwargs) -> T_DataArray: ...
def clip(self: T_DataArray, *args, **kwargs) -> T_DataArray: ...
def conj(self: T_DataArray, *args, **kwargs) -> T_DataArray: ...
def conjugate(self: T_DataArray, *args, **kwargs) -> T_DataArray: ...

Expand Down Expand Up @@ -406,7 +404,6 @@ class VariableOpsMixin:
def __invert__(self: T_Variable) -> T_Variable: ...
def round(self: T_Variable, *args, **kwargs) -> T_Variable: ...
def argsort(self: T_Variable, *args, **kwargs) -> T_Variable: ...
def clip(self: T_Variable, *args, **kwargs) -> T_Variable: ...
def conj(self: T_Variable, *args, **kwargs) -> T_Variable: ...
def conjugate(self: T_Variable, *args, **kwargs) -> T_Variable: ...

Expand Down
22 changes: 22 additions & 0 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,28 @@ def squeeze(
dims = get_squeeze_dims(self, dim, axis)
return self.isel(drop=drop, **{d: 0 for d in dims})

def clip(self, min=None, max=None, *, keep_attrs: bool = None):
"""
Return an array whose values are limited to ``[min, max]``.
At least one of max or min must be given.

Refer to `numpy.clip` for full documentation.

See Also
--------
numpy.clip : equivalent function
"""
from .computation import apply_ufunc

if keep_attrs is None:
# When this was a unary func, the default was True, so retaining the
# default.
keep_attrs = _get_keep_attrs(default=True)

return apply_ufunc(
np.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed"
)

def get_index(self, key: Hashable) -> pd.Index:
"""Get an index for a dimension, with fall-back to a default RangeIndex"""
if key not in self.dims:
Expand Down
1 change: 0 additions & 1 deletion xarray/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ def inplace_to_noninplace_op(f):

# _typed_ops.py uses the following wrapped functions as a kind of unary operator
argsort = _method_wrapper("argsort")
clip = _method_wrapper("clip")
conj = _method_wrapper("conj")
conjugate = _method_wrapper("conjugate")
round_ = _func_slash_method_wrapper(duck_array_ops.around, name="round")
Expand Down
15 changes: 15 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,6 +1665,21 @@ def fillna(self, value):
def where(self, cond, other=dtypes.NA):
return ops.where_method(self, cond, other)

def clip(self, min=None, max=None):
Copy link
Collaborator Author

@max-sixty max-sixty Apr 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether there should be a class that DataArray & Variable should inherit from (but not AbstractArray) where they share methods.

I don't think there's enough overlap to justify it at the moment, but worth considering. That may mean we could unify the tests, which are currently highly duplicated for clip for example.

"""
Return an array whose values are limited to ``[min, max]``.
At least one of max or min must be given.

Refer to `numpy.clip` for full documentation.

See Also
--------
numpy.clip : equivalent function
"""
from .computation import apply_ufunc

return apply_ufunc(np.clip, self, min, max, dask="allowed")

def reduce(
self,
func,
Expand Down
67 changes: 54 additions & 13 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6434,28 +6434,43 @@ def test_idxminmax_dask(self, op, ndim):
assert_equal(getattr(ar0_dsk, op)(dim="x"), getattr(ar0_raw, op)(dim="x"))


@pytest.fixture(params=["numpy", pytest.param("dask", marks=requires_dask)])
def backend(request):
return request.param


@pytest.fixture(params=[1])
def da(request):
def da(request, backend):
if request.param == 1:
times = pd.date_range("2000-01-01", freq="1D", periods=21)
values = np.random.random((3, 21, 4))
da = DataArray(values, dims=("a", "time", "x"))
da["time"] = times
return da
da = DataArray(
np.random.random((3, 21, 4)),
dims=("a", "time", "x"),
coords=dict(time=times),
)

if request.param == 2:
return DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")
da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")

if request.param == "repeating_ints":
return DataArray(
da = DataArray(
np.tile(np.arange(12), 5).reshape(5, 4, 3),
coords={"x": list("abc"), "y": list("defg")},
dims=list("zyx"),
)

if backend == "dask":
return da.chunk()
elif backend == "numpy":
return da
else:
raise ValueError


@pytest.fixture
def da_dask(seed=123):
# TODO: if possible, use the `da` fixture parameterized with backends rather than
# this.
pytest.importorskip("dask.array")
rs = np.random.RandomState(seed)
times = pd.date_range("2000-01-01", freq="1D", periods=21)
Expand Down Expand Up @@ -6596,6 +6611,7 @@ def test_rolling_properties(da):
@pytest.mark.parametrize("name", ("sum", "mean", "std", "min", "max", "median"))
@pytest.mark.parametrize("center", (True, False, None))
@pytest.mark.parametrize("min_periods", (1, None))
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
def test_rolling_wrapped_bottleneck(da, name, center, min_periods):
bn = pytest.importorskip("bottleneck", minversion="1.1")

Expand Down Expand Up @@ -6898,12 +6914,7 @@ def test_rolling_keep_attrs_deprecated():
data = np.linspace(10, 15, 100)
coords = np.linspace(1, 10, 100)

da = DataArray(
data,
dims=("coord"),
coords={"coord": coords},
attrs=attrs_da,
)
da = DataArray(data, dims=("coord"), coords={"coord": coords}, attrs=attrs_da)

# deprecated option
with pytest.warns(
Expand Down Expand Up @@ -7206,6 +7217,7 @@ def test_fallback_to_iris_AuxCoord(self, coord_values):
@pytest.mark.parametrize(
"window_type, window", [["span", 5], ["alpha", 0.5], ["com", 0.5], ["halflife", 5]]
)
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
def test_rolling_exp(da, dim, window_type, window):
da = da.isel(a=0)
da = da.where(da > 0.2)
Expand All @@ -7225,6 +7237,7 @@ def test_rolling_exp(da, dim, window_type, window):


@requires_numbagg
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
def test_rolling_exp_keep_attrs(da):
attrs = {"attrs": "da"}
da.attrs = attrs
Expand Down Expand Up @@ -7313,3 +7326,31 @@ def test_deepcopy_obj_array():
x0 = DataArray(np.array([object()]))
x1 = deepcopy(x0)
assert x0.values[0] is not x1.values[0]


def test_clip(da):
dcherian marked this conversation as resolved.
Show resolved Hide resolved
result = da.clip(min=0.5)
max-sixty marked this conversation as resolved.
Show resolved Hide resolved
assert result.min(...) >= 0.5

result = da.clip(max=0.5)
assert result.max(...) <= 0.5

result = da.clip(min=0.25, max=0.75)
assert result.min(...) >= 0.25
assert result.max(...) <= 0.75

result = da.clip(min=da.mean("x"), max=da.mean("a"))
assert result.dims == da.dims
max-sixty marked this conversation as resolved.
Show resolved Hide resolved
assert_array_equal(
result.data,
np.clip(da.data, da.mean("x").data[:, :, np.newaxis], da.mean("a").data),
)

with_nans = da.isel(time=[0, 1]).reindex_like(da)
result = da.clip(with_nans)
# The values should be the same where there were NaNs.
assert_array_equal(result.isel(time=[0, 1]), with_nans.isel(time=[0, 1]))

# Unclear whether we want this work, OK to adjust the test when we have decided.
with pytest.raises(ValueError, match="arguments without labels along dimension"):
result = da.clip(min=da.mean("x"), max=da.mean("a").isel(x=[0, 1]))
35 changes: 25 additions & 10 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6069,16 +6069,16 @@ def test_dir_unicode(data_set):
def ds(request):
if request.param == 1:
return Dataset(
{
"z1": (["y", "x"], np.random.randn(2, 8)),
"z2": (["time", "y"], np.random.randn(10, 2)),
},
{
"x": ("x", np.linspace(0, 1.0, 8)),
"time": ("time", np.linspace(0, 1.0, 10)),
"c": ("y", ["a", "b"]),
"y": range(2),
},
dict(
z1=(["y", "x"], np.random.randn(2, 8)),
z2=(["time", "y"], np.random.randn(10, 2)),
),
dict(
x=("x", np.linspace(0, 1.0, 8)),
time=("time", np.linspace(0, 1.0, 10)),
c=("y", ["a", "b"]),
y=range(2),
),
)

if request.param == 2:
Expand Down Expand Up @@ -6845,3 +6845,18 @@ def test_deepcopy_obj_array():
x0 = Dataset(dict(foo=DataArray(np.array([object()]))))
x1 = deepcopy(x0)
assert x0["foo"].values[0] is not x1["foo"].values[0]


def test_clip(ds):
result = ds.clip(min=0.5)
assert result.min(...) >= 0.5

result = ds.clip(max=0.5)
assert result.max(...) <= 0.5

result = ds.clip(min=0.25, max=0.75)
assert result.min(...) >= 0.25
assert result.max(...) <= 0.75

result = ds.clip(min=ds.mean("y"), max=ds.mean("y"))
assert result.dims == ds.dims
29 changes: 29 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
]


@pytest.fixture
def var():
return Variable(dims=list("xyz"), data=np.random.rand(3, 4, 5))


class VariableSubclassobjects:
def test_properties(self):
data = 0.5 * np.arange(10)
Expand Down Expand Up @@ -2510,3 +2515,27 @@ def test_DaskIndexingAdapter(self):
v = Variable(dims=("x", "y"), data=CopyOnWriteArray(DaskIndexingAdapter(da)))
self.check_orthogonal_indexing(v)
self.check_vectorized_indexing(v)


def test_clip(var):
# Copied from test_dataarray (would there be a way to combine the tests?)
result = var.clip(min=0.5)
assert result.min(...) >= 0.5

result = var.clip(max=0.5)
assert result.max(...) <= 0.5

result = var.clip(min=0.25, max=0.75)
assert result.min(...) >= 0.25
assert result.max(...) <= 0.75

result = var.clip(min=var.mean("x"), max=var.mean("z"))
assert result.dims == var.dims
assert_array_equal(
result.data,
np.clip(
var.data,
var.mean("x").data[np.newaxis, :, :],
var.mean("z").data[:, :, np.newaxis],
),
)
1 change: 0 additions & 1 deletion xarray/util/generate_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
OTHER_UNARY_METHODS = (
("round", "ops.round_"),
("argsort", "ops.argsort"),
("clip", "ops.clip"),
("conj", "ops.conj"),
("conjugate", "ops.conjugate"),
)
Expand Down