diff --git a/distributed/worker.py b/distributed/worker.py index 4ba5d55133..e47d7bbf43 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -850,9 +850,7 @@ def __init__( pc = PeriodicCallback(self.heartbeat, self.heartbeat_interval * 1000) self.periodic_callbacks["heartbeat"] = pc - pc = PeriodicCallback( - lambda: self.batched_stream.send({"op": "keep-alive"}), 60000 - ) + pc = PeriodicCallback(lambda: self.batched_send({"op": "keep-alive"}), 60000) self.periodic_callbacks["keep-alive"] = pc pc = PeriodicCallback(self.find_missing, 1000) @@ -957,22 +955,15 @@ def logs(self): return self._deque_handler.deque def log_event(self, topic, msg): - if ( - not self.batched_stream - or not self.batched_stream.comm - or self.batched_stream.comm.closed() - ): - return # pragma: nocover - full_msg = { "op": "log-event", "topic": topic, "msg": msg, } if self.thread_id == threading.get_ident(): - self.batched_stream.send(full_msg) + self.batched_send(full_msg) else: - self.loop.add_callback(self.batched_stream.send, full_msg) + self.loop.add_callback(self.batched_send, full_msg) @property def executing_count(self) -> int: @@ -1004,20 +995,13 @@ def status(self, value): self.handle_stimulus(UnpauseEvent(stimulus_id=stimulus_id)) def _send_worker_status_change(self, stimulus_id: str) -> None: - if ( - self.batched_stream - and self.batched_stream.comm - and not self.batched_stream.comm.closed() - ): - self.batched_stream.send( - { - "op": "worker-status-change", - "status": self._status.name, - "stimulus_id": stimulus_id, - }, - ) - elif self._status != Status.closed: - self.loop.call_later(0.05, self._send_worker_status_change, stimulus_id) + self.batched_send( + { + "op": "worker-status-change", + "status": self._status.name, + "stimulus_id": stimulus_id, + }, + ) async def get_metrics(self) -> dict: try: @@ -1120,6 +1104,25 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: # External Services # ##################### + def batched_send(self, msg: dict[str, Any]) -> None: + """Send a fire-and-forget message to the scheduler through bulk comms. + + If we're not currently connected to the scheduler, the message will be silently + dropped! + + Parameters + ---------- + msg: dict + msgpack-serializable message to send to the scheduler. + Must have a 'op' key which is registered in Scheduler.stream_handlers. + """ + if ( + self.batched_stream + and self.batched_stream.comm + and not self.batched_stream.comm.closed() + ): + self.batched_stream.send(msg) + async def _register_with_scheduler(self): self.periodic_callbacks["keep-alive"].stop() self.periodic_callbacks["heartbeat"].stop() @@ -1571,12 +1574,7 @@ async def close( if self._protocol == "ucx": # pragma: no cover await asyncio.sleep(0.2) - if ( - self.batched_stream - and self.batched_stream.comm - and not self.batched_stream.comm.closed() - ): - self.batched_stream.send({"op": "close-stream"}) + self.batched_send({"op": "close-stream"}) if self.batched_stream: with suppress(TimeoutError): @@ -2879,13 +2877,7 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: for ts in tasks: self.validate_task(ts) - if self.batched_stream.closed(): - logger.debug( - "BatchedSend closed while transitioning tasks. %d tasks not sent.", - len(instructions), - ) - else: - self._handle_instructions(instructions) + self._handle_instructions(instructions) @fail_hard @log_errors @@ -2918,7 +2910,7 @@ def _handle_instructions(self, instructions: Instructions) -> None: task: asyncio.Task | None = None if isinstance(inst, SendMessageToScheduler): - self.batched_stream.send(inst.to_dict()) + self.batched_send(inst.to_dict()) elif isinstance(inst, EnsureCommunicatingAfterTransitions): # A single compute-task or acquire-replicas command may cause @@ -3496,7 +3488,7 @@ def handle_steal_request(self, key: str, stimulus_id: str) -> None: "state": state, "stimulus_id": stimulus_id, } - self.batched_stream.send(response) + self.batched_send(response) if state in READY | {"waiting"}: assert ts