From 676bb1883621a11b285faf041f8e5d357cb73ee8 Mon Sep 17 00:00:00 2001 From: fjetter Date: Mon, 4 Apr 2022 18:32:55 +0200 Subject: [PATCH] Ensure Cancellation of distributed.comm.core.connect always raises CancelledError --- distributed/comm/core.py | 6 +- distributed/core.py | 24 ++----- distributed/tests/test_client.py | 13 ++-- distributed/tests/test_core.py | 111 ++++++++++++++++++++----------- distributed/tests/test_utils.py | 45 +++++++++++++ distributed/tests/test_worker.py | 23 +++---- distributed/utils.py | 28 +++++++- 7 files changed, 169 insertions(+), 81 deletions(-) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 5f9235a26b..cd578f60d4 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -18,6 +18,7 @@ from distributed.metrics import time from distributed.protocol import pickle from distributed.protocol.compression import get_default_compression +from distributed.utils import ensure_cancellation logger = logging.getLogger(__name__) @@ -286,8 +287,11 @@ def time_left(): active_exception = None while time_left() > 0: try: + task = ensure_cancellation( + connector.connect(loc, deserialize=deserialize, **connection_args) + ) comm = await asyncio.wait_for( - connector.connect(loc, deserialize=deserialize, **connection_args), + task, timeout=min(intermediate_cap, time_left()), ) break diff --git a/distributed/core.py b/distributed/core.py index df3a50fd81..666f95af7a 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -1086,6 +1086,7 @@ async def _connect(self, addr, timeout=None): deserialize=self.deserialize, **self.connection_args, ) + comm.name = "ConnectionPool" comm._pool = weakref.ref(self) comm.allow_offload = self.allow_offload @@ -1099,8 +1100,6 @@ async def _connect(self, addr, timeout=None): raise finally: self._connecting_count -= 1 - except asyncio.CancelledError: - raise CommClosedError("ConnectionPool closing.") finally: self._pending_count -= 1 @@ -1121,30 +1120,15 @@ async def connect(self, addr, timeout=None): if self.semaphore.locked(): self.collect() - # This construction is there to ensure that cancellation requests from - # the outside can be distinguished from cancellations of our own. - # Once the CommPool closes, we'll cancel the connect_attempt which will - # raise an OSError - # If the ``connect`` is cancelled from the outside, the Event.wait will - # be cancelled instead which we'll reraise as a CancelledError and allow - # it to propagate connect_attempt = asyncio.create_task(self._connect(addr, timeout)) - done = asyncio.Event() self._connecting.add(connect_attempt) - connect_attempt.add_done_callback(lambda _: done.set()) connect_attempt.add_done_callback(self._connecting.discard) - try: - await done.wait() + return await connect_attempt except asyncio.CancelledError: - # This is an outside cancel attempt - connect_attempt.cancel() - try: - await connect_attempt - except CommClosedError: - pass + if self.status == Status.closed: + raise CommClosedError("ConnectionPool closed.") raise - return await connect_attempt def reuse(self, addr, comm): """ diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 5b62db623d..b83c7ee098 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5743,12 +5743,13 @@ async def test_client_active_bad_port(): application = tornado.web.Application([(r"/", tornado.web.RequestHandler)]) http_server = tornado.httpserver.HTTPServer(application) http_server.listen(8080) - with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}): - c = Client("127.0.0.1:8080", asynchronous=True) - with pytest.raises((TimeoutError, IOError)): - await c - await c._close(fast=True) - http_server.stop() + try: + with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}): + with pytest.raises((TimeoutError, IOError)): + async with Client("127.0.0.1:8080", asynchronous=True) as c: + pass + finally: + http_server.stop() @pytest.mark.parametrize("direct", [True, False]) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 3792338076..0bfdf7c2ce 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -11,6 +11,7 @@ import dask from distributed.comm.core import CommClosedError +from distributed.comm.tcp import TCPBackend, TCPConnector from distributed.core import ( ConnectionPool, Server, @@ -591,26 +592,50 @@ async def ping(comm, delay=0.1): await asyncio.gather(*[server.close() for server in servers]) -@gen_test() -async def test_connection_pool_close_while_connecting(monkeypatch): - """ - Ensure a closed connection pool guarantees to have no connections left open - even if it is closed mid-connecting - """ - from distributed.comm.registry import backends - from distributed.comm.tcp import TCPBackend, TCPConnector +class WrongCancelConnector(TCPConnector): + async def connect(self, address, deserialize, **connection_args): + try: + await asyncio.sleep(10000) + except asyncio.CancelledError: + raise OSError("muhaha") - class SlowConnector(TCPConnector): - async def connect(self, address, deserialize, **connection_args): + +class WrongCancelBackend(TCPBackend): + _connector_class = WrongCancelConnector + + +class SlowConnector(TCPConnector): + async def connect(self, address, deserialize, **connection_args): + try: await asyncio.sleep(10000) - return await super().connect( - address, deserialize=deserialize, **connection_args - ) + except BaseException: + raise - class SlowBackend(TCPBackend): - _connector_class = SlowConnector - monkeypatch.setitem(backends, "tcp", SlowBackend()) +class SlowBackend(TCPBackend): + _connector_class = SlowConnector + + +@pytest.mark.parametrize( + "backend", + [ + SlowBackend, + WrongCancelBackend, + ], +) +@pytest.mark.parametrize( + "closing", + [ + True, + False, + ], +) +@gen_test() +async def test_connection_pool_cancellation(monkeypatch, closing, backend): + # Ensure cancellation errors are properly reraised + from distributed.comm.registry import backends + + monkeypatch.setitem(backends, "tcp", backend()) async with Server({}) as server: await server.listen("tcp://") @@ -623,53 +648,59 @@ async def connect_to_server(): # #tasks > limit tasks = [asyncio.create_task(connect_to_server()) for _ in range(5)] - - while not pool._connecting: + # Ensure the pool is saturated and some connection attempts are pending to + # connect + while pool._pending_count != len(tasks): await asyncio.sleep(0.01) - await pool.close() - for t in tasks: - with pytest.raises(CommClosedError): - await t + if closing: + await pool.close() + for t in tasks: + with pytest.raises(CommClosedError): + await t + else: + for t in tasks: + t.cancel() + await asyncio.wait(tasks) + assert all(t.cancelled() for t in tasks) + assert not pool.open assert not pool._n_connecting @gen_test() -async def test_connection_pool_outside_cancellation(monkeypatch): - # Ensure cancellation errors are properly reraised - from distributed.comm.registry import backends - from distributed.comm.tcp import TCPBackend, TCPConnector +async def test_connect_properly_raising(monkeypatch): + _connecting = 0 class SlowConnector(TCPConnector): async def connect(self, address, deserialize, **connection_args): - await asyncio.sleep(10000) - return await super().connect( - address, deserialize=deserialize, **connection_args - ) + try: + nonlocal _connecting + _connecting += 1 + await asyncio.sleep(10000) + except BaseException: + raise OSError class SlowBackend(TCPBackend): _connector_class = SlowConnector + # Ensure cancellation errors are properly reraised + from distributed.comm.registry import backends + monkeypatch.setitem(backends, "tcp", SlowBackend()) async with Server({}) as server: await server.listen("tcp://") - pool = await ConnectionPool(limit=2) - - async def connect_to_server(): - comm = await pool.connect(server.address) - pool.reuse(server.address, comm) # #tasks > limit - tasks = [asyncio.create_task(connect_to_server()) for _ in range(5)] - while not pool._connecting: - await asyncio.sleep(0.01) + tasks = [asyncio.create_task(connect(server.address)) for _ in range(5)] + + while _connecting != len(tasks): + await asyncio.sleep(0.1) for t in tasks: t.cancel() - - done, _ = await asyncio.wait(tasks) + await asyncio.wait(tasks) assert all(t.cancelled() for t in tasks) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index c281d418ec..2bb729af32 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -27,6 +27,7 @@ TimeoutError, _maybe_complex, ensure_bytes, + ensure_cancellation, ensure_ip, format_dashboard_link, get_ip_interface, @@ -782,3 +783,47 @@ def __repr__(self): ], } assert recursive_to_dict(info) == expect + + +def test_ensure_cancellation(): + # Do not use gen_test to allow us to test on CancelledErrors + async def _(): + ev = asyncio.Event() + + async def f(): + await asyncio.sleep(0) + ev.set() + raise ValueError("foo") + + async def g(): + ev.set() + await asyncio.sleep(1000000) + + task = asyncio.create_task(f()) + await ev.wait() + await asyncio.sleep(0) + with pytest.raises(ValueError, match="foo"): + await task + ev.clear() + + task = asyncio.create_task(ensure_cancellation(f())) + await ev.wait() + await asyncio.sleep(0) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + ev.clear() + task = asyncio.create_task(ensure_cancellation(g())) + await ev.wait() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + async def h(): + await asyncio.sleep(0) + return 1 + + assert await ensure_cancellation(h()) == 1 + + asyncio.run(_()) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 152320aa26..44c4311859 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -321,18 +321,17 @@ async def test_worker_port_range(s): pass -@pytest.mark.slow -@gen_test(timeout=60) -async def test_worker_waits_for_scheduler(): - w = Worker("127.0.0.1:8724") - try: - await asyncio.wait_for(w, 3) - except TimeoutError: - pass - else: - assert False - assert w.status not in (Status.closed, Status.running, Status.paused) - await w.close(timeout=0.1) +@pytest.mark.parametrize("connect_timeout", ["1s", "5s"]) +@gen_test() +async def test_worker_waits_for_scheduler(connect_timeout): + with dask.config.set({"distributed.comm.timeouts.connect": connect_timeout}): + w = Worker("127.0.0.1:8724") + + with pytest.raises(TimeoutError): + await asyncio.wait_for(w, 3) + + assert w.status not in (Status.closed, Status.running, Status.paused) + await w.close() @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) diff --git a/distributed/utils.py b/distributed/utils.py index 0aae248588..481e1003a2 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -27,9 +27,9 @@ from hashlib import md5 from importlib.util import cache_from_source from time import sleep -from types import ModuleType +from types import CoroutineType, ModuleType from typing import Any as AnyType -from typing import ClassVar +from typing import ClassVar, TypeVar import click import tblib.pickling_support @@ -1621,3 +1621,27 @@ def is_python_shutting_down() -> bool: from distributed import _python_shutting_down return _python_shutting_down + + +T = TypeVar("T") + + +async def ensure_cancellation(coro: CoroutineType[None, None, T]) -> T: + """Ensure that the wrapped coro will raise a CancelledError even if its + result is already set. + + See https://github.com/python/cpython/issues/86296 + """ + watcher = asyncio.Event() + + task = asyncio.create_task(coro) + task.add_done_callback(lambda _: watcher.set()) + + try: + await watcher.wait() + except asyncio.CancelledError: + task.cancel() + await watcher.wait() + raise + + return task.result()