diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9e7fc3689f..9e1b16fd76 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -92,6 +92,7 @@ from distributed.multi_lock import MultiLockExtension from distributed.node import ServerNode from distributed.proctitle import setproctitle +from distributed.protocol import deserialize from distributed.protocol.pickle import dumps, loads from distributed.protocol.serialize import Serialized, ToPickle, serialize from distributed.publish import PublishExtension @@ -112,6 +113,7 @@ key_split_group, log_errors, no_default, + offload, recursive_to_dict, validate_key, wait_for, @@ -124,6 +126,7 @@ ) from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis from distributed.variable import VariableExtension +from distributed.worker import dumps_task if TYPE_CHECKING: # TODO import from typing (requires Python >=3.10) @@ -4318,88 +4321,94 @@ async def add_nanny(self) -> dict[str, Any]: } return msg - def update_graph( + def _match_graph_with_tasks( + self, dsk: dict[str, Any], dependencies: dict[str, set[str]], keys: set[str] + ) -> set[str]: + n = 0 + lost_keys = set() + while len(dsk) != n: # walk through new tasks, cancel any bad deps + n = len(dsk) + for k, deps in list(dependencies.items()): + if any( + dep not in self.tasks and dep not in dsk for dep in deps + ): # bad key + lost_keys.add(k) + logger.info("User asked for computation on lost data, %s", k) + del dsk[k] + del dependencies[k] + if k in keys: + keys.remove(k) + del deps + # Avoid computation that is already finished + done = set() # tasks that are already done + for k, v in dependencies.items(): + if v and k in self.tasks: + ts = self.tasks[k] + if ts.state in ("memory", "erred"): + done.add(k) + + if done: + dependents = dask.core.reverse_dict(dependencies) + stack = list(done) + while stack: # remove unnecessary dependencies + key = stack.pop() + try: + deps = dependencies[key] + except KeyError: + deps = {ts.key for ts in self.tasks[key].dependencies} + for dep in deps: + if dep in dependents: + child_deps = dependents[dep] + elif dep in self.tasks: + child_deps = {ts.key for ts in self.tasks[key].dependencies} + else: + child_deps = set() + if all(d in done for d in child_deps): + if dep in self.tasks and dep not in done: + done.add(dep) + stack.append(dep) + for anc in done: + dsk.pop(anc, None) + dependencies.pop(anc, None) + return lost_keys + + def _create_taskstate_from_graph( self, + *, + start: float, + dsk: dict, + dependencies: dict, + keys: set[str], + ordered: dict[str, int], client: str, - graph_header: dict, - graph_frames: list[bytes], - keys: list[str], - internal_priority: dict[str, int] | None, + annotations_by_type: dict, + global_annotations: dict | None, + stimulus_id: str, submitting_task: str | None, user_priority: int | dict[str, int] = 0, actors: bool | list[str] | None = None, fifo_timeout: float = 0.0, code: tuple[SourceCode, ...] = (), - annotations: dict | None = None, - stimulus_id: str | None = None, ) -> None: - start = time() - try: - # TODO: deserialization + materialization should be offloaded to a - # thread since this is non-trivial compute time that blocks the - # event loop. This likely requires us to use a lock since we need to - # guarantee ordering of update_graph calls (as long as there is just - # a single offload thread, this is not a problem) - from distributed.protocol import deserialize - - graph = deserialize(graph_header, graph_frames).data - except Exception as e: - msg = """\ - Error during deserialization of the task graph. This frequently occurs if the Scheduler and Client have different environments. For more information, see https://docs.dask.org/en/stable/deployment-considerations.html#consistent-software-environments - """ - try: - raise RuntimeError(textwrap.dedent(msg)) from e - except RuntimeError as e: - err = error_message(e) - for key in keys: - self.report( - { - "op": "task-erred", - "key": key, - "exception": err["exception"], - "traceback": err["traceback"], - } - ) + """ + Take a low level graph and create the necessary scheduler state to + compute it. - return - annotations = annotations or {} - if isinstance(annotations, ToPickle): # type: ignore - # FIXME: what the heck? - annotations = annotations.data # type: ignore - - stimulus_id = stimulus_id or f"update-graph-{time()}" - ( - dsk, - dependencies, - annotations_by_type, - ) = self.materialize_graph(graph, annotations) - - if internal_priority is None: - # Removing all non-local keys before calling order() - dsk_keys = set(dsk) # intersection() of sets is much faster than dict_keys - stripped_deps = { - k: v.intersection(dsk_keys) - for k, v in dependencies.items() - if k in dsk_keys - } - internal_priority = dask.order.order(dsk, dependencies=stripped_deps) + WARNING + ------- + This method must not be made async since nothing here is concurrency + safe. All interactions with TaskState objects here should be happening + in the same event loop tick. + """ + + lost_keys = self._match_graph_with_tasks(dsk, dependencies, keys) - requested_keys = set(keys) - del keys if len(dsk) > 1: self.log_event( ["all", client], {"action": "update_graph", "count": len(dsk)} ) - self._pop_known_tasks( - known_tasks=self.tasks, dsk=dsk, dependencies=dependencies - ) - if lost_keys := self._pop_lost_tasks( - dsk=dsk, - known_tasks=self.tasks, - dependencies=dependencies, - keys=requested_keys, - ): + if lost_keys: self.report({"op": "cancelled-keys", "keys": lost_keys}, client=client) self.client_releases_keys( keys=lost_keys, client=client, stimulus_id=stimulus_id @@ -4414,13 +4423,14 @@ def update_graph( if code: # add new code blocks computation.code.add(code) - if annotations: + if global_annotations: # FIXME: This is kind of inconsistent since it only includes global # annotations. - computation.annotations.update(annotations) + computation.annotations.update(global_annotations) + del global_annotations runnable, touched_tasks, new_tasks = self._generate_taskstates( - keys=requested_keys, + keys=keys, dsk=dsk, dependencies=dependencies, computation=computation, @@ -4432,7 +4442,7 @@ def update_graph( ) self._set_priorities( - internal_priority=internal_priority, + internal_priority=ordered, submitting_task=submitting_task, user_priority=user_priority, fifo_timeout=fifo_timeout, @@ -4440,11 +4450,11 @@ def update_graph( tasks=runnable, ) - self.client_desires_keys(keys=requested_keys, client=client) + self.client_desires_keys(keys=keys, client=client) # Add actors if actors is True: - actors = list(requested_keys) + actors = list(keys) for actor in actors or []: ts = self.tasks[actor] ts.actor = True @@ -4496,7 +4506,7 @@ def update_graph( self, client=client, tasks=[ts.key for ts in touched_tasks], - keys=requested_keys, + keys=keys, dependencies=dependencies, annotations=dict(annotations_for_plugin), priority=priority, @@ -4510,6 +4520,88 @@ def update_graph( if ts.state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) + @log_errors + async def update_graph( + self, + client: str, + graph_header: dict, + graph_frames: list[bytes], + keys: set[str], + internal_priority: dict[str, int] | None, + submitting_task: str | None, + user_priority: int | dict[str, int] = 0, + actors: bool | list[str] | None = None, + fifo_timeout: float = 0.0, + code: tuple[SourceCode, ...] = (), + annotations: dict | None = None, + stimulus_id: str | None = None, + ) -> None: + # FIXME: Apparently empty dicts arrive as a ToPickle object + if isinstance(annotations, ToPickle): + annotations = annotations.data # type: ignore[unreachable] + start = time() + try: + try: + graph = deserialize(graph_header, graph_frames).data + del graph_header, graph_frames + except Exception as e: + msg = """\ + Error during deserialization of the task graph. This frequently occurs if the Scheduler and Client have different environments. For more information, see https://docs.dask.org/en/stable/deployment-considerations.html#consistent-software-environments + """ + raise RuntimeError(textwrap.dedent(msg)) from e + ( + dsk, + dependencies, + annotations_by_type, + ) = await offload( + _materialize_graph, + graph=graph, + global_annotations=annotations or {}, + ) + del graph + if not internal_priority: + # Removing all non-local keys before calling order() + dsk_keys = set( + dsk + ) # intersection() of sets is much faster than dict_keys + stripped_deps = { + k: v.intersection(dsk_keys) + for k, v in dependencies.items() + if k in dsk_keys + } + internal_priority = await offload( + dask.order.order, dsk=dsk, dependencies=stripped_deps + ) + + self._create_taskstate_from_graph( + dsk=dsk, + client=client, + dependencies=dependencies, + keys=set(keys), + ordered=internal_priority or {}, + submitting_task=submitting_task, + user_priority=user_priority, + actors=actors, + fifo_timeout=fifo_timeout, + code=code, + annotations_by_type=annotations_by_type, + # FIXME: This is just used to attach to Computation + # objects. This should be removed + global_annotations=annotations, + start=start, + stimulus_id=stimulus_id or f"update-graph-{start}", + ) + except RuntimeError as e: + err = error_message(e) + for key in keys: + self.report( + { + "op": "task-erred", + "key": key, + "exception": err["exception"], + "traceback": err["traceback"], + } + ) end = time() self.digest_metric("update-graph-duration", end - start) @@ -4670,127 +4762,6 @@ def _set_priorities( isinstance(el, (int, float)) for el in ts.priority ) - @staticmethod - def _pop_lost_tasks( - dsk: dict, keys: set[str], known_tasks: dict[str, TaskState], dependencies: dict - ) -> set[str]: - n = 0 - out = set() - while len(dsk) != n: # walk through new tasks, cancel any bad deps - n = len(dsk) - for k, deps in list(dependencies.items()): - if any( - dep not in known_tasks and dep not in dsk for dep in deps - ): # bad key - out.add(k) - logger.info("User asked for computation on lost data, %s", k) - del dsk[k] - del dependencies[k] - if k in keys: - keys.remove(k) - return out - - @staticmethod - def _pop_known_tasks( - known_tasks: dict[str, TaskState], dsk: dict, dependencies: dict - ) -> set[str]: - # Avoid computation that is already finished - already_in_memory = set() # tasks that are already done - for k, v in dependencies.items(): - if v and k in known_tasks: - ts = known_tasks[k] - if ts.state in ("memory", "erred"): - already_in_memory.add(k) - - done = set(already_in_memory) - if already_in_memory: - dependents = dask.core.reverse_dict(dependencies) - stack = list(already_in_memory) - while stack: # remove unnecessary dependencies - key = stack.pop() - try: - deps = dependencies[key] - except KeyError: - deps = known_tasks[key].dependencies - for dep in deps: - if dep in dependents: - child_deps = dependents[dep] - elif dep in known_tasks: - child_deps = known_tasks[dep].dependencies - else: - child_deps = set() - if all(d in done for d in child_deps): - if dep in known_tasks and dep not in done: - done.add(dep) - stack.append(dep) - for anc in done: - dsk.pop(anc, None) - dependencies.pop(anc, None) - return done - - @staticmethod - def materialize_graph( - hlg: HighLevelGraph, global_annotations: dict - ) -> tuple[dict, dict, dict]: - from distributed.worker import dumps_task - - dsk = dask.utils.ensure_dict(hlg) - - annotations_by_type: defaultdict[str, dict[str, Any]] = defaultdict(dict) - for type_, value in global_annotations.items(): - annotations_by_type[type_].update( - {stringify(k): (value(k) if callable(value) else value) for k in dsk} - ) - for layer in hlg.layers.values(): - if layer.annotations: - annot = layer.annotations - for annot_type, value in annot.items(): - annotations_by_type[annot_type].update( - { - stringify(k): (value(k) if callable(value) else value) - for k in layer - } - ) - - dependencies, _ = get_deps(dsk) - - # Remove `Future` objects from graph and note any future dependencies - dsk2 = {} - fut_deps = {} - for k, v in dsk.items(): - dsk2[k], futs = unpack_remotedata(v, byte_keys=True) - if futs: - fut_deps[k] = futs - dsk = dsk2 - - # - Add in deps for any tasks that depend on futures - for k, futures in fut_deps.items(): - dependencies[k].update(f.key for f in futures) - new_dsk = {} - exclusive = set(hlg) - for k, v in dsk.items(): - new_k = stringify(k) - new_dsk[new_k] = stringify(v, exclusive=exclusive) - dsk = new_dsk - dependencies = { - stringify(k): {stringify(dep) for dep in deps} - for k, deps in dependencies.items() - } - - # Remove any self-dependencies (happens on test_publish_bag() and others) - for k, v in dependencies.items(): - deps = set(v) - if k in deps: - deps.remove(k) - dependencies[k] = deps - - # Remove aliases - for k in list(dsk): - if dsk[k] is k: - del dsk[k] - dsk = valmap(dumps_task, dsk) - return dsk, dependencies, dict(annotations_by_type) - def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: """Respond to an event which may have opened spots on worker threadpools @@ -8505,3 +8476,66 @@ def transition( self.metadata[key] = ts.metadata self.state[key] = finish self.keys.discard(key) + + +def _materialize_graph( + graph: HighLevelGraph, global_annotations: dict +) -> tuple[dict, dict, dict]: + dsk = dask.utils.ensure_dict(graph) + annotations_by_type: defaultdict[str, dict[str, Any]] = defaultdict(dict) + for annotations_type, value in global_annotations.items(): + annotations_by_type[annotations_type].update( + {stringify(k): (value(k) if callable(value) else value) for k in dsk} + ) + + for layer in graph.layers.values(): + if layer.annotations: + annot = layer.annotations + for annot_type, value in annot.items(): + annotations_by_type[annot_type].update( + { + stringify(k): (value(k) if callable(value) else value) + for k in layer + } + ) + dependencies, _ = get_deps(dsk) + + # Remove `Future` objects from graph and note any future dependencies + dsk2 = {} + fut_deps = {} + for k, v in dsk.items(): + dsk2[k], futs = unpack_remotedata(v, byte_keys=True) + if futs: + fut_deps[k] = futs + dsk = dsk2 + + # - Add in deps for any tasks that depend on futures + for k, futures in fut_deps.items(): + dependencies[k].update(f.key for f in futures) + new_dsk = {} + # Annotation callables are evaluated on the non-stringified version of + # the keys + exclusive = set(graph) + for k, v in dsk.items(): + new_k = stringify(k) + new_dsk[new_k] = stringify(v, exclusive=exclusive) + dsk = new_dsk + dependencies = { + stringify(k): {stringify(dep) for dep in deps} + for k, deps in dependencies.items() + } + + # Remove any self-dependencies (happens on test_publish_bag() and others) + for k, v in dependencies.items(): + deps = set(v) + if k in deps: + deps.remove(k) + dependencies[k] = deps + + # Remove aliases + for k in list(dsk): + if dsk[k] is k: + del dsk[k] + dsk = valmap(dumps_task, dsk) + + return dsk, dependencies, annotations_by_type diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 275148f605..8c5bb3e75c 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -303,6 +303,15 @@ async def test_decide_worker_rootish_while_last_worker_is_retiring(c, s, a): while a.state.executing_count != 1 or b.state.executing_count != 1: await asyncio.sleep(0.01) + # Rootish is a dynamic property as it is defined right now. Since the + # above submit calls are individual update_graph calls, waiting for + # tasks to be in executing state on the worker is not sufficient to + # guarantee that all the y tasks are already on the scheduler. Only + # after at least 5 have been registered, will the task be flagged as + # rootish + while "y-2" not in s.tasks or not s.is_rootish(s.tasks["y-2"]): + await asyncio.sleep(0.01) + # - y-2 has no restrictions # - TaskGroup(y) has more than 4 tasks (total_nthreads * 2) # - TaskGroup(y) has less than 5 dependency groups @@ -1358,7 +1367,7 @@ async def test_update_graph_culls(s, a, b): ) header, frames = serialize(ToPickle(dsk), on_error="raise") - s.update_graph( + await s.update_graph( graph_header=header, graph_frames=frames, keys=["y"], diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index a0dadfcc32..10569265ed 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -641,10 +641,11 @@ async def test_steal_when_more_tasks(c, s, a, *rest): "slowidentity": 0.2, "slow2": 1, }, - "distributed.scheduler.work-stealing-interval": "20ms", }, ) async def test_steal_more_attractive_tasks(c, s, a, *rest): + ext = s.extensions["stealing"] + def slow2(x): sleep(1) return x @@ -652,9 +653,27 @@ def slow2(x): x = c.submit(mul, b"0", 100000000, workers=a.address) # 100 MB await wait(x) + # The submits below are all individual update_graph calls which are very + # likely submitted in the same batch. + # Prior to https://github.com/dask/distributed/pull/8049, the entire batch + # would be processed by the scheduler in the same event loop tick. + # Therefore, the first PC `stealing.balance` call would be guaranteed to see + # all the tasks and make the correct decision. + # After the PR, the batch is processed in multiple event loop ticks, so the + # first PC `stealing.balance` call would potentially only see the first + # tasks and would try to rebalance them instead of the slow and heavy one. + # To guarantee that the stealing extension sees all tasks, we're stopping + # the callback and are calling balance ourselves once we are certain the + # tasks are all on the scheduler. + # Related https://github.com/dask/distributed/pull/5443 + await ext.stop() futures = [c.submit(slowidentity, x, pure=False, delay=0.2) for i in range(10)] future = c.submit(slow2, x, priority=-1) + while future.key not in s.tasks: + await asyncio.sleep(0.01) + # Now call it once explicitly to move the heavy task + ext.balance() while not any(w.state.tasks for w in rest): await asyncio.sleep(0.01)