-
-
Notifications
You must be signed in to change notification settings - Fork 718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speculative task assignment #4264
base: main
Are you sure you want to change the base?
Changes from all commits
acf4705
fc5e70e
307d3ef
4f72097
0f49e7a
8cbd0ce
98f741f
756bf1b
3283c53
5400f20
74cfd1a
cebe850
ee023a8
0f9b767
3776dc5
1156345
b3b19c0
2e1b5e6
cd66c56
fc2d746
b609e8e
c2ced34
465740a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -128,7 +128,15 @@ def cclass(cls): | |
EventExtension, | ||
] | ||
|
||
ALL_TASK_STATES = {"released", "waiting", "no-worker", "processing", "erred", "memory"} | ||
ALL_TASK_STATES = { | ||
"released", | ||
"waiting", | ||
"no-worker", | ||
"processing", | ||
"erred", | ||
"memory", | ||
"speculative", | ||
} | ||
|
||
|
||
@cclass | ||
|
@@ -1401,6 +1409,34 @@ def __repr__(self): | |
return "%s(%s)" % (self.__class__, set(self)) | ||
|
||
|
||
def _recommend_speculative_assignment(ts): | ||
""" | ||
Recommend speculative assignment for dependent (child) task IFF: | ||
|
||
- Current task only has a single dependent task | ||
- All dependencies of that child task are present / processing on the same worker | ||
|
||
This function is called from two transition functions: | ||
- transition_waiting_processing | ||
- transition_waiting_speculative | ||
|
||
As each task passes through either of these (nearly every task goes through | ||
waiting -> processing), this function returns a recommendation (or no | ||
recommendation) for how to process the dependent task. | ||
|
||
""" | ||
|
||
if ( | ||
len(ts.dependents) == 1 | ||
and len({dts.processing_on for dts in list(ts.dependents)[0].dependencies}) == 1 | ||
and not ts.actor | ||
): | ||
return {list(ts.dependents)[0].key: "speculative"} | ||
|
||
else: | ||
return {} | ||
|
||
|
||
def _legacy_task_key_set(tasks): | ||
""" | ||
Transform a set of task states into a set of task keys. | ||
|
@@ -1808,9 +1844,13 @@ def __init__( | |
|
||
self._transitions = { | ||
("released", "waiting"): self.transition_released_waiting, | ||
("released", "speculative"): self.transition_waiting_speculative, | ||
("waiting", "released"): self.transition_waiting_released, | ||
("waiting", "processing"): self.transition_waiting_processing, | ||
("waiting", "memory"): self.transition_waiting_memory, | ||
("waiting", "speculative"): self.transition_waiting_speculative, | ||
("speculative", "processing"): self.transition_speculative_processing, | ||
("speculative", "released"): self.transition_speculative_released, | ||
("processing", "released"): self.transition_processing_released, | ||
("processing", "memory"): self.transition_processing_memory, | ||
("processing", "erred"): self.transition_processing_erred, | ||
|
@@ -2953,6 +2993,12 @@ def validate_waiting(self, key): | |
assert bool(dts._who_has) + (dts in ts._waiting_on) == 1 | ||
assert ts in dts._waiters # XXX even if dts._who_has? | ||
|
||
def validate_speculative(self, key): | ||
ts = self.tasks[key] | ||
assert len({dts.processing_on for dts in ts.waiting_on}) == 1 | ||
ws = ts.processing_on | ||
assert ws | ||
|
||
def validate_processing(self, key): | ||
ts: TaskState = self.tasks[key] | ||
dts: TaskState | ||
|
@@ -3202,14 +3248,20 @@ def send_task_to_worker(self, worker, ts: TaskState, duration=None): | |
msg["resource_restrictions"] = ts._resource_restrictions | ||
if ts._actor: | ||
msg["actor"] = True | ||
if ts.state == "speculative": | ||
msg["speculative"] = True | ||
|
||
deps: set = ts._dependencies | ||
if deps: | ||
if deps and not ts.state == "speculative": | ||
msg["who_has"] = { | ||
dts._key: [ws._address for ws in dts._who_has] for dts in deps | ||
} | ||
msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} | ||
|
||
if self.validate: | ||
assert all(msg["who_has"].values()) | ||
elif ts.state == "speculative": | ||
msg["who_has"] = {dts._key: dts._processing_on.address for dts in deps} | ||
msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} | ||
if self.validate: | ||
assert all(msg["who_has"].values()) | ||
|
||
|
@@ -4379,11 +4431,13 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState): | |
Get the estimated communication cost (in s.) to compute the task | ||
on the given worker. | ||
""" | ||
# TODO: How is it possible for nbytes to be None when there's a getter that is supposed to | ||
# stop that from happening? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like an excellent "who-dunnit?" mystery to enjoy |
||
dts: TaskState | ||
deps: set = ts._dependencies - ws._has_what | ||
nbytes: Py_ssize_t = 0 | ||
for dts in deps: | ||
nbytes += dts._nbytes | ||
nbytes += dts._nbytes or DEFAULT_DATA_SIZE | ||
return nbytes / self.bandwidth | ||
|
||
def get_task_duration(self, ts: TaskState, default=None): | ||
|
@@ -4696,6 +4750,98 @@ def decide_worker(self, ts: TaskState) -> WorkerState: | |
|
||
return ws | ||
|
||
def transition_waiting_speculative(self, key): | ||
try: | ||
|
||
ts = self.tasks[key] | ||
|
||
if self.validate: | ||
# All dependencies are on the same worker | ||
assert len({dts.processing_on for dts in ts.dependencies}) == 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We probably also want that at least one of the dependencies is in a "processing" state |
||
|
||
ws = list(ts.dependencies)[0].processing_on | ||
|
||
duration = self.get_task_duration(ts) | ||
|
||
# There are no comm costs if this is speculative | ||
ws.processing[ts] = duration | ||
ts.processing_on = ws | ||
ws.occupancy += duration | ||
self.total_occupancy += duration | ||
ts.state = "speculative" | ||
self.consume_resources(ts, ws) | ||
self.check_idle_saturated(ws) | ||
self.n_tasks += 1 | ||
|
||
for dts in ts.dependencies: | ||
ts.waiting_on.add(dts) | ||
dts.waiters.add(ts) | ||
|
||
self.send_task_to_worker(ws.address, ts) | ||
|
||
return _recommend_speculative_assignment(ts) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, so we would send long chains down to a worker all at once? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's what currently happens and I can see it leading to less than stellar performance, especially if a cluster is scaling up -- There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I think it's fine. Let's leave it like this until there is a problem. |
||
|
||
return {} | ||
except Exception as e: | ||
logger.exception(e) | ||
if LOG_PDB: | ||
import pdb | ||
|
||
pdb.set_trace() | ||
raise | ||
|
||
def transition_speculative_processing(self, key): | ||
try: | ||
ts = self.tasks[key] | ||
|
||
if self.validate: | ||
assert ts.state == "speculative" | ||
assert not ts.waiting_on | ||
|
||
ts.state = "processing" | ||
|
||
return {} | ||
|
||
except Exception as e: | ||
logger.exception(e) | ||
if LOG_PDB: | ||
import pdb | ||
|
||
pdb.set_trace() | ||
raise | ||
|
||
def transition_speculative_released(self, key): | ||
try: | ||
ts = self.tasks[key] | ||
|
||
if self.validate: | ||
assert ts.state == "speculative" | ||
assert ts.processing_on | ||
|
||
ts.waiters.clear() | ||
ts.waiting_on.clear() | ||
|
||
self._remove_from_processing( | ||
ts, send_worker_msg={"op": "release-task", "key": key} | ||
) | ||
ts.state = "released" | ||
|
||
recommendations = {} | ||
|
||
for dts in ts.dependents: | ||
if dts.state in ("erred", "speculative"): | ||
recommendations[dts.key] = "released" | ||
|
||
return recommendations | ||
|
||
except Exception as e: | ||
logger.exception(e) | ||
if LOG_PDB: | ||
import pdb | ||
|
||
pdb.set_trace() | ||
raise | ||
|
||
def transition_waiting_processing(self, key): | ||
try: | ||
ts: TaskState = self.tasks[key] | ||
|
@@ -4735,7 +4881,8 @@ def transition_waiting_processing(self, key): | |
|
||
self.send_task_to_worker(worker, ts, duration) | ||
|
||
return {} | ||
return _recommend_speculative_assignment(ts) | ||
|
||
except Exception as e: | ||
logger.exception(e) | ||
if LOG_PDB: | ||
|
@@ -6176,7 +6323,7 @@ def validate_task_state(ts: TaskState): | |
assert dts.state != "forgotten" | ||
|
||
for dts in ts._waiters: | ||
assert dts.state in ("waiting", "processing"), ( | ||
assert dts.state in ("waiting", "processing", "speculative"), ( | ||
"waiter not in play", | ||
str(ts), | ||
str(dts), | ||
|
@@ -6190,11 +6337,16 @@ def validate_task_state(ts: TaskState): | |
) | ||
assert dts.state != "forgotten" | ||
|
||
assert (ts._processing_on is not None) == (ts.state == "processing") | ||
assert (ts._processing_on is not None) == ( | ||
ts.state in ("processing", "speculative") | ||
) | ||
assert bool(ts._who_has) == (ts.state == "memory"), (ts, ts._who_has) | ||
|
||
if ts.state == "processing": | ||
assert all([dts._who_has for dts in ts._dependencies]), ( | ||
assert ( | ||
all(dts._who_has for dts in ts._dependencies) | ||
or len({dts._processing_on for dts in ts._dependencies}) == 1 | ||
), ( | ||
"task processing without all deps", | ||
str(ts), | ||
str(ts._dependencies), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import asyncio | ||
|
||
from dask import delayed | ||
|
||
from distributed.utils_test import gen_cluster, inc, dec, slowinc | ||
|
||
|
||
@gen_cluster(client=True, config={"dask.optimization.fuse.active": False}) | ||
async def test_speculative_assignment_simple(c, s, a, b): | ||
x = delayed(slowinc)(1) | ||
y = delayed(inc)(x) | ||
z = delayed(dec)(y) | ||
|
||
result = c.compute(z) | ||
while not a.tasks: | ||
await asyncio.sleep(0.001) | ||
|
||
assert x.key in a.tasks | ||
assert y.key in a.tasks | ||
assert z.key in a.tasks | ||
|
||
assert a.tasks[y.key].state == "speculative" | ||
assert a.tasks[z.key].state == "speculative" | ||
|
||
assert await result == 2 | ||
|
||
|
||
@gen_cluster(client=True, config={"dask.optimization.fuse.active": False}) | ||
async def test_spec_assign_all_dependencies(c, s, a, b): | ||
x1 = await c.scatter([1], workers=a.address) | ||
x2 = await c.scatter([2], workers=a.address) | ||
# not spec assigned: data already present | ||
x1 = delayed(inc)(x1[0]) | ||
# not spec assigned: data already present | ||
x2 = delayed(inc)(x2[0]) | ||
# spec assigned (two dependencies on same worker) | ||
x1x2 = x1 + x2 | ||
# spec assigned | ||
z = delayed(dec)(x1x2) | ||
|
||
zz = c.compute(z) | ||
result = await zz | ||
assert result == 4 | ||
assert (x1.key, "waiting", "ready") in a.story(x1.key) | ||
assert (x2.key, "waiting", "ready") in a.story(x2.key) | ||
assert (x1x2.key, "speculative", "ready") in a.story(x1x2.key) | ||
assert (z.key, "speculative", "ready") in a.story(z.key) | ||
|
||
|
||
@gen_cluster(client=True, config={"dask.optimization.fuse.active": False}) | ||
async def test_spec_assign_intermittent(c, s, a, b): | ||
""" | ||
d | ||
/ \ | ||
e h # no spec | ||
| | | ||
f i # both spec | ||
| | | ||
g j # both spec | ||
\ / # noqa | ||
k # spec | ||
|
||
""" | ||
|
||
d = await c.scatter([1]) | ||
|
||
e = delayed(inc)(d[0]) | ||
f = delayed(inc)(e) | ||
g = delayed(inc)(f) | ||
|
||
h = delayed(dec)(d[0]) | ||
i = delayed(dec)(h) | ||
j = delayed(dec)(i) | ||
|
||
k = j + g | ||
|
||
l = c.compute(k) | ||
result = await l | ||
assert result == 2 | ||
|
||
if b.story(h.key): | ||
worker = b | ||
else: | ||
worker = a | ||
|
||
assert (h.key, "waiting", "ready") in worker.story(h.key) | ||
assert (i.key, "speculative", "ready") in worker.story(i.key) | ||
assert (j.key, "speculative", "ready") in worker.story(j.key) | ||
|
||
if a.story(e.key): | ||
worker = a | ||
else: | ||
worker = b | ||
|
||
assert (e.key, "waiting", "ready") in worker.story(e.key) | ||
assert (f.key, "speculative", "ready") in worker.story(f.key) | ||
assert (g.key, "speculative", "ready") in worker.story(g.key) | ||
|
||
if a.story(k.key): | ||
assert (k.key, "speculative", "ready") in a.story(k.key) | ||
if b.story(k.key): | ||
assert (k.key, "speculative", "ready") in b.story(k.key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll eventually want to see what happens when a speculative task gets pushed back to released/waiting again, such as if a worker goes down. The catch all behavior is, I think, to move things back to released. I think that if you handle that that things should be ok, we may also want a speculative->waiting though. Same with errors. I suspect that when a processing task errs it will mark all of its dependents as errored, and so the scheduler will try to enact a speculative->erred transition.
In some of these it may be that the current
self.transition_processing_*
methods will work without modification, but we'll probably have to verify this.