Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

accelerate deepspeed and gradient accumulation integrate #23236

Merged
merged 37 commits into from
May 31, 2023

Conversation

pacman100
Copy link
Contributor

@pacman100 pacman100 commented May 9, 2023

What does this PR do?

  1. Shift deepspeed integration to accelerate
  2. Shift Gradient Accumulation to Accelerate
  3. Merge after shift torch dynamo handling to accelerate #23168
  4. no user facing change. Now user can use accelerate launch with trainer for DeepSpeed, e.g.:
accelerate launch --num_processes=2 --mixed_precision=bf16 --use_deepspeed --gradient_accumulation_steps=1 --gradient_clipping=1 --zero3_init_flag=True --zero3_save_16bit_model=False --zero_stage=3 --offload_optimizer_device=none --offload_param_device=none ./examples/pytorch/text-classification/run_glue.py  --model_name_or_path bert-base-cased   --task_name $TASK_NAME   --do_train   --do_eval   --max_seq_length 128   --per_device_train_batch_size 16   --learning_rate 5e-5   --num_train_epochs 3   --output_dir /tmp/$TASK_NAME/ --overwrite_output_dir --bf16

Usual run using torchrun and trainer args is unimpacted:

torchrun --nnodes 1 --nproc-per-node 2 ./examples/pytorch/text-classification/run_glue.py  --model_name_or_path bert-base-cased   --task_name $TASK_NAME   --do_train   --do_eval   --max_seq_length 128   --per_device_train_batch_size 16   --learning_rate 5e-5   --num_train_epochs 3   --output_dir /tmp/$TASK_NAME/ --overwrite_output_dir --deepspeed ~/transformers/tests/deepspeed/ds_config_zero2.json
  1. Save and load utils are changed accordingly

@pacman100 pacman100 requested review from sgugger and muellerzr and removed request for sgugger May 9, 2023 15:11
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 9, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! Looks great, once tests pass fully :)

@pacman100 pacman100 changed the base branch from smangrul/accelerate-dynamo-integrate to main May 10, 2023 04:10
@pacman100 pacman100 changed the base branch from main to smangrul/accelerate-dynamo-integrate May 10, 2023 04:14
@pacman100 pacman100 changed the base branch from smangrul/accelerate-dynamo-integrate to main May 10, 2023 04:15
@pacman100 pacman100 changed the base branch from main to smangrul/accelerate-dynamo-integrate May 10, 2023 04:15
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG2M, one clarification which indeed is true

src/transformers/trainer.py Show resolved Hide resolved
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. Is the diff longer than expected because of other PRs to be merged before?

Might be cool to Have Stas have a look (not pinging him here too early) once this is ready to merge and tests are confirmed to all pass.

@pacman100
Copy link
Contributor Author

Thanks for working on this. Is the diff longer than expected because of other PRs to be merged before?

Due to updating from main, it is not showing the diff wrt previous branches. Weird.

Might be cool to Have Stas have a look (not pinging him here too early) once this is ready to merge and tests are confirmed to all pass.

Yes, definitely. All tests are passing already. Checked the slow tests offline.

@pacman100 pacman100 changed the base branch from smangrul/accelerate-dynamo-integrate to main May 10, 2023 19:07
@pacman100 pacman100 changed the base branch from main to smangrul/accelerate-dynamo-integrate May 10, 2023 19:08
@pacman100
Copy link
Contributor Author

@sgugger, now the diff is only specific to DeepSpeed changes + gradient accumulation changes + saving/loading changes wrt previous PR.

@pacman100 pacman100 requested a review from stas00 May 10, 2023 19:09
@pacman100
Copy link
Contributor Author

Hello @stas00, please review this PR which aims to shift the accelerate handling in Trainer to Accelerate. Thank you!

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pacman100, I think the main concern is that this PR appears to be breaking BC in at least a few places - please correct me if I'm wrong. I think those cases are super minor and probably won't cause too many breakages for the users if any. I'll leave up to you to decide.

As I don't know all the nuances of Accelerate's Deepspeed integration I can't do a detailed review, but if all the SLOW tests pass it should be good.

