Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] relax checking root condition #620

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,16 +857,19 @@ def _set_is_root(self) -> None:
for n, m in self.named_modules():
# `n != ""` excludes self.
if n != "" and isinstance(m, FullyShardedDataParallel):
assert m._is_root is None
m._is_root = False
# When root instance doesn't have params, allow children instances
# to queue the post_backward hook.
#
# TODO (Min): we should think if we can have a empty param at the root
# so that root always have a callback on the backward graph.
if not self._has_params:
assert m._queue_wait_for_post_backward_closure is None
m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
# We relax the assert for non-root instance, when the nested inialized module is wrapped
# again in FSDP later, for example after training to run inference.
assert m._is_root is None or not m._is_root
if m._is_root is None:
m._is_root = False
# When root instance doesn't have params, allow children instances
# to queue the post_backward hook.
#
# TODO (Min): we should think if we can have a empty param at the root
# so that root always have a callback on the backward graph.
if not self._has_params:
assert m._queue_wait_for_post_backward_closure is None
m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
if m.process_group != self.process_group:
self.children_share_process_group = False

Expand Down
1 change: 1 addition & 0 deletions tests/ci_test_list_1.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
88 changes: 88 additions & 0 deletions tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP with nested wrapping multiple times. """

import tempfile

import pytest
import torch
import torch.multiprocessing as mp
from torch.nn import Linear, Module, Sequential
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this moved by black/isort? If not, CI will fail again. Our CI is pretty strict, it will take a bit of time to get used to. But it is really good once get used to. :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, after black forgot to do isort.

from torch.optim import SGD

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState
from fairscale.utils.testing import dist_init, skip_if_no_cuda, teardown, torch_version


def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed"

assert isinstance(fsdp_config, dict), str(fsdp_config)

class InnerModel(Module):
def __init__(self):
super().__init__()
self.layers = Sequential(FSDP(Linear(5, 5), **fsdp_config),)

def forward(self, x):
return self.layers(x)

inner_model = InnerModel()
model = FSDP(inner_model, **fsdp_config).cuda()
optim = SGD(model.parameters(), lr=0.1)

for i in range(3):
input = torch.rand((1, 5), dtype=torch.float).cuda()
input.requires_grad = True
output = model(input)
output.sum().backward()
optim.step()
optim.zero_grad()
input = torch.rand((1, 5), dtype=torch.float).cuda()
output = model(input)

model.assert_state(TrainingState.IDLE)

# second time to rewrap the inner model
rewrapped_model = FSDP(inner_model, **fsdp_config).cuda()
rewrapped_output = rewrapped_model(input)

assert torch.allclose(output, rewrapped_output)
teardown()


# We use strings for precision and flatten instead of bool to
# make the pytest output more readable.
@skip_if_no_cuda
@pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
def test(world_size, precision, flatten):
"""
This test simulates wrapping the module after training to run inference.
This is required in cases where later in a session, the model is wrapped again in FSDP but
contains nested FSDP wrappers within the module.
"""
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter")

temp_file_name = tempfile.mkstemp()[1]
unused = tempfile.mkstemp()[1]

fsdp_config = {
"mixed_precision": precision == "mixed",
"flatten_parameters": flatten == "flatten",
}

mp.spawn(
_test_func, args=(world_size, fsdp_config, temp_file_name, unused), nprocs=world_size, join=True,
)