diff --git a/doc/api.rst b/doc/api.rst index bddccb20e37..bb3a99bfbb0 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -36,6 +36,7 @@ Top-level functions map_blocks show_versions set_options + unify_chunks Dataset ======= diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ff67ea20073..f592f55a627 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,8 @@ v0.18.3 (unreleased) New Features ~~~~~~~~~~~~ +- New top-level function :py:func:`unify_chunks`. + By `Mattia Almansi `_. - Allow assigning values to a subset of a dataset using positional or label-based indexing (:issue:`3015`, :pull:`5362`). By `Matthias Göbel `_. diff --git a/xarray/__init__.py b/xarray/__init__.py index b4b7135091a..8321aba4b46 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -18,7 +18,7 @@ from .core.alignment import align, broadcast from .core.combine import combine_by_coords, combine_nested from .core.common import ALL_DIMS, full_like, ones_like, zeros_like -from .core.computation import apply_ufunc, corr, cov, dot, polyval, where +from .core.computation import apply_ufunc, corr, cov, dot, polyval, unify_chunks, where from .core.concat import concat from .core.dataarray import DataArray from .core.dataset import Dataset @@ -74,6 +74,7 @@ "save_mfdataset", "set_options", "show_versions", + "unify_chunks", "where", "zeros_like", # Classes diff --git a/xarray/core/computation.py b/xarray/core/computation.py index fe6d8ab35e9..474a72180e1 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1,6 +1,8 @@ """ Functions for applying functions that act on arrays to xarray's labeled data. """ +from __future__ import annotations + import functools import itertools import operator @@ -19,6 +21,7 @@ Optional, Sequence, Tuple, + TypeVar, Union, ) @@ -34,8 +37,11 @@ if TYPE_CHECKING: from .coordinates import Coordinates # noqa + from .dataarray import DataArray from .dataset import Dataset + T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) + _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") _JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"}) @@ -1721,3 +1727,61 @@ def _calc_idxminmax( res.attrs = indx.attrs return res + + +def unify_chunks(*objects: T_DSorDA) -> Tuple[T_DSorDA, ...]: + """ + Given any number of Dataset and/or DataArray objects, returns + new objects with unified chunk size along all chunked dimensions. + + Returns + ------- + unified (DataArray or Dataset) – Tuple of objects with the same type as + *objects with consistent chunk sizes for all dask-array variables + + See Also + -------- + dask.array.core.unify_chunks + """ + from .dataarray import DataArray + + # Convert all objects to datasets + datasets = [ + obj._to_temp_dataset() if isinstance(obj, DataArray) else obj.copy() + for obj in objects + ] + + # Get argumets to pass into dask.array.core.unify_chunks + unify_chunks_args = [] + sizes: dict[Hashable, int] = {} + for ds in datasets: + for v in ds._variables.values(): + if v.chunks is not None: + # Check that sizes match across different datasets + for dim, size in v.sizes.items(): + try: + if sizes[dim] != size: + raise ValueError( + f"Dimension {dim!r} size mismatch: {sizes[dim]} != {size}" + ) + except KeyError: + sizes[dim] = size + unify_chunks_args += [v._data, v._dims] + + # No dask arrays: Return inputs + if not unify_chunks_args: + return objects + + # Run dask.array.core.unify_chunks + from dask.array.core import unify_chunks + + _, dask_data = unify_chunks(*unify_chunks_args) + dask_data_iter = iter(dask_data) + out = [] + for obj, ds in zip(objects, datasets): + for k, v in ds._variables.items(): + if v.chunks is not None: + ds._variables[k] = v.copy(data=next(dask_data_iter)) + out.append(obj._from_temp_dataset(ds) if isinstance(obj, DataArray) else ds) + + return tuple(out) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index dffe97f7391..eab4413d5ce 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -44,6 +44,7 @@ ) from .arithmetic import DataArrayArithmetic from .common import AbstractArray, DataWithCoords +from .computation import unify_chunks from .coordinates import ( DataArrayCoordinates, assert_coordinate_consistent, @@ -3686,8 +3687,8 @@ def unify_chunks(self) -> "DataArray": -------- dask.array.core.unify_chunks """ - ds = self._to_temp_dataset().unify_chunks() - return self._from_temp_dataset(ds) + + return unify_chunks(self)[0] def map_blocks( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 466b27dd3a1..48925b70b66 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -53,6 +53,7 @@ from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align from .arithmetic import DatasetArithmetic from .common import DataWithCoords, _contains_datetime_like_objects +from .computation import unify_chunks from .coordinates import ( DatasetCoordinates, assert_coordinate_consistent, @@ -6566,37 +6567,7 @@ def unify_chunks(self) -> "Dataset": dask.array.core.unify_chunks """ - try: - self.chunks - except ValueError: # "inconsistent chunks" - pass - else: - # No variables with dask backend, or all chunks are already aligned - return self.copy() - - # import dask is placed after the quick exit test above to allow - # running this method if dask isn't installed and there are no chunks - import dask.array - - ds = self.copy() - - dims_pos_map = {dim: index for index, dim in enumerate(ds.dims)} - - dask_array_names = [] - dask_unify_args = [] - for name, variable in ds.variables.items(): - if isinstance(variable.data, dask.array.Array): - dims_tuple = [dims_pos_map[dim] for dim in variable.dims] - dask_array_names.append(name) - dask_unify_args.append(variable.data) - dask_unify_args.append(dims_tuple) - - _, rechunked_arrays = dask.array.core.unify_chunks(*dask_unify_args) - - for name, new_array in zip(dask_array_names, rechunked_arrays): - ds.variables[name]._data = new_array - - return ds + return unify_chunks(self)[0] def map_blocks( self, diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index e80cc2f4ffe..f790587efa9 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1069,12 +1069,26 @@ def test_unify_chunks(map_ds): with pytest.raises(ValueError, match=r"inconsistent chunks"): ds_copy.chunks - expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5), "z": (4,)} + expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5)} with raise_if_dask_computes(): actual_chunks = ds_copy.unify_chunks().chunks - expected_chunks == actual_chunks + assert actual_chunks == expected_chunks assert_identical(map_ds, ds_copy.unify_chunks()) + out_a, out_b = xr.unify_chunks(ds_copy.cxy, ds_copy.drop_vars("cxy")) + assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5)) + assert out_b.chunks == expected_chunks + + # Test unordered dims + da = ds_copy["cxy"] + out_a, out_b = xr.unify_chunks(da.chunk({"x": -1}), da.T.chunk({"y": -1})) + assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5)) + assert out_b.chunks == ((5, 5, 5, 5), (4, 4, 2)) + + # Test mismatch + with pytest.raises(ValueError, match=r"Dimension 'x' size mismatch: 10 != 2"): + xr.unify_chunks(da, da.isel(x=slice(2))) + @pytest.mark.parametrize("obj", [make_ds(), make_da()]) @pytest.mark.parametrize(