Skip to content

Commit

Permalink
Light docs
Browse files Browse the repository at this point in the history
  • Loading branch information
gjoseph92 committed Nov 19, 2021
1 parent 9955b60 commit f79985a
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 10 deletions.
2 changes: 1 addition & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def nogil(func):
ActiveMemoryManagerExtension,
MemorySamplerExtension,
]
DEFAULT_PLUGINS: tuple[SchedulerPlugin, ...] = (
DEFAULT_PLUGINS: "tuple[SchedulerPlugin, ...]" = (
(shuffle.ShuffleSchedulerPlugin(),) if shuffle.SHUFFLE_AVAILABLE else ()
)
# ^ TODO this assumes one Scheduler per process; probably a bad idea.
Expand Down
12 changes: 10 additions & 2 deletions distributed/shuffle/shuffle_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async def start(self, scheduler: Scheduler) -> None:
self.scheduler = scheduler

def transfer(self, id: ShuffleId, key: str) -> None:
"Handle a `transfer` task for a shuffle being scheduled"
state = self.shuffles.get(id, None)
if state:
assert (
Expand Down Expand Up @@ -69,6 +70,7 @@ def transfer(self, id: ShuffleId, key: str) -> None:
)

def barrier(self, id: ShuffleId, key: str) -> None:
"Handle a `barrier` task for a shuffle being scheduled"
state = self.shuffles[id]
assert (
not state.barrier_reached
Expand Down Expand Up @@ -108,6 +110,7 @@ def barrier(self, id: ShuffleId, key: str) -> None:
self.output_keys[dts.key] = id

def unpack(self, id: ShuffleId, key: str) -> None:
"Handle an `unpack` task for a shuffle completing"
# Check if all output keys are done

# NOTE: we don't actually need this `unpack` step or tracking output keys;
Expand Down Expand Up @@ -137,6 +140,7 @@ def unpack(self, id: ShuffleId, key: str) -> None:
del self.shuffles[id]

def erred(self, id: ShuffleId, key: str) -> None:
"Handle any task for a shuffle erroring"
try:
state = self.shuffles.pop(id)
except KeyError:
Expand All @@ -151,6 +155,7 @@ def erred(self, id: ShuffleId, key: str) -> None:
del self.output_keys[k]

def transition(self, key: str, start: str, finish: str, *args, **kwargs):
"Watch transitions for keys we care about"
parts = parse_key(key)
if parts and len(parts) == 3:
prefix, group, id = parts
Expand Down Expand Up @@ -184,8 +189,10 @@ def transition(self, key: str, start: str, finish: str, *args, **kwargs):

def worker_for_key(self, key: str, npartitions: int, workers: list[str]) -> str:
"Worker address this task should be assigned to"
# Infer which output partition number this task is fetching by parsing its key
# FIXME this is so brittle.
# Infer which output partition number this task is fetching by parsing its key.
# We have to parse keys, instead of generating the list of expected keys, because
# blockwise fusion means they won't just be `shuffle-unpack-abcde`.
# FIXME this feels very hacky/brittle.
# For example, after `df.set_index(...).to_delayed()`, you could create
# keys that don't have indices in them, and get fused (because they should!).
m = re.match(r"\(.+, (\d+)\)$", key)
Expand All @@ -202,6 +209,7 @@ def worker_for_key(self, key: str, npartitions: int, workers: list[str]) -> str:


def parse_key(key: str) -> list[str] | None:
"Split a shuffle key into its prefix, group, and shuffle ID, or None if not a shuffle key."
if TASK_PREFIX in key[: len(TASK_PREFIX) + 2]:
if key[0] == "(":
key = key_split_group(key)
Expand Down
58 changes: 51 additions & 7 deletions distributed/shuffle/shuffle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def __init__(self, worker: Worker) -> None:
##########

def shuffle_init(self, id: ShuffleId, workers: list[str], n_out_tasks: int) -> None:
"""
Handler: initialize a shuffle. Called by scheduler on all workers.
Must be called exactly once per ID.
"""
if id in self.shuffles:
raise ValueError(
f"Shuffle {id!r} is already registered on worker {self.worker.address}"
Expand All @@ -71,6 +76,12 @@ def shuffle_receive(
output_partition: int,
data: pd.DataFrame,
) -> None:
"""
Handler: receive data from a peer.
The shuffle ID can be unknown.
Calling after the barrier task is an error.
"""
try:
state = self.shuffles[id]
except KeyError:
Expand All @@ -89,6 +100,13 @@ def shuffle_receive(
self.output_data[id][output_partition].append(data)

async def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None:
"""
Handler: note that the barrier task has been reached. Called by a peer.
The shuffle will be removed if this worker holds no output partitions for it.
Must be called exactly once per ID. Blocks until `shuffle_init` has been called.
"""
state = await self.get_shuffle(id)
assert not state.barrier_reached, f"`inputs_done` called again for {id}"
state.barrier_reached = True
Expand All @@ -105,6 +123,15 @@ async def shuffle_inputs_done(self, comm: object, id: ShuffleId) -> None:
async def add_partition(
self, data: pd.DataFrame, id: ShuffleId, npartitions: int, column: str
) -> None:
"""
Task: Hand off an input partition to the extension.
This will block until the extension is ready to receive another input partition.
Also blocks until `shuffle_init` has been called.
Using an unknown ``shuffle_id`` is an error.
Calling after the barrier task is an error.
"""
# Block until scheduler has called init
state = await self.get_shuffle(id)
assert not state.barrier_reached, f"`add_partition` for {id} after barrier"
Expand All @@ -121,9 +148,14 @@ async def add_partition(
await self.send_partition(data, column, id, npartitions, state.workers)

async def barrier(self, id: ShuffleId) -> None:
# NOTE: requires workers list. This is guaranteed because it depends on `add_partition`,
# which got the workers list from the scheduler. So this task must run on a worker where
# `add_partition` has already run.
"""
Task: Note that the barrier task has been reached (`add_partition` called for all input partitions)
Using an unknown ``shuffle_id`` is an error.
Must be called exactly once per ID.
Blocks until `shuffle_init` has been called (on all workers).
Calling this before all partitions have been added will cause `add_partition` to fail.
"""
state = await self.get_shuffle(id)
assert not state.barrier_reached, f"`barrier` for {id} called multiple times"

Expand All @@ -140,13 +172,20 @@ async def barrier(self, id: ShuffleId) -> None:
async def get_output_partition(
self, id: ShuffleId, i: int, empty: pd.DataFrame
) -> pd.DataFrame:
state = self.shuffles[id]
# ^ Don't need to `get_shuffle`; `shuffle_inputs_done` has run already and guarantees it's there
"""
Task: Retrieve a shuffled output partition from the extension.
After calling on the final output partition remaining on this worker, the shuffle will be cleaned up.
Using an unknown ``shuffle_id`` is an error.
Requesting a partition which doesn't belong on this worker, or has already been retrieved, is an error.
"""
state = await self.get_shuffle(id) # should never have to wait
assert state.barrier_reached, f"`get_output_partition` for {id} before barrier"
assert (
state.out_parts_left > 0
), f"No outputs remaining, but requested output partition {i} on {self.worker.address} for {id}."
# ^ Note: this is impossible with our cleanup-on-empty
# ^ Note: impossible with our cleanup-on-empty

worker = worker_for(i, state.npartitions, state.workers)
assert worker == self.worker.address, (
Expand All @@ -172,6 +211,7 @@ async def get_output_partition(
#########

def remove(self, id: ShuffleId) -> None:
"Remove state for this shuffle. The shuffle must be complete and in a valid state."
state = self.shuffles.pop(id)
assert state.barrier_reached, f"Removed {id} before barrier"
assert (
Expand All @@ -187,7 +227,8 @@ def remove(self, id: ShuffleId) -> None:
not data
), f"Removed {id}, which still has data for output partitions {list(data)}"

async def get_shuffle(self, id: ShuffleId):
async def get_shuffle(self, id: ShuffleId) -> ShuffleState:
"Get the `ShuffleState`, blocking until it's been received from the scheduler."
try:
return self.shuffles[id]
except KeyError:
Expand All @@ -212,6 +253,7 @@ async def send_partition(
npartitions: int,
workers: list[str],
) -> None:
"Split up an input partition and send its parts to peers."
tasks = []
# TODO grouping is blocking, should it be offloaded to a thread?
# It mostly doesn't release the GIL though, so may not make much difference.
Expand All @@ -231,9 +273,11 @@ async def send_partition(

@property
def loop(self) -> asyncio.AbstractEventLoop:
"The asyncio event loop for the worker"
return self.worker.loop.asyncio_loop # type: ignore

def sync(self, coro: Coroutine[object, object, T]) -> T:
"Run an async function on the worker's event loop, synchronously from another thread."
# Is it a bad idea not to use `distributed.utils.sync`?
# It's much nicer to use asyncio, because among other things it gives us typechecking.
return asyncio.run_coroutine_threadsafe(coro, self.loop).result()

0 comments on commit f79985a

Please sign in to comment.