Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix][ShardedDDP] Properly handle .eval() mode #587

Merged
merged 3 commits into from
Apr 7, 2021

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Apr 7, 2021

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes eval mode not being properly handled by ShardedDDP, at best it kept all the grad hooks in place, at worst it could crash when there was a trainability change which should have been ignored. Add a unit test to catch that

Closes #586
cc @SeanNaren @ananthsub

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 7, 2021
@@ -490,6 +486,9 @@ def _setup_backward_hooks(self) -> None:
# Go through the parameters, attach the hook
self._grad_accs = []
self._manual_reduce = []
if not self.training:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to remove existing hooks in eval model? Just curious, otherwise we could move this to the top of the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that it was better for correctness, in that if there's a .backward() left somewhere it still respects the eval() setting ? The documentation is not super clear, to me at least https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=eval#torch.nn.Module.train

@@ -624,3 +623,19 @@ def _flush_reduce_calls(self) -> None:
bucket.sent = True

self._consume_work_handles()

def _detect_train_change(self) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

trainability_changed = trainable_mask != self._reference_trainable_mask

# - the whole model is not trainable but we still have grad hooks
trainability_changed |= not self.training and len(self._grad_hooks) > 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean that grad_hooks should be greater than 0 in eval model? Not sure I understand why this should be the case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was meant to detect that the trainability changed, ie. we're in eval() mode but there are grad_hooks in place so we should refresh ? it's tied to the question above, I'm not sure of the reference behavior here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my offline conversation with @blefaudeux to understand this better:

  • We can't detect when a module switches from train->eval unless we use the presence of hooks as an indicator.
  • We refresh trainable 1) at the beginning 2) when params changes their requires_grad property 3) train<->eval switch.

Thanks for the explanation @blefaudeux !

@SeanNaren
Copy link

Fixes the issue upstream in Lightning, thanks so much for the quick fix @blefaudeux :)

@blefaudeux blefaudeux merged commit ce1f2ce into master Apr 7, 2021
@blefaudeux blefaudeux deleted the shardedddp_handle_training_switch branch April 7, 2021 22:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[ShardedDDP] Handle transition to eval + parameter change
4 participants