Skip to content

Commit

Permalink
Improve tests, add dask tests
Browse files Browse the repository at this point in the history
  • Loading branch information
max-sixty committed Apr 19, 2021
1 parent 1569293 commit b9d0fc5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
4 changes: 3 additions & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,9 @@ def clip(self, min=None, max=None, *, keep_attrs=None):
# default.
keep_attrs = _get_keep_attrs(default=True)

return apply_ufunc(np.clip, self, min, max, keep_attrs=keep_attrs)
return apply_ufunc(
np.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed"
)

def get_index(self, key: Hashable) -> pd.Index:
"""Get an index for a dimension, with fall-back to a default RangeIndex"""
Expand Down
30 changes: 23 additions & 7 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6434,28 +6434,43 @@ def test_idxminmax_dask(self, op, ndim):
assert_equal(getattr(ar0_dsk, op)(dim="x"), getattr(ar0_raw, op)(dim="x"))


@pytest.fixture(params=["numpy", pytest.param("dask", marks=requires_dask)])
def backend(request):
return request.param


@pytest.fixture(params=[1])
def da(request):
def da(request, backend):
if request.param == 1:
times = pd.date_range("2000-01-01", freq="1D", periods=21)
values = np.random.random((3, 21, 4))
da = DataArray(values, dims=("a", "time", "x"))
da["time"] = times
return da
da = DataArray(
np.random.random((3, 21, 4)),
dims=("a", "time", "x"),
coords=dict(time=times),
)

if request.param == 2:
return DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")
da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")

if request.param == "repeating_ints":
return DataArray(
da = DataArray(
np.tile(np.arange(12), 5).reshape(5, 4, 3),
coords={"x": list("abc"), "y": list("defg")},
dims=list("zyx"),
)

if backend == "dask":
return da.chunk()
elif backend == "numpy":
return da
else:
raise ValueError


@pytest.fixture
def da_dask(seed=123):
# TODO: if possible, use the `da` fixture parameterized with backends rather than
# this.
pytest.importorskip("dask.array")
rs = np.random.RandomState(seed)
times = pd.date_range("2000-01-01", freq="1D", periods=21)
Expand Down Expand Up @@ -7328,3 +7343,4 @@ def test_clip(da):

result = da.clip(min=da.mean("x"), max=da.mean("a"))
assert result.dims == da.dims
assert_array_equal(result.data, np.clip(da.data, min=da.mean("x").data, max=da.mean("a").data))

0 comments on commit b9d0fc5

Please sign in to comment.