Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] relax checking root condition #620

Merged

Conversation

shuyingsunshine21
Copy link
Contributor

@shuyingsunshine21 shuyingsunshine21 commented Apr 21, 2021

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

When integrating with Lightning, found out that as model is nested in FSDP wrapper after training, and when we call trainer.test(model), it failed the assertion that the root is not set. In this case, non-root has already been set. We relax this assertion in this PR. (link to discussion: Lightning-AI/pytorch-lightning#6152)

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 21, 2021
@shuyingsunshine21
Copy link
Contributor Author

cc @min-xu-ai , @SeanNaren

Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Thanks! Once CI passes I can merge if you don't see the merge button.

if not self._has_params:
assert m._queue_wait_for_post_backward_closure is None
m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
assert m._is_root is None or m._is_root == False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert m._is_root is None or m._is_root == False
# We relax the assert for non-root instance. A lightning unit test triggers this otherwise.
assert m._is_root is None or m._is_root == False

@shuyingsunshine21
Copy link
Contributor Author

Nice. Thanks! Once CI passes I can merge if you don't see the merge button.

looks like CI passes.

if not self._has_params:
assert m._queue_wait_for_post_backward_closure is None
m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
# We relax the assert for non-root instance. A lightning unit test triggers this otherwise.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we need to mention lightning here inside of fairscale. eventually this comment will also be unclear what it was relaxed from or why its relaxed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to leave no comment so that it is might be hard to figure out in the future why was this relaxed. Maybe you can suggest something clearer?

@SeanNaren
Copy link

SeanNaren commented Apr 21, 2021

@ananthsub makes a good point, apologies for being lazy on this one, here is a pure pytorch test we can use to simulate this, and remove the PL line:

import os
import unittest
from unittest import mock

import torch
import torch.nn as nn
from fairscale.nn import FullyShardedDataParallel
import torch.nn.functional as F


@mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "1337"}, clear=True)
@unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
def test_wrapping_module():
    """
    This test simulates wrapping the module after training to run inference.
    This is required in cases where later in a session, the model is wrapped again in FSDP but
    contains nested FSDP wrappers within the module.
    """
    device = torch.device("cuda")
    torch.cuda.set_device(0)

    torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

    module = nn.Sequential(
        FullyShardedDataParallel(nn.Linear(5, 5)),
    )

    model = FullyShardedDataParallel(module).to(device)

    input = torch.rand((1, 5), dtype=torch.float).to(device)
    output = model(input)
    loss = F.mse_loss(input, output)
    loss.backward()

    model = FullyShardedDataParallel(module).to(device)
    second_output = model(input)

    assert torch.allclose(output, second_output)

    torch.distributed.destroy_process_group()

We can add this as a unit test to ensure this behaviour works!

@min-xu-ai
Copy link
Contributor

@SeanNaren, this is lovely, I can add a test file once my bug 617 work is done.

@shuyingsunshine21, you can add a new test file as well, but please use other fsdp tests as an example. We can't use Sean's code above as is since we don't want to use hard coded tcp port which may cause test port conflict on the same machine when multiple people are running it. Also, a new test file needs to be added to one of the test list text file under tests dir. It is totally fine to leave it to me if you can wait on it a bit.

@shuyingsunshine21
Copy link
Contributor Author

the CI test failure might be related to #624

@min-xu-ai
Copy link
Contributor

the CI test failure might be related to #624

yeah, sorry about that. will be merged within the next hour.

@@ -5,6 +5,7 @@ tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put the file in list_1.txt since it is shortest right now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was curious about where to put and what is the difference also (so put in similarly place as rest of the fsdp).

@shuyingsunshine21
Copy link
Contributor Author

the CI test failure might be related to #624

yeah, sorry about that. will be merged within the next hour.

no problem

@min-xu-ai
Copy link
Contributor

Please merge with master. I think Ben has fixed it already. My PR is going in soon too.

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState
from fairscale.utils.testing import dist_init, teardown, torch_version
from fairscale.utils.testing import dist_init, teardown, torch_version, skip_if_no_cuda
from torch.nn import Linear, Module, Sequential
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this moved by black/isort? If not, CI will fail again. Our CI is pretty strict, it will take a bit of time to get used to. But it is really good once get used to. :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, after black forgot to do isort.

@shuyingsunshine21
Copy link
Contributor Author

weird thing why it does not trigger CI

@min-xu-ai
Copy link
Contributor

Yeah, I have seen it today too. Perhaps a CI bug. I end up made and pushed a new commit to trigger it.

@shuyingsunshine21
Copy link
Contributor Author

all passed :)

@min-xu-ai min-xu-ai merged commit d3b86d6 into facebookresearch:master Apr 23, 2021
@min-xu-ai
Copy link
Contributor

all passed :)

Nice! Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants