Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: correct dask array handling in _calc_idxminmax #3922

Merged
merged 23 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
12fe069
FIX: correct dask array handling in _calc_idxminmax
kmuehlbauer Mar 31, 2020
59a74d6
FIX: remove unneeded import, reformat via black
kmuehlbauer Mar 31, 2020
7369007
fix idxmax, idxmin with dask arrays
dcherian Mar 31, 2020
9a862fc
FIX: use array[dim].data in `_calc_idxminmax` as per @keewis suggesti…
kmuehlbauer Apr 3, 2020
b154c8a
ADD: add dask tests to `idxmin`/`idxmax` dataarray tests
kmuehlbauer Apr 3, 2020
bf5530b
FIX: add back fixture line removed by accident
kmuehlbauer Apr 3, 2020
a60ae89
ADD: complete dask handling in `idxmin`/`idxmax` tests in test_dataar…
kmuehlbauer Apr 3, 2020
e7c42c7
ADD: add "support dask handling for idxmin/idxmax" in whats-new.rst
kmuehlbauer Apr 3, 2020
8b2ac75
Merge branch 'master' into fix-idxminmax-dask-issues
kmuehlbauer Apr 6, 2020
36dc8ee
Merge branch 'master' into fix-idxminmax-dask-issues
kmuehlbauer Apr 14, 2020
5c23327
MIN: reintroduce changes added by #3953
kmuehlbauer Apr 14, 2020
081ff6f
MIN: change if-clause to use `and` instead of `&` as per review-comment
kmuehlbauer Apr 14, 2020
26cc8f3
MIN: change if-clause to use `and` instead of `&` as per review-comment
kmuehlbauer Apr 14, 2020
613f4e6
WIP: remove dask handling entirely for debugging purposes
kmuehlbauer Apr 14, 2020
8a95492
Test for dask computes
dcherian Apr 15, 2020
8cf4f3e
WIP: re-add dask handling (map_blocks-approach), add `with raise_if_d…
kmuehlbauer Apr 16, 2020
4417a35
Use dask indexing instead of map_blocks.
dcherian Apr 18, 2020
43384c3
Better chunk choice.
dcherian Apr 18, 2020
58901b9
Return -1 for _nan_argminmax_object if all NaNs along dim
dcherian May 9, 2020
cfc98da
Revert "Return -1 for _nan_argminmax_object if all NaNs along dim"
dcherian May 9, 2020
c1324bf
Raise error for object arrays
dcherian May 9, 2020
525118b
No error for object arrays. Instead expect 1 compute in tests.
dcherian May 9, 2020
20ccf5a
Merge remote-tracking branch 'upstream/master' into fix-idxminmax-das…
dcherian May 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ New Features
- Implement :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`,
:py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:issue:`60`, :pull:`3871`)
By `Todd Jennings <https://github.com/toddrjen>`_
- Support dask handling for :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`,
:py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:pull:`3922`)
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- More support for unit aware arrays with pint (:pull:`3643`)
By `Justus Magin <https://github.com/keewis>`_.
- Support overriding existing variables in ``to_zarr()`` with ``mode='a'`` even
Expand Down
23 changes: 11 additions & 12 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from . import dtypes, duck_array_ops, utils
from .alignment import deep_align
from .merge import merge_coordinates_without_align
from .nanops import dask_array
from .options import OPTIONS
from .pycompat import dask_array_type
from .utils import is_dict_like
Expand Down Expand Up @@ -1380,24 +1379,24 @@ def _calc_idxminmax(
# This will run argmin or argmax.
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)

# Get the coordinate we want.
coordarray = array[dim]

# Handle dask arrays.
if isinstance(array, dask_array_type):
res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype)
if isinstance(array.data, dask_array_type):
import dask.array

chunks = dict(zip(array.dims, array.chunks))
dask_coord = dask.array.from_array(array[dim].data, chunks=chunks[dim])
res = indx.copy(data=dask_coord[(indx.data,)])
# we need to attach back the dim name
res.name = dim
else:
res = coordarray[
indx,
]
res = array[dim][(indx,)]
# The dim is gone but we need to remove the corresponding coordinate.
del res.coords[dim]

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
res = res.where(~allna, fill_value)

# The dim is gone but we need to remove the corresponding coordinate.
del res.coords[dim]

# Copy attributes from argmin/argmax, if any
res.attrs = indx.attrs

