Skip to content

Commit

Permalink
Refactor wait_for_state() (#6581)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Jun 16, 2022
1 parent 29dae02 commit ab46f6a
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 57 deletions.
40 changes: 13 additions & 27 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,16 @@
from distributed import Event, Lock, Worker
from distributed.client import wait
from distributed.utils_test import (
BlockedGetData,
_LockedCommPool,
assert_story,
gen_cluster,
inc,
slowinc,
wait_for_state,
)


async def wait_for_state(key, state, dask_worker):
while key not in dask_worker.tasks or dask_worker.tasks[key].state != state:
await asyncio.sleep(0.005)


async def wait_for_cancelled(key, dask_worker):
while key in dask_worker.tasks:
if dask_worker.tasks[key].state == "cancelled":
Expand Down Expand Up @@ -251,23 +248,13 @@ async def get_data(self, comm, *args, **kwargs):

@gen_cluster(client=True, nthreads=[("", 1)])
async def test_in_flight_lost_after_resumed(c, s, b):
block_get_data = asyncio.Lock()
in_get_data = asyncio.Event()

await block_get_data.acquire()
lock_executing = Lock()

def block_execution(lock):
with lock:
return 1
lock.acquire()
return 1

class BlockedGetData(Worker):
async def get_data(self, comm, *args, **kwargs):
in_get_data.set()
async with block_get_data:
return await super().get_data(comm, *args, **kwargs)

async with BlockedGetData(s.address, name="blocked-get-dataworker") as a:
async with BlockedGetData(s.address) as a:
fut1 = c.submit(
block_execution,
lock_executing,
Expand All @@ -277,35 +264,34 @@ async def get_data(self, comm, *args, **kwargs):
# Ensure fut1 is in memory but block any further execution afterwards to
# ensure we control when the recomputation happens
await wait(fut1)
await lock_executing.acquire()
fut2 = c.submit(inc, fut1, workers=[b.address], key="fut2")

# This ensures that B already fetches the task, i.e. after this the task
# is guaranteed to be in flight
await in_get_data.wait()
assert fut1.key in b.tasks
assert b.tasks[fut1.key].state == "flight"
await a.in_get_data.wait()
assert fut1.key in b.state.tasks
assert b.state.tasks[fut1.key].state == "flight"

s.set_restrictions({fut1.key: [a.address, b.address]})
# It is removed, i.e. get_data is guaranteed to fail and f1 is scheduled
# to be recomputed on B
await s.remove_worker(a.address, stimulus_id="foo", close=False, safe=True)

while not b.tasks[fut1.key].state == "resumed":
while not b.state.tasks[fut1.key].state == "resumed":
await asyncio.sleep(0.01)

fut1.release()
fut2.release()

while not b.tasks[fut1.key].state == "cancelled":
while not b.state.tasks[fut1.key].state == "cancelled":
await asyncio.sleep(0.01)

block_get_data.release()
while b.tasks:
a.block_get_data.set()
while b.state.tasks:
await asyncio.sleep(0.01)

assert_story(
b.story(fut1.key),
b.state.story(fut1.key),
expect=[
# The initial free-keys is rejected
("free-keys", (fut1.key,)),
Expand Down
32 changes: 31 additions & 1 deletion distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import dask.config

from distributed import Client, Nanny, Scheduler, Worker, config, default_client
from distributed import Client, Event, Nanny, Scheduler, Worker, config, default_client
from distributed.batched import BatchedSend
from distributed.comm.core import connect
from distributed.compatibility import WINDOWS
Expand All @@ -41,6 +41,7 @@
popen,
raises_with_cause,
tls_only_security,
wait_for_state,
)
from distributed.worker import fail_hard
from distributed.worker_state_machine import (
Expand Down Expand Up @@ -912,3 +913,32 @@ async def test_freeze_batched_send():
assert b.comm is comm
assert await comm.read() == ("baz",)
assert e.count == 3


@gen_cluster(client=True, nthreads=[("", 1)], timeout=2)
async def test_wait_for_state(c, s, a, capsys):
ev = Event()
x = c.submit(lambda ev: ev.wait(), ev, key="x")

await asyncio.gather(
wait_for_state("x", "processing", s),
wait_for_state("x", "executing", a),
c.run(wait_for_state, "x", "executing"),
)

await ev.set()

await asyncio.gather(
wait_for_state("x", "memory", s),
wait_for_state("x", "memory", a),
c.run(wait_for_state, "x", "memory"),
)

with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(wait_for_state("x", "bad_state", s), timeout=0.1)
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(wait_for_state("y", "memory", s), timeout=0.1)
assert capsys.readouterr().out == (
f"tasks[x].state='memory' on {s.address}; expected state='bad_state'\n"
f"tasks[y] not found on {s.address}\n"
)
19 changes: 6 additions & 13 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
raises_with_cause,
slowinc,
slowsum,
wait_for_state,
)
from distributed.worker import (
Worker,
Expand Down Expand Up @@ -3060,14 +3061,6 @@ async def test_worker_status_sync(s, a):
]


async def _wait_for_state(key: str, worker: Worker, state: str):
# Keep the sleep interval at 0 since the tests using this are very sensitive
# about timing. they intend to capture loop cycles after this specific
# condition was set
while key not in worker.tasks or worker.tasks[key].state != state:
await asyncio.sleep(0)


@gen_cluster(client=True)
async def test_task_flight_compute_oserror(c, s, a, b):
"""If the remote worker dies while a task is in flight, the task may be
Expand Down Expand Up @@ -3161,7 +3154,7 @@ async def test_gather_dep_cancelled_rescheduled(c, s):
fut4 = c.submit(sum, fut1, fut2, workers=[b.address], key="f4")
fut3 = c.submit(inc, fut1, workers=[b.address], key="f3")

await _wait_for_state(fut2.key, b, "flight")
await wait_for_state(fut2.key, "flight", b)
await b.in_gather_dep.wait()

fut4.release()
Expand All @@ -3174,7 +3167,7 @@ async def test_gather_dep_cancelled_rescheduled(c, s):
await a.in_get_data.wait()

fut4 = c.submit(sum, [fut1, fut2], workers=[b.address], key="f4")
await _wait_for_state(fut2.key, b, "flight")
await wait_for_state(fut2.key, "flight", b)

a.block_get_data.set()
await wait([fut3, fut4])
Expand Down Expand Up @@ -3227,7 +3220,7 @@ async def test_gather_dep_no_longer_in_flight_tasks(c, s, a):
fut1 = c.submit(inc, 1, workers=[a.address], key="f1")
fut2 = c.submit(sum, fut1, fut1, workers=[b.address], key="f2")

await _wait_for_state(fut1.key, b, "flight")
await wait_for_state(fut1.key, "flight", b)
await b.in_gather_dep.wait()

fut2.release()
Expand Down Expand Up @@ -3266,7 +3259,7 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker(
async with BlockedGatherDep(s.address, name="b") as b:
fut3 = c.submit(inc, fut2, workers=[b.address], key="f3")

await _wait_for_state(fut2.key, b, "flight")
await wait_for_state(fut2.key, "flight", b)

s.set_restrictions(worker={fut1B.key: a.address, fut2.key: b.address})

Expand All @@ -3276,7 +3269,7 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker(
address=x.address, safe=True, close=close_worker, stimulus_id="test"
)

await _wait_for_state(fut2.key, b, intermediate_state)
await wait_for_state(fut2.key, intermediate_state, b, interval=0)

b.block_gather_dep.set()
await fut3
Expand Down
24 changes: 8 additions & 16 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
BlockedGetData,
_LockedCommPool,
assert_story,
clean,
freeze_data_fetching,
gen_cluster,
inc,
wait_for_state,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
Expand All @@ -35,27 +35,19 @@
SerializedTask,
StateMachineEvent,
TaskState,
TaskStateState,
UpdateDataEvent,
WorkerState,
merge_recs_instructions,
)


async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) -> None:
while key not in dask_worker.tasks or dask_worker.tasks[key].state != state:
await asyncio.sleep(0.005)


@clean()
def test_task_state_tracking():
with clean():
x = TaskState("x")
assert len(TaskState._instances) == 1
assert first(TaskState._instances) == x

del x
assert len(TaskState._instances) == 0
def test_TaskState_tracking(cleanup):
gc.collect()
x = TaskState("x")
assert len(TaskState._instances) == 1
assert first(TaskState._instances) == x
del x
assert len(TaskState._instances) == 0


def test_TaskState_get_nbytes():
Expand Down
27 changes: 27 additions & 0 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2392,3 +2392,30 @@ def freeze_batched_send(bcomm: BatchedSend) -> Iterator[LockedComm]:
finally:
write_event.set()
bcomm.comm = orig_comm


async def wait_for_state(
key: str, state: str, dask_worker: Worker | Scheduler, *, interval: float = 0.01
) -> None:
if isinstance(dask_worker, Worker):
tasks = dask_worker.state.tasks
elif isinstance(dask_worker, Scheduler):
tasks = dask_worker.tasks
else:
raise TypeError(dask_worker) # pragma: nocover

try:
while key not in tasks or tasks[key].state != state:
await asyncio.sleep(interval)
except (asyncio.CancelledError, asyncio.TimeoutError):
if key in tasks:
msg = (
f"tasks[{key}].state={tasks[key].state!r} on {dask_worker.address}; "
f"expected {state=}"
)
else:
msg = f"tasks[{key}] not found on {dask_worker.address}"
# 99% of the times this is triggered by @gen_cluster timeout, so raising the
# message as an exception wouldn't work.
print(msg)
raise

0 comments on commit ab46f6a

Please sign in to comment.