Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure distributed.comm.core.connect can always be cancelled #6064

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion distributed/comm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from distributed.metrics import time
from distributed.protocol import pickle
from distributed.protocol.compression import get_default_compression
from distributed.utils import ensure_cancellation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -286,8 +287,11 @@ def time_left():
active_exception = None
while time_left() > 0:
try:
task = ensure_cancellation(
connector.connect(loc, deserialize=deserialize, **connection_args)
)
comm = await asyncio.wait_for(
connector.connect(loc, deserialize=deserialize, **connection_args),
task,
timeout=min(intermediate_cap, time_left()),
)
break
Expand Down
24 changes: 4 additions & 20 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,7 @@ async def _connect(self, addr, timeout=None):
deserialize=self.deserialize,
**self.connection_args,
)

comm.name = "ConnectionPool"
comm._pool = weakref.ref(self)
comm.allow_offload = self.allow_offload
Expand All @@ -1099,8 +1100,6 @@ async def _connect(self, addr, timeout=None):
raise
finally:
self._connecting_count -= 1
except asyncio.CancelledError:
raise CommClosedError("ConnectionPool closing.")
finally:
self._pending_count -= 1

Expand All @@ -1121,30 +1120,15 @@ async def connect(self, addr, timeout=None):
if self.semaphore.locked():
self.collect()

# This construction is there to ensure that cancellation requests from
# the outside can be distinguished from cancellations of our own.
# Once the CommPool closes, we'll cancel the connect_attempt which will
# raise an OSError
# If the ``connect`` is cancelled from the outside, the Event.wait will
# be cancelled instead which we'll reraise as a CancelledError and allow
# it to propagate
connect_attempt = asyncio.create_task(self._connect(addr, timeout))
done = asyncio.Event()
self._connecting.add(connect_attempt)
connect_attempt.add_done_callback(lambda _: done.set())
connect_attempt.add_done_callback(self._connecting.discard)

try:
await done.wait()
return await connect_attempt
except asyncio.CancelledError:
# This is an outside cancel attempt
connect_attempt.cancel()
try:
await connect_attempt
except CommClosedError:
pass
if self.status == Status.closed:
raise CommClosedError("ConnectionPool closed.")
raise
return await connect_attempt

