Skip to content

Commit

Permalink
Fix co-assignment logic to consider queued tasks
Browse files Browse the repository at this point in the history
When there were multiple root task groups, we were just re-using the last worker for every batch because it had nothing processing on it.

Unintentionally this also fixes #6597 in some cases (because the first task goes to processing, but we measure queued, so we pick the same worker for both task groups)
  • Loading branch information
gjoseph92 committed Jun 18, 2022
1 parent 590aa5b commit b2e7924
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
17 changes: 15 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,10 +1805,23 @@ def decide_worker(self, ts: TaskState) -> WorkerState | None:

if not (ws and tg.last_worker_tasks_left and ws.address in self.workers):
# Last-used worker is full or unknown; pick a new worker for the next few tasks

# We just pick the worker with the shortest queue (or if queuing is disabled,
# the fewest processing tasks). We've already decided dependencies are unimportant,
# so we don't care to schedule near them.
backlog = operator.attrgetter(
"processing" if math.isinf(self.WORKER_OVERSATURATION) else "queued"
)
ws = min(
(self.idle or self.workers).values(),
key=partial(self.worker_objective, ts),
self.workers.values(), key=lambda ws: len(backlog(ws)) / ws.nthreads
)
if self.validate:
assert ws is not tg.last_worker, (
f"Colocation reused worker {ws} for {tg}, "
f"idle: {list(self.idle.values())}, "
f"workers: {list(self.workers.values())}"
)

tg.last_worker_tasks_left = math.floor(
(len(tg) / self.total_nthreads) * ws.nthreads
)
Expand Down
29 changes: 29 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,35 @@ async def _test_oversaturation_factor(c, s, a, b):
_test_oversaturation_factor()


@pytest.mark.parametrize(
"saturation_factor",
[
0.0,
1.0,
pytest.param(
float("inf"),
marks=pytest.mark.skip("https://github.com/dask/distributed/issues/6597"),
),
],
)
@gen_cluster(
client=True,
nthreads=[("", 2), ("", 1)],
)
async def test_oversaturation_multiple_task_groups(c, s, a, b, saturation_factor):
s.WORKER_OVERSATURATION = saturation_factor
xs = [delayed(i, name=f"x-{i}") for i in range(9)]
ys = [delayed(i, name=f"y-{i}") for i in range(9)]
zs = [x + y for x, y in zip(xs, ys)]

await c.gather(c.compute(zs))

assert not a.incoming_transfer_log, [l["keys"] for l in a.incoming_transfer_log]
assert not b.incoming_transfer_log, [l["keys"] for l in b.incoming_transfer_log]
assert len(a.tasks) == 18
assert len(b.tasks) == 9


@gen_cluster(
client=True,
nthreads=[("", 2)] * 2,
Expand Down

0 comments on commit b2e7924

Please sign in to comment.