diff --git a/distributed/core.py b/distributed/core.py index accc976f72..6c86db301b 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -1006,10 +1006,7 @@ async def handle_stream(self, comm, extra=None): break handler = self.stream_handlers[op] if iscoroutinefunction(handler): - self._ongoing_background_tasks.call_soon( - handler, **merge(extra, msg) - ) - await asyncio.sleep(0) + await handler(**merge(extra, msg)) else: handler(**merge(extra, msg)) else: @@ -1521,6 +1518,14 @@ async def _() -> Self: return _().__await__() + async def __aenter__(self): + await self + return self + + async def __aexit__(self, *args): + await self.close() + return + async def start(self) -> None: # Invariant: semaphore._value == limit - open - _n_connecting self.semaphore = asyncio.Semaphore(self.limit) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 57cb266674..0f96df074f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5650,12 +5650,12 @@ def handle_worker_status_change( self.idle_task_count.discard(ws) self.saturated.discard(ws) - async def handle_request_refresh_who_has( + def handle_request_refresh_who_has( self, keys: Iterable[str], worker: str, stimulus_id: str ) -> None: - """Asynchronous request (through bulk comms) from a Worker to refresh the - who_has for some keys. Not to be confused with scheduler.who_has, which is a - synchronous RPC request from a Client. + """Request from a Worker to refresh the + who_has for some keys. Not to be confused with scheduler.who_has, which + is a dedicated comm RPC request from a Client. """ who_has = {} free_keys = [] diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 193c1554f0..5eb05dcb1f 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -3,6 +3,7 @@ import asyncio import contextlib import os +import random import socket import sys import threading @@ -15,6 +16,7 @@ import dask +from distributed.batched import BatchedSend from distributed.comm.core import CommClosedError from distributed.comm.registry import backends from distributed.comm.tcp import TCPBackend, TCPListener @@ -1380,3 +1382,78 @@ async def test_async_listener_stop(monkeypatch): async with Server({}) as s: await s.listen(0) assert s.listeners + + +@gen_test() +async def test_messages_are_ordered_bsend(): + ledger = [] + + async def async_handler(val): + await asyncio.sleep(0.01 * random.random()) + ledger.append(val) + + def sync_handler(val): + ledger.append(val) + + async with Server( + {}, + stream_handlers={ + "sync_handler": sync_handler, + "async_handler": async_handler, + }, + ) as s: + await s.listen() + comm = await connect(s.address) + try: + b = BatchedSend(interval=10) + try: + await comm.write({"op": "connection_stream"}) + b.start(comm) + n = 100 + for ix in range(n): + if ix % 2: + b.send({"op": "sync_handler", "val": ix}) + else: + b.send({"op": "async_handler", "val": ix}) + while not len(ledger) == n: + await asyncio.sleep(0.01) + assert ledger == list(range(n)) + finally: + await b.close() + finally: + await comm.close() + + +@gen_test() +async def test_messages_are_ordered_raw(): + ledger = [] + + async def async_handler(val): + await asyncio.sleep(0.01 * random.random()) + ledger.append(val) + + def sync_handler(val): + ledger.append(val) + + async with Server( + {}, + stream_handlers={ + "sync_handler": sync_handler, + "async_handler": async_handler, + }, + ) as s: + await s.listen() + comm = await connect(s.address) + try: + await comm.write({"op": "connection_stream"}) + n = 100 + for ix in range(n): + if ix % 2: + await comm.write({"op": "sync_handler", "val": ix}) + else: + await comm.write({"op": "async_handler", "val": ix}) + while not len(ledger) == n: + await asyncio.sleep(0.01) + assert ledger == list(range(n)) + finally: + await comm.close()