def reuse(self, addr, comm):
"""
Expand Down
13 changes: 7 additions & 6 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5743,12 +5743,13 @@ async def test_client_active_bad_port():
application = tornado.web.Application([(r"/", tornado.web.RequestHandler)])
http_server = tornado.httpserver.HTTPServer(application)
http_server.listen(8080)
with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}):
c = Client("127.0.0.1:8080", asynchronous=True)
with pytest.raises((TimeoutError, IOError)):
await c
await c._close(fast=True)
http_server.stop()
try:
with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}):
with pytest.raises((TimeoutError, IOError)):
async with Client("127.0.0.1:8080", asynchronous=True) as c:
pass
finally:
http_server.stop()


@pytest.mark.parametrize("direct", [True, False])
Expand Down
111 changes: 71 additions & 40 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import dask

from distributed.comm.core import CommClosedError
from distributed.comm.tcp import TCPBackend, TCPConnector
from distributed.core import (
ConnectionPool,
Server,
Expand Down Expand Up @@ -591,26 +592,50 @@ async def ping(comm, delay=0.1):
await asyncio.gather(*[server.close() for server in servers])


@gen_test()
async def test_connection_pool_close_while_connecting(monkeypatch):
"""
Ensure a closed connection pool guarantees to have no connections left open
even if it is closed mid-connecting
"""
from distributed.comm.registry import backends
from distributed.comm.tcp import TCPBackend, TCPConnector
class WrongCancelConnector(TCPConnector):
async def connect(self, address, deserialize, **connection_args):
try:
await asyncio.sleep(10000)
except asyncio.CancelledError:
raise OSError("muhaha")

class SlowConnector(TCPConnector):
async def connect(self, address, deserialize, **connection_args):

class WrongCancelBackend(TCPBackend):
_connector_class = WrongCancelConnector


class SlowConnector(TCPConnector):
async def connect(self, address, deserialize, **connection_args):
try:
await asyncio.sleep(10000)
return await super().connect(
address, deserialize=deserialize, **connection_args
)
except BaseException:
raise

class SlowBackend(TCPBackend):
_connector_class = SlowConnector

monkeypatch.setitem(backends, "tcp", SlowBackend())
class SlowBackend(TCPBackend):
_connector_class = SlowConnector


@pytest.mark.parametrize(
"backend",
[
SlowBackend,
WrongCancelBackend,
],
)
@pytest.mark.parametrize(
"closing",
[
True,
False,
],
)
@gen_test()
async def test_connection_pool_cancellation(monkeypatch, closing, backend):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests do not capture this race condition but the rewrite is much cleaner and covers more ground so I'd like to keep it

# Ensure cancellation errors are properly reraised
from distributed.comm.registry import backends

monkeypatch.setitem(backends, "tcp", backend())

async with Server({}) as server:
await server.listen("tcp://")
Expand All @@ -623,53 +648,59 @@ async def connect_to_server():

# #tasks > limit
tasks = [asyncio.create_task(connect_to_server()) for _ in range(5)]

while not pool._connecting:
# Ensure the pool is saturated and some connection attempts are pending to
# connect
while pool._pending_count != len(tasks):
await asyncio.sleep(0.01)

await pool.close()
for t in tasks:
with pytest.raises(CommClosedError):
await t
if closing:
await pool.close()
for t in tasks:
with pytest.raises(CommClosedError):
await t
else:
for t in tasks:
t.cancel()
await asyncio.wait(tasks)
assert all(t.cancelled() for t in tasks)

assert not pool.open
assert not pool._n_connecting


@gen_test()
async def test_connection_pool_outside_cancellation(monkeypatch):
# Ensure cancellation errors are properly reraised
from distributed.comm.registry import backends
from distributed.comm.tcp import TCPBackend, TCPConnector
async def test_connect_properly_raising(monkeypatch):
_connecting = 0

class SlowConnector(TCPConnector):
async def connect(self, address, deserialize, **connection_args):
await asyncio.sleep(10000)
return await super().connect(
address, deserialize=deserialize, **connection_args
)
try:
nonlocal _connecting
_connecting += 1
await asyncio.sleep(10000)
except BaseException:
raise OSError

class SlowBackend(TCPBackend):
_connector_class = SlowConnector

# Ensure cancellation errors are properly reraised
from distributed.comm.registry import backends

monkeypatch.setitem(backends, "tcp", SlowBackend())

async with Server({}) as server:
await server.listen("tcp://")
pool = await ConnectionPool(limit=2)

async def connect_to_server():
comm = await pool.connect(server.address)
pool.reuse(server.address, comm)

# #tasks > limit
tasks = [asyncio.create_task(connect_to_server()) for _ in range(5)]
while not pool._connecting:
await asyncio.sleep(0.01)
tasks = [asyncio.create_task(connect(server.address)) for _ in range(5)]

while _connecting != len(tasks):
await asyncio.sleep(0.1)

for t in tasks:
t.cancel()

done, _ = await asyncio.wait(tasks)
await asyncio.wait(tasks)
assert all(t.cancelled() for t in tasks)


Expand Down
45 changes: 45 additions & 0 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TimeoutError,
_maybe_complex,
ensure_bytes,
ensure_cancellation,
ensure_ip,
format_dashboard_link,
get_ip_interface,
Expand Down Expand Up @@ -782,3 +783,47 @@ def __repr__(self):
],
}
assert recursive_to_dict(info) == expect


def test_ensure_cancellation():
# Do not use gen_test to allow us to test on CancelledErrors
async def _():
ev = asyncio.Event()

async def f():
await asyncio.sleep(0)
ev.set()
raise ValueError("foo")

async def g():
ev.set()
await asyncio.sleep(1000000)

task = asyncio.create_task(f())
await ev.wait()
await asyncio.sleep(0)
with pytest.raises(ValueError, match="foo"):
await task
ev.clear()

task = asyncio.create_task(ensure_cancellation(f()))
await ev.wait()
await asyncio.sleep(0)
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task

ev.clear()
task = asyncio.create_task(ensure_cancellation(g()))
await ev.wait()
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task

async def h():
await asyncio.sleep(0)
return 1

assert await ensure_cancellation(h()) == 1

asyncio.run(_())
23 changes: 11 additions & 12 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,18 +321,17 @@ async def test_worker_port_range(s):
pass


@pytest.mark.slow
@gen_test(timeout=60)
async def test_worker_waits_for_scheduler():
w = Worker("127.0.0.1:8724")
try:
await asyncio.wait_for(w, 3)
except TimeoutError:
pass
else:
assert False
assert w.status not in (Status.closed, Status.running, Status.paused)
await w.close(timeout=0.1)
@pytest.mark.parametrize("connect_timeout", ["1s", "5s"])
@gen_test()
async def test_worker_waits_for_scheduler(connect_timeout):
with dask.config.set({"distributed.comm.timeouts.connect": connect_timeout}):
w = Worker("127.0.0.1:8724")

with pytest.raises(TimeoutError):
await asyncio.wait_for(w, 3)

assert w.status not in (Status.closed, Status.running, Status.paused)
await w.close()


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
Expand Down
28 changes: 26 additions & 2 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from hashlib import md5
from importlib.util import cache_from_source
from time import sleep
from types import ModuleType
from types import CoroutineType, ModuleType
from typing import Any as AnyType
from typing import ClassVar
from typing import ClassVar, TypeVar

import click
import tblib.pickling_support
Expand Down Expand Up @@ -1621,3 +1621,27 @@ def is_python_shutting_down() -> bool:
from distributed import _python_shutting_down

return _python_shutting_down


T = TypeVar("T")


async def ensure_cancellation(coro: CoroutineType[None, None, T]) -> T:
"""Ensure that the wrapped coro will raise a CancelledError even if its
result is already set.

See https://github.com/python/cpython/issues/86296
"""
watcher = asyncio.Event()

task = asyncio.create_task(coro)
task.add_done_callback(lambda _: watcher.set())

try:
await watcher.wait()
except asyncio.CancelledError:
task.cancel()
await watcher.wait()
raise
Comment on lines +1643 to +1645
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure if I need to await the task itself here. There are no warnings

cc @graingert


return task.result()