Skip to content

Commit

Permalink
Allow apply_ufunc to ignore missing core dims (#8138)
Browse files Browse the repository at this point in the history
* Allow `apply_ufunc` to ignore missing core dims

* Display data returned in ufunc error message

This makes debugging much easier!

* Add tests for multiple var dataset

* Update xarray/core/computation.py

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>

* Update xarray/core/computation.py

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>

* more error message changes

---------

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
  • Loading branch information
max-sixty and dcherian authored Sep 17, 2023
1 parent b6d4bf7 commit 6b0d47c
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 14 deletions.
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ New Features
By `Martin Raspaud <https://github.com/mraspaud>`_.
- Improved static typing of reduction methods (:pull:`6746`).
By `Richard Kleijn <https://github.com/rhkleijn>`_.
- Added `on_missing_core_dims` to :py:meth:`apply_ufunc` to allow for copying or
dropping a :py:class:`Dataset`'s variables with missing core dimensions.
(:pull:`8138`)
By `Maximilian Roos <https://github.com/max-sixty>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
77 changes: 63 additions & 14 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from xarray.core.dataset import Dataset
from xarray.core.types import CombineAttrsOptions, JoinOptions

MissingCoreDimOptions = Literal["raise", "copy", "drop"]

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"})
Expand All @@ -42,6 +44,7 @@ def _first_of_type(args, kind):
for arg in args:
if isinstance(arg, kind):
return arg

raise ValueError("This should be unreachable.")


Expand Down Expand Up @@ -347,7 +350,7 @@ def assert_and_return_exact_match(all_keys):
if keys != first_keys:
raise ValueError(
"exact match required for all data variable names, "
f"but {keys!r} != {first_keys!r}"
f"but {list(keys)} != {list(first_keys)}: {set(keys) ^ set(first_keys)} are not in both."
)
return first_keys

Expand Down Expand Up @@ -376,7 +379,7 @@ def collect_dict_values(
]


def _as_variables_or_variable(arg):
def _as_variables_or_variable(arg) -> Variable | tuple[Variable]:
try:
return arg.variables
except AttributeError:
Expand All @@ -396,8 +399,39 @@ def _unpack_dict_tuples(
return out


def _check_core_dims(signature, variable_args, name):
"""
Chcek if an arg has all the core dims required by the signature.
Slightly awkward design, of returning the error message. But we want to
give a detailed error message, which requires inspecting the variable in
the inner loop.
"""
missing = []
for i, (core_dims, variable_arg) in enumerate(
zip(signature.input_core_dims, variable_args)
):
# Check whether all the dims are on the variable. Note that we need the
# `hasattr` to check for a dims property, to protect against the case where
# a numpy array is passed in.
if hasattr(variable_arg, "dims") and set(core_dims) - set(variable_arg.dims):
missing += [[i, variable_arg, core_dims]]
if missing:
message = ""
for i, variable_arg, core_dims in missing:
message += f"Missing core dims {set(core_dims) - set(variable_arg.dims)} from arg number {i + 1} on a variable named `{name}`:\n{variable_arg}\n\n"
message += "Either add the core dimension, or if passing a dataset alternatively pass `on_missing_core_dim` as `copy` or `drop`. "
return message
return True


def apply_dict_of_variables_vfunc(
func, *args, signature: _UFuncSignature, join="inner", fill_value=None
func,
*args,
signature: _UFuncSignature,
join="inner",
fill_value=None,
on_missing_core_dim: MissingCoreDimOptions = "raise",
):
"""Apply a variable level function over dicts of DataArray, DataArray,
Variable and ndarray objects.
Expand All @@ -408,7 +442,20 @@ def apply_dict_of_variables_vfunc(

result_vars = {}
for name, variable_args in zip(names, grouped_by_name):
result_vars[name] = func(*variable_args)
core_dim_present = _check_core_dims(signature, variable_args, name)
if core_dim_present is True:
result_vars[name] = func(*variable_args)
else:
if on_missing_core_dim == "raise":
raise ValueError(core_dim_present)
elif on_missing_core_dim == "copy":
result_vars[name] = variable_args[0]
elif on_missing_core_dim == "drop":
pass
else:
raise ValueError(
f"Invalid value for `on_missing_core_dim`: {on_missing_core_dim!r}"
)

if signature.num_outputs > 1:
return _unpack_dict_tuples(result_vars, signature.num_outputs)
Expand Down Expand Up @@ -441,6 +488,7 @@ def apply_dataset_vfunc(
fill_value=_NO_FILL_VALUE,
exclude_dims=frozenset(),
keep_attrs="override",
on_missing_core_dim: MissingCoreDimOptions = "raise",
) -> Dataset | tuple[Dataset, ...]:
"""Apply a variable level function over Dataset, dict of DataArray,
DataArray, Variable and/or ndarray objects.
Expand All @@ -467,7 +515,12 @@ def apply_dataset_vfunc(
args = tuple(getattr(arg, "data_vars", arg) for arg in args)

result_vars = apply_dict_of_variables_vfunc(
func, *args, signature=signature, join=dataset_join, fill_value=fill_value
func,
*args,
signature=signature,
join=dataset_join,
fill_value=fill_value,
on_missing_core_dim=on_missing_core_dim,
)

out: Dataset | tuple[Dataset, ...]
Expand Down Expand Up @@ -595,17 +648,9 @@ def broadcast_compat_data(
return data

set_old_dims = set(old_dims)
missing_core_dims = [d for d in core_dims if d not in set_old_dims]
if missing_core_dims:
raise ValueError(
"operand to apply_ufunc has required core dimensions {}, but "
"some of these dimensions are absent on an input variable: {}".format(
list(core_dims), missing_core_dims
)
)

set_new_dims = set(new_dims)
unexpected_dims = [d for d in old_dims if d not in set_new_dims]

if unexpected_dims:
raise ValueError(
"operand to apply_ufunc encountered unexpected "
Expand Down Expand Up @@ -851,6 +896,7 @@ def apply_ufunc(
output_sizes: Mapping[Any, int] | None = None,
meta: Any = None,
dask_gufunc_kwargs: dict[str, Any] | None = None,
on_missing_core_dim: MissingCoreDimOptions = "raise",
) -> Any:
"""Apply a vectorized function for unlabeled arrays on xarray objects.
Expand Down Expand Up @@ -964,6 +1010,8 @@ def apply_ufunc(
:py:func:`dask.array.apply_gufunc`. ``meta`` should be given in the
``dask_gufunc_kwargs`` parameter . It will be removed as direct parameter
a future version.
on_missing_core_dim : {"raise", "copy", "drop"}, default: "raise"
How to handle missing core dimensions on input variables.
Returns
-------
Expand Down Expand Up @@ -1192,6 +1240,7 @@ def apply_ufunc(
dataset_join=dataset_join,
fill_value=dataset_fill_value,
keep_attrs=keep_attrs,
on_missing_core_dim=on_missing_core_dim,
)
# feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
elif any(isinstance(a, DataArray) for a in args):
Expand Down
162 changes: 162 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,168 @@ def func(x):
assert_identical(out1, dataset)


def test_apply_missing_dims() -> None:
## Single arg

def add_one(a, core_dims, on_missing_core_dim):
return apply_ufunc(
lambda x: x + 1,
a,
input_core_dims=core_dims,
output_core_dims=core_dims,
on_missing_core_dim=on_missing_core_dim,
)

array = np.arange(6).reshape(2, 3)
variable = xr.Variable(["x", "y"], array)
variable_no_y = xr.Variable(["x", "z"], array)

ds = xr.Dataset({"x_y": variable, "x_z": variable_no_y})

# Check the standard stuff works OK
assert_identical(
add_one(ds[["x_y"]], core_dims=[["y"]], on_missing_core_dim="raise"),
ds[["x_y"]] + 1,
)

# `raise` — should raise on a missing dim
with pytest.raises(ValueError):
add_one(ds, core_dims=[["y"]], on_missing_core_dim="raise"),

# `drop` — should drop the var with the missing dim
assert_identical(
add_one(ds, core_dims=[["y"]], on_missing_core_dim="drop"),
(ds + 1).drop_vars("x_z"),
)

# `copy` — should not add one to the missing with `copy`
copy_result = add_one(ds, core_dims=[["y"]], on_missing_core_dim="copy")
assert_identical(copy_result["x_y"], (ds + 1)["x_y"])
assert_identical(copy_result["x_z"], ds["x_z"])

## Multiple args

def sum_add(a, b, core_dims, on_missing_core_dim):
return apply_ufunc(
lambda a, b, axis=None: a.sum(axis) + b.sum(axis),
a,
b,
input_core_dims=core_dims,
on_missing_core_dim=on_missing_core_dim,
)

# Check the standard stuff works OK
assert_identical(
sum_add(
ds[["x_y"]],
ds[["x_y"]],
core_dims=[["x", "y"], ["x", "y"]],
on_missing_core_dim="raise",
),
ds[["x_y"]].sum() * 2,
)

# `raise` — should raise on a missing dim
with pytest.raises(
ValueError,
match=r".*Missing core dims \{'y'\} from arg number 1 on a variable named `x_z`:\n.*<xarray.Variable \(x: 2, z: ",
):
sum_add(
ds[["x_z"]],
ds[["x_z"]],
core_dims=[["x", "y"], ["x", "z"]],
on_missing_core_dim="raise",
)

# `raise` on a missing dim on a non-first arg
with pytest.raises(
ValueError,
match=r".*Missing core dims \{'y'\} from arg number 2 on a variable named `x_z`:\n.*<xarray.Variable \(x: 2, z: ",
):
sum_add(
ds[["x_z"]],
ds[["x_z"]],
core_dims=[["x", "z"], ["x", "y"]],
on_missing_core_dim="raise",
)

# `drop` — should drop the var with the missing dim
assert_identical(
sum_add(
ds[["x_z"]],
ds[["x_z"]],
core_dims=[["x", "y"], ["x", "y"]],
on_missing_core_dim="drop",
),
ds[[]],
)

# `copy` — should drop the var with the missing dim
assert_identical(
sum_add(
ds[["x_z"]],
ds[["x_z"]],
core_dims=[["x", "y"], ["x", "y"]],
on_missing_core_dim="copy",
),
ds[["x_z"]],
)

## Multiple vars per arg
assert_identical(
sum_add(
ds[["x_y", "x_y"]],
ds[["x_y", "x_y"]],
core_dims=[["x", "y"], ["x", "y"]],
on_missing_core_dim="raise",
),
ds[["x_y", "x_y"]].sum() * 2,
)

assert_identical(
sum_add(
ds,
ds,
core_dims=[["x", "y"], ["x", "y"]],
on_missing_core_dim="drop",
),
ds[["x_y"]].sum() * 2,
)

assert_identical(
sum_add(
# The first one has the wrong dims — using `z` for the `x_t` var
ds.assign(x_t=ds["x_z"])[["x_y", "x_t"]],
ds.assign(x_t=ds["x_y"])[["x_y", "x_t"]],
core_dims=[["x", "y"], ["x", "y"]],
on_missing_core_dim="drop",
),
ds[["x_y"]].sum() * 2,
)

assert_identical(
sum_add(
ds.assign(x_t=ds["x_y"])[["x_y", "x_t"]],
# The second one has the wrong dims — using `z` for the `x_t` var (seems
# duplicative but this was a bug in the initial impl...)
ds.assign(x_t=ds["x_z"])[["x_y", "x_t"]],
core_dims=[["x", "y"], ["x", "y"]],
on_missing_core_dim="drop",
),
ds[["x_y"]].sum() * 2,
)

assert_identical(
sum_add(
ds.assign(x_t=ds["x_y"])[["x_y", "x_t"]],
ds.assign(x_t=ds["x_z"])[["x_y", "x_t"]],
core_dims=[["x", "y"], ["x", "y"]],
on_missing_core_dim="copy",
),
ds.drop_vars("x_z").assign(x_y=30, x_t=ds["x_y"]),
)


@requires_dask
def test_apply_dask_parallelized_two_outputs() -> None:
data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y"))
Expand Down

0 comments on commit 6b0d47c

Please sign in to comment.