-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] relax checking root condition #620
Merged
min-xu-ai
merged 11 commits into
facebookresearch:master
from
shuyingsunshine21:lightning_fsdp_root_relax
Apr 23, 2021
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
1813311
relax checking root condition
6baf139
formatting
c2dd952
Merge branch 'master' of https://github.com/facebookresearch/fairscal…
e16c499
add unittest
0abce4a
add unittest to ci test list
39703fe
isort for import of unittest
5ff9f0d
format black .
255b57e
move test to list 1
39c63fe
Merge branch 'master' of https://github.com/facebookresearch/fairscal…
0dcba4d
add skip no cuda
da366c9
black and isort
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
tests/nn/misc/test_flatten_params_wrapper.py | ||
tests/nn/data_parallel/test_fsdp.py | ||
tests/nn/data_parallel/test_fsdp_freezing_weights.py | ||
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pylint: disable=missing-module-docstring | ||
# pylint: disable=missing-class-docstring | ||
# pylint: disable=missing-function-docstring | ||
|
||
""" Test FSDP with nested wrapping multiple times. """ | ||
|
||
import tempfile | ||
|
||
import pytest | ||
import torch | ||
import torch.multiprocessing as mp | ||
from torch.nn import Linear, Module, Sequential | ||
from torch.optim import SGD | ||
|
||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP | ||
from fairscale.nn.data_parallel import TrainingState | ||
from fairscale.utils.testing import dist_init, skip_if_no_cuda, teardown, torch_version | ||
|
||
|
||
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): | ||
result = dist_init(rank, world_size, tempfile_name, unused) | ||
assert result, "Dist init failed" | ||
|
||
assert isinstance(fsdp_config, dict), str(fsdp_config) | ||
|
||
class InnerModel(Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.layers = Sequential(FSDP(Linear(5, 5), **fsdp_config),) | ||
|
||
def forward(self, x): | ||
return self.layers(x) | ||
|
||
inner_model = InnerModel() | ||
model = FSDP(inner_model, **fsdp_config).cuda() | ||
optim = SGD(model.parameters(), lr=0.1) | ||
|
||
for i in range(3): | ||
input = torch.rand((1, 5), dtype=torch.float).cuda() | ||
input.requires_grad = True | ||
output = model(input) | ||
output.sum().backward() | ||
optim.step() | ||
optim.zero_grad() | ||
input = torch.rand((1, 5), dtype=torch.float).cuda() | ||
output = model(input) | ||
|
||
model.assert_state(TrainingState.IDLE) | ||
|
||
# second time to rewrap the inner model | ||
rewrapped_model = FSDP(inner_model, **fsdp_config).cuda() | ||
rewrapped_output = rewrapped_model(input) | ||
|
||
assert torch.allclose(output, rewrapped_output) | ||
teardown() | ||
|
||
|
||
# We use strings for precision and flatten instead of bool to | ||
# make the pytest output more readable. | ||
@skip_if_no_cuda | ||
@pytest.mark.parametrize("world_size", [1, 2]) | ||
@pytest.mark.parametrize("precision", ["full", "mixed"]) | ||
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"]) | ||
def test(world_size, precision, flatten): | ||
""" | ||
This test simulates wrapping the module after training to run inference. | ||
This is required in cases where later in a session, the model is wrapped again in FSDP but | ||
contains nested FSDP wrappers within the module. | ||
""" | ||
if torch_version() < (1, 6, 0): | ||
pytest.skip("older pytorch doesn't support reduce_scatter") | ||
|
||
temp_file_name = tempfile.mkstemp()[1] | ||
unused = tempfile.mkstemp()[1] | ||
|
||
fsdp_config = { | ||
"mixed_precision": precision == "mixed", | ||
"flatten_parameters": flatten == "flatten", | ||
} | ||
|
||
mp.spawn( | ||
_test_func, args=(world_size, fsdp_config, temp_file_name, unused), nprocs=world_size, join=True, | ||
) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this moved by black/isort? If not, CI will fail again. Our CI is pretty strict, it will take a bit of time to get used to. But it is really good once get used to. :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, after black forgot to do isort.