diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4f2d0b366df..c5552bd365f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -92,6 +92,7 @@ from distributed.multi_lock import MultiLockExtension from distributed.node import ServerNode from distributed.proctitle import setproctitle +from distributed.protocol import deserialize from distributed.protocol.pickle import dumps, loads from distributed.protocol.serialize import Serialized, ToPickle, serialize from distributed.publish import PublishExtension @@ -112,6 +113,7 @@ key_split_group, log_errors, no_default, + offload, recursive_to_dict, validate_key, wait_for, @@ -124,11 +126,14 @@ ) from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis from distributed.variable import VariableExtension +from distributed.worker import dumps_task if TYPE_CHECKING: # TODO import from typing (requires Python >=3.10) from typing_extensions import TypeAlias + from dask.highlevelgraph import HighLevelGraph + # Not to be confused with distributed.worker_state_machine.TaskStateState TaskStateState: TypeAlias = Literal[ "released", @@ -3761,6 +3766,7 @@ def __init__( setproctitle("dask scheduler [not started]") Scheduler._instances.add(self) self.rpc.allow_offload = False + self._update_graph_lock = asyncio.Lock() ################## # Administration # @@ -4316,93 +4322,14 @@ async def add_nanny(self) -> dict[str, Any]: } return msg - @staticmethod - def _materialize_graph( - graph_header: dict, graph_frames: list[bytes], global_annotations: dict - ) -> tuple[dict, dict, dict]: - try: - from distributed.protocol import deserialize - - graph = deserialize(graph_header, graph_frames).data - del graph_header, graph_frames - except Exception as e: - msg = """\ - Error during deserialization of the task graph. This frequently occurs if the Scheduler and Client have different environments. For more information, see https://docs.dask.org/en/stable/deployment-considerations.html#consistent-software-environments - """ - raise RuntimeError(textwrap.dedent(msg)) from e - - from distributed.worker import dumps_task - - dsk = dask.utils.ensure_dict(graph) - - annotations_by_type: defaultdict[str, dict[str, Any]] = defaultdict(dict) - for annotations_type, value in global_annotations.items(): - annotations_by_type[annotations_type].update( - {k: (value(k) if callable(value) else value) for k in dsk} - ) - - for layer in graph.layers.values(): - if layer.annotations: - annot = layer.annotations - for annot_type, value in annot.items(): - annotations_by_type[annot_type].update( - { - stringify(k): (value(k) if callable(value) else value) - for k in layer - } - ) - - dependencies, _ = get_deps(dsk) - - # Remove `Future` objects from graph and note any future dependencies - dsk2 = {} - fut_deps = {} - for k, v in dsk.items(): - dsk2[k], futs = unpack_remotedata(v, byte_keys=True) - if futs: - fut_deps[k] = futs - dsk = dsk2 - - # - Add in deps for any tasks that depend on futures - for k, futures in fut_deps.items(): - dependencies[k].update(f.key for f in futures) - new_dsk = {} - # Annotation callables are evaluated on the non-stringified version of - # the keys - exclusive = set(graph) - for k, v in dsk.items(): - new_k = stringify(k) - new_dsk[new_k] = stringify(v, exclusive=exclusive) - dsk = new_dsk - dependencies = { - stringify(k): {stringify(dep) for dep in deps} - for k, deps in dependencies.items() - } - - # Remove any self-dependencies (happens on test_publish_bag() and others) - for k, v in dependencies.items(): - deps = set(v) - if k in deps: - deps.remove(k) - dependencies[k] = deps - - # Remove aliases - for k in list(dsk): - if dsk[k] is k: - del dsk[k] - dsk = valmap(dumps_task, dsk) - - return dsk, dependencies, annotations_by_type - - @staticmethod - def _match_graph_with_tasks(known_tasks, dsk, dependencies, keys): + def _match_graph_with_tasks(self, dsk, dependencies, keys): n = 0 lost_keys = set() while len(dsk) != n: # walk through new tasks, cancel any bad deps n = len(dsk) for k, deps in list(dependencies.items()): if any( - dep not in known_tasks and dep not in dsk for dep in deps + dep not in self.tasks and dep not in dsk for dep in deps ): # bad key lost_keys.add(k) logger.info("User asked for computation on lost data, %s", k) @@ -4414,8 +4341,8 @@ def _match_graph_with_tasks(known_tasks, dsk, dependencies, keys): # Avoid computation that is already finished already_in_memory = set() # tasks that are already done for k, v in dependencies.items(): - if v and k in known_tasks: - ts = known_tasks[k] + if v and k in self.tasks: + ts = self.tasks[k] if ts.state in ("memory", "erred"): already_in_memory.add(k) @@ -4428,16 +4355,16 @@ def _match_graph_with_tasks(known_tasks, dsk, dependencies, keys): try: deps = dependencies[key] except KeyError: - deps = known_tasks[key].dependencies + deps = self.tasks[key].dependencies for dep in deps: if dep in dependents: child_deps = dependents[dep] - elif dep in known_tasks: - child_deps = known_tasks[dep].dependencies + elif dep in self.tasks: + child_deps = self.tasks[dep].dependencies else: child_deps = set() if all(d in done for d in child_deps): - if dep in known_tasks and dep not in done: + if dep in self.tasks and dep not in done: done.add(dep) stack.append(dep) for anc in done: @@ -4445,61 +4372,37 @@ def _match_graph_with_tasks(known_tasks, dsk, dependencies, keys): dependencies.pop(anc, None) return lost_keys - def update_graph( + def _create_taskstate_from_graph( self, - client: str, - graph_header: dict, - graph_frames: list[bytes], + *, + start: float, + dsk: dict, + dependencies: dict, keys: set[str], - internal_priority: dict[str, int] | None, + ordered: dict[str, int], + client: str, + annotations_by_type: dict, + global_annotations: dict | None, + stimulus_id: str, submitting_task: str | None, user_priority: int | dict[str, int] = 0, actors: bool | list[str] | None = None, fifo_timeout: float = 0.0, code: tuple[SourceCode, ...] = (), - annotations: dict | None = None, - stimulus_id: str | None = None, ) -> None: - start = time() - annotations = annotations or {} - if isinstance(annotations, ToPickle): # type: ignore - # FIXME: what the heck? - annotations = annotations.data # type: ignore - try: - ( - dsk, - dependencies, - annotations_by_type, - ) = self._materialize_graph(graph_header, graph_frames, annotations) - del graph_header, graph_frames - except RuntimeError as e: - err = error_message(e) - for key in keys: - self.report( - { - "op": "task-erred", - "key": key, - "exception": err["exception"], - "traceback": err["traceback"], - } - ) - keys = set(keys) - lost_keys = self._match_graph_with_tasks(self.tasks, dsk, dependencies, keys) - ordered: dict = {} - if not internal_priority: - # Removing all non-local keys before calling order() - dsk_keys = set(dsk) # intersection() of sets is much faster than dict_keys - stripped_deps = { - k: v.intersection(dsk_keys) - for k, v in dependencies.items() - if k in dsk_keys - } - ordered = dask.order.order(dsk, dependencies=stripped_deps) - assert ordered + """ + Take a low level graph and create the necessary scheduler state to + compute it. - stimulus_id = stimulus_id or f"update-graph-{time()}" + WARNING + ------- + This method must not be made async since nothing here is concurrency + safe. All interactions with TaskState objects here should be happening + in the same event loop tick. + """ + + lost_keys = self._match_graph_with_tasks(dsk, dependencies, keys) - # FIXME: How can I log this cleanly? if len(dsk) > 1: self.log_event( ["all", client], {"action": "update_graph", "count": len(dsk)} @@ -4520,10 +4423,11 @@ def update_graph( if code: # add new code blocks computation.code.add(code) - if annotations: + if global_annotations: # FIXME: This is kind of inconsistent since it only includes global # annotations. - computation.annotations.update(annotations) + computation.annotations.update(global_annotations) + del global_annotations runnable, touched_tasks, new_tasks = self._generate_taskstates( keys=keys, @@ -4538,7 +4442,7 @@ def update_graph( ) self._set_priorities( - internal_priority=internal_priority or ordered, + internal_priority=ordered, submitting_task=submitting_task, user_priority=user_priority, fifo_timeout=fifo_timeout, @@ -4616,6 +4520,90 @@ def update_graph( if ts.state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) + @log_errors + async def update_graph( + self, + client: str, + graph_header: dict, + graph_frames: list[bytes], + keys: set[str], + internal_priority: dict[str, int] | None, + submitting_task: str | None, + user_priority: int | dict[str, int] = 0, + actors: bool | list[str] | None = None, + fifo_timeout: float = 0.0, + code: tuple[SourceCode, ...] = (), + annotations: dict | None = None, + stimulus_id: str | None = None, + ) -> None: + # FIXME: Apparently empty dicts arrive as a ToPickle object + if isinstance(annotations, ToPickle): + annotations = annotations.data # type: ignore[unreachable] + start = time() + async with self._update_graph_lock: + try: + graph = deserialize(graph_header, graph_frames).data + del graph_header, graph_frames + except Exception as e: + msg = """\ + Error during deserialization of the task graph. This frequently occurs if the Scheduler and Client have different environments. For more information, see https://docs.dask.org/en/stable/deployment-considerations.html#consistent-software-environments + """ + raise RuntimeError(textwrap.dedent(msg)) from e + + except RuntimeError as e: + err = error_message(e) + for key in keys: + self.report( + { + "op": "task-erred", + "key": key, + "exception": err["exception"], + "traceback": err["traceback"], + } + ) + else: + ( + dsk, + dependencies, + annotations_by_type, + ) = await offload( + _materialize_graph, + graph=graph, + global_annotations=annotations or {}, + ) + del graph + if not internal_priority: + # Removing all non-local keys before calling order() + dsk_keys = set( + dsk + ) # intersection() of sets is much faster than dict_keys + stripped_deps = { + k: v.intersection(dsk_keys) + for k, v in dependencies.items() + if k in dsk_keys + } + internal_priority = await offload( + dask.order.order, dsk=dsk, dependencies=stripped_deps + ) + + self._create_taskstate_from_graph( + dsk=dsk, + client=client, + dependencies=dependencies, + keys=set(keys), + ordered=internal_priority or {}, + submitting_task=submitting_task, + user_priority=user_priority, + actors=actors, + fifo_timeout=fifo_timeout, + code=code, + annotations_by_type=annotations_by_type, + # FIXME: This is just used to attach to Computation objects. This + # should be removed + global_annotations=annotations, + start=start, + stimulus_id=stimulus_id or f"update-graph-{start}", + ) end = time() self.digest_metric("update-graph-duration", end - start) @@ -8490,3 +8478,66 @@ def transition( self.metadata[key] = ts.metadata self.state[key] = finish self.keys.discard(key) + + +def _materialize_graph( + graph: HighLevelGraph, global_annotations: dict +) -> tuple[dict, dict, dict]: + dsk = dask.utils.ensure_dict(graph) + annotations_by_type: defaultdict[str, dict[str, Any]] = defaultdict(dict) + for annotations_type, value in global_annotations.items(): + annotations_by_type[annotations_type].update( + {stringify(k): (value(k) if callable(value) else value) for k in dsk} + ) + + for layer in graph.layers.values(): + if layer.annotations: + annot = layer.annotations + for annot_type, value in annot.items(): + annotations_by_type[annot_type].update( + { + stringify(k): (value(k) if callable(value) else value) + for k in layer + } + ) + dependencies, _ = get_deps(dsk) + + # Remove `Future` objects from graph and note any future dependencies + dsk2 = {} + fut_deps = {} + for k, v in dsk.items(): + dsk2[k], futs = unpack_remotedata(v, byte_keys=True) + if futs: + fut_deps[k] = futs + dsk = dsk2 + + # - Add in deps for any tasks that depend on futures + for k, futures in fut_deps.items(): + dependencies[k].update(f.key for f in futures) + new_dsk = {} + # Annotation callables are evaluated on the non-stringified version of + # the keys + exclusive = set(graph) + for k, v in dsk.items(): + new_k = stringify(k) + new_dsk[new_k] = stringify(v, exclusive=exclusive) + dsk = new_dsk + dependencies = { + stringify(k): {stringify(dep) for dep in deps} + for k, deps in dependencies.items() + } + + # Remove any self-dependencies (happens on test_publish_bag() and others) + for k, v in dependencies.items(): + deps = set(v) + if k in deps: + deps.remove(k) + dependencies[k] = deps + + # Remove aliases + for k in list(dsk): + if dsk[k] is k: + del dsk[k] + dsk = valmap(dumps_task, dsk) + + return dsk, dependencies, annotations_by_type diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 5a9c5e8014d..3409be2963c 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3719,7 +3719,6 @@ async def test_scatter_raises_if_no_workers(c, s): await c.scatter(1, timeout=0.5) -@pytest.mark.slow @gen_test() async def test_reconnect(): port = open_port() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 275148f6059..9a78fd2a5fa 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1358,7 +1358,7 @@ async def test_update_graph_culls(s, a, b): ) header, frames = serialize(ToPickle(dsk), on_error="raise") - s.update_graph( + await s.update_graph( graph_header=header, graph_frames=frames, keys=["y"],