-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][feature] optimizer state dict save and load #537
Changes from 1 commit
ee088bb
ad7df24
ed7526a
ed75c59
44158f7
f82f3b6
1022e1e
75119c2
89947a4
8dcf0a8
2caf928
0b888fd
a2aacd0
0fc045d
d859734
3635277
dbb426f
f537632
a04b406
e5e91df
ea9d4b5
47e7cba
6cebcec
93c0857
c93d1db
13b0537
9d3dfb7
9f619b2
a4778b7
c77a9f7
aeefe69
d645337
75bdd3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
"""These files are used by fsdp to help consolidate and shard optimizer states.""" | ||
import copy | ||
from typing import Dict, Generator, List, Tuple | ||
|
||
import torch | ||
|
||
|
||
# This function helps shard an | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def flatten_optim_state_dict(sd: Dict) -> Dict: | ||
"""Called by FSDP.get_shard_from_optim_state_dict""" | ||
param_id_map = sd["param_id_map"] | ||
num_local_params = len(set(param_id_map.values())) | ||
if sd["state"]: | ||
new_state: Dict = {local_id: {} for local_id in range(num_local_params)} | ||
else: | ||
new_state = {} | ||
constant_state = {} | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# assumes sd sorted | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for expanded_pid, buffers in sd["state"].items(): | ||
consolidated_pid = param_id_map[expanded_pid] | ||
for buffer_name, p in buffers.items(): | ||
if torch.is_tensor(p): | ||
if buffer_name not in new_state[consolidated_pid]: | ||
new_state[consolidated_pid][buffer_name] = [] | ||
new_state[consolidated_pid][buffer_name].append(p.reshape(-1)) | ||
else: | ||
assert isinstance(p, (float, int)), f"unexpected type {type(p)} in optimizer state[{buffer_name}]" | ||
constant_state[buffer_name] = p | ||
# TODO(SS): THIS COULD BE WRONG. What if step is different for different params... At least check | ||
|
||
for consolidated_pid, state in new_state.items(): | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for buffer_name, tensors in state.items(): | ||
new_state[consolidated_pid][buffer_name] = torch.cat(tensors) | ||
new_state[consolidated_pid].update(constant_state) | ||
new_sd = {"state": new_state, "param_groups": sd["param_groups"]} | ||
|
||
for pg_id, _ in enumerate(sd["param_groups"]): | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# TODO: this list could be huge. Can we avoid materializing? | ||
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params)) | ||
|
||
return new_sd | ||
|
||
|
||
# All functions help saving the list of optimizer states, one from each rank | ||
# build_unflat_state_dict is the interface used by FSDP | ||
def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict: | ||
constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it. | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for k, v in combined_state[param_id].items(): | ||
|
||
if torch.is_tensor(v[0]): | ||
continue | ||
elif len(set(v)) == 1: | ||
constant_state[k] = v[0] | ||
else: | ||
raise TypeError(f"Dont know how to expand optimizer param {k} with values {v}") | ||
return constant_state | ||
|
||
|
||
def _combine_tensor_optim_state(states: List[Dict]) -> Dict[int, Dict]: | ||
combined_state = states[0] | ||
for param_id in combined_state: | ||
combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()} | ||
if len(states) == 1: | ||
return combined_state | ||
|
||
for rank, s in enumerate(states[1:]): | ||
for param_id, param_state in s.items(): | ||
for k, tensor in param_state.items(): | ||
combined_state[param_id][k].append(tensor) | ||
return combined_state | ||
|
||
|
||
def _unflatten_optim_state( | ||
combined_state: Dict[int, Dict], instance_list: List[torch.nn.Module], world_pad_info: List[List[List[int]]], | ||
) -> Tuple[Dict[int, Dict], Dict[int, int]]: | ||
local_to_global_param_id: Dict[int, List[int]] = {} | ||
next_global_id = 0 # gets incremented | ||
unflat_state = {} | ||
pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state} | ||
|
||
# constant_state refers to entries in sd[state][param_id] that are not tensors, like "step" | ||
# we check that these are identical across workers and then take the first | ||
constant_state = [_extract_constant_state(combined_state, id) for id in combined_state] | ||
|
||
# loop over parameters in state. | ||
# Tensor state will be padded, concatenated, and then restored to their original | ||
# shape with FlattenParamsWrapper.get_views | ||
# get_views multiple tensors, each of which is a new parameter with a new "global" id. | ||
for local_id in combined_state: | ||
local_to_global_param_id[local_id] = [] | ||
# undo the work of shard_parameters | ||
for k, v in combined_state[local_id].items(): | ||
if k in constant_state[local_id]: | ||
continue | ||
assert isinstance(v, list), f"got {k}: {v} for {local_id}" | ||
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])] | ||
flat_buffer = torch.cat(v_unpad) | ||
param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore | ||
for i, param_view in enumerate(param_views): | ||
if i == len(local_to_global_param_id[local_id]): # make a new ID | ||
local_to_global_param_id[local_id].append(next_global_id) | ||
next_global_id += 1 | ||
global_id = local_to_global_param_id[local_id][i] | ||
if global_id not in unflat_state: | ||
unflat_state[global_id] = copy.deepcopy(constant_state[local_id]) | ||
|
||
assert k not in unflat_state[global_id], f"already added {k} to new[{global_id}]" | ||
unflat_state[global_id][k] = param_view | ||
|
||
global_to_local_id = { | ||
new_id: old_pid for old_pid, global_ids in local_to_global_param_id.items() for new_id in global_ids | ||
} | ||
|
||
return unflat_state, global_to_local_id | ||
|
||
|
||
def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a docstring? |
||
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states] | ||
assert all(len(s) == len(instance_list) for s in world_pad_info) | ||
assert all(len(s[0]) == 1 for s in world_pad_info) | ||
param_groups = copy.deepcopy(world_optim_states[0]["param_groups"]) | ||
assert len(param_groups) == 1 | ||
# combined_state refers to tensor values in sd[state][param_id]. | ||
# Here we just aggregate them into a dictionary of lists (from a list of dictionaries) | ||
combined_state = _combine_tensor_optim_state([x["state"] for x in world_optim_states]) | ||
# cleanup all_optimizer_states_list | ||
del world_optim_states | ||
new_state_dict = {"state": {}, "param_groups": param_groups} | ||
# local ids are in the current state, global_ids will be in returned state. | ||
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info) | ||
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore | ||
new_state_dict["param_groups"][0]["params"] = list(range(num_params)) | ||
new_state_dict["param_id_map"] = global_to_local_id | ||
# Make sure that the parameters are sorted in the state, as expected for a pytorch dict | ||
new_state_dict["state"] = dict(sorted(unflat_state.items())) | ||
return new_state_dict | ||
|
||
|
||
def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None: | ||
n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values())) | ||
msg = ( | ||
f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved " | ||
f"there were {n_local_params_in_opt}" | ||
) | ||
stateless = len(full_optim_state_dict["state"]) == 0 | ||
assert stateless or (n_instances == n_local_params_in_opt), msg |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
from torch.nn import Parameter | ||
import torch.nn.functional as F | ||
|
||
import fairscale.nn.data_parallel.fsdp_optim_utils as ou | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. relative import like `import .fsdp_optim_utils as ou" is more portable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. got it. perhaps from . import fsdp_optim_utils as ou? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That works! |
||
from fairscale.nn.misc import FlattenParamsWrapper | ||
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap | ||
from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device | ||
|
@@ -1352,7 +1353,6 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None | |
traceback.print_stack() | ||
raise ValueError(msg) | ||
|
||
# Optim State dict functions | ||
def _consolidate_optim_state_dict( | ||
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None | ||
) -> List[Dict]: | ||
|
@@ -1401,106 +1401,10 @@ def gather_full_optim_state_dict( | |
world_optim_states = self._consolidate_optim_state_dict(optim, recipient_rank) | ||
if self.rank != recipient_rank and recipient_rank is not None: | ||
return None | ||
|
||
# Unify the shard states by concatenating tensors and unflattening params | ||
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states] | ||
instance_list: List[nn.Module] = self._fsdp_instances | ||
assert all(len(s) == len(instance_list) for s in world_pad_info) | ||
assert all(len(s[0]) == 1 for s in world_pad_info) | ||
|
||
param_groups = copy.deepcopy(world_optim_states[0]["param_groups"]) | ||
assert len(param_groups) == 1 | ||
|
||
# combined_state refers to tensor values in sd[state][param_id]. | ||
# Here we just aggregate them into a dictionary of lists (from a list of dictionaries) | ||
combined_state = self._combine_tensor_optim_state([x["state"] for x in world_optim_states], self.world_size) | ||
# cleanup all_optimizer_states_list | ||
del world_optim_states | ||
|
||
new_state_dict = {"state": {}, "param_groups": param_groups} | ||
|
||
# local ids are in the current state, global_ids will be in returned state. | ||
unflat_state, global_to_local_id = self._unflatten_optim_state(combined_state, instance_list, world_pad_info) | ||
|
||
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore | ||
new_state_dict["param_groups"][0]["params"] = list(range(num_params)) | ||
|
||
new_state_dict["param_id_map"] = global_to_local_id | ||
# Make sure that the parameters are sorted in the state, as expected for a pytorch dict | ||
new_state_dict["state"] = dict(sorted(unflat_state.items())) | ||
new_state_dict = ou.build_unflat_state_dict(self._fsdp_instances, world_optim_states) | ||
return new_state_dict | ||
|
||
@staticmethod | ||
def _unflatten_optim_state( | ||
combined_state: Dict[int, Dict], instance_list: List[nn.Module], world_pad_info: List[List[List[int]]], | ||
) -> Tuple[Dict[int, Dict], Dict[int, int]]: | ||
local_to_global_param_id: Dict[int, List[int]] = {} | ||
next_global_id = 0 # gets incremented | ||
unflat_state = {} | ||
pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state} | ||
|
||
# constant_state refers to entries in sd[state][param_id] that are not tensors, like "step" | ||
# we check that these are identical across workers and then take the first | ||
constant_state = [FullyShardedDataParallel._extract_constant_state(combined_state, id) for id in combined_state] | ||
|
||
# loop over parameters in state. | ||
# Tensor state will be padded, concatenated, and then restored to their original | ||
# shape with FlattenParamsWrapper.get_views | ||
# get_views multiple tensors, each of which is a new parameter with a new "global" id. | ||
for local_id in combined_state: | ||
local_to_global_param_id[local_id] = [] | ||
# undo the work of shard_parameters | ||
for k, v in combined_state[local_id].items(): | ||
if k in constant_state[local_id]: | ||
continue | ||
assert isinstance(v, list), f"got {k}: {v} for {local_id}" | ||
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])] | ||
flat_buffer = torch.cat(v_unpad) | ||
param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore | ||
for i, param_view in enumerate(param_views): | ||
if i == len(local_to_global_param_id[local_id]): # make a new ID | ||
local_to_global_param_id[local_id].append(next_global_id) | ||
next_global_id += 1 | ||
global_id = local_to_global_param_id[local_id][i] | ||
if global_id not in unflat_state: | ||
unflat_state[global_id] = copy.deepcopy(constant_state[local_id]) | ||
|
||
assert k not in unflat_state[global_id], f"already added {k} to new[{global_id}]" | ||
unflat_state[global_id][k] = param_view | ||
|
||
global_to_local_id = { | ||
new_id: old_pid for old_pid, global_ids in local_to_global_param_id.items() for new_id in global_ids | ||
} | ||
|
||
return unflat_state, global_to_local_id | ||
|
||
@staticmethod | ||
def _combine_tensor_optim_state(states: List[Dict], world_size: int) -> Dict[int, Dict]: | ||
combined_state = states[0] | ||
for param_id in combined_state: | ||
combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()} | ||
if world_size == 1: | ||
return combined_state | ||
|
||
for rank, s in enumerate(states[1:]): | ||
for param_id, param_state in s.items(): | ||
for k, tensor in param_state.items(): | ||
combined_state[param_id][k].append(tensor) | ||
return combined_state | ||
|
||
@staticmethod | ||
def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict: | ||
constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it. | ||
for k, v in combined_state[param_id].items(): | ||
|
||
if torch.is_tensor(v[0]): | ||
continue | ||
elif len(set(v)) == 1: | ||
constant_state[k] = v[0] | ||
else: | ||
raise TypeError(f"Dont know how to expand optimizer param {k} with values {v}") | ||
return constant_state | ||
|
||
@property | ||
def _fsdp_instances(self) -> List[nn.Module]: | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Returns all fsdp modules in self.modules() including self.""" | ||
|
@@ -1509,64 +1413,23 @@ def _fsdp_instances(self) -> List[nn.Module]: | |
def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]: | ||
"""Get the portion of the optimizer state dict associated with the shard""" | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Assert nesting is the same as it was at save time | ||
n_instances = len(self._fsdp_instances) | ||
n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values())) | ||
msg = f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved there were {n_local_params_in_opt}" | ||
stateless = len(full_optim_state_dict["state"]) == 0 | ||
assert stateless or (n_instances == n_local_params_in_opt), msg | ||
|
||
stateless = len(full_optim_state_dict["state"]) == 0 | ||
instance_list = self._fsdp_instances | ||
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list)) | ||
if self.flatten_parameters: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this assume all inner FSDP instances also have flatten == True? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, will assert |
||
full_optim_state_dict = self._flatten_optim_state_dict(full_optim_state_dict) | ||
assert stateless or len(full_optim_state_dict["state"]) == len(instance_list) | ||
full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict) | ||
assert len(full_optim_state_dict["state"]) in (0, len(instance_list)) | ||
|
||
# get the portion of dict associated with the shard | ||
# get the portion of dict associated with the shard, in place | ||
for id, s in full_optim_state_dict["state"].items(): | ||
for k, v in s.items(): | ||
if torch.is_tensor(v): | ||
v_shard, _ = self._get_shard(v) | ||
else: | ||
v_shard = v # dont partition entries that are not tensors | ||
v_shard = v # dont shard entries that are not tensors | ||
full_optim_state_dict["state"][id][k] = v_shard | ||
|
||
return full_optim_state_dict | ||
|
||
@staticmethod | ||
def _flatten_optim_state_dict(sd: Dict) -> Dict: | ||
param_id_map = sd["param_id_map"] | ||
num_local_params = len(set(param_id_map.values())) | ||
if sd["state"]: | ||
new_state: Dict = {local_id: {} for local_id in range(num_local_params)} | ||
else: | ||
new_state = {} | ||
constant_state = {} | ||
|
||
# assumes sd sorted | ||
for expanded_pid, buffers in sd["state"].items(): | ||
consolidated_pid = param_id_map[expanded_pid] | ||
for buffer_name, p in buffers.items(): | ||
if torch.is_tensor(p): | ||
if buffer_name not in new_state[consolidated_pid]: | ||
new_state[consolidated_pid][buffer_name] = [] | ||
new_state[consolidated_pid][buffer_name].append(p.reshape(-1)) | ||
else: | ||
assert isinstance(p, (float, int)), f"unexpected type {type(p)} in optimizer state[{buffer_name}]" | ||
constant_state[buffer_name] = p | ||
# TODO(SS): THIS COULD BE WRONG. What if step is different for different params... At least check | ||
|
||
for consolidated_pid, state in new_state.items(): | ||
for buffer_name, tensors in state.items(): | ||
new_state[consolidated_pid][buffer_name] = torch.cat(tensors) | ||
new_state[consolidated_pid].update(constant_state) | ||
new_sd = {"state": new_state, "param_groups": sd["param_groups"]} | ||
|
||
for pg_id, _ in enumerate(sd["param_groups"]): | ||
# TODO: this list could be huge. Can we avoid materializing? | ||
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params)) | ||
|
||
return new_sd | ||
|
||
|
||
@torch.no_grad() | ||
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️ this