diff --git a/distributed/stories.py b/distributed/_stories.py similarity index 100% rename from distributed/stories.py rename to distributed/_stories.py diff --git a/distributed/client.py b/distributed/client.py index 9ca8c0b2d2..af9bb32567 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4276,6 +4276,31 @@ def collections_to_dsk(collections, *args, **kwargs): """Convert many collections into a single dask graph, after optimization""" return collections_to_dsk(collections, *args, **kwargs) + async def _story(self, keys=(), on_error="raise"): + assert on_error in ("raise", "ignore") + + try: + flat_stories = await self.scheduler.get_story(keys=keys) + flat_stories = [("scheduler", *msg) for msg in flat_stories] + except Exception: + if on_error == "raise": + raise + elif on_error == "ignore": + flat_stories = [] + else: + raise ValueError(f"on_error not in {'raise', 'ignore'}") + + responses = await self.scheduler.broadcast( + msg={"op": "get_story", "keys": keys}, on_error=on_error + ) + for worker, stories in responses.items(): + flat_stories.extend((worker, *msg) for msg in stories) + return flat_stories + + def story(self, *keys_or_stimulus_ids, on_error="raise"): + """Returns a cluster-wide story for the given keys or simtulus_id's""" + return self.sync(self._story, keys=keys_or_stimulus_ids, on_error=on_error) + def get_task_stream( self, start=None, diff --git a/distributed/cluster_dump.py b/distributed/cluster_dump.py index 161f9091e7..5c77dc14ac 100644 --- a/distributed/cluster_dump.py +++ b/distributed/cluster_dump.py @@ -10,9 +10,9 @@ import fsspec import msgpack +from distributed._stories import scheduler_story as _scheduler_story +from distributed._stories import worker_story as _worker_story from distributed.compatibility import to_thread -from distributed.stories import scheduler_story as _scheduler_story -from distributed.stories import worker_story as _worker_story DEFAULT_CLUSTER_DUMP_FORMAT: Literal["msgpack" | "yaml"] = "msgpack" DEFAULT_CLUSTER_DUMP_EXCLUDE: Collection[str] = ("run_spec",) diff --git a/distributed/http/templates/task.html b/distributed/http/templates/task.html index 0b5c10695e..f10aaad560 100644 --- a/distributed/http/templates/task.html +++ b/distributed/http/templates/task.html @@ -118,16 +118,18 @@

Transition Log

Key Start Finish + Stimulus ID Recommended Key Recommended Action - {% for key, start, finish, recommendations, transition_time in scheduler.story(Task) %} + {% for key, start, finish, recommendations, stimulus_id, transition_time in scheduler.story(Task) %} {{ fromtimestamp(transition_time) }} {{key}} {{ start }} {{ finish }} + {{ stimulus_id }} @@ -137,6 +139,7 @@

Transition Log

+ {{key2}} {{ rec }} diff --git a/distributed/nanny.py b/distributed/nanny.py index 826c3766e1..20778a5197 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -289,7 +289,10 @@ async def _unregister(self, timeout=10): allowed_errors = (TimeoutError, CommClosedError, EnvironmentError, RPCClosed) with suppress(allowed_errors): await asyncio.wait_for( - self.scheduler.unregister(address=self.worker_address), timeout + self.scheduler.unregister( + address=self.worker_address, stimulus_id=f"nanny-close-{time()}" + ), + timeout, ) @property diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f65e34aeea..151ebaf7c8 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -54,6 +54,7 @@ from distributed import cluster_dump, preloading, profile from distributed import versions as version_module +from distributed._stories import scheduler_story from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker from distributed.batched import BatchedSend from distributed.comm import ( @@ -83,7 +84,6 @@ from distributed.security import Security from distributed.semaphore import SemaphoreExtension from distributed.stealing import WorkStealing -from distributed.stories import scheduler_story from distributed.utils import ( All, TimeoutError, @@ -109,6 +109,7 @@ DEFAULT_DATA_SIZE = parse_bytes( dask.config.get("distributed.scheduler.default-data-size") ) +STIMULUS_ID_UNSET = "" DEFAULT_EXTENSIONS = { "locks": LockExtension, @@ -1527,7 +1528,7 @@ def new_task( # State Transitions # ##################### - def _transition(self, key, finish: str, *args, **kwargs): + def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs): """Transition a key from its current state to the finish state Examples @@ -1562,14 +1563,16 @@ def _transition(self, key, finish: str, *args, **kwargs): start_finish = (start, finish) func = self.transitions_table.get(start_finish) if func is not None: - recommendations, client_msgs, worker_msgs = func(key, *args, **kwargs) # type: ignore + recommendations, client_msgs, worker_msgs = func( + key, stimulus_id, *args, **kwargs + ) # type: ignore self.transition_counter += 1 elif "released" not in start_finish: assert not args and not kwargs, (args, kwargs, start_finish) a_recs: dict a_cmsgs: dict a_wmsgs: dict - a: tuple = self._transition(key, "released") + a: tuple = self._transition(key, "released", stimulus_id) a_recs, a_cmsgs, a_wmsgs = a v = a_recs.get(key, finish) @@ -1577,7 +1580,7 @@ def _transition(self, key, finish: str, *args, **kwargs): b_recs: dict b_cmsgs: dict b_wmsgs: dict - b: tuple = func(key) # type: ignore + b: tuple = func(key, stimulus_id) # type: ignore b_recs, b_cmsgs, b_wmsgs = b recommendations.update(a_recs) @@ -1612,13 +1615,20 @@ def _transition(self, key, finish: str, *args, **kwargs): else: raise RuntimeError("Impossible transition from %r to %r" % start_finish) + if not stimulus_id: + stimulus_id = STIMULUS_ID_UNSET + finish2 = ts._state # FIXME downcast antipattern scheduler = cast(Scheduler, self) scheduler.transition_log.append( - (key, start, finish2, recommendations, time()) + (key, start, finish2, recommendations, stimulus_id, time()) ) if self.validate: + if stimulus_id == STIMULUS_ID_UNSET: + raise RuntimeError( + "stimulus_id not set during Scheduler transition" + ) logger.debug( "Transitioned %r %s->%s (actual: %s). Consequence: %s", key, @@ -1662,7 +1672,13 @@ def _transition(self, key, finish: str, *args, **kwargs): pdb.set_trace() raise - def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: dict): + def _transitions( + self, + recommendations: dict, + client_msgs: dict, + worker_msgs: dict, + stimulus_id: str, + ): """Process transitions until none are left This includes feedback from previous transitions and continues until we @@ -1680,7 +1696,7 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di key, finish = recommendations.popitem() keys.add(key) - new = self._transition(key, finish) + new = self._transition(key, finish, stimulus_id) new_recs, new_cmsgs, new_wmsgs = new recommendations.update(new_recs) @@ -1703,7 +1719,7 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di for key in keys: scheduler.validate_key(key) - def transition_released_waiting(self, key): + def transition_released_waiting(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] dts: TaskState @@ -1758,7 +1774,7 @@ def transition_released_waiting(self, key): pdb.set_trace() raise - def transition_no_worker_waiting(self, key): + def transition_no_worker_waiting(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] dts: TaskState @@ -1806,7 +1822,13 @@ def transition_no_worker_waiting(self, key): raise def transition_no_worker_memory( - self, key, nbytes=None, type=None, typename: str = None, worker=None + self, + key, + stimulus_id, + nbytes=None, + type=None, + typename: str = None, + worker=None, ): try: ws: WorkerState = self.workers[worker] @@ -1961,7 +1983,7 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> float: return total_duration - def transition_waiting_processing(self, key): + def transition_waiting_processing(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] dts: TaskState @@ -2006,7 +2028,14 @@ def transition_waiting_processing(self, key): raise def transition_waiting_memory( - self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs + self, + key, + stimulus_id, + nbytes=None, + type=None, + typename: str = None, + worker=None, + **kwargs, ): try: ws: WorkerState = self.workers[worker] @@ -2048,6 +2077,7 @@ def transition_waiting_memory( def transition_processing_memory( self, key, + stimulus_id, nbytes=None, type=None, typename: str = None, @@ -2091,7 +2121,7 @@ def transition_processing_memory( { "op": "cancel-compute", "key": key, - "stimulus_id": f"processing-memory-{time()}", + "stimulus_id": stimulus_id, } ] @@ -2141,7 +2171,7 @@ def transition_processing_memory( pdb.set_trace() raise - def transition_memory_released(self, key, safe: bool = False): + def transition_memory_released(self, key, stimulus_id, safe: bool = False): ws: WorkerState try: ts: TaskState = self.tasks[key] @@ -2179,7 +2209,7 @@ def transition_memory_released(self, key, safe: bool = False): worker_msg = { "op": "free-keys", "keys": [key], - "stimulus_id": f"memory-released-{time()}", + "stimulus_id": stimulus_id, } for ws in ts.who_has: worker_msgs[ws.address] = [worker_msg] @@ -2211,7 +2241,7 @@ def transition_memory_released(self, key, safe: bool = False): pdb.set_trace() raise - def transition_released_erred(self, key): + def transition_released_erred(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] dts: TaskState @@ -2256,7 +2286,7 @@ def transition_released_erred(self, key): pdb.set_trace() raise - def transition_erred_released(self, key): + def transition_erred_released(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] dts: TaskState @@ -2282,7 +2312,7 @@ def transition_erred_released(self, key): w_msg = { "op": "free-keys", "keys": [key], - "stimulus_id": f"erred-released-{time()}", + "stimulus_id": stimulus_id, } for ws_addr in ts.erred_on: worker_msgs[ws_addr] = [w_msg] @@ -2304,7 +2334,7 @@ def transition_erred_released(self, key): pdb.set_trace() raise - def transition_waiting_released(self, key): + def transition_waiting_released(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] recommendations: dict = {} @@ -2341,7 +2371,7 @@ def transition_waiting_released(self, key): pdb.set_trace() raise - def transition_processing_released(self, key): + def transition_processing_released(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] dts: TaskState @@ -2361,7 +2391,7 @@ def transition_processing_released(self, key): { "op": "free-keys", "keys": [key], - "stimulus_id": f"processing-released-{time()}", + "stimulus_id": stimulus_id, } ] @@ -2395,6 +2425,7 @@ def transition_processing_released(self, key): def transition_processing_erred( self, key: str, + stimulus_id: str, cause: str = None, exception=None, traceback=None, @@ -2481,7 +2512,7 @@ def transition_processing_erred( pdb.set_trace() raise - def transition_no_worker_released(self, key): + def transition_no_worker_released(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] dts: TaskState @@ -2523,7 +2554,7 @@ def remove_key(self, key): ts.exception_blame = ts.exception = ts.traceback = None self.task_metadata.pop(key, None) - def transition_memory_forgotten(self, key): + def transition_memory_forgotten(self, key, stimulus_id): ws: WorkerState try: ts: TaskState = self.tasks[key] @@ -2551,7 +2582,7 @@ def transition_memory_forgotten(self, key): for ws in ts.who_has: ws.actors.discard(ts) - _propagate_forgotten(self, ts, recommendations, worker_msgs) + _propagate_forgotten(self, ts, recommendations, worker_msgs, stimulus_id) client_msgs = _task_to_client_msgs(self, ts) self.remove_key(key) @@ -2565,7 +2596,7 @@ def transition_memory_forgotten(self, key): pdb.set_trace() raise - def transition_released_forgotten(self, key): + def transition_released_forgotten(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] recommendations: dict = {} @@ -2589,7 +2620,7 @@ def transition_released_forgotten(self, key): else: assert 0, (ts,) - _propagate_forgotten(self, ts, recommendations, worker_msgs) + _propagate_forgotten(self, ts, recommendations, worker_msgs, stimulus_id) client_msgs = _task_to_client_msgs(self, ts) self.remove_key(key) @@ -3154,6 +3185,7 @@ def __init__( "get_cluster_state": self.get_cluster_state, "dump_cluster_state_to_url": self.dump_cluster_state_to_url, "benchmark_hardware": self.benchmark_hardware, + "get_story": self.get_story, } connection_limit = get_fileno_limit() / 2 @@ -3480,7 +3512,7 @@ async def close(self, fast=False, close_workers=False): disable_gc_diagnosis() @log_errors - async def close_worker(self, worker: str, safe: bool = False): + async def close_worker(self, worker: str, stimulus_id: str, safe: bool = False): """Remove a worker from the cluster This both removes the worker from our local state and also sends a @@ -3491,7 +3523,7 @@ async def close_worker(self, worker: str, safe: bool = False): self.log_event(worker, {"action": "close-worker"}) # FIXME: This does not handle nannies self.worker_send(worker, {"op": "close", "report": False}) - await self.remove_worker(address=worker, safe=safe) + await self.remove_worker(address=worker, safe=safe, stimulus_id=stimulus_id) ########### # Stimuli # @@ -3629,6 +3661,7 @@ async def add_worker( versions: dict[str, Any] | None = None, nanny=None, extra=None, + stimulus_id=None, ): """Add a new worker to the cluster""" address = self.coerce_address(address, resolve_address) @@ -3727,12 +3760,15 @@ async def add_worker( t: tuple = self._transition( key, "memory", + stimulus_id, worker=address, nbytes=nbytes[key], typename=types[key], ) recommendations, client_msgs, worker_msgs = t - self._transitions(recommendations, client_msgs, worker_msgs) + self._transitions( + recommendations, client_msgs, worker_msgs, stimulus_id + ) recommendations = {} else: already_released_keys.append(key) @@ -3743,7 +3779,7 @@ async def add_worker( { "op": "remove-replicas", "keys": already_released_keys, - "stimulus_id": f"reconnect-already-released-{time()}", + "stimulus_id": stimulus_id, } ) @@ -3751,7 +3787,7 @@ async def add_worker( recommendations.update(self.bulk_schedule_after_adding_worker(ws)) if recommendations: - self._transitions(recommendations, client_msgs, worker_msgs) + self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) @@ -3779,7 +3815,7 @@ async def add_worker( if comm: await comm.write(msg) - await self.handle_worker(comm=comm, worker=address) + await self.handle_worker(comm=comm, worker=address, stimulus_id=stimulus_id) async def add_nanny(self, comm): msg = { @@ -3843,6 +3879,7 @@ def update_graph_hlg( fifo_timeout, annotations, code=code, + stimulus_id=f"update-graph-{time()}", ) def update_graph( @@ -3862,12 +3899,14 @@ def update_graph( fifo_timeout=0, annotations=None, code=None, + stimulus_id=None, ): """ Add new computations to the internal dask graph This happens whenever the Client calls submit, map, get, or compute. """ + stimulus_id = stimulus_id or f"update-graph-{time()}" start = time() fifo_timeout = parse_timedelta(fifo_timeout) keys = set(keys) @@ -3906,7 +3945,9 @@ def update_graph( if k in keys: keys.remove(k) self.report({"op": "cancelled-key", "key": k}, client=client) - self.client_releases_keys(keys=[k], client=client) + self.client_releases_keys( + keys=[k], client=client, stimulus_id=stimulus_id + ) # Avoid computation that is already finished already_in_memory = set() # tasks that are already done @@ -4121,7 +4162,7 @@ def update_graph( except Exception as e: logger.exception(e) - self.transitions(recommendations) + self.transitions(recommendations, stimulus_id) for ts in touched_tasks: if ts.state in ("memory", "erred"): @@ -4133,7 +4174,7 @@ def update_graph( # TODO: balance workers - def stimulus_task_finished(self, key=None, worker=None, **kwargs): + def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs): """Mark that a task has finished execution on a particular worker""" logger.debug("Stimulus task finished %s, %s", key, worker) @@ -4156,14 +4197,16 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): { "op": "free-keys", "keys": [key], - "stimulus_id": f"already-released-or-forgotten-{time()}", + "stimulus_id": stimulus_id, } ] elif ts.state == "memory": self.add_keys(worker=worker, keys=[key]) else: ts.metadata.update(kwargs["metadata"]) - r: tuple = self._transition(key, "memory", worker=worker, **kwargs) + r: tuple = self._transition( + key, "memory", stimulus_id, worker=worker, **kwargs + ) recommendations, client_msgs, worker_msgs = r if ts.state == "memory": @@ -4171,7 +4214,13 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): return recommendations, client_msgs, worker_msgs def stimulus_task_erred( - self, key=None, worker=None, exception=None, traceback=None, **kwargs + self, + key=None, + worker=None, + exception=None, + stimulus_id=None, + traceback=None, + **kwargs, ): """Mark that a task has erred on a particular worker""" logger.debug("Stimulus task erred %s, %s", key, worker) @@ -4182,11 +4231,12 @@ def stimulus_task_erred( if ts.retries > 0: ts.retries -= 1 - return self._transition(key, "waiting") + return self._transition(key, "waiting", stimulus_id) else: return self._transition( key, "erred", + stimulus_id, cause=key, exception=exception, traceback=traceback, @@ -4215,7 +4265,7 @@ def stimulus_retry(self, keys, client=None): roots.append(key) recommendations: dict = {key: "waiting" for key in roots} - self.transitions(recommendations) + self.transitions(recommendations, f"stimulus-retry-{time()}") if self.validate: for key in seen: @@ -4224,7 +4274,7 @@ def stimulus_retry(self, keys, client=None): return tuple(seen) @log_errors - async def remove_worker(self, address, safe=False, close=True): + async def remove_worker(self, address, stimulus_id, safe=False, close=True): """ Remove worker from cluster @@ -4291,7 +4341,9 @@ async def remove_worker(self, address, safe=False, close=True): e = pickle.dumps( KilledWorker(task=k, last_worker=ws.clean()), protocol=4 ) - r = self.transition(k, "erred", exception=e, cause=k) + r = self.transition( + k, "erred", exception=e, cause=k, stimulus_id=stimulus_id + ) recommendations.update(r) logger.info( "Task %s marked as failed because %d workers died" @@ -4308,7 +4360,7 @@ async def remove_worker(self, address, safe=False, close=True): else: # pure data recommendations[ts.key] = "forgotten" - self.transitions(recommendations) + self.transitions(recommendations, stimulus_id=stimulus_id) for plugin in list(self.plugins.values()): try: @@ -4370,7 +4422,9 @@ def cancel_key(self, key, client, retries=5, force=False): self.report({"op": "cancelled-key", "key": key}) clients = list(ts.who_wants) if force else [cs] for cs in clients: - self.client_releases_keys(keys=[key], client=cs.client_key) + self.client_releases_keys( + keys=[key], client=cs.client_key, stimulus_id=f"cancel-key-{time()}" + ) def client_desires_keys(self, keys=None, client=None): cs: ClientState = self.clients.get(client) @@ -4388,15 +4442,16 @@ def client_desires_keys(self, keys=None, client=None): if ts.state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) - def client_releases_keys(self, keys=None, client=None): + def client_releases_keys(self, keys=None, client=None, stimulus_id=None): """Remove keys from client desired list""" + stimulus_id = stimulus_id or f"client-releases-keys-{time()}" if not isinstance(keys, list): keys = list(keys) cs: ClientState = self.clients[client] recommendations: dict = {} _client_releases_keys(self, keys=keys, cs=cs, recommendations=recommendations) - self.transitions(recommendations) + self.transitions(recommendations, stimulus_id) def client_heartbeat(self, client=None): """Handle heartbeats from Client""" @@ -4625,7 +4680,7 @@ async def add_client( try: await self.handle_stream(comm=comm, extra={"client": client}) finally: - self.remove_client(client=client) + self.remove_client(client=client, stimulus_id=f"remove-client-{time()}") logger.debug("Finished handling client %s", client) finally: if not comm.closed(): @@ -4639,8 +4694,9 @@ async def add_client( except TypeError: # comm becomes None during GC pass - def remove_client(self, client: str) -> None: + def remove_client(self, client: str, stimulus_id: str = None) -> None: """Remove client from network""" + stimulus_id = stimulus_id or f"remove-client-{time()}" if self.status == Status.running: logger.info("Remove client %s", client) self.log_event(["all", client], {"action": "remove-client", "client": client}) @@ -4651,7 +4707,9 @@ def remove_client(self, client: str) -> None: pass else: self.client_releases_keys( - keys=[ts.key for ts in cs.wants_what], client=cs.client_key + keys=[ts.key for ts in cs.wants_what], + client=cs.client_key, + stimulus_id=stimulus_id, ) del self.clients[client] @@ -4687,24 +4745,30 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: float = -1): def handle_uncaught_error(self, **msg): logger.exception(clean_exception(**msg)[1]) - def handle_task_finished(self, key=None, worker=None, **msg): + def handle_task_finished( + self, key: str, worker: str, stimulus_id: str, **msg + ) -> None: if worker not in self.workers: return validate_key(key) - r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg) + + r: tuple = self.stimulus_task_finished( + key=key, worker=worker, stimulus_id=stimulus_id, **msg + ) recommendations, client_msgs, worker_msgs = r - self._transitions(recommendations, client_msgs, worker_msgs) + self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) - def handle_task_erred(self, key=None, **msg): - r: tuple = self.stimulus_task_erred(key=key, **msg) + def handle_task_erred(self, key: str, stimulus_id: str, **msg) -> None: + r: tuple = self.stimulus_task_erred(key=key, stimulus_id=stimulus_id, **msg) recommendations, client_msgs, worker_msgs = r - self._transitions(recommendations, client_msgs, worker_msgs) - + self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) - def handle_missing_data(self, key=None, errant_worker=None, **kwargs): + def handle_missing_data( + self, key: str, errant_worker: str, stimulus_id: str, **kwargs + ) -> None: """Signal that `errant_worker` does not hold `key` This may either indicate that `errant_worker` is dead or that we may be @@ -4723,20 +4787,22 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): """ logger.debug("handle missing data key=%s worker=%s", key, errant_worker) self.log_event(errant_worker, {"action": "missing-data", "key": key}) - ts: TaskState = self.tasks.get(key) - if ts is None: + + if key not in self.tasks: return + + ts: TaskState = self.tasks[key] ws: WorkerState = self.workers.get(errant_worker) if ws is not None and ws in ts.who_has: self.remove_replica(ts, ws) if ts.state == "memory" and not ts.who_has: if ts.run_spec: - self.transitions({key: "released"}) + self.transitions({key: "released"}, stimulus_id) else: - self.transitions({key: "forgotten"}) + self.transitions({key: "forgotten"}, stimulus_id) - def release_worker_data(self, key, worker): + def release_worker_data(self, key, worker, stimulus_id): ws: WorkerState = self.workers.get(worker) ts: TaskState = self.tasks.get(key) if not ws or not ts: @@ -4747,7 +4813,7 @@ def release_worker_data(self, key, worker): if not ts.who_has: recommendations[ts.key] = "released" if recommendations: - self.transitions(recommendations) + self.transitions(recommendations, stimulus_id) def handle_long_running(self, key=None, worker=None, compute_duration=None): """A task has seceded from the thread pool @@ -4789,7 +4855,9 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): ws.long_running.add(ts) self.check_idle_saturated(ws) - def handle_worker_status_change(self, status: str, worker: str) -> None: + def handle_worker_status_change( + self, status: str, worker: str, stimulus_id: str + ) -> None: ws: WorkerState = self.workers.get(worker) # type: ignore if not ws: return @@ -4813,13 +4881,13 @@ def handle_worker_status_change(self, status: str, worker: str) -> None: if recs: client_msgs: dict = {} worker_msgs: dict = {} - self._transitions(recs, client_msgs, worker_msgs) + self._transitions(recs, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) else: self.running.discard(ws) - async def handle_worker(self, comm=None, worker=None): + async def handle_worker(self, comm=None, worker=None, stimulus_id=None): """ Listen to responses from a single worker @@ -4829,6 +4897,7 @@ async def handle_worker(self, comm=None, worker=None): -------- Scheduler.handle_client: Equivalent coroutine for clients """ + assert stimulus_id comm.name = "Scheduler connection to worker" worker_comm = self.stream_comms[worker] worker_comm.start(comm) @@ -4838,7 +4907,7 @@ async def handle_worker(self, comm=None, worker=None): finally: if worker in self.stream_comms: worker_comm.abort() - await self.remove_worker(address=worker) + await self.remove_worker(address=worker, stimulus_id=stimulus_id) def add_plugin( self, @@ -4932,7 +5001,11 @@ def worker_send(self, worker, msg): try: stream_comms[worker].send(msg) except (CommClosedError, AttributeError): - self.loop.add_callback(self.remove_worker, address=worker) + self.loop.add_callback( + self.remove_worker, + address=worker, + stimulus_id=f"worker-send-comm-fail-{time()}", + ) def client_send(self, client, msg): """Send message to client""" @@ -4977,7 +5050,11 @@ def send_all(self, client_msgs: dict, worker_msgs: dict): # worker already gone pass except (CommClosedError, AttributeError): - self.loop.add_callback(self.remove_worker, address=worker) + self.loop.add_callback( + self.remove_worker, + address=worker, + stimulus_id=f"send-all-comm-fail-{time()}", + ) ############################ # Less common interactions # @@ -5034,6 +5111,7 @@ async def scatter( async def gather(self, keys, serializers=None): """Collect data from workers to the scheduler""" + stimulus_id = f"gather-{time()}" keys = list(keys) who_has = {} for key in keys: @@ -5065,7 +5143,9 @@ async def gather(self, keys, serializers=None): # reconnect. await asyncio.gather( *( - self.remove_worker(address=worker, close=False) + self.remove_worker( + address=worker, close=False, stimulus_id=stimulus_id + ) for worker in missing_workers ) ) @@ -5087,8 +5167,14 @@ async def gather(self, keys, serializers=None): for worker in workers: ws = self.workers.get(worker) if ws is not None and ws in ts.who_has: + # FIXME: This code path is not tested self.remove_replica(ts, ws) - self._transitions(recommendations, client_msgs, worker_msgs) + self._transitions( + recommendations, + client_msgs, + worker_msgs, + stimulus_id=stimulus_id, + ) self.send_all(client_msgs, worker_msgs) self.log_event("all", {"action": "gather", "count": len(keys)}) @@ -5104,12 +5190,15 @@ def clear_task_state(self): @log_errors async def restart(self, client=None, timeout=30): """Restart all workers. Reset local state.""" + stimulus_id = f"restart-{time()}" n_workers = len(self.workers) logger.info("Send lost future signal to clients") for cs in self.clients.values(): self.client_releases_keys( - keys=[ts.key for ts in cs.wants_what], client=cs.client_key + keys=[ts.key for ts in cs.wants_what], + client=cs.client_key, + stimulus_id=stimulus_id, ) nannies = {addr: ws.nanny for addr, ws in self.workers.items()} @@ -5118,7 +5207,9 @@ async def restart(self, client=None, timeout=30): try: # Ask the worker to close if it doesn't have a nanny, # otherwise the nanny will kill it anyway - await self.remove_worker(address=addr, close=addr not in nannies) + await self.remove_worker( + address=addr, close=addr not in nannies, stimulus_id=stimulus_id + ) except Exception: logger.info( "Exception while restarting. This is normal", exc_info=True @@ -5307,7 +5398,7 @@ async def gather_on_worker( return keys_failed async def delete_worker_data( - self, worker_address: str, keys: "Collection[str]" + self, worker_address: str, keys: "Collection[str]", stimulus_id: str ) -> None: """Delete data from a worker and update the corresponding worker/task states @@ -5344,7 +5435,7 @@ async def delete_worker_data( self.remove_replica(ts, ws) if not ts.who_has: # Last copy deleted - self.transitions({key: "released"}) + self.transitions({key: "released"}, stimulus_id) self.log_event(ws.address, {"action": "remove-worker-data", "keys": keys}) @@ -5354,6 +5445,7 @@ async def rebalance( comm=None, keys: "Iterable[Hashable]" = None, workers: "Iterable[str]" = None, + stimulus_id: str = None, ) -> dict: """Rebalance keys so that each worker ends up with roughly the same process memory (managed+unmanaged). @@ -5420,6 +5512,7 @@ async def rebalance( All other workers will be ignored. The mean cluster occupancy will be calculated only using the allowed workers. """ + stimulus_id = stimulus_id or f"rebalance-{time()}" if workers is not None: wss = [self.workers[w] for w in workers] else: @@ -5443,7 +5536,7 @@ async def rebalance( return {"status": "OK"} async with self._lock: - result = await self._rebalance_move_data(msgs) + result = await self._rebalance_move_data(msgs, stimulus_id) if result["status"] == "partial-fail" and keys is None: # Only return failed keys if the client explicitly asked for them result = {"status": "OK"} @@ -5640,7 +5733,7 @@ def _rebalance_find_msgs( return msgs async def _rebalance_move_data( - self, msgs: "list[tuple[WorkerState, WorkerState, TaskState]]" + self, msgs: "list[tuple[WorkerState, WorkerState, TaskState]]", stimulus_id: str ) -> dict: """Perform the actual transfer of data across the network in rebalance(). Takes in input the output of _rebalance_find_msgs(), that is a list of tuples: @@ -5676,7 +5769,7 @@ async def _rebalance_move_data( # Note: this never raises exceptions await asyncio.gather( - *(self.delete_worker_data(r, v) for r, v in to_senders.items()) + *(self.delete_worker_data(r, v, stimulus_id) for r, v in to_senders.items()) ) for r, v in to_recipients.items(): @@ -5706,6 +5799,7 @@ async def replicate( branching_factor=2, delete=True, lock=True, + stimulus_id=None, ): """Replicate data throughout cluster @@ -5728,6 +5822,7 @@ async def replicate( -------- Scheduler.rebalance """ + stimulus_id = stimulus_id or f"replicate-{time()}" assert branching_factor > 0 async with self._lock if lock else empty_context: if workers is not None: @@ -5762,7 +5857,9 @@ async def replicate( # Note: this never raises exceptions await asyncio.gather( *[ - self.delete_worker_data(ws.address, [t.key for t in tasks]) + self.delete_worker_data( + ws.address, [t.key for t in tasks], stimulus_id + ) for ws, tasks in del_worker_tasks.items() ] ) @@ -5951,6 +6048,7 @@ async def retire_workers( names: "list | None" = None, close_workers: bool = False, remove: bool = True, + stimulus_id: str = None, **kwargs, ) -> dict: """Gracefully retire workers from cluster @@ -5984,6 +6082,7 @@ async def retire_workers( -------- Scheduler.workers_to_close """ + stimulus_id = stimulus_id or f"retire-workers-{time()}" # This lock makes retire_workers, rebalance, and replicate mutually # exclusive and will no longer be necessary once rebalance and replicate are # migrated to the Active Memory Manager. @@ -6034,7 +6133,11 @@ async def retire_workers( ws.status = Status.closing_gracefully self.running.discard(ws) self.stream_comms[ws.address].send( - {"op": "worker-status-change", "status": ws.status.name} + { + "op": "worker-status-change", + "status": ws.status.name, + "stimulus_id": stimulus_id, + } ) coros.append( @@ -6044,6 +6147,7 @@ async def retire_workers( prev_status=prev_status, close_workers=close_workers, remove=remove, + stimulus_id=stimulus_id, ) ) @@ -6070,6 +6174,7 @@ async def _track_retire_worker( prev_status: Status, close_workers: bool, remove: bool, + stimulus_id: str, ) -> tuple: # tuple[str | None, dict] while not policy.done(): if policy.no_recipients: @@ -6077,7 +6182,11 @@ async def _track_retire_worker( # conditions and we can wait for a scheduler->worker->scheduler # round-trip. self.stream_comms[ws.address].send( - {"op": "worker-status-change", "status": prev_status.name} + { + "op": "worker-status-change", + "status": prev_status.name, + "stimulus_id": stimulus_id, + } ) return None, {} @@ -6091,9 +6200,13 @@ async def _track_retire_worker( ) if close_workers and ws.address in self.workers: - await self.close_worker(worker=ws.address, safe=True) + await self.close_worker( + worker=ws.address, safe=True, stimulus_id=stimulus_id + ) if remove: - await self.remove_worker(address=ws.address, safe=True) + await self.remove_worker( + address=ws.address, safe=True, stimulus_id=stimulus_id + ) logger.info("Retired worker %s", ws.address) return ws.address, ws.identity() @@ -6524,7 +6637,7 @@ async def unregister_nanny_plugin(self, comm, name): ) return responses - def transition(self, key, finish: str, *args, **kwargs): + def transition(self, key, finish: str, *args, stimulus_id: str, **kwargs): """Transition a key from its current state to the finish state Examples @@ -6540,12 +6653,12 @@ def transition(self, key, finish: str, *args, **kwargs): -------- Scheduler.transitions: transitive version of this function """ - a: tuple = self._transition(key, finish, *args, **kwargs) + a: tuple = self._transition(key, finish, stimulus_id, *args, **kwargs) recommendations, client_msgs, worker_msgs = a self.send_all(client_msgs, worker_msgs) return recommendations - def transitions(self, recommendations: dict): + def transitions(self, recommendations: dict, stimulus_id: str): """Process transitions until none are left This includes feedback from previous transitions and continues until we @@ -6553,7 +6666,7 @@ def transitions(self, recommendations: dict): """ client_msgs: dict = {} worker_msgs: dict = {} - self._transitions(recommendations, client_msgs, worker_msgs) + self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) def story(self, *keys): @@ -6561,6 +6674,9 @@ def story(self, *keys): keys = {key.key if isinstance(key, TaskState) else key for key in keys} return scheduler_story(keys, self.transition_log) + async def get_story(self, keys=()): + return self.story(*keys) + transition_story = story def reschedule(self, key=None, worker=None): @@ -6581,7 +6697,7 @@ def reschedule(self, key=None, worker=None): return if worker and ts.processing_on.address != worker: return - self.transitions({key: "released"}) + self.transitions({key: "released"}, f"reschedule-{time()}") ##################### # Utility functions # @@ -7005,6 +7121,7 @@ def reevaluate_occupancy(self, worker_index: int = 0): async def check_worker_ttl(self): now = time() + stimulus_id = f"check-worker-ttl-{now}" for ws in self.workers.values(): if (ws.last_seen < now - self.worker_ttl) and ( ws.last_seen < now - 10 * heartbeat_interval(len(self.workers)) @@ -7014,7 +7131,7 @@ async def check_worker_ttl(self): self.worker_ttl, ws, ) - await self.remove_worker(address=ws.address) + await self.remove_worker(address=ws.address, stimulus_id=stimulus_id) def check_idle(self): if any([ws.processing for ws in self.workers.values()]) or self.unrunnable: @@ -7223,7 +7340,11 @@ def _add_to_memory( def _propagate_forgotten( - state: SchedulerState, ts: TaskState, recommendations: dict, worker_msgs: dict + state: SchedulerState, + ts: TaskState, + recommendations: dict, + worker_msgs: dict, + stimulus_id: str, ): ts.state = "forgotten" key: str = ts.key @@ -7254,7 +7375,7 @@ def _propagate_forgotten( { "op": "free-keys", "keys": [key], - "stimulus_id": f"propagate-forgotten-{time()}", + "stimulus_id": stimulus_id, } ] state.remove_all_replicas(ts) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index e7015cf1f3..dd1adc070f 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -4,7 +4,7 @@ from distributed import Event, Worker from distributed.utils_test import ( _LockedCommPool, - assert_worker_story, + assert_story, gen_cluster, inc, slowinc, @@ -80,7 +80,7 @@ def f(ev): while "f1" in a.tasks: await asyncio.sleep(0.01) - assert_worker_story( + assert_story( a.story("f1"), [ ("f1", "compute-task"), @@ -156,7 +156,7 @@ async def get_data(self, comm, *args, **kwargs): await wait_for_state(fut1.key, "flight", b) # Close in scheduler to ensure we transition and reschedule task properly - await s.close_worker(worker=a.address) + await s.close_worker(worker=a.address, stimulus_id="test") await wait_for_state(fut1.key, "resumed", b) lock.release() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 9fc8982717..00783cb924 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4462,7 +4462,7 @@ def test_auto_normalize_collection_sync(c): def assert_no_data_loss(scheduler): - for key, start, finish, recommendations, _ in scheduler.transition_log: + for key, start, finish, recommendations, _, _ in scheduler.transition_log: if start == "memory" and finish == "released": for k, v in recommendations.items(): assert not (k == key and v == "waiting") diff --git a/distributed/tests/test_cluster_dump.py b/distributed/tests/test_cluster_dump.py index c3912116c4..540e0abb4e 100644 --- a/distributed/tests/test_cluster_dump.py +++ b/distributed/tests/test_cluster_dump.py @@ -8,7 +8,7 @@ import distributed from distributed.cluster_dump import DumpArtefact, _tuple_to_list, write_state -from distributed.utils_test import assert_worker_story, gen_cluster, gen_test, inc +from distributed.utils_test import assert_story, gen_cluster, gen_test, inc @pytest.mark.parametrize( @@ -126,31 +126,12 @@ async def test_cluster_dump_story(c, s, a, b, tmp_path): assert story.keys() == {"f1", "f2"} for k, task_story in story.items(): - expected = [ - (k, "released", "waiting", {k: "processing"}), - (k, "waiting", "processing", {}), - (k, "processing", "memory", {}), - ] - - for event, expected_event in zip(task_story, expected): - for e1, e2 in zip(event, expected_event): - assert e1 == e2 + assert_story(task_story, s.story(k)) story = dump.worker_story("f1", "f2") assert story.keys() == {"f1", "f2"} - for k, task_story in story.items(): - assert_worker_story( - task_story, - [ - (k, "compute-task"), - (k, "released", "waiting", "waiting", {k: "ready"}), - (k, "waiting", "ready", "ready", {k: "executing"}), - (k, "ready", "executing", "executing", {}), - (k, "put-in-memory"), - (k, "executing", "memory", "memory", {}), - ], - ) + assert_story(task_story, a.story(k) + b.story(k)) @gen_cluster(client=True) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 1f0d335192..ff76102475 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -313,7 +313,7 @@ async def test_remove_worker_from_scheduler(s, a, b): ) assert a.address in s.stream_comms - await s.remove_worker(address=a.address) + await s.remove_worker(address=a.address, stimulus_id="test") assert a.address not in s.workers assert len(s.workers[b.address].processing) == len(dsk) # b owns everything @@ -321,9 +321,12 @@ async def test_remove_worker_from_scheduler(s, a, b): @gen_cluster() async def test_remove_worker_by_name_from_scheduler(s, a, b): assert a.address in s.stream_comms - assert await s.remove_worker(address=a.name) == "OK" + assert await s.remove_worker(address=a.name, stimulus_id="test") == "OK" assert a.address not in s.workers - assert await s.remove_worker(address=a.address) == "already-removed" + assert ( + await s.remove_worker(address=a.address, stimulus_id="test") + == "already-removed" + ) @gen_cluster(config={"distributed.scheduler.events-cleanup-delay": "10 ms"}) @@ -333,7 +336,7 @@ async def test_clear_events_worker_removal(s, a, b): assert b.address in s.events assert b.address in s.workers - await s.remove_worker(address=a.address) + await s.remove_worker(address=a.address, stimulus_id="test") # Shortly after removal, the events should still be there assert a.address in s.events assert a.address not in s.workers @@ -596,7 +599,7 @@ async def test_ready_remove_worker(s, a, b): assert all(len(w.processing) > w.nthreads for w in s.workers.values()) - await s.remove_worker(address=a.address) + await s.remove_worker(address=a.address, stimulus_id="test") assert set(s.workers) == {b.address} assert all(len(w.processing) > w.nthreads for w in s.workers.values()) @@ -789,7 +792,7 @@ async def test_story(c, s, a, b): story = s.story(x.key) assert all(line in s.transition_log for line in story) assert len(story) < len(s.transition_log) - assert all(x.key == line[0] or x.key in line[-2] for line in story) + assert all(x.key == line[0] or x.key in line[3] for line in story) assert len(s.story(x.key, y.key)) > len(story) @@ -1212,7 +1215,7 @@ def f(dask_scheduler=None): async def test_close_worker(c, s, a, b): assert len(s.workers) == 2 - await s.close_worker(worker=a.address) + await s.close_worker(worker=a.address, stimulus_id="test") assert len(s.workers) == 1 assert a.address not in s.workers @@ -1230,7 +1233,7 @@ async def test_close_nanny(c, s, a, b): assert a.process.is_alive() a_worker_address = a.worker_address start = time() - await s.close_worker(worker=a_worker_address) + await s.close_worker(worker=a_worker_address, stimulus_id="test") assert len(s.workers) == 1 assert a_worker_address not in s.workers @@ -3094,7 +3097,9 @@ async def test_rebalance_dead_recipient(client, s, a, b, c): await c.close() assert s.workers.keys() == {a.address, b.address} - out = await s._rebalance_move_data([(a_ws, b_ws, x_ts), (a_ws, c_ws, y_ts)]) + out = await s._rebalance_move_data( + [(a_ws, b_ws, x_ts), (a_ws, c_ws, y_ts)], stimulus_id="test" + ) assert out == {"status": "partial-fail", "keys": [y.key]} assert a.data == {y.key: "y"} assert b.data == {x.key: "x"} @@ -3113,7 +3118,7 @@ async def test_delete_worker_data(c, s, a, b): assert b.data == {y.key: "y"} assert s.tasks.keys() == {x.key, y.key, z.key} - await s.delete_worker_data(a.address, [x.key, y.key]) + await s.delete_worker_data(a.address, [x.key, y.key], stimulus_id="test") assert a.data == {z.key: "z"} assert b.data == {y.key: "y"} assert s.tasks.keys() == {y.key, z.key} @@ -3127,8 +3132,8 @@ async def test_delete_worker_data_double_delete(c, s, a): """ x, y = await c.scatter(["x", "y"]) await asyncio.gather( - s.delete_worker_data(a.address, [x.key]), - s.delete_worker_data(a.address, [x.key]), + s.delete_worker_data(a.address, [x.key], stimulus_id="test"), + s.delete_worker_data(a.address, [x.key], stimulus_id="test"), ) assert a.data == {y.key: "y"} a_ws = s.workers[a.address] @@ -3143,7 +3148,7 @@ async def test_delete_worker_data_bad_worker(s, a, b): """ await a.close() assert s.workers.keys() == {b.address} - await s.delete_worker_data(a.address, ["x"]) + await s.delete_worker_data(a.address, ["x"], stimulus_id="test") @pytest.mark.parametrize("bad_first", [False, True]) @@ -3158,7 +3163,7 @@ async def test_delete_worker_data_bad_task(c, s, a, bad_first): assert s.tasks.keys() == {x.key, y.key} keys = ["notexist", x.key] if bad_first else [x.key, "notexist"] - await s.delete_worker_data(a.address, keys) + await s.delete_worker_data(a.address, keys, stimulus_id="test") assert a.data == {y.key: "y"} assert s.tasks.keys() == {y.key} assert s.workers[a.address].nbytes == s.tasks[y.key].nbytes @@ -3243,13 +3248,13 @@ async def test_worker_reconnect_task_memory(c, s, a): while not a.executing_count and not a.data: await asyncio.sleep(0.001) - await s.remove_worker(address=a.address, close=False) + await s.remove_worker(address=a.address, close=False, stimulus_id="test") while not res.done(): await a.heartbeat() await res assert ("no-worker", "memory") in { - (start, finish) for (_, start, finish, _, _) in s.transition_log + (start, finish) for (_, start, finish, _, _, _) in s.transition_log } @@ -3267,13 +3272,13 @@ async def test_worker_reconnect_task_memory_with_resources(c, s, a): while not b.executing_count and not b.data: await asyncio.sleep(0.001) - await s.remove_worker(address=b.address, close=False) + await s.remove_worker(address=b.address, close=False, stimulus_id="test") while not res.done(): await b.heartbeat() await res assert ("no-worker", "memory") in { - (start, finish) for (_, start, finish, _, _) in s.transition_log + (start, finish) for (_, start, finish, _, _, _) in s.transition_log } diff --git a/distributed/tests/test_stories.py b/distributed/tests/test_stories.py new file mode 100644 index 0000000000..1616422f25 --- /dev/null +++ b/distributed/tests/test_stories.py @@ -0,0 +1,170 @@ +import pytest + +import dask + +from distributed import Worker +from distributed.comm import CommClosedError +from distributed.utils_test import assert_story, assert_valid_story, gen_cluster, inc + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_scheduler_story_stimulus_success(c, s, a): + f = c.submit(inc, 1) + key = f.key + + await f + + stories = s.story(key) + + stimulus_ids = {s[-2] for s in stories} + # Two events + # - Compute + # - Success + assert len(stimulus_ids) == 2 + assert_story( + stories, + [ + (key, "released", "waiting", {key: "processing"}), + (key, "waiting", "processing", {}), + (key, "processing", "memory", {}), + ], + ) + + await c.close() + + stories_after_close = s.story(key) + assert len(stories_after_close) > len(stories) + + stimulus_ids = {s[-2] for s in stories_after_close} + # One more event + # - Forget / Release / Free since client closed + assert len(stimulus_ids) == 3 + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_scheduler_story_stimulus_retry(c, s, a): + def task(): + assert dask.config.get("foo") + + with dask.config.set(foo=False): + f = c.submit(task) + with pytest.raises(AssertionError): + await f + + with dask.config.set(foo=True): + await f.retry() + await f + + story = s.story(f.key) + stimulus_ids = {s[-2] for s in story} + # Four events + # - Compute + # - Erred + # - Compute / Retry + # - Success + assert len(stimulus_ids) == 4 + + assert_story( + story, + [ + # Erred transitions via released + (f.key, "processing", "erred", {}), + (f.key, "erred", "released", {}), + (f.key, "released", "waiting", {f.key: "processing"}), + ], + ) + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_client_story(c, s, a): + f = c.submit(inc, 1) + assert await f == 2 + story = await c.story(f.key) + + # Every event should be prefixed with it's origin + # This changes the format compared to default scheduler / worker stories + prefixes = set() + stripped_story = list() + for msg in story: + prefixes.add(msg[0]) + stripped_story.append(msg[1:]) + assert prefixes == {"scheduler", a.address} + + assert_valid_story(stripped_story, ordered_timestamps=False) + + # If it's a well formed story, we can sort by the last element which is a + # timestamp and compare the two lists. + assert sorted(stripped_story, key=lambda msg: msg[-1]) == sorted( + s.story(f.key) + a.story(f.key), key=lambda msg: msg[-1] + ) + + +class WorkerBrokenStory(Worker): + async def get_story(self, *args, **kw): + raise CommClosedError + + +@gen_cluster(client=True, Worker=WorkerBrokenStory) +@pytest.mark.parametrize("on_error", ["ignore", "raise"]) +async def test_client_story_failed_worker(c, s, a, b, on_error): + f = c.submit(inc, 1) + coro = c.story(f.key, on_error=on_error) + await f + + if on_error == "raise": + with pytest.raises(CommClosedError): + await coro + elif on_error == "ignore": + story = await coro + assert story + assert len(story) > 1 + else: + raise ValueError(on_error) + + +@gen_cluster(client=True) +async def test_worker_story_with_deps(c, s, a, b): + """ + Assert that the structure of the story does not change unintentionally and + expected subfields are actually filled + """ + dep = c.submit(inc, 1, workers=[a.address], key="dep") + res = c.submit(inc, dep, workers=[b.address], key="res") + await res + + story = a.story("res") + assert story == [] + story = b.story("res") + + # Story now includes randomized stimulus_ids and timestamps. + stimulus_ids = {ev[-2] for ev in story} + # Compute dep + # Success dep + # Compute res + assert len(stimulus_ids) == 3 + + # This is a simple transition log + expected = [ + ("res", "compute-task"), + ("res", "released", "waiting", "waiting", {"dep": "fetch"}), + ("res", "waiting", "ready", "ready", {"res": "executing"}), + ("res", "ready", "executing", "executing", {}), + ("res", "put-in-memory"), + ("res", "executing", "memory", "memory", {}), + ] + assert_story(story, expected, strict=True) + + story = b.story("dep") + stimulus_ids = {ev[-2] for ev in story} + assert len(stimulus_ids) == 2, stimulus_ids + expected = [ + ("dep", "ensure-task-exists", "released"), + ("dep", "released", "fetch", "fetch", {}), + ("gather-dependencies", a.address, {"dep"}), + ("dep", "fetch", "flight", "flight", {}), + ("request-dep", a.address, {"dep"}), + ("receive-dep", a.address, {"dep"}), + ("dep", "put-in-memory"), + ("dep", "flight", "memory", "memory", {"res": "ready"}), + ] + assert_story(story, expected, strict=True) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 3d4c7aee4d..c7770079b3 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -21,7 +21,7 @@ from distributed.utils_test import ( _LockedCommPool, _UnhashableCallable, - assert_worker_story, + assert_story, check_process_leak, cluster, dump_cluster_state, @@ -407,7 +407,7 @@ async def inner_test(c, s, a, b): assert "workers" in state -def test_assert_worker_story(): +def test_assert_story(): now = time() story = [ ("foo", "id1", now - 600), @@ -415,38 +415,38 @@ def test_assert_worker_story(): ("baz", {1: 2}, "id2", now), ] # strict=False - assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})]) - assert_worker_story(story, []) - assert_worker_story(story, [("foo",)]) - assert_worker_story(story, [("foo",), ("bar",)]) - assert_worker_story(story, [("baz", lambda d: d[1] == 2)]) + assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})]) + assert_story(story, []) + assert_story(story, [("foo",)]) + assert_story(story, [("foo",), ("bar",)]) + assert_story(story, [("baz", lambda d: d[1] == 2)]) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo", "nomatch")]) + assert_story(story, [("foo", "nomatch")]) with pytest.raises(AssertionError): - assert_worker_story(story, [("baz",)]) + assert_story(story, [("baz",)]) with pytest.raises(AssertionError): - assert_worker_story(story, [("baz", {1: 3})]) + assert_story(story, [("baz", {1: 3})]) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)]) + assert_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)]) with pytest.raises(AssertionError): - assert_worker_story(story, [("baz", lambda d: d[1] == 3)]) + assert_story(story, [("baz", lambda d: d[1] == 3)]) with pytest.raises(KeyError): # Faulty lambda - assert_worker_story(story, [("baz", lambda d: d[2] == 1)]) - assert_worker_story([], []) - assert_worker_story([("foo", "id1", now)], [("foo",)]) + assert_story(story, [("baz", lambda d: d[2] == 1)]) + assert_story([], []) + assert_story([("foo", "id1", now)], [("foo",)]) with pytest.raises(AssertionError): - assert_worker_story([], [("foo",)]) + assert_story([], [("foo",)]) # strict=True - assert_worker_story([], [], strict=True) - assert_worker_story([("foo", "id1", now)], [("foo",)]) - assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True) + assert_story([], [], strict=True) + assert_story([("foo", "id1", now)], [("foo",)]) + assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo",), ("bar",)], strict=True) + assert_story(story, [("foo",), ("bar",)], strict=True) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo",), ("baz", {1: 2})], strict=True) + assert_story(story, [("foo",), ("baz", {1: 2})], strict=True) with pytest.raises(AssertionError): - assert_worker_story(story, [], strict=True) + assert_story(story, [], strict=True) @pytest.mark.parametrize( @@ -467,11 +467,29 @@ def test_assert_worker_story(): ), ], ) -def test_assert_worker_story_malformed_story(story_factory): +def test_assert_story_malformed_story(story_factory): # defer the calls to time() to when the test runs rather than collection story = story_factory() with pytest.raises(AssertionError, match="Malformed story event"): - assert_worker_story(story, []) + assert_story(story, []) + + +@pytest.mark.parametrize("strict", [True, False]) +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_assert_story_identity(c, s, a, strict): + f1 = c.submit(inc, 1, key="f1") + f2 = c.submit(inc, f1, key="f2") + assert await f2 == 3 + scheduler_story = s.story(f2.key) + assert scheduler_story + worker_story = a.story(f2.key) + assert worker_story + assert_story(worker_story, worker_story, strict=strict) + assert_story(scheduler_story, scheduler_story, strict=strict) + with pytest.raises(AssertionError): + assert_story(scheduler_story, worker_story, strict=strict) + with pytest.raises(AssertionError): + assert_story(worker_story, scheduler_story, strict=strict) @gen_cluster() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 7f796c015b..8fff62dced 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -48,7 +48,7 @@ from distributed.utils_test import ( TaskStateMetadataPlugin, _LockedCommPool, - assert_worker_story, + assert_story, captured_logger, dec, div, @@ -1410,7 +1410,7 @@ def assert_amm_transfer_story(key: str, w_from: Worker, w_to: Worker) -> None: """Test that an in-memory key was transferred from worker w_from to worker w_to by the Active Memory Manager and it was not recalculated on w_to """ - assert_worker_story( + assert_story( w_to.story(key), [ (key, "ensure-task-exists", "released"), @@ -1733,50 +1733,6 @@ async def test_story(c, s, w): assert w.story(ts) == w.story(ts.key) -@gen_cluster(client=True) -async def test_story_with_deps(c, s, a, b): - """ - Assert that the structure of the story does not change unintentionally and - expected subfields are actually filled - """ - dep = c.submit(inc, 1, workers=[a.address], key="dep") - res = c.submit(inc, dep, workers=[b.address], key="res") - await res - - story = a.story("res") - assert story == [] - story = b.story("res") - - # Story now includes randomized stimulus_ids and timestamps. - stimulus_ids = {ev[-2] for ev in story} - assert len(stimulus_ids) == 2, stimulus_ids - # This is a simple transition log - expected = [ - ("res", "compute-task"), - ("res", "released", "waiting", "waiting", {"dep": "fetch"}), - ("res", "waiting", "ready", "ready", {"res": "executing"}), - ("res", "ready", "executing", "executing", {}), - ("res", "put-in-memory"), - ("res", "executing", "memory", "memory", {}), - ] - assert_worker_story(story, expected, strict=True) - - story = b.story("dep") - stimulus_ids = {ev[-2] for ev in story} - assert len(stimulus_ids) == 2, stimulus_ids - expected = [ - ("dep", "ensure-task-exists", "released"), - ("dep", "released", "fetch", "fetch", {}), - ("gather-dependencies", a.address, {"dep"}), - ("dep", "fetch", "flight", "flight", {}), - ("request-dep", a.address, {"dep"}), - ("receive-dep", a.address, {"dep"}), - ("dep", "put-in-memory"), - ("dep", "flight", "memory", "memory", {"res": "ready"}), - ] - assert_worker_story(story, expected, strict=True) - - @gen_cluster(client=True, nthreads=[("", 1)]) async def test_stimulus_story(c, s, a): class C: @@ -2712,7 +2668,7 @@ async def test_gather_dep_exception_one_task_2(c, s, a, b): while fut1.key not in b.tasks or b.tasks[fut1.key].state == "flight": await asyncio.sleep(0) - s.handle_missing_data(key="f1", errant_worker=a.address) + s.handle_missing_data(key="f1", errant_worker=a.address, stimulus_id="test") await fut2 @@ -2757,7 +2713,7 @@ async def test_acquire_replicas_same_channel(c, s, a, b): # same communication channel for fut in (futA, futB): - assert_worker_story( + assert_story( b.story(fut.key), [ ("gather-dependencies", a.address, {fut.key}), @@ -2816,7 +2772,7 @@ def __getstate__(self): assert await y == 123 story = await c.run(lambda dask_worker: dask_worker.story("x")) - assert_worker_story( + assert_story( story[b], [ ("x", "ensure-task-exists", "released"), @@ -2993,7 +2949,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers): await f2 - assert_worker_story(a.story(f1.key), [(f1.key, "missing-dep")]) + assert_story(a.story(f1.key), [(f1.key, "missing-dep")]) assert a.tasks[f1.key].suspicious_count == 0 assert s.tasks[f1.key].suspicious == 0 @@ -3074,7 +3030,7 @@ async def test_missing_released_zombie_tasks_2(c, s, b): while b.tasks: await asyncio.sleep(0.01) - assert_worker_story( + assert_story( b.story(ts), [("f1", "missing", "released", "released", {"f1": "forgotten"})], ) @@ -3186,7 +3142,7 @@ async def test_task_flight_compute_oserror(c, s, a, b): ("f1", "put-in-memory"), ("f1", "executing", "memory", "memory", {}), ] - assert_worker_story(sum_story, expected_sum_story, strict=True) + assert_story(sum_story, expected_sum_story, strict=True) @gen_cluster(client=True, nthreads=[]) @@ -3406,7 +3362,9 @@ async def gather_dep(self, *args, **kwargs): await in_gather_dep.wait() - await s.remove_worker(address=x.address, safe=True, close=close_worker) + await s.remove_worker( + address=x.address, safe=True, close=close_worker, stimulus_id="test" + ) await _wait_for_state(fut2_key, b, intermediate_state) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index d8337ace8e..feb99ee995 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -115,15 +115,19 @@ def test_slots(cls): def test_sendmsg_to_dict(): # Arbitrary sample class - smsg = ReleaseWorkerDataMsg(key="x") - assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"} + smsg = ReleaseWorkerDataMsg(key="x", stimulus_id="test") + assert smsg.to_dict() == { + "op": "release-worker-data", + "key": "x", + "stimulus_id": "test", + } def test_merge_recs_instructions(): x = TaskState("x") y = TaskState("y") - instr1 = RescheduleMsg(key="foo", worker="a") - instr2 = RescheduleMsg(key="bar", worker="b") + instr1 = RescheduleMsg(key="foo", worker="a", stimulus_id="test") + instr2 = RescheduleMsg(key="bar", worker="b", stimulus_id="test") assert merge_recs_instructions( ({x: "memory"}, [instr1]), ({y: "released"}, [instr2]), diff --git a/distributed/utils_test.py b/distributed/utils_test.py index b7b79fd8c8..836a7fb179 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1891,17 +1891,72 @@ def xfail_ssl_issue5601(): raise -def assert_worker_story( - story: list[tuple], expect: list[tuple], *, strict: bool = False +def assert_valid_story(story, ordered_timestamps=True): + """Test that a story is well formed. + + Parameters + ---------- + story: list[tuple] + Output of Worker.story + ordered_timestamps: bool, optional + If False, timestamps are not required to be monotically increasing. + Useful for asserting stories composed from the scheduler and + multiple workers + """ + + now = time() + prev_ts = 0.0 + for ev in story: + try: + assert len(ev) > 2, "too short" + assert isinstance(ev, tuple), "not a tuple" + assert isinstance(ev[-2], str) and ev[-2], "stimulus_id not a string" + assert isinstance(ev[-1], float), "Timestamp is not a float" + if ordered_timestamps: + assert prev_ts <= ev[-1], "Timestamps are not monotonically ascending" + # Timestamps are within the last hour. It's been observed that a + # timestamp generated in a Nanny process can be a few milliseconds + # in the future. + assert now - 3600 < ev[-1] <= now + 1, "Timestamps is too old" + prev_ts = ev[-1] + except AssertionError as err: + raise AssertionError( + f"Malformed story event: {ev}\nProblem: {err}.\nin story:\n{_format_story(story)}" + ) + + +def assert_story( + story: list[tuple], + expect: list[tuple], + *, + strict: bool = False, + ordered_timestamps: bool = True, ) -> None: - """Test the output of ``Worker.story`` + """Test the output of ``Worker.story`` or ``Scheduler.story`` + + Warning + ======= + + Tests with overly verbose stories introduce maintenance cost and should + therefore be used with caution. This should only be used for very specific + unit tests where the exact order of events is crucial and there is no other + practical way to assert or observe what happened. + A typical use case involves testing for race conditions where subtle changes + of event ordering would cause harm. Parameters ========== story: list[tuple] Output of Worker.story expect: list[tuple] - Expected events. Each expected event must contain exactly 2 less fields than the + Expected events. + The expected events either need to be exact matches or are allowed to + not provide a stimulus_id and timestamp. + e.g. + `("log", "entry", "stim-id-9876", 1234)` + is equivalent to + `("log", "entry")` + story (the last two fields are always the stimulus_id and the timestamp). Elements of the expect tuples can be @@ -1922,24 +1977,17 @@ def assert_worker_story( If True, the story must contain exactly as many events as expect. If False (the default), the story may contain more events than expect; extra events are ignored. + ordered_timestamps: bool, optional + If False, timestamps are not required to be monotically increasing. + Useful for asserting stories composed from the scheduler and + multiple workers """ - now = time() - prev_ts = 0.0 - for ev in story: - try: - assert len(ev) > 2 - assert isinstance(ev, tuple) - assert isinstance(ev[-2], str) and ev[-2] # stimulus_id - assert isinstance(ev[-1], float) # timestamp - assert prev_ts <= ev[-1] # Timestamps are monotonic ascending - # Timestamps are within the last hour. It's been observed that a timestamp - # generated in a Nanny process can be a few milliseconds in the future. - assert now - 3600 < ev[-1] <= now + 1 - prev_ts = ev[-1] - except AssertionError: - raise AssertionError( - f"Malformed story event: {ev}\nin story:\n{_format_story(story)}" - ) + assert_valid_story(story, ordered_timestamps=ordered_timestamps) + + def _valid_event(event, ev_expect): + return len(event) == len(ev_expect) and all( + ex(ev) if callable(ex) else ev == ex for ev, ex in zip(event, ev_expect) + ) try: if strict and len(story) != len(expect): @@ -1948,16 +1996,16 @@ def assert_worker_story( for ev_expect in expect: while True: event = next(story_it) - # Ignore (stimulus_id, timestamp) - event = event[:-2] - if len(event) == len(ev_expect) and all( - ex(ev) if callable(ex) else ev == ex - for ev, ex in zip(event, ev_expect) + + if ( + _valid_event(event, ev_expect) + # Ignore (stimulus_id, timestamp) + or _valid_event(event[:-2], ev_expect) ): break except StopIteration: raise AssertionError( - f"assert_worker_story({strict=}) failed\n" + f"assert_story({strict=}) failed\n" f"story:\n{_format_story(story)}\n" f"expect:\n{_format_story(expect)}" ) from None diff --git a/distributed/worker.py b/distributed/worker.py index 4bc2fab5c6..3e50a679bc 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -48,6 +48,7 @@ ) from distributed import comm, preloading, profile, utils +from distributed._stories import worker_story from distributed.batched import BatchedSend from distributed.comm import connect, get_address_host from distributed.comm.addressing import address_from_user_args, parse_address @@ -74,7 +75,6 @@ from distributed.security import Security from distributed.shuffle import ShuffleWorkerExtension from distributed.sizeof import safe_sizeof as sizeof -from distributed.stories import worker_story from distributed.threadpoolexecutor import ThreadPoolExecutor from distributed.threadpoolexecutor import secede as tpe_secede from distributed.utils import ( @@ -754,6 +754,7 @@ def __init__( "benchmark_disk": self.benchmark_disk, "benchmark_memory": self.benchmark_memory, "benchmark_network": self.benchmark_network, + "get_story": self.get_story, } stream_handlers = { @@ -927,21 +928,26 @@ def status(self, value): """ prev_status = self.status ServerNode.status.__set__(self, value) - self._send_worker_status_change() + stimulus_id = f"worker-status-change-{time()}" + self._send_worker_status_change(stimulus_id) if prev_status == Status.paused and value == Status.running: - self.handle_stimulus(UnpauseEvent(stimulus_id=f"set-status-{time()}")) + self.handle_stimulus(UnpauseEvent(stimulus_id=stimulus_id)) - def _send_worker_status_change(self) -> None: + def _send_worker_status_change(self, stimulus_id: str) -> None: if ( self.batched_stream and self.batched_stream.comm and not self.batched_stream.comm.closed() ): self.batched_stream.send( - {"op": "worker-status-change", "status": self._status.name} + { + "op": "worker-status-change", + "status": self._status.name, + "stimulus_id": stimulus_id, + }, ) elif self._status != Status.closed: - self.loop.call_later(0.05, self._send_worker_status_change) + self.loop.call_later(0.05, self._send_worker_status_change, stimulus_id) async def get_metrics(self) -> dict: try: @@ -1083,6 +1089,7 @@ async def _register_with_scheduler(self): versions=get_versions(), metrics=await self.get_metrics(), extra=await self.get_startup_information(), + stimulus_id=f"worker-connect-{time()}", ), serializers=["msgpack"], ) @@ -1479,7 +1486,11 @@ async def close( with suppress(EnvironmentError, TimeoutError): if report and self.contact_address is not None: await asyncio.wait_for( - self.scheduler.unregister(address=self.contact_address, safe=safe), + self.scheduler.unregister( + address=self.contact_address, + safe=safe, + stimulus_id=f"worker-close-{time()}", + ), timeout, ) await self.scheduler.close_rpc() @@ -1543,7 +1554,10 @@ async def close_gracefully(self, restart=None): # Scheduler.retire_workers will set the status to closing_gracefully and push it # back to this worker. await self.scheduler.retire_workers( - workers=[self.address], close_workers=False, remove=False + workers=[self.address], + close_workers=False, + remove=False, + stimulus_id=f"worker-close-gracefully-{time()}", ) await self.close(safe=True, nanny=not restart) @@ -1894,7 +1908,9 @@ def handle_compute_task( pass elif ts.state == "memory": recommendations[ts] = "memory" - instructions.append(self._get_task_finished_msg(ts)) + instructions.append( + self._get_task_finished_msg(ts, stimulus_id=stimulus_id) + ) elif ts.state in { "released", "fetch", @@ -2033,7 +2049,7 @@ def transition_memory_released( recs, instructions = self.transition_generic_released( ts, stimulus_id=stimulus_id ) - instructions.append(ReleaseWorkerDataMsg(ts.key)) + instructions.append(ReleaseWorkerDataMsg(key=ts.key, stimulus_id=stimulus_id)) return recs, instructions def transition_waiting_constrained( @@ -2056,7 +2072,7 @@ def transition_long_running_rescheduled( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: recs: Recs = {ts: "released"} - smsg = RescheduleMsg(key=ts.key, worker=self.address) + smsg = RescheduleMsg(key=ts.key, worker=self.address, stimulus_id=stimulus_id) return recs, [smsg] def transition_executing_rescheduled( @@ -2067,7 +2083,14 @@ def transition_executing_rescheduled( self._executing.discard(ts) return merge_recs_instructions( - ({ts: "released"}, [RescheduleMsg(key=ts.key, worker=self.address)]), + ( + {ts: "released"}, + [ + RescheduleMsg( + key=ts.key, worker=self.address, stimulus_id=stimulus_id + ) + ], + ), self._ensure_computing(), ) @@ -2145,6 +2168,7 @@ def transition_generic_error( traceback_text=traceback_text, thread=self.threads.get(ts.key), startstops=ts.startstops, + stimulus_id=stimulus_id, ) return {}, [smsg] @@ -2334,7 +2358,9 @@ def transition_generic_memory( else: if self.validate: assert ts.key in self.data or ts.key in self.actors - instructions.append(self._get_task_finished_msg(ts)) + instructions.append( + self._get_task_finished_msg(ts, stimulus_id=stimulus_id) + ) return recs, instructions @@ -2453,7 +2479,10 @@ def transition_executing_long_running( self.long_running.add(ts.key) return merge_recs_instructions( - ({}, [LongRunningMsg(key=ts.key, compute_duration=compute_duration)]), + ( + {}, + [LongRunningMsg(key=ts.key, compute_duration=compute_duration)], + ), self._ensure_computing(), ) @@ -2700,6 +2729,9 @@ def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]: keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} return worker_story(keys, self.log) + async def get_story(self, keys=None): + return self.story(*keys) + def stimulus_story( self, *keys_or_tasks: str | TaskState ) -> list[StateMachineEvent]: @@ -2768,7 +2800,9 @@ def ensure_communicating(self) -> None: for el in skipped_worker_in_flight: self.data_needed.push(el) - def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg: + def _get_task_finished_msg( + self, ts: TaskState, stimulus_id: str + ) -> TaskFinishedMsg: if ts.key not in self.data and ts.key not in self.actors: raise RuntimeError(f"Task {ts} not ready") typ = ts.type @@ -2794,6 +2828,7 @@ def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg: metadata=ts.metadata, thread=self.threads.get(ts.key), startstops=ts.startstops, + stimulus_id=stimulus_id, ) def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: @@ -3101,7 +3136,12 @@ async def gather_dep( self.has_what[worker].discard(ts.key) self.log.append((d, "missing-dep", stimulus_id, time())) self.batched_stream.send( - {"op": "missing-data", "errant_worker": worker, "key": d} + { + "op": "missing-data", + "errant_worker": worker, + "key": d, + "stimulus_id": stimulus_id, + } ) if ts.who_has: recommendations[ts] = "fetch" @@ -3207,7 +3247,7 @@ def handle_steal_request(self, key: str, stimulus_id: str) -> None: # `transition_constrained_executing` self.transition(ts, "released", stimulus_id=stimulus_id) - def handle_worker_status_change(self, status: str) -> None: + def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: new_status = Status.lookup[status] # type: ignore if ( @@ -3218,7 +3258,7 @@ def handle_worker_status_change(self, status: str) -> None: "Invalid Worker.status transition: %s -> %s", self._status, new_status ) # Reiterate the current status to the scheduler to restore sync - self._send_worker_status_change() + self._send_worker_status_change(stimulus_id) else: # Update status and send confirmation to the Scheduler (see status.setter) self.status = new_status @@ -3572,7 +3612,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No stop=result["stop"], nbytes=result["nbytes"], type=result["type"], - stimulus_id=stimulus_id, + stimulus_id=f"task-finished-{time()}", ) if isinstance(result["actual-exception"], Reschedule): @@ -3599,7 +3639,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No traceback=result["traceback"], exception_text=result["exception_text"], traceback_text=result["traceback_text"], - stimulus_id=stimulus_id, + stimulus_id=f"task-erred-{time()}", ) except Exception as exc: @@ -3613,7 +3653,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No traceback=msg["traceback"], exception_text=msg["exception_text"], traceback_text=msg["traceback_text"], - stimulus_id=stimulus_id, + stimulus_id=f"task-erred-{time()}", ) @functools.singledispatchmethod diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 3e74eb8cc8..abdbc12108 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -303,6 +303,7 @@ class TaskFinishedMsg(SendMessageToScheduler): metadata: dict thread: int | None startstops: list[StartStop] + stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def to_dict(self) -> dict[str, Any]: @@ -322,6 +323,7 @@ class TaskErredMsg(SendMessageToScheduler): traceback_text: str thread: int | None startstops: list[StartStop] + stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def to_dict(self) -> dict[str, Any]: @@ -334,8 +336,9 @@ def to_dict(self) -> dict[str, Any]: class ReleaseWorkerDataMsg(SendMessageToScheduler): op = "release-worker-data" - __slots__ = ("key",) + __slots__ = ("key", "stimulus_id") key: str + stimulus_id: str # Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception @@ -343,9 +346,10 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler): class RescheduleMsg(SendMessageToScheduler): op = "reschedule" - __slots__ = ("key", "worker") + __slots__ = ("key", "worker", "stimulus_id") key: str worker: str + stimulus_id: str @dataclass @@ -438,6 +442,7 @@ class ExecuteSuccessEvent(StateMachineEvent): stop: float nbytes: int type: type | None + stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def to_loggable(self, *, handled: float) -> StateMachineEvent: @@ -460,6 +465,7 @@ class ExecuteFailureEvent(StateMachineEvent): traceback: Serialize | None exception_text: str traceback_text: str + stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def _after_from_dict(self) -> None: