Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow worker to prioritize tasks based on memory production/consumption #5251

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4591,7 +4591,7 @@ def update_graph(
runnables = [ts for ts in touched_tasks if ts._run_spec]
for ts in runnables:
if ts._priority is None and ts._run_spec:
ts._priority = (self.generation, 0)
ts._priority = (0, generation, 0)

if restrictions:
# *restrictions* is a dict keying task ids to lists of
Expand Down
60 changes: 60 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import dask
from dask import delayed
from dask.sizeof import sizeof
from dask.system import CPU_COUNT

import distributed
Expand Down Expand Up @@ -2688,3 +2689,62 @@ async def test_gather_dep_exception_one_task_2(c, s, a, b):
s.handle_missing_data(key="f1", errant_worker=a.address)

await fut2


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_TaskPrefix(c, s, w):
x = c.submit(inc, 0)
y = c.submit(dec, x)
z = c.submit(dec, y)

await z

assert w.prefixes["inc"].bytes_consumed == 0
assert w.prefixes["inc"].bytes_produced == sizeof(1)

assert w.prefixes["dec"].bytes_consumed == sizeof(1) + sizeof(0)
assert w.prefixes["dec"].bytes_produced == sizeof(0) + sizeof(-1)


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_memory_prioritization(c, s, w):
np = pytest.importorskip("numpy")
# learn memory use
x = c.submit(np.arange, 1000000)
y = c.submit(np.sum, x)
await y
del x, y

arrays = c.map(np.arange, range(1_000_000, 1_000_010))
sums = c.map(np.sum, arrays)
await wait(sums)

assert min(w.tasks[future.key].priority for future in arrays) > max(
w.tasks[future.key].priority for future in sums
)


@gen_cluster(
nthreads=[("127.0.0.1", 1)],
client=True,
worker_kwargs={
"memory_spill_fraction": False, # don't spill
"memory_target_fraction": False,
"memory_pause_fraction": False,
},
)
async def test_pause_memory_producing_computations(c, s, a):
memory = psutil.Process().memory_info().rss
a.memory_limit = memory + 400_000_000
np = pytest.importorskip("numpy")

def f(_):
x = np.ones(int(300_000_000), dtype="u1")
assert not any(k.startswith("f") for k in get_worker().data)
get_worker().monitor.update()
return x

data = c.map(f, range(10))
results = c.map(np.sum, data)
del data
await c.gather(results)
120 changes: 116 additions & 4 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,37 @@
SerializedTask = namedtuple("SerializedTask", ["function", "args", "kwargs", "task"])


class TaskPrefix:
def __init__(self, key):
self.key = key
self.bytes_consumed = 0
self.bytes_produced = 0
self.n_processed = 0

def __repr__(self):
return (
f"<TaskPrefix: {self.key} -- {self.n_processed} tasks converted"
f" {format_bytes(self.bytes_consumed)} into"
f" {format_bytes(self.bytes_produced)}"
)

@property
def memory_producing(self):
return (
self.n_processed
and self.bytes_produced / self.n_processed > 1_000_000
and self.bytes_produced > 5 * self.bytes_consumed
)

@property
def memory_consuming(self):
return (
self.n_processed
and self.bytes_consumed / self.n_processed > 1_000_000
and self.bytes_consumed > 5 * self.bytes_produced
)


class TaskState:
"""Holds volatile state relating to an individual Dask task

Expand Down Expand Up @@ -188,6 +219,7 @@ def __init__(self, key, runspec=None):
self.nbytes = None
self.annotations = None
self.scheduler_holds_ref = False
self.prefix = None

def __repr__(self):
return f"<Task {self.key!r} {self.state}>"
Expand Down Expand Up @@ -421,6 +453,7 @@ def __init__(
**kwargs,
):
self.tasks = dict()
self.prefixes = dict()
self.waiting_for_data_count = 0
self.has_what = defaultdict(set)
self.pending_data_per_worker = defaultdict(deque)
Expand Down Expand Up @@ -892,6 +925,19 @@ def identity(self, comm=None):
"memory_limit": self.memory_limit,
}

def get_prefix(self, key: str) -> TaskPrefix:
prefix = key_split(key)
try:
return self.prefixes[prefix]
except KeyError:
tp = TaskPrefix(prefix)
self.prefixes[prefix] = tp
return tp

@property
def rss_memory(self):
return self.monitor.memory[-1]

#####################
# External Services #
#####################
Expand Down Expand Up @@ -1633,8 +1679,18 @@ def add_task(
# This will require a chained recommendation transition system like
# the scheduler

tp = self.get_prefix(ts.key)

if priority is not None:
priority = tuple(priority) + (self.generation,)
user, scheduler_generation, graph = priority
if tp.memory_consuming:
memory = -1
elif tp.memory_producing:
memory = 1
else:
memory = 0
priority = (user, scheduler_generation, memory, graph, self.generation)

self.generation -= 1

if actor:
Expand Down Expand Up @@ -1724,6 +1780,8 @@ def add_task(
pdb.set_trace()
raise

self.ensure_computing()

def transition(self, ts, finish, **kwargs):
if ts is None:
return
Expand Down Expand Up @@ -2027,6 +2085,13 @@ def transition_executing_done(self, ts, value=no_value, report=True):
for d in ts.dependents:
d.waiting_for_data.add(ts.key)

if ts.nbytes is not None:
tp = self.get_prefix(ts.key)
tp.n_processed += 1
tp.bytes_produced += ts.nbytes
for dep in ts.dependencies:
tp.bytes_consumed += dep.nbytes

if report and self.batched_stream and self.status == Status.running:
self.send_task_state_to_scheduler(ts)
else:
Expand Down Expand Up @@ -2882,8 +2947,17 @@ async def _maybe_deserialize_task(self, ts):
raise

def ensure_computing(self):
if self.paused:
if not self.ready and not self.constrained:
return

if self.paused:
priority, key = heapq.heappop(self.ready)
heapq.heappush(self.ready, (priority, key))
if priority[2] >= 0:
return
else:
self.paused = False # memory releasing tasks at the front, unpause

try:
while self.constrained and self.executing_count < self.nthreads:
key = self.constrained[0]
Expand All @@ -2896,18 +2970,56 @@ def ensure_computing(self):
self.transition(ts, "executing")
else:
break
while self.ready and self.executing_count < self.nthreads:
while (
self.ready and self.executing_count < self.nthreads and not self.paused
):
priority, key = heapq.heappop(self.ready)
ts = self.tasks.get(key)

if ts is None:
# It is possible for tasks to be released while still remaining on `ready`
# The scheduler might have re-routed to a new worker and told this worker
# to release. If the task has "disappeared" just continue through the heap
continue
elif ts.key in self.data:
self.transition(ts, "memory")
elif ts.state in READY:
continue

tp = self.get_prefix(key)
if self.validate:
assert ts.state == "ready"

# Need to reevaluate priorities in ready stack, things have changed
if tp.memory_producing and priority[2] < 1:
logger.info("Reevaluating priorities in ready heap")
heapq.heappush(self.ready, (priority, key)) # put it back on
ready = []
while self.ready:
(u, sg, m, g, wg), key = self.ready.pop()
tp = self.get_prefix(key)
m = (
1
if tp.memory_producing
else -1
if tp.memory_consuming
else 0
)
heapq.heappush(ready, ((u, sg, m, g, wg), key))
self.ready[:] = ready
continue

if (
tp.memory_producing
and self.memory_limit
and self.rss_memory > 0.5 * self.memory_limit
):
self.paused = True
heapq.heappush(self.ready, (priority, key)) # put it back on
continue

if ts.state in READY:
self.transition(ts, "executing")

except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand Down