Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Decouple move_params_to_cpu from the mixed_precision. #822

Merged
merged 29 commits into from
Oct 27, 2021
Merged

Conversation

anj-s
Copy link
Contributor

@anj-s anj-s commented Oct 21, 2021

What does this PR do?

This PR decouples mixed_precision from move_params_to_cpu. We should now be able to support full FP16 or FP32 workloads with offloading params and grads to CPU.

The main cutpoints that have been modified are when we create the fp16 shard, move params from fp32 to fp16 device and finally when we discard the fp16 shard.

One of the issues with the code is that we have named shards fp16 and fp32 instead of having a more general name such as storage and compute. This means the _fp16_shard may not be fp16. This is very confusing to parse. I will be modifying this in an upcoming PR.

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 21, 2021
@anj-s anj-s marked this pull request as draft October 21, 2021 12:23
@anj-s anj-s requested a review from min-xu-ai October 23, 2021 15:40
@anj-s anj-s marked this pull request as ready for review October 23, 2021 15:40
@anj-s anj-s changed the title [fix] Decouple CPU offload from the mixed precision parameter. [fix] Decouple move_params_to_cpu from the mixed_precision. Oct 23, 2021
Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nice! @zhaojuanmao, do you want to take a look too?

Comment on lines +1638 to +1641
if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
self._free_fp16_param_shard([p])

if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 2 "if-conditions" are guarding the same code? Why have 2 of them? Having line 1639 and line 1642 duplicated might not be good?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was more for the sake of readability that I split this up.

@@ -180,8 +180,7 @@ class FullyShardedDataParallel(nn.Module):
if ``True``, flatten parameters into a single contiguous tensor,
which improves training speed.
move_params_to_cpu (bool, Optional):
if ``True``, offload FP32 params to CPU. This is only relevant when
*``mixed_precision``* is ``True``.
if ``True``, offload params to CPU.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a requirement here that params need to be fp32 not fp16? if so, perhaps mention that here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no requirement for now. FP32 or FP16 params can be offloaded to CPU.

p._fp32_shard = p._fp32_shard.pin_memory()
p.data = p._fp32_shard

if self.move_params_to_cpu or self.mixed_precision:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p._fp16_shard is only needed when self.mixed_precision=True, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shard is needed any time you offload params to CPU. The shard is named fp16 which causes confusion. I am renaming the shard in a follow up PR.

@anj-s anj-s merged commit ed7ca76 into main Oct 27, 2021
@anj-s anj-s deleted the fp32-offload branch October 27, 2021 21:31
vtantia pushed a commit that referenced this pull request Oct 29, 2021
* remove offload dependency on fp16

* update python version for cpu tess

* run CPU tests with updated PyTorch version

* split changes

* revert tests config

* fix lint errors

* update nightly and test PyTorch versions

* skip failing multiprocess pipe test

* always skip test

* always skip test

* always skip test

* lint error

* skip unsupported versions

* improve skip message

* lint errors

* modify docs

* add tests

* fix test failures

* modify comments

* fix lint errors

* fix lint errors
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. FSDP + SSD offload
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants