diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index fa569e6e6..5e952b1f2 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -857,16 +857,19 @@ def _set_is_root(self) -> None: for n, m in self.named_modules(): # `n != ""` excludes self. if n != "" and isinstance(m, FullyShardedDataParallel): - assert m._is_root is None - m._is_root = False - # When root instance doesn't have params, allow children instances - # to queue the post_backward hook. - # - # TODO (Min): we should think if we can have a empty param at the root - # so that root always have a callback on the backward graph. - if not self._has_params: - assert m._queue_wait_for_post_backward_closure is None - m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward + # We relax the assert for non-root instance, when the nested inialized module is wrapped + # again in FSDP later, for example after training to run inference. + assert m._is_root is None or not m._is_root + if m._is_root is None: + m._is_root = False + # When root instance doesn't have params, allow children instances + # to queue the post_backward hook. + # + # TODO (Min): we should think if we can have a empty param at the root + # so that root always have a callback on the backward graph. + if not self._has_params: + assert m._queue_wait_for_post_backward_closure is None + m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward if m.process_group != self.process_group: self.children_share_process_group = False diff --git a/tests/ci_test_list_1.txt b/tests/ci_test_list_1.txt index efbd5210a..0370eb200 100644 --- a/tests/ci_test_list_1.txt +++ b/tests/ci_test_list_1.txt @@ -1,3 +1,4 @@ tests/nn/misc/test_flatten_params_wrapper.py tests/nn/data_parallel/test_fsdp.py tests/nn/data_parallel/test_fsdp_freezing_weights.py +tests/nn/data_parallel/test_fsdp_multiple_wrapping.py diff --git a/tests/nn/data_parallel/test_fsdp_multiple_wrapping.py b/tests/nn/data_parallel/test_fsdp_multiple_wrapping.py new file mode 100644 index 000000000..3e035da4d --- /dev/null +++ b/tests/nn/data_parallel/test_fsdp_multiple_wrapping.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring + +""" Test FSDP with nested wrapping multiple times. """ + +import tempfile + +import pytest +import torch +import torch.multiprocessing as mp +from torch.nn import Linear, Module, Sequential +from torch.optim import SGD + +from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP +from fairscale.nn.data_parallel import TrainingState +from fairscale.utils.testing import dist_init, skip_if_no_cuda, teardown, torch_version + + +def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): + result = dist_init(rank, world_size, tempfile_name, unused) + assert result, "Dist init failed" + + assert isinstance(fsdp_config, dict), str(fsdp_config) + + class InnerModel(Module): + def __init__(self): + super().__init__() + self.layers = Sequential(FSDP(Linear(5, 5), **fsdp_config),) + + def forward(self, x): + return self.layers(x) + + inner_model = InnerModel() + model = FSDP(inner_model, **fsdp_config).cuda() + optim = SGD(model.parameters(), lr=0.1) + + for i in range(3): + input = torch.rand((1, 5), dtype=torch.float).cuda() + input.requires_grad = True + output = model(input) + output.sum().backward() + optim.step() + optim.zero_grad() + input = torch.rand((1, 5), dtype=torch.float).cuda() + output = model(input) + + model.assert_state(TrainingState.IDLE) + + # second time to rewrap the inner model + rewrapped_model = FSDP(inner_model, **fsdp_config).cuda() + rewrapped_output = rewrapped_model(input) + + assert torch.allclose(output, rewrapped_output) + teardown() + + +# We use strings for precision and flatten instead of bool to +# make the pytest output more readable. +@skip_if_no_cuda +@pytest.mark.parametrize("world_size", [1, 2]) +@pytest.mark.parametrize("precision", ["full", "mixed"]) +@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"]) +def test(world_size, precision, flatten): + """ + This test simulates wrapping the module after training to run inference. + This is required in cases where later in a session, the model is wrapped again in FSDP but + contains nested FSDP wrappers within the module. + """ + if torch_version() < (1, 6, 0): + pytest.skip("older pytorch doesn't support reduce_scatter") + + temp_file_name = tempfile.mkstemp()[1] + unused = tempfile.mkstemp()[1] + + fsdp_config = { + "mixed_precision": precision == "mixed", + "flatten_parameters": flatten == "flatten", + } + + mp.spawn( + _test_func, args=(world_size, fsdp_config, temp_file_name, unused), nprocs=world_size, join=True, + )