Skip to content

Commit

Permalink
Remove waiting for shuffle_init
Browse files Browse the repository at this point in the history
See #5524 (comment). Since messages from scheduler to workers remain ordered in `BatchedSend` (and TCP preserves ordering), we should be able to count on the `shuffle_init` always hitting the worker before the `add_partition` does, so long as we trust the transition logic of our plugin.
  • Loading branch information
gjoseph92 committed Nov 19, 2021
1 parent e3170ae commit 9315e4c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 39 deletions.
13 changes: 6 additions & 7 deletions distributed/shuffle/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,17 @@ def shuffle_transfer(
ext.sync(ext.add_partition(data, id, npartitions, column))


def shuffle_unpack(
id: ShuffleId, i: int, empty: pd.DataFrame, barrier=None
) -> pd.DataFrame:
ext = get_shuffle_extension()
return ext.sync(ext.get_output_partition(id, i, empty))


def shuffle_barrier(id: ShuffleId, transfers: list[None]) -> None:
ext = get_shuffle_extension()
ext.sync(ext.barrier(id))


def shuffle_unpack(
id: ShuffleId, i: int, empty: pd.DataFrame, barrier=None
) -> pd.DataFrame:
return get_shuffle_extension().get_output_partition(id, i, empty)


def rearrange_by_column_p2p(
df: DataFrame,
column: str,
Expand Down
42 changes: 11 additions & 31 deletions distributed/shuffle/shuffle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class ShuffleWorkerExtension:
"Extend the Worker with routes and state for peer-to-peer shuffles"
worker: Worker
shuffles: dict[ShuffleId, ShuffleState]
waiting_for_metadata: dict[ShuffleId, asyncio.Event]
output_data: defaultdict[ShuffleId, defaultdict[int, list[pd.DataFrame]]]

def __init__(self, worker: Worker) -> None:
Expand All @@ -42,7 +41,6 @@ def __init__(self, worker: Worker) -> None:
# Initialize
self.worker: Worker = worker
self.shuffles = {}
self.waiting_for_metadata = {}
self.output_data = defaultdict(lambda: defaultdict(list))

# Handlers
Expand All @@ -63,11 +61,6 @@ def shuffle_init(self, id: ShuffleId, workers: list[str], n_out_tasks: int) -> N
n_out_tasks,
npartitions_for(self.worker.address, n_out_tasks, workers),
)
try:
# Invariant: if `waiting_for_metadata` event is set, key is already in `shuffles`
self.waiting_for_metadata[id].set()
except KeyError:
pass

def shuffle_receive(
self,
Expand Down Expand Up @@ -99,15 +92,15 @@ def shuffle_receive(

self.output_data[id][output_partition].append(data)

async def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None:
def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None:
"""
Handler: note that the barrier task has been reached. Called by a peer.
The shuffle will be removed if this worker holds no output partitions for it.
Must be called exactly once per ID. Blocks until `shuffle_init` has been called.
"""
state = await self.get_shuffle(id)
state = self.get_shuffle(id)
assert not state.barrier_reached, f"`inputs_done` called again for {id}"
state.barrier_reached = True

Expand All @@ -132,7 +125,7 @@ async def add_partition(
Calling after the barrier task is an error.
"""
# Block until scheduler has called init
state = await self.get_shuffle(id)
state = self.get_shuffle(id)
assert not state.barrier_reached, f"`add_partition` for {id} after barrier"

if npartitions != state.npartitions:
Expand All @@ -155,7 +148,7 @@ async def barrier(self, id: ShuffleId) -> None:
Blocks until `shuffle_init` has been called (on all workers).
Calling this before all partitions have been added will cause `add_partition` to fail.
"""
state = await self.get_shuffle(id)
state = self.get_shuffle(id)
assert not state.barrier_reached, f"`barrier` for {id} called multiple times"

# Call `shuffle_inputs_done` on peers.
Expand All @@ -168,7 +161,7 @@ async def barrier(self, id: ShuffleId) -> None:
),
)

async def get_output_partition(
def get_output_partition(
self, id: ShuffleId, i: int, empty: pd.DataFrame
) -> pd.DataFrame:
"""
Expand All @@ -179,7 +172,7 @@ async def get_output_partition(
Using an unknown ``shuffle_id`` is an error.
Requesting a partition which doesn't belong on this worker, or has already been retrieved, is an error.
"""
state = await self.get_shuffle(id) # should never have to wait
state = self.get_shuffle(id)
assert state.barrier_reached, f"`get_output_partition` for {id} before barrier"
assert (
state.out_parts_left > 0
Expand Down Expand Up @@ -217,32 +210,19 @@ def remove(self, id: ShuffleId) -> None:
not state.out_parts_left
), f"Removed {id} with {state.out_parts_left} outputs left"

event = self.waiting_for_metadata.pop(id, None)
if event:
assert event.is_set(), f"Removed {id} while still waiting for metadata"

data = self.output_data.pop(id, None)
assert (
not data
), f"Removed {id}, which still has data for output partitions {list(data)}"

async def get_shuffle(self, id: ShuffleId) -> ShuffleState:
"Get the `ShuffleState`, blocking until it's been received from the scheduler."
def get_shuffle(self, id: ShuffleId) -> ShuffleState:
"Get the `ShuffleState` by ID, raise ValueError if it's not registered."
try:
return self.shuffles[id]
except KeyError:
event = self.waiting_for_metadata.setdefault(id, asyncio.Event())
try:
await asyncio.wait_for(event.wait(), timeout=5) # TODO config
except TimeoutError:
raise TimeoutError(
f"Timed out waiting for scheduler to start shuffle {id}"
) from None
# Invariant: once `waiting_for_metadata` event is set, key is already in `shuffles`.
# And once key is in `shuffles`, no `get_shuffle` will create a new event.
# So we can safely remove the event now.
self.waiting_for_metadata.pop(id, None)
return self.shuffles[id]
raise ValueError(
f"Shuffle {id!r} is not registered on worker {self.worker.address}"
) from None

async def send_partition(
self,
Expand Down
1 change: 0 additions & 1 deletion distributed/shuffle/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ async def test_basic_state(c: Client, s: Scheduler, *workers: Worker):
for ext in exts:
assert not ext.shuffles
assert not ext.output_data
assert not ext.waiting_for_metadata

plugin = s.plugins[ShuffleSchedulerPlugin.name]
assert isinstance(plugin, ShuffleSchedulerPlugin)
Expand Down

0 comments on commit 9315e4c

Please sign in to comment.