Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama et al. / FSDP : Fix breaking change in 4.40 for FSDP #31161

Merged
merged 14 commits into from
Jun 26, 2024

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented May 31, 2024

What does this PR do?

Fixes: #30523

Click to see the snippet (make sure to run `accelerate config` and select FSDP options before hand and run `accelerate launch script.py`)
from functools import partial
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator

# verify we have FSDP activation support ready by importing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing,
)

from transformers.models.llama.modeling_llama import LlamaDecoderLayer

model_id = "HuggingFaceM4/tiny-random-Llama3ForCausalLM"

model = AutoModelForCausalLM.from_pretrained(model_id)

model.train()
model.gradient_checkpointing_enable()

accelerator = Accelerator()
model = accelerator.prepare(model)

check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)

non_reentrant_wrapper = partial(
    checkpoint_wrapper,
    offload_to_cpu=False,
    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

apply_activation_checkpointing(
    model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)

print(model)

rand_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)

model(rand_input)

#30743 introduced a breaking change for users that use Llama-based models + FSDP + activation checkpointing with FSDP.

Before #30743 - we were able to pass arbitrary kwargs within Llama modules that were silently ignored. When doing FSDP + activation checkpointing, the target gradient checkpointing classes are wrapped in a new class, and additional kwargs are passed along that class forward pass

The script above used to work for transformers <= 4.40.0 and does not work anymore due to #30743 , re-intoducing kwargs in the foward method signature fixes the bug

cc @amyeroberts

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing and apologies for breaking this!

Some questions before we can merge

  • Would it make sense to add a test to make sure we don't accidentally break this again?
  • Having **kwargs in the forward method isn't standard amongst transformers models. Is there something special about these models which need this for FSDP? If not, should we be adding to other models?
  • Is there an alternative to using this injection? Relying on kwargs being passed isn't ideal

@younesbelkada
Copy link
Contributor Author

Thanks !

Would it make sense to add a test to make sure we don't accidentally break this again?

Yes, I'll add a test in this PR to test this behavior and catch bugs in the future!

Having **kwargs in the forward method isn't standard amongst transformers models. Is there something special about these models which need this for FSDP? If not, should we be adding to other models?

Yes agreed, I think we should add it to all 'most-used' models. FSDP is useful for large models, so I would say we should add it for LLMs (llama, gemma, mistral, mixtral, gpt-neo, etc.) to make things consistent. Happy to do that within this PR !

Is there an alternative to using this injection? Relying on kwargs being passed isn't ideal

I am not sure, this seems to be something internal to FSDP + CPU offloading, I don't think we can find a workaround to this :/ for me since it used to work before, it should be still working for future transformers versions to ensure BC. What do you think?

@amyeroberts
Copy link
Collaborator

Yes, I'll add a test in this PR to test this behavior and catch bugs in the future!
Yes agreed, I think we should add it to all 'most-used' models. FSDP is useful for large models, so I would say we should add it for LLMs (llama, gemma, mistral, mixtral, gpt-neo, etc.) to make things consistent. Happy to do that within this PR !

Awesome - thank you!

I am not sure, this seems to be something internal to FSDP + CPU offloading, I don't think we can find a workaround to this :/ for me since it used to work before, it should be still working for future transformers versions to ensure BC. What do you think?

Make sense - let's leave as-is :)

@amyeroberts
Copy link
Collaborator

amyeroberts commented Jun 26, 2024

@younesbelkada I'm really sorry I missed the rerequest for review. I don't have permissions to make changes, so copied the branch here: #31638 and sync with main. I couldn't push working locally but could change through the editor

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing @younesbelkada, and apologies for the delay in reviewing.

I was able to make the necessary updates to resolve conflicts with main through the online editor. As this was just merging new input argument it didn't affect the structure of the PR. I did remove the testing_utils scripts (which I would have asked you to remove in a review :) )

@amyeroberts amyeroberts merged commit 3f93fd0 into main Jun 26, 2024
23 checks passed
@amyeroberts amyeroberts deleted the fix-llama-fsdp branch June 26, 2024 13:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Llama Attention Call should not pass **kwargs
3 participants