Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix]: Fixes an issue with pre_backward hook registering #833

Merged
merged 7 commits into from
Oct 27, 2021

Conversation

min-xu-ai
Copy link
Contributor

What does this PR do?

Fixes an issue with pre_backward hook registering. When an inner module and outer module both register on the same output (i.e. when the inner output is being returned by the outer module), also, when flatten is False for the outer module, we can have mis-firing of the pre-backward hook in the wrong order. E.g. the outer module's pre-backward fires again after the inner one's fire.

This PR adds a unit test and a fix to that.

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 27, 2021
@min-xu-ai min-xu-ai marked this pull request as ready for review October 27, 2021 04:51
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! I am thinking sometimes we have small functions that are very self-evident, so we don't need this in all test files in favor of moving forward.

# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP with pre-backward hook bug. """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would generalize docstrings and comments in this class/test to deal with hooks in general. I am guessing we will continue to add to this class given the complexity of hooks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can do that when it is actually extended.

assert p.grad is not None
p.grad = None

model = FSDP(Model(), flatten_parameters=False).cuda()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this work if flatten_parameters is True? Wouldn't we have already attached the hook on the inner module's flattened params?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good question. The effect (i.e. assertion) only shows up when flatten is False. When flatten is true, since post-backward doesn't happen between the multiple pre_backward hooks, it ended up working OK.

@anj-s
Copy link
Contributor

anj-s commented Oct 27, 2021

Another hook issue :) thanks for the fix @min-xu-ai !

@min-xu-ai min-xu-ai changed the title Fixes an issue with pre_backward hook registering [fix]: Fixes an issue with pre_backward hook registering Oct 27, 2021
@min-xu-ai min-xu-ai merged commit 5da5c0e into main Oct 27, 2021
@min-xu-ai min-xu-ai deleted the min/fsdp_hook branch October 27, 2021 17:53
@zhaojuanmao
Copy link
Contributor

I also encountered this issue when working on non_flatten params in PyTorch FSDP version, it is not found in unit tests, but the bug was hit when I ran benchmarks locally.

My fix is a little different, made quite some changes for pre_backward_hook and rebuild_full_param. One major change is to reset has_pre_backward_hook_run flag after each forward pass, and use that flag to avoid firing pre_backward hooks multiple times for each backward pass (does not matter it is inner or outer backward pass).

But looks like fix here also works!

@min-xu-ai
Copy link
Contributor Author

Thanks for taking a look, @zhaojuanmao. I briefly considered using the boolean guard to fix this as well. However, I wasn't convinced that with multiple-forward-pass type of models, we can use the guard that way. For multiple-forward-pass type of models, the pre_backward may need to fire multiple times. But I didn't run a test or anything like that. Maybe we can chat offline more on this case when we meet next time.

vtantia pushed a commit that referenced this pull request Oct 29, 2021
* added the failing test

* fixed the bug

* fine-tune the condition

* typo

* typo

* changelog and added test to test files

Co-authored-by: Min Xu <min.xu.public@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants