Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow apply_ufunc to ignore missing core dims #8138

Merged
merged 19 commits into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 53 additions & 14 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from xarray.core.dataset import Dataset
from xarray.core.types import CombineAttrsOptions, JoinOptions

MissingCoreDimOptions = Literal["raise", "copy", "drop"]

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"})
Expand All @@ -42,6 +44,7 @@ def _first_of_type(args, kind):
for arg in args:
if isinstance(arg, kind):
return arg

raise ValueError("This should be unreachable.")


Expand Down Expand Up @@ -347,7 +350,7 @@ def assert_and_return_exact_match(all_keys):
if keys != first_keys:
raise ValueError(
"exact match required for all data variable names, "
f"but {keys!r} != {first_keys!r}"
f"but {list(keys)} != {list(first_keys)}: {set(keys) ^ set(first_keys)} are not in both."
)
return first_keys

Expand Down Expand Up @@ -376,7 +379,7 @@ def collect_dict_values(
]


def _as_variables_or_variable(arg):
def _as_variables_or_variable(arg) -> Variable | tuple[Variable]:
try:
return arg.variables
except AttributeError:
Expand All @@ -396,8 +399,31 @@ def _unpack_dict_tuples(
return out


def _check_core_dims(signature, variable_args, name):
for core_dims, variable_arg in zip(signature.input_core_dims, variable_args):
# Check whether all the dims are on the variable. Note that we need the
# `hasattr` to check for a dims property, to protect against the case where
# a numpy array is passed in.
if hasattr(variable_arg, "dims") and set(core_dims) - set(variable_arg.dims):
# Slightly awkward design, of returning the error message. But we want to
# give a detailed error message, which requires inspecting the variable in
# the inner loop.
return (
f"Missing core dimension(s) {set(core_dims) - set(variable_arg.dims)} on `{name}`. "
"Either add the core dimension, or set `missing_core_dim` to `copy` or `drop`. "
max-sixty marked this conversation as resolved.
Show resolved Hide resolved
"The object:"
f"\n\n{variable_arg}"
)
return True


def apply_dict_of_variables_vfunc(
func, *args, signature: _UFuncSignature, join="inner", fill_value=None
func,
*args,
signature: _UFuncSignature,
join="inner",
fill_value=None,
missing_core_dim: MissingCoreDimOptions = "raise",
):
"""Apply a variable level function over dicts of DataArray, DataArray,
Variable and ndarray objects.
Expand All @@ -408,7 +434,20 @@ def apply_dict_of_variables_vfunc(

result_vars = {}
for name, variable_args in zip(names, grouped_by_name):
result_vars[name] = func(*variable_args)
core_dim_check = _check_core_dims(signature, variable_args, name)
if core_dim_check is True:
max-sixty marked this conversation as resolved.
Show resolved Hide resolved
result_vars[name] = func(*variable_args)
else:
if missing_core_dim == "raise":
raise ValueError(core_dim_check)
elif missing_core_dim == "copy":
result_vars[name] = variable_args[0]
max-sixty marked this conversation as resolved.
Show resolved Hide resolved
elif missing_core_dim == "drop":
pass
else:
raise ValueError(
dcherian marked this conversation as resolved.
Show resolved Hide resolved
f"Invalid value for `missing_core_dim`: {missing_core_dim!r}"
)

if signature.num_outputs > 1:
return _unpack_dict_tuples(result_vars, signature.num_outputs)
Expand Down Expand Up @@ -441,6 +480,7 @@ def apply_dataset_vfunc(
fill_value=_NO_FILL_VALUE,
exclude_dims=frozenset(),
keep_attrs="override",
missing_core_dim: MissingCoreDimOptions = "raise",
) -> Dataset | tuple[Dataset, ...]:
"""Apply a variable level function over Dataset, dict of DataArray,
DataArray, Variable and/or ndarray objects.
Expand All @@ -467,7 +507,12 @@ def apply_dataset_vfunc(
args = tuple(getattr(arg, "data_vars", arg) for arg in args)

result_vars = apply_dict_of_variables_vfunc(
func, *args, signature=signature, join=dataset_join, fill_value=fill_value
func,
*args,
signature=signature,
join=dataset_join,
fill_value=fill_value,
missing_core_dim=missing_core_dim,
)

out: Dataset | tuple[Dataset, ...]
Expand Down Expand Up @@ -595,17 +640,9 @@ def broadcast_compat_data(
return data

set_old_dims = set(old_dims)
missing_core_dims = [d for d in core_dims if d not in set_old_dims]
if missing_core_dims:
raise ValueError(
"operand to apply_ufunc has required core dimensions {}, but "
"some of these dimensions are absent on an input variable: {}".format(
list(core_dims), missing_core_dims
)
)

set_new_dims = set(new_dims)
unexpected_dims = [d for d in old_dims if d not in set_new_dims]

if unexpected_dims:
raise ValueError(
"operand to apply_ufunc encountered unexpected "
Expand Down Expand Up @@ -851,6 +888,7 @@ def apply_ufunc(
output_sizes: Mapping[Any, int] | None = None,
meta: Any = None,
dask_gufunc_kwargs: dict[str, Any] | None = None,
missing_core_dim: MissingCoreDimOptions = "raise",
max-sixty marked this conversation as resolved.
Show resolved Hide resolved
) -> Any:
"""Apply a vectorized function for unlabeled arrays on xarray objects.

Expand Down Expand Up @@ -1192,6 +1230,7 @@ def apply_ufunc(
dataset_join=dataset_join,
fill_value=dataset_fill_value,
keep_attrs=keep_attrs,
missing_core_dim=missing_core_dim,
)
# feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
elif any(isinstance(a, DataArray) for a in args):
Expand Down
160 changes: 160 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,166 @@ def func(x):
assert_identical(out1, dataset)


def test_apply_missing_dims() -> None:
## Single arg

def add_one(a, core_dims, missing_core_dim):
return apply_ufunc(
lambda x: x + 1,
a,
input_core_dims=core_dims,
output_core_dims=core_dims,
missing_core_dim=missing_core_dim,
)

array = np.arange(6).reshape(2, 3)
variable = xr.Variable(["x", "y"], array)
variable_no_y = xr.Variable(["x", "z"], array)

ds = xr.Dataset({"x_y": variable, "x_z": variable_no_y})

# Check the standard stuff works OK
assert_identical(
add_one(ds[["x_y"]], core_dims=[["y"]], missing_core_dim="raise"),
ds[["x_y"]] + 1,
)

# `raise` — should raise on a missing dim
with pytest.raises(ValueError):
add_one(ds, core_dims=[["y"]], missing_core_dim="raise"),

# `drop` — should drop the var with the missing dim
assert_identical(
add_one(ds, core_dims=[["y"]], missing_core_dim="drop"),
(ds + 1).drop_vars("x_z"),
)

# `copy` — should not add one to the missing with `copy`
copy_result = add_one(ds, core_dims=[["y"]], missing_core_dim="copy")
assert_identical(copy_result["x_y"], (ds + 1)["x_y"])
assert_identical(copy_result["x_z"], ds["x_z"])

## Multiple args

def sum_add(a, b, core_dims, missing_core_dim):
return apply_ufunc(
lambda a, b, axis=None: a.sum(axis) + b.sum(axis),
a,
b,
input_core_dims=core_dims,
missing_core_dim=missing_core_dim,
)

# Check the standard stuff works OK
assert_identical(
sum_add(
ds[["x_y"]],
ds[["x_y"]],
core_dims=[["x", "y"], ["x", "y"]],
missing_core_dim="raise",
),
ds[["x_y"]].sum() * 2,
)

# `raise` — should raise on a missing dim
with pytest.raises(
ValueError, match=r".*Missing core dimension\(s\) \{'y'\} on `x_z`.*"
):
sum_add(
ds[["x_z"]],
ds[["x_z"]],
core_dims=[["x", "y"], ["x", "z"]],
missing_core_dim="raise",
)

# `raise` on a missing dim on a non-first arg
with pytest.raises(
ValueError, match=r".*Missing core dimension\(s\) \{'y'\} on `x_z`.*"
):
sum_add(
ds[["x_z"]],
ds[["x_z"]],
core_dims=[["x", "z"], ["x", "y"]],
missing_core_dim="raise",
)

# `drop` — should drop the var with the missing dim
assert_identical(
sum_add(
ds[["x_z"]],
ds[["x_z"]],
core_dims=[["x", "y"], ["x", "y"]],
missing_core_dim="drop",
),
ds[[]],
)

# `copy` — should drop the var with the missing dim
assert_identical(
sum_add(
ds[["x_z"]],
ds[["x_z"]],
core_dims=[["x", "y"], ["x", "y"]],
missing_core_dim="copy",
),
ds[["x_z"]],
)

## Multiple vars per arg
assert_identical(
sum_add(
ds[["x_y", "x_y"]],
ds[["x_y", "x_y"]],
core_dims=[["x", "y"], ["x", "y"]],
missing_core_dim="raise",
),
ds[["x_y", "x_y"]].sum() * 2,
)

assert_identical(
sum_add(
ds,
ds,
core_dims=[["x", "y"], ["x", "y"]],
missing_core_dim="drop",
),
ds[["x_y"]].sum() * 2,
)

assert_identical(
sum_add(
# The first one has the wrong dims — using `z` for the `x_t` var
ds.assign(x_t=ds["x_z"])[["x_y", "x_t"]],
ds.assign(x_t=ds["x_y"])[["x_y", "x_t"]],
core_dims=[["x", "y"], ["x", "y"]],
missing_core_dim="drop",
),
ds[["x_y"]].sum() * 2,
)

assert_identical(
sum_add(
ds.assign(x_t=ds["x_y"])[["x_y", "x_t"]],
# The second one has the wrong dims — using `z` for the `x_t` var (seems
# duplicative but this was a bug in the initial impl...)
ds.assign(x_t=ds["x_z"])[["x_y", "x_t"]],
core_dims=[["x", "y"], ["x", "y"]],
missing_core_dim="drop",
),
ds[["x_y"]].sum() * 2,
)

assert_identical(
sum_add(
ds.assign(x_t=ds["x_y"])[["x_y", "x_t"]],
ds.assign(x_t=ds["x_z"])[["x_y", "x_t"]],
core_dims=[["x", "y"], ["x", "y"]],
missing_core_dim="copy",
),
ds.drop_vars("x_z").assign(x_y=30, x_t=ds["x_y"]),
)


@requires_dask
def test_apply_dask_parallelized_two_outputs() -> None:
data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y"))
Expand Down
Loading