Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

combine keep_attrs and combine_attrs in apply_ufunc #5041

Merged
merged 22 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
35c1b5b
add tests for most keep_attrs values for apply_ufunc for variables
keewis Mar 15, 2021
c13b080
implement keep_attrs for apply_variable_ufunc
keewis Mar 15, 2021
2ee7076
move the keep_attrs preprocessing to apply_ufunc
keewis Mar 15, 2021
68e09ec
implement keep_attrs for DataArray objects
keewis Mar 15, 2021
b0a3054
implement keep_attrs for Dataset
keewis Mar 15, 2021
da1a239
don't include the definition of the assert func in the traceback
keewis Mar 15, 2021
92d6ce5
check that keep_attrs is applied recursively
keewis Mar 15, 2021
8fccdd1
skip the recursion tests
keewis Mar 15, 2021
5a16add
don't skip the variable tests
keewis Mar 28, 2021
373fcf3
Merge branch 'master' into keep_attrs
keewis Mar 28, 2021
c995d65
Merge branch 'master' into keep_attrs
keewis Apr 5, 2021
fdb7ff3
Merge branch 'master' into keep_attrs
keewis May 5, 2021
d95cdf1
add a combine_attrs parameter to merge_coordinates_without_align
keewis May 5, 2021
6cdfd09
propagate keep_attrs
keewis May 5, 2021
8e23728
Merge branch 'master' into keep_attrs
keewis May 5, 2021
a38dffb
update whats-new.rst [skip-ci]
keewis May 5, 2021
085052b
Merge branch 'master' into keep_attrs
keewis May 8, 2021
b257ad7
change the default
keewis May 8, 2021
bef9064
change the default values to make sure keep_attrs is always a str
keewis May 8, 2021
9c07044
test data, dims and coords separately
keewis May 8, 2021
551b77b
move the whats-new.rst entry [skip-ci]
keewis May 8, 2021
616f7e9
Merge branch 'master' into keep_attrs
dcherian May 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ v0.18.1 (unreleased)

New Features
~~~~~~~~~~~~
- allow passing ``combine_attrs`` strategy names to the ``keep_attrs`` parameter of
:py:func:`apply_ufunc` (:pull:`5041`)
By `Justus Magin <https://github.com/keewis>`_.
- :py:meth:`Dataset.interp` now allows interpolation with non-numerical datatypes,
such as booleans, instead of dropping them. (:issue:`4761` :pull:`5008`).
By `Jimmy Westling <https://github.com/illviljan>`_.
Expand Down
92 changes: 59 additions & 33 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

from . import dtypes, duck_array_ops, utils
from .alignment import align, deep_align
from .merge import merge_coordinates_without_align
from .options import OPTIONS
from .merge import merge_attrs, merge_coordinates_without_align
from .options import OPTIONS, _get_keep_attrs
from .pycompat import is_duck_dask_array
from .utils import is_dict_like
from .variable import Variable
Expand All @@ -50,6 +50,11 @@ def _first_of_type(args, kind):
raise ValueError("This should be unreachable.")


def _all_of_type(args, kind):
"""Return all objects of type 'kind'"""
return [arg for arg in args if isinstance(arg, kind)]


class _UFuncSignature:
"""Core dimensions signature for a given function.

Expand Down Expand Up @@ -202,7 +207,10 @@ def _get_coords_list(args) -> List["Coordinates"]:


def build_output_coords(
args: list, signature: _UFuncSignature, exclude_dims: AbstractSet = frozenset()
args: list,
signature: _UFuncSignature,
exclude_dims: AbstractSet = frozenset(),
combine_attrs: str = "override",
) -> "List[Dict[Any, Variable]]":
"""Build output coordinates for an operation.

