From f79985a32e226c1948b28a0664ec2f959e39d663 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Thu, 18 Nov 2021 01:51:13 -0700 Subject: [PATCH] Light docs --- distributed/scheduler.py | 2 +- distributed/shuffle/shuffle_scheduler.py | 12 ++++- distributed/shuffle/shuffle_worker.py | 58 +++++++++++++++++++++--- 3 files changed, 62 insertions(+), 10 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 134a486cca..aa2d80775a 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -188,7 +188,7 @@ def nogil(func): ActiveMemoryManagerExtension, MemorySamplerExtension, ] -DEFAULT_PLUGINS: tuple[SchedulerPlugin, ...] = ( +DEFAULT_PLUGINS: "tuple[SchedulerPlugin, ...]" = ( (shuffle.ShuffleSchedulerPlugin(),) if shuffle.SHUFFLE_AVAILABLE else () ) # ^ TODO this assumes one Scheduler per process; probably a bad idea. diff --git a/distributed/shuffle/shuffle_scheduler.py b/distributed/shuffle/shuffle_scheduler.py index 774bf03d3c..2b7538d6dd 100644 --- a/distributed/shuffle/shuffle_scheduler.py +++ b/distributed/shuffle/shuffle_scheduler.py @@ -38,6 +38,7 @@ async def start(self, scheduler: Scheduler) -> None: self.scheduler = scheduler def transfer(self, id: ShuffleId, key: str) -> None: + "Handle a `transfer` task for a shuffle being scheduled" state = self.shuffles.get(id, None) if state: assert ( @@ -69,6 +70,7 @@ def transfer(self, id: ShuffleId, key: str) -> None: ) def barrier(self, id: ShuffleId, key: str) -> None: + "Handle a `barrier` task for a shuffle being scheduled" state = self.shuffles[id] assert ( not state.barrier_reached @@ -108,6 +110,7 @@ def barrier(self, id: ShuffleId, key: str) -> None: self.output_keys[dts.key] = id def unpack(self, id: ShuffleId, key: str) -> None: + "Handle an `unpack` task for a shuffle completing" # Check if all output keys are done # NOTE: we don't actually need this `unpack` step or tracking output keys; @@ -137,6 +140,7 @@ def unpack(self, id: ShuffleId, key: str) -> None: del self.shuffles[id] def erred(self, id: ShuffleId, key: str) -> None: + "Handle any task for a shuffle erroring" try: state = self.shuffles.pop(id) except KeyError: @@ -151,6 +155,7 @@ def erred(self, id: ShuffleId, key: str) -> None: del self.output_keys[k] def transition(self, key: str, start: str, finish: str, *args, **kwargs): + "Watch transitions for keys we care about" parts = parse_key(key) if parts and len(parts) == 3: prefix, group, id = parts @@ -184,8 +189,10 @@ def transition(self, key: str, start: str, finish: str, *args, **kwargs): def worker_for_key(self, key: str, npartitions: int, workers: list[str]) -> str: "Worker address this task should be assigned to" - # Infer which output partition number this task is fetching by parsing its key - # FIXME this is so brittle. + # Infer which output partition number this task is fetching by parsing its key. + # We have to parse keys, instead of generating the list of expected keys, because + # blockwise fusion means they won't just be `shuffle-unpack-abcde`. + # FIXME this feels very hacky/brittle. # For example, after `df.set_index(...).to_delayed()`, you could create # keys that don't have indices in them, and get fused (because they should!). m = re.match(r"\(.+, (\d+)\)$", key) @@ -202,6 +209,7 @@ def worker_for_key(self, key: str, npartitions: int, workers: list[str]) -> str: def parse_key(key: str) -> list[str] | None: + "Split a shuffle key into its prefix, group, and shuffle ID, or None if not a shuffle key." if TASK_PREFIX in key[: len(TASK_PREFIX) + 2]: if key[0] == "(": key = key_split_group(key) diff --git a/distributed/shuffle/shuffle_worker.py b/distributed/shuffle/shuffle_worker.py index 20df2ed041..ab35c4b7f5 100644 --- a/distributed/shuffle/shuffle_worker.py +++ b/distributed/shuffle/shuffle_worker.py @@ -49,6 +49,11 @@ def __init__(self, worker: Worker) -> None: ########## def shuffle_init(self, id: ShuffleId, workers: list[str], n_out_tasks: int) -> None: + """ + Handler: initialize a shuffle. Called by scheduler on all workers. + + Must be called exactly once per ID. + """ if id in self.shuffles: raise ValueError( f"Shuffle {id!r} is already registered on worker {self.worker.address}" @@ -71,6 +76,12 @@ def shuffle_receive( output_partition: int, data: pd.DataFrame, ) -> None: + """ + Handler: receive data from a peer. + + The shuffle ID can be unknown. + Calling after the barrier task is an error. + """ try: state = self.shuffles[id] except KeyError: @@ -89,6 +100,13 @@ def shuffle_receive( self.output_data[id][output_partition].append(data) async def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None: + """ + Handler: note that the barrier task has been reached. Called by a peer. + + The shuffle will be removed if this worker holds no output partitions for it. + + Must be called exactly once per ID. Blocks until `shuffle_init` has been called. + """ state = await self.get_shuffle(id) assert not state.barrier_reached, f"`inputs_done` called again for {id}" state.barrier_reached = True @@ -105,6 +123,15 @@ async def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None: async def add_partition( self, data: pd.DataFrame, id: ShuffleId, npartitions: int, column: str ) -> None: + """ + Task: Hand off an input partition to the extension. + + This will block until the extension is ready to receive another input partition. + Also blocks until `shuffle_init` has been called. + + Using an unknown ``shuffle_id`` is an error. + Calling after the barrier task is an error. + """ # Block until scheduler has called init state = await self.get_shuffle(id) assert not state.barrier_reached, f"`add_partition` for {id} after barrier" @@ -121,9 +148,14 @@ async def add_partition( await self.send_partition(data, column, id, npartitions, state.workers) async def barrier(self, id: ShuffleId) -> None: - # NOTE: requires workers list. This is guaranteed because it depends on `add_partition`, - # which got the workers list from the scheduler. So this task must run on a worker where - # `add_partition` has already run. + """ + Task: Note that the barrier task has been reached (`add_partition` called for all input partitions) + + Using an unknown ``shuffle_id`` is an error. + Must be called exactly once per ID. + Blocks until `shuffle_init` has been called (on all workers). + Calling this before all partitions have been added will cause `add_partition` to fail. + """ state = await self.get_shuffle(id) assert not state.barrier_reached, f"`barrier` for {id} called multiple times" @@ -140,13 +172,20 @@ async def barrier(self, id: ShuffleId) -> None: async def get_output_partition( self, id: ShuffleId, i: int, empty: pd.DataFrame ) -> pd.DataFrame: - state = self.shuffles[id] - # ^ Don't need to `get_shuffle`; `shuffle_inputs_done` has run already and guarantees it's there + """ + Task: Retrieve a shuffled output partition from the extension. + + After calling on the final output partition remaining on this worker, the shuffle will be cleaned up. + + Using an unknown ``shuffle_id`` is an error. + Requesting a partition which doesn't belong on this worker, or has already been retrieved, is an error. + """ + state = await self.get_shuffle(id) # should never have to wait assert state.barrier_reached, f"`get_output_partition` for {id} before barrier" assert ( state.out_parts_left > 0 ), f"No outputs remaining, but requested output partition {i} on {self.worker.address} for {id}." - # ^ Note: this is impossible with our cleanup-on-empty + # ^ Note: impossible with our cleanup-on-empty worker = worker_for(i, state.npartitions, state.workers) assert worker == self.worker.address, ( @@ -172,6 +211,7 @@ async def get_output_partition( ######### def remove(self, id: ShuffleId) -> None: + "Remove state for this shuffle. The shuffle must be complete and in a valid state." state = self.shuffles.pop(id) assert state.barrier_reached, f"Removed {id} before barrier" assert ( @@ -187,7 +227,8 @@ def remove(self, id: ShuffleId) -> None: not data ), f"Removed {id}, which still has data for output partitions {list(data)}" - async def get_shuffle(self, id: ShuffleId): + async def get_shuffle(self, id: ShuffleId) -> ShuffleState: + "Get the `ShuffleState`, blocking until it's been received from the scheduler." try: return self.shuffles[id] except KeyError: @@ -212,6 +253,7 @@ async def send_partition( npartitions: int, workers: list[str], ) -> None: + "Split up an input partition and send its parts to peers." tasks = [] # TODO grouping is blocking, should it be offloaded to a thread? # It mostly doesn't release the GIL though, so may not make much difference. @@ -231,9 +273,11 @@ async def send_partition( @property def loop(self) -> asyncio.AbstractEventLoop: + "The asyncio event loop for the worker" return self.worker.loop.asyncio_loop # type: ignore def sync(self, coro: Coroutine[object, object, T]) -> T: + "Run an async function on the worker's event loop, synchronously from another thread." # Is it a bad idea not to use `distributed.utils.sync`? # It's much nicer to use asyncio, because among other things it gives us typechecking. return asyncio.run_coroutine_threadsafe(coro, self.loop).result()