Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Can we access parameter views when using flatten_parameters=True? #430

Open
SeanNaren opened this issue Feb 24, 2021 · 9 comments
Open
Labels
FSDP FullyShardedDataParallel (zero-3)

Comments

@SeanNaren
Copy link

❓ Questions and Help

This should explain the case:

import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel
import os

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '1337'
torch.distributed.init_process_group("gloo", rank=0, world_size=1)


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(5, 5)


model = FullyShardedDataParallel(Model(), flatten_parameters=False)

# prints parameter list of module
print([p for p in model.module.layer.parameters()])

model = FullyShardedDataParallel(Model(), flatten_parameters=True)

# prints []
print([p for p in model.module.layer.parameters()])

# Throws an error: ValueError: optimizer got an empty parameter list
optimizer = torch.optim.SGD(model.layer.parameters(), lr=1e-3)

When flatten_parameters=True we remove the parameter as we have migrated it to a contiguous buffer, but this means when we call .parameters() on specific modules (in the case we only want to wrap certain parts of the model with optimizers) this can not be done.

Any remedy to this problem? We were experimenting with the possibility of using views to replace this functionality however this doesn't return a parameter I think. Alternatively, we could tell the users if they run into issues like above, to turn off flatten_parameters.

@myleott
Copy link
Contributor

myleott commented Feb 24, 2021

Hmm, this is a bit tricky. For the first part, this works (the params will get re-flattened after the context manager exits):

with model.unflatten_params():
    print([p for p in model.module.layer.parameters()])

The optimizer use case is harder though. The optimizer won't be happy with the re-flattening that happens when the context manager exits:

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(5, 5)

    def forward(self, x):
        return self.layer(x)

(...)

print("param norm (before)", torch.norm(torch.stack([p.norm() for p in model.parameters()])))

with model.unflatten_params():
    optimizer = torch.optim.SGD(model.layer.parameters(), lr=0.1)
optimizer.zero_grad()
loss = model(torch.rand(8, 5)).sum()
loss.backward()
optimizer.step()
print("param norm (after, broken)", torch.norm(torch.stack([p.norm() for p in model.parameters()])))

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()
loss = model(torch.rand(8, 5)).sum()
loss.backward()
optimizer.step()
print("param norm (after, works)", torch.norm(torch.stack([p.norm() for p in model.parameters()])))

Setting flatten_parameters=False should work, but will be quite slow until we improve our bucketing logic...

@SeanNaren
Copy link
Author

Thanks @myleott was helpful!

I think from a Lightning perspective we'll clear in the docs to only use the FSDP model in setting up optimizers for now, whilst we figure out a solution longer term in the FSDP code. Does that sound reasonable?

Regarding your comment about using communication hooks in the future #413 (comment) this would technically fix the issue right? Obviously quite a bit of work to move to using comms hooks/redo bucketing!

@myleott
Copy link
Contributor

myleott commented Feb 24, 2021

Yes, right now flattening is essential for performance, but once we speed up our bucketing solution the gap should be smaller, making it practical to set flatten_parameters=False.

To give you a sense of speed difference, here's a benchmark on 8xV100s for WMT'16 En-De translation in fairseq:

wrapper flatten params? words per second
PyTorch DDP no 150k
PyTorch DDP yes 178k
FSDP no 73k
FSDP yes 192k

@min-xu-ai
Copy link
Contributor

min-xu-ai commented Feb 27, 2021

BTW, the new summon_full_params context mgr is merged in. Do you still have some thing here in this issue and need to keep it open?

@min-xu-ai
Copy link
Contributor

BTW, the new summon_full_params context mgr is merged in. Do you still have some thing here in this issue need to keep it open?

Sorry, I think I misunderstood issue first and ignore the comment about summon_full_params above.

However, what's left to do in this issue? I am a bit unclear.

@SeanNaren
Copy link
Author

Hey @min-xu-ai!

Unfortunately this still won't work, because the original weights have been bucketed in place, removing pointers to the original weights without replacement I think.

optimizer = torch.optim.SGD(model.layer.parameters(), lr=1e-3)

will not work unless we set flatten_parameters=False. I'm sure this will be solved in the future once we're able to use a different bucketing technique that we can view for the original parameters to access them!

@min-xu-ai
Copy link
Contributor

Let me understand a bit more. Will bucketing really help here? Bucketing will help performance since FSDP will loop over fewer params internally. However, for your use case, you need module.layer.parameters() to return the original full unsharded params or sharded ones? Is the optimizer assumed to be point-wise, like that for FSDP? When will the optimizer call on module.layer.parameters() be made? Before or after the model is wrapped? Sorry if I missed the context from reading this thread.

@min-xu-ai min-xu-ai added the FSDP FullyShardedDataParallel (zero-3) label Apr 3, 2021
@anj-s
Copy link
Contributor

anj-s commented Oct 18, 2021

@SeanNaren @min-xu-ai Is there an action item here to follow up on? From my understanding, flatten_parameters=False + enhanced bucketing strategy will allow for this feature. In the meantime is there something we can do and what is the priority for the suggested change?

@min-xu-ai
Copy link
Contributor

@zhaojuanmao

I think new FSDP code will likely address this by adding an API for getting views for all original params to the flatten param. The views can be partial or even empty.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FSDP FullyShardedDataParallel (zero-3)
Projects
None yet
Development

No branches or pull requests

4 participants