diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index cc41bd79d09..86eb11e49ed 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -131,6 +131,7 @@ def failing(x): {"key": "task", "start": "waiting", "finish": "ready"}, {"key": "task", "start": "ready", "finish": "executing"}, {"key": "task", "start": "executing", "finish": "error"}, + {"key": "task", "state": "error"}, ] plugin = MyPlugin(1, expected_notifications=expected_notifications) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 7cfe8497e24..7e47df52879 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1209,6 +1209,10 @@ class TaskState: failed task is stored here (possibly itself). Otherwise this is ``None``. + .. attribute:: erred_on: set(str) + + Worker addresses on which errors appeared causing this task to be in an error state. + .. attribute:: suspicious: int The number of times this task has been involved in a worker death. @@ -1297,6 +1301,7 @@ class TaskState: _exception: object _traceback: object _exception_blame: object + _erred_on: set _suspicious: Py_ssize_t _host_restrictions: set _worker_restrictions: set @@ -1347,6 +1352,7 @@ class TaskState: "_who_wants", "_exception", "_traceback", + "_erred_on", "_exception_blame", "_suspicious", "_retries", @@ -1385,6 +1391,7 @@ def __init__(self, key: str, run_spec: object): self._group = None self._metadata = {} self._annotations = {} + self._erred_on = set() def __hash__(self): return self._hash @@ -1532,6 +1539,10 @@ def group_key(self): def prefix_key(self): return self._prefix._name + @property + def erred_on(self): + return self._erred_on + @ccall def add_dependency(self, other: "TaskState"): """Add another task as a dependency of this task""" @@ -1838,7 +1849,6 @@ def __init__( ("no-worker", "waiting"): self.transition_no_worker_waiting, ("released", "forgotten"): self.transition_released_forgotten, ("memory", "forgotten"): self.transition_memory_forgotten, - ("erred", "forgotten"): self.transition_released_forgotten, ("erred", "released"): self.transition_erred_released, ("memory", "released"): self.transition_memory_released, ("released", "erred"): self.transition_released_erred, @@ -2696,7 +2706,6 @@ def transition_erred_released(self, key): if self._validate: with log_errors(pdb=LOG_PDB): - assert all([dts._state != "erred" for dts in ts._dependencies]) assert ts._exception_blame assert not ts._who_has assert not ts._waiting_on @@ -2710,6 +2719,11 @@ def transition_erred_released(self, key): if dts._state == "erred": recommendations[dts._key] = "waiting" + w_msg = {"op": "release-task", "key": key} + for w in ts._erred_on: + worker_msgs[w] = [w_msg] + ts._erred_on.clear() + report_msg = {"op": "task-retried", "key": key} cs: ClientState for cs in ts._who_wants: @@ -2809,7 +2823,7 @@ def transition_processing_released(self, key): raise def transition_processing_erred( - self, key, cause=None, exception=None, traceback=None, **kwargs + self, key, cause=None, exception=None, traceback=None, worker=None, **kwargs ): ws: WorkerState try: @@ -2830,8 +2844,9 @@ def transition_processing_erred( ws = ts._processing_on ws._actors.remove(ts) - _remove_from_processing(self, ts) + w = _remove_from_processing(self, ts) + ts._erred_on.add(w or worker) if exception is not None: ts._exception = exception if traceback is not None: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index ccd7ce31dd7..d5706334726 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1893,11 +1893,25 @@ async def test_task_groups(c, s, a, b): assert tg.states["released"] == 5 assert tp.states["memory"] == 0 assert tp.states["released"] == 5 + assert tp.groups == [tg] assert tg.prefix is tp - assert tg in tp.groups + # these must be true since in this simple case there is a 1to1 mapping + # between prefix and group assert tg.duration == tp.duration assert tg.nbytes_in_memory == tp.nbytes_in_memory assert tg.nbytes_total == tp.nbytes_total + # It should map down to individual tasks + assert tg.nbytes_total == sum( + [ts.get_nbytes() for ts in s.tasks.values() if ts.group is tg] + ) + in_memory_ts = sum( + [ + ts.get_nbytes() + for ts in s.tasks.values() + if ts.group is tg and ts.state == "memory" + ] + ) + assert tg.nbytes_in_memory == in_memory_ts tg = s.task_groups[y.name] assert tg.states["memory"] == 5 @@ -1905,6 +1919,7 @@ async def test_task_groups(c, s, a, b): assert s.task_groups[y.name].dependencies == {s.task_groups[x.name]} await c.replicate(y) + # TODO: Are we supposed to track repliacted memory here? See also Scheduelr.add_keys assert tg.nbytes_in_memory == y.nbytes assert "array" in str(tg.types) assert "array" in str(tp.types) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index b1b5eb8b48a..87c54dd9658 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1897,3 +1897,313 @@ async def test_gather_dep_one_worker_always_busy(c, s, a, b): # naturally while any(["Worker.gather_dep" in str(t) for t in asyncio.all_tasks()]): await asyncio.sleep(0.05) + + +def assert_task_states_on_worker(expected, worker): + for dep_key, expected_state in expected.items(): + assert dep_key in worker.tasks, (worker.name, dep_key, worker.tasks) + dep_ts = worker.tasks[dep_key] + assert dep_ts.state == expected_state, (worker.name, dep_ts, expected_state) + assert set(expected) == set(worker.tasks) + + +@gen_cluster(client=True) +async def test_worker_state_error_release_error_last(c, s, a, b): + """ + Create a chain of tasks and err one of them. Then release tasks in a certain + order and ensure the tasks are released and/or kept in memory as appropriate + + F -- RES (error) + / + / + G + + Free error last + """ + + def raise_exc(*args): + raise RuntimeError() + + f = c.submit(inc, 1, workers=[a.address], key="f") + g = c.submit(inc, 1, workers=[b.address], key="g") + res = c.submit(raise_exc, f, g, workers=[a.address]) + + with pytest.raises(RuntimeError): + await res.result() + + # Nothing bad happened on B, therefore B should hold on to G + assert len(b.tasks) == 1 + assert g.key in b.tasks + + # A raised the exception therefore we should hold on to the erroneous task + assert res.key in a.tasks + ts = a.tasks[res.key] + assert ts.state == "error" + + expected_states = { + # A was instructed to compute this result and we're still holding a ref via `f` + f.key: "memory", + # This was fetched from another worker. While we hold a ref via `g`, the + # scheduler only instructed to compute this on B + g.key: "memory", + res.key: "error", + } + assert_task_states_on_worker(expected_states, a) + # Expected states after we release references to the futures + f.release() + g.release() + + # We no longer hold any refs to f or g and B didn't have any erros. It + # releases everything as expected + while b.tasks: + await asyncio.sleep(0.01) + + expected_states = { + # We currently don't have a good way to actually release this memory as + # long as the tasks still have a dependent. We'll need to live with this + # memory for now + f.key: "memory", + g.key: "memory", + res.key: "error", + } + + assert_task_states_on_worker(expected_states, a) + + res.release() + + # We no longer hold any refs. Cluster should reset completely + # This is not happening + for server in [s, a, b]: + while server.tasks: + await asyncio.sleep(0.01) + + +@gen_cluster(client=True) +async def test_worker_state_error_release_error_first(c, s, a, b): + """ + Create a chain of tasks and err one of them. Then release tasks in a certain + order and ensure the tasks are released and/or kept in memory as appropriate + + F -- RES (error) + / + / + G + + Free error first + """ + + def raise_exc(*args): + raise RuntimeError() + + f = c.submit(inc, 1, workers=[a.address], key="f") + g = c.submit(inc, 1, workers=[b.address], key="g") + res = c.submit(raise_exc, f, g, workers=[a.address]) + + with pytest.raises(RuntimeError): + await res.result() + + # Nothing bad happened on B, therefore B should hold on to G + assert len(b.tasks) == 1 + assert g.key in b.tasks + + # A raised the exception therefore we should hold on to the erroneous task + assert res.key in a.tasks + ts = a.tasks[res.key] + assert ts.state == "error" + + expected_states = { + # A was instructed to compute this result and we're still holding a ref + # via `f` + f.key: "memory", + # This was fetched from another worker. While we hold a ref via `g`, the + # scheduler only instructed to compute this on B + g.key: "memory", + res.key: "error", + } + assert_task_states_on_worker(expected_states, a) + # Expected states after we release references to the futures + + res.release() + # We no longer hold any refs to f or g and B didn't have any erros. It + # releases everything as expected + while res.key in a.tasks: + await asyncio.sleep(0.01) + + expected_states = { + f.key: "memory", + } + + assert_task_states_on_worker(expected_states, a) + + f.release() + g.release() + + # We no longer hold any refs. Cluster should reset completely + # This is not happening + for server in [s, a, b]: + while server.tasks: + await asyncio.sleep(0.01) + + +@gen_cluster(client=True) +async def test_worker_state_error_release_error_int(c, s, a, b): + """ + Create a chain of tasks and err one of them. Then release tasks in a certain + order and ensure the tasks are released and/or kept in memory as appropriate + + F -- RES (error) + / + / + G + + Free one successful task, then error, then last task + """ + + def raise_exc(*args): + raise RuntimeError() + + f = c.submit(inc, 1, workers=[a.address], key="f") + g = c.submit(inc, 1, workers=[b.address], key="g") + res = c.submit(raise_exc, f, g, workers=[a.address]) + + with pytest.raises(RuntimeError): + await res.result() + + # Nothing bad happened on B, therefore B should hold on to G + assert len(b.tasks) == 1 + assert g.key in b.tasks + + # A raised the exception therefore we should hold on to the erroneous task + assert res.key in a.tasks + ts = a.tasks[res.key] + assert ts.state == "error" + + expected_states = { + # A was instructed to compute this result and we're still holding a ref via `f` + f.key: "memory", + # This was fetched from another worker. While we hold a ref via `g`, the + # scheduler only instructed to compute this on B + g.key: "memory", + res.key: "error", + } + assert_task_states_on_worker(expected_states, a) + # Expected states after we release references to the futures + + f.release() + res.release() + # We no longer hold any refs to f or g and B didn't have any erros. It + # releases everything as expected + while a.tasks: + await asyncio.sleep(0.01) + + expected_states = { + g.key: "memory", + } + + assert_task_states_on_worker(expected_states, b) + + g.release() + + # We no longer hold any refs. Cluster should reset completely + for server in [s, a, b]: + while server.tasks: + await asyncio.sleep(0.01) + + +@gen_cluster(client=True) +async def test_worker_state_error_long_chain(c, s, a, b): + def raise_exc(*args): + raise RuntimeError() + + # f (A) --------> res (B) + # / + # g (B) -> h (A) + + f = c.submit(inc, 1, workers=[a.address], key="f", allow_other_workers=False) + g = c.submit(inc, 1, workers=[b.address], key="g", allow_other_workers=False) + h = c.submit(inc, g, workers=[a.address], key="h", allow_other_workers=False) + res = c.submit( + raise_exc, f, h, workers=[b.address], allow_other_workers=False, key="res" + ) + + with pytest.raises(RuntimeError): + await res.result() + + expected_states_A = { + f.key: "memory", + g.key: "memory", + h.key: "memory", + } + await asyncio.sleep(0.05) + assert_task_states_on_worker(expected_states_A, a) + + expected_states_B = { + f.key: "memory", + g.key: "memory", + h.key: "memory", + res.key: "error", + } + await asyncio.sleep(0.05) + assert_task_states_on_worker(expected_states_B, b) + + f.release() + + expected_states_A = { + g.key: "memory", + h.key: "memory", + } + await asyncio.sleep(0.05) + assert_task_states_on_worker(expected_states_A, a) + + expected_states_B = { + f.key: "memory", + g.key: "memory", + h.key: "memory", + res.key: "error", + } + await asyncio.sleep(0.05) + assert_task_states_on_worker(expected_states_B, b) + + g.release() + + expected_states_A = { + h.key: "memory", + } + await asyncio.sleep(0.05) + assert_task_states_on_worker(expected_states_A, a) + + # B must not forget a task since all have a still valid dependent + expected_states_B = { + f.key: "memory", + # We actually cannot hold on to G even though the graph would suggest + # otherwise. This is because H was only introduced as a dependency and + # the scheduler never told the worker how H fits into the big picture. + # Therefore, it thinks that G does not have any dependents anymore and + # releases it. Too bad. Once we have speculative task assignments this + # should be more exact since we should always tell the worker what's + # going on + # g.key: released, + h.key: "memory", + res.key: "error", + } + assert_task_states_on_worker(expected_states_B, b) + h.release() + await asyncio.sleep(0.05) + + expected_states_A = {} + assert_task_states_on_worker(expected_states_A, a) + expected_states_B = { + f.key: "memory", + # See above + # g.key: released, + h.key: "memory", + res.key: "error", + } + + assert_task_states_on_worker(expected_states_B, b) + res.release() + + # We no longer hold any refs. Cluster should reset completely + for server in [s, a, b]: + while server.tasks: + await asyncio.sleep(0.01) diff --git a/distributed/worker.py b/distributed/worker.py index 494df279d75..1dfe4a9e2b4 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -13,7 +13,6 @@ from collections.abc import MutableMapping from contextlib import suppress from datetime import timedelta -from functools import partial from inspect import isawaitable from pickle import PicklingError from typing import Iterable @@ -188,6 +187,7 @@ def __init__(self, key, runspec=None): self.stop_time = None self.metadata = {} self.nbytes = None + self.scheduler_holds_ref = False def __repr__(self): return "" % (self.key, self.state) @@ -688,7 +688,7 @@ def __init__( stream_handlers = { "close": self.close, "compute-task": self.add_task, - "release-task": partial(self.release_key, report=False), + "release-task": self.release_task, "delete-data": self.delete_data, "steal-request": self.steal_request, } @@ -1431,6 +1431,7 @@ def update_data(self, comm=None, data=None, report=True, serializers=None): self.put_key_in_memory(ts, value) ts.priority = None ts.duration = None + ts.scheduler_holds_ref = True self.log.append((key, "receive-from-scatter")) @@ -1439,9 +1440,18 @@ def update_data(self, comm=None, data=None, report=True, serializers=None): info = {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} return info + def release_task(self, key): + ts = self.tasks.get(key) + if ts: + ts.scheduler_holds_ref = False + self.release_key(key, report=False) + def delete_data(self, comm=None, keys=None, report=True): if keys: for key in list(keys): + ts = self.tasks.get(key) + if ts: + ts.scheduler_holds_ref = False self.log.append((key, "delete")) self.release_key(key, cause="delete data") @@ -1485,6 +1495,7 @@ def add_task( runspec = SerializedTask(function, args, kwargs, task) if key in self.tasks: ts = self.tasks[key] + ts.scheduler_holds_ref = True if ts.state == "memory": assert key in self.data or key in self.actors logger.debug( @@ -1508,7 +1519,6 @@ def add_task( key=key, runspec=SerializedTask(function, args, kwargs, task) ) self.transition(ts, "waiting") - # TODO: move transition of `ts` to end of `add_task` # This will require a chained recommendation transition system like # the scheduler @@ -1520,6 +1530,7 @@ def add_task( if actor: self.actors[ts.key] = None + ts.scheduler_holds_ref = True ts.runspec = runspec ts.priority = priority ts.duration = duration @@ -2456,12 +2467,23 @@ def release_key(self, key, cause=None, reason=None, report=True): # for any dependencies of key we are releasing remove task as dependent for dependency in ts.dependencies: dependency.dependents.discard(ts) - # don't boot keys that are in flight - # we don't know if they're already queued up for transit - # in a gather_dep callback - if not dependency.dependents and dependency.state in ( - "waiting", - "fetch", + + if ( + not dependency.dependents + and dependency.state + not in ( + # don't boot keys that are in flight + # we don't know if they're already queued up for transit + # in a gather_dep callback + "flight", + # The same is true for already executing keys. + "executing", + ) + # If the scheduler holds a reference which is usually the + # case when it instructed the task to be computed here or if + # data was scattered we must not release it unless the + # scheduler allow us to. See also handle_delete_data and + and not dependency.scheduler_holds_ref ): self.release_key(dependency.key, cause=f"Dependent {ts} released")