Expand Down
120 changes: 96 additions & 24 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
source_ndarray,
)

from .test_dask import raise_if_dask_computes


class TestDataArray:
@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -4524,11 +4526,21 @@ def test_argmax(self, x, minindex, maxindex, nanindex):

assert_identical(result2, expected2)

def test_idxmin(self, x, minindex, maxindex, nanindex):
ar0 = xr.DataArray(
@pytest.mark.parametrize("use_dask", [True, False])
def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask):
if use_dask and not has_dask:
pytest.skip("requires dask")
if use_dask and x.dtype.kind == "M":
pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)")
ar0_raw = xr.DataArray(
x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs
)

if use_dask:
ar0 = ar0_raw.chunk({})
else:
ar0 = ar0_raw

# dim doesn't exist
with pytest.raises(KeyError):
ar0.idxmin(dim="spam")
Expand Down Expand Up @@ -4620,11 +4632,21 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
result7 = ar0.idxmin(fill_value=-1j)
assert_identical(result7, expected7)

def test_idxmax(self, x, minindex, maxindex, nanindex):
ar0 = xr.DataArray(
@pytest.mark.parametrize("use_dask", [True, False])
def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask):
if use_dask and not has_dask:
pytest.skip("requires dask")
if use_dask and x.dtype.kind == "M":
pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)")
ar0_raw = xr.DataArray(
x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs
)

if use_dask:
ar0 = ar0_raw.chunk({})
else:
ar0 = ar0_raw

# dim doesn't exist
with pytest.raises(KeyError):
ar0.idxmax(dim="spam")
Expand Down Expand Up @@ -4944,14 +4966,31 @@ def test_argmax(self, x, minindex, maxindex, nanindex):

assert_identical(result3, expected2)

def test_idxmin(self, x, minindex, maxindex, nanindex):
ar0 = xr.DataArray(
@pytest.mark.parametrize("use_dask", [True, False])
def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask):
if use_dask and not has_dask:
pytest.skip("requires dask")
if use_dask and x.dtype.kind == "M":
pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)")

if x.dtype.kind == "O":
# TODO: nanops._nan_argminmax_object computes once to check for all-NaN slices.
max_computes = 1
else:
max_computes = 0

ar0_raw = xr.DataArray(
x,
dims=["y", "x"],
coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])},
attrs=self.attrs,
)

if use_dask:
ar0 = ar0_raw.chunk({})
else:
ar0 = ar0_raw

assert_identical(ar0, ar0)

# No dimension specified
Expand Down Expand Up @@ -4982,15 +5021,18 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
expected0.name = "x"

# Default fill value (NaN)
result0 = ar0.idxmin(dim="x")
with raise_if_dask_computes(max_computes=max_computes):
result0 = ar0.idxmin(dim="x")
assert_identical(result0, expected0)

# Manually specify NaN fill_value
result1 = ar0.idxmin(dim="x", fill_value=np.NaN)
with raise_if_dask_computes(max_computes=max_computes):
result1 = ar0.idxmin(dim="x", fill_value=np.NaN)
assert_identical(result1, expected0)

# keep_attrs
result2 = ar0.idxmin(dim="x", keep_attrs=True)
with raise_if_dask_computes(max_computes=max_computes):
result2 = ar0.idxmin(dim="x", keep_attrs=True)
expected2 = expected0.copy()
expected2.attrs = self.attrs
assert_identical(result2, expected2)
Expand All @@ -5008,11 +5050,13 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
expected3.name = "x"
expected3.attrs = {}

result3 = ar0.idxmin(dim="x", skipna=False)
with raise_if_dask_computes(max_computes=max_computes):
result3 = ar0.idxmin(dim="x", skipna=False)
assert_identical(result3, expected3)

# fill_value should be ignored with skipna=False
result4 = ar0.idxmin(dim="x", skipna=False, fill_value=-100j)
with raise_if_dask_computes(max_computes=max_computes):
result4 = ar0.idxmin(dim="x", skipna=False, fill_value=-100j)
assert_identical(result4, expected3)

# Float fill_value
Expand All @@ -5024,7 +5068,8 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
expected5 = xr.concat(expected5, dim="y")
expected5.name = "x"

result5 = ar0.idxmin(dim="x", fill_value=-1.1)
with raise_if_dask_computes(max_computes=max_computes):
result5 = ar0.idxmin(dim="x", fill_value=-1.1)
assert_identical(result5, expected5)

# Integer fill_value
Expand All @@ -5036,7 +5081,8 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
expected6 = xr.concat(expected6, dim="y")
expected6.name = "x"

result6 = ar0.idxmin(dim="x", fill_value=-1)
with raise_if_dask_computes(max_computes=max_computes):
result6 = ar0.idxmin(dim="x", fill_value=-1)
assert_identical(result6, expected6)

# Complex fill_value
Expand All @@ -5048,17 +5094,35 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
expected7 = xr.concat(expected7, dim="y")
expected7.name = "x"

result7 = ar0.idxmin(dim="x", fill_value=-5j)
with raise_if_dask_computes(max_computes=max_computes):
result7 = ar0.idxmin(dim="x", fill_value=-5j)
assert_identical(result7, expected7)

def test_idxmax(self, x, minindex, maxindex, nanindex):
ar0 = xr.DataArray(
@pytest.mark.parametrize("use_dask", [True, False])
def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask):
if use_dask and not has_dask:
pytest.skip("requires dask")
if use_dask and x.dtype.kind == "M":
pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)")

if x.dtype.kind == "O":
# TODO: nanops._nan_argminmax_object computes once to check for all-NaN slices.
max_computes = 1
else:
max_computes = 0

ar0_raw = xr.DataArray(
x,
dims=["y", "x"],
coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])},
attrs=self.attrs,
)

if use_dask:
ar0 = ar0_raw.chunk({})
else:
ar0 = ar0_raw

# No dimension specified
with pytest.raises(ValueError):
ar0.idxmax()
Expand Down Expand Up @@ -5090,15 +5154,18 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
expected0.name = "x"

# Default fill value (NaN)
result0 = ar0.idxmax(dim="x")
with raise_if_dask_computes(max_computes=max_computes):
result0 = ar0.idxmax(dim="x")
assert_identical(result0, expected0)

# Manually specify NaN fill_value
result1 = ar0.idxmax(dim="x", fill_value=np.NaN)
with raise_if_dask_computes(max_computes=max_computes):
result1 = ar0.idxmax(dim="x", fill_value=np.NaN)
assert_identical(result1, expected0)

# keep_attrs
result2 = ar0.idxmax(dim="x", keep_attrs=True)
with raise_if_dask_computes(max_computes=max_computes):
result2 = ar0.idxmax(dim="x", keep_attrs=True)
expected2 = expected0.copy()
expected2.attrs = self.attrs
assert_identical(result2, expected2)
Expand All @@ -5116,11 +5183,13 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
expected3.name = "x"
expected3.attrs = {}

result3 = ar0.idxmax(dim="x", skipna=False)
with raise_if_dask_computes(max_computes=max_computes):
result3 = ar0.idxmax(dim="x", skipna=False)
assert_identical(result3, expected3)

# fill_value should be ignored with skipna=False
result4 = ar0.idxmax(dim="x", skipna=False, fill_value=-100j)
with raise_if_dask_computes(max_computes=max_computes):
result4 = ar0.idxmax(dim="x", skipna=False, fill_value=-100j)
assert_identical(result4, expected3)

# Float fill_value
Expand All @@ -5132,7 +5201,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
expected5 = xr.concat(expected5, dim="y")
expected5.name = "x"

result5 = ar0.idxmax(dim="x", fill_value=-1.1)
with raise_if_dask_computes(max_computes=max_computes):
result5 = ar0.idxmax(dim="x", fill_value=-1.1)
assert_identical(result5, expected5)

# Integer fill_value
Expand All @@ -5144,7 +5214,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
expected6 = xr.concat(expected6, dim="y")
expected6.name = "x"

result6 = ar0.idxmax(dim="x", fill_value=-1)
with raise_if_dask_computes(max_computes=max_computes):
result6 = ar0.idxmax(dim="x", fill_value=-1)
assert_identical(result6, expected6)

# Complex fill_value
Expand All @@ -5156,7 +5227,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
expected7 = xr.concat(expected7, dim="y")
expected7.name = "x"

result7 = ar0.idxmax(dim="x", fill_value=-5j)
with raise_if_dask_computes(max_computes=max_computes):
result7 = ar0.idxmax(dim="x", fill_value=-5j)
assert_identical(result7, expected7)


Expand Down