Skip to content

Commit

Permalink
Fix deadlocks and infinite loops in worker
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Jan 8, 2021
1 parent 658fccc commit 42690f0
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 245 deletions.
13 changes: 6 additions & 7 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,6 @@ async def send_recv(comm, reply=True, serializers=None, deserializers=None, **kw
msg = kwargs
msg["reply"] = reply
please_close = kwargs.get("close")
force_close = False
if deserializers is None:
deserializers = serializers
if deserializers is not None:
Expand All @@ -635,15 +634,15 @@ async def send_recv(comm, reply=True, serializers=None, deserializers=None, **kw
response = await comm.read(deserializers=deserializers)
else:
response = None
except EnvironmentError:
# On communication errors, we should simply close the communication
force_close = True
raise
except Exception as exc:
# If an exception occured we will need to close the comm, if possible.
# Otherwise the other end might wait for a reply while this end is
# reusing the comm for something else.
comm.abort()
raise exc
finally:
if please_close:
await comm.close()
elif force_close:
comm.abort()

if isinstance(response, dict) and response.get("status") == "uncaught-error":
if comm.deserialize:
Expand Down
16 changes: 14 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2738,7 +2738,14 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs):
ts._who_has,
)
if ws not in ts._who_has:
self.worker_send(worker, {"op": "release-task", "key": key})
self.worker_send(
worker,
{
"op": "release-task",
"key": key,
"reason": "stimulus task finished",
},
)
recommendations = {}

return recommendations
Expand Down Expand Up @@ -5223,7 +5230,12 @@ def transition_processing_released(self, key):
assert self.tasks[key].state == "processing"

self._remove_from_processing(
ts, send_worker_msg={"op": "release-task", "key": key}
ts,
send_worker_msg={
"op": "release-task",
"key": key,
"reason": "transition released",
},
)

ts.state = "released"
Expand Down
6 changes: 2 additions & 4 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1943,18 +1943,16 @@ def __setstate__(self, state):
async def test_badly_serialized_input(c, s, a, b):
o = BadlySerializedObject()

future = c.submit(inc, o)
future = c.submit(inc, o, key="broken")
futures = c.map(inc, range(10))

L = await c.gather(futures)
assert list(L) == list(map(inc, range(10)))
assert future.status == "error"

with pytest.raises(Exception) as info:
with pytest.raises(TypeError, match="hello!"):
await future

assert "hello!" in str(info.value)


@pytest.mark.skipif("True", reason="")
async def test_badly_serialized_input_stderr(capsys, c):
Expand Down
49 changes: 46 additions & 3 deletions distributed/tests/test_failed_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ async def test_worker_who_has_clears_after_failed_connection(c, s, a, b):
result = await c.submit(sum, futures, workers=a.address)
deps = [dep for dep in a.tasks.values() if dep.key not in a.data_needed]
for dep in deps:
a.release_key(dep.key, report=True)
a.release_key(dep.key, report=True, reason="test")

n_worker_address = n.worker_address
with suppress(CommClosedError):
Expand All @@ -404,6 +404,21 @@ async def test_worker_who_has_clears_after_failed_connection(c, s, a, b):
await n.close()


@gen_cluster(client=True)
async def test_worker_release_key_recovers(c, s, a, b):
futs = c.map(slowinc, range(2), delay=0.1)

await wait(futs)

assert len(a.tasks) > 0
for ts in list(a.tasks.values()):
a.release_key(ts.key, reason="test")

res = await c.submit(sum, futs)

assert res == 3


@pytest.mark.slow
@gen_cluster(client=True, timeout=60, Worker=Nanny, nthreads=[("127.0.0.1", 1)])
async def test_restart_timeout_on_long_running_task(c, s, a):
Expand All @@ -416,7 +431,6 @@ async def test_restart_timeout_on_long_running_task(c, s, a):
assert "timeout" not in text.lower()


@pytest.mark.slow
@gen_cluster(client=True, scheduler_kwargs={"worker_ttl": "500ms"})
async def test_worker_time_to_live(c, s, a, b):
from distributed.scheduler import heartbeat_interval
Expand All @@ -434,4 +448,33 @@ async def test_worker_time_to_live(c, s, a, b):
await asyncio.sleep(interval)
assert time() < start + interval + 0.1

set(s.workers) == {b.address}
assert set(s.workers) == {b.address}


@gen_cluster(client=True, Worker=Nanny)
async def test_get_data_faulty_dep(c, s, a, b):
"""This test creates a broken dependency and forces serialization by
requiring it to be submitted to another worker. The computation should
eventually finish by flagging the dep as bad and raise an appropriate
exception.
"""

class BrokenDeserialization:
def __setstate__(self, *state):
raise AttributeError()

def __getstate__(self, *args):
return ""

def create():
return BrokenDeserialization()

def collect(*args):
return args

fut1 = c.submit(create, workers=[a.name])

fut2 = c.submit(collect, fut1, workers=[b.name])

with pytest.raises(RuntimeError, match="Could not find dependencies for collect-"):
await fut2.result()
6 changes: 5 additions & 1 deletion distributed/tests/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,11 @@ async def test_minimum_resource(c, s, a):
assert a.total_resources == a.available_resources


@gen_cluster(client=True, nthreads=[("127.0.0.1", 2, {"resources": {"A": 1}})])
@gen_cluster(
client=True,
nthreads=[("127.0.0.1", 2, {"resources": {"A": 1}})],
active_rpc_timeout=10,
)
async def test_prefer_constrained(c, s, a):
futures = c.map(slowinc, range(1000), delay=0.1)
constrained = c.map(inc, range(10), resources={"A": 1})
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def test_cancel_stress(c, s, *workers):
def test_cancel_stress_sync(loop):
da = pytest.importorskip("dask.array")
x = da.random.random((50, 50), chunks=(2, 2))
with cluster(active_rpc_timeout=10) as (s, [a, b]):
with cluster(active_rpc_timeout=10, disconnect_timeout=10) as (s, [a, b]):
with Client(s["address"], loop=loop) as c:
x = c.persist(x)
y = (x.sum(axis=0) + x.sum(axis=1) + 1).std()
Expand Down
Loading

0 comments on commit 42690f0

Please sign in to comment.