diff --git a/xarray/core/common.py b/xarray/core/common.py index 5200761a479..ae1e69d7a54 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -388,7 +388,9 @@ def clip(self, min=None, max=None, *, keep_attrs=None): # default. keep_attrs = _get_keep_attrs(default=True) - return apply_ufunc(np.clip, self, min, max, keep_attrs=keep_attrs) + return apply_ufunc( + np.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed" + ) def get_index(self, key: Hashable) -> pd.Index: """Get an index for a dimension, with fall-back to a default RangeIndex""" diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index f338f9b11da..6e0f05fa6de 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6434,28 +6434,43 @@ def test_idxminmax_dask(self, op, ndim): assert_equal(getattr(ar0_dsk, op)(dim="x"), getattr(ar0_raw, op)(dim="x")) +@pytest.fixture(params=["numpy", pytest.param("dask", marks=requires_dask)]) +def backend(request): + return request.param + + @pytest.fixture(params=[1]) -def da(request): +def da(request, backend): if request.param == 1: times = pd.date_range("2000-01-01", freq="1D", periods=21) - values = np.random.random((3, 21, 4)) - da = DataArray(values, dims=("a", "time", "x")) - da["time"] = times - return da + da = DataArray( + np.random.random((3, 21, 4)), + dims=("a", "time", "x"), + coords=dict(time=times), + ) if request.param == 2: - return DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time") + da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time") if request.param == "repeating_ints": - return DataArray( + da = DataArray( np.tile(np.arange(12), 5).reshape(5, 4, 3), coords={"x": list("abc"), "y": list("defg")}, dims=list("zyx"), ) + if backend == "dask": + return da.chunk() + elif backend == "numpy": + return da + else: + raise ValueError + @pytest.fixture def da_dask(seed=123): + # TODO: if possible, use the `da` fixture parameterized with backends rather than + # this. pytest.importorskip("dask.array") rs = np.random.RandomState(seed) times = pd.date_range("2000-01-01", freq="1D", periods=21) @@ -7328,3 +7343,4 @@ def test_clip(da): result = da.clip(min=da.mean("x"), max=da.mean("a")) assert result.dims == da.dims + assert_array_equal(result.data, np.clip(da.data, min=da.mean("x").data, max=da.mean("a").data))