Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][feature] optimizer state dict save and load #537

Merged
merged 33 commits into from
Mar 25, 2021

Conversation

sshleifer
Copy link
Contributor

@sshleifer sshleifer commented Mar 19, 2021

Overview

# save
fsdp = FSDP(world_size=4)
optim = Adam(fsdp.parameters())
full_state_dict = fsdp.gather_full_optim_state_dict(optim, recipient_rank=0))
# this is None if you are not on a recipient rank
# recipient_rank=None  performs the same consolidation on all ranks.

# load with different world size
fsdp2 = FSDP(world_size=2)
optim2 = Adam(fsdp2.parameters()
state_shard = fsdp2.get_shard_from_optim_state_dict(full_state_dict)
optim2.load_state_dict(state_shard)

Future Work

  • fairseq integration
  • support flatten_parameters=False
  • support param groups
  • test more nested setups. An FSDP with no params could certainly break this.

On the fairseq side, I tested running with 4 gpus and loading with 2 and this worked.

Assumptions

(0) flatten_parameters=True

(1) if there is a tensor in optimizer state, it is the same size and corresponds to a tensor in model state. If there are singleton tensors in the optimizer, or tensors that correspond to the average update for a column of params (so shaped differently), things will break.

(2) We assume that these two lists are the same if we account for padding:

mlist = [sum(m._param_numels) for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
params = [p.numel() for p in self.parameters()]

we use this assumption to call
mlist[i].get_params_view(flat_param=params_unpadded[i]).

New overhead introduced

  • _get_shard now returns how many padding elements it introduced.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 19, 2021
@sshleifer sshleifer linked an issue Mar 19, 2021 that may be closed by this pull request
@sshleifer sshleifer changed the title [wip] [FSDP][feature] optimizer state dict save and load [FSDP][feature] optimizer state dict save and load Mar 20, 2021
@sshleifer sshleifer marked this pull request as ready for review March 21, 2021 21:40
@@ -1346,6 +1352,244 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None
traceback.print_stack()
raise ValueError(msg)

# Optim State dict functions
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered moving these to a separate FSDPOptimizerMixin in fsdp_optimizer_utils.py, but decided it wasn't really a mixin since it depends heavily on FSDP.

Copy link
Contributor

@myleott myleott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About half-way through, leaving initial comments and will post rest in second batch

fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
Comment on lines 1374 to 1401
if rank == self.rank:
sd = optim.state_dict()
sd["num_padded"] = [m.num_padded for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
if should_collect_state:
_all_optimizer_states.append(
recursive_copy_to_device(sd, non_blocking=True, device=torch.device("cpu"))
)

# Sync with other replicas
state_to_share = (
sd if should_send_state else torch.tensor([0], dtype=torch.uint8, device=_default_device)
)
broadcast_object(
state_to_share, src_rank=self.rank, group=self.process_group, dist_device=_default_device,
)
else:
# Fetch the optim state from the other replicas
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=_default_device),
src_rank=rank,
group=self.process_group,
dist_device=_default_device,
)

if should_collect_state:
_all_optimizer_states.append(
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be rearranged to remove some duplication? Something like:

for rank in range(self.world_size):
    if rank == self.rank:
        state = optim.state_dict()
        sd["num_padded"] = ...
        state = broadcast_object(state, src_rank=rank, ...)
    else:
        state = broadcast_object(None, src_rank=rank, ...)

    if should_collect_state:
        _all_optimizer_states.append(recursive_copy_to_device(state, device=torch.device("cpu"))

Copy link
Contributor Author

@sshleifer sshleifer Mar 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just copy pasted this func from OSS. I think the reason for the extra append is to save useless communication from recipient_rank to recipient_rank

Copy link
Contributor Author

@sshleifer sshleifer Mar 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the simplified implem working with torch.distributed.broadcast_object_list.
I no longer need compute_device. Still calling lazy_init_ for safety.

@min-xu-ai
Copy link
Contributor

Looks like really cool stuff. Some high level questions about the context:

  1. in the example, do users need to put the optimizer back into it is original state after the consolidation? If not, perhaps make a comment about it in the example?
  2. this assumes there is enough GPU memory to hold the state at each or one rank? what's the solution for very large models?

It seems that this is only needed when we change world size between save/restore? If the world size not changed, normal save/restore with the only the sharded data is OK?

@sshleifer
Copy link
Contributor Author

@min-xu-ai

  1. No resetting is needed by the user. This doesn't mutate optimizer it just combines optimizer.state_dict between ranks. I'll add a comment and test to that effect.
  2. Consolidation is happening in CPU memory (the cast is in the consolidate method). If we see use cases where there is not enough CPU memory to handle consolidation on one node, we will iterate :)

Copy link
Contributor

@myleott myleott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we envision a use case for calling consolidate_optim_state_dict without calling gather_full_optim_state_dict?

If not, perhaps simplify interface to:

fsdp = FSDP(world_size=4)
optim = Adam(fsdp.parameters())
full_state_dict = fsdp.gather_full_optim_state_dict(optim, recipient_rank=-1)

Comment on lines 1426 to 1434
# combined_state refers to tensor values in sd[state][param_id].
# Here we just aggregate them into a list inside the dictionary from a list of dictionaries.
combined_state = self._combine_tensor_optim_state(
[x["state"] for x in self._all_optimizer_states], self.world_size
)

# constant_state refers to entries in sd[state][param_id] that are not tensors, like "step"
# we check that these are identical across workers and then take the first
constant_state = [self._extract_constant_state(combined_state, id) for id in combined_state]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these comments/helper methods are very nice 😄

fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved

if next_global_param_id == 0: # stateless optimizer
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
new_state_dict["param_groups"][pg_id]["params"] = list(range(num_params))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this list could be quite large, right? I guess this only affects SGD w/o momentum, but I wonder if there's a more compact way. Let's not worry about it for now, but perhaps put a note or TODO to make it more efficient

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you talking about list(range(num_params))? If so, it affects both cases.
I'll leave a TODO

#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""These files are used by fsdp to help consolidate and shard optimizer states."""
Copy link
Contributor

@myleott myleott Mar 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️ this

@sshleifer sshleifer added the FSDP FullyShardedDataParallel (zero-3) label Mar 23, 2021
Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like super solid work. I added some minor comments. I didn't check the logic in detail mainly because I have two high level questions:

  1. should we consider some optimizer wrapper that work together with fsdp to get the full state? It seems right now everything is in fsdp. Will an optimizer wrapper help more? I haven't thought through this.
  2. I have been thinking that fsdp should support a "streaming" mode for full state so that no single rank's work need to hold all state (non-shard state) in memory. Should this PR try to do streaming to avoid overly big state?

Both 1 and 2 above are kind of independent of PR. Just wanted to put them out there in case they are helpful. If not, just let me know and I will dive deep into this version of the code and give it a more detailed review. Thanks!

fairscale/nn/data_parallel/fsdp_optim_utils.py Outdated Show resolved Hide resolved
@@ -19,9 +19,10 @@
from torch.nn import Parameter
import torch.nn.functional as F

import fairscale.nn.data_parallel.fsdp_optim_utils as ou
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

relative import like `import .fsdp_optim_utils as ou" is more portable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SyntaxError: invalid syntax :(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. perhaps from . import fsdp_optim_utils as ou?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works!

@@ -88,8 +89,8 @@ class FullyShardedDataParallel(nn.Module):
import torch
from fairscale.nn.auto_wrap import enable_wrap, auto_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
fsdp_params = dict(mixed_precision=True, flatten_parameters=True)
with enable_wrap(wrapper_cls=FSDP, **fsdp_params):
fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the doc here!

fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
tests/nn/data_parallel/test_fsdp_optimizer_utils.py Outdated Show resolved Hide resolved
Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the test file first. It looks very good. Minor comments.


from parameterized import parameterized
import torch
from torch.optim import SGD, Adadelta, Adam # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we usually don't have typing in test files. so "type: ignore" is not needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was getting "torch.optim has no Attribute Adadelta" from mypy without this, using

mypy --ignore-missing-imports --scripts-are-modules --pretty .

from fs_test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. magic mypy. I thought it would skip the whole file since there isn't any type annotation in it.

try:
fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,)
optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01)
except TypeError: # AdaScale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you actually mean "AdaScale" here? I don't see AdaScale being used here in this test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, nice catch

tests/nn/data_parallel/test_fsdp_optimizer_utils.py Outdated Show resolved Hide resolved
Comment on lines +88 to +89
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is interesting. thanks for the comment. what the usual value for duration? I am surprised that it is somehow connected with world_size, which is not in the unit of seconds even.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It takes longer to gather from 8 nodes than 4 than 2.
This actually takes 4 ms, but I accidentally regressed it during development and caused it to take 8 seconds for world size 2, 13 for world size 4.
Now that it's fixed I want to prevent it happening again, agreed that the units are arbitrary.

tests/nn/data_parallel/test_fsdp_optimizer_utils.py Outdated Show resolved Hide resolved
Comment on lines 97 to 98
sum([first_tensor_shape(v) for k, v in sd["state"].items()]),
sum([first_tensor_shape(v) for k, v in unwrapped_sd["state"].items()]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps norm will be slightly better than sum for comparison in case both tensors sum to the same values? same with line 110, 111.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just checks that we have the same num elements as the base model after unflattening.
I renamed first_tensor_shape -> first_tensor_numel to make it clearer.

tests/nn/data_parallel/test_fsdp_optimizer_utils.py Outdated Show resolved Hide resolved
tests/nn/data_parallel/test_fsdp_optimizer_utils.py Outdated Show resolved Hide resolved
Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed fsdp changes. I am not sure if nested FSDP cases are well supported by this change.

  1. the APIs are really only intended for the root instance?
  2. root and all inner instances should have flatten == True?
  3. all instance needs to have world_size == default world_size?

If so, can you assert those are the cases in the APIs so that we don't accidentally produce incorrect optim states or crash with non-obvious errors?

def _consolidate_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None
) -> List[Dict]:
"""Update the consolidated state_dict list, one per rank.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be called only on the root FSDP instance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, more specifically it should be called on the instance that was the argument to optimizer(model.parameters(). Are there other cases?

should_collect_state = recipient_rank is None or (self.rank == recipient_rank)
all_states: List[Dict[str, Any]] = []
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
for rank in range(self.world_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there might be complications here when nested FSDP instance have different world_size, right? For example, if BN layers are in their own world_size == 1 process groups, then we collect duplicated states for them? add a TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added TODO in the caller

fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
# Assert nesting is the same as it was at save time
instance_list = self._fsdp_instances
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
if self.flatten_parameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this assume all inner FSDP instances also have flatten == True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will assert

@@ -122,15 +122,15 @@ def _flatten_params(self, flat_param: Optional[nn.Parameter] = None) -> None:
# register the views as plain attributes
self._unflatten_params_as_views()

def _get_param_views(self, flat_param: Tensor) -> Generator:
def get_param_views(self, flat_param: Tensor) -> Generator:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is becoming an public method, can you please:

  1. add docstring with proper doc
  2. assert flat_param is valid before using it?

Copy link
Contributor Author

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments!

def _consolidate_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None
) -> List[Dict]:
"""Update the consolidated state_dict list, one per rank.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, more specifically it should be called on the instance that was the argument to optimizer(model.parameters(). Are there other cases?

should_collect_state = recipient_rank is None or (self.rank == recipient_rank)
all_states: List[Dict[str, Any]] = []
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
for rank in range(self.world_size):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added TODO in the caller

# Assert nesting is the same as it was at save time
instance_list = self._fsdp_instances
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
if self.flatten_parameters:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will assert


from parameterized import parameterized
import torch
from torch.optim import SGD, Adadelta, Adam # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was getting "torch.optim has no Attribute Adadelta" from mypy without this, using

mypy --ignore-missing-imports --scripts-are-modules --pretty .

from fs_test.

try:
fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,)
optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01)
except TypeError: # AdaScale
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, nice catch

Comment on lines +88 to +89
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It takes longer to gather from 8 nodes than 4 than 2.
This actually takes 4 ms, but I accidentally regressed it during development and caused it to take 8 seconds for world size 2, 13 for world size 4.
Now that it's fixed I want to prevent it happening again, agreed that the units are arbitrary.

Comment on lines 97 to 98
sum([first_tensor_shape(v) for k, v in sd["state"].items()]),
sum([first_tensor_shape(v) for k, v in unwrapped_sd["state"].items()]),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just checks that we have the same num elements as the base model after unflattening.
I renamed first_tensor_shape -> first_tensor_numel to make it clearer.

Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finished reviewing. Great step forward. I wished there were more comments in the fsdp_optim_utils.py for me follow along better. I tried my best and it seems to make sense. It might be able to be simplified and individually tested. But we can iterated on them later as we learn more.

fairscale/nn/data_parallel/fsdp_optim_utils.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fsdp_optim_utils.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fsdp_optim_utils.py Outdated Show resolved Hide resolved
return unflat_state, global_to_local_id


def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a docstring?

sshleifer and others added 3 commits March 24, 2021 17:53
Co-authored-by: Min Xu <24926999+min-xu-ai@users.noreply.github.com>
@sshleifer
Copy link
Contributor Author

I'm gunna merge this tomorrow AM unless further comments or CI failure @myleott

@sshleifer sshleifer merged commit 9474d75 into master Mar 25, 2021
@sshleifer sshleifer deleted the fsdp-gather-optimizer branch March 25, 2021 15:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. FSDP FullyShardedDataParallel (zero-3)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants