Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safe wait for #6235

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
no_default,
sync,
thread_state,
wait_for,
)
from distributed.utils_comm import (
WrappedKey,
Expand Down Expand Up @@ -596,7 +597,7 @@ async def wait(self, timeout=None):
Time in seconds after which to raise a
``dask.distributed.TimeoutError``
"""
await asyncio.wait_for(self._get_event().wait(), timeout)
await wait_for(self._get_event().wait(), timeout)

def __repr__(self):
return f"<{self.__class__.__name__}: {self.status}>"
Expand Down Expand Up @@ -1258,7 +1259,7 @@ async def _ensure_connected(self, timeout=None):
)
comm.name = "Client->Scheduler"
if timeout is not None:
await asyncio.wait_for(self._update_scheduler_info(), timeout)
await wait_for(self._update_scheduler_info(), timeout)
else:
await self._update_scheduler_info()
await comm.write(
Expand All @@ -1277,7 +1278,7 @@ async def _ensure_connected(self, timeout=None):
finally:
self._connecting_to_scheduler = False
if timeout is not None:
msg = await asyncio.wait_for(comm.read(), timeout)
msg = await wait_for(comm.read(), timeout)
else:
msg = await comm.read()
assert len(msg) == 1
Expand Down Expand Up @@ -1522,7 +1523,7 @@ async def _close(self, fast=False):
and handle_report_task is not current_task
):
with suppress(asyncio.CancelledError, TimeoutError):
await asyncio.wait_for(asyncio.shield(handle_report_task), 0.1)
await wait_for(asyncio.shield(handle_report_task), 0.1)

if (
self.scheduler_comm
Expand Down Expand Up @@ -1550,7 +1551,7 @@ async def _close(self, fast=False):
and handle_report_task is not current_task
):
with suppress(TimeoutError, asyncio.CancelledError):
await asyncio.wait_for(handle_report_task, 0 if fast else 2)
await wait_for(handle_report_task, 0 if fast else 2)

with suppress(AttributeError):
await self.scheduler.close_rpc()
Expand Down Expand Up @@ -1594,7 +1595,7 @@ def close(self, timeout=no_default):
if self.asynchronous:
coro = self._close()
if timeout:
coro = asyncio.wait_for(coro, timeout)
coro = wait_for(coro, timeout)
return coro

if self._start_arg is None:
Expand Down Expand Up @@ -3354,7 +3355,7 @@ async def _restart(self, timeout=no_default):
self._send_to_scheduler({"op": "restart", "timeout": timeout})
self._restart_event = asyncio.Event()
try:
await asyncio.wait_for(self._restart_event.wait(), timeout)
await wait_for(self._restart_event.wait(), timeout)
except TimeoutError:
logger.error("Restart timed out after %.2f seconds", timeout)

Expand Down Expand Up @@ -4681,7 +4682,7 @@ async def _wait(fs, timeout=None, return_when=ALL_COMPLETED):

future = wait_for({f._state.wait() for f in fs})
if timeout is not None:
future = asyncio.wait_for(future, timeout)
future = wait_for(future, timeout)
await future

done, not_done = (
Expand Down
11 changes: 6 additions & 5 deletions distributed/comm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from distributed.metrics import time
from distributed.protocol import pickle
from distributed.protocol.compression import get_default_compression
from distributed.utils import wait_for

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -223,8 +224,8 @@ async def on_connection(self, comm: Comm, handshake_overrides=None):
# Timeout is to ensure that we'll terminate connections eventually.
# Connector side will employ smaller timeouts and we should only
# reach this if the comm is dead anyhow.
await asyncio.wait_for(comm.write(local_info), timeout=timeout)
handshake = await asyncio.wait_for(comm.read(), timeout=timeout)
await wait_for(comm.write(local_info), timeout=timeout)
handshake = await wait_for(comm.read(), timeout=timeout)
# This would be better, but connections leak if worker is closed quickly
# write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
except Exception as e:
Expand Down Expand Up @@ -286,7 +287,7 @@ def time_left():
active_exception = None
while time_left() > 0:
try:
comm = await asyncio.wait_for(
comm = await wait_for(
connector.connect(loc, deserialize=deserialize, **connection_args),
timeout=min(intermediate_cap, time_left()),
)
Expand Down Expand Up @@ -323,8 +324,8 @@ def time_left():
try:
# This would be better, but connections leak if worker is closed quickly
# write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
handshake = await asyncio.wait_for(comm.read(), time_left())
await asyncio.wait_for(comm.write(local_info), time_left())
handshake = await wait_for(comm.read(), time_left())
await wait_for(comm.write(local_info), time_left())
except Exception as exc:
with suppress(Exception):
await comm.close()
Expand Down
27 changes: 6 additions & 21 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
is_coroutine_function,
recursive_to_dict,
truncate_exception,
wait_for,
)


Expand Down Expand Up @@ -296,7 +297,7 @@ async def _():
return self
if timeout:
try:
await asyncio.wait_for(self.start(), timeout=timeout)
await wait_for(self.start(), timeout=timeout)
self.status = Status.running
except Exception:
await self.close(timeout=1)
Expand Down Expand Up @@ -1086,6 +1087,7 @@ async def _connect(self, addr, timeout=None):
deserialize=self.deserialize,
**self.connection_args,
)

comm.name = "ConnectionPool"
comm._pool = weakref.ref(self)
comm.allow_offload = self.allow_offload
Expand All @@ -1099,8 +1101,6 @@ async def _connect(self, addr, timeout=None):
raise
finally:
self._connecting_count -= 1
except asyncio.CancelledError:
raise CommClosedError("ConnectionPool closing.")
finally:
self._pending_count -= 1

Expand All @@ -1121,30 +1121,15 @@ async def connect(self, addr, timeout=None):
if self.semaphore.locked():
self.collect()

# This construction is there to ensure that cancellation requests from
# the outside can be distinguished from cancellations of our own.
# Once the CommPool closes, we'll cancel the connect_attempt which will
# raise an OSError
# If the ``connect`` is cancelled from the outside, the Event.wait will
# be cancelled instead which we'll reraise as a CancelledError and allow
# it to propagate
connect_attempt = asyncio.create_task(self._connect(addr, timeout))
done = asyncio.Event()
self._connecting.add(connect_attempt)
connect_attempt.add_done_callback(lambda _: done.set())
connect_attempt.add_done_callback(self._connecting.discard)

try:
await done.wait()
return await connect_attempt
except asyncio.CancelledError:
# This is an outside cancel attempt
connect_attempt.cancel()
try:
await connect_attempt
except CommClosedError:
pass
if self.status == Status.closed:
raise CommClosedError("ConnectionPool closed.")
raise
return await connect_attempt

def reuse(self, addr, comm):
"""
Expand Down
10 changes: 6 additions & 4 deletions distributed/deploy/tests/test_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from distributed.compatibility import LINUX, MACOS, WINDOWS
from distributed.metrics import time
from distributed.utils_test import async_wait_for, clean, gen_test, slowinc
from distributed.utils_test import async_wait_for_condition, clean, gen_test, slowinc


def test_adaptive_local_cluster(loop):
Expand Down Expand Up @@ -460,7 +460,7 @@ async def _():
await client.gather(futures)

del futures
await async_wait_for(lambda: not cluster.workers, 10)
await async_wait_for_condition(lambda: not cluster.workers, 10)


@gen_test()
Expand All @@ -475,8 +475,10 @@ async def test_adaptive_stopped():
instance = cluster.adapt(interval="10ms")
assert instance.periodic_callback is not None

await async_wait_for(lambda: instance.periodic_callback.is_running(), timeout=5)
await async_wait_for_condition(
lambda: instance.periodic_callback.is_running(), timeout=5
)

pc = instance.periodic_callback

await async_wait_for(lambda: not pc.is_running(), timeout=5)
await async_wait_for_condition(lambda: not pc.is_running(), timeout=5)
8 changes: 4 additions & 4 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from distributed import Worker, WorkerPlugin
from distributed.utils_test import async_wait_for, gen_cluster, inc
from distributed.utils_test import async_wait_for_condition, gen_cluster, inc


class MyPlugin(WorkerPlugin):
Expand Down Expand Up @@ -141,7 +141,7 @@ async def test_normal_task_transitions_called(c, s, w):

await c.register_worker_plugin(plugin)
await c.submit(lambda x: x, 1, key="task")
await async_wait_for(lambda: not w.tasks, timeout=10)
await async_wait_for_condition(lambda: not w.tasks, timeout=10)


@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
Expand Down Expand Up @@ -183,7 +183,7 @@ async def test_superseding_task_transitions_called(c, s, w):

await c.register_worker_plugin(plugin)
await c.submit(lambda x: x, 1, key="task", resources={"X": 1})
await async_wait_for(lambda: not w.tasks, timeout=10)
await async_wait_for_condition(lambda: not w.tasks, timeout=10)


@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
Expand All @@ -209,7 +209,7 @@ async def test_dependent_tasks(c, s, w):

await c.register_worker_plugin(plugin)
await c.get(dsk, "task", sync=False)
await async_wait_for(lambda: not w.tasks, timeout=10)
await async_wait_for_condition(lambda: not w.tasks, timeout=10)


@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
Expand Down
4 changes: 2 additions & 2 deletions distributed/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dask.utils import parse_timedelta

from distributed.client import Client
from distributed.utils import TimeoutError, log_errors
from distributed.utils import TimeoutError, log_errors, wait_for
from distributed.worker import get_worker

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -69,7 +69,7 @@ async def event_wait(self, name=None, timeout=None):
event = self._events[name]
future = event.wait()
if timeout is not None:
future = asyncio.wait_for(future, timeout)
future = wait_for(future, timeout)

self._waiter_count[name] += 1
try:
Expand Down
4 changes: 2 additions & 2 deletions distributed/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dask.utils import parse_timedelta

from distributed.client import Client
from distributed.utils import TimeoutError, log_errors
from distributed.utils import TimeoutError, log_errors, wait_for
from distributed.worker import get_worker

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -42,7 +42,7 @@ async def acquire(self, name=None, id=None, timeout=None):
self.events[name].append(event)
future = event.wait()
if timeout is not None:
future = asyncio.wait_for(future, timeout)
future = wait_for(future, timeout)
try:
await future
except TimeoutError:
Expand Down
4 changes: 2 additions & 2 deletions distributed/multi_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dask.utils import parse_timedelta

from distributed.client import Client
from distributed.utils import TimeoutError, log_errors
from distributed.utils import TimeoutError, log_errors, wait_for
from distributed.worker import get_worker

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -118,7 +118,7 @@ async def acquire(self, locks=None, id=None, timeout=None, num_locks=None):
self.events[id] = event
future = event.wait()
if timeout is not None:
future = asyncio.wait_for(future, timeout)
future = wait_for(future, timeout)
try:
await future
except TimeoutError:
Expand Down
9 changes: 4 additions & 5 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
mp_context,
parse_ports,
silence_logging,
wait_for,
)
from distributed.worker import Worker, run
from distributed.worker_memory import (
Expand Down Expand Up @@ -288,7 +289,7 @@ async def _unregister(self, timeout=10):

allowed_errors = (TimeoutError, CommClosedError, EnvironmentError, RPCClosed)
with suppress(allowed_errors):
await asyncio.wait_for(
await wait_for(
self.scheduler.unregister(address=self.worker_address), timeout
)

Expand Down Expand Up @@ -407,9 +408,7 @@ async def instantiate(self) -> Status:

if self.death_timeout:
try:
result = await asyncio.wait_for(
self.process.start(), self.death_timeout
)
result = await wait_for(self.process.start(), self.death_timeout)
except TimeoutError:
logger.error(
"Timed out connecting Nanny '%s' to scheduler '%s'",
Expand Down Expand Up @@ -476,7 +475,7 @@ async def _():
await self.instantiate()

try:
await asyncio.wait_for(_(), timeout)
await wait_for(_(), timeout)
except TimeoutError:
logger.error(
f"Restart timed out after {timeout}s; returning before finished"
Expand Down
4 changes: 2 additions & 2 deletions distributed/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import dask

from distributed.utils import TimeoutError, mp_context
from distributed.utils import TimeoutError, mp_context, wait_for

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -267,7 +267,7 @@ async def join(self, timeout=None):
try:
# Shield otherwise the timeout cancels the future and our
# on_exit callback will try to set a result on a canceled future
await asyncio.wait_for(asyncio.shield(self._exit_future), timeout)
await wait_for(asyncio.shield(self._exit_future), timeout)
except TimeoutError:
pass

Expand Down
4 changes: 2 additions & 2 deletions distributed/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from distributed.core import CommClosedError
from distributed.metrics import time
from distributed.protocol.serialize import to_serialize
from distributed.utils import TimeoutError, sync
from distributed.utils import TimeoutError, sync, wait_for

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -417,7 +417,7 @@ async def _():
await self.condition.wait()

try:
await asyncio.wait_for(_(), timeout2)
await wait_for(_(), timeout2)
finally:
self.condition.release()

Expand Down
Loading