request: as Accelerate is taking over please remove me from the Issues/PR templates as deepspeed integration maintainer while you're at it, since I won't be able to support users any longer.

Thank you!

docs/source/en/main_classes/deepspeed.mdx Outdated Show resolved Hide resolved
src/transformers/deepspeed.py Show resolved Hide resolved
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this! Looks great

Base automatically changed from smangrul/accelerate-dynamo-integrate to main May 31, 2023 09:12
@pacman100 pacman100 changed the title Smangrul/accelerate deepspeed integrate accelerate deepspeed and gradient accumulation integrate May 31, 2023
@pacman100 pacman100 merged commit a73b1d5 into main May 31, 2023
@pacman100 pacman100 deleted the smangrul/accelerate-deepspeed-integrate branch May 31, 2023 09:46
sheonhan pushed a commit to sheonhan/transformers that referenced this pull request Jun 1, 2023
…#23236)

* mixed precision support via accelerate

* fix issues

* fix for the sharded ddp case

* fix flax and tf failing tests

* `refactor the place to create `Accelerator` object

* move ddp prep to accelerate

* fix 😅

* resolving comments

* move fsdp handling to accelerate

* fixex

* fix saving

* shift torch dynamo handling to accelerate

* shift deepspeed integration and save & load utils to accelerate

* fix accelerate launcher support

* oops

* fix 🐛

* save ckpt fix

* Trigger CI

* nasty 🐛 😅

* as deepspeed needs grad_acc fixes, transfer grad_acc to accelerate

* make tests happy

* quality ✨

* loss tracked needs to account for grad_acc

* fixing the deepspeed tests

* quality ✨

* 😅😅😅

* tests 😡

* quality ✨

* Trigger CI

* resolve comments and fix the issue with the previous merge from branch

* Trigger CI

* accelerate took over deepspeed integration

---------

Co-authored-by: Stas Bekman <stas@stason.org>
gojiteji pushed a commit to gojiteji/transformers that referenced this pull request Jun 5, 2023
…#23236)

* mixed precision support via accelerate

* fix issues

* fix for the sharded ddp case

* fix flax and tf failing tests

* `refactor the place to create `Accelerator` object

* move ddp prep to accelerate

* fix 😅

* resolving comments

* move fsdp handling to accelerate

* fixex

* fix saving

* shift torch dynamo handling to accelerate

* shift deepspeed integration and save & load utils to accelerate

* fix accelerate launcher support

* oops

* fix 🐛

* save ckpt fix

* Trigger CI

* nasty 🐛 😅

* as deepspeed needs grad_acc fixes, transfer grad_acc to accelerate

* make tests happy

* quality ✨

* loss tracked needs to account for grad_acc

* fixing the deepspeed tests

* quality ✨

* 😅😅😅

* tests 😡

* quality ✨

* Trigger CI

* resolve comments and fix the issue with the previous merge from branch

* Trigger CI

* accelerate took over deepspeed integration

---------

Co-authored-by: Stas Bekman <stas@stason.org>
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…#23236)

* mixed precision support via accelerate

* fix issues

* fix for the sharded ddp case

* fix flax and tf failing tests

* `refactor the place to create `Accelerator` object

* move ddp prep to accelerate

* fix 😅

* resolving comments

* move fsdp handling to accelerate

* fixex

* fix saving

* shift torch dynamo handling to accelerate

* shift deepspeed integration and save & load utils to accelerate

* fix accelerate launcher support

* oops

* fix 🐛

* save ckpt fix

* Trigger CI

* nasty 🐛 😅

* as deepspeed needs grad_acc fixes, transfer grad_acc to accelerate

* make tests happy

* quality ✨

* loss tracked needs to account for grad_acc

* fixing the deepspeed tests

* quality ✨

* 😅😅😅

* tests 😡

* quality ✨

* Trigger CI

* resolve comments and fix the issue with the previous merge from branch

* Trigger CI

* accelerate took over deepspeed integration

---------

Co-authored-by: Stas Bekman <stas@stason.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants