Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: SAM not working with FSDP/DeepSpeed and LR scheduler. #3259

Merged
merged 5 commits into from
May 10, 2024

Conversation

Joqsan
Copy link
Contributor

@Joqsan Joqsan commented May 6, 2024

Description

I'm tracing this bug back to Composer coming from using LLMFoundry after we noticed that our models didn't learn anything when running distributed training with FSDP alongside SAM and a LR scheduler. If running the same setup but w/o distributed training (on a single GPU), everything works fine.

Reproducibility

Just checkout the source branch at the commit d882d8f4, which contains only the tests (without the fix), and run:

## All tests pass (no distributed training with 1 GPU)
make test-dist-gpu WORLD_SIZE=1 EXTRA_ARGS="tests/algorithms/test_sam.py::TestSAMParamGroups"

## All tests failed (FSDP and DeepSpeed tests)
make test-dist-gpu WORLD_SIZE=2 EXTRA_ARGS="tests/algorithms/test_sam.py::TestSAMParamGroups"

Problem (FSDP)

  • When using SAM with a LR scheduler that resets the starting param_groups[0]['lr'] to 0.0 (such as CosineAnnealingWithWarmupScheduler), it varies the learning rate at SAMOptimizer.param_groups[0]['lr'], but not the one at base_optimizer.param_groups[0]['lr'] (which is also reset to 0.0).
  • So the latter remains equal to 0.0 for the whole training process, which is why we don't see any learning.

Cause (short answer)

  • SAMOptimizer.param_groups[i] and base_optimizer.param_groups[i] are not referencing the same params because of the way Composer handles the sharding.

Cause (long answer)

  • Right when triggering Event.INIT, SAMOptimizer binds references to the param groups in base_optimizer., so up to this point both optimizers reference the same param groups.
  • Later (after initializing optimizers and anything triggered by Event.INIT) Composer initializes FSDP (here), which is weird since the FSDP docs warns about initializing the optimizer after doing the sharding, not the other way around.
  • Composer does the sharding by clearing the param groups from the SAMOptimizer and bounding to it the new sharded params.
  • At the end of the whole Trainer initialization, SAMOptimizer and base_optimizer reference different param groups:
    • SAMOptimizer the sharded ones.
    • base_optimizer the original ones.
  • During training, state.schedulers[0].step() is done on the 'lr' param of SAMOptimizer, which are not the ones referenced by base_optimizer, but the real optimizer step is done by the base_optimizer, not SAMOptimizer. So no learning.

Problem (DeepSpeed)

  • As far as a I debugged, the problem above is just related to PyTorch's FSDP and not DeepSpeed (it handles the sharding process on its own).
  • But a problem arises when allocating the internal optimizer state when initializing the DeepSpeedEngine (here).
    • SAMOptimizer does the optimizer step iff a closure is passed, while BF16_Optimizer.initialize_optimizer_states() needs to do an optimizer step without any closure.

Solution

  • Let any sharding-related initialization be done on the param groups of the would-be base_optimizer optimizer instance, and just after all of this wrap the resulting optimizer with a SAMOptimizer instance, binding references to the sharded param groups.
  • For DeepSpeed this works, even though the underlying base_optimizer is a DeepSpeed Zero BF16_Optimizer instance, since it uses the same underlying parameter groups as the original PyTorch optimizer (see here and here).

This fix is achieved by triggering the SAMOptimizer wrapping later in the code. After this, distributed training works as expected.

Before submitting

  • Have you read the contributor guidelines?
  • Is this change a documentation change or typo fix? If so, skip the rest of this checklist.
  • Was this change discussed/approved in a GitHub issue first? It is much more likely to be merged if so.
  • Did you update any related docs and document your change?
  • Did you update any related tests and add any new tests related to your change? (see testing)
  • Did you run the tests locally to make sure they pass?
  • Did you run pre-commit on your change? (see the pre-commit section of prerequisites)

@Joqsan Joqsan requested a review from a team as a code owner May 6, 2024 14:28
@mvpatel2000
Copy link
Contributor

Thanks for the PR! Once the tests pass, you can tag me for review.

Also, the PR description is really well done 🙏 -- appreciate the detailed writeup. Fix makes sense to me! Once tests pass I'll approve

@Joqsan
Copy link
Contributor Author

Joqsan commented May 8, 2024

@mvpatel2000 Hi!

I fixed the failing tests.

Now the SAM algorithm always set the state.scaler to ClosureGradScaler(). This is safe, since Trainer._use_grad_scaling() takes care of enabling/disabling the grad scaling step during training.

@Joqsan
Copy link
Contributor Author

Joqsan commented May 8, 2024

@mvpatel2000 yapf formatting workflows are failing, but they are not related to files I changed in this PR.

How should I proceed?

@mvpatel2000
Copy link
Contributor

@mvpatel2000 yapf formatting workflows are failing, but they are not related to files I changed in this PR.

How should I proceed?

I can help out with linting :) will take a look soon!

Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@mvpatel2000 mvpatel2000 merged commit e625a06 into mosaicml:dev May 10, 2024
15 checks passed
j316chuck pushed a commit that referenced this pull request May 16, 2024
* add SAM tests with FSDP and DeepSpeed

* fix SAM for distributed training

* SAMOptimizer always needs ClosureGradScaler

* lint

---------

Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants