Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update fsdp mixed precision #2047

Merged
merged 6 commits into from
Mar 9, 2023
Merged

Conversation

vchiley
Copy link
Contributor

@vchiley vchiley commented Mar 8, 2023

What does this PR do?

Our MixedPrecision setting 'DEFAULT' should be updated to:

# If mixed_precision = 'default'
mixed_precision = MixedPrecision(
  param_dtype=autocast_precision,
  reduce_dtype=torch.float32,
  buffer_dtype=autocast_precision,
)

This will make 'DEFAULT' emulate AMP behavior.

use FULL if you need
use PURE if you want to live dangerously
standard AMP should be the 'DEFAULT' MixedPrecision setting (see above)

The current 'DEFAULT' shouldn't really be used.

What issue(s) does this change relate to?

Fixes https://mosaicml.atlassian.net/browse/CO-1896

Before submitting

  • Have you read the contributor guidelines?
  • Was this change discussed/approved in an issue first? It is much more likely to be merged if so.
  • Did you update any related docs and document your change?
  • Did you run the tests locally to make sure they pass?
  • Did you run pre-commit on your change? (see the pre-commit section of prerequisites)

@vchiley vchiley self-assigned this Mar 8, 2023
Copy link
Contributor

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, would like abhi to approve as well

docs/source/notes/distributed_training.rst Outdated Show resolved Hide resolved
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
docs/source/notes/distributed_training.rst Outdated Show resolved Hide resolved
docs/source/notes/distributed_training.rst Outdated Show resolved Hide resolved
composer/trainer/mosaic_fsdp.py Show resolved Hide resolved
vchiley and others added 2 commits March 8, 2023 16:33
@vchiley
Copy link
Contributor Author

vchiley commented Mar 9, 2023

fix to test failing in #2050

Copy link
Contributor

@bcui19 bcui19 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@abhi-mosaic
Copy link
Contributor

@vchiley can we get a side-by-side LLM run before and after this change? Aka can we do just a small 125M run with previous DEFAULT and new DEFAULT bheavior and post the screenshot here for posterity? Basically loss and MFU side by side

@vchiley
Copy link
Contributor Author

vchiley commented Mar 9, 2023

testing behavior change here

gpt125m-flash-mpf is using mp: 'FULL'
gpt125m-flash-mpp is using mp: 'PURE'
gpt125m-flash-mpd is using mp: 'DEFAULT'

gpt125m-flash-mpp16b16 is using mp: {'param_dtype': 'bf16', 'buffer_dtype': 'bf16'} ie the proposed 'DEFAULT'
(here reduce_dtype will default to None and use fp32).

At on 8 GPUs at MosaicGPT 125M to chinchilla point shows no diff between any of the runs. The speed is virtually the same for all of the configs, except 'FULL' is ~3% slower.
(the old DEFAULT is ~1% slower than the new DEFAULT)
(In this exact setup the new DEFAULT is ~0.2% faster than PURE??? this is either within measurement noise or the casting is more expensive than one would hope.

Screenshot 2023-03-09 at 12 21 00 PM

I tried running the 7B model with the same set of 4 diff mp configs on 32 GPUs to see performance:
Screenshot 2023-03-09 at 12 24 03 PM

It seems like PURE and the new DEFAULT run at the same speed and are faster than the old DEFAULT and PURE is slowest
best I can explain this is: bwd pass is 2x fwd pass; BW pass isn't memory bound if param communication in bwd pass is bf16, bwd pass is compute bound.

@vchiley vchiley merged commit aaa8dcc into mosaicml:dev Mar 9, 2023
bandish-shah pushed a commit that referenced this pull request Mar 14, 2023
* update fsdp mixed precision

* Update docs/source/notes/distributed_training.rst

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>

* Update docs/source/notes/distributed_training.rst

* Apply suggestions from code review

Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>

---------

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants