Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forget erred tasks // Fix deadlocks on worker #4784

Merged
merged 11 commits into from
Jun 11, 2021
1 change: 1 addition & 0 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def failing(x):
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "error"},
{"key": "task", "state": "error"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand Down
64 changes: 50 additions & 14 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,10 @@ class TaskState:
failed task is stored here (possibly itself). Otherwise this
is ``None``.

.. attribute:: erred_on: set(str)

Worker addresses on which errors appeared causing this task to be in an error state.

.. attribute:: suspicious: int

The number of times this task has been involved in a worker death.
Expand Down Expand Up @@ -1293,6 +1297,7 @@ class TaskState:
_exception: object
_traceback: object
_exception_blame: object
_erred_on: set
_suspicious: Py_ssize_t
_host_restrictions: set
_worker_restrictions: set
Expand Down Expand Up @@ -1343,6 +1348,7 @@ class TaskState:
"_who_wants",
"_exception",
"_traceback",
"_erred_on",
"_exception_blame",
"_suspicious",
"_retries",
Expand Down Expand Up @@ -1381,6 +1387,7 @@ def __init__(self, key: str, run_spec: object):
self._group = None
self._metadata = {}
self._annotations = {}
self._erred_on = set()

def __hash__(self):
return self._hash
Expand Down Expand Up @@ -1528,6 +1535,10 @@ def group_key(self):
def prefix_key(self):
return self._prefix._name

@property
def erred_on(self):
return self._erred_on
Comment on lines +1538 to +1540
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the reasons why the erred task state is never forgotten is that keys are only "remembered" in processing_on (processing) or who_has (memory), depending on the state of the task. This would be the equivalent for erred tasks to allow us to tell the worker to forget the task.

@jakirkham you have been involved in the scheduler state machine a lot recently. Just pinging in case you have thoughts about adding more state here or if you see other options


@ccall
def add_dependency(self, other: "TaskState"):
"""Add another task as a dependency of this task"""
Expand Down Expand Up @@ -1842,7 +1853,6 @@ def __init__(
("no-worker", "waiting"): self.transition_no_worker_waiting,
("released", "forgotten"): self.transition_released_forgotten,
("memory", "forgotten"): self.transition_memory_forgotten,
("erred", "forgotten"): self.transition_released_forgotten,
("erred", "released"): self.transition_erred_released,
("memory", "released"): self.transition_memory_released,
("released", "erred"): self.transition_released_erred,
Expand Down Expand Up @@ -2629,9 +2639,9 @@ def transition_memory_released(self, key, safe: bint = False):
# XXX factor this out?
ts_nbytes: Py_ssize_t = ts.get_nbytes()
worker_msg = {
"op": "delete-data",
"op": "free-keys",
"keys": [key],
"report": False,
"reason": f"Memory->Released {key}",
}
for ws in ts._who_has:
del ws._has_what[ts]
Expand Down Expand Up @@ -2722,7 +2732,6 @@ def transition_erred_released(self, key):

if self._validate:
with log_errors(pdb=LOG_PDB):
assert all([dts._state != "erred" for dts in ts._dependencies])
assert ts._exception_blame
assert not ts._who_has
assert not ts._waiting_on
Expand All @@ -2736,6 +2745,11 @@ def transition_erred_released(self, key):
if dts._state == "erred":
recommendations[dts._key] = "waiting"

w_msg = {"op": "free-keys", "keys": [key], "reason": "Erred->Released"}
for w in ts._erred_on:
worker_msgs[w] = [w_msg]
fjetter marked this conversation as resolved.
Show resolved Hide resolved
ts._erred_on.clear()

report_msg = {"op": "task-retried", "key": key}
cs: ClientState
for cs in ts._who_wants:
Expand Down Expand Up @@ -2805,7 +2819,9 @@ def transition_processing_released(self, key):

w: str = _remove_from_processing(self, ts)
if w:
worker_msgs[w] = [{"op": "release-task", "key": key}]
worker_msgs[w] = [
{"op": "free-keys", "keys": [key], "reason": "Processing->Released"}
]

ts.state = "released"

Expand Down Expand Up @@ -2835,7 +2851,7 @@ def transition_processing_released(self, key):
raise

def transition_processing_erred(
self, key, cause=None, exception=None, traceback=None, **kwargs
self, key, cause=None, exception=None, traceback=None, worker=None, **kwargs
):
ws: WorkerState
try:
Expand All @@ -2856,8 +2872,9 @@ def transition_processing_erred(
ws = ts._processing_on
ws._actors.remove(ts)

_remove_from_processing(self, ts)
w = _remove_from_processing(self, ts)

ts._erred_on.add(w or worker)
if exception is not None:
ts._exception = exception
if traceback is not None:
Expand Down Expand Up @@ -4456,7 +4473,9 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs):
ts._who_has,
)
if ws not in ts._who_has:
worker_msgs[worker] = [{"op": "release-task", "key": key}]
worker_msgs[worker] = [
{"op": "free-keys", "keys": [key], "reason": "Stimulus Finished"}
]

return recommendations, client_msgs, worker_msgs

Expand Down Expand Up @@ -5113,7 +5132,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
def release_worker_data(self, comm=None, keys=None, worker=None):
parent: SchedulerState = cast(SchedulerState, self)
ws: WorkerState = parent._workers_dv[worker]
tasks: set = {parent._tasks[k] for k in keys}
tasks: set = {parent._tasks[k] for k in keys if k in parent._tasks}
removed_tasks: set = tasks.intersection(ws._has_what)

ts: TaskState
Expand Down Expand Up @@ -5519,8 +5538,11 @@ async def _delete_worker_data(self, worker_address, keys):
List of keys to delete on the specified worker
"""
parent: SchedulerState = cast(SchedulerState, self)

await retry_operation(
self.rpc(addr=worker_address).delete_data, keys=list(keys), report=False
self.rpc(addr=worker_address).free_keys,
keys=list(keys),
reason="rebalance/replicate",
)

ws: WorkerState = parent._workers_dv[worker_address]
Expand Down Expand Up @@ -6271,6 +6293,7 @@ def add_keys(self, comm=None, worker=None, keys=()):
if worker not in parent._workers_dv:
return "not found"
ws: WorkerState = parent._workers_dv[worker]
superfluous_data = []
for key in keys:
ts: TaskState = parent._tasks.get(key)
if ts is not None and ts._state == "memory":
Expand All @@ -6279,9 +6302,16 @@ def add_keys(self, comm=None, worker=None, keys=()):
ws._has_what[ts] = None
ts._who_has.add(ws)
else:
self.worker_send(
worker, {"op": "delete-data", "keys": [key], "report": False}
)
superfluous_data.append(key)
if superfluous_data:
self.worker_send(
worker,
{
"op": "superfluous-data",
"keys": superfluous_data,
"reason": f"Add keys which are not in-memory {superfluous_data}",
},
)

return "OK"

Expand Down Expand Up @@ -7308,7 +7338,13 @@ def _propagate_forgotten(
ws._nbytes -= ts_nbytes
w: str = ws._address
if w in state._workers_dv: # in case worker has died
worker_msgs[w] = [{"op": "delete-data", "keys": [key], "report": False}]
worker_msgs[w] = [
{
"op": "free-keys",
"keys": [key],
"reason": f"propagate-forgotten {ts.key}",
}
]
ts._who_has.clear()


Expand Down
48 changes: 38 additions & 10 deletions distributed/tests/test_failed_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,26 @@ async def test_restart_during_computation(c, s, a, b):
assert not s.tasks


@gen_cluster(client=True, timeout=60)
class SlowTransmitData:
def __init__(self, data, delay=0.1):
self.delay = delay
self.data = data

def __reduce__(self):
import time

time.sleep(self.delay)
return (SlowTransmitData, (self.delay,))

def __sizeof__(self) -> int:
# Ensure this is offloaded to avoid blocking loop
import dask
from dask.utils import parse_bytes

return parse_bytes(dask.config.get("distributed.comm.offload")) + 1


@gen_cluster(client=True)
async def test_worker_who_has_clears_after_failed_connection(c, s, a, b):
n = await Nanny(s.address, nthreads=2, loop=s.loop)

Expand All @@ -393,23 +412,32 @@ async def test_worker_who_has_clears_after_failed_connection(c, s, a, b):
await asyncio.sleep(0.01)
assert time() < start + 5

futures = c.map(slowinc, range(20), delay=0.01, key=["f%d" % i for i in range(20)])
await wait(futures)

result = await c.submit(sum, futures, workers=a.address)
deps = [dep for dep in a.tasks.values() if dep.key not in a.data_needed]
for dep in deps:
a.release_key(dep.key, report=True)
def slow_ser(x, delay):
return SlowTransmitData(x, delay=delay)

n_worker_address = n.worker_address
futures = c.map(
slow_ser,
range(20),
delay=0.1,
key=["f%d" % i for i in range(20)],
workers=[n_worker_address],
allow_other_workers=True,
)

def sink(*args):
pass

await wait(futures)
result_fut = c.submit(sink, futures, workers=a.address)

with suppress(CommClosedError):
await c._run(os._exit, 1, workers=[n_worker_address])

while len(s.workers) > 2:
await asyncio.sleep(0.01)

total = c.submit(sum, futures, workers=a.address)
await total
await result_fut

assert not a.has_what.get(n_worker_address)
assert not any(n_worker_address in s for ts in a.tasks.values() for s in ts.who_has)
Expand Down
17 changes: 16 additions & 1 deletion distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,18 +1890,33 @@ async def test_task_groups(c, s, a, b):
assert tg.states["released"] == 5
assert tp.states["memory"] == 0
assert tp.states["released"] == 5
assert tp.groups == [tg]
assert tg.prefix is tp
assert tg in tp.groups
# these must be true since in this simple case there is a 1to1 mapping
# between prefix and group
assert tg.duration == tp.duration
assert tg.nbytes_in_memory == tp.nbytes_in_memory
assert tg.nbytes_total == tp.nbytes_total
# It should map down to individual tasks
assert tg.nbytes_total == sum(
[ts.get_nbytes() for ts in s.tasks.values() if ts.group is tg]
)
in_memory_ts = sum(
[
ts.get_nbytes()
for ts in s.tasks.values()
if ts.group is tg and ts.state == "memory"
]
)
assert tg.nbytes_in_memory == in_memory_ts

tg = s.task_groups[y.name]
assert tg.states["memory"] == 5

assert s.task_groups[y.name].dependencies == {s.task_groups[x.name]}

await c.replicate(y)
# TODO: Are we supposed to track replicated memory here? See also Scheduler.add_keys
assert tg.nbytes_in_memory == y.nbytes
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran into some issues with the nbytes counting and wondered if this is correct. While the test states this explicitly, I'm a bit surprised and believe there are code areas where this is treated differently

assert "array" in str(tg.types)
assert "array" in str(tp.types)
Expand Down
Loading