Skip to content

Commit

Permalink
handle resheduled tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Aug 4, 2021
1 parent 51ca4d9 commit c3ceffd
Show file tree
Hide file tree
Showing 17 changed files with 1,780 additions and 910 deletions.
4 changes: 4 additions & 0 deletions distributed/cfexecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def map(self, fn, *iterables, **kwargs):
raise TypeError("unexpected arguments to map(): %s" % sorted(kwargs))

fs = self._client.map(fn, *iterables, **self._kwargs)
if isinstance(fs, list):
# Below iterator relies on this being a generator to cancel
# remaining futures
fs = (val for val in fs)

# Yield must be hidden in closure so that the tasks are submitted
# before the first iterator value is required.
Expand Down
18 changes: 0 additions & 18 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,24 +157,6 @@ def transition(self, key, start, finish, **kwargs):
kwargs : More options passed when transitioning
"""

def release_key(self, key, state, cause, reason, report):
"""
Called when the worker releases a task.
Parameters
----------
key : string
state : string
State of the released task.
One of waiting, ready, executing, long-running, memory, error.
cause : string or None
Additional information on what triggered the release of the task.
reason : None
Not used.
report : bool
Whether the worker should report the released task to the scheduler.
"""


class NannyPlugin:
"""Interface to extend the Nanny
Expand Down
75 changes: 62 additions & 13 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def transition(self, key, start, finish, **kwargs):
{"key": key, "start": start, "finish": finish}
)

def release_key(self, key, state, cause, reason, report):
self.observed_notifications.append({"key": key, "state": state})


