diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 9455ef2f127..6de2bc8dc64 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -8,7 +8,7 @@ dependencies: # - cdms2 # Not available on Windows # - cfgrib # Causes Python interpreter crash on Windows: https://github.com/pydata/xarray/pull/3340 - cftime - - dask<2021.02.0 + - dask - distributed - h5netcdf - h5py=2 diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 7261b5b6954..0f59d9570c8 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -9,7 +9,7 @@ dependencies: - cdms2 - cfgrib - cftime - - dask<2021.02.0 + - dask - distributed - h5netcdf - h5py=2 diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7d51adb5244..066a2f690b0 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -866,13 +866,12 @@ def __dask_postcompute__(self): import dask info = [ - (True, k, v.__dask_postcompute__()) + (k, None) + v.__dask_postcompute__() if dask.is_dask_collection(v) - else (False, k, v) + else (k, v, None, None) for k, v in self._variables.items() ] - args = ( - info, + construct_direct_args = ( self._coord_names, self._dims, self._attrs, @@ -880,19 +879,18 @@ def __dask_postcompute__(self): self._encoding, self._close, ) - return self._dask_postcompute, args + return self._dask_postcompute, (info, construct_direct_args) def __dask_postpersist__(self): import dask info = [ - (True, k, v.__dask_postpersist__()) + (k, None, v.__dask_keys__()) + v.__dask_postpersist__() if dask.is_dask_collection(v) - else (False, k, v) + else (k, v, None, None, None) for k, v in self._variables.items() ] - args = ( - info, + construct_direct_args = ( self._coord_names, self._dims, self._attrs, @@ -900,45 +898,37 @@ def __dask_postpersist__(self): self._encoding, self._close, ) - return self._dask_postpersist, args + return self._dask_postpersist, (info, construct_direct_args) @staticmethod - def _dask_postcompute(results, info, *args): + def _dask_postcompute(results, info, construct_direct_args): variables = {} - results2 = list(results[::-1]) - for is_dask, k, v in info: - if is_dask: - func, args2 = v - r = results2.pop() - result = func(r, *args2) + results_iter = iter(results) + for k, v, rebuild, rebuild_args in info: + if v is None: + variables[k] = rebuild(next(results_iter), *rebuild_args) else: - result = v - variables[k] = result + variables[k] = v - final = Dataset._construct_direct(variables, *args) + final = Dataset._construct_direct(variables, *construct_direct_args) return final @staticmethod - def _dask_postpersist(dsk, info, *args): + def _dask_postpersist(dsk, info, construct_direct_args): + from dask.optimization import cull + variables = {} # postpersist is called in both dask.optimize and dask.persist # When persisting, we want to filter out unrelated keys for # each Variable's task graph. - is_persist = len(dsk) == len(info) - for is_dask, k, v in info: - if is_dask: - func, args2 = v - if is_persist: - name = args2[1][0] - dsk2 = {k: v for k, v in dsk.items() if k[0] == name} - else: - dsk2 = dsk - result = func(dsk2, *args2) + for k, v, dask_keys, rebuild, rebuild_args in info: + if v is None: + dsk2, _ = cull(dsk, dask_keys) + variables[k] = rebuild(dsk2, *rebuild_args) else: - result = v - variables[k] = result + variables[k] = v - return Dataset._construct_direct(variables, *args) + return Dataset._construct_direct(variables, *construct_direct_args) def compute(self, **kwargs) -> "Dataset": """Manually trigger loading and/or computation of this dataset's data