Expand Down Expand Up @@ -230,7 +238,7 @@ def build_output_coords(
else:
# TODO: save these merged indexes, instead of re-computing them later
merged_vars, unused_indexes = merge_coordinates_without_align(
coords_list, exclude_dims=exclude_dims
coords_list, exclude_dims=exclude_dims, combine_attrs=combine_attrs
)

output_coords = []
Expand All @@ -248,7 +256,12 @@ def build_output_coords(


def apply_dataarray_vfunc(
func, *args, signature, join="inner", exclude_dims=frozenset(), keep_attrs=False
func,
*args,
signature,
join="inner",
exclude_dims=frozenset(),
keep_attrs="override",
):
"""Apply a variable level function over DataArray, Variable and/or ndarray
objects.
Expand All @@ -260,12 +273,16 @@ def apply_dataarray_vfunc(
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
)

if keep_attrs:
objs = _all_of_type(args, DataArray)

if keep_attrs == "drop":
name = result_name(args)
else:
first_obj = _first_of_type(args, DataArray)
name = first_obj.name
else:
name = result_name(args)
result_coords = build_output_coords(args, signature, exclude_dims)
result_coords = build_output_coords(
args, signature, exclude_dims, combine_attrs=keep_attrs
)

data_vars = [getattr(a, "variable", a) for a in args]
result_var = func(*data_vars)
Expand All @@ -279,13 +296,12 @@ def apply_dataarray_vfunc(
(coords,) = result_coords
out = DataArray(result_var, coords, name=name, fastpath=True)

if keep_attrs:
if isinstance(out, tuple):
for da in out:
# This is adding attrs in place
da._copy_attrs_from(first_obj)
else:
out._copy_attrs_from(first_obj)
attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs)
if isinstance(out, tuple):
for da in out:
da.attrs = attrs
else:
out.attrs = attrs

return out

Expand Down Expand Up @@ -400,7 +416,7 @@ def apply_dataset_vfunc(
dataset_join="exact",
fill_value=_NO_FILL_VALUE,
exclude_dims=frozenset(),
keep_attrs=False,
keep_attrs="override",
):
"""Apply a variable level function over Dataset, dict of DataArray,
DataArray, Variable and/or ndarray objects.
Expand All @@ -414,15 +430,16 @@ def apply_dataset_vfunc(
"dataset_fill_value argument."
)

if keep_attrs:
first_obj = _first_of_type(args, Dataset)
objs = _all_of_type(args, Dataset)

if len(args) > 1:
args = deep_align(
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
)

list_of_coords = build_output_coords(args, signature, exclude_dims)
list_of_coords = build_output_coords(
args, signature, exclude_dims, combine_attrs=keep_attrs
)
args = [getattr(arg, "data_vars", arg) for arg in args]

result_vars = apply_dict_of_variables_vfunc(
Expand All @@ -435,13 +452,13 @@ def apply_dataset_vfunc(
(coord_vars,) = list_of_coords
out = _fast_dataset(result_vars, coord_vars)

if keep_attrs:
if isinstance(out, tuple):
for ds in out:
# This is adding attrs in place
ds._copy_attrs_from(first_obj)
else:
out._copy_attrs_from(first_obj)
attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs)
if isinstance(out, tuple):
for ds in out:
ds.attrs = attrs
else:
out.attrs = attrs

return out


Expand Down Expand Up @@ -609,14 +626,12 @@ def apply_variable_ufunc(
dask="forbidden",
output_dtypes=None,
vectorize=False,
keep_attrs=False,
keep_attrs="override",
dask_gufunc_kwargs=None,
):
"""Apply a ndarray level function over Variable and/or ndarray objects."""
from .variable import Variable, as_compatible_data

first_obj = _first_of_type(args, Variable)

dim_sizes = unified_dim_sizes(
(a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims
)
Expand Down Expand Up @@ -736,6 +751,12 @@ def func(*arrays):
)
)

objs = _all_of_type(args, Variable)
attrs = merge_attrs(
[obj.attrs for obj in objs],
combine_attrs=keep_attrs,
)

output = []
for dims, data in zip(output_dims, result_data):
data = as_compatible_data(data)
Expand All @@ -758,8 +779,7 @@ def func(*arrays):
)
)

if keep_attrs:
var.attrs.update(first_obj.attrs)
var.attrs = attrs
output.append(var)

if signature.num_outputs == 1:
Expand Down Expand Up @@ -801,7 +821,7 @@ def apply_ufunc(
join: str = "exact",
dataset_join: str = "exact",
dataset_fill_value: object = _NO_FILL_VALUE,
keep_attrs: bool = False,
keep_attrs: Union[bool, str] = None,
kwargs: Mapping = None,
dask: str = "forbidden",
output_dtypes: Sequence = None,
Expand Down Expand Up @@ -1098,6 +1118,12 @@ def apply_ufunc(
if kwargs:
func = functools.partial(func, **kwargs)

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)

if isinstance(keep_attrs, bool):
keep_attrs = "override" if keep_attrs else "drop"

variables_vfunc = functools.partial(
apply_variable_ufunc,
func,
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def merge_coordinates_without_align(
objects: "List[Coordinates]",
prioritized: Mapping[Hashable, MergeElement] = None,
exclude_dims: AbstractSet = frozenset(),
combine_attrs: str = "override",
) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]:
"""Merge variables/indexes from coordinates without automatic alignments.

Expand All @@ -335,7 +336,7 @@ def merge_coordinates_without_align(
else:
filtered = collected

return merge_collected(filtered, prioritized)
return merge_collected(filtered, prioritized, combine_attrs=combine_attrs)


def determine_coords(
Expand Down
Loading