From 42690f0a63682bb637fe44cd2ec5e1898def9f14 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 4 Dec 2020 17:04:22 +0100 Subject: [PATCH] Fix deadlocks and infinite loops in worker --- distributed/core.py | 13 +- distributed/scheduler.py | 16 +- distributed/tests/test_client.py | 6 +- distributed/tests/test_failed_workers.py | 49 ++- distributed/tests/test_resources.py | 6 +- distributed/tests/test_stress.py | 2 +- distributed/worker.py | 465 ++++++++++++----------- 7 files changed, 312 insertions(+), 245 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 1d98241bb1..f74f21fadd 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -623,7 +623,6 @@ async def send_recv(comm, reply=True, serializers=None, deserializers=None, **kw msg = kwargs msg["reply"] = reply please_close = kwargs.get("close") - force_close = False if deserializers is None: deserializers = serializers if deserializers is not None: @@ -635,15 +634,15 @@ async def send_recv(comm, reply=True, serializers=None, deserializers=None, **kw response = await comm.read(deserializers=deserializers) else: response = None - except EnvironmentError: - # On communication errors, we should simply close the communication - force_close = True - raise + except Exception as exc: + # If an exception occured we will need to close the comm, if possible. + # Otherwise the other end might wait for a reply while this end is + # reusing the comm for something else. + comm.abort() + raise exc finally: if please_close: await comm.close() - elif force_close: - comm.abort() if isinstance(response, dict) and response.get("status") == "uncaught-error": if comm.deserialize: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 77a9be2f6d..2b3d90d40d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2738,7 +2738,14 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): ts._who_has, ) if ws not in ts._who_has: - self.worker_send(worker, {"op": "release-task", "key": key}) + self.worker_send( + worker, + { + "op": "release-task", + "key": key, + "reason": "stimulus task finished", + }, + ) recommendations = {} return recommendations @@ -5223,7 +5230,12 @@ def transition_processing_released(self, key): assert self.tasks[key].state == "processing" self._remove_from_processing( - ts, send_worker_msg={"op": "release-task", "key": key} + ts, + send_worker_msg={ + "op": "release-task", + "key": key, + "reason": "transition released", + }, ) ts.state = "released" diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 724a738777..1870005064 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1943,18 +1943,16 @@ def __setstate__(self, state): async def test_badly_serialized_input(c, s, a, b): o = BadlySerializedObject() - future = c.submit(inc, o) + future = c.submit(inc, o, key="broken") futures = c.map(inc, range(10)) L = await c.gather(futures) assert list(L) == list(map(inc, range(10))) assert future.status == "error" - with pytest.raises(Exception) as info: + with pytest.raises(TypeError, match="hello!"): await future - assert "hello!" in str(info.value) - @pytest.mark.skipif("True", reason="") async def test_badly_serialized_input_stderr(capsys, c): diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 32addfa68b..1704005fc1 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -386,7 +386,7 @@ async def test_worker_who_has_clears_after_failed_connection(c, s, a, b): result = await c.submit(sum, futures, workers=a.address) deps = [dep for dep in a.tasks.values() if dep.key not in a.data_needed] for dep in deps: - a.release_key(dep.key, report=True) + a.release_key(dep.key, report=True, reason="test") n_worker_address = n.worker_address with suppress(CommClosedError): @@ -404,6 +404,21 @@ async def test_worker_who_has_clears_after_failed_connection(c, s, a, b): await n.close() +@gen_cluster(client=True) +async def test_worker_release_key_recovers(c, s, a, b): + futs = c.map(slowinc, range(2), delay=0.1) + + await wait(futs) + + assert len(a.tasks) > 0 + for ts in list(a.tasks.values()): + a.release_key(ts.key, reason="test") + + res = await c.submit(sum, futs) + + assert res == 3 + + @pytest.mark.slow @gen_cluster(client=True, timeout=60, Worker=Nanny, nthreads=[("127.0.0.1", 1)]) async def test_restart_timeout_on_long_running_task(c, s, a): @@ -416,7 +431,6 @@ async def test_restart_timeout_on_long_running_task(c, s, a): assert "timeout" not in text.lower() -@pytest.mark.slow @gen_cluster(client=True, scheduler_kwargs={"worker_ttl": "500ms"}) async def test_worker_time_to_live(c, s, a, b): from distributed.scheduler import heartbeat_interval @@ -434,4 +448,33 @@ async def test_worker_time_to_live(c, s, a, b): await asyncio.sleep(interval) assert time() < start + interval + 0.1 - set(s.workers) == {b.address} + assert set(s.workers) == {b.address} + + +@gen_cluster(client=True, Worker=Nanny) +async def test_get_data_faulty_dep(c, s, a, b): + """This test creates a broken dependency and forces serialization by + requiring it to be submitted to another worker. The computation should + eventually finish by flagging the dep as bad and raise an appropriate + exception. + """ + + class BrokenDeserialization: + def __setstate__(self, *state): + raise AttributeError() + + def __getstate__(self, *args): + return "" + + def create(): + return BrokenDeserialization() + + def collect(*args): + return args + + fut1 = c.submit(create, workers=[a.name]) + + fut2 = c.submit(collect, fut1, workers=[b.name]) + + with pytest.raises(RuntimeError, match="Could not find dependencies for collect-"): + await fut2.result() diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index ac7c06f07e..aaba48b99f 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -249,7 +249,11 @@ async def test_minimum_resource(c, s, a): assert a.total_resources == a.available_resources -@gen_cluster(client=True, nthreads=[("127.0.0.1", 2, {"resources": {"A": 1}})]) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 2, {"resources": {"A": 1}})], + active_rpc_timeout=10, +) async def test_prefer_constrained(c, s, a): futures = c.map(slowinc, range(1000), delay=0.1) constrained = c.map(inc, range(10), resources={"A": 1}) diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index d699ac9452..755f5f97ef 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -79,7 +79,7 @@ async def test_cancel_stress(c, s, *workers): def test_cancel_stress_sync(loop): da = pytest.importorskip("dask.array") x = da.random.random((50, 50), chunks=(2, 2)) - with cluster(active_rpc_timeout=10) as (s, [a, b]): + with cluster(active_rpc_timeout=10, disconnect_timeout=10) as (s, [a, b]): with Client(s["address"], loop=loop) as c: x = c.persist(x) y = (x.sum(axis=0) + x.sum(axis=1) + 1).std() diff --git a/distributed/worker.py b/distributed/worker.py index 87ee7fe368..e871bcfcd7 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -77,7 +77,6 @@ no_value = "--no-value-sentinel--" -IN_PLAY = ("waiting", "ready", "executing", "long-running") PENDING = ("waiting", "ready", "constrained") PROCESSING = ("waiting", "ready", "constrained", "executing", "long-running") READY = ("ready", "constrained") @@ -151,10 +150,10 @@ class TaskState: """ - def __init__(self, key, runspec=None): + def __init__(self, key): assert key is not None self.key = key - self.runspec = runspec + self.runspec = None self.dependencies = set() self.dependents = set() self.duration = None @@ -405,7 +404,7 @@ def __init__( self.data_needed = deque() # TODO: replace with heap? self.in_flight_tasks = 0 - self.in_flight_workers = dict() + self.in_flight_workers = defaultdict(set) self.total_out_connections = dask.config.get( "distributed.worker.connections.outgoing" ) @@ -414,7 +413,6 @@ def __init__( ) self.total_comm_nbytes = 10e6 self.comm_nbytes = 0 - self._missing_dep_flight = set() self.threads = dict() @@ -450,6 +448,7 @@ def __init__( ("waiting", "flight"): self.transition_waiting_flight, ("ready", "executing"): self.transition_ready_executing, ("ready", "memory"): self.transition_ready_memory, + ("ready", "error"): self.transition_ready_error, ("constrained", "executing"): self.transition_constrained_executing, ("executing", "memory"): self.transition_executing_done, ("executing", "error"): self.transition_executing_done, @@ -459,7 +458,7 @@ def __init__( ("long-running", "memory"): self.transition_executing_done, ("long-running", "rescheduled"): self.transition_executing_done, ("flight", "memory"): self.transition_flight_memory, - ("flight", "ready"): self.transition_flight_memory, + ("flight", "ready"): self.transition_flight_ready, ("flight", "waiting"): self.transition_flight_waiting, } @@ -867,7 +866,14 @@ async def _register_with_scheduler(self): keys=list(self.data), nthreads=self.nthreads, name=self.name, - nbytes={ts.key: ts.get_nbytes() for ts in self.tasks.values()}, + nbytes={ + # FIXME: Tasks and data should not be allowed to + # diverge. Data should always be a subset of tasks. + # For all practical purposes this should only happen + # if someone modifies the dicts manually + k: self.tasks.get(k, TaskState(k)).get_nbytes() + for k in self.data.keys() + }, types={k: typename(v) for k, v in self.data.items()}, now=time(), resources=self.total_resources, @@ -1336,6 +1342,11 @@ async def get_data( from .actor import Actor data[k] = Actor(type(self.actors[k]), self.address, k) + if len(data) != len(keys): + logger.debug( + "Data request from %s but keys %s are not available." + % (who, set(keys) - set(data)) + ) msg = {"status": "OK", "data": {k: to_serialize(v) for k, v in data.items()}} nbytes = {k: self.tasks[k].nbytes for k in data if k in self.tasks} @@ -1348,6 +1359,16 @@ async def get_data( compressed = await comm.write(msg, serializers=serializers) response = await comm.read(deserializers=serializers) assert response == "OK", response + + except CommClosedError: + logger.exception( + "Other end hung up during get_data with %s -> %s", + self.address, + who, + exc_info=True, + ) + comm.abort() + raise except EnvironmentError: logger.exception( "failed during get data with %s -> %s", self.address, who, exc_info=True @@ -1406,7 +1427,7 @@ def delete_data(self, comm=None, keys=None, report=True): if keys: for key in list(keys): self.log.append((key, "delete")) - self.release_key(key) + self.release_key(key, reason="delete") logger.debug("Worker %s -- Deleted %d keys", self.name, len(keys)) return "OK" @@ -1444,9 +1465,11 @@ def add_task( actor=False, **kwargs2, ): + runspec = SerializedTask(function, args, kwargs, task) try: if key in self.tasks: ts = self.tasks[key] + ts.runspec = runspec if ts.state == "memory": assert key in self.data or key in self.actors logger.debug( @@ -1454,65 +1477,54 @@ def add_task( ) self.send_task_state_to_scheduler(ts) return - if ts.state in IN_PLAY: - return - if ts.state == "erred": - ts.exception = None - ts.traceback = None - else: - ts.state = "waiting" else: self.log.append((key, "new")) - self.tasks[key] = ts = TaskState( - key=key, runspec=SerializedTask(function, args, kwargs, task) - ) + self.tasks[key] = ts = TaskState(key=key) ts.state = "waiting" - if priority is not None: - priority = tuple(priority) + (self.generation,) - self.generation -= 1 - if actor: self.actors[ts.key] = None ts.priority = priority ts.duration = duration + ts.runspec = runspec if resource_restrictions: ts.resource_restrictions = resource_restrictions who_has = who_has or {} for dependency, workers in who_has.items(): - assert workers if dependency not in self.tasks: self.tasks[dependency] = dep_ts = TaskState(key=dependency) - dep_ts.state = ( - "waiting" if dependency not in self.data else "memory" - ) + dep_ts.state = "waiting" + logger.debug(f"New dep {dep_ts}") + self.log.append((dependency, "new-dep", dep_ts.state)) + else: + dep_ts = self.tasks[dependency] + logger.debug(f"Known dep {dep_ts}") + self.log.append((dependency, "known-dep", dep_ts.state)) - dep_ts = self.tasks[dependency] - self.log.append((dependency, "new-dep", dep_ts.state)) + if self.address in workers and dep_ts.state != "memory": + logger.debug( + f"Who has claims Worker {self.name} would own data of {dependency} but this is false." + ) - if dep_ts.state != "memory": + if dep_ts.state not in ("memory",) and dep_ts.runspec is None: ts.waiting_for_data.add(dep_ts.key) self.waiting_for_data_count += 1 - dep_ts.who_has.update(workers) - ts.dependencies.add(dep_ts) dep_ts.dependents.add(ts) - for worker in workers: - self.has_what[worker].add(dep_ts.key) - if dep_ts.state != "memory": - self.pending_data_per_worker[worker].append(dep_ts.key) + self.update_who_has(who_has) if nbytes is not None: for key, value in nbytes.items(): self.tasks[key].nbytes = value if ts.waiting_for_data: - self.data_needed.append(ts.key) + if ts.key not in self.data_needed: + self.data_needed.append(ts.key) else: self.transition(ts, "ready") if self.validate: @@ -1540,17 +1552,16 @@ def transition(self, ts, finish, **kwargs): state = func(ts, **kwargs) self.log.append((ts.key, start, state or finish)) ts.state = state or finish - if self.validate: - self.validate_task(ts) self._notify_plugins("transition", ts.key, start, state or finish, **kwargs) def transition_waiting_flight(self, ts, worker=None): try: if self.validate: - assert ts.state != "flight" + assert ts.state == "waiting" assert ts.dependents ts.coming_from = worker + self.in_flight_workers[worker].add(ts.key) self.in_flight_tasks += 1 except Exception as e: logger.exception(e) @@ -1560,33 +1571,22 @@ def transition_waiting_flight(self, ts, worker=None): pdb.set_trace() raise - def transition_flight_waiting(self, ts, worker=None, remove=True): + def transition_flight_waiting(self, ts, worker=None, remove=True, runspec=None): try: if self.validate: assert ts.state == "flight" self.in_flight_tasks -= 1 ts.coming_from = None - if remove: - try: - ts.who_has.remove(worker) - self.has_what[worker].remove(ts.key) - except KeyError: - pass + ts.runspec = runspec or ts.runspec - if not ts.who_has: - if ts.key not in self._missing_dep_flight: - self._missing_dep_flight.add(ts.key) - self.loop.add_callback(self.handle_missing_dep, ts) for dependent in ts.dependents: if dependent.state == "waiting": - if remove: # try a new worker immediately - self.data_needed.appendleft(dependent.key) - else: # worker was probably busy, wait a while - self.data_needed.append(dependent.key) - - if not ts.dependents: - self.release_key(ts.key) + if dependent.key not in self.data_needed: + if remove: # try a new worker immediately + self.data_needed.appendleft(dependent.key) + else: # worker was probably busy, wait a while + self.data_needed.append(dependent.key) except Exception as e: logger.exception(e) if LOG_PDB: @@ -1602,19 +1602,39 @@ def transition_flight_memory(self, ts, value=None): self.in_flight_tasks -= 1 ts.coming_from = None - if ts.dependents: - self.put_key_in_memory(ts, value) - for dependent in ts.dependents: - try: - dependent.waiting_for_data.remove(ts.key) - self.waiting_for_data_count -= 1 - except KeyError: - pass + self.put_key_in_memory(ts, value) + for dependent in ts.dependents: + dependent.waiting_for_data.discard(ts.key) + self.waiting_for_data_count -= 1 - self.batched_stream.send({"op": "add-keys", "keys": [ts.key]}) - else: - self.release_key(ts.key) + self.batched_stream.send({"op": "add-keys", "keys": [ts.key]}) + + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb + pdb.set_trace() + raise + + def transition_flight_ready(self, ts): + try: + if self.validate: + assert ts.state == "flight" + assert not ts.waiting_for_data + assert all( + dep.key in self.data or dep.key in self.actors + for dep in ts.dependencies + ) + assert all(dep.state == "memory" for dep in ts.dependencies) + assert ts.key not in self.ready + assert ts.runspec is not None + ts.coming_from = None + if ts.resource_restrictions is not None: + self.constrained.append(ts.key) + return "constrained" + else: + heapq.heappush(self.ready, (ts.priority, ts.key)) except Exception as e: logger.exception(e) if LOG_PDB: @@ -1634,6 +1654,7 @@ def transition_waiting_ready(self, ts): ) assert all(dep.state == "memory" for dep in ts.dependencies) assert ts.key not in self.ready + assert ts.runspec is not None ts.waiting_for_data.clear() @@ -1691,6 +1712,12 @@ def transition_ready_executing(self, ts): pdb.set_trace() raise + def transition_ready_error(self, ts): + if self.validate: + assert ts.exception is not None + assert ts.traceback is not None + self.send_task_state_to_scheduler(ts) + def transition_ready_memory(self, ts, value=None): self.send_task_state_to_scheduler(ts) @@ -1731,11 +1758,6 @@ def transition_executing_done(self, ts, value=no_value, report=True): ts.state = "error" out = "error" - # Don't release the dependency keys, but do remove them from `dependents` - for dependency in ts.dependencies: - dependency.dependents.discard(ts) - ts.dependencies.clear() - if report and self.batched_stream and self.status == Status.running: self.send_task_state_to_scheduler(ts) else: @@ -1838,21 +1860,6 @@ def ensure_communicating(self): assert all(dep.key in self.tasks for dep in deps) deps = [dep for dep in deps if dep.state == "waiting"] - - missing_deps = {dep for dep in deps if not dep.who_has} - if missing_deps: - logger.info("Can't find dependencies for key %s", key) - missing_deps2 = { - dep - for dep in missing_deps - if dep.key not in self._missing_dep_flight - } - for dep in missing_deps2: - self._missing_dep_flight.add(dep.key) - self.loop.add_callback(self.handle_missing_dep, *missing_deps2) - - deps = [dep for dep in deps if dep not in missing_deps] - self.log.append(("gather-dependencies", key, deps)) in_flight = False @@ -1864,8 +1871,6 @@ def ensure_communicating(self): dep = deps.pop() if dep.state != "waiting": continue - if not dep.who_has: - continue workers = [ w for w in dep.who_has if w not in self.in_flight_workers ] @@ -1882,11 +1887,10 @@ def ensure_communicating(self): worker, dep.key ) self.comm_nbytes += total_nbytes - self.in_flight_workers[worker] = to_gather for d in to_gather: self.transition(self.tasks[d], "flight", worker=worker) self.loop.add_callback( - self.gather_dep, worker, dep, to_gather, total_nbytes, cause=key + self.gather_dep, worker, to_gather, total_nbytes, cause=key ) changed = True @@ -1998,15 +2002,18 @@ def select_keys_for_gather(self, worker, dep): return deps, total_bytes - async def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): + async def gather_dep(self, worker, deps, total_nbytes, cause): """Gather dependencies for a task from a worker who has them Parameters ---------- worker : str address of worker to gather dependency from - dep : TaskState - task we want to gather dependencies for + cause : str + Task key the dependencies are gathered for + total_nbytes: int + The sum of bytes to be submited when gathering deps. After + finishing, this will be subtracted from self.comm_nbytes deps : list keys of dependencies to gather from worker -- this is not necessarily equivalent to the full list of dependencies of ``dep`` @@ -2022,11 +2029,11 @@ async def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): # dep states may have changed before gather_dep runs # if a dep is no longer in-flight then don't fetch it - deps_ts = [self.tasks.get(key, None) or TaskState(key) for key in deps] + deps_ts = [self.tasks[key] for key in deps] deps_ts = tuple(ts for ts in deps_ts if ts.state == "flight") deps = [d.key for d in deps_ts] - self.log.append(("request-dep", dep.key, worker, deps)) + self.log.append(("request-dep", cause, worker, deps)) logger.debug("Request %d keys", len(deps)) start = time() @@ -2035,6 +2042,9 @@ async def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): ) stop = time() + # FIXME: We have multiple busy handlings. There is also one in + # finally. May diverge. Alternative to return: raise Busy / + # except Busy: pass; if response["status"] == "busy": self.log.append(("busy-gather", worker, deps)) for ts in deps_ts: @@ -2096,39 +2106,52 @@ async def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): self.log.append(("receive-dep", worker, list(response["data"]))) except EnvironmentError as e: logger.exception("Worker stream died during communication: %s", worker) - self.log.append(("receive-dep-failed", worker)) - for d in self.has_what.pop(worker): - self.tasks[d].who_has.remove(worker) - + self.log.append(("receive-dep-failed", worker, deps)) except Exception as e: - logger.exception(e) + logger.exception( + "Unexpected exception while gathering dependencies from %s", worker + ) + self.log.append(("receive-dep-failed", worker, deps)) if self.batched_stream and LOG_PDB: import pdb pdb.set_trace() + ts = self.tasks[cause] + self.bad_dep(ts) raise finally: self.comm_nbytes -= total_nbytes busy = response.get("status", "") == "busy" data = response.get("data", {}) + if self.validate: + self.validate_state() + + missing = set() + for d in self.in_flight_workers.pop(worker): + if d not in self.tasks: + logger.debug("Task %s already forgotten." % d) + continue - ts = self.tasks.get(d) + ts = self.tasks[d] - if not busy and d in data: + if ts.key in data: self.transition(ts, "memory", value=data[d]) - elif ts is None or ts.state == "executing": - self.release_key(d) - continue - elif ts.state not in ("ready", "memory"): - self.transition(ts, "waiting", worker=worker, remove=not busy) - if not busy and d not in data and ts.dependents: - self.log.append(("missing-dep", d)) - self.batched_stream.send( - {"op": "missing-data", "errant_worker": worker, "key": d} - ) + if ts.state != "memory": + self.transition(ts, "waiting", worker=worker) + missing.add(ts) + + if self.validate: + self.validate_state() + if missing: + await self.handle_missing_dep(*missing, worker=worker) + if cause not in self.data_needed: + if self.validate: + assert isinstance(cause, str) + assert cause in self.tasks + self.data_needed.append(cause) if self.validate: self.validate_state() @@ -2137,83 +2160,68 @@ async def gather_dep(self, worker, dep, deps, total_nbytes, cause=None): if not busy: self.repetitively_busy = 0 - self.ensure_communicating() else: # Exponential backoff to avoid hammering scheduler/worker self.repetitively_busy += 1 await asyncio.sleep(0.100 * 1.5 ** self.repetitively_busy) + self.ensure_communicating() - # See if anyone new has the data - await self.query_who_has(dep.key) - self.ensure_communicating() - - def bad_dep(self, dep): - exc = ValueError( - "Could not find dependent %s. Check worker logs" % str(dep.key) + def bad_dep(self, ts): + exc = RuntimeError( + "Could not find dependencies for %s. Check worker logs" % str(ts.key) ) - for ts in dep.dependents: - msg = error_message(exc) - ts.exception = msg["exception"] - ts.traceback = msg["traceback"] - self.transition(ts, "error") - self.release_key(dep.key) - - async def handle_missing_dep(self, *deps, **kwargs): + msg = error_message(exc) + ts.exception = msg["exception"] + ts.traceback = msg["traceback"] + self.transition(ts, "error") + + async def handle_missing_dep(self, *deps, worker): + """In case a dependency is missing we'll sync up with the scheduler to + confirm that our view of the world is accurate. If the scheudler has new + information, we'll just try again. If the scheduler doesn't have new + information, this means something is wrong with the world and we'll escalate to + the scheduler and trigger a key state reset to recompute the key""" self.log.append(("handle-missing", deps)) - try: - deps = {dep for dep in deps if dep.dependents} - if not deps: - return - - for dep in deps: - if dep.suspicious_count > 5: - deps.remove(dep) - self.bad_dep(dep) - if not deps: - return - for dep in deps: - logger.info( - "Dependent not found: %s %s . Asking scheduler", - dep.key, - dep.suspicious_count, - ) + deps2 = { + dep for dep in deps if dep.dependents and dep.state in ("flight", "waiting") + } + logger.debug( + "Worker %s - %s - Handle missing %s from worker %s." + % (self.name, self.address, deps2, worker) + ) + if not deps2: + return - who_has = await retry_operation( - self.scheduler.who_has, keys=list(dep.key for dep in deps) + for dep in deps2: + logger.info( + "Dependent not found: %s %s . Asking scheduler", + dep.key, + dep.suspicious_count, ) - who_has = {k: v for k, v in who_has.items() if v} - self.update_who_has(who_has) - for dep in deps: - dep.suspicious_count += 1 - - if not who_has.get(dep.key): - self.log.append((dep.key, "no workers found", dep.dependents)) - self.release_key(dep.key) - else: - self.log.append((dep.key, "new workers found")) - for dependent in dep.dependents: - if dependent.key in dep.waiting_for_data: - self.data_needed.append(dependent.key) + who_has = await retry_operation( + self.scheduler.who_has, keys={d.key for d in deps2} + ) + who_has_2 = {k: v for k, v in who_has.items() if v} - except Exception: - logger.error("Handle missing dep failed, retrying", exc_info=True) - retries = kwargs.get("retries", 5) - self.log.append(("handle-missing-failed", retries, deps)) - if retries > 0: - await self.handle_missing_dep(*deps, retries=retries - 1) - else: - raise - finally: - try: - for dep in deps: - self._missing_dep_flight.remove(dep.key) - except KeyError: - pass + for task_key, task_who_has in who_has_2.items(): + ts = self.tasks[task_key] + deps2.remove(ts) - self.ensure_communicating() + if ts.who_has == set(task_who_has): + self.batched_stream.send( + {"op": "missing-data", "errant_worker": worker, "key": task_key} + ) + self.transition(ts, "waiting", worker=worker) + self.update_who_has(who_has) + for dep in deps2: + self.batched_stream.send( + {"op": "missing-data", "errant_worker": worker, "key": dep.key} + ) + self.transition(dep, "waiting", worker=worker) async def query_who_has(self, *deps): + # FIXME: If this is improperly called, the query fails. We should ensure that deps is a flat list here with log_errors(): response = await retry_operation(self.scheduler.who_has, keys=deps) self.update_who_has(response) @@ -2225,10 +2233,17 @@ def update_who_has(self, who_has): if not workers: continue - self.tasks[dep].who_has.update(workers) + dep_ts = self.tasks[dep] + old = self.tasks[dep].who_has + self.tasks[dep].who_has = set(workers) for worker in workers: self.has_what[worker].add(dep) + if dep_ts.state != "memory": + self.pending_data_per_worker[worker].append(dep_ts.key) + for worker in old - set(workers): + self.has_what[worker].discard(dep) + except Exception as e: logger.exception(e) if LOG_PDB: @@ -2250,32 +2265,30 @@ def steal_request(self, key): self.batched_stream.send(response) if state in ("ready", "waiting", "constrained"): - self.release_key(key) + self.release_key(key, reason="steal_request") - def release_key(self, key, cause=None, reason=None, report=True): + def release_key(self, key, reason=None, cause=None, report=True): try: + + if self.validate: + self.validate_state() ts = self.tasks.get(key, TaskState(key=key)) + logger.debug( + f"Worker {self.name} - {self.address} - Release key {ts} - Cause {cause} - Reason {reason}" + ) if cause: self.log.append((key, "release-key", {"cause": cause})) else: self.log.append((key, "release-key")) - if key in self.data and not ts.dependents: + + if key in self.data: try: del self.data[key] except FileNotFoundError: logger.error("Tried to delete %s but no file found", exc_info=True) - if key in self.actors and not ts.dependents: + if key in self.actors: del self.actors[key] - # for any dependencies of key we are releasing remove task as dependent - for dependency in ts.dependencies: - dependency.dependents.discard(ts) - if not dependency.dependents and dependency.state in ( - "waiting", - "flight", - ): - self.release_key(dependency.key) - for worker in ts.who_has: self.has_what[worker].discard(ts.key) @@ -2297,6 +2310,8 @@ def release_key(self, key, cause=None, reason=None, report=True): if key in self.tasks: self.tasks.pop(key) del ts + if self.validate: + self.validate_state() except CommClosedError: pass except Exception as e: @@ -2307,32 +2322,6 @@ def release_key(self, key, cause=None, reason=None, report=True): pdb.set_trace() raise - def rescind_key(self, key): - try: - if self.tasks[key].state not in PENDING: - return - - ts = self.tasks.pop(key) - - # Task has been rescinded - # For every task that it required - for dependency in ts.dependencies: - # Remove it as a dependent - dependency.dependents.remove(key) - # If the dependent is now without purpose (no dependencies), remove it - if not dependency.dependents: - self.release_key( - dependency.key, reason="All dependent keys rescinded" - ) - - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb - - pdb.set_trace() - raise - ################ # Execute Task # ################ @@ -2452,6 +2441,7 @@ def meets_resource_constraints(self, key): return True async def _maybe_deserialize_task(self, ts): + assert ts.runspec is not None if not isinstance(ts.runspec, SerializedTask): return ts.runspec try: @@ -2470,11 +2460,11 @@ async def _maybe_deserialize_task(self, ts): return function, args, kwargs except Exception as e: logger.warning("Could not deserialize task", exc_info=True) - emsg = error_message(e) - emsg["key"] = ts.key - emsg["op"] = "task-erred" - self.batched_stream.send(emsg) + msg = error_message(e) + ts.exception = msg["exception"] + ts.traceback = msg["traceback"] self.log.append((ts.key, "deserialize-error")) + self.transition(ts, "error") raise async def ensure_computing(self): @@ -2489,8 +2479,8 @@ async def ensure_computing(self): continue if self.meets_resource_constraints(key): self.constrained.popleft() + # Ensure task is deserialized prior to execution try: - # Ensure task is deserialized prior to execution ts.runspec = await self._maybe_deserialize_task(ts) except Exception: continue @@ -2508,8 +2498,8 @@ async def ensure_computing(self): elif ts.key in self.data: self.transition(ts, "memory") elif ts.state in READY: + # Ensure task is deserialized prior to execution try: - # Ensure task is deserialized prior to execution ts.runspec = await self._maybe_deserialize_task(ts) except Exception: continue @@ -2530,7 +2520,15 @@ async def execute(self, key, report=False): if key not in self.tasks: return ts = self.tasks[key] - if ts.state != "executing" or ts.runspec is None: + if ts.state != "executing": + # This might happen if keys are canceled + logger.debug( + "Trying to execute a task %s which is not in executing state anymore" + % ts + ) + return + if ts.runspec is None: + logger.critical("No runspec available for task %s." % ts) return if self.validate: assert not ts.waiting_for_data @@ -2601,7 +2599,9 @@ async def execute(self, key, report=False): if isinstance(result.pop("actual-exception"), Reschedule): self.batched_stream.send({"op": "reschedule", "key": ts.key}) self.transition(ts, "rescheduled", report=False) - self.release_key(ts.key, report=False) + self.release_key( + ts.key, report=False, reason="exception during execution" + ) else: ts.exception = result["exception"] ts.traceback = result["traceback"] @@ -2878,11 +2878,13 @@ def validate_task_memory(self, ts): assert not ts.waiting_for_data assert ts.key not in self.ready assert ts.state == "memory" + assert ts.coming_from is None def validate_task_executing(self, ts): assert ts.state == "executing" assert ts.key not in self.data assert not ts.waiting_for_data + assert ts.runspec is not None assert all( dep.key in self.data or dep.key in self.actors for dep in ts.dependencies ) @@ -2905,9 +2907,11 @@ def validate_task_waiting(self, ts): def validate_task_flight(self, ts): assert ts.key not in self.data assert not any(dep.key in self.ready for dep in ts.dependents) + assert ts.coming_from in self.in_flight_workers assert ts.key in self.in_flight_workers[ts.coming_from] def validate_task(self, ts): + assert ts.key in self.tasks try: if ts.state == "memory": self.validate_task_memory(ts) @@ -2931,6 +2935,8 @@ def validate_state(self): if self.status != Status.running: return try: + assert len(self.data_needed) == len(set(self.data_needed)) + waiting_keys = set() for ts in self.tasks.values(): assert ts.state is not None # check that worker has task @@ -2950,14 +2956,19 @@ def validate_state(self): assert ( ts_wait.state == "flight" or ts_wait.state == "waiting" - or ts.wait.key in self._missing_dep_flight or ts_wait.who_has.issubset(self.in_flight_workers) ) + waiting_keys.add(key) if ts.state == "memory": assert isinstance(ts.nbytes, int) assert not ts.waiting_for_data assert ts.key in self.data or ts.key in self.actors + # FIXME: Tracking of waiting_for_data_count is broken. Most likely + # since tasks may be required/waited on by multiple tasks but the + # counter implementation doesn't reflect this, see Worker.transition_waiting_done + # assert self.waiting_for_data_count >= 0 + # assert len(waiting_keys) == self.waiting_for_data_count for worker, keys in self.has_what.items(): for k in keys: assert worker in self.tasks[k].who_has @@ -3261,7 +3272,7 @@ async def _get_data(): except KeyError: raise ValueError("Unexpected response", response) else: - if status == "OK": + if not comm.closed() and status == "OK": await comm.write("OK") return response finally: