Skip to content

Commit

Permalink
[fix] OSS dict load/save fix - better fix than 383 and unit test (#386)
Browse files Browse the repository at this point in the history
* WIP, needs to be fixed !

* should be a fix, many thanks Weiyi Zheng

* slightly better unit test, sorting the states on the way out

* reproducing the issue from Weiyi in a unit test, and finally properly fixing

* fixing unit test on pytorch1.5 - original loss diff 26.404895782470703 - 26.404342651367188
  • Loading branch information
blefaudeux committed Feb 14, 2021
1 parent b666d6a commit 54bd62d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
19 changes: 7 additions & 12 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ def state_dict(self) -> Dict[str, Any]:
global_id = self.param_to_index[local_index_to_param_id[local_param_index]]
state_dict["state"][global_id] = s["state"][local_param_index]

# Make sure that the parameters are sorted in the state, as expected
state_dict["state"] = dict(sorted(state_dict["state"].items()))
return state_dict

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand All @@ -391,23 +393,16 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

# NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time
# we work around that here by using the fact that the params are ordered as in the param_groups
pytorch15_index_redirect = {k: i for i, k in enumerate(state_dict["state"].keys())}

for i_param, (key, value) in enumerate(state_dict["state"].items()):
param = self.index_to_param[i_param]
for key, value in state_dict["state"].items():
param = self.index_to_param[pytorch15_index_redirect[key]]

# Populate the sharded optimizer state on the fly
if self.param_to_rank[param] != self.rank:
state_dict["state"][key] = None

if key in self.index_to_param:
param = self.index_to_param[i_param]

# Only add this state to the sharded optimizer if it owns this param
for pg in self.optim.param_groups:
if id(param) in [id(p) for p in pg["params"]]:
self.optim.state[param] = recursive_copy_to_device(
value, non_blocking=True, device=param.device
)
else:
self.optim.state[param] = recursive_copy_to_device(value, non_blocking=True, device=param.device)

super().load_state_dict(state_dict)

Expand Down
16 changes: 12 additions & 4 deletions tests/optim/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,9 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):

def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer], change_train_graph: bool = False):
# Any model works. Add one different buffer per rank
trunk = torch.nn.Sequential(torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden))
trunk = torch.nn.Sequential(
torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden), torch.nn.Linear(hidden, hidden)
)
trunk.register_buffer("test_buffer", torch.ones((1)) * rank)
trunk.to(device)

Expand Down Expand Up @@ -832,8 +834,8 @@ def closure_sharded(input_tensor=input_tensor):
loss_sharded_optim = cast(torch.Tensor, sharded_optimizer.step(closure=closure_sharded))

assert torch.allclose(
loss_ddp, loss_sharded_optim
), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}"
loss_ddp, loss_sharded_optim, rtol=1e-3
), f"Losses differ in between Pytorch optim and OSS\n {loss_ddp.item()} - {loss_sharded_optim.item()} - world size {world_size}"

check_same_model_params(oss_ddp_model, ddp_model)

Expand All @@ -859,10 +861,16 @@ def closure_sharded(input_tensor=input_tensor):
sharded_optim_state_dict = sync_object_ranks(sharded_optim_state_dict, RECIPIENT_RANK, device)

# - cross load the states
# run one step and check that the models are still the same
ddp_state_dict_ref = copy.deepcopy(ddp_state_dict) # OSS will remove some states
ddp_optimizer.load_state_dict(sharded_optim_state_dict) # mixup on purpose !
sharded_optimizer.load_state_dict(ddp_state_dict)
check_step()

# - run one step and check that the models are still the same
# - self load, rewind, check no problem
# run one step and check that the models are still the same
ddp_optimizer.load_state_dict(ddp_state_dict_ref)
sharded_optimizer.load_state_dict(sharded_optim_state_dict)
check_step()

for opt in [torch.optim.Adam, torch.optim.SGD]:
Expand Down

0 comments on commit 54bd62d

Please sign in to comment.