diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9c25d2c85b..3217436d09 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -18,7 +18,7 @@ from datetime import timedelta from functools import partial from numbers import Number -from typing import Optional +from typing import Optional, ValuesView import psutil import sortedcontainers @@ -2387,7 +2387,10 @@ def decide_worker(self, ts: TaskState) -> WorkerState: if ts._dependencies or valid_workers is not None: ws = decide_worker( ts, - self._workers_dv.values(), + dict.values(self._workers_dv), + dict.values(self._idle_dv), + # ^ NOTE: For performance, these must be actual `dict_values`, not `SortedDictValues`. + # In Cython, `_workers_dv` is a plain dict, but in plain Python, it's still a `SortedDict`. valid_workers, partial(self.worker_objective, ts), ) @@ -7623,14 +7626,19 @@ def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState): @cfunc @exceptval(check=False) def decide_worker( - ts: TaskState, all_workers, valid_workers: set, objective + ts: TaskState, + all_workers: ValuesView, + idle_workers: ValuesView, + valid_workers: set, + objective, ) -> WorkerState: """ Decide which worker should take task *ts*. - We choose the worker that has the data on which *ts* depends. + We consider all workers which hold dependencies of *ts*, + plus a sample of up to 10 random workers (with preference for idle ones). - If several workers have dependencies then we choose the less-busy worker. + From those, we choose the worker where the *objective* function is minimized. Optionally provide *valid_workers* of where jobs are allowed to occur (if all workers are allowed to take the task, pass None instead). @@ -7640,6 +7648,8 @@ def decide_worker( of bytes sent between workers. This is determined by calling the *objective* function. """ + # NOTE: `all_workers` and `idle_workers` must be plain `dict_values` objects, + # not a `SortedValuesView`, which is much slower to iterate over. ws: WorkerState = None wws: WorkerState dts: TaskState @@ -7649,7 +7659,17 @@ def decide_worker( if ts._actor: candidates = set(all_workers) else: + # Select all workers holding deps of this task candidates = {wws for dts in deps for wws in dts._who_has} + # Add up to 10 random workers into `candidates`, preferring idle ones. + worker_pool = valid_workers if valid_workers is not None else all_workers + if len(candidates) < len(worker_pool): + sample_from = idle_workers or worker_pool + candidates.update( + random.choices(list(sample_from), k=min(10, len(sample_from))) + if len(sample_from) > 10 + else sample_from + ) if valid_workers is None: if not candidates: candidates = set(all_workers) @@ -7659,7 +7679,7 @@ def decide_worker( candidates = valid_workers if not candidates: if ts._loose_restrictions: - ws = decide_worker(ts, all_workers, None, objective) + ws = decide_worker(ts, all_workers, idle_workers, None, objective) return ws ncandidates: Py_ssize_t = len(candidates) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 6a15109178..340965019f 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -101,14 +101,16 @@ async def test_recompute_released_results(c, s, a, b): assert result == 1 -@gen_cluster(client=True) +@gen_cluster(client=True, config={"distributed.scheduler.bandwidth": "1Mb"}) async def test_decide_worker_with_many_independent_leaves(c, s, a, b): + # Make data large to penalize scheduling dependent tasks on other workers + ballast = b"\0" * int(s.bandwidth) xs = await asyncio.gather( - c.scatter(list(range(0, 100, 2)), workers=a.address), - c.scatter(list(range(1, 100, 2)), workers=b.address), + c.scatter([bytes(i) + ballast for i in range(0, 100, 2)], workers=a.address), + c.scatter([bytes(i) + ballast for i in range(1, 100, 2)], workers=b.address), ) xs = list(concat(zip(*xs))) - ys = [delayed(inc)(x) for x in xs] + ys = [delayed(lambda s: s[0])(x) for x in xs] y2s = c.persist(ys) await wait(y2s) @@ -127,6 +129,27 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c): assert x.key in a.data or x.key in b.data +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 3, + config={"distributed.scheduler.work-stealing": False}, +) +async def test_decide_worker_select_candidate_holding_no_deps(client, s, a, b, c): + await client.submit(slowinc, 10, delay=0.1) # learn that slowinc is slow + root = await client.scatter(1) + assert sum(root.key in worker.data for worker in [a, b, c]) == 1 + + start = time() + tasks = client.map(slowinc, [root] * 6, delay=0.1, pure=False) + await wait(tasks) + elapsed = time() - start + + assert elapsed <= 4 + assert all(root.key in worker.data for worker in [a, b, c]), [ + list(worker.data.keys()) for worker in [a, b, c] + ] + + @pytest.mark.parametrize("ndeps", [0, 1, 4]) @pytest.mark.parametrize( "nthreads", diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index e5e877a2a9..4c0c3886ce 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1079,8 +1079,8 @@ async def test_scheduler_delay(c, s, a, b): @pytest.mark.flaky(reruns=10, reruns_delay=5) -@gen_cluster(client=True) -async def test_statistical_profiling(c, s, a, b): +@gen_cluster(client=True, ncores=[("127.0.0.1", 1)]) +async def test_statistical_profiling(c, s, a): futures = c.map(slowinc, range(10), delay=0.1) await wait(futures)