Skip to content

Commit

Permalink
Add a test about expected task states in an exception case
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed May 25, 2021
1 parent 833c5f6 commit 05fb3df
Show file tree
Hide file tree
Showing 5 changed files with 377 additions and 14 deletions.
1 change: 1 addition & 0 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def failing(x):
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "error"},
{"key": "task", "state": "error"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand Down
23 changes: 19 additions & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,10 @@ class TaskState:
failed task is stored here (possibly itself). Otherwise this
is ``None``.
.. attribute:: erred_on: set(str)
Worker addresses on which errors appeared causing this task to be in an error state.
.. attribute:: suspicious: int
The number of times this task has been involved in a worker death.
Expand Down Expand Up @@ -1297,6 +1301,7 @@ class TaskState:
_exception: object
_traceback: object
_exception_blame: object
_erred_on: set
_suspicious: Py_ssize_t
_host_restrictions: set
_worker_restrictions: set
Expand Down Expand Up @@ -1347,6 +1352,7 @@ class TaskState:
"_who_wants",
"_exception",
"_traceback",
"_erred_on",
"_exception_blame",
"_suspicious",
"_retries",
Expand Down Expand Up @@ -1385,6 +1391,7 @@ def __init__(self, key: str, run_spec: object):
self._group = None
self._metadata = {}
self._annotations = {}
self._erred_on = set()

def __hash__(self):
return self._hash
Expand Down Expand Up @@ -1532,6 +1539,10 @@ def group_key(self):
def prefix_key(self):
return self._prefix._name

@property
def erred_on(self):
return self._erred_on

@ccall
def add_dependency(self, other: "TaskState"):
"""Add another task as a dependency of this task"""
Expand Down Expand Up @@ -1838,7 +1849,6 @@ def __init__(
("no-worker", "waiting"): self.transition_no_worker_waiting,
("released", "forgotten"): self.transition_released_forgotten,
("memory", "forgotten"): self.transition_memory_forgotten,
("erred", "forgotten"): self.transition_released_forgotten,
("erred", "released"): self.transition_erred_released,
("memory", "released"): self.transition_memory_released,
("released", "erred"): self.transition_released_erred,
Expand Down Expand Up @@ -2696,7 +2706,6 @@ def transition_erred_released(self, key):

if self._validate:
with log_errors(pdb=LOG_PDB):
assert all([dts._state != "erred" for dts in ts._dependencies])
assert ts._exception_blame
assert not ts._who_has
assert not ts._waiting_on
Expand All @@ -2710,6 +2719,11 @@ def transition_erred_released(self, key):
if dts._state == "erred":
recommendations[dts._key] = "waiting"

w_msg = {"op": "release-task", "key": key}
for w in ts._erred_on:
worker_msgs[w] = [w_msg]
ts._erred_on.clear()

report_msg = {"op": "task-retried", "key": key}
cs: ClientState
for cs in ts._who_wants:
Expand Down Expand Up @@ -2809,7 +2823,7 @@ def transition_processing_released(self, key):
raise

def transition_processing_erred(
self, key, cause=None, exception=None, traceback=None, **kwargs
self, key, cause=None, exception=None, traceback=None, worker=None, **kwargs
):
ws: WorkerState
try:
Expand All @@ -2830,8 +2844,9 @@ def transition_processing_erred(
ws = ts._processing_on
ws._actors.remove(ts)

_remove_from_processing(self, ts)
w = _remove_from_processing(self, ts)

ts._erred_on.add(w or worker)
if exception is not None:
ts._exception = exception
if traceback is not None:
Expand Down
17 changes: 16 additions & 1 deletion distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1893,18 +1893,33 @@ async def test_task_groups(c, s, a, b):
assert tg.states["released"] == 5
assert tp.states["memory"] == 0
assert tp.states["released"] == 5
assert tp.groups == [tg]
assert tg.prefix is tp
assert tg in tp.groups
# these must be true since in this simple case there is a 1to1 mapping
# between prefix and group
assert tg.duration == tp.duration
assert tg.nbytes_in_memory == tp.nbytes_in_memory
assert tg.nbytes_total == tp.nbytes_total
# It should map down to individual tasks
assert tg.nbytes_total == sum(
[ts.get_nbytes() for ts in s.tasks.values() if ts.group is tg]
)
in_memory_ts = sum(
[
ts.get_nbytes()
for ts in s.tasks.values()
if ts.group is tg and ts.state == "memory"
]
)
assert tg.nbytes_in_memory == in_memory_ts

tg = s.task_groups[y.name]
assert tg.states["memory"] == 5

assert s.task_groups[y.name].dependencies == {s.task_groups[x.name]}

await c.replicate(y)
# TODO: Are we supposed to track repliacted memory here? See also Scheduelr.add_keys
assert tg.nbytes_in_memory == y.nbytes
assert "array" in str(tg.types)
assert "array" in str(tp.types)
Expand Down
Loading

0 comments on commit 05fb3df

Please sign in to comment.