Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] OSS dict load/save fix - better fix than 383 and unit test #386

Merged
merged 5 commits into from
Feb 14, 2021

Conversation

blefaudeux
Copy link
Contributor

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests? Catch the case which was broken before, and fix it

What does this PR do?

Fixes #380. Better take than #383 because fixing another issue which was not caught (383 was not enough), and reproducing the issue in an updated unit test so that this does not happen again. Thanks again @zhengwy888

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 13, 2021
@@ -859,10 +861,16 @@ def closure_sharded(input_tensor=input_tensor):
sharded_optim_state_dict = sync_object_ranks(sharded_optim_state_dict, RECIPIENT_RANK, device)

# - cross load the states
# run one step and check that the models are still the same
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this modified unit test does break on the old version

@@ -379,6 +379,8 @@ def state_dict(self) -> Dict[str, Any]:
global_id = self.param_to_index[local_index_to_param_id[local_param_index]]
state_dict["state"][global_id] = s["state"][local_param_index]

# Make sure that the parameters are sorted in the state, as expected
state_dict["state"] = dict(sorted(state_dict["state"].items()))
Copy link
Contributor Author

@blefaudeux blefaudeux Feb 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the state dict returned was sorted properly under the "param_groups" key, but not under the "state" field, which was following the partitioning. I was assuming when loading that it was sorted, so that would break.
Pytorch just uses the ordering from the "param_groups" key, and I was just testing the loading OSS-> Pytorch and vice versa, so this was not caught unfortunately

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know python's dictionary is ordered, so I just looked it up. turns out this has been enabled since python 3.5, good to know! https://stackoverflow.com/questions/39980323/are-dictionaries-ordered-in-python-3-6


# Only add this state to the sharded optimizer if it owns this param
for pg in self.optim.param_groups:
if id(param) in [id(p) for p in pg["params"]]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this second check could mask an issue, we just checked above that this rank owns this param, so this is not needed (and potentially risky)

@@ -832,8 +834,8 @@ def closure_sharded(input_tensor=input_tensor):
loss_sharded_optim = cast(torch.Tensor, sharded_optimizer.step(closure=closure_sharded))

assert torch.allclose(
loss_ddp, loss_sharded_optim
), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}"
loss_ddp, loss_sharded_optim, rtol=1e-3
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the rtol change is only needed on pytorch 1.5 unfortunately, without that on a two gpu machine the difference becomes
26.404895782470703 vs 26.404342651367188 (which I assume is due to a different casting and not structurally wrong) and this asserts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is worth documenting the reason for 1e-3.

@blefaudeux
Copy link
Contributor Author

sorry @min-xu-ai for the revert of the previous one, I just thought this was cleaner and there was one fix left in the cold

Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this seems to be much nicer.

@@ -832,8 +834,8 @@ def closure_sharded(input_tensor=input_tensor):
loss_sharded_optim = cast(torch.Tensor, sharded_optimizer.step(closure=closure_sharded))

assert torch.allclose(
loss_ddp, loss_sharded_optim
), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}"
loss_ddp, loss_sharded_optim, rtol=1e-3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is worth documenting the reason for 1e-3.

@blefaudeux blefaudeux merged commit 54bd62d into master Feb 14, 2021
@blefaudeux blefaudeux deleted the oss_dict_load_fix branch February 16, 2021 23:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[OSS] When loading a state, the insertion order for the params may not match the params_groups
4 participants