-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix][ShardedDDP] Properly handle .eval() mode #587
Conversation
@@ -490,6 +486,9 @@ def _setup_backward_hooks(self) -> None: | |||
# Go through the parameters, attach the hook | |||
self._grad_accs = [] | |||
self._manual_reduce = [] | |||
if not self.training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to remove existing hooks in eval model? Just curious, otherwise we could move this to the top of the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that it was better for correctness, in that if there's a .backward() left somewhere it still respects the eval() setting ? The documentation is not super clear, to me at least https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=eval#torch.nn.Module.train
@@ -624,3 +623,19 @@ def _flush_reduce_calls(self) -> None: | |||
bucket.sent = True | |||
|
|||
self._consume_work_handles() | |||
|
|||
def _detect_train_change(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
trainability_changed = trainable_mask != self._reference_trainable_mask | ||
|
||
# - the whole model is not trainable but we still have grad hooks | ||
trainability_changed |= not self.training and len(self._grad_hooks) > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this mean that grad_hooks should be greater than 0 in eval model? Not sure I understand why this should be the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was meant to detect that the trainability changed, ie. we're in eval() mode but there are grad_hooks in place so we should refresh ? it's tied to the question above, I'm not sure of the reference behavior here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my offline conversation with @blefaudeux to understand this better:
- We can't detect when a module switches from train->eval unless we use the presence of hooks as an indicator.
- We refresh trainable 1) at the beginning 2) when params changes their requires_grad property 3) train<->eval switch.
Thanks for the explanation @blefaudeux !
Fixes the issue upstream in Lightning, thanks so much for the quick fix @blefaudeux :) |
Before submitting
What does this PR do?
Fixes eval mode not being properly handled by ShardedDDP, at best it kept all the grad hooks in place, at worst it could crash when there was a trainability change which should have been ignored. Add a unit test to catch that
Closes #586
cc @SeanNaren @ananthsub
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