diff --git a/distributed/stories.py b/distributed/_stories.py
similarity index 100%
rename from distributed/stories.py
rename to distributed/_stories.py
diff --git a/distributed/client.py b/distributed/client.py
index 9ca8c0b2d2..af9bb32567 100644
--- a/distributed/client.py
+++ b/distributed/client.py
@@ -4276,6 +4276,31 @@ def collections_to_dsk(collections, *args, **kwargs):
"""Convert many collections into a single dask graph, after optimization"""
return collections_to_dsk(collections, *args, **kwargs)
+ async def _story(self, keys=(), on_error="raise"):
+ assert on_error in ("raise", "ignore")
+
+ try:
+ flat_stories = await self.scheduler.get_story(keys=keys)
+ flat_stories = [("scheduler", *msg) for msg in flat_stories]
+ except Exception:
+ if on_error == "raise":
+ raise
+ elif on_error == "ignore":
+ flat_stories = []
+ else:
+ raise ValueError(f"on_error not in {'raise', 'ignore'}")
+
+ responses = await self.scheduler.broadcast(
+ msg={"op": "get_story", "keys": keys}, on_error=on_error
+ )
+ for worker, stories in responses.items():
+ flat_stories.extend((worker, *msg) for msg in stories)
+ return flat_stories
+
+ def story(self, *keys_or_stimulus_ids, on_error="raise"):
+ """Returns a cluster-wide story for the given keys or simtulus_id's"""
+ return self.sync(self._story, keys=keys_or_stimulus_ids, on_error=on_error)
+
def get_task_stream(
self,
start=None,
diff --git a/distributed/cluster_dump.py b/distributed/cluster_dump.py
index 161f9091e7..5c77dc14ac 100644
--- a/distributed/cluster_dump.py
+++ b/distributed/cluster_dump.py
@@ -10,9 +10,9 @@
import fsspec
import msgpack
+from distributed._stories import scheduler_story as _scheduler_story
+from distributed._stories import worker_story as _worker_story
from distributed.compatibility import to_thread
-from distributed.stories import scheduler_story as _scheduler_story
-from distributed.stories import worker_story as _worker_story
DEFAULT_CLUSTER_DUMP_FORMAT: Literal["msgpack" | "yaml"] = "msgpack"
DEFAULT_CLUSTER_DUMP_EXCLUDE: Collection[str] = ("run_spec",)
diff --git a/distributed/http/templates/task.html b/distributed/http/templates/task.html
index 0b5c10695e..f10aaad560 100644
--- a/distributed/http/templates/task.html
+++ b/distributed/http/templates/task.html
@@ -118,16 +118,18 @@
Transition Log
Key |
Start |
Finish |
+ Stimulus ID |
Recommended Key |
Recommended Action |
- {% for key, start, finish, recommendations, transition_time in scheduler.story(Task) %}
+ {% for key, start, finish, recommendations, stimulus_id, transition_time in scheduler.story(Task) %}
{{ fromtimestamp(transition_time) }} |
{{key}} |
{{ start }} |
{{ finish }} |
+ {{ stimulus_id }} |
|
|
@@ -137,6 +139,7 @@ Transition Log
|
|
|
+ |
{{key2}} |
{{ rec }} |
diff --git a/distributed/nanny.py b/distributed/nanny.py
index 826c3766e1..20778a5197 100644
--- a/distributed/nanny.py
+++ b/distributed/nanny.py
@@ -289,7 +289,10 @@ async def _unregister(self, timeout=10):
allowed_errors = (TimeoutError, CommClosedError, EnvironmentError, RPCClosed)
with suppress(allowed_errors):
await asyncio.wait_for(
- self.scheduler.unregister(address=self.worker_address), timeout
+ self.scheduler.unregister(
+ address=self.worker_address, stimulus_id=f"nanny-close-{time()}"
+ ),
+ timeout,
)
@property
diff --git a/distributed/scheduler.py b/distributed/scheduler.py
index f65e34aeea..151ebaf7c8 100644
--- a/distributed/scheduler.py
+++ b/distributed/scheduler.py
@@ -54,6 +54,7 @@
from distributed import cluster_dump, preloading, profile
from distributed import versions as version_module
+from distributed._stories import scheduler_story
from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker
from distributed.batched import BatchedSend
from distributed.comm import (
@@ -83,7 +84,6 @@
from distributed.security import Security
from distributed.semaphore import SemaphoreExtension
from distributed.stealing import WorkStealing
-from distributed.stories import scheduler_story
from distributed.utils import (
All,
TimeoutError,
@@ -109,6 +109,7 @@
DEFAULT_DATA_SIZE = parse_bytes(
dask.config.get("distributed.scheduler.default-data-size")
)
+STIMULUS_ID_UNSET = ""
DEFAULT_EXTENSIONS = {
"locks": LockExtension,
@@ -1527,7 +1528,7 @@ def new_task(
# State Transitions #
#####################
- def _transition(self, key, finish: str, *args, **kwargs):
+ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs):
"""Transition a key from its current state to the finish state
Examples
@@ -1562,14 +1563,16 @@ def _transition(self, key, finish: str, *args, **kwargs):
start_finish = (start, finish)
func = self.transitions_table.get(start_finish)
if func is not None:
- recommendations, client_msgs, worker_msgs = func(key, *args, **kwargs) # type: ignore
+ recommendations, client_msgs, worker_msgs = func(
+ key, stimulus_id, *args, **kwargs
+ ) # type: ignore
self.transition_counter += 1
elif "released" not in start_finish:
assert not args and not kwargs, (args, kwargs, start_finish)
a_recs: dict
a_cmsgs: dict
a_wmsgs: dict
- a: tuple = self._transition(key, "released")
+ a: tuple = self._transition(key, "released", stimulus_id)
a_recs, a_cmsgs, a_wmsgs = a
v = a_recs.get(key, finish)
@@ -1577,7 +1580,7 @@ def _transition(self, key, finish: str, *args, **kwargs):
b_recs: dict
b_cmsgs: dict
b_wmsgs: dict
- b: tuple = func(key) # type: ignore
+ b: tuple = func(key, stimulus_id) # type: ignore
b_recs, b_cmsgs, b_wmsgs = b
recommendations.update(a_recs)
@@ -1612,13 +1615,20 @@ def _transition(self, key, finish: str, *args, **kwargs):
else:
raise RuntimeError("Impossible transition from %r to %r" % start_finish)
+ if not stimulus_id:
+ stimulus_id = STIMULUS_ID_UNSET
+
finish2 = ts._state
# FIXME downcast antipattern
scheduler = cast(Scheduler, self)
scheduler.transition_log.append(
- (key, start, finish2, recommendations, time())
+ (key, start, finish2, recommendations, stimulus_id, time())
)
if self.validate:
+ if stimulus_id == STIMULUS_ID_UNSET:
+ raise RuntimeError(
+ "stimulus_id not set during Scheduler transition"
+ )
logger.debug(
"Transitioned %r %s->%s (actual: %s). Consequence: %s",
key,
@@ -1662,7 +1672,13 @@ def _transition(self, key, finish: str, *args, **kwargs):
pdb.set_trace()
raise
- def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: dict):
+ def _transitions(
+ self,
+ recommendations: dict,
+ client_msgs: dict,
+ worker_msgs: dict,
+ stimulus_id: str,
+ ):
"""Process transitions until none are left
This includes feedback from previous transitions and continues until we
@@ -1680,7 +1696,7 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di
key, finish = recommendations.popitem()
keys.add(key)
- new = self._transition(key, finish)
+ new = self._transition(key, finish, stimulus_id)
new_recs, new_cmsgs, new_wmsgs = new
recommendations.update(new_recs)
@@ -1703,7 +1719,7 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di
for key in keys:
scheduler.validate_key(key)
- def transition_released_waiting(self, key):
+ def transition_released_waiting(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -1758,7 +1774,7 @@ def transition_released_waiting(self, key):
pdb.set_trace()
raise
- def transition_no_worker_waiting(self, key):
+ def transition_no_worker_waiting(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -1806,7 +1822,13 @@ def transition_no_worker_waiting(self, key):
raise
def transition_no_worker_memory(
- self, key, nbytes=None, type=None, typename: str = None, worker=None
+ self,
+ key,
+ stimulus_id,
+ nbytes=None,
+ type=None,
+ typename: str = None,
+ worker=None,
):
try:
ws: WorkerState = self.workers[worker]
@@ -1961,7 +1983,7 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> float:
return total_duration
- def transition_waiting_processing(self, key):
+ def transition_waiting_processing(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -2006,7 +2028,14 @@ def transition_waiting_processing(self, key):
raise
def transition_waiting_memory(
- self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs
+ self,
+ key,
+ stimulus_id,
+ nbytes=None,
+ type=None,
+ typename: str = None,
+ worker=None,
+ **kwargs,
):
try:
ws: WorkerState = self.workers[worker]
@@ -2048,6 +2077,7 @@ def transition_waiting_memory(
def transition_processing_memory(
self,
key,
+ stimulus_id,
nbytes=None,
type=None,
typename: str = None,
@@ -2091,7 +2121,7 @@ def transition_processing_memory(
{
"op": "cancel-compute",
"key": key,
- "stimulus_id": f"processing-memory-{time()}",
+ "stimulus_id": stimulus_id,
}
]
@@ -2141,7 +2171,7 @@ def transition_processing_memory(
pdb.set_trace()
raise
- def transition_memory_released(self, key, safe: bool = False):
+ def transition_memory_released(self, key, stimulus_id, safe: bool = False):
ws: WorkerState
try:
ts: TaskState = self.tasks[key]
@@ -2179,7 +2209,7 @@ def transition_memory_released(self, key, safe: bool = False):
worker_msg = {
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"memory-released-{time()}",
+ "stimulus_id": stimulus_id,
}
for ws in ts.who_has:
worker_msgs[ws.address] = [worker_msg]
@@ -2211,7 +2241,7 @@ def transition_memory_released(self, key, safe: bool = False):
pdb.set_trace()
raise
- def transition_released_erred(self, key):
+ def transition_released_erred(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -2256,7 +2286,7 @@ def transition_released_erred(self, key):
pdb.set_trace()
raise
- def transition_erred_released(self, key):
+ def transition_erred_released(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -2282,7 +2312,7 @@ def transition_erred_released(self, key):
w_msg = {
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"erred-released-{time()}",
+ "stimulus_id": stimulus_id,
}
for ws_addr in ts.erred_on:
worker_msgs[ws_addr] = [w_msg]
@@ -2304,7 +2334,7 @@ def transition_erred_released(self, key):
pdb.set_trace()
raise
- def transition_waiting_released(self, key):
+ def transition_waiting_released(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
recommendations: dict = {}
@@ -2341,7 +2371,7 @@ def transition_waiting_released(self, key):
pdb.set_trace()
raise
- def transition_processing_released(self, key):
+ def transition_processing_released(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -2361,7 +2391,7 @@ def transition_processing_released(self, key):
{
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"processing-released-{time()}",
+ "stimulus_id": stimulus_id,
}
]
@@ -2395,6 +2425,7 @@ def transition_processing_released(self, key):
def transition_processing_erred(
self,
key: str,
+ stimulus_id: str,
cause: str = None,
exception=None,
traceback=None,
@@ -2481,7 +2512,7 @@ def transition_processing_erred(
pdb.set_trace()
raise
- def transition_no_worker_released(self, key):
+ def transition_no_worker_released(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -2523,7 +2554,7 @@ def remove_key(self, key):
ts.exception_blame = ts.exception = ts.traceback = None
self.task_metadata.pop(key, None)
- def transition_memory_forgotten(self, key):
+ def transition_memory_forgotten(self, key, stimulus_id):
ws: WorkerState
try:
ts: TaskState = self.tasks[key]
@@ -2551,7 +2582,7 @@ def transition_memory_forgotten(self, key):
for ws in ts.who_has:
ws.actors.discard(ts)
- _propagate_forgotten(self, ts, recommendations, worker_msgs)
+ _propagate_forgotten(self, ts, recommendations, worker_msgs, stimulus_id)
client_msgs = _task_to_client_msgs(self, ts)
self.remove_key(key)
@@ -2565,7 +2596,7 @@ def transition_memory_forgotten(self, key):
pdb.set_trace()
raise
- def transition_released_forgotten(self, key):
+ def transition_released_forgotten(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
recommendations: dict = {}
@@ -2589,7 +2620,7 @@ def transition_released_forgotten(self, key):
else:
assert 0, (ts,)
- _propagate_forgotten(self, ts, recommendations, worker_msgs)
+ _propagate_forgotten(self, ts, recommendations, worker_msgs, stimulus_id)
client_msgs = _task_to_client_msgs(self, ts)
self.remove_key(key)
@@ -3154,6 +3185,7 @@ def __init__(
"get_cluster_state": self.get_cluster_state,
"dump_cluster_state_to_url": self.dump_cluster_state_to_url,
"benchmark_hardware": self.benchmark_hardware,
+ "get_story": self.get_story,
}
connection_limit = get_fileno_limit() / 2
@@ -3480,7 +3512,7 @@ async def close(self, fast=False, close_workers=False):
disable_gc_diagnosis()
@log_errors
- async def close_worker(self, worker: str, safe: bool = False):
+ async def close_worker(self, worker: str, stimulus_id: str, safe: bool = False):
"""Remove a worker from the cluster
This both removes the worker from our local state and also sends a
@@ -3491,7 +3523,7 @@ async def close_worker(self, worker: str, safe: bool = False):
self.log_event(worker, {"action": "close-worker"})
# FIXME: This does not handle nannies
self.worker_send(worker, {"op": "close", "report": False})
- await self.remove_worker(address=worker, safe=safe)
+ await self.remove_worker(address=worker, safe=safe, stimulus_id=stimulus_id)
###########
# Stimuli #
@@ -3629,6 +3661,7 @@ async def add_worker(
versions: dict[str, Any] | None = None,
nanny=None,
extra=None,
+ stimulus_id=None,
):
"""Add a new worker to the cluster"""
address = self.coerce_address(address, resolve_address)
@@ -3727,12 +3760,15 @@ async def add_worker(
t: tuple = self._transition(
key,
"memory",
+ stimulus_id,
worker=address,
nbytes=nbytes[key],
typename=types[key],
)
recommendations, client_msgs, worker_msgs = t
- self._transitions(recommendations, client_msgs, worker_msgs)
+ self._transitions(
+ recommendations, client_msgs, worker_msgs, stimulus_id
+ )
recommendations = {}
else:
already_released_keys.append(key)
@@ -3743,7 +3779,7 @@ async def add_worker(
{
"op": "remove-replicas",
"keys": already_released_keys,
- "stimulus_id": f"reconnect-already-released-{time()}",
+ "stimulus_id": stimulus_id,
}
)
@@ -3751,7 +3787,7 @@ async def add_worker(
recommendations.update(self.bulk_schedule_after_adding_worker(ws))
if recommendations:
- self._transitions(recommendations, client_msgs, worker_msgs)
+ self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)
@@ -3779,7 +3815,7 @@ async def add_worker(
if comm:
await comm.write(msg)
- await self.handle_worker(comm=comm, worker=address)
+ await self.handle_worker(comm=comm, worker=address, stimulus_id=stimulus_id)
async def add_nanny(self, comm):
msg = {
@@ -3843,6 +3879,7 @@ def update_graph_hlg(
fifo_timeout,
annotations,
code=code,
+ stimulus_id=f"update-graph-{time()}",
)
def update_graph(
@@ -3862,12 +3899,14 @@ def update_graph(
fifo_timeout=0,
annotations=None,
code=None,
+ stimulus_id=None,
):
"""
Add new computations to the internal dask graph
This happens whenever the Client calls submit, map, get, or compute.
"""
+ stimulus_id = stimulus_id or f"update-graph-{time()}"
start = time()
fifo_timeout = parse_timedelta(fifo_timeout)
keys = set(keys)
@@ -3906,7 +3945,9 @@ def update_graph(
if k in keys:
keys.remove(k)
self.report({"op": "cancelled-key", "key": k}, client=client)
- self.client_releases_keys(keys=[k], client=client)
+ self.client_releases_keys(
+ keys=[k], client=client, stimulus_id=stimulus_id
+ )
# Avoid computation that is already finished
already_in_memory = set() # tasks that are already done
@@ -4121,7 +4162,7 @@ def update_graph(
except Exception as e:
logger.exception(e)
- self.transitions(recommendations)
+ self.transitions(recommendations, stimulus_id)
for ts in touched_tasks:
if ts.state in ("memory", "erred"):
@@ -4133,7 +4174,7 @@ def update_graph(
# TODO: balance workers
- def stimulus_task_finished(self, key=None, worker=None, **kwargs):
+ def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs):
"""Mark that a task has finished execution on a particular worker"""
logger.debug("Stimulus task finished %s, %s", key, worker)
@@ -4156,14 +4197,16 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs):
{
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"already-released-or-forgotten-{time()}",
+ "stimulus_id": stimulus_id,
}
]
elif ts.state == "memory":
self.add_keys(worker=worker, keys=[key])
else:
ts.metadata.update(kwargs["metadata"])
- r: tuple = self._transition(key, "memory", worker=worker, **kwargs)
+ r: tuple = self._transition(
+ key, "memory", stimulus_id, worker=worker, **kwargs
+ )
recommendations, client_msgs, worker_msgs = r
if ts.state == "memory":
@@ -4171,7 +4214,13 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs):
return recommendations, client_msgs, worker_msgs
def stimulus_task_erred(
- self, key=None, worker=None, exception=None, traceback=None, **kwargs
+ self,
+ key=None,
+ worker=None,
+ exception=None,
+ stimulus_id=None,
+ traceback=None,
+ **kwargs,
):
"""Mark that a task has erred on a particular worker"""
logger.debug("Stimulus task erred %s, %s", key, worker)
@@ -4182,11 +4231,12 @@ def stimulus_task_erred(
if ts.retries > 0:
ts.retries -= 1
- return self._transition(key, "waiting")
+ return self._transition(key, "waiting", stimulus_id)
else:
return self._transition(
key,
"erred",
+ stimulus_id,
cause=key,
exception=exception,
traceback=traceback,
@@ -4215,7 +4265,7 @@ def stimulus_retry(self, keys, client=None):
roots.append(key)
recommendations: dict = {key: "waiting" for key in roots}
- self.transitions(recommendations)
+ self.transitions(recommendations, f"stimulus-retry-{time()}")
if self.validate:
for key in seen:
@@ -4224,7 +4274,7 @@ def stimulus_retry(self, keys, client=None):
return tuple(seen)
@log_errors
- async def remove_worker(self, address, safe=False, close=True):
+ async def remove_worker(self, address, stimulus_id, safe=False, close=True):
"""
Remove worker from cluster
@@ -4291,7 +4341,9 @@ async def remove_worker(self, address, safe=False, close=True):
e = pickle.dumps(
KilledWorker(task=k, last_worker=ws.clean()), protocol=4
)
- r = self.transition(k, "erred", exception=e, cause=k)
+ r = self.transition(
+ k, "erred", exception=e, cause=k, stimulus_id=stimulus_id
+ )
recommendations.update(r)
logger.info(
"Task %s marked as failed because %d workers died"
@@ -4308,7 +4360,7 @@ async def remove_worker(self, address, safe=False, close=True):
else: # pure data
recommendations[ts.key] = "forgotten"
- self.transitions(recommendations)
+ self.transitions(recommendations, stimulus_id=stimulus_id)
for plugin in list(self.plugins.values()):
try:
@@ -4370,7 +4422,9 @@ def cancel_key(self, key, client, retries=5, force=False):
self.report({"op": "cancelled-key", "key": key})
clients = list(ts.who_wants) if force else [cs]
for cs in clients:
- self.client_releases_keys(keys=[key], client=cs.client_key)
+ self.client_releases_keys(
+ keys=[key], client=cs.client_key, stimulus_id=f"cancel-key-{time()}"
+ )
def client_desires_keys(self, keys=None, client=None):
cs: ClientState = self.clients.get(client)
@@ -4388,15 +4442,16 @@ def client_desires_keys(self, keys=None, client=None):
if ts.state in ("memory", "erred"):
self.report_on_key(ts=ts, client=client)
- def client_releases_keys(self, keys=None, client=None):
+ def client_releases_keys(self, keys=None, client=None, stimulus_id=None):
"""Remove keys from client desired list"""
+ stimulus_id = stimulus_id or f"client-releases-keys-{time()}"
if not isinstance(keys, list):
keys = list(keys)
cs: ClientState = self.clients[client]
recommendations: dict = {}
_client_releases_keys(self, keys=keys, cs=cs, recommendations=recommendations)
- self.transitions(recommendations)
+ self.transitions(recommendations, stimulus_id)
def client_heartbeat(self, client=None):
"""Handle heartbeats from Client"""
@@ -4625,7 +4680,7 @@ async def add_client(
try:
await self.handle_stream(comm=comm, extra={"client": client})
finally:
- self.remove_client(client=client)
+ self.remove_client(client=client, stimulus_id=f"remove-client-{time()}")
logger.debug("Finished handling client %s", client)
finally:
if not comm.closed():
@@ -4639,8 +4694,9 @@ async def add_client(
except TypeError: # comm becomes None during GC
pass
- def remove_client(self, client: str) -> None:
+ def remove_client(self, client: str, stimulus_id: str = None) -> None:
"""Remove client from network"""
+ stimulus_id = stimulus_id or f"remove-client-{time()}"
if self.status == Status.running:
logger.info("Remove client %s", client)
self.log_event(["all", client], {"action": "remove-client", "client": client})
@@ -4651,7 +4707,9 @@ def remove_client(self, client: str) -> None:
pass
else:
self.client_releases_keys(
- keys=[ts.key for ts in cs.wants_what], client=cs.client_key
+ keys=[ts.key for ts in cs.wants_what],
+ client=cs.client_key,
+ stimulus_id=stimulus_id,
)
del self.clients[client]
@@ -4687,24 +4745,30 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: float = -1):
def handle_uncaught_error(self, **msg):
logger.exception(clean_exception(**msg)[1])
- def handle_task_finished(self, key=None, worker=None, **msg):
+ def handle_task_finished(
+ self, key: str, worker: str, stimulus_id: str, **msg
+ ) -> None:
if worker not in self.workers:
return
validate_key(key)
- r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg)
+
+ r: tuple = self.stimulus_task_finished(
+ key=key, worker=worker, stimulus_id=stimulus_id, **msg
+ )
recommendations, client_msgs, worker_msgs = r
- self._transitions(recommendations, client_msgs, worker_msgs)
+ self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)
- def handle_task_erred(self, key=None, **msg):
- r: tuple = self.stimulus_task_erred(key=key, **msg)
+ def handle_task_erred(self, key: str, stimulus_id: str, **msg) -> None:
+ r: tuple = self.stimulus_task_erred(key=key, stimulus_id=stimulus_id, **msg)
recommendations, client_msgs, worker_msgs = r
- self._transitions(recommendations, client_msgs, worker_msgs)
-
+ self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)
- def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
+ def handle_missing_data(
+ self, key: str, errant_worker: str, stimulus_id: str, **kwargs
+ ) -> None:
"""Signal that `errant_worker` does not hold `key`
This may either indicate that `errant_worker` is dead or that we may be
@@ -4723,20 +4787,22 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
"""
logger.debug("handle missing data key=%s worker=%s", key, errant_worker)
self.log_event(errant_worker, {"action": "missing-data", "key": key})
- ts: TaskState = self.tasks.get(key)
- if ts is None:
+
+ if key not in self.tasks:
return
+
+ ts: TaskState = self.tasks[key]
ws: WorkerState = self.workers.get(errant_worker)
if ws is not None and ws in ts.who_has:
self.remove_replica(ts, ws)
if ts.state == "memory" and not ts.who_has:
if ts.run_spec:
- self.transitions({key: "released"})
+ self.transitions({key: "released"}, stimulus_id)
else:
- self.transitions({key: "forgotten"})
+ self.transitions({key: "forgotten"}, stimulus_id)
- def release_worker_data(self, key, worker):
+ def release_worker_data(self, key, worker, stimulus_id):
ws: WorkerState = self.workers.get(worker)
ts: TaskState = self.tasks.get(key)
if not ws or not ts:
@@ -4747,7 +4813,7 @@ def release_worker_data(self, key, worker):
if not ts.who_has:
recommendations[ts.key] = "released"
if recommendations:
- self.transitions(recommendations)
+ self.transitions(recommendations, stimulus_id)
def handle_long_running(self, key=None, worker=None, compute_duration=None):
"""A task has seceded from the thread pool
@@ -4789,7 +4855,9 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None):
ws.long_running.add(ts)
self.check_idle_saturated(ws)
- def handle_worker_status_change(self, status: str, worker: str) -> None:
+ def handle_worker_status_change(
+ self, status: str, worker: str, stimulus_id: str
+ ) -> None:
ws: WorkerState = self.workers.get(worker) # type: ignore
if not ws:
return
@@ -4813,13 +4881,13 @@ def handle_worker_status_change(self, status: str, worker: str) -> None:
if recs:
client_msgs: dict = {}
worker_msgs: dict = {}
- self._transitions(recs, client_msgs, worker_msgs)
+ self._transitions(recs, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)
else:
self.running.discard(ws)
- async def handle_worker(self, comm=None, worker=None):
+ async def handle_worker(self, comm=None, worker=None, stimulus_id=None):
"""
Listen to responses from a single worker
@@ -4829,6 +4897,7 @@ async def handle_worker(self, comm=None, worker=None):
--------
Scheduler.handle_client: Equivalent coroutine for clients
"""
+ assert stimulus_id
comm.name = "Scheduler connection to worker"
worker_comm = self.stream_comms[worker]
worker_comm.start(comm)
@@ -4838,7 +4907,7 @@ async def handle_worker(self, comm=None, worker=None):
finally:
if worker in self.stream_comms:
worker_comm.abort()
- await self.remove_worker(address=worker)
+ await self.remove_worker(address=worker, stimulus_id=stimulus_id)
def add_plugin(
self,
@@ -4932,7 +5001,11 @@ def worker_send(self, worker, msg):
try:
stream_comms[worker].send(msg)
except (CommClosedError, AttributeError):
- self.loop.add_callback(self.remove_worker, address=worker)
+ self.loop.add_callback(
+ self.remove_worker,
+ address=worker,
+ stimulus_id=f"worker-send-comm-fail-{time()}",
+ )
def client_send(self, client, msg):
"""Send message to client"""
@@ -4977,7 +5050,11 @@ def send_all(self, client_msgs: dict, worker_msgs: dict):
# worker already gone
pass
except (CommClosedError, AttributeError):
- self.loop.add_callback(self.remove_worker, address=worker)
+ self.loop.add_callback(
+ self.remove_worker,
+ address=worker,
+ stimulus_id=f"send-all-comm-fail-{time()}",
+ )
############################
# Less common interactions #
@@ -5034,6 +5111,7 @@ async def scatter(
async def gather(self, keys, serializers=None):
"""Collect data from workers to the scheduler"""
+ stimulus_id = f"gather-{time()}"
keys = list(keys)
who_has = {}
for key in keys:
@@ -5065,7 +5143,9 @@ async def gather(self, keys, serializers=None):
# reconnect.
await asyncio.gather(
*(
- self.remove_worker(address=worker, close=False)
+ self.remove_worker(
+ address=worker, close=False, stimulus_id=stimulus_id
+ )
for worker in missing_workers
)
)
@@ -5087,8 +5167,14 @@ async def gather(self, keys, serializers=None):
for worker in workers:
ws = self.workers.get(worker)
if ws is not None and ws in ts.who_has:
+ # FIXME: This code path is not tested
self.remove_replica(ts, ws)
- self._transitions(recommendations, client_msgs, worker_msgs)
+ self._transitions(
+ recommendations,
+ client_msgs,
+ worker_msgs,
+ stimulus_id=stimulus_id,
+ )
self.send_all(client_msgs, worker_msgs)
self.log_event("all", {"action": "gather", "count": len(keys)})
@@ -5104,12 +5190,15 @@ def clear_task_state(self):
@log_errors
async def restart(self, client=None, timeout=30):
"""Restart all workers. Reset local state."""
+ stimulus_id = f"restart-{time()}"
n_workers = len(self.workers)
logger.info("Send lost future signal to clients")
for cs in self.clients.values():
self.client_releases_keys(
- keys=[ts.key for ts in cs.wants_what], client=cs.client_key
+ keys=[ts.key for ts in cs.wants_what],
+ client=cs.client_key,
+ stimulus_id=stimulus_id,
)
nannies = {addr: ws.nanny for addr, ws in self.workers.items()}
@@ -5118,7 +5207,9 @@ async def restart(self, client=None, timeout=30):
try:
# Ask the worker to close if it doesn't have a nanny,
# otherwise the nanny will kill it anyway
- await self.remove_worker(address=addr, close=addr not in nannies)
+ await self.remove_worker(
+ address=addr, close=addr not in nannies, stimulus_id=stimulus_id
+ )
except Exception:
logger.info(
"Exception while restarting. This is normal", exc_info=True
@@ -5307,7 +5398,7 @@ async def gather_on_worker(
return keys_failed
async def delete_worker_data(
- self, worker_address: str, keys: "Collection[str]"
+ self, worker_address: str, keys: "Collection[str]", stimulus_id: str
) -> None:
"""Delete data from a worker and update the corresponding worker/task states
@@ -5344,7 +5435,7 @@ async def delete_worker_data(
self.remove_replica(ts, ws)
if not ts.who_has:
# Last copy deleted
- self.transitions({key: "released"})
+ self.transitions({key: "released"}, stimulus_id)
self.log_event(ws.address, {"action": "remove-worker-data", "keys": keys})
@@ -5354,6 +5445,7 @@ async def rebalance(
comm=None,
keys: "Iterable[Hashable]" = None,
workers: "Iterable[str]" = None,
+ stimulus_id: str = None,
) -> dict:
"""Rebalance keys so that each worker ends up with roughly the same process
memory (managed+unmanaged).
@@ -5420,6 +5512,7 @@ async def rebalance(
All other workers will be ignored. The mean cluster occupancy will be
calculated only using the allowed workers.
"""
+ stimulus_id = stimulus_id or f"rebalance-{time()}"
if workers is not None:
wss = [self.workers[w] for w in workers]
else:
@@ -5443,7 +5536,7 @@ async def rebalance(
return {"status": "OK"}
async with self._lock:
- result = await self._rebalance_move_data(msgs)
+ result = await self._rebalance_move_data(msgs, stimulus_id)
if result["status"] == "partial-fail" and keys is None:
# Only return failed keys if the client explicitly asked for them
result = {"status": "OK"}
@@ -5640,7 +5733,7 @@ def _rebalance_find_msgs(
return msgs
async def _rebalance_move_data(
- self, msgs: "list[tuple[WorkerState, WorkerState, TaskState]]"
+ self, msgs: "list[tuple[WorkerState, WorkerState, TaskState]]", stimulus_id: str
) -> dict:
"""Perform the actual transfer of data across the network in rebalance().
Takes in input the output of _rebalance_find_msgs(), that is a list of tuples:
@@ -5676,7 +5769,7 @@ async def _rebalance_move_data(
# Note: this never raises exceptions
await asyncio.gather(
- *(self.delete_worker_data(r, v) for r, v in to_senders.items())
+ *(self.delete_worker_data(r, v, stimulus_id) for r, v in to_senders.items())
)
for r, v in to_recipients.items():
@@ -5706,6 +5799,7 @@ async def replicate(
branching_factor=2,
delete=True,
lock=True,
+ stimulus_id=None,
):
"""Replicate data throughout cluster
@@ -5728,6 +5822,7 @@ async def replicate(
--------
Scheduler.rebalance
"""
+ stimulus_id = stimulus_id or f"replicate-{time()}"
assert branching_factor > 0
async with self._lock if lock else empty_context:
if workers is not None:
@@ -5762,7 +5857,9 @@ async def replicate(
# Note: this never raises exceptions
await asyncio.gather(
*[
- self.delete_worker_data(ws.address, [t.key for t in tasks])
+ self.delete_worker_data(
+ ws.address, [t.key for t in tasks], stimulus_id
+ )
for ws, tasks in del_worker_tasks.items()
]
)
@@ -5951,6 +6048,7 @@ async def retire_workers(
names: "list | None" = None,
close_workers: bool = False,
remove: bool = True,
+ stimulus_id: str = None,
**kwargs,
) -> dict:
"""Gracefully retire workers from cluster
@@ -5984,6 +6082,7 @@ async def retire_workers(
--------
Scheduler.workers_to_close
"""
+ stimulus_id = stimulus_id or f"retire-workers-{time()}"
# This lock makes retire_workers, rebalance, and replicate mutually
# exclusive and will no longer be necessary once rebalance and replicate are
# migrated to the Active Memory Manager.
@@ -6034,7 +6133,11 @@ async def retire_workers(
ws.status = Status.closing_gracefully
self.running.discard(ws)
self.stream_comms[ws.address].send(
- {"op": "worker-status-change", "status": ws.status.name}
+ {
+ "op": "worker-status-change",
+ "status": ws.status.name,
+ "stimulus_id": stimulus_id,
+ }
)
coros.append(
@@ -6044,6 +6147,7 @@ async def retire_workers(
prev_status=prev_status,
close_workers=close_workers,
remove=remove,
+ stimulus_id=stimulus_id,
)
)
@@ -6070,6 +6174,7 @@ async def _track_retire_worker(
prev_status: Status,
close_workers: bool,
remove: bool,
+ stimulus_id: str,
) -> tuple: # tuple[str | None, dict]
while not policy.done():
if policy.no_recipients:
@@ -6077,7 +6182,11 @@ async def _track_retire_worker(
# conditions and we can wait for a scheduler->worker->scheduler
# round-trip.
self.stream_comms[ws.address].send(
- {"op": "worker-status-change", "status": prev_status.name}
+ {
+ "op": "worker-status-change",
+ "status": prev_status.name,
+ "stimulus_id": stimulus_id,
+ }
)
return None, {}
@@ -6091,9 +6200,13 @@ async def _track_retire_worker(
)
if close_workers and ws.address in self.workers:
- await self.close_worker(worker=ws.address, safe=True)
+ await self.close_worker(
+ worker=ws.address, safe=True, stimulus_id=stimulus_id
+ )
if remove:
- await self.remove_worker(address=ws.address, safe=True)
+ await self.remove_worker(
+ address=ws.address, safe=True, stimulus_id=stimulus_id
+ )
logger.info("Retired worker %s", ws.address)
return ws.address, ws.identity()
@@ -6524,7 +6637,7 @@ async def unregister_nanny_plugin(self, comm, name):
)
return responses
- def transition(self, key, finish: str, *args, **kwargs):
+ def transition(self, key, finish: str, *args, stimulus_id: str, **kwargs):
"""Transition a key from its current state to the finish state
Examples
@@ -6540,12 +6653,12 @@ def transition(self, key, finish: str, *args, **kwargs):
--------
Scheduler.transitions: transitive version of this function
"""
- a: tuple = self._transition(key, finish, *args, **kwargs)
+ a: tuple = self._transition(key, finish, stimulus_id, *args, **kwargs)
recommendations, client_msgs, worker_msgs = a
self.send_all(client_msgs, worker_msgs)
return recommendations
- def transitions(self, recommendations: dict):
+ def transitions(self, recommendations: dict, stimulus_id: str):
"""Process transitions until none are left
This includes feedback from previous transitions and continues until we
@@ -6553,7 +6666,7 @@ def transitions(self, recommendations: dict):
"""
client_msgs: dict = {}
worker_msgs: dict = {}
- self._transitions(recommendations, client_msgs, worker_msgs)
+ self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)
def story(self, *keys):
@@ -6561,6 +6674,9 @@ def story(self, *keys):
keys = {key.key if isinstance(key, TaskState) else key for key in keys}
return scheduler_story(keys, self.transition_log)
+ async def get_story(self, keys=()):
+ return self.story(*keys)
+
transition_story = story
def reschedule(self, key=None, worker=None):
@@ -6581,7 +6697,7 @@ def reschedule(self, key=None, worker=None):
return
if worker and ts.processing_on.address != worker:
return
- self.transitions({key: "released"})
+ self.transitions({key: "released"}, f"reschedule-{time()}")
#####################
# Utility functions #
@@ -7005,6 +7121,7 @@ def reevaluate_occupancy(self, worker_index: int = 0):
async def check_worker_ttl(self):
now = time()
+ stimulus_id = f"check-worker-ttl-{now}"
for ws in self.workers.values():
if (ws.last_seen < now - self.worker_ttl) and (
ws.last_seen < now - 10 * heartbeat_interval(len(self.workers))
@@ -7014,7 +7131,7 @@ async def check_worker_ttl(self):
self.worker_ttl,
ws,
)
- await self.remove_worker(address=ws.address)
+ await self.remove_worker(address=ws.address, stimulus_id=stimulus_id)
def check_idle(self):
if any([ws.processing for ws in self.workers.values()]) or self.unrunnable:
@@ -7223,7 +7340,11 @@ def _add_to_memory(
def _propagate_forgotten(
- state: SchedulerState, ts: TaskState, recommendations: dict, worker_msgs: dict
+ state: SchedulerState,
+ ts: TaskState,
+ recommendations: dict,
+ worker_msgs: dict,
+ stimulus_id: str,
):
ts.state = "forgotten"
key: str = ts.key
@@ -7254,7 +7375,7 @@ def _propagate_forgotten(
{
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"propagate-forgotten-{time()}",
+ "stimulus_id": stimulus_id,
}
]
state.remove_all_replicas(ts)
diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py
index e7015cf1f3..dd1adc070f 100644
--- a/distributed/tests/test_cancelled_state.py
+++ b/distributed/tests/test_cancelled_state.py
@@ -4,7 +4,7 @@
from distributed import Event, Worker
from distributed.utils_test import (
_LockedCommPool,
- assert_worker_story,
+ assert_story,
gen_cluster,
inc,
slowinc,
@@ -80,7 +80,7 @@ def f(ev):
while "f1" in a.tasks:
await asyncio.sleep(0.01)
- assert_worker_story(
+ assert_story(
a.story("f1"),
[
("f1", "compute-task"),
@@ -156,7 +156,7 @@ async def get_data(self, comm, *args, **kwargs):
await wait_for_state(fut1.key, "flight", b)
# Close in scheduler to ensure we transition and reschedule task properly
- await s.close_worker(worker=a.address)
+ await s.close_worker(worker=a.address, stimulus_id="test")
await wait_for_state(fut1.key, "resumed", b)
lock.release()
diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py
index 9fc8982717..00783cb924 100644
--- a/distributed/tests/test_client.py
+++ b/distributed/tests/test_client.py
@@ -4462,7 +4462,7 @@ def test_auto_normalize_collection_sync(c):
def assert_no_data_loss(scheduler):
- for key, start, finish, recommendations, _ in scheduler.transition_log:
+ for key, start, finish, recommendations, _, _ in scheduler.transition_log:
if start == "memory" and finish == "released":
for k, v in recommendations.items():
assert not (k == key and v == "waiting")
diff --git a/distributed/tests/test_cluster_dump.py b/distributed/tests/test_cluster_dump.py
index c3912116c4..540e0abb4e 100644
--- a/distributed/tests/test_cluster_dump.py
+++ b/distributed/tests/test_cluster_dump.py
@@ -8,7 +8,7 @@
import distributed
from distributed.cluster_dump import DumpArtefact, _tuple_to_list, write_state
-from distributed.utils_test import assert_worker_story, gen_cluster, gen_test, inc
+from distributed.utils_test import assert_story, gen_cluster, gen_test, inc
@pytest.mark.parametrize(
@@ -126,31 +126,12 @@ async def test_cluster_dump_story(c, s, a, b, tmp_path):
assert story.keys() == {"f1", "f2"}
for k, task_story in story.items():
- expected = [
- (k, "released", "waiting", {k: "processing"}),
- (k, "waiting", "processing", {}),
- (k, "processing", "memory", {}),
- ]
-
- for event, expected_event in zip(task_story, expected):
- for e1, e2 in zip(event, expected_event):
- assert e1 == e2
+ assert_story(task_story, s.story(k))
story = dump.worker_story("f1", "f2")
assert story.keys() == {"f1", "f2"}
-
for k, task_story in story.items():
- assert_worker_story(
- task_story,
- [
- (k, "compute-task"),
- (k, "released", "waiting", "waiting", {k: "ready"}),
- (k, "waiting", "ready", "ready", {k: "executing"}),
- (k, "ready", "executing", "executing", {}),
- (k, "put-in-memory"),
- (k, "executing", "memory", "memory", {}),
- ],
- )
+ assert_story(task_story, a.story(k) + b.story(k))
@gen_cluster(client=True)
diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py
index 1f0d335192..ff76102475 100644
--- a/distributed/tests/test_scheduler.py
+++ b/distributed/tests/test_scheduler.py
@@ -313,7 +313,7 @@ async def test_remove_worker_from_scheduler(s, a, b):
)
assert a.address in s.stream_comms
- await s.remove_worker(address=a.address)
+ await s.remove_worker(address=a.address, stimulus_id="test")
assert a.address not in s.workers
assert len(s.workers[b.address].processing) == len(dsk) # b owns everything
@@ -321,9 +321,12 @@ async def test_remove_worker_from_scheduler(s, a, b):
@gen_cluster()
async def test_remove_worker_by_name_from_scheduler(s, a, b):
assert a.address in s.stream_comms
- assert await s.remove_worker(address=a.name) == "OK"
+ assert await s.remove_worker(address=a.name, stimulus_id="test") == "OK"
assert a.address not in s.workers
- assert await s.remove_worker(address=a.address) == "already-removed"
+ assert (
+ await s.remove_worker(address=a.address, stimulus_id="test")
+ == "already-removed"
+ )
@gen_cluster(config={"distributed.scheduler.events-cleanup-delay": "10 ms"})
@@ -333,7 +336,7 @@ async def test_clear_events_worker_removal(s, a, b):
assert b.address in s.events
assert b.address in s.workers
- await s.remove_worker(address=a.address)
+ await s.remove_worker(address=a.address, stimulus_id="test")
# Shortly after removal, the events should still be there
assert a.address in s.events
assert a.address not in s.workers
@@ -596,7 +599,7 @@ async def test_ready_remove_worker(s, a, b):
assert all(len(w.processing) > w.nthreads for w in s.workers.values())
- await s.remove_worker(address=a.address)
+ await s.remove_worker(address=a.address, stimulus_id="test")
assert set(s.workers) == {b.address}
assert all(len(w.processing) > w.nthreads for w in s.workers.values())
@@ -789,7 +792,7 @@ async def test_story(c, s, a, b):
story = s.story(x.key)
assert all(line in s.transition_log for line in story)
assert len(story) < len(s.transition_log)
- assert all(x.key == line[0] or x.key in line[-2] for line in story)
+ assert all(x.key == line[0] or x.key in line[3] for line in story)
assert len(s.story(x.key, y.key)) > len(story)
@@ -1212,7 +1215,7 @@ def f(dask_scheduler=None):
async def test_close_worker(c, s, a, b):
assert len(s.workers) == 2
- await s.close_worker(worker=a.address)
+ await s.close_worker(worker=a.address, stimulus_id="test")
assert len(s.workers) == 1
assert a.address not in s.workers
@@ -1230,7 +1233,7 @@ async def test_close_nanny(c, s, a, b):
assert a.process.is_alive()
a_worker_address = a.worker_address
start = time()
- await s.close_worker(worker=a_worker_address)
+ await s.close_worker(worker=a_worker_address, stimulus_id="test")
assert len(s.workers) == 1
assert a_worker_address not in s.workers
@@ -3094,7 +3097,9 @@ async def test_rebalance_dead_recipient(client, s, a, b, c):
await c.close()
assert s.workers.keys() == {a.address, b.address}
- out = await s._rebalance_move_data([(a_ws, b_ws, x_ts), (a_ws, c_ws, y_ts)])
+ out = await s._rebalance_move_data(
+ [(a_ws, b_ws, x_ts), (a_ws, c_ws, y_ts)], stimulus_id="test"
+ )
assert out == {"status": "partial-fail", "keys": [y.key]}
assert a.data == {y.key: "y"}
assert b.data == {x.key: "x"}
@@ -3113,7 +3118,7 @@ async def test_delete_worker_data(c, s, a, b):
assert b.data == {y.key: "y"}
assert s.tasks.keys() == {x.key, y.key, z.key}
- await s.delete_worker_data(a.address, [x.key, y.key])
+ await s.delete_worker_data(a.address, [x.key, y.key], stimulus_id="test")
assert a.data == {z.key: "z"}
assert b.data == {y.key: "y"}
assert s.tasks.keys() == {y.key, z.key}
@@ -3127,8 +3132,8 @@ async def test_delete_worker_data_double_delete(c, s, a):
"""
x, y = await c.scatter(["x", "y"])
await asyncio.gather(
- s.delete_worker_data(a.address, [x.key]),
- s.delete_worker_data(a.address, [x.key]),
+ s.delete_worker_data(a.address, [x.key], stimulus_id="test"),
+ s.delete_worker_data(a.address, [x.key], stimulus_id="test"),
)
assert a.data == {y.key: "y"}
a_ws = s.workers[a.address]
@@ -3143,7 +3148,7 @@ async def test_delete_worker_data_bad_worker(s, a, b):
"""
await a.close()
assert s.workers.keys() == {b.address}
- await s.delete_worker_data(a.address, ["x"])
+ await s.delete_worker_data(a.address, ["x"], stimulus_id="test")
@pytest.mark.parametrize("bad_first", [False, True])
@@ -3158,7 +3163,7 @@ async def test_delete_worker_data_bad_task(c, s, a, bad_first):
assert s.tasks.keys() == {x.key, y.key}
keys = ["notexist", x.key] if bad_first else [x.key, "notexist"]
- await s.delete_worker_data(a.address, keys)
+ await s.delete_worker_data(a.address, keys, stimulus_id="test")
assert a.data == {y.key: "y"}
assert s.tasks.keys() == {y.key}
assert s.workers[a.address].nbytes == s.tasks[y.key].nbytes
@@ -3243,13 +3248,13 @@ async def test_worker_reconnect_task_memory(c, s, a):
while not a.executing_count and not a.data:
await asyncio.sleep(0.001)
- await s.remove_worker(address=a.address, close=False)
+ await s.remove_worker(address=a.address, close=False, stimulus_id="test")
while not res.done():
await a.heartbeat()
await res
assert ("no-worker", "memory") in {
- (start, finish) for (_, start, finish, _, _) in s.transition_log
+ (start, finish) for (_, start, finish, _, _, _) in s.transition_log
}
@@ -3267,13 +3272,13 @@ async def test_worker_reconnect_task_memory_with_resources(c, s, a):
while not b.executing_count and not b.data:
await asyncio.sleep(0.001)
- await s.remove_worker(address=b.address, close=False)
+ await s.remove_worker(address=b.address, close=False, stimulus_id="test")
while not res.done():
await b.heartbeat()
await res
assert ("no-worker", "memory") in {
- (start, finish) for (_, start, finish, _, _) in s.transition_log
+ (start, finish) for (_, start, finish, _, _, _) in s.transition_log
}
diff --git a/distributed/tests/test_stories.py b/distributed/tests/test_stories.py
new file mode 100644
index 0000000000..1616422f25
--- /dev/null
+++ b/distributed/tests/test_stories.py
@@ -0,0 +1,170 @@
+import pytest
+
+import dask
+
+from distributed import Worker
+from distributed.comm import CommClosedError
+from distributed.utils_test import assert_story, assert_valid_story, gen_cluster, inc
+
+
+@gen_cluster(client=True, nthreads=[("", 1)])
+async def test_scheduler_story_stimulus_success(c, s, a):
+ f = c.submit(inc, 1)
+ key = f.key
+
+ await f
+
+ stories = s.story(key)
+
+ stimulus_ids = {s[-2] for s in stories}
+ # Two events
+ # - Compute
+ # - Success
+ assert len(stimulus_ids) == 2
+ assert_story(
+ stories,
+ [
+ (key, "released", "waiting", {key: "processing"}),
+ (key, "waiting", "processing", {}),
+ (key, "processing", "memory", {}),
+ ],
+ )
+
+ await c.close()
+
+ stories_after_close = s.story(key)
+ assert len(stories_after_close) > len(stories)
+
+ stimulus_ids = {s[-2] for s in stories_after_close}
+ # One more event
+ # - Forget / Release / Free since client closed
+ assert len(stimulus_ids) == 3
+
+
+@gen_cluster(client=True, nthreads=[("", 1)])
+async def test_scheduler_story_stimulus_retry(c, s, a):
+ def task():
+ assert dask.config.get("foo")
+
+ with dask.config.set(foo=False):
+ f = c.submit(task)
+ with pytest.raises(AssertionError):
+ await f
+
+ with dask.config.set(foo=True):
+ await f.retry()
+ await f
+
+ story = s.story(f.key)
+ stimulus_ids = {s[-2] for s in story}
+ # Four events
+ # - Compute
+ # - Erred
+ # - Compute / Retry
+ # - Success
+ assert len(stimulus_ids) == 4
+
+ assert_story(
+ story,
+ [
+ # Erred transitions via released
+ (f.key, "processing", "erred", {}),
+ (f.key, "erred", "released", {}),
+ (f.key, "released", "waiting", {f.key: "processing"}),
+ ],
+ )
+
+
+@gen_cluster(client=True, nthreads=[("", 1)])
+async def test_client_story(c, s, a):
+ f = c.submit(inc, 1)
+ assert await f == 2
+ story = await c.story(f.key)
+
+ # Every event should be prefixed with it's origin
+ # This changes the format compared to default scheduler / worker stories
+ prefixes = set()
+ stripped_story = list()
+ for msg in story:
+ prefixes.add(msg[0])
+ stripped_story.append(msg[1:])
+ assert prefixes == {"scheduler", a.address}
+
+ assert_valid_story(stripped_story, ordered_timestamps=False)
+
+ # If it's a well formed story, we can sort by the last element which is a
+ # timestamp and compare the two lists.
+ assert sorted(stripped_story, key=lambda msg: msg[-1]) == sorted(
+ s.story(f.key) + a.story(f.key), key=lambda msg: msg[-1]
+ )
+
+
+class WorkerBrokenStory(Worker):
+ async def get_story(self, *args, **kw):
+ raise CommClosedError
+
+
+@gen_cluster(client=True, Worker=WorkerBrokenStory)
+@pytest.mark.parametrize("on_error", ["ignore", "raise"])
+async def test_client_story_failed_worker(c, s, a, b, on_error):
+ f = c.submit(inc, 1)
+ coro = c.story(f.key, on_error=on_error)
+ await f
+
+ if on_error == "raise":
+ with pytest.raises(CommClosedError):
+ await coro
+ elif on_error == "ignore":
+ story = await coro
+ assert story
+ assert len(story) > 1
+ else:
+ raise ValueError(on_error)
+
+
+@gen_cluster(client=True)
+async def test_worker_story_with_deps(c, s, a, b):
+ """
+ Assert that the structure of the story does not change unintentionally and
+ expected subfields are actually filled
+ """
+ dep = c.submit(inc, 1, workers=[a.address], key="dep")
+ res = c.submit(inc, dep, workers=[b.address], key="res")
+ await res
+
+ story = a.story("res")
+ assert story == []
+ story = b.story("res")
+
+ # Story now includes randomized stimulus_ids and timestamps.
+ stimulus_ids = {ev[-2] for ev in story}
+ # Compute dep
+ # Success dep
+ # Compute res
+ assert len(stimulus_ids) == 3
+
+ # This is a simple transition log
+ expected = [
+ ("res", "compute-task"),
+ ("res", "released", "waiting", "waiting", {"dep": "fetch"}),
+ ("res", "waiting", "ready", "ready", {"res": "executing"}),
+ ("res", "ready", "executing", "executing", {}),
+ ("res", "put-in-memory"),
+ ("res", "executing", "memory", "memory", {}),
+ ]
+ assert_story(story, expected, strict=True)
+
+ story = b.story("dep")
+ stimulus_ids = {ev[-2] for ev in story}
+ assert len(stimulus_ids) == 2, stimulus_ids
+ expected = [
+ ("dep", "ensure-task-exists", "released"),
+ ("dep", "released", "fetch", "fetch", {}),
+ ("gather-dependencies", a.address, {"dep"}),
+ ("dep", "fetch", "flight", "flight", {}),
+ ("request-dep", a.address, {"dep"}),
+ ("receive-dep", a.address, {"dep"}),
+ ("dep", "put-in-memory"),
+ ("dep", "flight", "memory", "memory", {"res": "ready"}),
+ ]
+ assert_story(story, expected, strict=True)
diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py
index 3d4c7aee4d..c7770079b3 100755
--- a/distributed/tests/test_utils_test.py
+++ b/distributed/tests/test_utils_test.py
@@ -21,7 +21,7 @@
from distributed.utils_test import (
_LockedCommPool,
_UnhashableCallable,
- assert_worker_story,
+ assert_story,
check_process_leak,
cluster,
dump_cluster_state,
@@ -407,7 +407,7 @@ async def inner_test(c, s, a, b):
assert "workers" in state
-def test_assert_worker_story():
+def test_assert_story():
now = time()
story = [
("foo", "id1", now - 600),
@@ -415,38 +415,38 @@ def test_assert_worker_story():
("baz", {1: 2}, "id2", now),
]
# strict=False
- assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})])
- assert_worker_story(story, [])
- assert_worker_story(story, [("foo",)])
- assert_worker_story(story, [("foo",), ("bar",)])
- assert_worker_story(story, [("baz", lambda d: d[1] == 2)])
+ assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})])
+ assert_story(story, [])
+ assert_story(story, [("foo",)])
+ assert_story(story, [("foo",), ("bar",)])
+ assert_story(story, [("baz", lambda d: d[1] == 2)])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo", "nomatch")])
+ assert_story(story, [("foo", "nomatch")])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("baz",)])
+ assert_story(story, [("baz",)])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("baz", {1: 3})])
+ assert_story(story, [("baz", {1: 3})])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)])
+ assert_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("baz", lambda d: d[1] == 3)])
+ assert_story(story, [("baz", lambda d: d[1] == 3)])
with pytest.raises(KeyError): # Faulty lambda
- assert_worker_story(story, [("baz", lambda d: d[2] == 1)])
- assert_worker_story([], [])
- assert_worker_story([("foo", "id1", now)], [("foo",)])
+ assert_story(story, [("baz", lambda d: d[2] == 1)])
+ assert_story([], [])
+ assert_story([("foo", "id1", now)], [("foo",)])
with pytest.raises(AssertionError):
- assert_worker_story([], [("foo",)])
+ assert_story([], [("foo",)])
# strict=True
- assert_worker_story([], [], strict=True)
- assert_worker_story([("foo", "id1", now)], [("foo",)])
- assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True)
+ assert_story([], [], strict=True)
+ assert_story([("foo", "id1", now)], [("foo",)])
+ assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True)
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo",), ("bar",)], strict=True)
+ assert_story(story, [("foo",), ("bar",)], strict=True)
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo",), ("baz", {1: 2})], strict=True)
+ assert_story(story, [("foo",), ("baz", {1: 2})], strict=True)
with pytest.raises(AssertionError):
- assert_worker_story(story, [], strict=True)
+ assert_story(story, [], strict=True)
@pytest.mark.parametrize(
@@ -467,11 +467,29 @@ def test_assert_worker_story():
),
],
)
-def test_assert_worker_story_malformed_story(story_factory):
+def test_assert_story_malformed_story(story_factory):
# defer the calls to time() to when the test runs rather than collection
story = story_factory()
with pytest.raises(AssertionError, match="Malformed story event"):
- assert_worker_story(story, [])
+ assert_story(story, [])
+
+
+@pytest.mark.parametrize("strict", [True, False])
+@gen_cluster(client=True, nthreads=[("", 1)])
+async def test_assert_story_identity(c, s, a, strict):
+ f1 = c.submit(inc, 1, key="f1")
+ f2 = c.submit(inc, f1, key="f2")
+ assert await f2 == 3
+ scheduler_story = s.story(f2.key)
+ assert scheduler_story
+ worker_story = a.story(f2.key)
+ assert worker_story
+ assert_story(worker_story, worker_story, strict=strict)
+ assert_story(scheduler_story, scheduler_story, strict=strict)
+ with pytest.raises(AssertionError):
+ assert_story(scheduler_story, worker_story, strict=strict)
+ with pytest.raises(AssertionError):
+ assert_story(worker_story, scheduler_story, strict=strict)
@gen_cluster()
diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py
index 7f796c015b..8fff62dced 100644
--- a/distributed/tests/test_worker.py
+++ b/distributed/tests/test_worker.py
@@ -48,7 +48,7 @@
from distributed.utils_test import (
TaskStateMetadataPlugin,
_LockedCommPool,
- assert_worker_story,
+ assert_story,
captured_logger,
dec,
div,
@@ -1410,7 +1410,7 @@ def assert_amm_transfer_story(key: str, w_from: Worker, w_to: Worker) -> None:
"""Test that an in-memory key was transferred from worker w_from to worker w_to by
the Active Memory Manager and it was not recalculated on w_to
"""
- assert_worker_story(
+ assert_story(
w_to.story(key),
[
(key, "ensure-task-exists", "released"),
@@ -1733,50 +1733,6 @@ async def test_story(c, s, w):
assert w.story(ts) == w.story(ts.key)
-@gen_cluster(client=True)
-async def test_story_with_deps(c, s, a, b):
- """
- Assert that the structure of the story does not change unintentionally and
- expected subfields are actually filled
- """
- dep = c.submit(inc, 1, workers=[a.address], key="dep")
- res = c.submit(inc, dep, workers=[b.address], key="res")
- await res
-
- story = a.story("res")
- assert story == []
- story = b.story("res")
-
- # Story now includes randomized stimulus_ids and timestamps.
- stimulus_ids = {ev[-2] for ev in story}
- assert len(stimulus_ids) == 2, stimulus_ids
- # This is a simple transition log
- expected = [
- ("res", "compute-task"),
- ("res", "released", "waiting", "waiting", {"dep": "fetch"}),
- ("res", "waiting", "ready", "ready", {"res": "executing"}),
- ("res", "ready", "executing", "executing", {}),
- ("res", "put-in-memory"),
- ("res", "executing", "memory", "memory", {}),
- ]
- assert_worker_story(story, expected, strict=True)
-
- story = b.story("dep")
- stimulus_ids = {ev[-2] for ev in story}
- assert len(stimulus_ids) == 2, stimulus_ids
- expected = [
- ("dep", "ensure-task-exists", "released"),
- ("dep", "released", "fetch", "fetch", {}),
- ("gather-dependencies", a.address, {"dep"}),
- ("dep", "fetch", "flight", "flight", {}),
- ("request-dep", a.address, {"dep"}),
- ("receive-dep", a.address, {"dep"}),
- ("dep", "put-in-memory"),
- ("dep", "flight", "memory", "memory", {"res": "ready"}),
- ]
- assert_worker_story(story, expected, strict=True)
-
-
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_stimulus_story(c, s, a):
class C:
@@ -2712,7 +2668,7 @@ async def test_gather_dep_exception_one_task_2(c, s, a, b):
while fut1.key not in b.tasks or b.tasks[fut1.key].state == "flight":
await asyncio.sleep(0)
- s.handle_missing_data(key="f1", errant_worker=a.address)
+ s.handle_missing_data(key="f1", errant_worker=a.address, stimulus_id="test")
await fut2
@@ -2757,7 +2713,7 @@ async def test_acquire_replicas_same_channel(c, s, a, b):
# same communication channel
for fut in (futA, futB):
- assert_worker_story(
+ assert_story(
b.story(fut.key),
[
("gather-dependencies", a.address, {fut.key}),
@@ -2816,7 +2772,7 @@ def __getstate__(self):
assert await y == 123
story = await c.run(lambda dask_worker: dask_worker.story("x"))
- assert_worker_story(
+ assert_story(
story[b],
[
("x", "ensure-task-exists", "released"),
@@ -2993,7 +2949,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers):
await f2
- assert_worker_story(a.story(f1.key), [(f1.key, "missing-dep")])
+ assert_story(a.story(f1.key), [(f1.key, "missing-dep")])
assert a.tasks[f1.key].suspicious_count == 0
assert s.tasks[f1.key].suspicious == 0
@@ -3074,7 +3030,7 @@ async def test_missing_released_zombie_tasks_2(c, s, b):
while b.tasks:
await asyncio.sleep(0.01)
- assert_worker_story(
+ assert_story(
b.story(ts),
[("f1", "missing", "released", "released", {"f1": "forgotten"})],
)
@@ -3186,7 +3142,7 @@ async def test_task_flight_compute_oserror(c, s, a, b):
("f1", "put-in-memory"),
("f1", "executing", "memory", "memory", {}),
]
- assert_worker_story(sum_story, expected_sum_story, strict=True)
+ assert_story(sum_story, expected_sum_story, strict=True)
@gen_cluster(client=True, nthreads=[])
@@ -3406,7 +3362,9 @@ async def gather_dep(self, *args, **kwargs):
await in_gather_dep.wait()
- await s.remove_worker(address=x.address, safe=True, close=close_worker)
+ await s.remove_worker(
+ address=x.address, safe=True, close=close_worker, stimulus_id="test"
+ )
await _wait_for_state(fut2_key, b, intermediate_state)
diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py
index d8337ace8e..feb99ee995 100644
--- a/distributed/tests/test_worker_state_machine.py
+++ b/distributed/tests/test_worker_state_machine.py
@@ -115,15 +115,19 @@ def test_slots(cls):
def test_sendmsg_to_dict():
# Arbitrary sample class
- smsg = ReleaseWorkerDataMsg(key="x")
- assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"}
+ smsg = ReleaseWorkerDataMsg(key="x", stimulus_id="test")
+ assert smsg.to_dict() == {
+ "op": "release-worker-data",
+ "key": "x",
+ "stimulus_id": "test",
+ }
def test_merge_recs_instructions():
x = TaskState("x")
y = TaskState("y")
- instr1 = RescheduleMsg(key="foo", worker="a")
- instr2 = RescheduleMsg(key="bar", worker="b")
+ instr1 = RescheduleMsg(key="foo", worker="a", stimulus_id="test")
+ instr2 = RescheduleMsg(key="bar", worker="b", stimulus_id="test")
assert merge_recs_instructions(
({x: "memory"}, [instr1]),
({y: "released"}, [instr2]),
diff --git a/distributed/utils_test.py b/distributed/utils_test.py
index b7b79fd8c8..836a7fb179 100644
--- a/distributed/utils_test.py
+++ b/distributed/utils_test.py
@@ -1891,17 +1891,72 @@ def xfail_ssl_issue5601():
raise
-def assert_worker_story(
- story: list[tuple], expect: list[tuple], *, strict: bool = False
+def assert_valid_story(story, ordered_timestamps=True):
+ """Test that a story is well formed.
+
+ Parameters
+ ----------
+ story: list[tuple]
+ Output of Worker.story
+ ordered_timestamps: bool, optional
+ If False, timestamps are not required to be monotically increasing.
+ Useful for asserting stories composed from the scheduler and
+ multiple workers
+ """
+
+ now = time()
+ prev_ts = 0.0
+ for ev in story:
+ try:
+ assert len(ev) > 2, "too short"
+ assert isinstance(ev, tuple), "not a tuple"
+ assert isinstance(ev[-2], str) and ev[-2], "stimulus_id not a string"
+ assert isinstance(ev[-1], float), "Timestamp is not a float"
+ if ordered_timestamps:
+ assert prev_ts <= ev[-1], "Timestamps are not monotonically ascending"
+ # Timestamps are within the last hour. It's been observed that a
+ # timestamp generated in a Nanny process can be a few milliseconds
+ # in the future.
+ assert now - 3600 < ev[-1] <= now + 1, "Timestamps is too old"
+ prev_ts = ev[-1]
+ except AssertionError as err:
+ raise AssertionError(
+ f"Malformed story event: {ev}\nProblem: {err}.\nin story:\n{_format_story(story)}"
+ )
+
+
+def assert_story(
+ story: list[tuple],
+ expect: list[tuple],
+ *,
+ strict: bool = False,
+ ordered_timestamps: bool = True,
) -> None:
- """Test the output of ``Worker.story``
+ """Test the output of ``Worker.story`` or ``Scheduler.story``
+
+ Warning
+ =======
+
+ Tests with overly verbose stories introduce maintenance cost and should
+ therefore be used with caution. This should only be used for very specific
+ unit tests where the exact order of events is crucial and there is no other
+ practical way to assert or observe what happened.
+ A typical use case involves testing for race conditions where subtle changes
+ of event ordering would cause harm.
Parameters
==========
story: list[tuple]
Output of Worker.story
expect: list[tuple]
- Expected events. Each expected event must contain exactly 2 less fields than the
+ Expected events.
+ The expected events either need to be exact matches or are allowed to
+ not provide a stimulus_id and timestamp.
+ e.g.
+ `("log", "entry", "stim-id-9876", 1234)`
+ is equivalent to
+ `("log", "entry")`
+
story (the last two fields are always the stimulus_id and the timestamp).
Elements of the expect tuples can be
@@ -1922,24 +1977,17 @@ def assert_worker_story(
If True, the story must contain exactly as many events as expect.
If False (the default), the story may contain more events than expect; extra
events are ignored.
+ ordered_timestamps: bool, optional
+ If False, timestamps are not required to be monotically increasing.
+ Useful for asserting stories composed from the scheduler and
+ multiple workers
"""
- now = time()
- prev_ts = 0.0
- for ev in story:
- try:
- assert len(ev) > 2
- assert isinstance(ev, tuple)
- assert isinstance(ev[-2], str) and ev[-2] # stimulus_id
- assert isinstance(ev[-1], float) # timestamp
- assert prev_ts <= ev[-1] # Timestamps are monotonic ascending
- # Timestamps are within the last hour. It's been observed that a timestamp
- # generated in a Nanny process can be a few milliseconds in the future.
- assert now - 3600 < ev[-1] <= now + 1
- prev_ts = ev[-1]
- except AssertionError:
- raise AssertionError(
- f"Malformed story event: {ev}\nin story:\n{_format_story(story)}"
- )
+ assert_valid_story(story, ordered_timestamps=ordered_timestamps)
+
+ def _valid_event(event, ev_expect):
+ return len(event) == len(ev_expect) and all(
+ ex(ev) if callable(ex) else ev == ex for ev, ex in zip(event, ev_expect)
+ )
try:
if strict and len(story) != len(expect):
@@ -1948,16 +1996,16 @@ def assert_worker_story(
for ev_expect in expect:
while True:
event = next(story_it)
- # Ignore (stimulus_id, timestamp)
- event = event[:-2]
- if len(event) == len(ev_expect) and all(
- ex(ev) if callable(ex) else ev == ex
- for ev, ex in zip(event, ev_expect)
+
+ if (
+ _valid_event(event, ev_expect)
+ # Ignore (stimulus_id, timestamp)
+ or _valid_event(event[:-2], ev_expect)
):
break
except StopIteration:
raise AssertionError(
- f"assert_worker_story({strict=}) failed\n"
+ f"assert_story({strict=}) failed\n"
f"story:\n{_format_story(story)}\n"
f"expect:\n{_format_story(expect)}"
) from None
diff --git a/distributed/worker.py b/distributed/worker.py
index 4bc2fab5c6..3e50a679bc 100644
--- a/distributed/worker.py
+++ b/distributed/worker.py
@@ -48,6 +48,7 @@
)
from distributed import comm, preloading, profile, utils
+from distributed._stories import worker_story
from distributed.batched import BatchedSend
from distributed.comm import connect, get_address_host
from distributed.comm.addressing import address_from_user_args, parse_address
@@ -74,7 +75,6 @@
from distributed.security import Security
from distributed.shuffle import ShuffleWorkerExtension
from distributed.sizeof import safe_sizeof as sizeof
-from distributed.stories import worker_story
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
@@ -754,6 +754,7 @@ def __init__(
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
+ "get_story": self.get_story,
}
stream_handlers = {
@@ -927,21 +928,26 @@ def status(self, value):
"""
prev_status = self.status
ServerNode.status.__set__(self, value)
- self._send_worker_status_change()
+ stimulus_id = f"worker-status-change-{time()}"
+ self._send_worker_status_change(stimulus_id)
if prev_status == Status.paused and value == Status.running:
- self.handle_stimulus(UnpauseEvent(stimulus_id=f"set-status-{time()}"))
+ self.handle_stimulus(UnpauseEvent(stimulus_id=stimulus_id))
- def _send_worker_status_change(self) -> None:
+ def _send_worker_status_change(self, stimulus_id: str) -> None:
if (
self.batched_stream
and self.batched_stream.comm
and not self.batched_stream.comm.closed()
):
self.batched_stream.send(
- {"op": "worker-status-change", "status": self._status.name}
+ {
+ "op": "worker-status-change",
+ "status": self._status.name,
+ "stimulus_id": stimulus_id,
+ },
)
elif self._status != Status.closed:
- self.loop.call_later(0.05, self._send_worker_status_change)
+ self.loop.call_later(0.05, self._send_worker_status_change, stimulus_id)
async def get_metrics(self) -> dict:
try:
@@ -1083,6 +1089,7 @@ async def _register_with_scheduler(self):
versions=get_versions(),
metrics=await self.get_metrics(),
extra=await self.get_startup_information(),
+ stimulus_id=f"worker-connect-{time()}",
),
serializers=["msgpack"],
)
@@ -1479,7 +1486,11 @@ async def close(
with suppress(EnvironmentError, TimeoutError):
if report and self.contact_address is not None:
await asyncio.wait_for(
- self.scheduler.unregister(address=self.contact_address, safe=safe),
+ self.scheduler.unregister(
+ address=self.contact_address,
+ safe=safe,
+ stimulus_id=f"worker-close-{time()}",
+ ),
timeout,
)
await self.scheduler.close_rpc()
@@ -1543,7 +1554,10 @@ async def close_gracefully(self, restart=None):
# Scheduler.retire_workers will set the status to closing_gracefully and push it
# back to this worker.
await self.scheduler.retire_workers(
- workers=[self.address], close_workers=False, remove=False
+ workers=[self.address],
+ close_workers=False,
+ remove=False,
+ stimulus_id=f"worker-close-gracefully-{time()}",
)
await self.close(safe=True, nanny=not restart)
@@ -1894,7 +1908,9 @@ def handle_compute_task(
pass
elif ts.state == "memory":
recommendations[ts] = "memory"
- instructions.append(self._get_task_finished_msg(ts))
+ instructions.append(
+ self._get_task_finished_msg(ts, stimulus_id=stimulus_id)
+ )
elif ts.state in {
"released",
"fetch",
@@ -2033,7 +2049,7 @@ def transition_memory_released(
recs, instructions = self.transition_generic_released(
ts, stimulus_id=stimulus_id
)
- instructions.append(ReleaseWorkerDataMsg(ts.key))
+ instructions.append(ReleaseWorkerDataMsg(key=ts.key, stimulus_id=stimulus_id))
return recs, instructions
def transition_waiting_constrained(
@@ -2056,7 +2072,7 @@ def transition_long_running_rescheduled(
self, ts: TaskState, *, stimulus_id: str
) -> RecsInstrs:
recs: Recs = {ts: "released"}
- smsg = RescheduleMsg(key=ts.key, worker=self.address)
+ smsg = RescheduleMsg(key=ts.key, worker=self.address, stimulus_id=stimulus_id)
return recs, [smsg]
def transition_executing_rescheduled(
@@ -2067,7 +2083,14 @@ def transition_executing_rescheduled(
self._executing.discard(ts)
return merge_recs_instructions(
- ({ts: "released"}, [RescheduleMsg(key=ts.key, worker=self.address)]),
+ (
+ {ts: "released"},
+ [
+ RescheduleMsg(
+ key=ts.key, worker=self.address, stimulus_id=stimulus_id
+ )
+ ],
+ ),
self._ensure_computing(),
)
@@ -2145,6 +2168,7 @@ def transition_generic_error(
traceback_text=traceback_text,
thread=self.threads.get(ts.key),
startstops=ts.startstops,
+ stimulus_id=stimulus_id,
)
return {}, [smsg]
@@ -2334,7 +2358,9 @@ def transition_generic_memory(
else:
if self.validate:
assert ts.key in self.data or ts.key in self.actors
- instructions.append(self._get_task_finished_msg(ts))
+ instructions.append(
+ self._get_task_finished_msg(ts, stimulus_id=stimulus_id)
+ )
return recs, instructions
@@ -2453,7 +2479,10 @@ def transition_executing_long_running(
self.long_running.add(ts.key)
return merge_recs_instructions(
- ({}, [LongRunningMsg(key=ts.key, compute_duration=compute_duration)]),
+ (
+ {},
+ [LongRunningMsg(key=ts.key, compute_duration=compute_duration)],
+ ),
self._ensure_computing(),
)
@@ -2700,6 +2729,9 @@ def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]:
keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks}
return worker_story(keys, self.log)
+ async def get_story(self, keys=None):
+ return self.story(*keys)
+
def stimulus_story(
self, *keys_or_tasks: str | TaskState
) -> list[StateMachineEvent]:
@@ -2768,7 +2800,9 @@ def ensure_communicating(self) -> None:
for el in skipped_worker_in_flight:
self.data_needed.push(el)
- def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg:
+ def _get_task_finished_msg(
+ self, ts: TaskState, stimulus_id: str
+ ) -> TaskFinishedMsg:
if ts.key not in self.data and ts.key not in self.actors:
raise RuntimeError(f"Task {ts} not ready")
typ = ts.type
@@ -2794,6 +2828,7 @@ def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg:
metadata=ts.metadata,
thread=self.threads.get(ts.key),
startstops=ts.startstops,
+ stimulus_id=stimulus_id,
)
def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs:
@@ -3101,7 +3136,12 @@ async def gather_dep(
self.has_what[worker].discard(ts.key)
self.log.append((d, "missing-dep", stimulus_id, time()))
self.batched_stream.send(
- {"op": "missing-data", "errant_worker": worker, "key": d}
+ {
+ "op": "missing-data",
+ "errant_worker": worker,
+ "key": d,
+ "stimulus_id": stimulus_id,
+ }
)
if ts.who_has:
recommendations[ts] = "fetch"
@@ -3207,7 +3247,7 @@ def handle_steal_request(self, key: str, stimulus_id: str) -> None:
# `transition_constrained_executing`
self.transition(ts, "released", stimulus_id=stimulus_id)
- def handle_worker_status_change(self, status: str) -> None:
+ def handle_worker_status_change(self, status: str, stimulus_id: str) -> None:
new_status = Status.lookup[status] # type: ignore
if (
@@ -3218,7 +3258,7 @@ def handle_worker_status_change(self, status: str) -> None:
"Invalid Worker.status transition: %s -> %s", self._status, new_status
)
# Reiterate the current status to the scheduler to restore sync
- self._send_worker_status_change()
+ self._send_worker_status_change(stimulus_id)
else:
# Update status and send confirmation to the Scheduler (see status.setter)
self.status = new_status
@@ -3572,7 +3612,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
stop=result["stop"],
nbytes=result["nbytes"],
type=result["type"],
- stimulus_id=stimulus_id,
+ stimulus_id=f"task-finished-{time()}",
)
if isinstance(result["actual-exception"], Reschedule):
@@ -3599,7 +3639,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
traceback=result["traceback"],
exception_text=result["exception_text"],
traceback_text=result["traceback_text"],
- stimulus_id=stimulus_id,
+ stimulus_id=f"task-erred-{time()}",
)
except Exception as exc:
@@ -3613,7 +3653,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
traceback=msg["traceback"],
exception_text=msg["exception_text"],
traceback_text=msg["traceback_text"],
- stimulus_id=stimulus_id,
+ stimulus_id=f"task-erred-{time()}",
)
@functools.singledispatchmethod
diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py
index 3e74eb8cc8..abdbc12108 100644
--- a/distributed/worker_state_machine.py
+++ b/distributed/worker_state_machine.py
@@ -303,6 +303,7 @@ class TaskFinishedMsg(SendMessageToScheduler):
metadata: dict
thread: int | None
startstops: list[StartStop]
+ stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore
def to_dict(self) -> dict[str, Any]:
@@ -322,6 +323,7 @@ class TaskErredMsg(SendMessageToScheduler):
traceback_text: str
thread: int | None
startstops: list[StartStop]
+ stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore
def to_dict(self) -> dict[str, Any]:
@@ -334,8 +336,9 @@ def to_dict(self) -> dict[str, Any]:
class ReleaseWorkerDataMsg(SendMessageToScheduler):
op = "release-worker-data"
- __slots__ = ("key",)
+ __slots__ = ("key", "stimulus_id")
key: str
+ stimulus_id: str
# Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception
@@ -343,9 +346,10 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler):
class RescheduleMsg(SendMessageToScheduler):
op = "reschedule"
- __slots__ = ("key", "worker")
+ __slots__ = ("key", "worker", "stimulus_id")
key: str
worker: str
+ stimulus_id: str
@dataclass
@@ -438,6 +442,7 @@ class ExecuteSuccessEvent(StateMachineEvent):
stop: float
nbytes: int
type: type | None
+ stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore
def to_loggable(self, *, handled: float) -> StateMachineEvent:
@@ -460,6 +465,7 @@ class ExecuteFailureEvent(StateMachineEvent):
traceback: Serialize | None
exception_text: str
traceback_text: str
+ stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore
def _after_from_dict(self) -> None: