From d757215dba93c6bbe41f71e54d3559fff655a277 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Sat, 11 Jun 2022 22:00:07 +0100 Subject: [PATCH 1/3] Trivial cut-paste changes to worker and worker_state_machine --- distributed/worker.py | 2113 +-------------------------- distributed/worker_state_machine.py | 2023 ++++++++++++++++++++++++- 2 files changed, 2083 insertions(+), 2053 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index 6a09d56cae..5686da06d4 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -5,7 +5,6 @@ import builtins import errno import functools -import heapq import logging import operator import os @@ -21,7 +20,6 @@ Collection, Container, Iterable, - Iterator, Mapping, MutableMapping, ) @@ -29,10 +27,9 @@ from contextlib import suppress from datetime import timedelta from inspect import isawaitable -from pickle import PicklingError from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast -from tlz import first, keymap, merge, peekn, pluck # noqa: F401 +from tlz import first, keymap, pluck from tornado.ioloop import IOLoop, PeriodicCallback import dask @@ -50,7 +47,6 @@ ) from distributed import comm, preloading, profile, utils -from distributed._stories import worker_story from distributed.batched import BatchedSend from distributed.collections import LRU, HeapSet from distributed.comm import connect, get_address_host @@ -74,7 +70,7 @@ from distributed.metrics import time from distributed.node import ServerNode from distributed.proctitle import setproctitle -from distributed.protocol import Serialize, pickle, to_serialize +from distributed.protocol import pickle, to_serialize from distributed.pubsub import PubSubWorkerExtension from distributed.security import Security from distributed.shuffle import ShuffleWorkerExtension @@ -109,57 +105,31 @@ ) from distributed.worker_state_machine import ( NO_VALUE, - PROCESSING, - READY, AcquireReplicasEvent, - AddKeysMsg, AlreadyCancelledEvent, CancelComputeEvent, ComputeTaskEvent, - EnsureCommunicatingAfterTransitions, - Execute, ExecuteFailureEvent, ExecuteSuccessEvent, FindMissingEvent, FreeKeysEvent, - GatherDep, GatherDepBusyEvent, - GatherDepDoneEvent, GatherDepFailureEvent, GatherDepNetworkFailureEvent, GatherDepSuccessEvent, - Instructions, - InvalidTaskState, - InvalidTransition, - LongRunningMsg, - RecommendationsConflict, - Recs, - RecsInstrs, RefreshWhoHasEvent, - ReleaseWorkerDataMsg, RemoveReplicasEvent, - RequestRefreshWhoHasMsg, RescheduleEvent, - RescheduleMsg, RetryBusyWorkerEvent, - RetryBusyWorkerLater, SecedeEvent, - SendMessageToScheduler, StateMachineEvent, StealRequestEvent, - StealResponseMsg, - TaskErredMsg, - TaskFinishedMsg, TaskState, - TaskStateState, - TransitionCounterMaxExceeded, UnpauseEvent, UpdateDataEvent, - merge_recs_instructions, ) if TYPE_CHECKING: - from distributed.actor import Actor from distributed.client import Client from distributed.diagnostics.plugin import WorkerPlugin from distributed.nanny import Nanny @@ -430,20 +400,9 @@ class Worker(ServerNode): _instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet() _initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet() - tasks: dict[str, TaskState] - waiting_for_data_count: int - has_what: defaultdict[str, set[str]] # {worker address: {ts.key, ...} - data_needed: HeapSet[TaskState] - data_needed_per_worker: defaultdict[str, HeapSet[TaskState]] nanny: Nanny | None _lock: threading.Lock - in_flight_workers: dict[str, set[str]] # {worker address: {ts.key, ...}} - busy_workers: set[str] - total_out_connections: int total_in_connections: int - comm_threshold_bytes: int - comm_nbytes: int - _missing_dep_flight: set[TaskState] threads: dict[str, int] # {ts.key: thread ID} active_threads_lock: threading.Lock active_threads: dict[int, str] # {thread ID: ts.key} @@ -452,21 +411,9 @@ class Worker(ServerNode): profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]] profile_recent: dict[str, Any] profile_history: deque[tuple[float, dict[str, Any]]] - generation: int - ready: list[tuple[tuple[int, ...], str]] # heapq [(priority, key), ...] - constrained: deque[str] - _executing: set[TaskState] - _in_flight_tasks: set[TaskState] - executed_count: int - long_running: set[str] - log: deque[tuple] # [(..., stimulus_id: str | None, timestamp: float), ...] - stimulus_log: deque[StateMachineEvent] incoming_transfer_log: deque[dict[str, Any]] outgoing_transfer_log: deque[dict[str, Any]] - target_message_size: int validate: bool - transition_counter: int - transition_counter_max: int | Literal[False] incoming_count: int outgoing_count: int outgoing_current_count: int @@ -488,9 +435,7 @@ class Worker(ServerNode): _dashboard_address: str | None _dashboard: bool _http_prefix: str - nthreads: int total_resources: dict[str, float] - available_resources: dict[str, float] death_timeout: float | None lifetime: float | None lifetime_stagger: float | None @@ -498,7 +443,6 @@ class Worker(ServerNode): extensions: dict security: Security connection_args: dict[str, Any] - actors: dict[str, Actor | None] loop: IOLoop executors: dict[str, Executor] batched_stream: BatchedSend @@ -516,7 +460,6 @@ class Worker(ServerNode): execution_state: dict[str, Any] plugins: dict[str, WorkerPlugin] _pending_plugins: tuple[WorkerPlugin, ...] - _async_instructions: set[asyncio.Task] def __init__( self, @@ -1792,10 +1735,6 @@ async def get_data( # Local Execution # ################### - @functools.singledispatchmethod - def _handle_event(self, ev: StateMachineEvent) -> RecsInstrs: - raise TypeError(ev) # pragma: nocover - def update_data( self, data: dict[str, object], @@ -1811,99 +1750,6 @@ def update_data( ) return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} - @_handle_event.register - def _handle_update_data(self, ev: UpdateDataEvent) -> RecsInstrs: - recommendations: Recs = {} - instructions: Instructions = [] - for key, value in ev.data.items(): - try: - ts = self.tasks[key] - recommendations[ts] = ("memory", value) - except KeyError: - self.tasks[key] = ts = TaskState(key) - - try: - recs = self._put_key_in_memory( - ts, value, stimulus_id=ev.stimulus_id - ) - except Exception as e: - msg = error_message(e) - recommendations = {ts: tuple(msg.values())} - else: - recommendations.update(recs) - - self.log.append((key, "receive-from-scatter", ev.stimulus_id, time())) - - if ev.report: - instructions.append( - AddKeysMsg(keys=list(ev.data), stimulus_id=ev.stimulus_id) - ) - - return recommendations, instructions - - @_handle_event.register - def _handle_free_keys(self, ev: FreeKeysEvent) -> RecsInstrs: - """Handler to be called by the scheduler. - - The given keys are no longer referred to and required by the scheduler. - The worker is now allowed to release the key, if applicable. - - This does not guarantee that the memory is released since the worker may - still decide to hold on to the data and task since it is required by an - upstream dependency. - """ - self.log.append(("free-keys", ev.keys, ev.stimulus_id, time())) - recommendations: Recs = {} - for key in ev.keys: - ts = self.tasks.get(key) - if ts: - recommendations[ts] = "released" - return recommendations, [] - - @_handle_event.register - def _handle_remove_replicas(self, ev: RemoveReplicasEvent) -> RecsInstrs: - """Stream handler notifying the worker that it might be holding unreferenced, - superfluous data. - - This should not actually happen during ordinary operations and is only intended - to correct any erroneous state. An example where this is necessary is if a - worker fetches data for a downstream task but that task is released before the - data arrives. In this case, the scheduler will notify the worker that it may be - holding this unnecessary data, if the worker hasn't released the data itself, - already. - - This handler does not guarantee the task nor the data to be actually - released but only asks the worker to release the data on a best effort - guarantee. This protects from race conditions where the given keys may - already have been rescheduled for compute in which case the compute - would win and this handler is ignored. - - For stronger guarantees, see handler free_keys - """ - recommendations: Recs = {} - instructions: Instructions = [] - - rejected = [] - for key in ev.keys: - ts = self.tasks.get(key) - if ts is None or ts.state != "memory": - continue - if not ts.is_protected(): - self.log.append( - (ts.key, "remove-replica-confirmed", ev.stimulus_id, time()) - ) - recommendations[ts] = "released" - else: - rejected.append(key) - - if rejected: - self.log.append( - ("remove-replica-rejected", rejected, ev.stimulus_id, time()) - ) - instructions.append(AddKeysMsg(keys=rejected, stimulus_id=ev.stimulus_id)) - - return recommendations, instructions - async def set_resources(self, **resources) -> None: for r, quantity in resources.items(): if r in self.total_resources: @@ -1918,950 +1764,87 @@ async def set_resources(self, **resources) -> None: worker=self.contact_address, ) - ################### - # Task Management # - ################### - - @fail_hard - def _handle_remote_stimulus( - self, cls: type[StateMachineEvent] - ) -> Callable[..., None]: - def _(**kwargs): - event = cls(**kwargs) - self.handle_stimulus(event) - - _.__name__ = f"_handle_remote_stimulus({cls.__name__})" - return _ - - @_handle_event.register - def _handle_acquire_replicas(self, ev: AcquireReplicasEvent) -> RecsInstrs: - if self.validate: - assert all(ev.who_has.values()) - - recommendations: Recs = {} - for key in ev.who_has: - ts = self._ensure_task_exists( - key=key, - # Transfer this data after all dependency tasks of computations with - # default or explicitly high (>0) user priority and before all - # computations with low priority (<0). Note that the priority= parameter - # of compute() is multiplied by -1 before it reaches TaskState.priority. - priority=(1,), - stimulus_id=ev.stimulus_id, - ) - if ts.state != "memory": - recommendations[ts] = "fetch" - - self._update_who_has(ev.who_has) - return recommendations, [] - - def _ensure_task_exists( - self, key: str, *, priority: tuple[int, ...], stimulus_id: str - ) -> TaskState: - try: - ts = self.tasks[key] - logger.debug("Data task %s already known (stimulus_id=%s)", ts, stimulus_id) - except KeyError: - self.tasks[key] = ts = TaskState(key) - if not ts.priority: - assert priority - ts.priority = priority - - self.log.append((key, "ensure-task-exists", ts.state, stimulus_id, time())) - return ts - - @_handle_event.register - def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs: - try: - ts = self.tasks[ev.key] - logger.debug( - "Asked to compute an already known task %s", - {"task": ts, "stimulus_id": ev.stimulus_id}, - ) - except KeyError: - self.tasks[ev.key] = ts = TaskState(ev.key) - self.log.append((ev.key, "compute-task", ts.state, ev.stimulus_id, time())) - - recommendations: Recs = {} - instructions: Instructions = [] - - if ts.state in READY | { - "executing", - "long-running", - "waiting", - }: - pass - elif ts.state == "memory": - instructions.append( - self._get_task_finished_msg(ts, stimulus_id=ev.stimulus_id) - ) - elif ts.state == "error": - instructions.append(TaskErredMsg.from_task(ts, stimulus_id=ev.stimulus_id)) - elif ts.state in { - "released", - "fetch", - "flight", - "missing", - "cancelled", - "resumed", - }: - recommendations[ts] = "waiting" - - ts.run_spec = ev.run_spec - - priority = ev.priority + (self.generation,) - self.generation -= 1 - - if ev.actor: - self.actors[ts.key] = None - - ts.exception = None - ts.traceback = None - ts.exception_text = "" - ts.traceback_text = "" - ts.priority = priority - ts.duration = ev.duration - ts.resource_restrictions = ev.resource_restrictions - ts.annotations = ev.annotations - - if self.validate: - assert ev.who_has.keys() == ev.nbytes.keys() - assert all(ev.who_has.values()) - - for dep_key, dep_workers in ev.who_has.items(): - dep_ts = self._ensure_task_exists( - key=dep_key, - priority=priority, - stimulus_id=ev.stimulus_id, - ) - # link up to child / parents - ts.dependencies.add(dep_ts) - dep_ts.dependents.add(ts) - - for dep_key, value in ev.nbytes.items(): - self.tasks[dep_key].nbytes = value - - self._update_who_has(ev.who_has) - else: - raise RuntimeError( # pragma: nocover - f"Unexpected task state encountered for {ts}; " - f"stimulus_id={ev.stimulus_id}; story={self.story(ts)}" - ) - - return recommendations, instructions - - ######################## - # Worker State Machine # - ######################## - - def _transition_generic_fetch(self, ts: TaskState, stimulus_id: str) -> RecsInstrs: - if not ts.who_has: - return {ts: "missing"}, [] - - ts.state = "fetch" - ts.done = False - assert ts.priority - self.data_needed.add(ts) - for w in ts.who_has: - self.data_needed_per_worker[w].add(ts) - - # This is the same as `return self._ensure_communicating()`, except that when - # many tasks transition to fetch at the same time, e.g. from a single - # compute-task or acquire-replicas command from the scheduler, it allows - # clustering the transfers into less GatherDep instructions; see - # _select_keys_for_gather(). - return {}, [EnsureCommunicatingAfterTransitions(stimulus_id=stimulus_id)] - - def _transition_missing_waiting( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - self._missing_dep_flight.discard(ts) - self._purge_state(ts) - return self._transition_released_waiting(ts, stimulus_id=stimulus_id) - - def _transition_missing_fetch( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - if self.validate: - assert ts.state == "missing" - - if not ts.who_has: - return {}, [] - - self._missing_dep_flight.discard(ts) - return self._transition_generic_fetch(ts, stimulus_id=stimulus_id) - - def _transition_missing_released( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - self._missing_dep_flight.discard(ts) - recs, instructions = self._transition_generic_released( - ts, stimulus_id=stimulus_id - ) - assert ts.key in self.tasks - return recs, instructions - - def _transition_flight_missing( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - assert ts.done - return self._transition_generic_missing(ts, stimulus_id=stimulus_id) - - def _transition_generic_missing( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - if self.validate: - assert not ts.who_has - - ts.state = "missing" - self._missing_dep_flight.add(ts) - ts.done = False - return {}, [] - - def _transition_released_fetch( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - if self.validate: - assert ts.state == "released" - return self._transition_generic_fetch(ts, stimulus_id=stimulus_id) - - def _transition_generic_released( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - self._purge_state(ts) - recs: Recs = {} - for dependency in ts.dependencies: - if ( - not dependency.waiters - and dependency.state not in READY | PROCESSING | {"memory"} - ): - recs[dependency] = "released" - - ts.state = "released" - if not ts.dependents: - recs[ts] = "forgotten" - - return merge_recs_instructions( - (recs, []), - self._ensure_computing(), - ) - - def _transition_released_waiting( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - if self.validate: - assert all(d.key in self.tasks for d in ts.dependencies) - - recommendations: Recs = {} - ts.waiting_for_data.clear() - for dep_ts in ts.dependencies: - if dep_ts.state != "memory": - ts.waiting_for_data.add(dep_ts) - dep_ts.waiters.add(ts) - recommendations[dep_ts] = "fetch" - - if ts.waiting_for_data: - self.waiting_for_data_count += 1 - elif ts.resource_restrictions: - recommendations[ts] = "constrained" - else: - recommendations[ts] = "ready" - - ts.state = "waiting" - return recommendations, [] - - def _transition_fetch_flight( - self, ts: TaskState, worker: str, *, stimulus_id: str - ) -> RecsInstrs: - if self.validate: - assert ts.state == "fetch" - assert ts.who_has - # The task has already been removed by _ensure_communicating - assert ts not in self.data_needed - for w in ts.who_has: - assert ts not in self.data_needed_per_worker[w] - - ts.done = False - ts.state = "flight" - ts.coming_from = worker - self._in_flight_tasks.add(ts) - return {}, [] - - def _transition_fetch_missing( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - # _ensure_communicating could have just popped this task out of data_needed - self.data_needed.discard(ts) - return self._transition_generic_missing(ts, stimulus_id=stimulus_id) - - def _transition_memory_released( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - recs, instructions = self._transition_generic_released( - ts, stimulus_id=stimulus_id - ) - instructions.append(ReleaseWorkerDataMsg(key=ts.key, stimulus_id=stimulus_id)) - return recs, instructions - - def _transition_waiting_constrained( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - if self.validate: - assert ts.state == "waiting" - assert not ts.waiting_for_data - assert all( - dep.key in self.data or dep.key in self.actors - for dep in ts.dependencies - ) - assert all(dep.state == "memory" for dep in ts.dependencies) - assert ts.key not in self.ready - ts.state = "constrained" - self.constrained.append(ts.key) - return self._ensure_computing() - - def _transition_long_running_rescheduled( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - recs: Recs = {ts: "released"} - smsg = RescheduleMsg(key=ts.key, stimulus_id=stimulus_id) - return recs, [smsg] - - def _transition_executing_rescheduled( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - for resource, quantity in ts.resource_restrictions.items(): - self.available_resources[resource] += quantity - self._executing.discard(ts) - - return merge_recs_instructions( - ( - {ts: "released"}, - [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)], - ), - self._ensure_computing(), - ) - - def _transition_waiting_ready( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - if self.validate: - assert ts.state == "waiting" - assert ts.key not in self.ready - assert not ts.waiting_for_data - for dep in ts.dependencies: - assert dep.key in self.data or dep.key in self.actors - assert dep.state == "memory" - - ts.state = "ready" - assert ts.priority is not None - heapq.heappush(self.ready, (ts.priority, ts.key)) - - return self._ensure_computing() - - def _transition_cancelled_error( - self, - ts: TaskState, - exception: Serialize, - traceback: Serialize | None, - exception_text: str, - traceback_text: str, - *, - stimulus_id: str, - ) -> RecsInstrs: - assert ts._previous == "executing" or ts.key in self.long_running - recs, instructions = self._transition_executing_error( - ts, - exception, - traceback, - exception_text, - traceback_text, - stimulus_id=stimulus_id, - ) - # We'll ignore instructions, i.e. we choose to not submit the failure - # message to the scheduler since from the schedulers POV it already - # released this task - if self.validate: - assert len(instructions) == 1 - assert isinstance(instructions[0], TaskErredMsg) - assert instructions[0].key == ts.key - instructions.clear() - # Workers should never "retry" tasks. A transition to error should, by - # default, be the end. Since cancelled indicates that the scheduler lost - # interest, we can transition straight to released - assert ts not in recs - recs[ts] = "released" - return recs, instructions - - def _transition_generic_error( - self, - ts: TaskState, - exception: Serialize, - traceback: Serialize | None, - exception_text: str, - traceback_text: str, - *, - stimulus_id: str, - ) -> RecsInstrs: - ts.exception = exception - ts.traceback = traceback - ts.exception_text = exception_text - ts.traceback_text = traceback_text - ts.state = "error" - smsg = TaskErredMsg.from_task( - ts, - stimulus_id=stimulus_id, - thread=self.threads.get(ts.key), - ) - - return {}, [smsg] - - def _transition_executing_error( + @log_errors + async def plugin_add( self, - ts: TaskState, - exception: Serialize, - traceback: Serialize | None, - exception_text: str, - traceback_text: str, - *, - stimulus_id: str, - ) -> RecsInstrs: - for resource, quantity in ts.resource_restrictions.items(): - self.available_resources[resource] += quantity - self._executing.discard(ts) - - return merge_recs_instructions( - self._transition_generic_error( - ts, - exception, - traceback, - exception_text, - traceback_text, - stimulus_id=stimulus_id, - ), - self._ensure_computing(), - ) - - def _transition_from_resumed( - self, ts: TaskState, finish: TaskStateState, *args, stimulus_id: str - ) -> RecsInstrs: - """`resumed` is an intermediate degenerate state which splits further up - into two states depending on what the last signal / next state is - intended to be. There are only two viable choices depending on whether - the task is required to be fetched from another worker `resumed(fetch)` - or the task shall be computed on this worker `resumed(waiting)`. - - The only viable state transitions ending up here are - - flight -> cancelled -> resumed(waiting) - - or - - executing -> cancelled -> resumed(fetch) + plugin: WorkerPlugin | bytes, + name: str | None = None, + catch_errors: bool = True, + ) -> dict[str, Any]: + if isinstance(plugin, bytes): + # Note: historically we have accepted duck-typed classes that don't + # inherit from WorkerPlugin. Don't do `assert isinstance`. + plugin = cast("WorkerPlugin", pickle.loads(plugin)) - depending on the origin. Equally, only `fetch`, `waiting`, or `released` - are allowed output states. + if name is None: + name = _get_plugin_name(plugin) - See also `_transition_resumed_waiting` - """ - recs: Recs = {} - instructions: Instructions = [] - - if ts._previous == finish: - # We're back where we started. We should forget about the entire - # cancellation attempt - ts.state = finish - ts._next = None - ts._previous = None - elif not ts.done: - # If we're not done, yet, just remember where we want to be next - ts._next = finish - else: - # Flight/executing finished unsuccessfully, i.e. not in memory - assert finish != "memory" - next_state = ts._next - assert next_state in {"waiting", "fetch"}, next_state - assert ts._previous in {"executing", "flight"}, ts._previous - - if next_state != finish: - recs, instructions = self._transition_generic_released( - ts, stimulus_id=stimulus_id - ) - recs[ts] = next_state + assert name - return recs, instructions + if name in self.plugins: + await self.plugin_remove(name=name) - def _transition_resumed_fetch( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - """See Worker._transition_from_resumed""" - recs, instructions = self._transition_from_resumed( - ts, "fetch", stimulus_id=stimulus_id - ) - if self.validate: - # This would only be possible in a fetch->cancelled->resumed->fetch loop, - # but there are no transitions from fetch which set the state to cancelled. - # If this assertion failed, we' need to call _ensure_communicating like in - # the other transitions that set ts.status = "fetch". - assert ts.state != "fetch" - return recs, instructions - - def _transition_resumed_missing( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - """See Worker._transition_from_resumed""" - return self._transition_from_resumed(ts, "missing", stimulus_id=stimulus_id) - - def _transition_resumed_released( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - if not ts.done: - ts.state = "cancelled" - ts._next = None - return {}, [] - else: - return self._transition_generic_released(ts, stimulus_id=stimulus_id) - - def _transition_resumed_waiting( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - """See Worker._transition_from_resumed""" - return self._transition_from_resumed(ts, "waiting", stimulus_id=stimulus_id) - - def _transition_cancelled_fetch( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - if ts.done: - return {ts: "released"}, [] - elif ts._previous == "flight": - ts.state = ts._previous - return {}, [] - else: - assert ts._previous == "executing" - ts.state = "resumed" - ts._next = "fetch" - return {}, [] - - def _transition_cancelled_waiting( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - if ts.done: - return {ts: "released"}, [] - elif ts._previous == "executing": - ts.state = ts._previous - return {}, [] - else: - assert ts._previous == "flight" - ts.state = "resumed" - ts._next = "waiting" - return {}, [] - - def _transition_cancelled_forgotten( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - ts._next = "forgotten" - if not ts.done: - return {}, [] - return {ts: "released"}, [] - - def _transition_cancelled_released( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - if not ts.done: - return {}, [] - self._executing.discard(ts) - self._in_flight_tasks.discard(ts) - - for resource, quantity in ts.resource_restrictions.items(): - self.available_resources[resource] += quantity - - return self._transition_generic_released(ts, stimulus_id=stimulus_id) - - def _transition_executing_released( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - ts._previous = ts.state - ts._next = None - # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 - ts.state = "cancelled" - ts.done = False - return self._ensure_computing() - - def _transition_long_running_memory( - self, ts: TaskState, value=NO_VALUE, *, stimulus_id: str - ) -> RecsInstrs: - self.executed_count += 1 - return self._transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) - - def _transition_generic_memory( - self, ts: TaskState, value=NO_VALUE, *, stimulus_id: str - ) -> RecsInstrs: - if value is NO_VALUE and ts.key not in self.data: - raise RuntimeError( - f"Tried to transition task {ts} to `memory` without data available" - ) + self.plugins[name] = plugin - if ts.resource_restrictions is not None: - for resource, quantity in ts.resource_restrictions.items(): - self.available_resources[resource] += quantity + logger.info("Starting Worker plugin %s" % name) + if hasattr(plugin, "setup"): + try: + result = plugin.setup(worker=self) + if isawaitable(result): + result = await result + except Exception as e: + if not catch_errors: + raise + msg = error_message(e) + return cast("dict[str, Any]", msg) - self._executing.discard(ts) - self._in_flight_tasks.discard(ts) - ts.coming_from = None + return {"status": "OK"} - instructions: Instructions = [] + @log_errors + async def plugin_remove(self, name: str) -> dict[str, Any]: + logger.info(f"Removing Worker plugin {name}") try: - recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) + plugin = self.plugins.pop(name) + if hasattr(plugin, "teardown"): + result = plugin.teardown(worker=self) + if isawaitable(result): + result = await result except Exception as e: msg = error_message(e) - recs = {ts: tuple(msg.values())} - else: - if self.validate: - assert ts.key in self.data or ts.key in self.actors - instructions.append( - self._get_task_finished_msg(ts, stimulus_id=stimulus_id) - ) - - return recs, instructions - - def _transition_executing_memory( - self, ts: TaskState, value=NO_VALUE, *, stimulus_id: str - ) -> RecsInstrs: - if self.validate: - assert ts.state == "executing" or ts.key in self.long_running - assert not ts.waiting_for_data - assert ts.key not in self.ready - - self._executing.discard(ts) - self.executed_count += 1 - return merge_recs_instructions( - self._transition_generic_memory(ts, value=value, stimulus_id=stimulus_id), - self._ensure_computing(), - ) - - def _transition_constrained_executing( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - if self.validate: - assert not ts.waiting_for_data - assert ts.key not in self.data - assert ts.state in READY - assert ts.key not in self.ready - for dep in ts.dependencies: - assert dep.key in self.data or dep.key in self.actors - - ts.state = "executing" - instr = Execute(key=ts.key, stimulus_id=stimulus_id) - return {}, [instr] - - def _transition_ready_executing( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - if self.validate: - assert not ts.waiting_for_data - assert ts.key not in self.data - assert ts.state in READY - assert ts.key not in self.ready - assert all( - dep.key in self.data or dep.key in self.actors - for dep in ts.dependencies - ) - - ts.state = "executing" - instr = Execute(key=ts.key, stimulus_id=stimulus_id) - return {}, [instr] - - def _transition_flight_fetch( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - # If this transition is called after the flight coroutine has finished, - # we can reset the task and transition to fetch again. If it is not yet - # finished, this should be a no-op - if not ts.done: - return {}, [] - - ts.coming_from = None - return self._transition_generic_fetch(ts, stimulus_id=stimulus_id) - - def _transition_flight_error( - self, - ts: TaskState, - exception: Serialize, - traceback: Serialize | None, - exception_text: str, - traceback_text: str, - *, - stimulus_id: str, - ) -> RecsInstrs: - self._in_flight_tasks.discard(ts) - ts.coming_from = None - return self._transition_generic_error( - ts, - exception, - traceback, - exception_text, - traceback_text, - stimulus_id=stimulus_id, - ) + return cast("dict[str, Any]", msg) - def _transition_flight_released( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - if ts.done: - # FIXME: Is this even possible? Would an assert instead be more - # sensible? - return self._transition_generic_released(ts, stimulus_id=stimulus_id) - else: - ts._previous = "flight" - ts._next = None - # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 - ts.state = "cancelled" - return {}, [] - - def _transition_cancelled_memory(self, ts, value, *, stimulus_id): - # We only need this because the to-memory signatures require a value but - # we do not want to store a cancelled result and want to release - # immediately - assert ts.done - - return self._transition_cancelled_released(ts, stimulus_id=stimulus_id) - - def _transition_executing_long_running( - self, ts: TaskState, compute_duration: float, *, stimulus_id: str - ) -> RecsInstrs: - ts.state = "long-running" - self._executing.discard(ts) - self.long_running.add(ts.key) - - smsg = LongRunningMsg( - key=ts.key, compute_duration=compute_duration, stimulus_id=stimulus_id - ) - return merge_recs_instructions( - ({}, [smsg]), - self._ensure_computing(), - ) + return {"status": "OK"} - def _transition_released_memory( - self, ts: TaskState, value, *, stimulus_id: str - ) -> RecsInstrs: - try: - recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) - except Exception as e: - msg = error_message(e) - recs = {ts: tuple(msg.values())} - return recs, [] - smsg = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id) - return recs, [smsg] - - def _transition_flight_memory( - self, ts: TaskState, value, *, stimulus_id: str - ) -> RecsInstrs: - self._in_flight_tasks.discard(ts) - ts.coming_from = None - try: - recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) - except Exception as e: - msg = error_message(e) - recs = {ts: tuple(msg.values())} - return recs, [] - smsg = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id) - return recs, [smsg] - - def _transition_released_forgotten( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - recommendations: Recs = {} - # Dependents _should_ be released by the scheduler before this - if self.validate: - assert not any(d.state != "forgotten" for d in ts.dependents) - for dep in ts.dependencies: - dep.dependents.discard(ts) - if dep.state == "released" and not dep.dependents: - recommendations[dep] = "forgotten" - self._purge_state(ts) - # Mark state as forgotten in case it is still referenced - ts.state = "forgotten" - self.tasks.pop(ts.key, None) - return recommendations, [] - - # { - # (start, finish): - # transition__( - # self, ts: TaskState, *args, stimulus_id: str - # ) -> (recommendations, instructions) - # } - _TRANSITIONS_TABLE: ClassVar[ - Mapping[tuple[TaskStateState, TaskStateState], Callable[..., RecsInstrs]] - ] = { - ("cancelled", "fetch"): _transition_cancelled_fetch, - ("cancelled", "released"): _transition_cancelled_released, - ("cancelled", "missing"): _transition_cancelled_released, - ("cancelled", "waiting"): _transition_cancelled_waiting, - ("cancelled", "forgotten"): _transition_cancelled_forgotten, - ("cancelled", "memory"): _transition_cancelled_memory, - ("cancelled", "error"): _transition_cancelled_error, - ("resumed", "memory"): _transition_generic_memory, - ("resumed", "error"): _transition_generic_error, - ("resumed", "released"): _transition_resumed_released, - ("resumed", "waiting"): _transition_resumed_waiting, - ("resumed", "fetch"): _transition_resumed_fetch, - ("resumed", "missing"): _transition_resumed_missing, - ("constrained", "executing"): _transition_constrained_executing, - ("constrained", "released"): _transition_generic_released, - ("error", "released"): _transition_generic_released, - ("executing", "error"): _transition_executing_error, - ("executing", "long-running"): _transition_executing_long_running, - ("executing", "memory"): _transition_executing_memory, - ("executing", "released"): _transition_executing_released, - ("executing", "rescheduled"): _transition_executing_rescheduled, - ("fetch", "flight"): _transition_fetch_flight, - ("fetch", "missing"): _transition_fetch_missing, - ("fetch", "released"): _transition_generic_released, - ("flight", "error"): _transition_flight_error, - ("flight", "fetch"): _transition_flight_fetch, - ("flight", "memory"): _transition_flight_memory, - ("flight", "missing"): _transition_flight_missing, - ("flight", "released"): _transition_flight_released, - ("long-running", "error"): _transition_generic_error, - ("long-running", "memory"): _transition_long_running_memory, - ("long-running", "rescheduled"): _transition_executing_rescheduled, - ("long-running", "released"): _transition_executing_released, - ("memory", "released"): _transition_memory_released, - ("missing", "fetch"): _transition_missing_fetch, - ("missing", "released"): _transition_missing_released, - ("missing", "error"): _transition_generic_error, - ("missing", "waiting"): _transition_missing_waiting, - ("ready", "error"): _transition_generic_error, - ("ready", "executing"): _transition_ready_executing, - ("ready", "released"): _transition_generic_released, - ("released", "error"): _transition_generic_error, - ("released", "fetch"): _transition_released_fetch, - ("released", "missing"): _transition_generic_missing, - ("released", "forgotten"): _transition_released_forgotten, - ("released", "memory"): _transition_released_memory, - ("released", "waiting"): _transition_released_waiting, - ("waiting", "constrained"): _transition_waiting_constrained, - ("waiting", "ready"): _transition_waiting_ready, - ("waiting", "released"): _transition_generic_released, - } - - def _transition( - self, - ts: TaskState, - finish: TaskStateState | tuple, - *args, - stimulus_id: str, - **kwargs, - ) -> RecsInstrs: - """Transition a key from its current state to the finish state + def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: + new_status = Status.lookup[status] # type: ignore - See Also - -------- - Worker.transitions: wrapper around this method - """ - if isinstance(finish, tuple): - # the concatenated transition path might need to access the tuple - assert not args - args = finish[1:] - finish = cast(TaskStateState, finish[0]) - - if ts.state == finish: - return {}, [] - - start = ts.state - func = self._TRANSITIONS_TABLE.get((start, finish)) - - # Notes: - # - in case of transition through released, this counter is incremented by 2 - # - this increase happens before the actual transitions, so that it can - # catch potential infinite recursions - self.transition_counter += 1 if ( - self.transition_counter_max - and self.transition_counter >= self.transition_counter_max + new_status == Status.closing_gracefully + and self._status not in WORKER_ANY_RUNNING ): - raise TransitionCounterMaxExceeded(ts.key, start, finish, self.story(ts)) - - if func is not None: - recs, instructions = func( - self, ts, *args, stimulus_id=stimulus_id, **kwargs + logger.error( + "Invalid Worker.status transition: %s -> %s", self._status, new_status ) - self._notify_plugins("transition", ts.key, start, finish, **kwargs) - - elif "released" not in (start, finish): - # start -> "released" -> finish - try: - recs, instructions = self._transition( - ts, "released", stimulus_id=stimulus_id - ) - v_state: TaskStateState - v_args: list | tuple - while v := recs.pop(ts, None): - if isinstance(v, tuple): - v_state, *v_args = v - else: - v_state, v_args = v, () - if v_state == "forgotten": - # We do not want to forget. The purpose of this - # transition path is to get to `finish` - continue - recs, instructions = merge_recs_instructions( - (recs, instructions), - self._transition(ts, v_state, *v_args, stimulus_id=stimulus_id), - ) - recs, instructions = merge_recs_instructions( - (recs, instructions), - self._transition(ts, finish, *args, stimulus_id=stimulus_id), - ) - except (InvalidTransition, RecommendationsConflict) as e: - raise InvalidTransition(ts.key, start, finish, self.story(ts)) from e - + # Reiterate the current status to the scheduler to restore sync + self._send_worker_status_change(stimulus_id) else: - raise InvalidTransition(ts.key, start, finish, self.story(ts)) - - self.log.append( - ( - # key - ts.key, - # initial - start, - # recommended - finish, - # final - ts.state, - # new recommendations - {ts.key: new for ts, new in recs.items()}, - stimulus_id, - time(), - ) - ) - return recs, instructions - - def _transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: - """Process transitions until none are left - - This includes feedback from previous transitions and continues until we - reach a steady state - """ - instructions = [] - - remaining_recs = recommendations.copy() - tasks = set() - while remaining_recs: - ts, finish = remaining_recs.popitem() - tasks.add(ts) - a_recs, a_instructions = self._transition( - ts, finish, stimulus_id=stimulus_id - ) + # Update status and send confirmation to the Scheduler (see status.setter) + self.status = new_status - remaining_recs.update(a_recs) - instructions += a_instructions + ################### + # Task Management # + ################### - if self.validate: - # Full state validation is very expensive - for ts in tasks: - self.validate_task(ts) + @fail_hard + def _handle_remote_stimulus( + self, cls: type[StateMachineEvent] + ) -> Callable[..., None]: + def _(**kwargs): + event = cls(**kwargs) + self.handle_stimulus(event) - self._handle_instructions(instructions) + _.__name__ = f"_handle_remote_stimulus({cls.__name__})" + return _ @fail_hard @log_errors @@ -2878,94 +1861,6 @@ def handle_stimulus(self, stim: StateMachineEvent) -> None: self.log_event(topic, msg) raise - @fail_hard - @log_errors - def _handle_stimulus_from_task( - self, task: asyncio.Task[StateMachineEvent | None] - ) -> None: - self._async_instructions.remove(task) - try: - # This *should* never raise any other exceptions - stim = task.result() - except asyncio.CancelledError: - return - if stim: - self.handle_stimulus(stim) - - @fail_hard - def _handle_instructions(self, instructions: Instructions) -> None: - while instructions: - ensure_communicating: EnsureCommunicatingAfterTransitions | None = None - for inst in instructions: - task: asyncio.Task | None = None - - if isinstance(inst, SendMessageToScheduler): - self.batched_send(inst.to_dict()) - - elif isinstance(inst, EnsureCommunicatingAfterTransitions): - # A single compute-task or acquire-replicas command may cause - # multiple tasks to transition to fetch; this in turn means that we - # will receive multiple instances of this instruction. - # _ensure_communicating is a no-op if it runs twice in a row; we're - # not calling it inside the for loop to avoid a O(n^2) condition - # when - # 1. there are many fetches queued because all workers are in flight - # 2. a single compute-task or acquire-replicas command just sent - # many dependencies to fetch at once. - ensure_communicating = inst - - elif isinstance(inst, GatherDep): - assert inst.to_gather - keys_str = ", ".join(peekn(27, inst.to_gather)[0]) - if len(keys_str) > 80: - keys_str = keys_str[:77] + "..." - task = asyncio.create_task( - self.gather_dep( - inst.worker, - inst.to_gather, - total_nbytes=inst.total_nbytes, - stimulus_id=inst.stimulus_id, - ), - name=f"gather_dep({inst.worker}, {{{keys_str}}})", - ) - - elif isinstance(inst, Execute): - task = asyncio.create_task( - self.execute(inst.key, stimulus_id=inst.stimulus_id), - name=f"execute({inst.key})", - ) - - elif isinstance(inst, RetryBusyWorkerLater): - task = asyncio.create_task( - self.retry_busy_worker_later(inst.worker), - name=f"retry_busy_worker_later({inst.worker})", - ) - - else: - raise TypeError(inst) # pragma: nocover - - if task is not None: - self._async_instructions.add(task) - task.add_done_callback(self._handle_stimulus_from_task) - - if ensure_communicating: - # Potentially re-fill instructions, causing a second iteration of `while - # instructions` at the top of this method - recs, instructions = self._ensure_communicating( - stimulus_id=ensure_communicating.stimulus_id - ) - self._transitions(recs, stimulus_id=ensure_communicating.stimulus_id) - else: - instructions = [] - - @_handle_event.register - def _handle_secede(self, ev: SecedeEvent) -> RecsInstrs: - ts = self.tasks.get(ev.key) - if ts and ts.state == "executing": - return {ts: ("long-running", ev.compute_duration)}, [] - else: - return {}, [] - def stateof(self, key: str) -> dict[str, Any]: ts = self.tasks[key] return { @@ -2975,217 +1870,9 @@ def stateof(self, key: str) -> dict[str, Any]: "data": key in self.data, } - def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]: - """Return all transitions involving one or more tasks""" - keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} - return worker_story(keys, self.log) - async def get_story(self, keys=None): return self.story(*keys) - def stimulus_story( - self, *keys_or_tasks: str | TaskState - ) -> list[StateMachineEvent]: - """Return all state machine events involving one or more tasks""" - keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} - return [ev for ev in self.stimulus_log if getattr(ev, "key", None) in keys] - - def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: - if self.status != Status.running: - return {}, [] - - skipped_worker_in_flight_or_busy = [] - - recommendations: Recs = {} - instructions: Instructions = [] - - while self.data_needed and ( - len(self.in_flight_workers) < self.total_out_connections - or self.comm_nbytes < self.comm_threshold_bytes - ): - logger.debug( - "Ensure communicating. Pending: %d. Connections: %d/%d. Busy: %d", - len(self.data_needed), - len(self.in_flight_workers), - self.total_out_connections, - len(self.busy_workers), - ) - - ts = self.data_needed.pop() - - if self.validate: - assert ts.state == "fetch" - assert self.address not in ts.who_has - - if not ts.who_has: - recommendations[ts] = "missing" - continue - - workers = [ - w - for w in ts.who_has - if w not in self.in_flight_workers and w not in self.busy_workers - ] - if not workers: - skipped_worker_in_flight_or_busy.append(ts) - continue - - for w in ts.who_has: - self.data_needed_per_worker[w].remove(ts) - - host = get_address_host(self.address) - local = [w for w in workers if get_address_host(w) == host] - worker = random.choice(local or workers) - - to_gather_tasks, total_nbytes = self._select_keys_for_gather(worker, ts) - to_gather_keys = {ts.key for ts in to_gather_tasks} - - self.log.append( - ("gather-dependencies", worker, to_gather_keys, stimulus_id, time()) - ) - - self.comm_nbytes += total_nbytes - self.in_flight_workers[worker] = to_gather_keys - for d_ts in to_gather_tasks: - if self.validate: - assert d_ts.state == "fetch" - assert d_ts not in recommendations - recommendations[d_ts] = ("flight", worker) - - # A single invocation of _ensure_communicating may generate up to one - # GatherDep instruction per worker. Multiple tasks from the same worker may - # be clustered in the same instruction by _select_keys_for_gather. But once - # a worker has been selected for a GatherDep and added to in_flight_workers, - # it won't be selected again until the gather completes. - instructions.append( - GatherDep( - worker=worker, - to_gather=to_gather_keys, - total_nbytes=total_nbytes, - stimulus_id=stimulus_id, - ) - ) - - for ts in skipped_worker_in_flight_or_busy: - self.data_needed.add(ts) - - return recommendations, instructions - - def _get_task_finished_msg( - self, ts: TaskState, stimulus_id: str - ) -> TaskFinishedMsg: - if ts.key not in self.data and ts.key not in self.actors: - raise RuntimeError(f"Task {ts} not ready") - typ = ts.type - if ts.nbytes is None or typ is None: - try: - value = self.data[ts.key] - except KeyError: - value = self.actors[ts.key] - ts.nbytes = sizeof(value) - typ = ts.type = type(value) - del value - try: - typ_serialized = dumps_function(typ) - except PicklingError: - # Some types fail pickling (example: _thread.lock objects), - # send their name as a best effort. - typ_serialized = pickle.dumps(typ.__name__, protocol=4) - return TaskFinishedMsg( - key=ts.key, - nbytes=ts.nbytes, - type=typ_serialized, - typename=typename(typ), - metadata=ts.metadata, - thread=self.threads.get(ts.key), - startstops=ts.startstops, - stimulus_id=stimulus_id, - ) - - def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: - """ - Put a key into memory and set data related task state attributes. - On success, generate recommendations for dependents. - - This method does not generate any scheduler messages since this method - cannot distinguish whether it has to be an `add-task` or a - `task-finished` signal. The caller is required to generate this message - on success. - - Raises - ------ - Exception: - In case the data is put into the in memory buffer and a serialization error - occurs during spilling, this raises that error. This has to be handled by - the caller since most callers generate scheduler messages on success (see - comment above) but we need to signal that this was not successful. - - Can only trigger if distributed.worker.memory.target is enabled, the value - is individually larger than target * memory_limit, and the task is not an - actor. - """ - if ts.key in self.data: - ts.state = "memory" - return {} - - recommendations: Recs = {} - if ts.key in self.actors: - self.actors[ts.key] = value - else: - start = time() - self.data[ts.key] = value - stop = time() - if stop - start > 0.020: - ts.startstops.append( - {"action": "disk-write", "start": start, "stop": stop} - ) - - ts.state = "memory" - if ts.nbytes is None: - ts.nbytes = sizeof(value) - - ts.type = type(value) - - for dep in ts.dependents: - dep.waiting_for_data.discard(ts) - if not dep.waiting_for_data and dep.state == "waiting": - self.waiting_for_data_count -= 1 - recommendations[dep] = "ready" - - self.log.append((ts.key, "put-in-memory", stimulus_id, time())) - return recommendations - - def _select_keys_for_gather( - self, worker: str, ts: TaskState - ) -> tuple[set[TaskState], int]: - """``_ensure_communicating`` decided to fetch a single task from a worker, - following priority. In order to minimise overhead, request fetching other tasks - from the same worker within the message, following priority for the single - worker but ignoring higher priority tasks from other workers, up to - ``target_message_size``. - """ - tss = {ts} - total_bytes = ts.get_nbytes() - tasks = self.data_needed_per_worker[worker] - - while tasks: - ts = tasks.peek() - if self.validate: - assert ts.state == "fetch" - assert worker in ts.who_has - if total_bytes + ts.get_nbytes() > self.target_message_size: - break - tasks.pop() - self.data_needed.remove(ts) - for other_worker in ts.who_has: - if other_worker != worker: - self.data_needed_per_worker[other_worker].remove(ts) - - tss.add(ts) - total_bytes += ts.get_nbytes() - - return tss, total_bytes - @property def total_comm_bytes(self): warnings.warn( @@ -3195,6 +1882,10 @@ def total_comm_bytes(self): ) return self.comm_threshold_bytes + ########################## + # Dependencies gathering # + ########################## + def _get_cause(self, keys: Iterable[str]) -> TaskState: """For diagnostics, we want to attach a transfer to a single task. This task is typically the next to be executed but since we're fetching tasks for potentially @@ -3352,124 +2043,6 @@ async def gather_dep( stimulus_id=f"gather-dep-failure-{time()}", ) - def _gather_dep_done_common(self, ev: GatherDepDoneEvent) -> Iterator[TaskState]: - """Common code for all subclasses of GatherDepDoneEvent. - - Yields the tasks that need to transition out of flight. - """ - self.comm_nbytes -= ev.total_nbytes - keys = self.in_flight_workers.pop(ev.worker) - for key in keys: - ts = self.tasks[key] - ts.done = True - yield ts - - @_handle_event.register - def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs: - """gather_dep terminated successfully. - The response may contain less keys than the request. - """ - recommendations: Recs = {} - for ts in self._gather_dep_done_common(ev): - if ts.key in ev.data: - recommendations[ts] = ("memory", ev.data[ts.key]) - else: - self.log.append((ts.key, "missing-dep", ev.stimulus_id, time())) - if self.validate: - assert ts.state != "fetch" - assert ts not in self.data_needed_per_worker[ev.worker] - ts.who_has.discard(ev.worker) - self.has_what[ev.worker].discard(ts.key) - recommendations[ts] = "fetch" - - return merge_recs_instructions( - (recommendations, []), - self._ensure_communicating(stimulus_id=ev.stimulus_id), - ) - - @_handle_event.register - def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs: - """gather_dep terminated: remote worker is busy""" - # Avoid hammering the worker. If there are multiple replicas - # available, immediately try fetching from a different worker. - self.busy_workers.add(ev.worker) - - recommendations: Recs = {} - refresh_who_has = [] - for ts in self._gather_dep_done_common(ev): - recommendations[ts] = "fetch" - if not ts.who_has - self.busy_workers: - refresh_who_has.append(ts.key) - - instructions: Instructions = [ - RetryBusyWorkerLater(worker=ev.worker, stimulus_id=ev.stimulus_id), - ] - - if refresh_who_has: - # All workers that hold known replicas of our tasks are busy. - # Try querying the scheduler for unknown ones. - instructions.append( - RequestRefreshWhoHasMsg( - keys=refresh_who_has, stimulus_id=ev.stimulus_id - ) - ) - - return merge_recs_instructions( - (recommendations, instructions), - self._ensure_communicating(stimulus_id=ev.stimulus_id), - ) - - @_handle_event.register - def _handle_gather_dep_network_failure( - self, ev: GatherDepNetworkFailureEvent - ) -> RecsInstrs: - """gather_dep terminated: network failure while trying to - communicate with remote worker - - Though the network failure could be transient, we assume it is not, and - preemptively act as though the other worker has died (including removing all - keys from it, even ones we did not fetch). - - This optimization leads to faster completion of the fetch, since we immediately - either retry a different worker, or ask the scheduler to inform us of a new - worker if no other worker is available. - """ - self.data_needed_per_worker.pop(ev.worker) - for key in self.has_what.pop(ev.worker): - ts = self.tasks[key] - ts.who_has.discard(ev.worker) - - recommendations: Recs = {} - for ts in self._gather_dep_done_common(ev): - self.log.append((ts.key, "missing-dep", ev.stimulus_id, time())) - recommendations[ts] = "fetch" - - return merge_recs_instructions( - (recommendations, []), - self._ensure_communicating(stimulus_id=ev.stimulus_id), - ) - - @_handle_event.register - def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs: - """gather_dep terminated: generic error raised (not a network failure); - e.g. data failed to deserialize. - """ - recommendations: Recs = { - ts: ( - "error", - ev.exception, - ev.traceback, - ev.exception_text, - ev.traceback_text, - ) - for ts in self._gather_dep_done_common(ev) - } - - return merge_recs_instructions( - (recommendations, []), - self._ensure_communicating(stimulus_id=ev.stimulus_id), - ) - async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None: await asyncio.sleep(0.15) return RetryBusyWorkerEvent( @@ -3485,113 +2058,6 @@ def find_missing(self) -> None: "heartbeat" ].callback_time - def _update_who_has(self, who_has: Mapping[str, Collection[str]]) -> None: - for key, workers in who_has.items(): - ts = self.tasks.get(key) - if not ts: - # The worker sent a refresh-who-has request to the scheduler but, by the - # time the answer comes back, some of the keys have been forgotten. - continue - workers = set(workers) - - if self.address in workers: - workers.remove(self.address) - # This can only happen if rebalance() recently asked to release a key, - # but the RPC call hasn't returned yet. rebalance() is flagged as not - # being safe to run while the cluster is not at rest and has already - # been penned in to be redesigned on top of the AMM. - # It is not necessary to send a message back to the - # scheduler here, because it is guaranteed that there's already a - # release-worker-data message in transit to it. - if ts.state != "memory": - logger.debug( # pragma: nocover - "Scheduler claims worker %s holds data for task %s, " - "which is not true.", - self.address, - ts, - ) - - if ts.who_has == workers: - continue - - for worker in ts.who_has - workers: - self.has_what[worker].remove(key) - if ts.state == "fetch": - self.data_needed_per_worker[worker].remove(ts) - - for worker in workers - ts.who_has: - self.has_what[worker].add(key) - if ts.state == "fetch": - self.data_needed_per_worker[worker].add(ts) - - ts.who_has = workers - - @_handle_event.register - def _handle_steal_request(self, ev: StealRequestEvent) -> RecsInstrs: - # There may be a race condition between stealing and releasing a task. - # In this case the self.tasks is already cleared. The `None` will be - # registered as `already-computing` on the other end - ts = self.tasks.get(ev.key) - state = ts.state if ts is not None else None - smsg = StealResponseMsg(key=ev.key, state=state, stimulus_id=ev.stimulus_id) - - if state in READY | {"waiting"}: - # If task is marked as "constrained" we haven't yet assigned it an - # `available_resources` to run on, that happens in - # `_transition_constrained_executing` - assert ts - return {ts: "released"}, [smsg] - else: - return {}, [smsg] - - def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: - new_status = Status.lookup[status] # type: ignore - - if ( - new_status == Status.closing_gracefully - and self._status not in WORKER_ANY_RUNNING - ): - logger.error( - "Invalid Worker.status transition: %s -> %s", self._status, new_status - ) - # Reiterate the current status to the scheduler to restore sync - self._send_worker_status_change(stimulus_id) - else: - # Update status and send confirmation to the Scheduler (see status.setter) - self.status = new_status - - def _purge_state(self, ts: TaskState) -> None: - """Ensure that TaskState attributes are reset to a neutral default and - Worker-level state associated to the provided key is cleared (e.g. - who_has) - This is idempotent - """ - key = ts.key - logger.debug("Purge task key: %s state: %s; stimulus_id=%s", ts.key, ts.state) - self.data.pop(key, None) - self.actors.pop(key, None) - - for worker in ts.who_has: - self.has_what[worker].discard(ts.key) - self.data_needed_per_worker[worker].discard(ts) - ts.who_has.clear() - self.data_needed.discard(ts) - - self.threads.pop(key, None) - - for d in ts.dependencies: - ts.waiting_for_data.discard(d) - d.waiters.discard(ts) - - ts.waiting_for_data.clear() - ts.nbytes = None - ts._previous = None - ts._next = None - ts.done = False - - self._executing.discard(ts) - self._in_flight_tasks.discard(ts) - ################ # Execute Task # ################ @@ -3602,57 +2068,6 @@ def run(self, comm, function, args=(), wait=True, kwargs=None): def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True): return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) - @log_errors - async def plugin_add( - self, - plugin: WorkerPlugin | bytes, - name: str | None = None, - catch_errors: bool = True, - ) -> dict[str, Any]: - if isinstance(plugin, bytes): - # Note: historically we have accepted duck-typed classes that don't - # inherit from WorkerPlugin. Don't do `assert isinstance`. - plugin = cast("WorkerPlugin", pickle.loads(plugin)) - - if name is None: - name = _get_plugin_name(plugin) - - assert name - - if name in self.plugins: - await self.plugin_remove(name=name) - - self.plugins[name] = plugin - - logger.info("Starting Worker plugin %s" % name) - if hasattr(plugin, "setup"): - try: - result = plugin.setup(worker=self) - if isawaitable(result): - result = await result - except Exception as e: - if not catch_errors: - raise - msg = error_message(e) - return cast("dict[str, Any]", msg) - - return {"status": "OK"} - - @log_errors - async def plugin_remove(self, name: str) -> dict[str, Any]: - logger.info(f"Removing Worker plugin {name}") - try: - plugin = self.plugins.pop(name) - if hasattr(plugin, "teardown"): - result = plugin.teardown(worker=self) - if isawaitable(result): - result = await result - except Exception as e: - msg = error_message(e) - return cast("dict[str, Any]", msg) - - return {"status": "OK"} - async def actor_execute( self, actor=None, @@ -3713,63 +2128,6 @@ async def _maybe_deserialize_task( ) return function, args, kwargs - def _ensure_computing(self) -> RecsInstrs: - if self.status != Status.running: - return {}, [] - - recs: Recs = {} - while self.constrained and len(self._executing) < self.nthreads: - key = self.constrained[0] - ts = self.tasks.get(key, None) - if ts is None or ts.state != "constrained": - self.constrained.popleft() - continue - - # There may be duplicates in the self.constrained and self.ready queues. - # This happens if a task: - # 1. is assigned to a Worker and transitioned to ready (heappush) - # 2. is stolen (no way to pop from heap, the task stays there) - # 3. is assigned to the worker again (heappush again) - if ts in recs: - continue - - if any( - self.available_resources[resource] < needed - for resource, needed in ts.resource_restrictions.items() - ): - break - - self.constrained.popleft() - for resource, needed in ts.resource_restrictions.items(): - self.available_resources[resource] -= needed - - recs[ts] = "executing" - self._executing.add(ts) - - while self.ready and len(self._executing) < self.nthreads: - _, key = heapq.heappop(self.ready) - ts = self.tasks.get(key) - if ts is None: - # It is possible for tasks to be released while still remaining on - # `ready`. The scheduler might have re-routed to a new worker and - # told this worker to release. If the task has "disappeared", just - # continue through the heap. - continue - - if key in self.data: - # See comment above about duplicates - if self.validate: - assert ts not in recs or recs[ts] == "memory" - recs[ts] = "memory" - elif ts.state in READY: - # See comment above about duplicates - if self.validate: - assert ts not in recs or recs[ts] == "executing" - recs[ts] = "executing" - self._executing.add(ts) - - return recs, [] - @fail_hard async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None: if self.status in {Status.closing, Status.closed, Status.closing_gracefully}: @@ -3896,136 +2254,6 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No stimulus_id=f"execute-unknown-error-{time()}", ) - @_handle_event.register - def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs: - """Emerge from paused status. Do not send this event directly. Instead, just set - Worker.status back to running. - """ - assert self.status == Status.running - return merge_recs_instructions( - self._ensure_computing(), - self._ensure_communicating(stimulus_id=ev.stimulus_id), - ) - - @_handle_event.register - def _handle_retry_busy_worker(self, ev: RetryBusyWorkerEvent) -> RecsInstrs: - self.busy_workers.discard(ev.worker) - return self._ensure_communicating(stimulus_id=ev.stimulus_id) - - @_handle_event.register - def _handle_cancel_compute(self, ev: CancelComputeEvent) -> RecsInstrs: - """Cancel a task on a best-effort basis. This is only possible while a task - is in state `waiting` or `ready`; nothing will happen otherwise. - """ - ts = self.tasks.get(ev.key) - if not ts or ts.state not in READY | {"waiting"}: - return {}, [] - - self.log.append((ev.key, "cancel-compute", ev.stimulus_id, time())) - # All possible dependents of ts should not be in state Processing on - # scheduler side and therefore should not be assigned to a worker, yet. - assert not ts.dependents - return {ts: "released"}, [] - - @_handle_event.register - def _handle_already_cancelled(self, ev: AlreadyCancelledEvent) -> RecsInstrs: - """Task is already cancelled by the time execute() runs""" - # key *must* be still in tasks. Releasing it directly is forbidden - # without going through cancelled - ts = self.tasks.get(ev.key) - assert ts, self.story(ev.key) - ts.done = True - return {ts: "released"}, [] - - @_handle_event.register - def _handle_execute_success(self, ev: ExecuteSuccessEvent) -> RecsInstrs: - """Task completed successfully""" - # key *must* be still in tasks. Releasing it directly is forbidden - # without going through cancelled - ts = self.tasks.get(ev.key) - assert ts, self.story(ev.key) - - ts.done = True - ts.startstops.append({"action": "compute", "start": ev.start, "stop": ev.stop}) - ts.nbytes = ev.nbytes - ts.type = ev.type - return {ts: ("memory", ev.value)}, [] - - @_handle_event.register - def _handle_execute_failure(self, ev: ExecuteFailureEvent) -> RecsInstrs: - """Task execution failed""" - # key *must* be still in tasks. Releasing it directly is forbidden - # without going through cancelled - ts = self.tasks.get(ev.key) - assert ts, self.story(ev.key) - - ts.done = True - if ev.start is not None and ev.stop is not None: - ts.startstops.append( - {"action": "compute", "start": ev.start, "stop": ev.stop} - ) - - return { - ts: ( - "error", - ev.exception, - ev.traceback, - ev.exception_text, - ev.traceback_text, - ) - }, [] - - @_handle_event.register - def _handle_reschedule(self, ev: RescheduleEvent) -> RecsInstrs: - """Task raised Reschedule exception while it was running""" - # key *must* be still in tasks. Releasing it directly is forbidden - # without going through cancelled - ts = self.tasks.get(ev.key) - assert ts, self.story(ev.key) - return {ts: "rescheduled"}, [] - - @_handle_event.register - def _handle_find_missing(self, ev: FindMissingEvent) -> RecsInstrs: - if not self._missing_dep_flight: - return {}, [] - - if self.validate: - for ts in self._missing_dep_flight: - assert not ts.who_has, self.story(ts) - - smsg = RequestRefreshWhoHasMsg( - keys=[ts.key for ts in self._missing_dep_flight], - stimulus_id=ev.stimulus_id, - ) - return {}, [smsg] - - @_handle_event.register - def _handle_refresh_who_has(self, ev: RefreshWhoHasEvent) -> RecsInstrs: - self._update_who_has(ev.who_has) - recommendations: Recs = {} - instructions: Instructions = [] - - for key in ev.who_has: - ts = self.tasks.get(key) - if not ts: - continue - - if ts.who_has and ts.state == "missing": - recommendations[ts] = "fetch" - elif ts.who_has and ts.state == "fetch": - # We potentially just acquired new replicas whereas all previously known - # workers are in flight or busy. We're deliberately not testing the - # minute use cases here for the sake of simplicity; instead we rely on - # _ensure_communicating to be a no-op when there's nothing to do. - recommendations, instructions = merge_recs_instructions( - (recommendations, instructions), - self._ensure_communicating(stimulus_id=ev.stimulus_id), - ) - elif not ts.who_has and ts.state == "fetch": - recommendations[ts] = "missing" - - return recommendations, instructions - def _prepare_args_for_execution( self, ts: TaskState, args: tuple, kwargs: dict[str, Any] ) -> tuple[tuple, dict[str, Any]]: @@ -4171,16 +2399,6 @@ def get_call_stack(self, keys: Collection[str] | None = None) -> dict[str, Any]: return {key: profile.call_stack(frame) for key, frame in frames.items()} - def _notify_plugins(self, method_name, *args, **kwargs): - for name, plugin in self.plugins.items(): - if hasattr(plugin, method_name): - try: - getattr(plugin, method_name)(*args, **kwargs) - except Exception: - logger.info( - "Plugin '%s' failed with exception", name, exc_info=True - ) - async def benchmark_disk(self) -> dict[str, float]: return await self.loop.run_in_executor( self.executor, benchmark_disk, self.local_directory @@ -4192,205 +2410,6 @@ async def benchmark_memory(self) -> dict[str, float]: async def benchmark_network(self, address: str) -> dict[str, float]: return await benchmark_network(rpc=self.rpc, address=address) - ############## - # Validation # - ############## - - def _validate_task_memory(self, ts): - assert ts.key in self.data or ts.key in self.actors - assert isinstance(ts.nbytes, int) - assert not ts.waiting_for_data - assert ts.key not in self.ready - assert ts.state == "memory" - - def _validate_task_executing(self, ts): - assert ts.state == "executing" - assert ts.run_spec is not None - assert ts.key not in self.data - assert not ts.waiting_for_data - for dep in ts.dependencies: - assert dep.state == "memory", self.story(dep) - assert dep.key in self.data or dep.key in self.actors - - def _validate_task_ready(self, ts): - assert ts.key in pluck(1, self.ready) - assert ts.key not in self.data - assert ts.state != "executing" - assert not ts.done - assert not ts.waiting_for_data - assert all( - dep.key in self.data or dep.key in self.actors for dep in ts.dependencies - ) - - def _validate_task_waiting(self, ts): - assert ts.key not in self.data - assert ts.state == "waiting" - assert not ts.done - if ts.dependencies and ts.run_spec: - assert not all(dep.key in self.data for dep in ts.dependencies) - - def _validate_task_flight(self, ts): - assert ts.key not in self.data - assert ts in self._in_flight_tasks - assert not any(dep.key in self.ready for dep in ts.dependents) - assert ts.coming_from - assert ts.coming_from in self.in_flight_workers - assert ts.key in self.in_flight_workers[ts.coming_from] - - def _validate_task_fetch(self, ts): - assert ts.key not in self.data - assert self.address not in ts.who_has - assert not ts.done - assert ts in self.data_needed - # Note: ts.who_has may be empty; see GatherDepNetworkFailureEvent - for w in ts.who_has: - assert ts.key in self.has_what[w] - assert ts in self.data_needed_per_worker[w] - - def _validate_task_missing(self, ts): - assert ts.key not in self.data - assert not ts.who_has - assert not ts.done - assert not any(ts.key in has_what for has_what in self.has_what.values()) - assert ts in self._missing_dep_flight - - def _validate_task_cancelled(self, ts): - assert ts.key not in self.data - assert ts._previous in {"long-running", "executing", "flight"} - # We'll always transition to released after it is done - assert ts._next is None, (ts.key, ts._next, self.story(ts)) - - def _validate_task_resumed(self, ts): - assert ts.key not in self.data - assert ts._next - assert ts._previous in {"long-running", "executing", "flight"} - - def _validate_task_released(self, ts): - assert ts.key not in self.data - assert not ts._next - assert not ts._previous - assert ts not in self.data_needed - for tss in self.data_needed_per_worker.values(): - assert ts not in tss - assert ts not in self._executing - assert ts not in self._in_flight_tasks - assert ts not in self._missing_dep_flight - - # FIXME the below assert statement is true most of the time. If a task - # performs the transition flight->cancel->waiting, its dependencies are - # normally in released state. However, the compute-task call for their - # previous dependent provided them with who_has, such that this assert - # is no longer true. - # assert not any(ts.key in has_what for has_what in self.has_what.values()) - - assert not ts.waiting_for_data - assert not ts.done - assert not ts.exception - assert not ts.traceback - - def validate_task(self, ts): - try: - if ts.key in self.tasks: - assert self.tasks[ts.key] == ts - if ts.state == "memory": - self._validate_task_memory(ts) - elif ts.state == "waiting": - self._validate_task_waiting(ts) - elif ts.state == "missing": - self._validate_task_missing(ts) - elif ts.state == "cancelled": - self._validate_task_cancelled(ts) - elif ts.state == "resumed": - self._validate_task_resumed(ts) - elif ts.state == "ready": - self._validate_task_ready(ts) - elif ts.state == "executing": - self._validate_task_executing(ts) - elif ts.state == "flight": - self._validate_task_flight(ts) - elif ts.state == "fetch": - self._validate_task_fetch(ts) - elif ts.state == "released": - self._validate_task_released(ts) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb - - pdb.set_trace() - - raise InvalidTaskState( - key=ts.key, state=ts.state, story=self.story(ts) - ) from e - - def validate_state(self): - try: - assert self.executing_count >= 0 - waiting_for_data_count = 0 - for ts in self.tasks.values(): - assert ts.state is not None - # check that worker has task - for worker in ts.who_has: - assert worker != self.address - assert ts.key in self.has_what[worker] - # check that deps have a set state and that dependency<->dependent links - # are there - for dep in ts.dependencies: - # self.tasks was just a dict of tasks - # and this check was originally that the key was in `task_state` - # so we may have popped the key out of `self.tasks` but the - # dependency can still be in `memory` before GC grabs it...? - # Might need better bookkeeping - assert dep.state is not None - assert ts in dep.dependents, ts - if ts.waiting_for_data: - waiting_for_data_count += 1 - for ts_wait in ts.waiting_for_data: - assert ts_wait.key in self.tasks - assert ( - ts_wait.state - in READY | {"executing", "flight", "fetch", "missing"} - or ts_wait in self._missing_dep_flight - or ts_wait.who_has.issubset(self.in_flight_workers) - ), (ts, ts_wait, self.story(ts), self.story(ts_wait)) - # FIXME https://github.com/dask/distributed/issues/6319 - # assert self.waiting_for_data_count == waiting_for_data_count - for worker, keys in self.has_what.items(): - assert worker != self.address - for k in keys: - assert k in self.tasks, self.story(k) - assert worker in self.tasks[k].who_has - - for ts in self.data_needed: - assert ts.state == "fetch", self.story(ts) - assert self.tasks[ts.key] is ts - for worker, tss in self.data_needed_per_worker.items(): - for ts in tss: - assert ts.state == "fetch" - assert self.tasks[ts.key] is ts - assert ts in self.data_needed - assert worker in ts.who_has - - for ts in self.tasks.values(): - self.validate_task(ts) - - if self.transition_counter_max: - assert self.transition_counter < self.transition_counter_max - - except Exception as e: - logger.error("Validate state failed", exc_info=e) - logger.exception(e) - if LOG_PDB: - import pdb - - pdb.set_trace() - - if hasattr(e, "to_event"): - topic, msg = e.to_event() - self.log_event(topic, msg) - - raise - ####################################### # Worker Clients (advanced workloads) # ####################################### diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index b6fa61f580..21b2043d05 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1,23 +1,55 @@ from __future__ import annotations +import abc +import asyncio +import functools +import heapq +import logging +import random import sys -from collections.abc import Collection, Container +from collections import defaultdict, deque +from collections.abc import ( + Callable, + Collection, + Container, + Iterator, + Mapping, + MutableMapping, +) from copy import copy from dataclasses import dataclass, field from functools import lru_cache -from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict +from pickle import PicklingError +from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict, cast -import dask -from dask.utils import parse_bytes +from tlz import peekn, pluck -from distributed.core import ErrorMessage, error_message +import dask +from dask.utils import parse_bytes, typename + +from distributed._stories import worker_story +from distributed.collections import HeapSet +from distributed.comm import get_address_host +from distributed.core import ErrorMessage, Status, error_message +from distributed.metrics import time +from distributed.protocol import pickle from distributed.protocol.serialize import Serialize +from distributed.sizeof import safe_sizeof as sizeof from distributed.utils import recursive_to_dict +logger = logging.getLogger(__name__) + +LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") + if TYPE_CHECKING: - # TODO move to typing and get out of TYPE_CHECKING (requires Python >=3.10) + # TODO import from typing (requires Python >=3.10) from typing_extensions import TypeAlias + # Circular imports + from distributed.actor import Actor + from distributed.diagnostics.plugin import WorkerPlugin + + # TODO move out of TYPE_CHECKING (requires Python >=3.10) TaskStateState: TypeAlias = Literal[ "cancelled", "constrained", @@ -819,3 +851,1982 @@ def merge_recs_instructions(*args: RecsInstrs) -> RecsInstrs: recs[ts] = finish instr += instr_i return recs, instr + + +class WorkerState: + address: str + data: MutableMapping[str, Any] + threads: dict[str, int] # {ts.key: thread ID} + plugins: dict[str, WorkerPlugin] + + tasks: dict[str, TaskState] + waiting_for_data_count: int + has_what: defaultdict[str, set[str]] # {worker address: {ts.key, ...} + data_needed: HeapSet[TaskState] + data_needed_per_worker: defaultdict[str, HeapSet[TaskState]] + in_flight_workers: dict[str, set[str]] # {worker address: {ts.key, ...}} + busy_workers: set[str] + total_out_connections: int + comm_threshold_bytes: int + comm_nbytes: int + _missing_dep_flight: set[TaskState] + generation: int + ready: list[tuple[tuple[int, ...], str]] # heapq [(priority, key), ...] + constrained: deque[str] + nthreads: int + available_resources: dict[str, float] + _executing: set[TaskState] + _in_flight_tasks: set[TaskState] + executed_count: int + long_running: set[str] + actors: dict[str, Actor | None] + log: deque[tuple] # [(..., stimulus_id: str | None, timestamp: float), ...] + stimulus_log: deque[StateMachineEvent] + target_message_size: int + transition_counter: int + transition_counter_max: int | Literal[False] + validate: bool + + ######################### + # Shared helper methods # + ######################### + + def _ensure_task_exists( + self, key: str, *, priority: tuple[int, ...], stimulus_id: str + ) -> TaskState: + try: + ts = self.tasks[key] + logger.debug("Data task %s already known (stimulus_id=%s)", ts, stimulus_id) + except KeyError: + self.tasks[key] = ts = TaskState(key) + if not ts.priority: + assert priority + ts.priority = priority + + self.log.append((key, "ensure-task-exists", ts.state, stimulus_id, time())) + return ts + + def _update_who_has(self, who_has: Mapping[str, Collection[str]]) -> None: + for key, workers in who_has.items(): + ts = self.tasks.get(key) + if not ts: + # The worker sent a refresh-who-has request to the scheduler but, by the + # time the answer comes back, some of the keys have been forgotten. + continue + workers = set(workers) + + if self.address in workers: + workers.remove(self.address) + # This can only happen if rebalance() recently asked to release a key, + # but the RPC call hasn't returned yet. rebalance() is flagged as not + # being safe to run while the cluster is not at rest and has already + # been penned in to be redesigned on top of the AMM. + # It is not necessary to send a message back to the + # scheduler here, because it is guaranteed that there's already a + # release-worker-data message in transit to it. + if ts.state != "memory": + logger.debug( # pragma: nocover + "Scheduler claims worker %s holds data for task %s, " + "which is not true.", + self.address, + ts, + ) + + if ts.who_has == workers: + continue + + for worker in ts.who_has - workers: + self.has_what[worker].remove(key) + if ts.state == "fetch": + self.data_needed_per_worker[worker].remove(ts) + + for worker in workers - ts.who_has: + self.has_what[worker].add(key) + if ts.state == "fetch": + self.data_needed_per_worker[worker].add(ts) + + ts.who_has = workers + + def _purge_state(self, ts: TaskState) -> None: + """Ensure that TaskState attributes are reset to a neutral default and + Worker-level state associated to the provided key is cleared (e.g. + who_has) + This is idempotent + """ + key = ts.key + logger.debug("Purge task key: %s state: %s; stimulus_id=%s", ts.key, ts.state) + self.data.pop(key, None) + self.actors.pop(key, None) + + for worker in ts.who_has: + self.has_what[worker].discard(ts.key) + self.data_needed_per_worker[worker].discard(ts) + ts.who_has.clear() + self.data_needed.discard(ts) + + self.threads.pop(key, None) + + for d in ts.dependencies: + ts.waiting_for_data.discard(d) + d.waiters.discard(ts) + + ts.waiting_for_data.clear() + ts.nbytes = None + ts._previous = None + ts._next = None + ts.done = False + + self._executing.discard(ts) + self._in_flight_tasks.discard(ts) + + def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: + if self.status != Status.running: + return {}, [] + + skipped_worker_in_flight_or_busy = [] + + recommendations: Recs = {} + instructions: Instructions = [] + + while self.data_needed and ( + len(self.in_flight_workers) < self.total_out_connections + or self.comm_nbytes < self.comm_threshold_bytes + ): + logger.debug( + "Ensure communicating. Pending: %d. Connections: %d/%d. Busy: %d", + len(self.data_needed), + len(self.in_flight_workers), + self.total_out_connections, + len(self.busy_workers), + ) + + ts = self.data_needed.pop() + + if self.validate: + assert ts.state == "fetch" + assert self.address not in ts.who_has + + if not ts.who_has: + recommendations[ts] = "missing" + continue + + workers = [ + w + for w in ts.who_has + if w not in self.in_flight_workers and w not in self.busy_workers + ] + if not workers: + skipped_worker_in_flight_or_busy.append(ts) + continue + + for w in ts.who_has: + self.data_needed_per_worker[w].remove(ts) + + host = get_address_host(self.address) + local = [w for w in workers if get_address_host(w) == host] + worker = random.choice(local or workers) + + to_gather_tasks, total_nbytes = self._select_keys_for_gather(worker, ts) + to_gather_keys = {ts.key for ts in to_gather_tasks} + + self.log.append( + ("gather-dependencies", worker, to_gather_keys, stimulus_id, time()) + ) + + self.comm_nbytes += total_nbytes + self.in_flight_workers[worker] = to_gather_keys + for d_ts in to_gather_tasks: + if self.validate: + assert d_ts.state == "fetch" + assert d_ts not in recommendations + recommendations[d_ts] = ("flight", worker) + + # A single invocation of _ensure_communicating may generate up to one + # GatherDep instruction per worker. Multiple tasks from the same worker may + # be clustered in the same instruction by _select_keys_for_gather. But once + # a worker has been selected for a GatherDep and added to in_flight_workers, + # it won't be selected again until the gather completes. + instructions.append( + GatherDep( + worker=worker, + to_gather=to_gather_keys, + total_nbytes=total_nbytes, + stimulus_id=stimulus_id, + ) + ) + + for ts in skipped_worker_in_flight_or_busy: + self.data_needed.add(ts) + + return recommendations, instructions + + def _ensure_computing(self) -> RecsInstrs: + if self.status != Status.running: + return {}, [] + + recs: Recs = {} + while self.constrained and len(self._executing) < self.nthreads: + key = self.constrained[0] + ts = self.tasks.get(key, None) + if ts is None or ts.state != "constrained": + self.constrained.popleft() + continue + + # There may be duplicates in the self.constrained and self.ready queues. + # This happens if a task: + # 1. is assigned to a Worker and transitioned to ready (heappush) + # 2. is stolen (no way to pop from heap, the task stays there) + # 3. is assigned to the worker again (heappush again) + if ts in recs: + continue + + if any( + self.available_resources[resource] < needed + for resource, needed in ts.resource_restrictions.items() + ): + break + + self.constrained.popleft() + for resource, needed in ts.resource_restrictions.items(): + self.available_resources[resource] -= needed + + recs[ts] = "executing" + self._executing.add(ts) + + while self.ready and len(self._executing) < self.nthreads: + _, key = heapq.heappop(self.ready) + ts = self.tasks.get(key) + if ts is None: + # It is possible for tasks to be released while still remaining on + # `ready`. The scheduler might have re-routed to a new worker and + # told this worker to release. If the task has "disappeared", just + # continue through the heap. + continue + + if key in self.data: + # See comment above about duplicates + if self.validate: + assert ts not in recs or recs[ts] == "memory" + recs[ts] = "memory" + elif ts.state in READY: + # See comment above about duplicates + if self.validate: + assert ts not in recs or recs[ts] == "executing" + recs[ts] = "executing" + self._executing.add(ts) + + return recs, [] + + def _get_task_finished_msg( + self, ts: TaskState, stimulus_id: str + ) -> TaskFinishedMsg: + if ts.key not in self.data and ts.key not in self.actors: + raise RuntimeError(f"Task {ts} not ready") + typ = ts.type + if ts.nbytes is None or typ is None: + try: + value = self.data[ts.key] + except KeyError: + value = self.actors[ts.key] + ts.nbytes = sizeof(value) + typ = ts.type = type(value) + del value + try: + typ_serialized = dumps_function(typ) + except PicklingError: + # Some types fail pickling (example: _thread.lock objects), + # send their name as a best effort. + typ_serialized = pickle.dumps(typ.__name__, protocol=4) + return TaskFinishedMsg( + key=ts.key, + nbytes=ts.nbytes, + type=typ_serialized, + typename=typename(typ), + metadata=ts.metadata, + thread=self.threads.get(ts.key), + startstops=ts.startstops, + stimulus_id=stimulus_id, + ) + + def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: + """ + Put a key into memory and set data related task state attributes. + On success, generate recommendations for dependents. + + This method does not generate any scheduler messages since this method + cannot distinguish whether it has to be an `add-task` or a + `task-finished` signal. The caller is required to generate this message + on success. + + Raises + ------ + Exception: + In case the data is put into the in memory buffer and a serialization error + occurs during spilling, this raises that error. This has to be handled by + the caller since most callers generate scheduler messages on success (see + comment above) but we need to signal that this was not successful. + + Can only trigger if distributed.worker.memory.target is enabled, the value + is individually larger than target * memory_limit, and the task is not an + actor. + """ + if ts.key in self.data: + ts.state = "memory" + return {} + + recommendations: Recs = {} + if ts.key in self.actors: + self.actors[ts.key] = value + else: + start = time() + self.data[ts.key] = value + stop = time() + if stop - start > 0.020: + ts.startstops.append( + {"action": "disk-write", "start": start, "stop": stop} + ) + + ts.state = "memory" + if ts.nbytes is None: + ts.nbytes = sizeof(value) + + ts.type = type(value) + + for dep in ts.dependents: + dep.waiting_for_data.discard(ts) + if not dep.waiting_for_data and dep.state == "waiting": + self.waiting_for_data_count -= 1 + recommendations[dep] = "ready" + + self.log.append((ts.key, "put-in-memory", stimulus_id, time())) + return recommendations + + def _select_keys_for_gather( + self, worker: str, ts: TaskState + ) -> tuple[set[TaskState], int]: + """``_ensure_communicating`` decided to fetch a single task from a worker, + following priority. In order to minimise overhead, request fetching other tasks + from the same worker within the message, following priority for the single + worker but ignoring higher priority tasks from other workers, up to + ``target_message_size``. + """ + tss = {ts} + total_bytes = ts.get_nbytes() + tasks = self.data_needed_per_worker[worker] + + while tasks: + ts = tasks.peek() + if self.validate: + assert ts.state == "fetch" + assert worker in ts.who_has + if total_bytes + ts.get_nbytes() > self.target_message_size: + break + tasks.pop() + self.data_needed.remove(ts) + for other_worker in ts.who_has: + if other_worker != worker: + self.data_needed_per_worker[other_worker].remove(ts) + + tss.add(ts) + total_bytes += ts.get_nbytes() + + return tss, total_bytes + + ############### + # Transitions # + ############### + + def _transition_generic_fetch(self, ts: TaskState, stimulus_id: str) -> RecsInstrs: + if not ts.who_has: + return {ts: "missing"}, [] + + ts.state = "fetch" + ts.done = False + assert ts.priority + self.data_needed.add(ts) + for w in ts.who_has: + self.data_needed_per_worker[w].add(ts) + + # This is the same as `return self._ensure_communicating()`, except that when + # many tasks transition to fetch at the same time, e.g. from a single + # compute-task or acquire-replicas command from the scheduler, it allows + # clustering the transfers into less GatherDep instructions; see + # _select_keys_for_gather(). + return {}, [EnsureCommunicatingAfterTransitions(stimulus_id=stimulus_id)] + + def _transition_missing_waiting( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + self._missing_dep_flight.discard(ts) + self._purge_state(ts) + return self._transition_released_waiting(ts, stimulus_id=stimulus_id) + + def _transition_missing_fetch( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + if self.validate: + assert ts.state == "missing" + + if not ts.who_has: + return {}, [] + + self._missing_dep_flight.discard(ts) + return self._transition_generic_fetch(ts, stimulus_id=stimulus_id) + + def _transition_missing_released( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + self._missing_dep_flight.discard(ts) + recs, instructions = self._transition_generic_released( + ts, stimulus_id=stimulus_id + ) + assert ts.key in self.tasks + return recs, instructions + + def _transition_flight_missing( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + assert ts.done + return self._transition_generic_missing(ts, stimulus_id=stimulus_id) + + def _transition_generic_missing( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + if self.validate: + assert not ts.who_has + + ts.state = "missing" + self._missing_dep_flight.add(ts) + ts.done = False + return {}, [] + + def _transition_released_fetch( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + if self.validate: + assert ts.state == "released" + return self._transition_generic_fetch(ts, stimulus_id=stimulus_id) + + def _transition_generic_released( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + self._purge_state(ts) + recs: Recs = {} + for dependency in ts.dependencies: + if ( + not dependency.waiters + and dependency.state not in READY | PROCESSING | {"memory"} + ): + recs[dependency] = "released" + + ts.state = "released" + if not ts.dependents: + recs[ts] = "forgotten" + + return merge_recs_instructions( + (recs, []), + self._ensure_computing(), + ) + + def _transition_released_waiting( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + if self.validate: + assert all(d.key in self.tasks for d in ts.dependencies) + + recommendations: Recs = {} + ts.waiting_for_data.clear() + for dep_ts in ts.dependencies: + if dep_ts.state != "memory": + ts.waiting_for_data.add(dep_ts) + dep_ts.waiters.add(ts) + recommendations[dep_ts] = "fetch" + + if ts.waiting_for_data: + self.waiting_for_data_count += 1 + elif ts.resource_restrictions: + recommendations[ts] = "constrained" + else: + recommendations[ts] = "ready" + + ts.state = "waiting" + return recommendations, [] + + def _transition_fetch_flight( + self, ts: TaskState, worker: str, *, stimulus_id: str + ) -> RecsInstrs: + if self.validate: + assert ts.state == "fetch" + assert ts.who_has + # The task has already been removed by _ensure_communicating + assert ts not in self.data_needed + for w in ts.who_has: + assert ts not in self.data_needed_per_worker[w] + + ts.done = False + ts.state = "flight" + ts.coming_from = worker + self._in_flight_tasks.add(ts) + return {}, [] + + def _transition_fetch_missing( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + # _ensure_communicating could have just popped this task out of data_needed + self.data_needed.discard(ts) + return self._transition_generic_missing(ts, stimulus_id=stimulus_id) + + def _transition_memory_released( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + recs, instructions = self._transition_generic_released( + ts, stimulus_id=stimulus_id + ) + instructions.append(ReleaseWorkerDataMsg(key=ts.key, stimulus_id=stimulus_id)) + return recs, instructions + + def _transition_waiting_constrained( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + if self.validate: + assert ts.state == "waiting" + assert not ts.waiting_for_data + assert all( + dep.key in self.data or dep.key in self.actors + for dep in ts.dependencies + ) + assert all(dep.state == "memory" for dep in ts.dependencies) + assert ts.key not in self.ready + ts.state = "constrained" + self.constrained.append(ts.key) + return self._ensure_computing() + + def _transition_long_running_rescheduled( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + recs: Recs = {ts: "released"} + smsg = RescheduleMsg(key=ts.key, stimulus_id=stimulus_id) + return recs, [smsg] + + def _transition_executing_rescheduled( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + for resource, quantity in ts.resource_restrictions.items(): + self.available_resources[resource] += quantity + self._executing.discard(ts) + + return merge_recs_instructions( + ( + {ts: "released"}, + [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)], + ), + self._ensure_computing(), + ) + + def _transition_waiting_ready( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + if self.validate: + assert ts.state == "waiting" + assert ts.key not in self.ready + assert not ts.waiting_for_data + for dep in ts.dependencies: + assert dep.key in self.data or dep.key in self.actors + assert dep.state == "memory" + + ts.state = "ready" + assert ts.priority is not None + heapq.heappush(self.ready, (ts.priority, ts.key)) + + return self._ensure_computing() + + def _transition_cancelled_error( + self, + ts: TaskState, + exception: Serialize, + traceback: Serialize | None, + exception_text: str, + traceback_text: str, + *, + stimulus_id: str, + ) -> RecsInstrs: + assert ts._previous == "executing" or ts.key in self.long_running + recs, instructions = self._transition_executing_error( + ts, + exception, + traceback, + exception_text, + traceback_text, + stimulus_id=stimulus_id, + ) + # We'll ignore instructions, i.e. we choose to not submit the failure + # message to the scheduler since from the schedulers POV it already + # released this task + if self.validate: + assert len(instructions) == 1 + assert isinstance(instructions[0], TaskErredMsg) + assert instructions[0].key == ts.key + instructions.clear() + # Workers should never "retry" tasks. A transition to error should, by + # default, be the end. Since cancelled indicates that the scheduler lost + # interest, we can transition straight to released + assert ts not in recs + recs[ts] = "released" + return recs, instructions + + def _transition_generic_error( + self, + ts: TaskState, + exception: Serialize, + traceback: Serialize | None, + exception_text: str, + traceback_text: str, + *, + stimulus_id: str, + ) -> RecsInstrs: + ts.exception = exception + ts.traceback = traceback + ts.exception_text = exception_text + ts.traceback_text = traceback_text + ts.state = "error" + smsg = TaskErredMsg.from_task( + ts, + stimulus_id=stimulus_id, + thread=self.threads.get(ts.key), + ) + + return {}, [smsg] + + def _transition_executing_error( + self, + ts: TaskState, + exception: Serialize, + traceback: Serialize | None, + exception_text: str, + traceback_text: str, + *, + stimulus_id: str, + ) -> RecsInstrs: + for resource, quantity in ts.resource_restrictions.items(): + self.available_resources[resource] += quantity + self._executing.discard(ts) + + return merge_recs_instructions( + self._transition_generic_error( + ts, + exception, + traceback, + exception_text, + traceback_text, + stimulus_id=stimulus_id, + ), + self._ensure_computing(), + ) + + def _transition_from_resumed( + self, ts: TaskState, finish: TaskStateState, *args, stimulus_id: str + ) -> RecsInstrs: + """`resumed` is an intermediate degenerate state which splits further up + into two states depending on what the last signal / next state is + intended to be. There are only two viable choices depending on whether + the task is required to be fetched from another worker `resumed(fetch)` + or the task shall be computed on this worker `resumed(waiting)`. + + The only viable state transitions ending up here are + + flight -> cancelled -> resumed(waiting) + + or + + executing -> cancelled -> resumed(fetch) + + depending on the origin. Equally, only `fetch`, `waiting`, or `released` + are allowed output states. + + See also `_transition_resumed_waiting` + """ + recs: Recs = {} + instructions: Instructions = [] + + if ts._previous == finish: + # We're back where we started. We should forget about the entire + # cancellation attempt + ts.state = finish + ts._next = None + ts._previous = None + elif not ts.done: + # If we're not done, yet, just remember where we want to be next + ts._next = finish + else: + # Flight/executing finished unsuccessfully, i.e. not in memory + assert finish != "memory" + next_state = ts._next + assert next_state in {"waiting", "fetch"}, next_state + assert ts._previous in {"executing", "flight"}, ts._previous + + if next_state != finish: + recs, instructions = self._transition_generic_released( + ts, stimulus_id=stimulus_id + ) + recs[ts] = next_state + + return recs, instructions + + def _transition_resumed_fetch( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + """See Worker._transition_from_resumed""" + recs, instructions = self._transition_from_resumed( + ts, "fetch", stimulus_id=stimulus_id + ) + if self.validate: + # This would only be possible in a fetch->cancelled->resumed->fetch loop, + # but there are no transitions from fetch which set the state to cancelled. + # If this assertion failed, we' need to call _ensure_communicating like in + # the other transitions that set ts.status = "fetch". + assert ts.state != "fetch" + return recs, instructions + + def _transition_resumed_missing( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + """See Worker._transition_from_resumed""" + return self._transition_from_resumed(ts, "missing", stimulus_id=stimulus_id) + + def _transition_resumed_released( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + if not ts.done: + ts.state = "cancelled" + ts._next = None + return {}, [] + else: + return self._transition_generic_released(ts, stimulus_id=stimulus_id) + + def _transition_resumed_waiting( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + """See Worker._transition_from_resumed""" + return self._transition_from_resumed(ts, "waiting", stimulus_id=stimulus_id) + + def _transition_cancelled_fetch( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + if ts.done: + return {ts: "released"}, [] + elif ts._previous == "flight": + ts.state = ts._previous + return {}, [] + else: + assert ts._previous == "executing" + ts.state = "resumed" + ts._next = "fetch" + return {}, [] + + def _transition_cancelled_waiting( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + if ts.done: + return {ts: "released"}, [] + elif ts._previous == "executing": + ts.state = ts._previous + return {}, [] + else: + assert ts._previous == "flight" + ts.state = "resumed" + ts._next = "waiting" + return {}, [] + + def _transition_cancelled_forgotten( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + ts._next = "forgotten" + if not ts.done: + return {}, [] + return {ts: "released"}, [] + + def _transition_cancelled_released( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + if not ts.done: + return {}, [] + self._executing.discard(ts) + self._in_flight_tasks.discard(ts) + + for resource, quantity in ts.resource_restrictions.items(): + self.available_resources[resource] += quantity + + return self._transition_generic_released(ts, stimulus_id=stimulus_id) + + def _transition_executing_released( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + ts._previous = ts.state + ts._next = None + # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 + ts.state = "cancelled" + ts.done = False + return self._ensure_computing() + + def _transition_long_running_memory( + self, ts: TaskState, value=NO_VALUE, *, stimulus_id: str + ) -> RecsInstrs: + self.executed_count += 1 + return self._transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) + + def _transition_generic_memory( + self, ts: TaskState, value=NO_VALUE, *, stimulus_id: str + ) -> RecsInstrs: + if value is NO_VALUE and ts.key not in self.data: + raise RuntimeError( + f"Tried to transition task {ts} to `memory` without data available" + ) + + if ts.resource_restrictions is not None: + for resource, quantity in ts.resource_restrictions.items(): + self.available_resources[resource] += quantity + + self._executing.discard(ts) + self._in_flight_tasks.discard(ts) + ts.coming_from = None + + instructions: Instructions = [] + try: + recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) + except Exception as e: + msg = error_message(e) + recs = {ts: tuple(msg.values())} + else: + if self.validate: + assert ts.key in self.data or ts.key in self.actors + instructions.append( + self._get_task_finished_msg(ts, stimulus_id=stimulus_id) + ) + + return recs, instructions + + def _transition_executing_memory( + self, ts: TaskState, value=NO_VALUE, *, stimulus_id: str + ) -> RecsInstrs: + if self.validate: + assert ts.state == "executing" or ts.key in self.long_running + assert not ts.waiting_for_data + assert ts.key not in self.ready + + self._executing.discard(ts) + self.executed_count += 1 + return merge_recs_instructions( + self._transition_generic_memory(ts, value=value, stimulus_id=stimulus_id), + self._ensure_computing(), + ) + + def _transition_constrained_executing( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + if self.validate: + assert not ts.waiting_for_data + assert ts.key not in self.data + assert ts.state in READY + assert ts.key not in self.ready + for dep in ts.dependencies: + assert dep.key in self.data or dep.key in self.actors + + ts.state = "executing" + instr = Execute(key=ts.key, stimulus_id=stimulus_id) + return {}, [instr] + + def _transition_ready_executing( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + if self.validate: + assert not ts.waiting_for_data + assert ts.key not in self.data + assert ts.state in READY + assert ts.key not in self.ready + assert all( + dep.key in self.data or dep.key in self.actors + for dep in ts.dependencies + ) + + ts.state = "executing" + instr = Execute(key=ts.key, stimulus_id=stimulus_id) + return {}, [instr] + + def _transition_flight_fetch( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + # If this transition is called after the flight coroutine has finished, + # we can reset the task and transition to fetch again. If it is not yet + # finished, this should be a no-op + if not ts.done: + return {}, [] + + ts.coming_from = None + return self._transition_generic_fetch(ts, stimulus_id=stimulus_id) + + def _transition_flight_error( + self, + ts: TaskState, + exception: Serialize, + traceback: Serialize | None, + exception_text: str, + traceback_text: str, + *, + stimulus_id: str, + ) -> RecsInstrs: + self._in_flight_tasks.discard(ts) + ts.coming_from = None + return self._transition_generic_error( + ts, + exception, + traceback, + exception_text, + traceback_text, + stimulus_id=stimulus_id, + ) + + def _transition_flight_released( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + if ts.done: + # FIXME: Is this even possible? Would an assert instead be more + # sensible? + return self._transition_generic_released(ts, stimulus_id=stimulus_id) + else: + ts._previous = "flight" + ts._next = None + # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 + ts.state = "cancelled" + return {}, [] + + def _transition_cancelled_memory(self, ts, value, *, stimulus_id): + # We only need this because the to-memory signatures require a value but + # we do not want to store a cancelled result and want to release + # immediately + assert ts.done + + return self._transition_cancelled_released(ts, stimulus_id=stimulus_id) + + def _transition_executing_long_running( + self, ts: TaskState, compute_duration: float, *, stimulus_id: str + ) -> RecsInstrs: + ts.state = "long-running" + self._executing.discard(ts) + self.long_running.add(ts.key) + + smsg = LongRunningMsg( + key=ts.key, compute_duration=compute_duration, stimulus_id=stimulus_id + ) + return merge_recs_instructions( + ({}, [smsg]), + self._ensure_computing(), + ) + + def _transition_released_memory( + self, ts: TaskState, value, *, stimulus_id: str + ) -> RecsInstrs: + try: + recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) + except Exception as e: + msg = error_message(e) + recs = {ts: tuple(msg.values())} + return recs, [] + smsg = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id) + return recs, [smsg] + + def _transition_flight_memory( + self, ts: TaskState, value, *, stimulus_id: str + ) -> RecsInstrs: + self._in_flight_tasks.discard(ts) + ts.coming_from = None + try: + recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) + except Exception as e: + msg = error_message(e) + recs = {ts: tuple(msg.values())} + return recs, [] + smsg = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id) + return recs, [smsg] + + def _transition_released_forgotten( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + recommendations: Recs = {} + # Dependents _should_ be released by the scheduler before this + if self.validate: + assert not any(d.state != "forgotten" for d in ts.dependents) + for dep in ts.dependencies: + dep.dependents.discard(ts) + if dep.state == "released" and not dep.dependents: + recommendations[dep] = "forgotten" + self._purge_state(ts) + # Mark state as forgotten in case it is still referenced + ts.state = "forgotten" + self.tasks.pop(ts.key, None) + return recommendations, [] + + # { + # (start, finish): + # transition__( + # self, ts: TaskState, *args, stimulus_id: str + # ) -> (recommendations, instructions) + # } + _TRANSITIONS_TABLE: ClassVar[ + Mapping[tuple[TaskStateState, TaskStateState], Callable[..., RecsInstrs]] + ] = { + ("cancelled", "fetch"): _transition_cancelled_fetch, + ("cancelled", "released"): _transition_cancelled_released, + ("cancelled", "missing"): _transition_cancelled_released, + ("cancelled", "waiting"): _transition_cancelled_waiting, + ("cancelled", "forgotten"): _transition_cancelled_forgotten, + ("cancelled", "memory"): _transition_cancelled_memory, + ("cancelled", "error"): _transition_cancelled_error, + ("resumed", "memory"): _transition_generic_memory, + ("resumed", "error"): _transition_generic_error, + ("resumed", "released"): _transition_resumed_released, + ("resumed", "waiting"): _transition_resumed_waiting, + ("resumed", "fetch"): _transition_resumed_fetch, + ("resumed", "missing"): _transition_resumed_missing, + ("constrained", "executing"): _transition_constrained_executing, + ("constrained", "released"): _transition_generic_released, + ("error", "released"): _transition_generic_released, + ("executing", "error"): _transition_executing_error, + ("executing", "long-running"): _transition_executing_long_running, + ("executing", "memory"): _transition_executing_memory, + ("executing", "released"): _transition_executing_released, + ("executing", "rescheduled"): _transition_executing_rescheduled, + ("fetch", "flight"): _transition_fetch_flight, + ("fetch", "missing"): _transition_fetch_missing, + ("fetch", "released"): _transition_generic_released, + ("flight", "error"): _transition_flight_error, + ("flight", "fetch"): _transition_flight_fetch, + ("flight", "memory"): _transition_flight_memory, + ("flight", "missing"): _transition_flight_missing, + ("flight", "released"): _transition_flight_released, + ("long-running", "error"): _transition_generic_error, + ("long-running", "memory"): _transition_long_running_memory, + ("long-running", "rescheduled"): _transition_executing_rescheduled, + ("long-running", "released"): _transition_executing_released, + ("memory", "released"): _transition_memory_released, + ("missing", "fetch"): _transition_missing_fetch, + ("missing", "released"): _transition_missing_released, + ("missing", "error"): _transition_generic_error, + ("missing", "waiting"): _transition_missing_waiting, + ("ready", "error"): _transition_generic_error, + ("ready", "executing"): _transition_ready_executing, + ("ready", "released"): _transition_generic_released, + ("released", "error"): _transition_generic_error, + ("released", "fetch"): _transition_released_fetch, + ("released", "missing"): _transition_generic_missing, + ("released", "forgotten"): _transition_released_forgotten, + ("released", "memory"): _transition_released_memory, + ("released", "waiting"): _transition_released_waiting, + ("waiting", "constrained"): _transition_waiting_constrained, + ("waiting", "ready"): _transition_waiting_ready, + ("waiting", "released"): _transition_generic_released, + } + + def _notify_plugins(self, method_name, *args, **kwargs): + for name, plugin in self.plugins.items(): + if hasattr(plugin, method_name): + try: + getattr(plugin, method_name)(*args, **kwargs) + except Exception: + logger.info( + "Plugin '%s' failed with exception", name, exc_info=True + ) + + def _transition( + self, + ts: TaskState, + finish: TaskStateState | tuple, + *args, + stimulus_id: str, + **kwargs, + ) -> RecsInstrs: + """Transition a key from its current state to the finish state + + See Also + -------- + Worker.transitions: wrapper around this method + """ + if isinstance(finish, tuple): + # the concatenated transition path might need to access the tuple + assert not args + args = finish[1:] + finish = cast(TaskStateState, finish[0]) + + if ts.state == finish: + return {}, [] + + start = ts.state + func = self._TRANSITIONS_TABLE.get((start, finish)) + + # Notes: + # - in case of transition through released, this counter is incremented by 2 + # - this increase happens before the actual transitions, so that it can + # catch potential infinite recursions + self.transition_counter += 1 + if ( + self.transition_counter_max + and self.transition_counter >= self.transition_counter_max + ): + raise TransitionCounterMaxExceeded(ts.key, start, finish, self.story(ts)) + + if func is not None: + recs, instructions = func( + self, ts, *args, stimulus_id=stimulus_id, **kwargs + ) + self._notify_plugins("transition", ts.key, start, finish, **kwargs) + + elif "released" not in (start, finish): + # start -> "released" -> finish + try: + recs, instructions = self._transition( + ts, "released", stimulus_id=stimulus_id + ) + v_state: TaskStateState + v_args: list | tuple + while v := recs.pop(ts, None): + if isinstance(v, tuple): + v_state, *v_args = v + else: + v_state, v_args = v, () + if v_state == "forgotten": + # We do not want to forget. The purpose of this + # transition path is to get to `finish` + continue + recs, instructions = merge_recs_instructions( + (recs, instructions), + self._transition(ts, v_state, *v_args, stimulus_id=stimulus_id), + ) + recs, instructions = merge_recs_instructions( + (recs, instructions), + self._transition(ts, finish, *args, stimulus_id=stimulus_id), + ) + except (InvalidTransition, RecommendationsConflict) as e: + raise InvalidTransition(ts.key, start, finish, self.story(ts)) from e + + else: + raise InvalidTransition(ts.key, start, finish, self.story(ts)) + + self.log.append( + ( + # key + ts.key, + # initial + start, + # recommended + finish, + # final + ts.state, + # new recommendations + {ts.key: new for ts, new in recs.items()}, + stimulus_id, + time(), + ) + ) + return recs, instructions + + def _transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: + """Process transitions until none are left + + This includes feedback from previous transitions and continues until we + reach a steady state + """ + instructions = [] + + remaining_recs = recommendations.copy() + tasks = set() + while remaining_recs: + ts, finish = remaining_recs.popitem() + tasks.add(ts) + a_recs, a_instructions = self._transition( + ts, finish, stimulus_id=stimulus_id + ) + + remaining_recs.update(a_recs) + instructions += a_instructions + + if self.validate: + # Full state validation is very expensive + for ts in tasks: + self.validate_task(ts) + + self._handle_instructions(instructions) + + ########## + # Events # + ########## + + @functools.singledispatchmethod + def _handle_event(self, ev: StateMachineEvent) -> RecsInstrs: + raise TypeError(ev) # pragma: nocover + + @_handle_event.register + def _handle_update_data(self, ev: UpdateDataEvent) -> RecsInstrs: + recommendations: Recs = {} + instructions: Instructions = [] + for key, value in ev.data.items(): + try: + ts = self.tasks[key] + recommendations[ts] = ("memory", value) + except KeyError: + self.tasks[key] = ts = TaskState(key) + + try: + recs = self._put_key_in_memory( + ts, value, stimulus_id=ev.stimulus_id + ) + except Exception as e: + msg = error_message(e) + recommendations = {ts: tuple(msg.values())} + else: + recommendations.update(recs) + + self.log.append((key, "receive-from-scatter", ev.stimulus_id, time())) + + if ev.report: + instructions.append( + AddKeysMsg(keys=list(ev.data), stimulus_id=ev.stimulus_id) + ) + + return recommendations, instructions + + @_handle_event.register + def _handle_free_keys(self, ev: FreeKeysEvent) -> RecsInstrs: + """Handler to be called by the scheduler. + + The given keys are no longer referred to and required by the scheduler. + The worker is now allowed to release the key, if applicable. + + This does not guarantee that the memory is released since the worker may + still decide to hold on to the data and task since it is required by an + upstream dependency. + """ + self.log.append(("free-keys", ev.keys, ev.stimulus_id, time())) + recommendations: Recs = {} + for key in ev.keys: + ts = self.tasks.get(key) + if ts: + recommendations[ts] = "released" + return recommendations, [] + + @_handle_event.register + def _handle_remove_replicas(self, ev: RemoveReplicasEvent) -> RecsInstrs: + """Stream handler notifying the worker that it might be holding unreferenced, + superfluous data. + + This should not actually happen during ordinary operations and is only intended + to correct any erroneous state. An example where this is necessary is if a + worker fetches data for a downstream task but that task is released before the + data arrives. In this case, the scheduler will notify the worker that it may be + holding this unnecessary data, if the worker hasn't released the data itself, + already. + + This handler does not guarantee the task nor the data to be actually + released but only asks the worker to release the data on a best effort + guarantee. This protects from race conditions where the given keys may + already have been rescheduled for compute in which case the compute + would win and this handler is ignored. + + For stronger guarantees, see handler free_keys + """ + recommendations: Recs = {} + instructions: Instructions = [] + + rejected = [] + for key in ev.keys: + ts = self.tasks.get(key) + if ts is None or ts.state != "memory": + continue + if not ts.is_protected(): + self.log.append( + (ts.key, "remove-replica-confirmed", ev.stimulus_id, time()) + ) + recommendations[ts] = "released" + else: + rejected.append(key) + + if rejected: + self.log.append( + ("remove-replica-rejected", rejected, ev.stimulus_id, time()) + ) + instructions.append(AddKeysMsg(keys=rejected, stimulus_id=ev.stimulus_id)) + + return recommendations, instructions + + @_handle_event.register + def _handle_acquire_replicas(self, ev: AcquireReplicasEvent) -> RecsInstrs: + if self.validate: + assert all(ev.who_has.values()) + + recommendations: Recs = {} + for key in ev.who_has: + ts = self._ensure_task_exists( + key=key, + # Transfer this data after all dependency tasks of computations with + # default or explicitly high (>0) user priority and before all + # computations with low priority (<0). Note that the priority= parameter + # of compute() is multiplied by -1 before it reaches TaskState.priority. + priority=(1,), + stimulus_id=ev.stimulus_id, + ) + if ts.state != "memory": + recommendations[ts] = "fetch" + + self._update_who_has(ev.who_has) + return recommendations, [] + + @_handle_event.register + def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs: + try: + ts = self.tasks[ev.key] + logger.debug( + "Asked to compute an already known task %s", + {"task": ts, "stimulus_id": ev.stimulus_id}, + ) + except KeyError: + self.tasks[ev.key] = ts = TaskState(ev.key) + self.log.append((ev.key, "compute-task", ts.state, ev.stimulus_id, time())) + + recommendations: Recs = {} + instructions: Instructions = [] + + if ts.state in READY | { + "executing", + "long-running", + "waiting", + }: + pass + elif ts.state == "memory": + instructions.append( + self._get_task_finished_msg(ts, stimulus_id=ev.stimulus_id) + ) + elif ts.state == "error": + instructions.append(TaskErredMsg.from_task(ts, stimulus_id=ev.stimulus_id)) + elif ts.state in { + "released", + "fetch", + "flight", + "missing", + "cancelled", + "resumed", + }: + recommendations[ts] = "waiting" + + ts.run_spec = ev.run_spec + + priority = ev.priority + (self.generation,) + self.generation -= 1 + + if ev.actor: + self.actors[ts.key] = None + + ts.exception = None + ts.traceback = None + ts.exception_text = "" + ts.traceback_text = "" + ts.priority = priority + ts.duration = ev.duration + ts.resource_restrictions = ev.resource_restrictions + ts.annotations = ev.annotations + + if self.validate: + assert ev.who_has.keys() == ev.nbytes.keys() + assert all(ev.who_has.values()) + + for dep_key, dep_workers in ev.who_has.items(): + dep_ts = self._ensure_task_exists( + key=dep_key, + priority=priority, + stimulus_id=ev.stimulus_id, + ) + # link up to child / parents + ts.dependencies.add(dep_ts) + dep_ts.dependents.add(ts) + + for dep_key, value in ev.nbytes.items(): + self.tasks[dep_key].nbytes = value + + self._update_who_has(ev.who_has) + else: + raise RuntimeError( # pragma: nocover + f"Unexpected task state encountered for {ts}; " + f"stimulus_id={ev.stimulus_id}; story={self.story(ts)}" + ) + + return recommendations, instructions + + def _gather_dep_done_common(self, ev: GatherDepDoneEvent) -> Iterator[TaskState]: + """Common code for the handlers of all subclasses of GatherDepDoneEvent. + + Yields the tasks that need to transition out of flight. + """ + self.comm_nbytes -= ev.total_nbytes + keys = self.in_flight_workers.pop(ev.worker) + for key in keys: + ts = self.tasks[key] + ts.done = True + yield ts + + @_handle_event.register + def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs: + """gather_dep terminated successfully. + The response may contain less keys than the request. + """ + recommendations: Recs = {} + for ts in self._gather_dep_done_common(ev): + if ts.key in ev.data: + recommendations[ts] = ("memory", ev.data[ts.key]) + else: + self.log.append((ts.key, "missing-dep", ev.stimulus_id, time())) + if self.validate: + assert ts.state != "fetch" + assert ts not in self.data_needed_per_worker[ev.worker] + ts.who_has.discard(ev.worker) + self.has_what[ev.worker].discard(ts.key) + recommendations[ts] = "fetch" + + return merge_recs_instructions( + (recommendations, []), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) + + @_handle_event.register + def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs: + """gather_dep terminated: remote worker is busy""" + # Avoid hammering the worker. If there are multiple replicas + # available, immediately try fetching from a different worker. + self.busy_workers.add(ev.worker) + + recommendations: Recs = {} + refresh_who_has = [] + for ts in self._gather_dep_done_common(ev): + recommendations[ts] = "fetch" + if not ts.who_has - self.busy_workers: + refresh_who_has.append(ts.key) + + instructions: Instructions = [ + RetryBusyWorkerLater(worker=ev.worker, stimulus_id=ev.stimulus_id), + ] + + if refresh_who_has: + # All workers that hold known replicas of our tasks are busy. + # Try querying the scheduler for unknown ones. + instructions.append( + RequestRefreshWhoHasMsg( + keys=refresh_who_has, stimulus_id=ev.stimulus_id + ) + ) + + return merge_recs_instructions( + (recommendations, instructions), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) + + @_handle_event.register + def _handle_gather_dep_network_failure( + self, ev: GatherDepNetworkFailureEvent + ) -> RecsInstrs: + """gather_dep terminated: network failure while trying to + communicate with remote worker + + Though the network failure could be transient, we assume it is not, and + preemptively act as though the other worker has died (including removing all + keys from it, even ones we did not fetch). + + This optimization leads to faster completion of the fetch, since we immediately + either retry a different worker, or ask the scheduler to inform us of a new + worker if no other worker is available. + """ + self.data_needed_per_worker.pop(ev.worker) + for key in self.has_what.pop(ev.worker): + ts = self.tasks[key] + ts.who_has.discard(ev.worker) + + recommendations: Recs = {} + for ts in self._gather_dep_done_common(ev): + self.log.append((ts.key, "missing-dep", ev.stimulus_id, time())) + recommendations[ts] = "fetch" + + return merge_recs_instructions( + (recommendations, []), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) + + @_handle_event.register + def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs: + """gather_dep terminated: generic error raised (not a network failure); + e.g. data failed to deserialize. + """ + recommendations: Recs = { + ts: ( + "error", + ev.exception, + ev.traceback, + ev.exception_text, + ev.traceback_text, + ) + for ts in self._gather_dep_done_common(ev) + } + + return merge_recs_instructions( + (recommendations, []), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) + + @_handle_event.register + def _handle_secede(self, ev: SecedeEvent) -> RecsInstrs: + ts = self.tasks.get(ev.key) + if ts and ts.state == "executing": + return {ts: ("long-running", ev.compute_duration)}, [] + else: + return {}, [] + + @_handle_event.register + def _handle_steal_request(self, ev: StealRequestEvent) -> RecsInstrs: + # There may be a race condition between stealing and releasing a task. + # In this case the self.tasks is already cleared. The `None` will be + # registered as `already-computing` on the other end + ts = self.tasks.get(ev.key) + state = ts.state if ts is not None else None + smsg = StealResponseMsg(key=ev.key, state=state, stimulus_id=ev.stimulus_id) + + if state in READY | {"waiting"}: + # If task is marked as "constrained" we haven't yet assigned it an + # `available_resources` to run on, that happens in + # `_transition_constrained_executing` + assert ts + return {ts: "released"}, [smsg] + else: + return {}, [smsg] + + @_handle_event.register + def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs: + """Emerge from paused status. Do not send this event directly. Instead, just set + Worker.status back to running. + """ + assert self.status == Status.running + return merge_recs_instructions( + self._ensure_computing(), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) + + @_handle_event.register + def _handle_retry_busy_worker(self, ev: RetryBusyWorkerEvent) -> RecsInstrs: + self.busy_workers.discard(ev.worker) + return self._ensure_communicating(stimulus_id=ev.stimulus_id) + + @_handle_event.register + def _handle_cancel_compute(self, ev: CancelComputeEvent) -> RecsInstrs: + """Cancel a task on a best-effort basis. This is only possible while a task + is in state `waiting` or `ready`; nothing will happen otherwise. + """ + ts = self.tasks.get(ev.key) + if not ts or ts.state not in READY | {"waiting"}: + return {}, [] + + self.log.append((ev.key, "cancel-compute", ev.stimulus_id, time())) + # All possible dependents of ts should not be in state Processing on + # scheduler side and therefore should not be assigned to a worker, yet. + assert not ts.dependents + return {ts: "released"}, [] + + @_handle_event.register + def _handle_already_cancelled(self, ev: AlreadyCancelledEvent) -> RecsInstrs: + """Task is already cancelled by the time execute() runs""" + # key *must* be still in tasks. Releasing it directly is forbidden + # without going through cancelled + ts = self.tasks.get(ev.key) + assert ts, self.story(ev.key) + ts.done = True + return {ts: "released"}, [] + + @_handle_event.register + def _handle_execute_success(self, ev: ExecuteSuccessEvent) -> RecsInstrs: + """Task completed successfully""" + # key *must* be still in tasks. Releasing it directly is forbidden + # without going through cancelled + ts = self.tasks.get(ev.key) + assert ts, self.story(ev.key) + + ts.done = True + ts.startstops.append({"action": "compute", "start": ev.start, "stop": ev.stop}) + ts.nbytes = ev.nbytes + ts.type = ev.type + return {ts: ("memory", ev.value)}, [] + + @_handle_event.register + def _handle_execute_failure(self, ev: ExecuteFailureEvent) -> RecsInstrs: + """Task execution failed""" + # key *must* be still in tasks. Releasing it directly is forbidden + # without going through cancelled + ts = self.tasks.get(ev.key) + assert ts, self.story(ev.key) + + ts.done = True + if ev.start is not None and ev.stop is not None: + ts.startstops.append( + {"action": "compute", "start": ev.start, "stop": ev.stop} + ) + + return { + ts: ( + "error", + ev.exception, + ev.traceback, + ev.exception_text, + ev.traceback_text, + ) + }, [] + + @_handle_event.register + def _handle_reschedule(self, ev: RescheduleEvent) -> RecsInstrs: + """Task raised Reschedule exception while it was running""" + # key *must* be still in tasks. Releasing it directly is forbidden + # without going through cancelled + ts = self.tasks.get(ev.key) + assert ts, self.story(ev.key) + return {ts: "rescheduled"}, [] + + @_handle_event.register + def _handle_find_missing(self, ev: FindMissingEvent) -> RecsInstrs: + if not self._missing_dep_flight: + return {}, [] + + if self.validate: + for ts in self._missing_dep_flight: + assert not ts.who_has, self.story(ts) + + smsg = RequestRefreshWhoHasMsg( + keys=[ts.key for ts in self._missing_dep_flight], + stimulus_id=ev.stimulus_id, + ) + return {}, [smsg] + + @_handle_event.register + def _handle_refresh_who_has(self, ev: RefreshWhoHasEvent) -> RecsInstrs: + self._update_who_has(ev.who_has) + recommendations: Recs = {} + instructions: Instructions = [] + + for key in ev.who_has: + ts = self.tasks.get(key) + if not ts: + continue + + if ts.who_has and ts.state == "missing": + recommendations[ts] = "fetch" + elif ts.who_has and ts.state == "fetch": + # We potentially just acquired new replicas whereas all previously known + # workers are in flight or busy. We're deliberately not testing the + # minute use cases here for the sake of simplicity; instead we rely on + # _ensure_communicating to be a no-op when there's nothing to do. + recommendations, instructions = merge_recs_instructions( + (recommendations, instructions), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) + elif not ts.who_has and ts.state == "fetch": + recommendations[ts] = "missing" + + return recommendations, instructions + + ############### + # Diagnostics # + ############### + + def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]: + """Return all transitions involving one or more tasks""" + keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} + return worker_story(keys, self.log) + + def stimulus_story( + self, *keys_or_tasks: str | TaskState + ) -> list[StateMachineEvent]: + """Return all state machine events involving one or more tasks""" + keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} + return [ev for ev in self.stimulus_log if getattr(ev, "key", None) in keys] + + ############## + # Validation # + ############## + + def _validate_task_memory(self, ts): + assert ts.key in self.data or ts.key in self.actors + assert isinstance(ts.nbytes, int) + assert not ts.waiting_for_data + assert ts.key not in self.ready + assert ts.state == "memory" + + def _validate_task_executing(self, ts): + assert ts.state == "executing" + assert ts.run_spec is not None + assert ts.key not in self.data + assert not ts.waiting_for_data + for dep in ts.dependencies: + assert dep.state == "memory", self.story(dep) + assert dep.key in self.data or dep.key in self.actors + + def _validate_task_ready(self, ts): + assert ts.key in pluck(1, self.ready) + assert ts.key not in self.data + assert ts.state != "executing" + assert not ts.done + assert not ts.waiting_for_data + assert all( + dep.key in self.data or dep.key in self.actors for dep in ts.dependencies + ) + + def _validate_task_waiting(self, ts): + assert ts.key not in self.data + assert ts.state == "waiting" + assert not ts.done + if ts.dependencies and ts.run_spec: + assert not all(dep.key in self.data for dep in ts.dependencies) + + def _validate_task_flight(self, ts): + assert ts.key not in self.data + assert ts in self._in_flight_tasks + assert not any(dep.key in self.ready for dep in ts.dependents) + assert ts.coming_from + assert ts.coming_from in self.in_flight_workers + assert ts.key in self.in_flight_workers[ts.coming_from] + + def _validate_task_fetch(self, ts): + assert ts.key not in self.data + assert self.address not in ts.who_has + assert not ts.done + assert ts in self.data_needed + # Note: ts.who_has may be empty; see GatherDepNetworkFailureEvent + for w in ts.who_has: + assert ts.key in self.has_what[w] + assert ts in self.data_needed_per_worker[w] + + def _validate_task_missing(self, ts): + assert ts.key not in self.data + assert not ts.who_has + assert not ts.done + assert not any(ts.key in has_what for has_what in self.has_what.values()) + assert ts in self._missing_dep_flight + + def _validate_task_cancelled(self, ts): + assert ts.key not in self.data + assert ts._previous in {"long-running", "executing", "flight"} + # We'll always transition to released after it is done + assert ts._next is None, (ts.key, ts._next, self.story(ts)) + + def _validate_task_resumed(self, ts): + assert ts.key not in self.data + assert ts._next + assert ts._previous in {"long-running", "executing", "flight"} + + def _validate_task_released(self, ts): + assert ts.key not in self.data + assert not ts._next + assert not ts._previous + assert ts not in self.data_needed + for tss in self.data_needed_per_worker.values(): + assert ts not in tss + assert ts not in self._executing + assert ts not in self._in_flight_tasks + assert ts not in self._missing_dep_flight + + # FIXME the below assert statement is true most of the time. If a task + # performs the transition flight->cancel->waiting, its dependencies are + # normally in released state. However, the compute-task call for their + # previous dependent provided them with who_has, such that this assert + # is no longer true. + # assert not any(ts.key in has_what for has_what in self.has_what.values()) + + assert not ts.waiting_for_data + assert not ts.done + assert not ts.exception + assert not ts.traceback + + def validate_task(self, ts): + try: + if ts.key in self.tasks: + assert self.tasks[ts.key] == ts + if ts.state == "memory": + self._validate_task_memory(ts) + elif ts.state == "waiting": + self._validate_task_waiting(ts) + elif ts.state == "missing": + self._validate_task_missing(ts) + elif ts.state == "cancelled": + self._validate_task_cancelled(ts) + elif ts.state == "resumed": + self._validate_task_resumed(ts) + elif ts.state == "ready": + self._validate_task_ready(ts) + elif ts.state == "executing": + self._validate_task_executing(ts) + elif ts.state == "flight": + self._validate_task_flight(ts) + elif ts.state == "fetch": + self._validate_task_fetch(ts) + elif ts.state == "released": + self._validate_task_released(ts) + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb + + pdb.set_trace() + + raise InvalidTaskState( + key=ts.key, state=ts.state, story=self.story(ts) + ) from e + + def validate_state(self): + try: + assert self.executing_count >= 0 + waiting_for_data_count = 0 + for ts in self.tasks.values(): + assert ts.state is not None + # check that worker has task + for worker in ts.who_has: + assert worker != self.address + assert ts.key in self.has_what[worker] + # check that deps have a set state and that dependency<->dependent links + # are there + for dep in ts.dependencies: + # self.tasks was just a dict of tasks + # and this check was originally that the key was in `task_state` + # so we may have popped the key out of `self.tasks` but the + # dependency can still be in `memory` before GC grabs it...? + # Might need better bookkeeping + assert dep.state is not None + assert ts in dep.dependents, ts + if ts.waiting_for_data: + waiting_for_data_count += 1 + for ts_wait in ts.waiting_for_data: + assert ts_wait.key in self.tasks + assert ( + ts_wait.state + in READY | {"executing", "flight", "fetch", "missing"} + or ts_wait in self._missing_dep_flight + or ts_wait.who_has.issubset(self.in_flight_workers) + ), (ts, ts_wait, self.story(ts), self.story(ts_wait)) + # FIXME https://github.com/dask/distributed/issues/6319 + # assert self.waiting_for_data_count == waiting_for_data_count + for worker, keys in self.has_what.items(): + assert worker != self.address + for k in keys: + assert k in self.tasks, self.story(k) + assert worker in self.tasks[k].who_has + + for ts in self.data_needed: + assert ts.state == "fetch", self.story(ts) + assert self.tasks[ts.key] is ts + for worker, tss in self.data_needed_per_worker.items(): + for ts in tss: + assert ts.state == "fetch" + assert self.tasks[ts.key] is ts + assert ts in self.data_needed + assert worker in ts.who_has + + for ts in self.tasks.values(): + self.validate_task(ts) + + if self.transition_counter_max: + assert self.transition_counter < self.transition_counter_max + + except Exception as e: + logger.error("Validate state failed", exc_info=e) + logger.exception(e) + if LOG_PDB: + import pdb + + pdb.set_trace() + + if hasattr(e, "to_event"): + topic, msg = e.to_event() + self.log_event(topic, msg) + + raise + + +class BaseWorker(abc.ABC): + _async_instructions: set[asyncio.Task] + + @fail_hard + @log_errors + def _handle_stimulus_from_task( + self, task: asyncio.Task[StateMachineEvent | None] + ) -> None: + self._async_instructions.remove(task) + try: + # This *should* never raise any other exceptions + stim = task.result() + except asyncio.CancelledError: + return + if stim: + self.handle_stimulus(stim) + + @fail_hard + def _handle_instructions(self, instructions: Instructions) -> None: + while instructions: + ensure_communicating: EnsureCommunicatingAfterTransitions | None = None + for inst in instructions: + task: asyncio.Task | None = None + + if isinstance(inst, SendMessageToScheduler): + self.batched_send(inst.to_dict()) + + elif isinstance(inst, EnsureCommunicatingAfterTransitions): + # A single compute-task or acquire-replicas command may cause + # multiple tasks to transition to fetch; this in turn means that we + # will receive multiple instances of this instruction. + # _ensure_communicating is a no-op if it runs twice in a row; we're + # not calling it inside the for loop to avoid a O(n^2) condition + # when + # 1. there are many fetches queued because all workers are in flight + # 2. a single compute-task or acquire-replicas command just sent + # many dependencies to fetch at once. + ensure_communicating = inst + + elif isinstance(inst, GatherDep): + assert inst.to_gather + keys_str = ", ".join(peekn(27, inst.to_gather)[0]) + if len(keys_str) > 80: + keys_str = keys_str[:77] + "..." + task = asyncio.create_task( + self.gather_dep( + inst.worker, + inst.to_gather, + total_nbytes=inst.total_nbytes, + stimulus_id=inst.stimulus_id, + ), + name=f"gather_dep({inst.worker}, {{{keys_str}}})", + ) + + elif isinstance(inst, Execute): + task = asyncio.create_task( + self.execute(inst.key, stimulus_id=inst.stimulus_id), + name=f"execute({inst.key})", + ) + + elif isinstance(inst, RetryBusyWorkerLater): + task = asyncio.create_task( + self.retry_busy_worker_later(inst.worker), + name=f"retry_busy_worker_later({inst.worker})", + ) + + else: + raise TypeError(inst) # pragma: nocover + + if task is not None: + self._async_instructions.add(task) + task.add_done_callback(self._handle_stimulus_from_task) + + if ensure_communicating: + # Potentially re-fill instructions, causing a second iteration of `while + # instructions` at the top of this method + recs, instructions = self._ensure_communicating( + stimulus_id=ensure_communicating.stimulus_id + ) + self._transitions(recs, stimulus_id=ensure_communicating.stimulus_id) + else: + instructions = [] From c2e94488ca2892243ab3644d4235a7181fd0d3f4 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 13 Jun 2022 15:28:39 +0100 Subject: [PATCH 2/3] Non-trivial changes to worker and worker_state_machine --- distributed/worker.py | 442 +++++++++--------- distributed/worker_state_machine.py | 669 +++++++++++++++++++++------- 2 files changed, 732 insertions(+), 379 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index 5686da06d4..d39decd416 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -6,7 +6,6 @@ import errno import functools import logging -import operator import os import pathlib import random @@ -46,11 +45,12 @@ typename, ) -from distributed import comm, preloading, profile, utils +from distributed import preloading, profile, utils from distributed.batched import BatchedSend -from distributed.collections import LRU, HeapSet -from distributed.comm import connect, get_address_host -from distributed.comm.addressing import address_from_user_args, parse_address +from distributed.collections import LRU +from distributed.comm import Comm, connect, get_address_host, parse_address +from distributed.comm import resolve_address as comm_resolve_address +from distributed.comm.addressing import address_from_user_args from distributed.comm.utils import OFFLOAD_THRESHOLD from distributed.compatibility import randbytes, to_thread from distributed.core import ( @@ -107,8 +107,10 @@ NO_VALUE, AcquireReplicasEvent, AlreadyCancelledEvent, + BaseWorker, CancelComputeEvent, ComputeTaskEvent, + DeprecatedWorkerStateAttribute, ExecuteFailureEvent, ExecuteSuccessEvent, FindMissingEvent, @@ -117,6 +119,7 @@ GatherDepFailureEvent, GatherDepNetworkFailureEvent, GatherDepSuccessEvent, + PauseEvent, RefreshWhoHasEvent, RemoveReplicasEvent, RescheduleEvent, @@ -127,7 +130,9 @@ TaskState, UnpauseEvent, UpdateDataEvent, + WorkerState, ) +from distributed.worker_state_machine import logger as wsm_logger if TYPE_CHECKING: from distributed.client import Client @@ -203,7 +208,7 @@ async def _force_close(self): # Worker is in a very broken state if closing fails. We need to shut down # immediately, to ensure things don't get even worse and this worker potentially # deadlocks the cluster. - if self.validate and not self.nanny: + if self.state.validate and not self.nanny: # We're likely in a unit test. Don't kill the whole test suite! raise @@ -217,7 +222,7 @@ async def _force_close(self): os._exit(1) -class Worker(ServerNode): +class Worker(BaseWorker, ServerNode): """Worker node in a Dask distributed cluster Workers perform two functions: @@ -259,14 +264,10 @@ class Worker(ServerNode): * **services:** ``{str: Server}``: Auxiliary web servers running on this worker * **service_ports:** ``{str: port}``: - * **total_out_connections**: ``int`` - The maximum number of concurrent outgoing requests for data * **total_in_connections**: ``int`` - The maximum number of concurrent incoming requests for data - * **comm_threshold_bytes**: ``int`` - As long as the total number of bytes in flight is below this threshold - we will not limit the number of outgoing connections for a single tasks - dependency fetch. + The maximum number of concurrent incoming requests for data. + See also + :attr:`distributed.worker_state_machine.WorkerState.total_out_connections`. * **batched_stream**: ``BatchedSend`` A batched stream along which we communicate to the scheduler * **log**: ``[(message)]`` @@ -279,50 +280,14 @@ class Worker(ServerNode): we want to compute and ``dep`` is the name of a piece of dependent data that we want to collect from others. - * **tasks**: ``{key: TaskState}`` - The tasks currently executing on this worker (and any dependencies of those tasks) - * **data_needed**: HeapSet[TaskState] - The tasks which still require data in order to execute and are in memory on at - least another worker, prioritized as a heap - * **data_needed_per_worker**: ``{worker: HeapSet[TaskState]}`` - Same as data_needed, split by worker - * **ready**: [keys] - Keys that are ready to run. Stored in a LIFO stack - * **constrained**: [keys] - Keys for which we have the data to run, but are waiting on abstract - resources like GPUs. Stored in a FIFO deque - * **executing_count**: ``int`` - A count of tasks currently executing on this worker - * **executed_count**: int - A number of tasks that this worker has run in its lifetime - * **long_running**: {keys} - A set of keys of tasks that are running and have started their own - long-running clients. - * **has_what**: ``{worker: {deps}}`` - The data that we care about that we think a worker has - * **in_flight_tasks**: ``int`` - A count of the number of tasks that are coming to us in current - peer-to-peer connections - * **in_flight_workers**: ``{worker: {task}}`` - The workers from which we are currently gathering data and the - dependencies we expect from those connections. Workers in this dict won't be - asked for additional dependencies until the current query returns. - * **busy_workers**: ``{worker}`` - Workers that recently returned a busy status. Workers in this set won't be - asked for additional dependencies for some time. - * **comm_bytes**: ``int`` - The total number of bytes in flight * **threads**: ``{key: int}`` The ID of the thread on which the task ran * **active_threads**: ``{int: key}`` The keys currently running on active threads - * **waiting_for_data_count**: ``int`` - A count of how many tasks are currently waiting for data - * **generation**: ``int`` - Counter that decreases every time the compute-task handler is invoked by the - Scheduler. It is appended to TaskState.priority and acts as a tie-breaker - between tasks that have the same priority on the Scheduler, determining a - last-in-first-out order between them. + * **state**: ``WorkerState`` + Encapsulated state machine. See + :class:`~distributed.worker_state_machine.BaseWorker` and + :class:`~distributed.worker_state_machine.WorkerState` Parameters ---------- @@ -413,7 +378,6 @@ class Worker(ServerNode): profile_history: deque[tuple[float, dict[str, Any]]] incoming_transfer_log: deque[dict[str, Any]] outgoing_transfer_log: deque[dict[str, Any]] - validate: bool incoming_count: int outgoing_count: int outgoing_current_count: int @@ -540,27 +504,15 @@ def __init__( DeprecationWarning, stacklevel=2, ) - self.tasks = {} - self.waiting_for_data_count = 0 - self.has_what = defaultdict(set) - self.data_needed = HeapSet(key=operator.attrgetter("priority")) - self.data_needed_per_worker = defaultdict( - lambda: HeapSet(key=operator.attrgetter("priority")) - ) self.nanny = nanny self._lock = threading.Lock() - self.in_flight_workers = {} - self.busy_workers = set() - self.total_out_connections = dask.config.get( + total_out_connections = dask.config.get( "distributed.worker.connections.outgoing" ) self.total_in_connections = dask.config.get( "distributed.worker.connections.incoming" ) - self.comm_threshold_bytes = int(10e6) - self.comm_nbytes = 0 - self._missing_dep_flight = set() self.threads = {} @@ -572,25 +524,9 @@ def __init__( self.profile_recent = profile.create() self.profile_history = deque(maxlen=3600) - self.generation = 0 - - self.ready = [] - self.constrained = deque() - self._executing = set() - self._in_flight_tasks = set() - self.executed_count = 0 - self.long_running = set() - - self.target_message_size = int(50e6) # 50 MB - - self.log = deque(maxlen=100_000) - self.stimulus_log = deque(maxlen=10_000) if validate is None: validate = dask.config.get("distributed.scheduler.validate") - self.validate = validate - self.transition_counter = 0 - self.transition_counter_max = transition_counter_max self.incoming_transfer_log = deque(maxlen=100000) self.incoming_count = 0 self.outgoing_transfer_log = deque(maxlen=100000) @@ -609,7 +545,7 @@ def __init__( profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms") assert profile_cycle_interval - self._setup_logging(logger) + self._setup_logging(logger, wsm_logger) if local_dir is not None: warnings.warn("The local_dir keyword has moved to local_directory") @@ -672,13 +608,12 @@ def __init__( self._interface = interface self._protocol = protocol - self.nthreads = nthreads or CPU_COUNT + nthreads = nthreads or CPU_COUNT if resources is None: - resources = dask.config.get("distributed.worker.resources", None) + resources = dask.config.get("distributed.worker.resources") assert isinstance(resources, dict) + self.total_resources = resources.copy() - self.total_resources = resources or {} - self.available_resources = (resources or {}).copy() self.death_timeout = parse_timedelta(death_timeout) self.extensions = {} @@ -691,7 +626,6 @@ def __init__( assert isinstance(self.security, Security) self.connection_args = self.security.get_connection_args("worker") - self.actors = {} self.loop = self.io_loop = IOLoop.current() # Common executors always available @@ -713,7 +647,7 @@ def __init__( self.executors["default"] = executor if "default" not in self.executors: self.executors["default"] = ThreadPoolExecutor( - self.nthreads, thread_name_prefix="Dask-Default-Threads" + nthreads, thread_name_prefix="Dask-Default-Threads" ) self.batched_stream = BatchedSend(interval="2ms", loop=self.loop) @@ -725,6 +659,9 @@ def __init__( if self.local_directory not in sys.path: sys.path.insert(0, self.local_directory) + self.plugins = {} + self._pending_plugins = plugins + self.services = {} self.service_specs = services or {} @@ -780,12 +717,33 @@ def __init__( "worker-status-change": self.handle_worker_status_change, } - super().__init__( + ServerNode.__init__( + self, handlers=handlers, stream_handlers=stream_handlers, connection_args=self.connection_args, **kwargs, ) + self.memory_manager = WorkerMemoryManager( + self, + data=data, + nthreads=nthreads, + memory_limit=memory_limit, + memory_target_fraction=memory_target_fraction, + memory_spill_fraction=memory_spill_fraction, + memory_pause_fraction=memory_pause_fraction, + ) + state = WorkerState( + nthreads=nthreads, + data=self.memory_manager.data, + threads=self.threads, + plugins=self.plugins, + resources=self.total_resources, + total_out_connections=total_out_connections, + validate=validate, + transition_counter_max=transition_counter_max, + ) + BaseWorker.__init__(self, state) self.scheduler = self.rpc(scheduler_addr) self.execution_state = { @@ -795,7 +753,8 @@ def __init__( } self.heartbeat_interval = parse_timedelta(heartbeat_interval, default="ms") - pc = PeriodicCallback(self.heartbeat, self.heartbeat_interval * 1000) + # FIXME https://github.com/tornadoweb/tornado/issues/3117 + pc = PeriodicCallback(self.heartbeat, self.heartbeat_interval * 1000) # type: ignore self.periodic_callbacks["heartbeat"] = pc pc = PeriodicCallback(lambda: self.batched_send({"op": "keep-alive"}), 60000) @@ -812,15 +771,6 @@ def __init__( name: extension(self) for name, extension in extensions.items() } - self.memory_manager = WorkerMemoryManager( - self, - data=data, - memory_limit=memory_limit, - memory_target_fraction=memory_target_fraction, - memory_spill_fraction=memory_spill_fraction, - memory_pause_fraction=memory_pause_fraction, - ) - setproctitle("dask-worker [not started]") if dask.config.get("distributed.worker.profile.enabled"): @@ -833,9 +783,6 @@ def __init__( pc = PeriodicCallback(self.cycle_profile, profile_cycle_interval * 1000) self.periodic_callbacks["profile-cycle"] = pc - self.plugins = {} - self._pending_plugins = plugins - if lifetime is None: lifetime = dask.config.get("distributed.worker.lifetime.duration") lifetime = parse_timedelta(lifetime) @@ -853,8 +800,6 @@ def __init__( self.io_loop.call_later(lifetime, self.close_gracefully) self.lifetime = lifetime - self._async_instructions = set() - Worker._instances.add(self) ################ @@ -864,8 +809,8 @@ def __init__( @property def data(self) -> MutableMapping[str, Any]: - """{task key: task payload} of all completed tasks, whether they were computed on - this Worker or computed somewhere else and then transferred here over the + """{task key: task payload} of all completed tasks, whether they were computed + on this Worker or computed somewhere else and then transferred here over the network. When using the default configuration, this is a zict buffer that automatically @@ -874,6 +819,10 @@ def data(self) -> MutableMapping[str, Any]: It could also be a user-defined arbitrary dict-like passed when initialising the Worker or the Nanny. Worker logic should treat this opaquely and stick to the MutableMapping API. + + .. note:: + This same collection is also available at ``self.state.data`` and + ``self.memory_manager.data``. """ return self.memory_manager.data @@ -884,6 +833,41 @@ def data(self) -> MutableMapping[str, Any]: memory_pause_fraction = DeprecatedMemoryManagerAttribute() memory_monitor = DeprecatedMemoryMonitor() + ########################### + # State machine accessors # + ########################### + + # Deprecated attributes moved to self.state. + actors = DeprecatedWorkerStateAttribute() + available_resources = DeprecatedWorkerStateAttribute() + busy_workers = DeprecatedWorkerStateAttribute() + comm_nbytes = DeprecatedWorkerStateAttribute() + comm_threshold_bytes = DeprecatedWorkerStateAttribute() + constrained = DeprecatedWorkerStateAttribute() + data_needed = DeprecatedWorkerStateAttribute() + data_needed_per_worker = DeprecatedWorkerStateAttribute() + executed_count = DeprecatedWorkerStateAttribute() + executing_count = DeprecatedWorkerStateAttribute() + generation = DeprecatedWorkerStateAttribute() + has_what = DeprecatedWorkerStateAttribute() + in_flight_tasks = DeprecatedWorkerStateAttribute(target="in_flight_tasks_count") + in_flight_workers = DeprecatedWorkerStateAttribute() + log = DeprecatedWorkerStateAttribute() + long_running = DeprecatedWorkerStateAttribute() + nthreads = DeprecatedWorkerStateAttribute() + stimulus_log = DeprecatedWorkerStateAttribute() + stimulus_story = DeprecatedWorkerStateAttribute() + story = DeprecatedWorkerStateAttribute() + ready = DeprecatedWorkerStateAttribute() + tasks = DeprecatedWorkerStateAttribute() + target_message_size = DeprecatedWorkerStateAttribute() + total_out_connections = DeprecatedWorkerStateAttribute() + transition_counter = DeprecatedWorkerStateAttribute() + transition_counter_max = DeprecatedWorkerStateAttribute() + validate = DeprecatedWorkerStateAttribute() + validate_task = DeprecatedWorkerStateAttribute() + waiting_for_data_count = DeprecatedWorkerStateAttribute() + ################## # Administrative # ################## @@ -894,10 +878,10 @@ def __repr__(self): f"<{self.__class__.__name__} {self.address_safe!r}{name}, " f"status: {self.status.name}, " f"stored: {len(self.data)}, " - f"running: {self.executing_count}/{self.nthreads}, " - f"ready: {len(self.ready)}, " - f"comm: {self.in_flight_tasks}, " - f"waiting: {self.waiting_for_data_count}>" + f"running: {self.state.executing_count}/{self.state.nthreads}, " + f"ready: {len(self.state.ready)}, " + f"comm: {self.state.in_flight_tasks_count}, " + f"waiting: {self.state.waiting_for_data_count}>" ) @property @@ -915,14 +899,6 @@ def log_event(self, topic: str | Collection[str], msg: Any) -> None: else: self.loop.add_callback(self.batched_send, full_msg) - @property - def executing_count(self) -> int: - return len(self._executing) - - @property - def in_flight_tasks(self) -> int: - return len(self._in_flight_tasks) - @property def worker_address(self): """For API compatibility with Nanny""" @@ -935,13 +911,15 @@ def executor(self): @ServerNode.status.setter # type: ignore def status(self, value): """Override Server.status to notify the Scheduler of status changes. - Also handles unpausing. + Also handles pausing/unpausing. """ prev_status = self.status ServerNode.status.__set__(self, value) stimulus_id = f"worker-status-change-{time()}" self._send_worker_status_change(stimulus_id) - if prev_status == Status.paused and value == Status.running: + if prev_status == Status.running: + self.handle_stimulus(PauseEvent(stimulus_id=stimulus_id)) + elif value == Status.running: self.handle_stimulus(UnpauseEvent(stimulus_id=stimulus_id)) def _send_worker_status_change(self, stimulus_id: str) -> None: @@ -961,10 +939,10 @@ async def get_metrics(self) -> dict: spilled_memory, spilled_disk = 0, 0 out = dict( - executing=self.executing_count, + executing=self.state.executing_count, in_memory=len(self.data), - ready=len(self.ready), - in_flight=self.in_flight_tasks, + ready=len(self.state.ready), + in_flight=self.state.in_flight_tasks_count, bandwidth={ "total": self.bandwidth, "workers": dict(self.bandwidth_workers), @@ -1008,7 +986,7 @@ def identity(self): "type": type(self).__name__, "id": self.id, "scheduler": self.scheduler.address, - "nthreads": self.nthreads, + "nthreads": self.state.nthreads, "memory_limit": self.memory_manager.memory_limit, } @@ -1025,29 +1003,15 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: info = super()._to_dict(exclude=exclude) extra = { "status": self.status, - "ready": self.ready, - "constrained": self.constrained, - "data_needed": list(self.data_needed.sorted()), - "data_needed_per_worker": { - w: list(v.sorted()) for w, v in self.data_needed_per_worker.items() - }, - "long_running": self.long_running, - "executing_count": self.executing_count, - "in_flight_tasks": self.in_flight_tasks, - "in_flight_workers": self.in_flight_workers, - "busy_workers": self.busy_workers, - "log": self.log, - "stimulus_log": self.stimulus_log, - "transition_counter": self.transition_counter, - "tasks": self.tasks, "logs": self.get_logs(), "config": dask.config.config, "incoming_transfer_log": self.incoming_transfer_log, "outgoing_transfer_log": self.outgoing_transfer_log, } + extra = {k: v for k, v in extra.items() if k not in exclude} info.update(extra) + info.update(self.state._to_dict(exclude=exclude)) info.update(self.memory_manager._to_dict(exclude=exclude)) - info = {k: v for k, v in info.items() if k not in exclude} return recursive_to_dict(info, exclude=exclude) ##################### @@ -1055,16 +1019,16 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: ##################### def batched_send(self, msg: dict[str, Any]) -> None: - """Send a fire-and-forget message to the scheduler through bulk comms. + """Implements BaseWorker abstract method. + + Send a fire-and-forget message to the scheduler through bulk comms. If we're not currently connected to the scheduler, the message will be silently dropped! - Parameters - ---------- - msg: dict - msgpack-serializable message to send to the scheduler. - Must have a 'op' key which is registered in Scheduler.stream_handlers. + See also + -------- + distributed.worker_state_machine.BaseWorker.batched_send """ if ( self.batched_stream @@ -1073,7 +1037,7 @@ def batched_send(self, msg: dict[str, Any]) -> None: ): self.batched_stream.send(msg) - async def _register_with_scheduler(self): + async def _register_with_scheduler(self) -> None: self.periodic_callbacks["keep-alive"].stop() self.periodic_callbacks["heartbeat"].stop() start = time() @@ -1093,11 +1057,11 @@ async def _register_with_scheduler(self): address=self.contact_address, status=self.status.name, keys=list(self.data), - nthreads=self.nthreads, + nthreads=self.state.nthreads, name=self.name, nbytes={ ts.key: ts.get_nbytes() - for ts in self.tasks.values() + for ts in self.state.tasks.values() # Only if the task is in memory this is a sensible # result since otherwise it simply submits the # default value @@ -1155,12 +1119,12 @@ async def _register_with_scheduler(self): self.periodic_callbacks["heartbeat"].start() self.loop.add_callback(self.handle_scheduler, comm) - def _update_latency(self, latency): + def _update_latency(self, latency) -> None: self.latency = latency * 0.05 + self.latency * 0.95 if self.digests is not None: self.digests["latency"].add(latency) - async def heartbeat(self): + async def heartbeat(self) -> None: if self.heartbeat_active: logger.debug("Heartbeat skipped: channel busy") return @@ -1174,9 +1138,9 @@ async def heartbeat(self): now=start, metrics=await self.get_metrics(), executing={ - key: start - self.tasks[key].start_time + key: start - self.state.tasks[key].start_time for key in self.active_keys - if key in self.tasks + if key in self.state.tasks }, extensions={ name: extension.heartbeat() @@ -1219,7 +1183,7 @@ async def heartbeat(self): self.heartbeat_active = False @fail_hard - async def handle_scheduler(self, comm): + async def handle_scheduler(self, comm: Comm) -> None: await self.handle_stream(comm) logger.info( "Connection to scheduler broken. Closing without reporting. ID: %s Address %s Status: %s", @@ -1229,7 +1193,9 @@ async def handle_scheduler(self, comm): ) await self.close() - async def upload_file(self, comm, filename=None, data=None, load=True): + async def upload_file( + self, filename: str, data: str | bytes, load: bool = True + ) -> dict[str, Any]: out_filename = os.path.join(self.local_directory, filename) def func(data): @@ -1384,7 +1350,7 @@ async def start_unsafe(self): logger.info(" {:>16} at: {:>26}".format(k, self.ip + ":" + str(v))) logger.info("Waiting to connect to: %26s", self.scheduler.address) logger.info("-" * 49) - logger.info(" Threads: %26d", self.nthreads) + logger.info(" Threads: %26d", self.state.nthreads) if self.memory_manager.memory_limit: logger.info( " Memory: %26s", @@ -1412,14 +1378,13 @@ async def start_unsafe(self): raise plugins_exceptions[0] self._pending_plugins = () - + self.state.address = self.address await self._register_with_scheduler() - self.start_periodic_callbacks() return self @log_errors - async def close( + async def close( # type: ignore self, timeout: float = 30, executor_wait: bool = True, @@ -1427,7 +1392,8 @@ async def close( ) -> str | None: """Close the worker - Close asynchronous operations running on the worker, stop all executors and comms. If requested, this also closes the nanny. + Close asynchronous operations running on the worker, stop all executors and + comms. If requested, this also closes the nanny. Parameters ---------- @@ -1479,16 +1445,8 @@ async def close( for pc in self.periodic_callbacks.values(): pc.stop() - if self._async_instructions: - for task in self._async_instructions: - task.cancel() - # async tasks can handle cancellation and could take an arbitrary amount - # of time to terminate - _, pending = await asyncio.wait(self._async_instructions, timeout=timeout) - for task in pending: - logger.error( - f"Failed to cancel asyncio task after {timeout} seconds: {task}" - ) + # Cancel async instructions + await BaseWorker.close(self, timeout=timeout) for preload in self.preloads: try: @@ -1580,7 +1538,7 @@ def _close(): await self.rpc.close() self.status = Status.closed - await super().close() + await ServerNode.close(self) setproctitle("dask-worker [closed]") return "OK" @@ -1641,7 +1599,7 @@ async def batched_send_connect(): async def get_data( self, comm, keys=None, who=None, serializers=None, max_connections=None - ): + ) -> dict | Status: start = time() if max_connections is None: @@ -1681,13 +1639,15 @@ async def get_data( if len(data) < len(keys): for k in set(keys) - set(data): - if k in self.actors: + if k in self.state.actors: from distributed.actor import Actor - data[k] = Actor(type(self.actors[k]), self.address, k, worker=self) + data[k] = Actor( + type(self.state.actors[k]), self.address, k, worker=self + ) msg = {"status": "OK", "data": {k: to_serialize(v) for k, v in data.items()}} - nbytes = {k: self.tasks[k].nbytes for k in data if k in self.tasks} + nbytes = {k: self.state.tasks[k].nbytes for k in data if k in self.state.tasks} stop = time() if self.digests is not None: self.digests["get-data-load-duration"].add(stop - start) @@ -1753,9 +1713,9 @@ def update_data( async def set_resources(self, **resources) -> None: for r, quantity in resources.items(): if r in self.total_resources: - self.available_resources[r] += quantity - self.total_resources[r] + self.state.available_resources[r] += quantity - self.total_resources[r] else: - self.available_resources[r] = quantity + self.state.available_resources[r] = quantity self.total_resources[r] = quantity await retry_operation( @@ -1835,7 +1795,6 @@ def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: # Task Management # ################### - @fail_hard def _handle_remote_stimulus( self, cls: type[StateMachineEvent] ) -> Callable[..., None]: @@ -1847,14 +1806,28 @@ def _(**kwargs): return _ @fail_hard - @log_errors + def _handle_stimulus_from_task( + self, task: asyncio.Task[StateMachineEvent | None] + ) -> None: + """Override BaseWorker method for added validation + + See also + -------- + distributed.worker_state_machine.BaseWorker._handle_stimulus_from_task + """ + super()._handle_stimulus_from_task(task) + + @fail_hard def handle_stimulus(self, stim: StateMachineEvent) -> None: - if not isinstance(stim, FindMissingEvent): - self.stimulus_log.append(stim.to_loggable(handled=time())) + """Override BaseWorker method for added validation + + See also + -------- + distributed.worker_state_machine.BaseWorker.handle_stimulus + distributed.worker_state_machine.WorkerState.handle_stimulus + """ try: - recs, instructions = self._handle_event(stim) - self._transitions(recs, stimulus_id=stim.stimulus_id) - self._handle_instructions(instructions) + super().handle_stimulus(stim) except Exception as e: if hasattr(e, "to_event"): topic, msg = e.to_event() # type: ignore @@ -1862,26 +1835,17 @@ def handle_stimulus(self, stim: StateMachineEvent) -> None: raise def stateof(self, key: str) -> dict[str, Any]: - ts = self.tasks[key] + ts = self.state.tasks[key] return { "executing": ts.state == "executing", "waiting_for_data": bool(ts.waiting_for_data), - "heap": key in pluck(1, self.ready), + "heap": key in pluck(1, self.state.ready), "data": key in self.data, } async def get_story(self, keys=None): return self.story(*keys) - @property - def total_comm_bytes(self): - warnings.warn( - "The attribute `Worker.total_comm_bytes` has been renamed to `comm_threshold_bytes`. " - "Future versions will only support the new name.", - FutureWarning, - ) - return self.comm_threshold_bytes - ########################## # Dependencies gathering # ########################## @@ -1898,7 +1862,7 @@ def _get_cause(self, keys: Iterable[str]) -> TaskState: """ cause = None for key in keys: - ts = self.tasks[key] + ts = self.state.tasks[key] if ts.dependents: return next(iter(ts.dependents)) cause = ts @@ -1914,7 +1878,7 @@ def _update_metrics_received_data( worker: str, ) -> None: - total_bytes = sum(self.tasks[key].get_nbytes() for key in data) + total_bytes = sum(self.state.tasks[key].get_nbytes() for key in data) cause.startstops.append( { @@ -1932,7 +1896,7 @@ def _update_metrics_received_data( "stop": stop + self.scheduler_delay, "middle": (start + stop) / 2.0 + self.scheduler_delay, "duration": duration, - "keys": {key: self.tasks[key].nbytes for key in data}, + "keys": {key: self.state.tasks[key].nbytes for key in data}, "total": total_bytes, "bandwidth": bandwidth, "who": worker, @@ -1956,7 +1920,6 @@ def _update_metrics_received_data( self.incoming_count += 1 @fail_hard - @log_errors async def gather_dep( self, worker: str, @@ -1965,24 +1928,19 @@ async def gather_dep( *, stimulus_id: str, ) -> StateMachineEvent | None: - """Gather dependencies for a task from a worker who has them + """Implements BaseWorker abstract method - Parameters - ---------- - worker : str - Address of worker to gather dependencies from - to_gather : list - Keys of dependencies to gather from worker -- this is not - necessarily equivalent to the full list of dependencies of ``dep`` - as some dependencies may already be present on this worker. - total_nbytes : int - Total number of bytes for all the dependencies in to_gather combined + See also + -------- + distributed.worker_state_machine.BaseWorker.gather_dep """ if self.status not in WORKER_ANY_RUNNING: return None try: - self.log.append(("request-dep", worker, to_gather, stimulus_id, time())) + self.state.log.append( + ("request-dep", worker, to_gather, stimulus_id, time()) + ) logger.debug("Request %d keys from %s", len(to_gather), worker) start = time() @@ -1991,7 +1949,9 @@ async def gather_dep( ) stop = time() if response["status"] == "busy": - self.log.append(("busy-gather", worker, to_gather, stimulus_id, time())) + self.state.log.append( + ("busy-gather", worker, to_gather, stimulus_id, time()) + ) return GatherDepBusyEvent( worker=worker, total_nbytes=total_nbytes, @@ -2007,7 +1967,7 @@ async def gather_dep( cause=cause, worker=worker, ) - self.log.append( + self.state.log.append( ("receive-dep", worker, set(response["data"]), stimulus_id, time()) ) return GatherDepSuccessEvent( @@ -2019,7 +1979,7 @@ async def gather_dep( except OSError: logger.exception("Worker stream died during communication: %s", worker) - self.log.append( + self.state.log.append( ("receive-dep-failed", worker, to_gather, stimulus_id, time()) ) return GatherDepNetworkFailureEvent( @@ -2044,6 +2004,13 @@ async def gather_dep( ) async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None: + """Wait some time, then take a peer worker out of busy state. + Implements BaseWorker abstract method. + + See Also + -------- + distributed.worker_state_machine.BaseWorker.retry_busy_worker_later + """ await asyncio.sleep(0.15) return RetryBusyWorkerEvent( worker=worker, stimulus_id=f"retry-busy-worker-{time()}" @@ -2078,7 +2045,7 @@ async def actor_execute( kwargs = kwargs or {} separate_thread = kwargs.pop("separate_thread", True) key = actor - actor = self.actors[key] + actor = self.state.actors[key] func = getattr(actor, function) name = key_split(key) + "." + function @@ -2105,7 +2072,7 @@ async def actor_execute( def actor_attribute(self, actor=None, attribute=None) -> dict[str, Any]: try: - value = getattr(self.actors[actor], attribute) + value = getattr(self.state.actors[actor], attribute) return {"status": "OK", "result": to_serialize(value)} except Exception as ex: return {"status": "error", "exception": to_serialize(ex)} @@ -2130,9 +2097,15 @@ async def _maybe_deserialize_task( @fail_hard async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None: + """Execute a task. Implements BaseWorker abstract method. + + See also + -------- + distributed.worker_state_machine.BaseWorker.execute + """ if self.status in {Status.closing, Status.closed, Status.closing_gracefully}: return None - ts = self.tasks.get(key) + ts = self.state.tasks.get(key) if not ts: return None if ts.state == "cancelled": @@ -2153,17 +2126,15 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No ) try: - if self.validate: + if self.state.validate: assert not ts.waiting_for_data assert ts.state == "executing" assert ts.run_spec is not None args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs) - try: - executor = ts.annotations["executor"] # type: ignore - except (TypeError, KeyError): - executor = "default" + assert ts.annotations is not None + executor = ts.annotations.get("executor", "default") try: e = self.executors[executor] except KeyError: @@ -2266,7 +2237,7 @@ def _prepare_args_for_execution( except KeyError: from distributed.actor import Actor # TODO: create local actor - data[k] = Actor(type(self.actors[k]), self.address, k, self) + data[k] = Actor(type(self.state.actors[k]), self.address, k, self) args2 = pack_data(args, data, key_types=(bytes, str)) kwargs2 = pack_data(kwargs, data, key_types=(bytes, str)) stop = time() @@ -2503,6 +2474,23 @@ def get_current_task(self) -> str: """ return self.active_threads[threading.get_ident()] + def validate_state(self) -> None: + try: + self.state.validate_state() + except Exception as e: + logger.error("Validate state failed", exc_info=e) + logger.exception(e) + if LOG_PDB: + import pdb + + pdb.set_trace() + + if hasattr(e, "to_event"): + topic, msg = e.to_event() # type: ignore + self.log_event(topic, msg) + + raise + def get_worker() -> Worker: """Get the worker currently running this task @@ -2576,7 +2564,7 @@ def get_client(address=None, timeout=None, resolve_address=True) -> Client: timeout = parse_timedelta(timeout, "s") if address and resolve_address: - address = comm.resolve_address(address) + address = comm_resolve_address(address) try: worker = get_worker() except ValueError: # could not find worker diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 21b2043d05..3cca984cc7 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -5,6 +5,7 @@ import functools import heapq import logging +import operator import random import sys from collections import defaultdict, deque @@ -19,7 +20,6 @@ from copy import copy from dataclasses import dataclass, field from functools import lru_cache -from pickle import PicklingError from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict, cast from tlz import peekn, pluck @@ -30,7 +30,7 @@ from distributed._stories import worker_story from distributed.collections import HeapSet from distributed.comm import get_address_host -from distributed.core import ErrorMessage, Status, error_message +from distributed.core import ErrorMessage, error_message from distributed.metrics import time from distributed.protocol import pickle from distributed.protocol.serialize import Serialize @@ -39,8 +39,6 @@ logger = logging.getLogger(__name__) -LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") - if TYPE_CHECKING: # TODO import from typing (requires Python >=3.10) from typing_extensions import TypeAlias @@ -48,6 +46,7 @@ # Circular imports from distributed.actor import Actor from distributed.diagnostics.plugin import WorkerPlugin + from distributed.worker import Worker # TODO move out of TYPE_CHECKING (requires Python >=3.10) TaskStateState: TypeAlias = Literal[ @@ -542,6 +541,11 @@ def _after_from_dict(self) -> None: """Optional post-processing after an instance is created by ``from_dict``""" +@dataclass +class PauseEvent(StateMachineEvent): + __slots__ = () + + @dataclass class UnpauseEvent(StateMachineEvent): __slots__ = () @@ -854,38 +858,266 @@ def merge_recs_instructions(*args: RecsInstrs) -> RecsInstrs: class WorkerState: + """State machine encapsulating the lifetime of all tasks on a worker. + + Not to be confused with :class:`distributed.scheduler.WorkerState`. + + .. note:: + The data attributes of this class are implementation details and may be + changed without a deprecation cycle. + + .. warning:: + The attributes of this class are all heavily correlated with each other. + *Do not* modify them directly, *ever*, as it is extremely easy to obtain a broken + state this way, which in turn will likely result in cluster-wide deadlocks. + + The state should be exclusively mutated through :meth:`handle_stimulus`. + """ + + #: Worker :. This is used in decision-making by the state machine, + #: e.g. to determine if a peer worker is running on the same host or not. + #: This attribute may not be known when the WorkerState is initialised. It *must* be + #: set before the first call to :meth:`handle_stimulus`. address: str + + #: ``{key: TaskState}``. The tasks currently executing on this worker (and any + #: dependencies of those tasks) + tasks: dict[str, TaskState] + + #: ``{ts.key: thread ID}``. This collection is shared by reference between + #: :class:`~distributed.worker.Worker` and this class. While the WorkerState is + #: thread-agnostic, it still needs access to this information in some cases. + #: This collection is populated by :meth:`distributed.worker.Worker.execute`. + #: It does not *need* to be populated for the WorkerState to work. + threads: dict[str, int] + + #: In-memory tasks data. This collection is shared by reference between + #: :class:`~distributed.worker.Worker`, + #: :class:`~distributed.worker_memory.WorkerMemoryManager`, and this class. data: MutableMapping[str, Any] - threads: dict[str, int] # {ts.key: thread ID} + + #: ``{name: worker plugin}``. This collection is shared by reference between + #: :class:`~distributed.worker.Worker` and this class. The Worker managed adding and + #: removing plugins, while the WorkerState invokes the ``WorkerPlugin.transition`` + #: method, is available. plugins: dict[str, WorkerPlugin] - tasks: dict[str, TaskState] + # heapq ``[(priority, key), ...]``. Keys that are ready to run. + ready: list[tuple[tuple[int, ...], str]] + + #: Keys for which we have the data to run, but are waiting on abstract resources + #: like GPUs. Stored in a FIFO deque. + #: See :attr:`available_resources` and :doc:`resources`. + constrained: deque[str] + + #: Number of tasks that can be executing in parallel. + #: At any given time, :meth:`executing_count` <= nthreads. + nthreads: int + + #: True if the state machine should start executing more tasks and fetch + #: dependencies whenever a slot is available. This property must be kept aligned + #: with the Worker: ``WorkerState.running == (Worker.status is Status.running)``. + running: bool + + #: A count of how many tasks are currently waiting for data waiting_for_data_count: int - has_what: defaultdict[str, set[str]] # {worker address: {ts.key, ...} + + #: ``{worker address: {ts.key, ...}``. + #: The data that we care about that we think a worker has + has_what: defaultdict[str, set[str]] + + #: The tasks which still require data in order to execute and are in memory on at + #: least another worker, prioritized as a heap. All and only tasks with + #: ``TaskState.state == 'fetch'`` are in this collection. data_needed: HeapSet[TaskState] + + #: Same as :attr:`data_needed`, individually for every peer worker. A + #: :class:`TaskState` with multiple entries in :attr:`~TaskState.who_has` will + #: appear multiple times here. data_needed_per_worker: defaultdict[str, HeapSet[TaskState]] - in_flight_workers: dict[str, set[str]] # {worker address: {ts.key, ...}} - busy_workers: set[str] + + #: Number of bytes to fetch from the same worker in a single call to + #: :meth:`BaseWorker.gather_dep`. Multiple small tasks that can be fetched from the + #: same worker will be clustered in a single instruction as long as their combined + #: size doesn't exceed this value. + target_message_size: int + + #: All and only tasks with ``TaskState.state == 'missing'``. + missing_dep_flight: set[TaskState] + + #: Which tasks that are coming to us in current peer-to-peer connections. + #: All and only tasks with TaskState.state == 'flight'. + #: See also :meth:`in_flight_tasks_count`. + in_flight_tasks: set[TaskState] + + #: ``{worker address: {ts.key, ...}}`` + #: The workers from which we are currently gathering data and the dependencies we + #: expect from those connections. Workers in this dict won't be asked for additional + #: dependencies until the current query returns. + in_flight_workers: dict[str, set[str]] + + #: The total number of bytes in flight + comm_nbytes: int + + #: The maximum number of concurrent incoming requests for data. + #: See also :attr:`distributed.worker.Worker.total_in_connections`. total_out_connections: int + + #: Ignore :attr:`total_out_connections` as long as :attr:`comm_nbytes` is + #: less than this value. comm_threshold_bytes: int - comm_nbytes: int - _missing_dep_flight: set[TaskState] + + #: Peer workers that recently returned a busy status. Workers in this set won't be + #: asked for additional dependencies for some time. + busy_workers: set[str] + + #: Counter that decreases every time the compute-task handler is invoked by the + #: Scheduler. It is appended to :attr:`TaskState.priority` and acts as a + #: tie-breaker between tasks that have the same priority on the Scheduler, + #: determining a last-in-first-out order between them. generation: int - ready: list[tuple[tuple[int, ...], str]] # heapq [(priority, key), ...] - constrained: deque[str] - nthreads: int + + #: ``{resource name: amount}``. Current resources that aren't being currently + #: consumed by task execution. Always less or equal to ``Worker.total_resources``. + #: See :doc:`resources`. available_resources: dict[str, float] - _executing: set[TaskState] - _in_flight_tasks: set[TaskState] - executed_count: int + + #: Set of tasks that are currently running. + #: See also :meth:`executing_count` and :attr:`long_runing`. + executing: set[TaskState] + + #: Set of keys of tasks that are currently running and have called + #: :func:`~distributed.secede`. + #: These tasks do not appear in the :attr:`executing` set. long_running: set[str] + + #: A number of tasks that this worker has run in its lifetime. + #: See also :meth:`executing_count`. + executed_count: int + + #: Actor tasks. See :doc:`actors`. actors: dict[str, Actor | None] - log: deque[tuple] # [(..., stimulus_id: str | None, timestamp: float), ...] + + #: Transition log: ``[(..., stimulus_id: str | None, timestamp: float), ...]`` + #: The number of stimuli logged is capped. + #: See also :meth:`story` and :attr:`stimulus_log`. + log: deque[tuple] + + #: Log of all stimuli received by :meth:`handle_stimulus`. + #: The number of events logged is capped. + #: See also :attr:`log` and :meth:`stimulus_story`. stimulus_log: deque[StateMachineEvent] - target_message_size: int + + #: If True, enable expensive internal consistency check. + #: Typically disabled in production. + validate: bool + + #: Total number of state transitions so far. + #: See also :attr:`log` and :attr:`transition_counter_max`. transition_counter: int + + #: Raise an error if the :attr:`transition_counter` ever reaches this value. + #: This is meant for debugging only, to catch infinite recursion loops. + #: In production, it should always be set to False. transition_counter_max: int | Literal[False] - validate: bool + + __slots__ = tuple(__annotations__) # type: ignore + + def __init__( + self, + nthreads: int, + *, + address: str | None = None, + data: MutableMapping[str, Any] = None, + threads: dict[str, int] | None = None, + plugins: dict[str, WorkerPlugin] | None = None, + resources: Mapping[str, float] | None = None, + total_out_connections: int = 9999, + validate: bool = True, + transition_counter_max: int | Literal[False] = False, + ): + self.nthreads = nthreads + + # address may not be known yet when the State Machine is initialised. + # Raise AttributeError if a method tries reading it before it's been set. + if address: + self.address = address + + # These collections are normally passed by reference by the Worker. + # For the sake of convenience, create independent ones during unit tests. + self.data = data if data is not None else {} + self.threads = threads if threads is not None else {} + self.plugins = plugins if plugins is not None else {} + self.available_resources = dict(resources) if resources is not None else {} + + self.validate = validate + self.tasks = {} + self.running = True + self.waiting_for_data_count = 0 + self.has_what = defaultdict(set) + self.data_needed = HeapSet(key=operator.attrgetter("priority")) + self.data_needed_per_worker = defaultdict( + lambda: HeapSet(key=operator.attrgetter("priority")) + ) + self.in_flight_workers = {} + self.busy_workers = set() + self.total_out_connections = total_out_connections + self.comm_threshold_bytes = int(10e6) + self.comm_nbytes = 0 + self.missing_dep_flight = set() + self.generation = 0 + self.ready = [] + self.constrained = deque() + self.executing = set() + self.in_flight_tasks = set() + self.executed_count = 0 + self.long_running = set() + self.target_message_size = int(50e6) # 50 MB + self.log = deque(maxlen=100_000) + self.stimulus_log = deque(maxlen=10_000) + self.transition_counter = 0 + self.transition_counter_max = transition_counter_max + self.actors = {} + + def handle_stimulus(self, stim: StateMachineEvent) -> Instructions: + """Process an external event, transition relevant tasks to new states, and + return a list of instructions to be executed as a consequence. + + See also + -------- + BaseWorker.handle_stimulus + """ + if not isinstance(stim, FindMissingEvent): + self.stimulus_log.append(stim.to_loggable(handled=time())) + recs, instructions = self._handle_event(stim) + instructions += self._transitions(recs, stimulus_id=stim.stimulus_id) + return instructions + + ############# + # Accessors # + ############# + + @property + def executing_count(self) -> int: + """Count of tasks currently executing on this worker. + + See also + -------- + WorkerState.executing + WorkerState.executed_count + WorkerState.nthreads + """ + return len(self.executing) + + @property + def in_flight_tasks_count(self) -> int: + """Count of tasks currently being replicated from other workers to this one. + + See also + -------- + WorkerState.in_flight_tasks + """ + return len(self.in_flight_tasks) ######################### # Shared helper methods # @@ -976,11 +1208,11 @@ def _purge_state(self, ts: TaskState) -> None: ts._next = None ts.done = False - self._executing.discard(ts) - self._in_flight_tasks.discard(ts) + self.executing.discard(ts) + self.in_flight_tasks.discard(ts) def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: - if self.status != Status.running: + if not self.running: return {}, [] skipped_worker_in_flight_or_busy = [] @@ -1061,11 +1293,11 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: return recommendations, instructions def _ensure_computing(self) -> RecsInstrs: - if self.status != Status.running: + if not self.running: return {}, [] recs: Recs = {} - while self.constrained and len(self._executing) < self.nthreads: + while self.constrained and len(self.executing) < self.nthreads: key = self.constrained[0] ts = self.tasks.get(key, None) if ts is None or ts.state != "constrained": @@ -1091,9 +1323,9 @@ def _ensure_computing(self) -> RecsInstrs: self.available_resources[resource] -= needed recs[ts] = "executing" - self._executing.add(ts) + self.executing.add(ts) - while self.ready and len(self._executing) < self.nthreads: + while self.ready and len(self.executing) < self.nthreads: _, key = heapq.heappop(self.ready) ts = self.tasks.get(key) if ts is None: @@ -1113,7 +1345,7 @@ def _ensure_computing(self) -> RecsInstrs: if self.validate: assert ts not in recs or recs[ts] == "executing" recs[ts] = "executing" - self._executing.add(ts) + self.executing.add(ts) return recs, [] @@ -1132,8 +1364,8 @@ def _get_task_finished_msg( typ = ts.type = type(value) del value try: - typ_serialized = dumps_function(typ) - except PicklingError: + typ_serialized = pickle.dumps(typ, protocol=4) + except Exception: # Some types fail pickling (example: _thread.lock objects), # send their name as a best effort. typ_serialized = pickle.dumps(typ.__name__, protocol=4) @@ -1257,7 +1489,7 @@ def _transition_generic_fetch(self, ts: TaskState, stimulus_id: str) -> RecsInst def _transition_missing_waiting( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: - self._missing_dep_flight.discard(ts) + self.missing_dep_flight.discard(ts) self._purge_state(ts) return self._transition_released_waiting(ts, stimulus_id=stimulus_id) @@ -1270,13 +1502,13 @@ def _transition_missing_fetch( if not ts.who_has: return {}, [] - self._missing_dep_flight.discard(ts) + self.missing_dep_flight.discard(ts) return self._transition_generic_fetch(ts, stimulus_id=stimulus_id) def _transition_missing_released( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: - self._missing_dep_flight.discard(ts) + self.missing_dep_flight.discard(ts) recs, instructions = self._transition_generic_released( ts, stimulus_id=stimulus_id ) @@ -1296,7 +1528,7 @@ def _transition_generic_missing( assert not ts.who_has ts.state = "missing" - self._missing_dep_flight.add(ts) + self.missing_dep_flight.add(ts) ts.done = False return {}, [] @@ -1366,7 +1598,7 @@ def _transition_fetch_flight( ts.done = False ts.state = "flight" ts.coming_from = worker - self._in_flight_tasks.add(ts) + self.in_flight_tasks.add(ts) return {}, [] def _transition_fetch_missing( @@ -1413,7 +1645,7 @@ def _transition_executing_rescheduled( ) -> RecsInstrs: for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity - self._executing.discard(ts) + self.executing.discard(ts) return merge_recs_instructions( ( @@ -1509,7 +1741,7 @@ def _transition_executing_error( ) -> RecsInstrs: for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity - self._executing.discard(ts) + self.executing.discard(ts) return merge_recs_instructions( self._transition_generic_error( @@ -1524,7 +1756,7 @@ def _transition_executing_error( ) def _transition_from_resumed( - self, ts: TaskState, finish: TaskStateState, *args, stimulus_id: str + self, ts: TaskState, finish: TaskStateState, stimulus_id: str ) -> RecsInstrs: """`resumed` is an intermediate degenerate state which splits further up into two states depending on what the last signal / next state is @@ -1650,8 +1882,8 @@ def _transition_cancelled_released( ) -> RecsInstrs: if not ts.done: return {}, [] - self._executing.discard(ts) - self._in_flight_tasks.discard(ts) + self.executing.discard(ts) + self.in_flight_tasks.discard(ts) for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity @@ -1686,8 +1918,8 @@ def _transition_generic_memory( for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity - self._executing.discard(ts) - self._in_flight_tasks.discard(ts) + self.executing.discard(ts) + self.in_flight_tasks.discard(ts) ts.coming_from = None instructions: Instructions = [] @@ -1713,7 +1945,7 @@ def _transition_executing_memory( assert not ts.waiting_for_data assert ts.key not in self.ready - self._executing.discard(ts) + self.executing.discard(ts) self.executed_count += 1 return merge_recs_instructions( self._transition_generic_memory(ts, value=value, stimulus_id=stimulus_id), @@ -1774,7 +2006,7 @@ def _transition_flight_error( *, stimulus_id: str, ) -> RecsInstrs: - self._in_flight_tasks.discard(ts) + self.in_flight_tasks.discard(ts) ts.coming_from = None return self._transition_generic_error( ts, @@ -1811,7 +2043,7 @@ def _transition_executing_long_running( self, ts: TaskState, compute_duration: float, *, stimulus_id: str ) -> RecsInstrs: ts.state = "long-running" - self._executing.discard(ts) + self.executing.discard(ts) self.long_running.add(ts.key) smsg = LongRunningMsg( @@ -1837,7 +2069,7 @@ def _transition_released_memory( def _transition_flight_memory( self, ts: TaskState, value, *, stimulus_id: str ) -> RecsInstrs: - self._in_flight_tasks.discard(ts) + self.in_flight_tasks.discard(ts) ts.coming_from = None try: recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) @@ -2028,7 +2260,7 @@ def _transition( ) return recs, instructions - def _transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: + def _transitions(self, recommendations: Recs, *, stimulus_id: str) -> Instructions: """Process transitions until none are left This includes feedback from previous transitions and continues until we @@ -2053,7 +2285,7 @@ def _transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: for ts in tasks: self.validate_task(ts) - self._handle_instructions(instructions) + return instructions ########## # Events # @@ -2403,11 +2635,17 @@ def _handle_steal_request(self, ev: StealRequestEvent) -> RecsInstrs: return {}, [smsg] @_handle_event.register - def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs: - """Emerge from paused status. Do not send this event directly. Instead, just set - Worker.status back to running. + def _handle_pause(self, ev: PauseEvent) -> RecsInstrs: + """Prevent any further tasks to be executed or gathered. Tasks that are + currently executing or in flight will continue to progress. """ - assert self.status == Status.running + self.running = False + return {}, [] + + @_handle_event.register + def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs: + """Emerge from paused status""" + self.running = True return merge_recs_instructions( self._ensure_computing(), self._ensure_communicating(stimulus_id=ev.stimulus_id), @@ -2492,15 +2730,15 @@ def _handle_reschedule(self, ev: RescheduleEvent) -> RecsInstrs: @_handle_event.register def _handle_find_missing(self, ev: FindMissingEvent) -> RecsInstrs: - if not self._missing_dep_flight: + if not self.missing_dep_flight: return {}, [] if self.validate: - for ts in self._missing_dep_flight: + for ts in self.missing_dep_flight: assert not ts.who_has, self.story(ts) smsg = RequestRefreshWhoHasMsg( - keys=[ts.key for ts in self._missing_dep_flight], + keys=[ts.key for ts in self.missing_dep_flight], stimulus_id=ev.stimulus_id, ) return {}, [smsg] @@ -2548,18 +2786,52 @@ def stimulus_story( keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} return [ev for ev in self.stimulus_log if getattr(ev, "key", None) in keys] + def _to_dict(self, *, exclude: Container[str] = ()) -> dict: + """Dictionary representation for debugging purposes. + Not type stable and not intended for roundtrips. + + See also + -------- + Client.dump_cluster_state + distributed.utils.recursive_to_dict + """ + info = { + "address": self.address, + "nthreads": self.nthreads, + "running": self.running, + "ready": self.ready, + "constrained": self.constrained, + "data": dict.fromkeys(self.data), + "data_needed": [ts.key for ts in self.data_needed.sorted()], + "data_needed_per_worker": { + w: [ts.key for ts in tss.sorted()] + for w, tss in self.data_needed_per_worker.items() + }, + "executing": {ts.key for ts in self.executing}, + "long_running": self.long_running, + "in_flight_tasks": {ts.key for ts in self.in_flight_tasks}, + "in_flight_workers": self.in_flight_workers, + "busy_workers": self.busy_workers, + "log": self.log, + "stimulus_log": self.stimulus_log, + "transition_counter": self.transition_counter, + "tasks": self.tasks, + } + info = {k: v for k, v in info.items() if k not in exclude} + return recursive_to_dict(info, exclude=exclude) + ############## # Validation # ############## - def _validate_task_memory(self, ts): + def _validate_task_memory(self, ts: TaskState) -> None: assert ts.key in self.data or ts.key in self.actors assert isinstance(ts.nbytes, int) assert not ts.waiting_for_data assert ts.key not in self.ready assert ts.state == "memory" - def _validate_task_executing(self, ts): + def _validate_task_executing(self, ts: TaskState) -> None: assert ts.state == "executing" assert ts.run_spec is not None assert ts.key not in self.data @@ -2568,7 +2840,7 @@ def _validate_task_executing(self, ts): assert dep.state == "memory", self.story(dep) assert dep.key in self.data or dep.key in self.actors - def _validate_task_ready(self, ts): + def _validate_task_ready(self, ts: TaskState) -> None: assert ts.key in pluck(1, self.ready) assert ts.key not in self.data assert ts.state != "executing" @@ -2578,22 +2850,22 @@ def _validate_task_ready(self, ts): dep.key in self.data or dep.key in self.actors for dep in ts.dependencies ) - def _validate_task_waiting(self, ts): + def _validate_task_waiting(self, ts: TaskState) -> None: assert ts.key not in self.data assert ts.state == "waiting" assert not ts.done if ts.dependencies and ts.run_spec: assert not all(dep.key in self.data for dep in ts.dependencies) - def _validate_task_flight(self, ts): + def _validate_task_flight(self, ts: TaskState) -> None: assert ts.key not in self.data - assert ts in self._in_flight_tasks + assert ts in self.in_flight_tasks assert not any(dep.key in self.ready for dep in ts.dependents) assert ts.coming_from assert ts.coming_from in self.in_flight_workers assert ts.key in self.in_flight_workers[ts.coming_from] - def _validate_task_fetch(self, ts): + def _validate_task_fetch(self, ts: TaskState) -> None: assert ts.key not in self.data assert self.address not in ts.who_has assert not ts.done @@ -2603,40 +2875,40 @@ def _validate_task_fetch(self, ts): assert ts.key in self.has_what[w] assert ts in self.data_needed_per_worker[w] - def _validate_task_missing(self, ts): + def _validate_task_missing(self, ts: TaskState) -> None: assert ts.key not in self.data assert not ts.who_has assert not ts.done assert not any(ts.key in has_what for has_what in self.has_what.values()) - assert ts in self._missing_dep_flight + assert ts in self.missing_dep_flight - def _validate_task_cancelled(self, ts): + def _validate_task_cancelled(self, ts: TaskState) -> None: assert ts.key not in self.data assert ts._previous in {"long-running", "executing", "flight"} # We'll always transition to released after it is done assert ts._next is None, (ts.key, ts._next, self.story(ts)) - def _validate_task_resumed(self, ts): + def _validate_task_resumed(self, ts: TaskState) -> None: assert ts.key not in self.data assert ts._next assert ts._previous in {"long-running", "executing", "flight"} - def _validate_task_released(self, ts): + def _validate_task_released(self, ts: TaskState) -> None: assert ts.key not in self.data assert not ts._next assert not ts._previous assert ts not in self.data_needed for tss in self.data_needed_per_worker.values(): assert ts not in tss - assert ts not in self._executing - assert ts not in self._in_flight_tasks - assert ts not in self._missing_dep_flight - - # FIXME the below assert statement is true most of the time. If a task - # performs the transition flight->cancel->waiting, its dependencies are - # normally in released state. However, the compute-task call for their - # previous dependent provided them with who_has, such that this assert - # is no longer true. + assert ts not in self.executing + assert ts not in self.in_flight_tasks + assert ts not in self.missing_dep_flight + + # The below assert statement is true most of the time. If a task performs the + # transition flight->cancel->waiting, its dependencies are normally in released + # state. However, the compute-task call for their previous dependent provided + # them with who_has, such that this assert is no longer true. + # # assert not any(ts.key in has_what for has_what in self.has_what.values()) assert not ts.waiting_for_data @@ -2644,7 +2916,7 @@ def _validate_task_released(self, ts): assert not ts.exception assert not ts.traceback - def validate_task(self, ts): + def validate_task(self, ts: TaskState) -> None: try: if ts.key in self.tasks: assert self.tasks[ts.key] == ts @@ -2670,92 +2942,80 @@ def validate_task(self, ts): self._validate_task_released(ts) except Exception as e: logger.exception(e) - if LOG_PDB: - import pdb - - pdb.set_trace() - raise InvalidTaskState( key=ts.key, state=ts.state, story=self.story(ts) ) from e - def validate_state(self): - try: - assert self.executing_count >= 0 - waiting_for_data_count = 0 - for ts in self.tasks.values(): - assert ts.state is not None - # check that worker has task - for worker in ts.who_has: - assert worker != self.address - assert ts.key in self.has_what[worker] - # check that deps have a set state and that dependency<->dependent links - # are there - for dep in ts.dependencies: - # self.tasks was just a dict of tasks - # and this check was originally that the key was in `task_state` - # so we may have popped the key out of `self.tasks` but the - # dependency can still be in `memory` before GC grabs it...? - # Might need better bookkeeping - assert dep.state is not None - assert ts in dep.dependents, ts - if ts.waiting_for_data: - waiting_for_data_count += 1 - for ts_wait in ts.waiting_for_data: - assert ts_wait.key in self.tasks - assert ( - ts_wait.state - in READY | {"executing", "flight", "fetch", "missing"} - or ts_wait in self._missing_dep_flight - or ts_wait.who_has.issubset(self.in_flight_workers) - ), (ts, ts_wait, self.story(ts), self.story(ts_wait)) - # FIXME https://github.com/dask/distributed/issues/6319 - # assert self.waiting_for_data_count == waiting_for_data_count - for worker, keys in self.has_what.items(): + def validate_state(self) -> None: + assert len(self.executing) >= 0 + waiting_for_data_count = 0 + for ts in self.tasks.values(): + assert ts.state is not None + # check that worker has task + for worker in ts.who_has: assert worker != self.address - for k in keys: - assert k in self.tasks, self.story(k) - assert worker in self.tasks[k].who_has - - for ts in self.data_needed: - assert ts.state == "fetch", self.story(ts) + assert ts.key in self.has_what[worker] + # check that deps have a set state and that dependency<->dependent links + # are there + for dep in ts.dependencies: + # self.tasks was just a dict of tasks + # and this check was originally that the key was in `task_state` + # so we may have popped the key out of `self.tasks` but the + # dependency can still be in `memory` before GC grabs it...? + # Might need better bookkeeping + assert dep.state is not None + assert ts in dep.dependents, ts + if ts.waiting_for_data: + waiting_for_data_count += 1 + for ts_wait in ts.waiting_for_data: + assert ts_wait.key in self.tasks + assert ( + ts_wait.state in READY | {"executing", "flight", "fetch", "missing"} + or ts_wait in self.missing_dep_flight + or ts_wait.who_has.issubset(self.in_flight_workers) + ), (ts, ts_wait, self.story(ts), self.story(ts_wait)) + # FIXME https://github.com/dask/distributed/issues/6319 + # assert self.waiting_for_data_count == waiting_for_data_count + for worker, keys in self.has_what.items(): + assert worker != self.address + for k in keys: + assert k in self.tasks, self.story(k) + assert worker in self.tasks[k].who_has + + for ts in self.data_needed: + assert ts.state == "fetch", self.story(ts) + assert self.tasks[ts.key] is ts + for worker, tss in self.data_needed_per_worker.items(): + for ts in tss: + assert ts.state == "fetch" assert self.tasks[ts.key] is ts - for worker, tss in self.data_needed_per_worker.items(): - for ts in tss: - assert ts.state == "fetch" - assert self.tasks[ts.key] is ts - assert ts in self.data_needed - assert worker in ts.who_has - - for ts in self.tasks.values(): - self.validate_task(ts) - - if self.transition_counter_max: - assert self.transition_counter < self.transition_counter_max - - except Exception as e: - logger.error("Validate state failed", exc_info=e) - logger.exception(e) - if LOG_PDB: - import pdb - - pdb.set_trace() + assert ts in self.data_needed + assert worker in ts.who_has - if hasattr(e, "to_event"): - topic, msg = e.to_event() - self.log_event(topic, msg) + for ts in self.tasks.values(): + self.validate_task(ts) - raise + if self.transition_counter_max: + assert self.transition_counter < self.transition_counter_max class BaseWorker(abc.ABC): + """Wrapper around the :class:`WorkerState` that implements instructions handling. + This is an abstract class with several ``@abc.abstractmethod`` methods, to be + subclassed by :class:`~distributed.worker.Worker` and by unit test mock-ups. + """ + + state: WorkerState _async_instructions: set[asyncio.Task] - @fail_hard - @log_errors + def __init__(self, state: WorkerState): + self.state = state + self._async_instructions = set() + def _handle_stimulus_from_task( self, task: asyncio.Task[StateMachineEvent | None] ) -> None: + """An asynchronous instruction just completed; process the returned stimulus.""" self._async_instructions.remove(task) try: # This *should* never raise any other exceptions @@ -2765,8 +3025,19 @@ def _handle_stimulus_from_task( if stim: self.handle_stimulus(stim) - @fail_hard - def _handle_instructions(self, instructions: Instructions) -> None: + def handle_stimulus(self, stim: StateMachineEvent) -> None: + """Forward an external stimulus to :meth:`WorkerState.handle_stimulus` and + process the returned instructions, invoking the relevant Worker callbacks + (``@abc.abstractmethod`` methods below). + + Spawn asyncio tasks for all asynchronous instructions and start tracking them. + + See also + -------- + WorkerState.handle_stimulus + """ + instructions = self.state.handle_stimulus(stim) + while instructions: ensure_communicating: EnsureCommunicatingAfterTransitions | None = None for inst in instructions: @@ -2824,9 +3095,103 @@ def _handle_instructions(self, instructions: Instructions) -> None: if ensure_communicating: # Potentially re-fill instructions, causing a second iteration of `while # instructions` at the top of this method - recs, instructions = self._ensure_communicating( + # FIXME access to private methods + # https://github.com/dask/distributed/issues/6497 + recs, instructions = self.state._ensure_communicating( stimulus_id=ensure_communicating.stimulus_id ) - self._transitions(recs, stimulus_id=ensure_communicating.stimulus_id) + instructions += self.state._transitions( + recs, stimulus_id=ensure_communicating.stimulus_id + ) else: - instructions = [] + return + + async def close(self, timeout: float = 30) -> None: + """Cancel all asynchronous instructions""" + if not self._async_instructions: + return + for task in self._async_instructions: + task.cancel() + # async tasks can handle cancellation and could take an arbitrary amount + # of time to terminate + _, pending = await asyncio.wait(self._async_instructions, timeout=timeout) + for task in pending: + logger.error( + f"Failed to cancel asyncio task after {timeout} seconds: {task}" + ) + + @abc.abstractmethod + def batched_send(self, msg: dict[str, Any]) -> None: + """Send a fire-and-forget message to the scheduler through bulk comms. + + Parameters + ---------- + msg: dict + msgpack-serializable message to send to the scheduler. + Must have a 'op' key which is registered in Scheduler.stream_handlers. + """ + ... + + @abc.abstractmethod + async def gather_dep( + self, + worker: str, + to_gather: Collection[str], + total_nbytes: int, + *, + stimulus_id: str, + ) -> StateMachineEvent | None: + """Gather dependencies for a task from a worker who has them + + Parameters + ---------- + worker : str + Address of worker to gather dependencies from + to_gather : list + Keys of dependencies to gather from worker -- this is not + necessarily equivalent to the full list of dependencies of ``dep`` + as some dependencies may already be present on this worker. + total_nbytes : int + Total number of bytes for all the dependencies in to_gather combined + """ + ... + + @abc.abstractmethod + async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None: + """Execute a task""" + ... + + @abc.abstractmethod + async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None: + """Wait some time, then take a peer worker out of busy state""" + ... + + +class DeprecatedWorkerStateAttribute: + name: str + target: str | None + + def __init__(self, target: str | None = None): + self.target = target + + def __set_name__(self, owner: type, name: str) -> None: + self.name = name + + def _warn_deprecated(self) -> None: + pass + # warnings.warn( + # f"The `Worker.{self.name}` attribute has been moved to " + # f"`Worker.state.{self.target or self.name}", + # FutureWarning, + # ) + + def __get__(self, instance: Worker | None, _): + if instance is None: + # This is triggered by Sphinx + return None # pragma: nocover + self._warn_deprecated() + return getattr(instance.state, self.target or self.name) + + def __set__(self, instance: Worker, value) -> None: + self._warn_deprecated() + setattr(instance.state, self.target or self.name, value) From 82bcf8b6c2cd1bbbbd3466c10ec43f40492f379f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 13 Jun 2022 14:29:34 +0100 Subject: [PATCH 3/3] Everything else --- distributed/active_memory_manager.py | 3 +- distributed/diagnostics/plugin.py | 2 +- distributed/node.py | 7 +- distributed/scheduler.py | 5 +- distributed/shuffle/shuffle_extension.py | 2 +- .../tests/test_active_memory_manager.py | 6 +- distributed/tests/test_client.py | 3 +- distributed/tests/test_utils_test.py | 25 +++---- distributed/tests/test_worker.py | 42 ++++++----- .../tests/test_worker_state_machine.py | 69 +++++++++++++++++++ distributed/utils_test.py | 21 +++--- distributed/worker_memory.py | 11 ++- docs/source/worker.rst | 10 +++ 13 files changed, 146 insertions(+), 60 deletions(-) diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 6af1fbe3be..7fef5c85ee 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -416,8 +416,9 @@ def run( ) -> SuggestionGenerator: """This method is invoked by the ActiveMemoryManager every few seconds, or whenever the user invokes ``client.amm.run_once``. + It is an iterator that must emit - :class:`~distributed.active_memory_manager.Suggestion`s: + :class:`~distributed.active_memory_manager.Suggestion` objects: - ``Suggestion("replicate", )`` - ``Suggestion("replicate", , {subset of potential workers to replicate to})`` diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 2f431da3ae..bec5f72c7a 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -334,7 +334,7 @@ def __init__(self, filepath): async def setup(self, worker): response = await worker.upload_file( - comm=None, filename=self.filename, data=self.data, load=True + filename=self.filename, data=self.data, load=True ) assert len(self.data) == response["nbytes"] diff --git a/distributed/node.py b/distributed/node.py index 6fedd1b8ac..922125d887 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -77,15 +77,16 @@ def stop_services(self): def service_ports(self): return {k: v.port for k, v in self.services.items()} - def _setup_logging(self, logger): + def _setup_logging(self, *loggers): self._deque_handler = DequeHandler( n=dask.config.get("distributed.admin.log-length") ) self._deque_handler.setFormatter( logging.Formatter(dask.config.get("distributed.admin.log-format")) ) - logger.addHandler(self._deque_handler) - weakref.finalize(self, logger.removeHandler, self._deque_handler) + for logger in loggers: + logger.addHandler(self._deque_handler) + weakref.finalize(self, logger.removeHandler, self._deque_handler) def get_logs(self, start=0, n=None, timestamps=False): """ diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a51144d5b2..26c5146d38 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -347,7 +347,10 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: class WorkerState: - """A simple object holding information about a worker.""" + """A simple object holding information about a worker. + + Not to be confused with :class:`distributed.worker_state_machine.WorkerState`. + """ #: This worker's unique key. This can be its connected address #: (such as ``"tcp://127.0.0.1:8891"``) or an alias (such as ``"alice"``). diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index dbf6946001..3f26c595a8 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -230,7 +230,7 @@ def __init__(self, worker: Worker) -> None: # Initialize self.worker: Worker = worker self.shuffles: dict[ShuffleId, Shuffle] = {} - self.executor = ThreadPoolExecutor(worker.nthreads) + self.executor = ThreadPoolExecutor(worker.state.nthreads) # Handlers ########## diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index 8fcc3f31ce..27768854bc 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -866,8 +866,8 @@ async def test_RetireWorker_no_recipients(c, s, w1, w2, w3, w4): assert set(out) in ({w1.address, w3.address}, {w1.address, w4.address}) assert not s.extensions["amm"].policies assert set(s.workers) in ({w2.address, w3.address}, {w2.address, w4.address}) - # After a Scheduler -> Worker -> WorkerState roundtrip, workers that failed to - # retire went back from closing_gracefully to running and can run tasks + # After a Scheduler -> Worker -> Scheduler roundtrip, workers that failed to retire + # went back from closing_gracefully to running and can run tasks while any(ws.status != Status.running for ws in s.workers.values()): await asyncio.sleep(0.01) assert await c.submit(inc, 1) == 2 @@ -896,7 +896,7 @@ async def test_RetireWorker_all_recipients_are_paused(c, s, a, b): assert not s.extensions["amm"].policies assert set(s.workers) == {a.address, b.address} - # After a Scheduler -> Worker -> WorkerState roundtrip, workers that failed to + # After a Scheduler -> Worker -> Scheduler roundtrip, workers that failed to # retire went back from closing_gracefully to running and can run tasks while ws_a.status != Status.running: await asyncio.sleep(0.01) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 28d490f15b..f1d5dba234 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1612,6 +1612,7 @@ def g(): os.remove("myfile.zip") +@pytest.mark.slow @gen_cluster(client=True) async def test_upload_file_egg(c, s, a, b): pytest.importorskip("setuptools") @@ -6810,7 +6811,7 @@ async def test_workers_collection_restriction(c, s, a, b): assert a.data and not b.data -@gen_cluster(client=True, nthreads=[("127.0.0.1", 0)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) async def test_get_client_functions_spawn_clusters(c, s, a): # see gh4565 diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 585fbad081..78523561d9 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -8,6 +8,7 @@ import threading from contextlib import contextmanager from time import sleep +from unittest import mock import pytest import yaml @@ -44,7 +45,8 @@ from distributed.worker_state_machine import ( InvalidTaskState, InvalidTransition, - StateMachineEvent, + PauseEvent, + WorkerState, ) @@ -656,22 +658,17 @@ def test_start_failure_scheduler(): def test_invalid_transitions(capsys): - class BrokenEvent(StateMachineEvent): - pass - - class MyWorker(Worker): - @Worker._handle_event.register - def _(self, ev: BrokenEvent): - ts = next(iter(self.tasks.values())) - return {ts: "foo"}, [] - - @gen_cluster(client=True, Worker=MyWorker, nthreads=[("", 1)]) + @gen_cluster(client=True, nthreads=[("", 1)]) async def test_log_invalid_transitions(c, s, a): x = c.submit(inc, 1, key="task-name") await x - - with pytest.raises(InvalidTransition): - a.handle_stimulus(BrokenEvent(stimulus_id="test")) + ts = a.tasks["task-name"] + ev = PauseEvent(stimulus_id="test") + with mock.patch.object( + WorkerState, "_handle_event", return_value=({ts: "foo"}, []) + ): + with pytest.raises(InvalidTransition): + a.handle_stimulus(ev) while not s.events["invalid-worker-transition"]: await asyncio.sleep(0.01) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 32cee6209b..0f1c29d95f 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1557,7 +1557,9 @@ async def f(ev): task for task in asyncio.all_tasks() if "execute(f1)" in task.get_name() ) start = time() - with captured_logger("distributed.worker", level=logging.ERROR) as logger: + with captured_logger( + "distributed.worker_state_machine", level=logging.ERROR + ) as logger: await a.close(timeout=1) assert "Failed to cancel asyncio task" in logger.getvalue() assert time() - start < 5 @@ -2030,7 +2032,7 @@ async def test_gather_dep_from_remote_workers_if_all_local_workers_are_busy( assert_story(a.story("receive-dep"), [("receive-dep", rw.address, {"f"})]) -@gen_cluster(client=True, nthreads=[("127.0.0.1", 0)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) async def test_worker_client_uses_default_no_close(c, s, a): """ If a default client is available in the process, the worker will pick this @@ -2057,7 +2059,7 @@ def get_worker_client_id(): assert c is c_def -@gen_cluster(nthreads=[("127.0.0.1", 0)]) +@gen_cluster(nthreads=[("127.0.0.1", 1)]) async def test_worker_client_closes_if_created_on_worker_one_worker(s, a): async with Client(s.address, set_as_default=False, asynchronous=True) as c: with pytest.raises(ValueError): @@ -2542,7 +2544,7 @@ def raise_exc(*args): await asyncio.sleep(0.01) -@gen_cluster(client=True, nthreads=[("127.0.0.1", x) for x in range(4)]) +@gen_cluster(client=True, nthreads=[("", x) for x in (1, 2, 3, 4)]) async def test_hold_on_to_replicas(c, s, *workers): f1 = c.submit(inc, 1, workers=[workers[0].address], key="f1") f2 = c.submit(inc, 2, workers=[workers[1].address], key="f2") @@ -3283,14 +3285,28 @@ async def test_Worker__to_dict(c, s, a): "type", "id", "scheduler", - "nthreads", "address", "status", "thread_id", + "logs", + "config", + "incoming_transfer_log", + "outgoing_transfer_log", + # Attributes of WorkerMemoryManager + "data", + "max_spill", + "memory_limit", + "memory_monitor_interval", + "memory_pause_fraction", + "memory_spill_fraction", + "memory_target_fraction", + # Attributes of WorkerState + "nthreads", + "running", "ready", "constrained", + "executing", "long_running", - "executing_count", "in_flight_tasks", "in_flight_workers", "busy_workers", @@ -3298,23 +3314,11 @@ async def test_Worker__to_dict(c, s, a): "stimulus_log", "transition_counter", "tasks", - "logs", - "config", - "incoming_transfer_log", - "outgoing_transfer_log", "data_needed", "data_needed_per_worker", - # attributes of WorkerMemoryManager - "data", - "max_spill", - "memory_limit", - "memory_monitor_interval", - "memory_pause_fraction", - "memory_spill_fraction", - "memory_target_fraction", } assert d["tasks"]["x"]["key"] == "x" - assert d["data"] == ["x"] + assert d["data"] == {"x": None} @gen_cluster(nthreads=[]) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 4fa937581c..6172083408 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -32,6 +32,7 @@ TaskState, TaskStateState, UpdateDataEvent, + WorkerState, merge_recs_instructions, ) @@ -72,6 +73,74 @@ def test_TaskState__to_dict(): ] +def test_WorkerState__to_dict(): + ws = WorkerState(8) + ws.address = "127.0.0.1.1234" + ws.handle_stimulus( + AcquireReplicasEvent(who_has={"x": ["127.0.0.1:1235"]}, stimulus_id="s1") + ) + ws.handle_stimulus( + UpdateDataEvent(data={"y": object()}, report=False, stimulus_id="s2") + ) + + actual = recursive_to_dict(ws) + # Remove timestamps + for ev in actual["log"]: + del ev[-1] + for stim in actual["stimulus_log"]: + del stim["handled"] + + expect = { + "address": "127.0.0.1.1234", + "busy_workers": [], + "constrained": [], + "data": {"y": None}, + "data_needed": ["x"], + "data_needed_per_worker": {"127.0.0.1:1235": ["x"]}, + "executing": [], + "in_flight_tasks": [], + "in_flight_workers": {}, + "log": [ + ["x", "ensure-task-exists", "released", "s1"], + ["x", "released", "fetch", "fetch", {}, "s1"], + ["y", "put-in-memory", "s2"], + ["y", "receive-from-scatter", "s2"], + ], + "long_running": [], + "nthreads": 8, + "ready": [], + "running": True, + "stimulus_log": [ + { + "cls": "AcquireReplicasEvent", + "stimulus_id": "s1", + "who_has": {"x": ["127.0.0.1:1235"]}, + }, + { + "cls": "UpdateDataEvent", + "data": {"y": None}, + "report": False, + "stimulus_id": "s2", + }, + ], + "tasks": { + "x": { + "key": "x", + "priority": [1], + "state": "fetch", + "who_has": ["127.0.0.1:1235"], + }, + "y": { + "key": "y", + "nbytes": 16, + "state": "memory", + }, + }, + "transition_counter": 1, + } + assert actual == expect + + def traverse_subclasses(cls: type) -> Iterator[type]: yield cls for subcls in cls.__subclasses__(): diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 3495657128..42e90de0eb 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -70,7 +70,8 @@ reset_logger_locks, sync, ) -from distributed.worker import WORKER_ANY_RUNNING, InvalidTransition, Worker +from distributed.worker import WORKER_ANY_RUNNING, Worker +from distributed.worker_state_machine import InvalidTransition try: import ssl @@ -1271,8 +1272,10 @@ def validate_state(*servers: Scheduler | Worker | Nanny) -> None: Excludes workers wrapped by Nannies and workers manually started by the test. """ for s in servers: - if s.validate and hasattr(s, "validate_state"): - s.validate_state() # type: ignore + if isinstance(s, Scheduler) and s.validate: + s.validate_state() + elif isinstance(s, Worker) and s.state.validate: + s.validate_state() def raises(func, exc=Exception): @@ -2322,13 +2325,13 @@ def freeze_data_fetching(w: Worker, *, jump_start: bool = False): If True, trigger ensure_communicating on exit; this simulates e.g. an unrelated worker moving out of in_flight_workers. """ - old_out_connections = w.total_out_connections - old_comm_threshold = w.comm_threshold_bytes - w.total_out_connections = 0 - w.comm_threshold_bytes = 0 + old_out_connections = w.state.total_out_connections + old_comm_threshold = w.state.comm_threshold_bytes + w.state.total_out_connections = 0 + w.state.comm_threshold_bytes = 0 yield - w.total_out_connections = old_out_connections - w.comm_threshold_bytes = old_comm_threshold + w.state.total_out_connections = old_out_connections + w.state.comm_threshold_bytes = old_comm_threshold if jump_start: w.status = Status.paused w.status = Status.running diff --git a/distributed/worker_memory.py b/distributed/worker_memory.py index 5132afb2a3..e3aaad21b2 100644 --- a/distributed/worker_memory.py +++ b/distributed/worker_memory.py @@ -68,6 +68,7 @@ def __init__( self, worker: Worker, *, + nthreads: int, memory_limit: str | float = "auto", # This should be None most of the times, short of a power user replacing the # SpillBuffer with their own custom dict-like @@ -84,7 +85,7 @@ def __init__( memory_spill_fraction: float | Literal[False] | None = None, memory_pause_fraction: float | Literal[False] | None = None, ): - self.memory_limit = parse_memory_limit(memory_limit, worker.nthreads) + self.memory_limit = parse_memory_limit(memory_limit, nthreads) self.memory_target_fraction = _parse_threshold( "distributed.worker.memory.target", @@ -293,12 +294,8 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None: ) def _to_dict(self, *, exclude: Container[str] = ()) -> dict: - info = { - k: v - for k, v in self.__dict__.items() - if not k.startswith("_") and k != "data" and k not in exclude - } - info["data"] = list(self.data) + info = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} + info["data"] = dict.fromkeys(self.data) return info diff --git a/docs/source/worker.rst b/docs/source/worker.rst index 91cac09947..b3d9f8c0cb 100644 --- a/docs/source/worker.rst +++ b/docs/source/worker.rst @@ -162,8 +162,18 @@ process. API Documentation ----------------- +.. currentmodule:: distributed.worker_state_machine + .. autoclass:: distributed.worker_state_machine.TaskState :members: +.. autoclass:: distributed.worker_state_machine.WorkerState + :members: + +.. autoclass:: distributed.worker_state_machine.BaseWorker + :members: + +.. currentmodule:: distributed.worker + .. autoclass:: distributed.worker.Worker :members: