Skip to content

Commit

Permalink
Encapsulate Worker.batched_stream.send() (#6475)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Jun 2, 2022
1 parent 7e49d88 commit 69b798d
Showing 1 changed file with 33 additions and 41 deletions.
74 changes: 33 additions & 41 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,9 +850,7 @@ def __init__(
pc = PeriodicCallback(self.heartbeat, self.heartbeat_interval * 1000)
self.periodic_callbacks["heartbeat"] = pc

pc = PeriodicCallback(
lambda: self.batched_stream.send({"op": "keep-alive"}), 60000
)
pc = PeriodicCallback(lambda: self.batched_send({"op": "keep-alive"}), 60000)
self.periodic_callbacks["keep-alive"] = pc

pc = PeriodicCallback(self.find_missing, 1000)
Expand Down Expand Up @@ -957,22 +955,15 @@ def logs(self):
return self._deque_handler.deque

def log_event(self, topic, msg):
if (
not self.batched_stream
or not self.batched_stream.comm
or self.batched_stream.comm.closed()
):
return # pragma: nocover

full_msg = {
"op": "log-event",
"topic": topic,
"msg": msg,
}
if self.thread_id == threading.get_ident():
self.batched_stream.send(full_msg)
self.batched_send(full_msg)
else:
self.loop.add_callback(self.batched_stream.send, full_msg)
self.loop.add_callback(self.batched_send, full_msg)

@property
def executing_count(self) -> int:
Expand Down Expand Up @@ -1004,20 +995,13 @@ def status(self, value):
self.handle_stimulus(UnpauseEvent(stimulus_id=stimulus_id))

def _send_worker_status_change(self, stimulus_id: str) -> None:
if (
self.batched_stream
and self.batched_stream.comm
and not self.batched_stream.comm.closed()
):
self.batched_stream.send(
{
"op": "worker-status-change",
"status": self._status.name,
"stimulus_id": stimulus_id,
},
)
elif self._status != Status.closed:
self.loop.call_later(0.05, self._send_worker_status_change, stimulus_id)
self.batched_send(
{
"op": "worker-status-change",
"status": self._status.name,
"stimulus_id": stimulus_id,
},
)

async def get_metrics(self) -> dict:
try:
Expand Down Expand Up @@ -1120,6 +1104,25 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
# External Services #
#####################

def batched_send(self, msg: dict[str, Any]) -> None:
"""Send a fire-and-forget message to the scheduler through bulk comms.
If we're not currently connected to the scheduler, the message will be silently
dropped!
Parameters
----------
msg: dict
msgpack-serializable message to send to the scheduler.
Must have a 'op' key which is registered in Scheduler.stream_handlers.
"""
if (
self.batched_stream
and self.batched_stream.comm
and not self.batched_stream.comm.closed()
):
self.batched_stream.send(msg)

async def _register_with_scheduler(self):
self.periodic_callbacks["keep-alive"].stop()
self.periodic_callbacks["heartbeat"].stop()
Expand Down Expand Up @@ -1571,12 +1574,7 @@ async def close(
if self._protocol == "ucx": # pragma: no cover
await asyncio.sleep(0.2)

if (
self.batched_stream
and self.batched_stream.comm
and not self.batched_stream.comm.closed()
):
self.batched_stream.send({"op": "close-stream"})
self.batched_send({"op": "close-stream"})

if self.batched_stream:
with suppress(TimeoutError):
Expand Down Expand Up @@ -2879,13 +2877,7 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None:
for ts in tasks:
self.validate_task(ts)

if self.batched_stream.closed():
logger.debug(
"BatchedSend closed while transitioning tasks. %d tasks not sent.",
len(instructions),
)
else:
self._handle_instructions(instructions)
self._handle_instructions(instructions)

@fail_hard
@log_errors
Expand Down Expand Up @@ -2918,7 +2910,7 @@ def _handle_instructions(self, instructions: Instructions) -> None:
task: asyncio.Task | None = None

if isinstance(inst, SendMessageToScheduler):
self.batched_stream.send(inst.to_dict())
self.batched_send(inst.to_dict())

elif isinstance(inst, EnsureCommunicatingAfterTransitions):
# A single compute-task or acquire-replicas command may cause
Expand Down Expand Up @@ -3496,7 +3488,7 @@ def handle_steal_request(self, key: str, stimulus_id: str) -> None:
"state": state,
"stimulus_id": stimulus_id,
}
self.batched_stream.send(response)
self.batched_send(response)

if state in READY | {"waiting"}:
assert ts
Expand Down

0 comments on commit 69b798d

Please sign in to comment.