Skip to content

Commit

Permalink
Use high-level graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger committed Dec 14, 2018
1 parent 1588e12 commit e211338
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion dask_ml/_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,19 @@ def fit(model, x, y, compute=True, shuffle_blocks=True, random_state=None, **kwa
}
)

new_dsk = dask.sharedict.merge((name, dsk), x.dask, getattr(y, "dask", {}))
graphs = {x.name: x.__dask_graph__()}
if hasattr(y, "__dask_graph__"):
graphs[y.name] = y.__dask_graph__()

try:
from dask.highlevelgraph import HighLevelGraph

new_dsk = HighLevelGraph.merge(*graphs.values())
except ImportError:
from dask import sharedict

new_dsk = sharedict.merge(graphs.values)

value = Delayed((name, nblocks - 1), new_dsk)

if compute:
Expand Down

0 comments on commit e211338

Please sign in to comment.