Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure restart clears taskgroups et al #6944

Merged
merged 2 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,21 @@ def new_task(

return ts

def _clear_task_state(self):

logger.debug("Clear task state")
for collection in [
self.unrunnable,
self.erred_tasks,
self.computations,
self.task_prefixes,
self.task_groups,
self.task_metadata,
self.unknown_durations,
self.replicated_tasks,
]:
collection.clear()

#####################
# State Transitions #
#####################
Expand Down Expand Up @@ -3063,8 +3078,6 @@ def __init__(
resources = {}
aliases = {}

self._task_state_collections = [unrunnable]

self._worker_collections = [
workers,
host_info,
Expand Down Expand Up @@ -3365,7 +3378,7 @@ async def start_unsafe(self):

enable_gc_diagnosis()

self.clear_task_state()
self._clear_task_state()

for addr in self._start_address:
await self.listen(
Expand Down Expand Up @@ -5143,13 +5156,6 @@ async def gather(self, keys, serializers=None):
self.log_event("all", {"action": "gather", "count": len(keys)})
return result

def clear_task_state(self):
# XXX what about nested state such as ClientState.wants_what
# (see also fire-and-forget...)
logger.info("Clear task state")
for collection in self._task_state_collections:
collection.clear()

@log_errors
async def restart(self, client=None, timeout=30, wait_for_workers=True):
"""
Expand Down Expand Up @@ -5189,9 +5195,8 @@ async def restart(self, client=None, timeout=30, wait_for_workers=True):
stimulus_id=stimulus_id,
)

self.clear_task_state()
self.erred_tasks.clear()
self.computations.clear()
self._clear_task_state()
assert not self.tasks
self.report({"op": "restart"})

for plugin in list(self.plugins.values()):
Expand Down
8 changes: 8 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,11 +615,19 @@ async def test_ready_remove_worker(s, a, b):

@gen_cluster(client=True, Worker=Nanny, timeout=60)
async def test_restart(c, s, a, b):
from distributed.scheduler import TaskState
fjetter marked this conversation as resolved.
Show resolved Hide resolved

before = TaskState._instances
Copy link
Collaborator

@crusaderky crusaderky Aug 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're creating a new reference to the same WeakSet. The assertion below is always true because it's the same object.
Also, I suspect this may be prone to race conditions with late/delayed cleanup from previous tests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right of course. I don't even know if this is a sane test, to be honest. We've been looking into TaskState garbage collection before and this is a bit tricky. I'll probably just remove this

futures = c.map(inc, range(20))
await wait(futures)

await s.restart()

assert TaskState._instances == before
assert not s.computations
assert not s.task_prefixes
assert not s.task_groups

assert len(s.workers) == 2

for ws in s.workers.values():
Expand Down