Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][feature] optimizer state dict save and load #537

Merged
merged 33 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ee088bb
consolidate works
sshleifer Mar 19, 2021
ad7df24
cat
sshleifer Mar 19, 2021
ed7526a
Unpad before cat
sshleifer Mar 19, 2021
ed75c59
update params list
sshleifer Mar 19, 2021
44158f7
simple case passing
sshleifer Mar 19, 2021
f82f3b6
found other bug
sshleifer Mar 19, 2021
1022e1e
Broken tests for other optimizers
sshleifer Mar 19, 2021
75119c2
boom boom
sshleifer Mar 20, 2021
89947a4
Merge branch 'master' into fsdp-gather-optimizer
sshleifer Mar 20, 2021
8dcf0a8
remove oss changes
sshleifer Mar 20, 2021
2caf928
passing besides mypy
sshleifer Mar 20, 2021
0b888fd
Smaller delta
sshleifer Mar 20, 2021
a2aacd0
Nesting works
sshleifer Mar 21, 2021
0fc045d
passing, lint attempt
sshleifer Mar 21, 2021
d859734
merge master
sshleifer Mar 21, 2021
3635277
update test list
sshleifer Mar 21, 2021
dbb426f
mypy
sshleifer Mar 22, 2021
f537632
Simpler consolidate_optim_state_dict
sshleifer Mar 22, 2021
a04b406
slightly cleaner
sshleifer Mar 22, 2021
e5e91df
Simplified signature, helper fn for unflattening
sshleifer Mar 22, 2021
ea9d4b5
add todo
sshleifer Mar 22, 2021
47e7cba
Give CI more time to show me a traceback
sshleifer Mar 23, 2021
6cebcec
Fix broadcast_object regression
sshleifer Mar 23, 2021
93c0857
Move most dictionary manipulation to fsdp_optim_utils.py
sshleifer Mar 23, 2021
c93d1db
passing
sshleifer Mar 23, 2021
13b0537
style
sshleifer Mar 23, 2021
9d3dfb7
passing
sshleifer Mar 23, 2021
9f619b2
stateless fix
sshleifer Mar 23, 2021
a4778b7
Min comments
sshleifer Mar 23, 2021
c77a9f7
Min comments
sshleifer Mar 24, 2021
aeefe69
Apply suggestions from code review
sshleifer Mar 24, 2021
d645337
Min comments
sshleifer Mar 24, 2021
75bdd3f
also test param groups
sshleifer Mar 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions fairscale/nn/data_parallel/fsdp_optim_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""These files are used by fsdp to help consolidate and shard optimizer states."""
Copy link
Contributor

@myleott myleott Mar 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️ this

import copy
from typing import Dict, Generator, List, Tuple

import torch


# This function helps shard an
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
def flatten_optim_state_dict(sd: Dict) -> Dict:
"""Called by FSDP.get_shard_from_optim_state_dict"""
param_id_map = sd["param_id_map"]
num_local_params = len(set(param_id_map.values()))
if sd["state"]:
new_state: Dict = {local_id: {} for local_id in range(num_local_params)}
else:
new_state = {}
constant_state = {}
sshleifer marked this conversation as resolved.
Show resolved Hide resolved

# assumes sd sorted
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
for expanded_pid, buffers in sd["state"].items():
consolidated_pid = param_id_map[expanded_pid]
for buffer_name, p in buffers.items():
if torch.is_tensor(p):
if buffer_name not in new_state[consolidated_pid]:
new_state[consolidated_pid][buffer_name] = []
new_state[consolidated_pid][buffer_name].append(p.reshape(-1))
else:
assert isinstance(p, (float, int)), f"unexpected type {type(p)} in optimizer state[{buffer_name}]"
constant_state[buffer_name] = p
# TODO(SS): THIS COULD BE WRONG. What if step is different for different params... At least check

for consolidated_pid, state in new_state.items():
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
for buffer_name, tensors in state.items():
new_state[consolidated_pid][buffer_name] = torch.cat(tensors)
new_state[consolidated_pid].update(constant_state)
new_sd = {"state": new_state, "param_groups": sd["param_groups"]}

for pg_id, _ in enumerate(sd["param_groups"]):
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
# TODO: this list could be huge. Can we avoid materializing?
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params))

return new_sd


# All functions help saving the list of optimizer states, one from each rank
# build_unflat_state_dict is the interface used by FSDP
def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict:
constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it.
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
for k, v in combined_state[param_id].items():

if torch.is_tensor(v[0]):
continue
elif len(set(v)) == 1:
constant_state[k] = v[0]
else:
raise TypeError(f"Dont know how to expand optimizer param {k} with values {v}")
return constant_state


def _combine_tensor_optim_state(states: List[Dict]) -> Dict[int, Dict]:
combined_state = states[0]
for param_id in combined_state:
combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()}
if len(states) == 1:
return combined_state

for rank, s in enumerate(states[1:]):
for param_id, param_state in s.items():
for k, tensor in param_state.items():
combined_state[param_id][k].append(tensor)
return combined_state


def _unflatten_optim_state(
combined_state: Dict[int, Dict], instance_list: List[torch.nn.Module], world_pad_info: List[List[List[int]]],
) -> Tuple[Dict[int, Dict], Dict[int, int]]:
local_to_global_param_id: Dict[int, List[int]] = {}
next_global_id = 0 # gets incremented
unflat_state = {}
pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state}

# constant_state refers to entries in sd[state][param_id] that are not tensors, like "step"
# we check that these are identical across workers and then take the first
constant_state = [_extract_constant_state(combined_state, id) for id in combined_state]

# loop over parameters in state.
# Tensor state will be padded, concatenated, and then restored to their original
# shape with FlattenParamsWrapper.get_views
# get_views multiple tensors, each of which is a new parameter with a new "global" id.
for local_id in combined_state:
local_to_global_param_id[local_id] = []
# undo the work of shard_parameters
for k, v in combined_state[local_id].items():
if k in constant_state[local_id]:
continue
assert isinstance(v, list), f"got {k}: {v} for {local_id}"
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])]
flat_buffer = torch.cat(v_unpad)
param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore
for i, param_view in enumerate(param_views):
if i == len(local_to_global_param_id[local_id]): # make a new ID
local_to_global_param_id[local_id].append(next_global_id)
next_global_id += 1
global_id = local_to_global_param_id[local_id][i]
if global_id not in unflat_state:
unflat_state[global_id] = copy.deepcopy(constant_state[local_id])

assert k not in unflat_state[global_id], f"already added {k} to new[{global_id}]"
unflat_state[global_id][k] = param_view

global_to_local_id = {
new_id: old_pid for old_pid, global_ids in local_to_global_param_id.items() for new_id in global_ids
}

return unflat_state, global_to_local_id


def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a docstring?

world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states]
assert all(len(s) == len(instance_list) for s in world_pad_info)
assert all(len(s[0]) == 1 for s in world_pad_info)
param_groups = copy.deepcopy(world_optim_states[0]["param_groups"])
assert len(param_groups) == 1
# combined_state refers to tensor values in sd[state][param_id].
# Here we just aggregate them into a dictionary of lists (from a list of dictionaries)
combined_state = _combine_tensor_optim_state([x["state"] for x in world_optim_states])
# cleanup all_optimizer_states_list
del world_optim_states
new_state_dict = {"state": {}, "param_groups": param_groups}
# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info)
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
new_state_dict["param_groups"][0]["params"] = list(range(num_params))
new_state_dict["param_id_map"] = global_to_local_id
# Make sure that the parameters are sorted in the state, as expected for a pytorch dict
new_state_dict["state"] = dict(sorted(unflat_state.items()))
return new_state_dict


def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None:
n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values()))
msg = (
f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved "
f"there were {n_local_params_in_opt}"
)
stateless = len(full_optim_state_dict["state"]) == 0
assert stateless or (n_instances == n_local_params_in_opt), msg
151 changes: 7 additions & 144 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.nn import Parameter
import torch.nn.functional as F

import fairscale.nn.data_parallel.fsdp_optim_utils as ou
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

relative import like `import .fsdp_optim_utils as ou" is more portable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SyntaxError: invalid syntax :(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. perhaps from . import fsdp_optim_utils as ou?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works!

from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device
Expand Down Expand Up @@ -1352,7 +1353,6 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None
traceback.print_stack()
raise ValueError(msg)

# Optim State dict functions
def _consolidate_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None
) -> List[Dict]:
Expand Down Expand Up @@ -1401,106 +1401,10 @@ def gather_full_optim_state_dict(
world_optim_states = self._consolidate_optim_state_dict(optim, recipient_rank)
if self.rank != recipient_rank and recipient_rank is not None:
return None

# Unify the shard states by concatenating tensors and unflattening params
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states]
instance_list: List[nn.Module] = self._fsdp_instances
assert all(len(s) == len(instance_list) for s in world_pad_info)
assert all(len(s[0]) == 1 for s in world_pad_info)

param_groups = copy.deepcopy(world_optim_states[0]["param_groups"])
assert len(param_groups) == 1

# combined_state refers to tensor values in sd[state][param_id].
# Here we just aggregate them into a dictionary of lists (from a list of dictionaries)
combined_state = self._combine_tensor_optim_state([x["state"] for x in world_optim_states], self.world_size)
# cleanup all_optimizer_states_list
del world_optim_states

new_state_dict = {"state": {}, "param_groups": param_groups}

# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = self._unflatten_optim_state(combined_state, instance_list, world_pad_info)

num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
new_state_dict["param_groups"][0]["params"] = list(range(num_params))

new_state_dict["param_id_map"] = global_to_local_id
# Make sure that the parameters are sorted in the state, as expected for a pytorch dict
new_state_dict["state"] = dict(sorted(unflat_state.items()))
new_state_dict = ou.build_unflat_state_dict(self._fsdp_instances, world_optim_states)
return new_state_dict

@staticmethod
def _unflatten_optim_state(
combined_state: Dict[int, Dict], instance_list: List[nn.Module], world_pad_info: List[List[List[int]]],
) -> Tuple[Dict[int, Dict], Dict[int, int]]:
local_to_global_param_id: Dict[int, List[int]] = {}
next_global_id = 0 # gets incremented
unflat_state = {}
pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state}

# constant_state refers to entries in sd[state][param_id] that are not tensors, like "step"
# we check that these are identical across workers and then take the first
constant_state = [FullyShardedDataParallel._extract_constant_state(combined_state, id) for id in combined_state]

# loop over parameters in state.
# Tensor state will be padded, concatenated, and then restored to their original
# shape with FlattenParamsWrapper.get_views
# get_views multiple tensors, each of which is a new parameter with a new "global" id.
for local_id in combined_state:
local_to_global_param_id[local_id] = []
# undo the work of shard_parameters
for k, v in combined_state[local_id].items():
if k in constant_state[local_id]:
continue
assert isinstance(v, list), f"got {k}: {v} for {local_id}"
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])]
flat_buffer = torch.cat(v_unpad)
param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore
for i, param_view in enumerate(param_views):
if i == len(local_to_global_param_id[local_id]): # make a new ID
local_to_global_param_id[local_id].append(next_global_id)
next_global_id += 1
global_id = local_to_global_param_id[local_id][i]
if global_id not in unflat_state:
unflat_state[global_id] = copy.deepcopy(constant_state[local_id])

assert k not in unflat_state[global_id], f"already added {k} to new[{global_id}]"
unflat_state[global_id][k] = param_view

global_to_local_id = {
new_id: old_pid for old_pid, global_ids in local_to_global_param_id.items() for new_id in global_ids
}

return unflat_state, global_to_local_id

@staticmethod
def _combine_tensor_optim_state(states: List[Dict], world_size: int) -> Dict[int, Dict]:
combined_state = states[0]
for param_id in combined_state:
combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()}
if world_size == 1:
return combined_state

for rank, s in enumerate(states[1:]):
for param_id, param_state in s.items():
for k, tensor in param_state.items():
combined_state[param_id][k].append(tensor)
return combined_state

@staticmethod
def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict:
constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it.
for k, v in combined_state[param_id].items():

if torch.is_tensor(v[0]):
continue
elif len(set(v)) == 1:
constant_state[k] = v[0]
else:
raise TypeError(f"Dont know how to expand optimizer param {k} with values {v}")
return constant_state

@property
def _fsdp_instances(self) -> List[nn.Module]:
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
"""Returns all fsdp modules in self.modules() including self."""
Expand All @@ -1509,64 +1413,23 @@ def _fsdp_instances(self) -> List[nn.Module]:
def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Get the portion of the optimizer state dict associated with the shard"""
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
# Assert nesting is the same as it was at save time
n_instances = len(self._fsdp_instances)
n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values()))
msg = f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved there were {n_local_params_in_opt}"
stateless = len(full_optim_state_dict["state"]) == 0
assert stateless or (n_instances == n_local_params_in_opt), msg

