Skip to content

Commit

Permalink
Implement idxmax and idxmin functions (pydata#3871)
Browse files Browse the repository at this point in the history
* drop numpy 1.12 compat code that can hide other errors

* deep copy _indexes (pydata#3899)

* implement idxmax and idxmin
  • Loading branch information
toddrjen authored Mar 29, 2020
1 parent ca6bb85 commit 1416d5a
Show file tree
Hide file tree
Showing 8 changed files with 1,277 additions and 14 deletions.
4 changes: 4 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ Computation
:py:attr:`~Dataset.any`
:py:attr:`~Dataset.argmax`
:py:attr:`~Dataset.argmin`
:py:attr:`~Dataset.idxmax`
:py:attr:`~Dataset.idxmin`
:py:attr:`~Dataset.max`
:py:attr:`~Dataset.mean`
:py:attr:`~Dataset.median`
Expand Down Expand Up @@ -362,6 +364,8 @@ Computation
:py:attr:`~DataArray.any`
:py:attr:`~DataArray.argmax`
:py:attr:`~DataArray.argmin`
:py:attr:`~DataArray.idxmax`
:py:attr:`~DataArray.idxmin`
:py:attr:`~DataArray.max`
:py:attr:`~DataArray.mean`
:py:attr:`~DataArray.median`
Expand Down
6 changes: 6 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,16 @@ New Features
- Limited the length of array items with long string reprs to a
reasonable width (:pull:`3900`)
By `Maximilian Roos <https://github.com/max-sixty>`_
- Implement :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`,
:py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:issue:`60`, :pull:`3871`)
By `Todd Jennings <https://github.com/toddrjen>`_


Bug fixes
~~~~~~~~~
- Fix a regression where deleting a coordinate from a copied :py:class:`DataArray`
can affect the original :py:class:`Dataarray`. (:issue:`3899`, :pull:`3871`)
By `Todd Jennings <https://github.com/toddrjen>`_


Documentation
Expand Down
66 changes: 65 additions & 1 deletion xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

import numpy as np

from . import duck_array_ops, utils
from . import dtypes, duck_array_ops, utils
from .alignment import deep_align
from .merge import merge_coordinates_without_align
from .nanops import dask_array
from .options import OPTIONS
from .pycompat import dask_array_type
from .utils import is_dict_like
Expand Down Expand Up @@ -1338,3 +1339,66 @@ def polyval(coord, coeffs, degree_dim="degree"):
coords={coord.name: coord, degree_dim: np.arange(deg_coord.max() + 1)[::-1]},
)
return (lhs * coeffs).sum(degree_dim)


def _calc_idxminmax(
*,
array,
func: Callable,
dim: Hashable = None,
skipna: bool = None,
fill_value: Any = dtypes.NA,
keep_attrs: bool = None,
):
"""Apply common operations for idxmin and idxmax."""
# This function doesn't make sense for scalars so don't try
if not array.ndim:
raise ValueError("This function does not apply for scalars")

if dim is not None:
pass # Use the dim if available
elif array.ndim == 1:
# it is okay to guess the dim if there is only 1
dim = array.dims[0]
else:
# The dim is not specified and ambiguous. Don't guess.
raise ValueError("Must supply 'dim' argument for multidimensional arrays")

if dim not in array.dims:
raise KeyError(f'Dimension "{dim}" not in dimension')
if dim not in array.coords:
raise KeyError(f'Dimension "{dim}" does not have coordinates')

# These are dtypes with NaN values argmin and argmax can handle
na_dtypes = "cfO"

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Need to skip NaN values since argmin and argmax can't handle them
allna = array.isnull().all(dim)
array = array.where(~allna, 0)

# This will run argmin or argmax.
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)

# Get the coordinate we want.
coordarray = array[dim]

# Handle dask arrays.
if isinstance(array, dask_array_type):
res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype)
else:
res = coordarray[
indx,
]

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
res = res.where(~allna, fill_value)

# The dim is gone but we need to remove the corresponding coordinate.
del res.coords[dim]

# Copy attributes from argmin/argmax, if any
res.attrs = indx.attrs

return res
193 changes: 192 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,10 @@ def copy(self, deep: bool = True, data: Any = None) -> "DataArray":
"""
variable = self.variable.copy(deep=deep, data=data)
coords = {k: v.copy(deep=deep) for k, v in self._coords.items()}
indexes = self._indexes
if self._indexes is None:
indexes = self._indexes
else:
indexes = {k: v.copy(deep=deep) for k, v in self._indexes.items()}
return self._replace(variable, coords, indexes=indexes)

def __copy__(self) -> "DataArray":
Expand Down Expand Up @@ -3505,6 +3508,194 @@ def pad(
)
return self._from_temp_dataset(ds)

def idxmin(
self,
dim: Hashable = None,
skipna: bool = None,
fill_value: Any = dtypes.NA,
keep_attrs: bool = None,
) -> "DataArray":
"""Return the coordinate label of the minimum value along a dimension.
Returns a new `DataArray` named after the dimension with the values of
the coordinate labels along that dimension corresponding to minimum
values along that dimension.
In comparison to :py:meth:`~DataArray.argmin`, this returns the
coordinate label while :py:meth:`~DataArray.argmin` returns the index.
Parameters
----------
dim : str, optional
Dimension over which to apply `idxmin`. This is optional for 1D
arrays, but required for arrays with 2 or more dimensions.
skipna : bool or None, default None
If True, skip missing values (as marked by NaN). By default, only
skips missing values for ``float``, ``complex``, and ``object``
dtypes; other dtypes either do not have a sentinel missing value
(``int``) or ``skipna=True`` has not been implemented
(``datetime64`` or ``timedelta64``).
fill_value : Any, default NaN
Value to be filled in case all of the values along a dimension are
null. By default this is NaN. The fill value and result are
automatically converted to a compatible dtype if possible.
Ignored if ``skipna`` is False.
keep_attrs : bool, default False
If True, the attributes (``attrs``) will be copied from the
original object to the new one. If False (default), the new object
will be returned without attributes.
Returns
-------
reduced : DataArray
New `DataArray` object with `idxmin` applied to its data and the
indicated dimension removed.
See also
--------
Dataset.idxmin, DataArray.idxmax, DataArray.min, DataArray.argmin
Examples
--------
>>> array = xr.DataArray([0, 2, 1, 0, -2], dims="x",
... coords={"x": ['a', 'b', 'c', 'd', 'e']})
>>> array.min()
<xarray.DataArray ()>
array(-2)
>>> array.argmin()
<xarray.DataArray ()>
array(4)
>>> array.idxmin()
<xarray.DataArray 'x' ()>
array('e', dtype='<U1')
>>> array = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0],
... [-4.0, np.NaN, 2.0, np.NaN, -2.0],
... [np.NaN, np.NaN, 1., np.NaN, np.NaN]],
... dims=["y", "x"],
... coords={"y": [-1, 0, 1],
... "x": np.arange(5.)**2}
... )
>>> array.min(dim="x")
<xarray.DataArray (y: 3)>
array([-2., -4., 1.])
Coordinates:
* y (y) int64 -1 0 1
>>> array.argmin(dim="x")
<xarray.DataArray (y: 3)>
array([4, 0, 2])
Coordinates:
* y (y) int64 -1 0 1
>>> array.idxmin(dim="x")
<xarray.DataArray 'x' (y: 3)>
array([16., 0., 4.])
Coordinates:
* y (y) int64 -1 0 1
"""
return computation._calc_idxminmax(
array=self,
func=lambda x, *args, **kwargs: x.argmin(*args, **kwargs),
dim=dim,
skipna=skipna,
fill_value=fill_value,
keep_attrs=keep_attrs,
)

def idxmax(
self,
dim: Hashable = None,
skipna: bool = None,
fill_value: Any = dtypes.NA,
keep_attrs: bool = None,
) -> "DataArray":
"""Return the coordinate label of the maximum value along a dimension.
Returns a new `DataArray` named after the dimension with the values of
the coordinate labels along that dimension corresponding to maximum
values along that dimension.
In comparison to :py:meth:`~DataArray.argmax`, this returns the
coordinate label while :py:meth:`~DataArray.argmax` returns the index.
Parameters
----------
dim : str, optional
Dimension over which to apply `idxmax`. This is optional for 1D
arrays, but required for arrays with 2 or more dimensions.
skipna : bool or None, default None
If True, skip missing values (as marked by NaN). By default, only
skips missing values for ``float``, ``complex``, and ``object``
dtypes; other dtypes either do not have a sentinel missing value
(``int``) or ``skipna=True`` has not been implemented
(``datetime64`` or ``timedelta64``).
fill_value : Any, default NaN
Value to be filled in case all of the values along a dimension are
null. By default this is NaN. The fill value and result are
automatically converted to a compatible dtype if possible.
Ignored if ``skipna`` is False.
keep_attrs : bool, default False
If True, the attributes (``attrs``) will be copied from the
original object to the new one. If False (default), the new object
will be returned without attributes.
Returns
-------
reduced : DataArray
New `DataArray` object with `idxmax` applied to its data and the
indicated dimension removed.
See also
--------
Dataset.idxmax, DataArray.idxmin, DataArray.max, DataArray.argmax
Examples
--------
>>> array = xr.DataArray([0, 2, 1, 0, -2], dims="x",
... coords={"x": ['a', 'b', 'c', 'd', 'e']})
>>> array.max()
<xarray.DataArray ()>
array(2)
>>> array.argmax()
<xarray.DataArray ()>
array(1)
>>> array.idxmax()
<xarray.DataArray 'x' ()>
array('b', dtype='<U1')
>>> array = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0],
... [-4.0, np.NaN, 2.0, np.NaN, -2.0],
... [np.NaN, np.NaN, 1., np.NaN, np.NaN]],
... dims=["y", "x"],
... coords={"y": [-1, 0, 1],
... "x": np.arange(5.)**2}
... )
>>> array.max(dim="x")
<xarray.DataArray (y: 3)>
array([2., 2., 1.])
Coordinates:
* y (y) int64 -1 0 1
>>> array.argmax(dim="x")
<xarray.DataArray (y: 3)>
array([0, 2, 2])
Coordinates:
* y (y) int64 -1 0 1
>>> array.idxmax(dim="x")
<xarray.DataArray 'x' (y: 3)>
array([0., 4., 4.])
Coordinates:
* y (y) int64 -1 0 1
"""
return computation._calc_idxminmax(
array=self,
func=lambda x, *args, **kwargs: x.argmax(*args, **kwargs),
dim=dim,
skipna=skipna,
fill_value=fill_value,
keep_attrs=keep_attrs,
)

# this needs to be at the end, or mypy will confuse with `str`
# https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names
str = property(StringAccessor)
Expand Down
Loading

0 comments on commit 1416d5a

Please sign in to comment.