Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure restart clears taskgroups et al #6944

Merged
merged 2 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,21 @@ def new_task(

return ts

def _clear_task_state(self):

logger.debug("Clear task state")
for collection in [
self.unrunnable,
self.erred_tasks,
self.computations,
self.task_prefixes,
self.task_groups,
self.task_metadata,
self.unknown_durations,
self.replicated_tasks,
]:
collection.clear()

#####################
# State Transitions #
#####################
Expand Down Expand Up @@ -3063,8 +3078,6 @@ def __init__(
resources = {}
aliases = {}

self._task_state_collections = [unrunnable]

self._worker_collections = [
workers,
host_info,
Expand Down Expand Up @@ -3365,7 +3378,7 @@ async def start_unsafe(self):

enable_gc_diagnosis()

self.clear_task_state()
self._clear_task_state()

for addr in self._start_address:
await self.listen(
Expand Down Expand Up @@ -5143,13 +5156,6 @@ async def gather(self, keys, serializers=None):
self.log_event("all", {"action": "gather", "count": len(keys)})
return result

def clear_task_state(self):
# XXX what about nested state such as ClientState.wants_what
# (see also fire-and-forget...)
logger.info("Clear task state")
for collection in self._task_state_collections:
collection.clear()

@log_errors
async def restart(self, client=None, timeout=30, wait_for_workers=True):
"""
Expand Down Expand Up @@ -5189,9 +5195,8 @@ async def restart(self, client=None, timeout=30, wait_for_workers=True):
stimulus_id=stimulus_id,
)

self.clear_task_state()
self.erred_tasks.clear()
self.computations.clear()
self._clear_task_state()
assert not self.tasks
self.report({"op": "restart"})

for plugin in list(self.plugins.values()):
Expand Down
5 changes: 5 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,11 +615,16 @@ async def test_ready_remove_worker(s, a, b):

@gen_cluster(client=True, Worker=Nanny, timeout=60)
async def test_restart(c, s, a, b):

futures = c.map(inc, range(20))
await wait(futures)

await s.restart()

assert not s.computations
assert not s.task_prefixes
assert not s.task_groups

assert len(s.workers) == 2

for ws in s.workers.values():
Expand Down