@gen_cluster(client=True, nthreads=[])
async def test_create_with_client(c, s):
Expand Down Expand Up @@ -107,11 +104,12 @@ async def test_create_on_construction(c, s, a, b):
@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_normal_task_transitions_called(c, s, w):
expected_notifications = [
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "task", "state": "memory"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -127,11 +125,12 @@ def failing(x):
raise Exception()

expected_notifications = [
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "error"},
{"key": "task", "state": "error"},
{"key": "task", "start": "error", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -147,11 +146,12 @@ def failing(x):
)
async def test_superseding_task_transitions_called(c, s, w):
expected_notifications = [
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "constrained"},
{"key": "task", "start": "constrained", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "task", "state": "memory"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -166,16 +166,18 @@ async def test_dependent_tasks(c, s, w):
dsk = {"dep": 1, "task": (inc, "dep")}

expected_notifications = [
{"key": "dep", "start": "new", "finish": "waiting"},
{"key": "dep", "start": "released", "finish": "waiting"},
{"key": "dep", "start": "waiting", "finish": "ready"},
{"key": "dep", "start": "ready", "finish": "executing"},
{"key": "dep", "start": "executing", "finish": "memory"},
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "dep", "state": "memory"},
{"key": "task", "state": "memory"},
{"key": "dep", "start": "memory", "finish": "released"},
{"key": "dep", "start": "released", "finish": "forgotten"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand Down Expand Up @@ -218,3 +220,50 @@ class MyCustomPlugin(WorkerPlugin):
await c.register_worker_plugin(MyCustomPlugin())
assert len(w.plugins) == 1
assert next(iter(w.plugins)).startswith("MyCustomPlugin-")


def test_release_key_deprecated():
class ReleaseKeyDeprecated(WorkerPlugin):
def __init__(self):
self._called = False

def release_key(self, key, state, cause, reason, report):
# Ensure that the handler still works
self._called = True
assert state == "memory"
assert key == "task"

def teardown(self, worker):
assert self._called
return super().teardown(worker)

@gen_cluster(client=True, nthreads=[("", 1)])
async def test(c, s, a):

await c.register_worker_plugin(ReleaseKeyDeprecated())
fut = await c.submit(inc, 1, key="task")
assert fut == 2

with pytest.deprecated_call(
match="The `WorkerPlugin.release_key` hook is depreacted"
):
test()


def test_assert_no_warning_no_overload():
"""Assert we do not receive a deprecation warning if we do not overload any
methods
"""

class Dummy(WorkerPlugin):
pass

@gen_cluster(client=True, nthreads=[("", 1)])
async def test(c, s, a):

await c.register_worker_plugin(Dummy())
fut = await c.submit(inc, 1, key="task")
assert fut == 2

with pytest.warns(None):
test()
60 changes: 53 additions & 7 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1969,6 +1969,8 @@ def __init__(
("processing", "erred"): self.transition_processing_erred,
("no-worker", "released"): self.transition_no_worker_released,
("no-worker", "waiting"): self.transition_no_worker_waiting,
# TODO: Write a test. Worker disconnects -> no-worker -> reconnect with task to memory. Triggered every few hundred times by test_handle_superfluous_data
("no-worker", "memory"): self.transition_no_worker_memory,
("released", "forgotten"): self.transition_released_forgotten,
("memory", "forgotten"): self.transition_memory_forgotten,
("erred", "released"): self.transition_erred_released,
Expand Down Expand Up @@ -2215,7 +2217,7 @@ def _transition(self, key, finish: str, *args, **kwargs):
self._transition_counter += 1
recommendations, client_msgs, worker_msgs = a
elif "released" not in start_finish:
assert not args and not kwargs
assert not args and not kwargs, start_finish
a_recs: dict
a_cmsgs: dict
a_wmsgs: dict
Expand Down Expand Up @@ -2614,6 +2616,42 @@ def transition_waiting_processing(self, key):
pdb.set_trace()
raise

def transition_no_worker_memory(
self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs
):
try:
ws: WorkerState = self._workers_dv[worker]
ts: TaskState = self._tasks[key]
recommendations: dict = {}
client_msgs: dict = {}
worker_msgs: dict = {}

if self._validate:
assert not ts._processing_on
assert not ts._waiting_on
assert ts._state == "no-worker"

self._unrunnable.remove(ts)

if nbytes is not None:
ts.set_nbytes(nbytes)

self.check_idle_saturated(ws)

_add_to_memory(
self, ts, ws, recommendations, client_msgs, type=type, typename=typename
)
ts.state = "memory"

return recommendations, client_msgs, worker_msgs
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb

pdb.set_trace()
raise

def transition_waiting_memory(
self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs
):
Expand Down Expand Up @@ -5353,6 +5391,8 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):

def release_worker_data(self, comm=None, keys=None, worker=None):
parent: SchedulerState = cast(SchedulerState, self)
if worker not in parent._workers_dv:
return
ws: WorkerState = parent._workers_dv[worker]
tasks: set = {parent._tasks[k] for k in keys if k in parent._tasks}
removed_tasks: set = tasks.intersection(ws._has_what)
Expand Down Expand Up @@ -6610,7 +6650,7 @@ def add_keys(self, comm=None, worker=None, keys=()):
if worker not in parent._workers_dv:
return "not found"
ws: WorkerState = parent._workers_dv[worker]
superfluous_data = []
redundant_replicas = []
for key in keys:
ts: TaskState = parent._tasks.get(key)
if ts is not None and ts._state == "memory":
Expand All @@ -6619,14 +6659,15 @@ def add_keys(self, comm=None, worker=None, keys=()):
ws._has_what[ts] = None
ts._who_has.add(ws)
else:
superfluous_data.append(key)
if superfluous_data:
redundant_replicas.append(key)

if redundant_replicas:
self.worker_send(
worker,
{
"op": "superfluous-data",
"keys": superfluous_data,
"reason": f"Add keys which are not in-memory {superfluous_data}",
"op": "remove-replicas",
"keys": redundant_replicas,
"stimulus_id": f"redundant-replicas-{time()}",
},
)

Expand Down Expand Up @@ -7728,12 +7769,15 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->

if duration < 0:
duration = state.get_task_duration(ts)
import uuid

msg: dict = {
"op": "compute-task",
"key": ts._key,
"priority": ts._priority,
"duration": duration,
"stimulus_id": f"compute-task-{uuid.uuid4()}",
"who_has": {},
}
if ts._resource_restrictions:
msg["resource_restrictions"] = ts._resource_restrictions
Expand All @@ -7758,6 +7802,8 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->

if ts._annotations:
msg["annotations"] = ts._annotations

assert "stimulus_id" in msg
return msg


Expand Down
12 changes: 10 additions & 2 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,15 @@ async def move_task_confirm(self, key=None, worker=None, state=None):
return

# Victim had already started execution, reverse stealing
if state in ("memory", "executing", "long-running", None):
if state in (
"memory",
"executing",
"long-running",
"released",
"cancelled",
"resumed",
None,
):
self.log(("already-computing", key, victim.address, thief.address))
self.scheduler.check_idle_saturated(thief)
self.scheduler.check_idle_saturated(victim)
Expand All @@ -256,7 +264,7 @@ async def move_task_confirm(self, key=None, worker=None, state=None):
await self.scheduler.remove_worker(thief.address)
self.log(("confirm", key, victim.address, thief.address))
else:
raise ValueError("Unexpected task state: %s" % state)
raise ValueError(f"Unexpected task state: {ts}")
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand Down
Loading

0 comments on commit c3ceffd

Please sign in to comment.