Skip to content

Commit

Permalink
No commit message
Browse files Browse the repository at this point in the history
  • Loading branch information
max-sixty committed Sep 12, 2023
1 parent db84bda commit ebe707b
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 25 deletions.
47 changes: 28 additions & 19 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,25 +416,34 @@ def apply_dict_of_variables_vfunc(

result_vars = {}
for name, variable_args in zip(names, grouped_by_name):
# TODO: is this a reasonable to check for missing core dims? We check for a dims
# property to protect against the case where a numpy array is passed in.
if hasattr(variable_args[0], "dims") and set(
signature.all_input_core_dims
) - set(variable_args[0].dims):
if missing_core_dim == "raise":
raise ValueError(
f"Missing core dimension on {name!r}. Either add the core dimension, or set `missing_core_dim` to `copy` or `drop`."
)
elif missing_core_dim == "copy":
# TODO: is it correct to copy the first variable here?
result_vars[name] = variable_args[0]
elif missing_core_dim == "drop":
pass
else:
raise ValueError(
f"Invalid value for `missing_core_dim`: {missing_core_dim!r}"
)
else:
for core_dims, variable_arg in zip(signature.input_core_dims, variable_args):
# (if there's a more elegant way to do this than using a temporary, that'd
# be nice. But we need to have context of the specific object failing in
# order to produce a good error message; and need to copy the variable if
# necessary, skipping the intended evaluation of `func`.)
all_vars_have_all_core_dims = True
# Check whether all the dims are on the variable. Note that we need the
# `hasattr` to check for a dims property, to protect against the case where
# a numpy array is passed in.
if hasattr(variable_arg, "dims") and set(core_dims) - set(
variable_arg.dims
):
if missing_core_dim == "raise":
raise ValueError(
f"Missing core dimension(s) {set(core_dims) - set(variable_arg.dims)} on `{name}` (object below). "
"Either add the core dimension, or set `missing_core_dim` to `copy` or `drop`."
f"\n\n{variable_arg}"
)
elif missing_core_dim == "copy":
result_vars[name] = variable_args[0]
all_vars_have_all_core_dims = False
elif missing_core_dim == "drop":
all_vars_have_all_core_dims = False
else:
raise ValueError(
f"Invalid value for `missing_core_dim`: {missing_core_dim!r}"
)
if all_vars_have_all_core_dims:
result_vars[name] = func(*variable_args)

if signature.num_outputs > 1:
Expand Down
80 changes: 74 additions & 6 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def func(x):


def test_apply_missing_dims() -> None:
## Single arg

def add_one(a, core_dims, missing_core_dim):
return apply_ufunc(
lambda x: x + 1,
Expand All @@ -271,12 +273,12 @@ def add_one(a, core_dims, missing_core_dim):
variable = xr.Variable(["x", "y"], array)
variable_no_y = xr.Variable(["x", "z"], array)

dataset = xr.Dataset({"matching": variable, "missing": variable_no_y})
dataset = xr.Dataset({"x_y": variable, "x_z": variable_no_y})

# Check the standard stuff works OK
assert_identical(
add_one(dataset[["matching"]], core_dims=[["y"]], missing_core_dim="raise"),
dataset[["matching"]] + 1,
add_one(dataset[["x_y"]], core_dims=[["y"]], missing_core_dim="raise"),
dataset[["x_y"]] + 1,
)

# `raise` — should raise on a missing dim
Expand All @@ -286,13 +288,79 @@ def add_one(a, core_dims, missing_core_dim):
# `drop` — should drop the var with the missing dim
assert_identical(
add_one(dataset, core_dims=[["y"]], missing_core_dim="drop"),
(dataset + 1).drop_vars("missing"),
(dataset + 1).drop_vars("x_z"),
)

# `copy` — should not add one to the missing with `copy`
copy_result = add_one(dataset, core_dims=[["y"]], missing_core_dim="copy")
assert_identical(copy_result["matching"], (dataset + 1)["matching"])
assert_identical(copy_result["missing"], dataset["missing"])
assert_identical(copy_result["x_y"], (dataset + 1)["x_y"])
assert_identical(copy_result["x_z"], dataset["x_z"])

## Multiple args

def sum_add(a, b, core_dims, missing_core_dim):
return apply_ufunc(
lambda a, b: a.sum() + b.sum(),
a,
b,
input_core_dims=core_dims,
missing_core_dim=missing_core_dim,
)

# Check the standard stuff works OK
assert_identical(
sum_add(
dataset[["x_y"]],
dataset[["x_y"]],
core_dims=[["x", "y"], ["x", "y"]],
missing_core_dim="raise",
),
dataset[["x_y"]].sum() * 2,
)

# `raise` — should raise on a missing dim
with pytest.raises(
ValueError, match=r".*Missing core dimension\(s\) \{'y'\} on `x_z`.*"
):
sum_add(
dataset[["x_z"]],
dataset[["x_z"]],
core_dims=[["x", "y"], ["x", "z"]],
missing_core_dim="raise",
)

# `raise` on a missing dim on a non-first arg
with pytest.raises(
ValueError, match=r".*Missing core dimension\(s\) \{'y'\} on `x_z`.*"
):
sum_add(
dataset[["x_z"]],
dataset[["x_z"]],
core_dims=[["x", "z"], ["x", "y"]],
missing_core_dim="raise",
)

# `drop` — should drop the var with the missing dim
assert_identical(
sum_add(
dataset[["x_z"]],
dataset[["x_z"]],
core_dims=[["x", "y"], ["x", "y"]],
missing_core_dim="drop",
),
dataset[[]],
)

# `copy` — should drop the var with the missing dim
assert_identical(
sum_add(
dataset[["x_z"]],
dataset[["x_z"]],
core_dims=[["x", "y"], ["x", "y"]],
missing_core_dim="copy",
),
dataset[["x_z"]],
)


@requires_dask
Expand Down

0 comments on commit ebe707b

Please sign in to comment.