From 9315e4c78a3722c936d5db0fc3918d63f12a6ce4 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Fri, 19 Nov 2021 15:34:29 -0700 Subject: [PATCH] Remove waiting for shuffle_init See https://github.com/dask/distributed/pull/5524#discussion_r752647896. Since messages from scheduler to workers remain ordered in `BatchedSend` (and TCP preserves ordering), we should be able to count on the `shuffle_init` always hitting the worker before the `add_partition` does, so long as we trust the transition logic of our plugin. --- distributed/shuffle/graph.py | 13 ++++---- distributed/shuffle/shuffle_worker.py | 42 +++++++------------------ distributed/shuffle/tests/test_graph.py | 1 - 3 files changed, 17 insertions(+), 39 deletions(-) diff --git a/distributed/shuffle/graph.py b/distributed/shuffle/graph.py index db7ab7efde..69a797fe9d 100644 --- a/distributed/shuffle/graph.py +++ b/distributed/shuffle/graph.py @@ -41,18 +41,17 @@ def shuffle_transfer( ext.sync(ext.add_partition(data, id, npartitions, column)) -def shuffle_unpack( - id: ShuffleId, i: int, empty: pd.DataFrame, barrier=None -) -> pd.DataFrame: - ext = get_shuffle_extension() - return ext.sync(ext.get_output_partition(id, i, empty)) - - def shuffle_barrier(id: ShuffleId, transfers: list[None]) -> None: ext = get_shuffle_extension() ext.sync(ext.barrier(id)) +def shuffle_unpack( + id: ShuffleId, i: int, empty: pd.DataFrame, barrier=None +) -> pd.DataFrame: + return get_shuffle_extension().get_output_partition(id, i, empty) + + def rearrange_by_column_p2p( df: DataFrame, column: str, diff --git a/distributed/shuffle/shuffle_worker.py b/distributed/shuffle/shuffle_worker.py index 9ded0a294e..d93de519e3 100644 --- a/distributed/shuffle/shuffle_worker.py +++ b/distributed/shuffle/shuffle_worker.py @@ -29,7 +29,6 @@ class ShuffleWorkerExtension: "Extend the Worker with routes and state for peer-to-peer shuffles" worker: Worker shuffles: dict[ShuffleId, ShuffleState] - waiting_for_metadata: dict[ShuffleId, asyncio.Event] output_data: defaultdict[ShuffleId, defaultdict[int, list[pd.DataFrame]]] def __init__(self, worker: Worker) -> None: @@ -42,7 +41,6 @@ def __init__(self, worker: Worker) -> None: # Initialize self.worker: Worker = worker self.shuffles = {} - self.waiting_for_metadata = {} self.output_data = defaultdict(lambda: defaultdict(list)) # Handlers @@ -63,11 +61,6 @@ def shuffle_init(self, id: ShuffleId, workers: list[str], n_out_tasks: int) -> N n_out_tasks, npartitions_for(self.worker.address, n_out_tasks, workers), ) - try: - # Invariant: if `waiting_for_metadata` event is set, key is already in `shuffles` - self.waiting_for_metadata[id].set() - except KeyError: - pass def shuffle_receive( self, @@ -99,7 +92,7 @@ def shuffle_receive( self.output_data[id][output_partition].append(data) - async def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None: + def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None: """ Handler: note that the barrier task has been reached. Called by a peer. @@ -107,7 +100,7 @@ async def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None: Must be called exactly once per ID. Blocks until `shuffle_init` has been called. """ - state = await self.get_shuffle(id) + state = self.get_shuffle(id) assert not state.barrier_reached, f"`inputs_done` called again for {id}" state.barrier_reached = True @@ -132,7 +125,7 @@ async def add_partition( Calling after the barrier task is an error. """ # Block until scheduler has called init - state = await self.get_shuffle(id) + state = self.get_shuffle(id) assert not state.barrier_reached, f"`add_partition` for {id} after barrier" if npartitions != state.npartitions: @@ -155,7 +148,7 @@ async def barrier(self, id: ShuffleId) -> None: Blocks until `shuffle_init` has been called (on all workers). Calling this before all partitions have been added will cause `add_partition` to fail. """ - state = await self.get_shuffle(id) + state = self.get_shuffle(id) assert not state.barrier_reached, f"`barrier` for {id} called multiple times" # Call `shuffle_inputs_done` on peers. @@ -168,7 +161,7 @@ async def barrier(self, id: ShuffleId) -> None: ), ) - async def get_output_partition( + def get_output_partition( self, id: ShuffleId, i: int, empty: pd.DataFrame ) -> pd.DataFrame: """ @@ -179,7 +172,7 @@ async def get_output_partition( Using an unknown ``shuffle_id`` is an error. Requesting a partition which doesn't belong on this worker, or has already been retrieved, is an error. """ - state = await self.get_shuffle(id) # should never have to wait + state = self.get_shuffle(id) assert state.barrier_reached, f"`get_output_partition` for {id} before barrier" assert ( state.out_parts_left > 0 @@ -217,32 +210,19 @@ def remove(self, id: ShuffleId) -> None: not state.out_parts_left ), f"Removed {id} with {state.out_parts_left} outputs left" - event = self.waiting_for_metadata.pop(id, None) - if event: - assert event.is_set(), f"Removed {id} while still waiting for metadata" - data = self.output_data.pop(id, None) assert ( not data ), f"Removed {id}, which still has data for output partitions {list(data)}" - async def get_shuffle(self, id: ShuffleId) -> ShuffleState: - "Get the `ShuffleState`, blocking until it's been received from the scheduler." + def get_shuffle(self, id: ShuffleId) -> ShuffleState: + "Get the `ShuffleState` by ID, raise ValueError if it's not registered." try: return self.shuffles[id] except KeyError: - event = self.waiting_for_metadata.setdefault(id, asyncio.Event()) - try: - await asyncio.wait_for(event.wait(), timeout=5) # TODO config - except TimeoutError: - raise TimeoutError( - f"Timed out waiting for scheduler to start shuffle {id}" - ) from None - # Invariant: once `waiting_for_metadata` event is set, key is already in `shuffles`. - # And once key is in `shuffles`, no `get_shuffle` will create a new event. - # So we can safely remove the event now. - self.waiting_for_metadata.pop(id, None) - return self.shuffles[id] + raise ValueError( + f"Shuffle {id!r} is not registered on worker {self.worker.address}" + ) from None async def send_partition( self, diff --git a/distributed/shuffle/tests/test_graph.py b/distributed/shuffle/tests/test_graph.py index 98e1c99f21..048ea94978 100644 --- a/distributed/shuffle/tests/test_graph.py +++ b/distributed/shuffle/tests/test_graph.py @@ -93,7 +93,6 @@ async def test_basic_state(c: Client, s: Scheduler, *workers: Worker): for ext in exts: assert not ext.shuffles assert not ext.output_data - assert not ext.waiting_for_metadata plugin = s.plugins[ShuffleSchedulerPlugin.name] assert isinstance(plugin, ShuffleSchedulerPlugin)