-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix]: Fixes an issue with pre_backward hook registering #833
Conversation
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pylint: disable=missing-module-docstring |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review! I am thinking sometimes we have small functions that are very self-evident, so we don't need this in all test files in favor of moving forward.
# pylint: disable=missing-class-docstring | ||
# pylint: disable=missing-function-docstring | ||
|
||
""" Test FSDP with pre-backward hook bug. """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I would generalize docstrings and comments in this class/test to deal with hooks in general. I am guessing we will continue to add to this class given the complexity of hooks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can do that when it is actually extended.
assert p.grad is not None | ||
p.grad = None | ||
|
||
model = FSDP(Model(), flatten_parameters=False).cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this work if flatten_parameters is True? Wouldn't we have already attached the hook on the inner module's flattened params?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good question. The effect (i.e. assertion) only shows up when flatten is False. When flatten is true, since post-backward doesn't happen between the multiple pre_backward hooks, it ended up working OK.
Another hook issue :) thanks for the fix @min-xu-ai ! |
I also encountered this issue when working on non_flatten params in PyTorch FSDP version, it is not found in unit tests, but the bug was hit when I ran benchmarks locally. My fix is a little different, made quite some changes for pre_backward_hook and rebuild_full_param. One major change is to reset has_pre_backward_hook_run flag after each forward pass, and use that flag to avoid firing pre_backward hooks multiple times for each backward pass (does not matter it is inner or outer backward pass). But looks like fix here also works! |
Thanks for taking a look, @zhaojuanmao. I briefly considered using the boolean guard to fix this as well. However, I wasn't convinced that with multiple-forward-pass type of models, we can use the guard that way. For multiple-forward-pass type of models, the pre_backward may need to fire multiple times. But I didn't run a test or anything like that. Maybe we can chat offline more on this case when we meet next time. |
* added the failing test * fixed the bug * fine-tune the condition * typo * typo * changelog and added test to test files Co-authored-by: Min Xu <min.xu.public@gmail.com>
What does this PR do?
Fixes an issue with pre_backward hook registering. When an inner module and outer module both register on the same output (i.e. when the inner output is being returned by the outer module), also, when flatten is False for the outer module, we can have mis-firing of the pre-backward hook in the wrong order. E.g. the outer module's pre-backward fires again after the inner one's fire.
This PR adds a unit test and a fix to that.
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.