diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 4c69b24e7f..f0df021cbd 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1371,14 +1371,24 @@ def func(*args): assert (ntasks_per_worker < ideal * 1.5).all(), (ideal, ntasks_per_worker) -def test_balance_steal_communication_heavy_tasks(): - dependencies = {"a": 10, "b": 10} +@pytest.mark.parametrize( + "cost, ntasks, expect_steal", + [ + pytest.param(10, 5, False, id="not enough work to steal"), + pytest.param(10, 10, True, id="enough work to steal"), + pytest.param(20, 10, False, id="not enough work for increased cost"), + ], +) +def test_balance_expensive_tasks(cost, ntasks, expect_steal): + dependencies = {"a": cost, "b": cost} dependency_placement = [["a"], ["b"]] - task_placement = [[["a", "b"]] * 10, []] + task_placement = [[["a", "b"]] * ntasks, []] def _correct_placement(actual): actual_task_counts = [len(placed) for placed in actual] - return sum(actual_task_counts) == 10 and actual_task_counts[1] > 0 + return sum(actual_task_counts) == ntasks and ( + (actual_task_counts[1] > 0) == expect_steal + ) _run_dependency_balance_test( dependencies,