stateless = len(full_optim_state_dict["state"]) == 0
instance_list = self._fsdp_instances
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
if self.flatten_parameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this assume all inner FSDP instances also have flatten == True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will assert

full_optim_state_dict = self._flatten_optim_state_dict(full_optim_state_dict)
assert stateless or len(full_optim_state_dict["state"]) == len(instance_list)
full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict)
assert len(full_optim_state_dict["state"]) in (0, len(instance_list))

# get the portion of dict associated with the shard
# get the portion of dict associated with the shard, in place
for id, s in full_optim_state_dict["state"].items():
for k, v in s.items():
if torch.is_tensor(v):
v_shard, _ = self._get_shard(v)
else:
v_shard = v # dont partition entries that are not tensors
v_shard = v # dont shard entries that are not tensors
full_optim_state_dict["state"][id][k] = v_shard

return full_optim_state_dict

@staticmethod
def _flatten_optim_state_dict(sd: Dict) -> Dict:
param_id_map = sd["param_id_map"]
num_local_params = len(set(param_id_map.values()))
if sd["state"]:
new_state: Dict = {local_id: {} for local_id in range(num_local_params)}
else:
new_state = {}
constant_state = {}

# assumes sd sorted
for expanded_pid, buffers in sd["state"].items():
consolidated_pid = param_id_map[expanded_pid]
for buffer_name, p in buffers.items():
if torch.is_tensor(p):
if buffer_name not in new_state[consolidated_pid]:
new_state[consolidated_pid][buffer_name] = []
new_state[consolidated_pid][buffer_name].append(p.reshape(-1))
else:
assert isinstance(p, (float, int)), f"unexpected type {type(p)} in optimizer state[{buffer_name}]"
constant_state[buffer_name] = p
# TODO(SS): THIS COULD BE WRONG. What if step is different for different params... At least check

for consolidated_pid, state in new_state.items():
for buffer_name, tensors in state.items():
new_state[consolidated_pid][buffer_name] = torch.cat(tensors)
new_state[consolidated_pid].update(constant_state)
new_sd = {"state": new_state, "param_groups": sd["param_groups"]}

for pg_id, _ in enumerate(sd["param_groups"]):
# TODO: this list could be huge. Can we avoid materializing?
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params))

return new_sd


@torch.no_grad()
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
Expand Down