Skip to content
forked from pydata/xarray

Commit

Permalink
Compatibility with dask 2021.02.0 (pydata#4884)
Browse files Browse the repository at this point in the history
* Compatibility with dask 2021.02.0

* Rework postpersist and postcompute
  • Loading branch information
crusaderky authored Feb 11, 2021
1 parent 10f0227 commit 2a34bfb
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 36 deletions.
2 changes: 1 addition & 1 deletion ci/requirements/environment-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dependencies:
# - cdms2 # Not available on Windows
# - cfgrib # Causes Python interpreter crash on Windows: https://github.com/pydata/xarray/pull/3340
- cftime
- dask<2021.02.0
- dask
- distributed
- h5netcdf
- h5py=2
Expand Down
2 changes: 1 addition & 1 deletion ci/requirements/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- cdms2
- cfgrib
- cftime
- dask<2021.02.0
- dask
- distributed
- h5netcdf
- h5py=2
Expand Down
58 changes: 24 additions & 34 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,79 +866,69 @@ def __dask_postcompute__(self):
import dask

info = [
(True, k, v.__dask_postcompute__())
(k, None) + v.__dask_postcompute__()
if dask.is_dask_collection(v)
else (False, k, v)
else (k, v, None, None)
for k, v in self._variables.items()
]
args = (
info,
construct_direct_args = (
self._coord_names,
self._dims,
self._attrs,
self._indexes,
self._encoding,
self._close,
)
return self._dask_postcompute, args
return self._dask_postcompute, (info, construct_direct_args)

def __dask_postpersist__(self):
import dask

info = [
(True, k, v.__dask_postpersist__())
(k, None, v.__dask_keys__()) + v.__dask_postpersist__()
if dask.is_dask_collection(v)
else (False, k, v)
else (k, v, None, None, None)
for k, v in self._variables.items()
]
args = (
info,
construct_direct_args = (
self._coord_names,
self._dims,
self._attrs,
self._indexes,
self._encoding,
self._close,
)
return self._dask_postpersist, args
return self._dask_postpersist, (info, construct_direct_args)

@staticmethod
def _dask_postcompute(results, info, *args):
def _dask_postcompute(results, info, construct_direct_args):
variables = {}
results2 = list(results[::-1])
for is_dask, k, v in info:
if is_dask:
func, args2 = v
r = results2.pop()
result = func(r, *args2)
results_iter = iter(results)
for k, v, rebuild, rebuild_args in info:
if v is None:
variables[k] = rebuild(next(results_iter), *rebuild_args)
else:
result = v
variables[k] = result
variables[k] = v

final = Dataset._construct_direct(variables, *args)
final = Dataset._construct_direct(variables, *construct_direct_args)
return final

@staticmethod
def _dask_postpersist(dsk, info, *args):
def _dask_postpersist(dsk, info, construct_direct_args):
from dask.optimization import cull

variables = {}
# postpersist is called in both dask.optimize and dask.persist
# When persisting, we want to filter out unrelated keys for
# each Variable's task graph.
is_persist = len(dsk) == len(info)
for is_dask, k, v in info:
if is_dask:
func, args2 = v
if is_persist:
name = args2[1][0]
dsk2 = {k: v for k, v in dsk.items() if k[0] == name}
else:
dsk2 = dsk
result = func(dsk2, *args2)
for k, v, dask_keys, rebuild, rebuild_args in info:
if v is None:
dsk2, _ = cull(dsk, dask_keys)
variables[k] = rebuild(dsk2, *rebuild_args)
else:
result = v
variables[k] = result
variables[k] = v

return Dataset._construct_direct(variables, *args)
return Dataset._construct_direct(variables, *construct_direct_args)

def compute(self, **kwargs) -> "Dataset":
"""Manually trigger loading and/or computation of this dataset's data
Expand Down

0 comments on commit 2a34bfb

Please sign in to comment.