Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] redundant additional allgather during backward when using FSDP FULL_SHARD with gradient checkpointing #30404

Closed
2 of 4 tasks
yundai424 opened this issue Apr 22, 2024 · 7 comments · Fixed by #31578
Closed
2 of 4 tasks

Comments

@yundai424
Copy link
Contributor

yundai424 commented Apr 22, 2024

System Info

  • transformers version: 4.39.3
  • Platform: Linux-5.15.138.1-4.cm2-x86_64-with-glibc2.35
  • Python version: 3.10.2
  • Huggingface_hub version: 0.22.2
  • Safetensors version: 0.4.2
  • Accelerate version: 0.29.3
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: yes

Who can help?

@pacman100 @muellerzr

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

For reproduction script, please refer to this gist https://gist.github.com/yundai424/1b6f0fa9f23796033aaf585ae9e1d4b6

Input alpaca data can be downloaded via https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/blob/1d971b5591d198b810ccca6feaed420df9f39d28/alpaca/alpaca_data.json

If run the script as-is (which is essentially very basic finetuning, enabling fsdp and grad checkpointing via --fsdp and --gradient_checkpointing flags to TrainingArguments), it'll use the existing way of enabling gradient checkpointing, which is directly through torch.utils.checkpoint. By looking at the profiled trace view, it's obvious that there are 2 allgather ops per transformer block backward, and if tracing the call stack a bit you'll find the 1st one corresponds to the allgather for gradient checkpointing forward, and the 2nd one corresponds to the actual pre_backward hook for FSDP. This introduces unnecessary redundancy because the time gap between them is very short (i.e. the forward compute in gradient checkpointing) and is resulting in 2 -> 3 collectives per backprop (hence 3 -> 4 collectives per training step):

Screenshot 2024-04-22 at 1 02 45 PM

If uncomment line 67-75 in the training script, it uses a different torch gradient checkpointing implementation torch.distributed.algorithm._checkpoint as also suggested by Meta's official llama recipe. It's another layer wrapping on top of torch.utils.checkpoint. This will address the redundant allgather as described above, resulting in trace that looks like follows:

Screenshot 2024-04-22 at 1 03 21 PM

Note that the relative order of wrapping with FSDP module and calling apply_activation_checkpointing doesn't matter here. The above code example call apply_activation_checkpointing before wrapping with FSDP module, but it works in the other way too.

Expected behavior

When using FSDP with gradient checkpointing, should instead use torch.distributed.algorithm._checkpoint.apply_activation_checkpointing

Will be happy to draft a PR.

@yundai424 yundai424 changed the title [FSDP] redundant additional allgather during backward when using FSDP with gradient checkpointing [FSDP] redundant additional allgather during backward when using FSDP FULL_SHARD with gradient checkpointing Apr 22, 2024
@SunMarc
Copy link
Member

SunMarc commented Apr 23, 2024

Hi @yundai424, thanks for the detailed report ! If you can submit a PR to fix this issue, that would be awesome 🔥 cc @pacman100

@huggingface huggingface deleted a comment from github-actions bot May 23, 2024
@huggingface huggingface deleted a comment from github-actions bot Jun 17, 2024
@SunMarc
Copy link
Member

SunMarc commented Jun 17, 2024

Hi @yundai424, hope you are doing well. Do you still want to draft the PR ? Otherwise, I will try to fix the issue based on your report. Any advice is appreciated.

@yundai424
Copy link
Contributor Author

Hi @SunMarc thanks for checking! I later found there's already another flag (this) under fsdp_config that uses the torch distributed ckpt wrapper which addressed the issue. But it would be nice to just merge them

@SunMarc
Copy link
Member

SunMarc commented Jun 20, 2024

Hi @yundai424, thanks for answering and i'm glad you manage to fix the issue ! What do you mean by merging them ? I guess you are talking about gradient_checkpointing and activation_checkpointing.

@yundai424
Copy link
Contributor Author

yundai424 commented Jun 21, 2024

@SunMarc yes exactly!

@SunMarc
Copy link
Member

SunMarc commented Jun 21, 2024

I see here that we can't set both args to True and that we advise the user to use FSDP activation_checkpointing.

I think that we should not merge the two args for BC but maybe add a warning when gradient_checkpointing is set to True when using fsdp advising users to use activation_checkpointing arg instead since it uses a different torch gradient checkpointing implementation which is better suited for fsdp ! Would you like to open the PR ? However, I can do it !

@yundai424
Copy link
Contributor Author

yundai424 commented Jun 24, 2024

@SunMarc that makes sense, i can open the PR to add the warning : ) thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants