Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor flash attention implementation in transformers #31446

Merged
merged 62 commits into from
Jul 11, 2024
Merged

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Jun 17, 2024

What does this PR do?

EDIT: just refactor for now

Enables us to run transformers model with Ragged Tensors:

image

One of the goals is also to make it easy for people to re-define the ExtraKwargs typedict, to build on top of transformers

@ArthurZucker
Copy link
Collaborator Author

cc @fxmarty, @LysandreJik and @OlivierDehaene

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator Author

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much @fxmarty for going to the end of this!

@fxmarty
Copy link
Contributor

fxmarty commented Jul 10, 2024

No more flash attention tests fail here compared to main (on H100).

================================================================== short test summary info ==================================================================
FAILED tests/models/bark/test_modeling_bark.py::BarkSemanticModelTest::test_flash_attn_2_from_config - ValueError: Unrecognized configuration class <class 'transformers.models.bark.configuration_bark.BarkSemanticConfig'> for this kind of AutoModel: AutoMo...
FAILED tests/models/bark/test_modeling_bark.py::BarkCoarseModelTest::test_flash_attn_2_from_config - ValueError: Unrecognized configuration class <class 'transformers.models.bark.configuration_bark.BarkCoarseConfig'> for this kind of AutoModel: AutoMode...
FAILED tests/models/dpr/test_modeling_dpr.py::DPRModelTest::test_sdpa_can_dispatch_on_flash - RuntimeError: No available kernel. Aborting execution.
FAILED tests/models/gemma/test_modeling_gemma.py::GemmaIntegrationTest::test_model_2b_flash_attn - OSError: You are trying to access a gated repo.
FAILED tests/models/gemma2/test_modeling_gemma2.py::Gemma2ModelTest::test_flash_attn_2_equivalence - AssertionError: assert False
FAILED tests/models/gemma2/test_modeling_gemma2.py::Gemma2ModelTest::test_sdpa_can_dispatch_on_flash - RuntimeError: No available kernel. Aborting execution.
FAILED tests/models/gpt_neox/test_modeling_gpt_neox.py::GPTNeoXModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/idefics2/test_modeling_idefics2.py::Idefics2ModelTest::test_flash_attn_2_inference_equivalence_right_padding - ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour for Flash Attention version...
FAILED tests/models/idefics2/test_modeling_idefics2.py::Idefics2ForConditionalGenerationModelTest::test_flash_attn_2_inference_equivalence_right_padding - ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour for Flash Attention version...
FAILED tests/models/jamba/test_modeling_jamba.py::JambaModelTest::test_flash_attn_2_generate_padding_right - AssertionError: ValueError not raised
FAILED tests/models/jamba/test_modeling_jamba.py::JambaModelTest::test_sdpa_can_dispatch_on_flash - RuntimeError: No available kernel. Aborting execution.
FAILED tests/models/m2m_100/test_modeling_m2m_100.py::M2M100ModelTest::test_flash_attn_2_from_config - ValueError: Unrecognized configuration class <class 'transformers.models.m2m_100.configuration_m2m_100.M2M100Config'> for this kind of AutoModel: AutoMo...
FAILED tests/models/m2m_100/test_modeling_m2m_100.py::M2M100ModelIntegrationTests::test_flash_attn_2_seq_to_seq_generation - RuntimeError: FlashAttention only support fp16 and bf16 data type
FAILED tests/models/mixtral/test_modeling_mixtral.py::MixtralModelTest::test_flash_attn_2_generate_padding_right - AssertionError: ValueError not raised
FAILED tests/models/olmo/test_modeling_olmo.py::OlmoModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/phi3/test_modeling_phi3.py::Phi3ModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/qwen2/test_modeling_qwen2.py::Qwen2ModelTest::test_flash_attn_2_generate_padding_right - AssertionError: ValueError not raised
FAILED tests/models/qwen2/test_modeling_qwen2.py::Qwen2ModelTest::test_flash_attn_2_inference_equivalence - AssertionError: assert False
FAILED tests/models/qwen2_moe/test_modeling_qwen2_moe.py::Qwen2MoeModelTest::test_flash_attn_2_generate_padding_right - AssertionError: ValueError not raised
FAILED tests/models/qwen2_moe/test_modeling_qwen2_moe.py::Qwen2MoeModelTest::test_flash_attn_2_inference_equivalence - AssertionError: assert False
FAILED tests/models/stablelm/test_modeling_stablelm.py::StableLmModelTest::test_flash_attn_2_generate_padding_right - AssertionError: False is not true
FAILED tests/models/starcoder2/test_modeling_starcoder2.py::Starcoder2ModelTest::test_flash_attn_2_generate_padding_right - AssertionError: ValueError not raised
FAILED tests/models/unispeech/test_modeling_unispeech.py::UniSpeechRobustModelTest::test_flash_attn_2_inference_equivalence - RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16
FAILED tests/models/unispeech/test_modeling_unispeech.py::UniSpeechRobustModelTest::test_flash_attn_2_inference_equivalence_right_padding - RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16
FAILED tests/models/unispeech/test_modeling_unispeech.py::UniSpeechRobustModelTest::test_sdpa_can_dispatch_on_flash - RuntimeError: mat1 and mat2 must have the same dtype, but got Float and Half
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelTest::test_flash_attn_2_from_config - IndexError: too many indices for tensor of dimension 2
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelTest::test_flash_attn_2_inference_equivalence_right_padding - IndexError: too many indices for tensor of dimension 2
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperStandaloneDecoderModelTest::test_flash_attn_2_from_config - IndexError: too many indices for tensor of dimension 2
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperStandaloneDecoderModelTest::test_flash_attn_2_generate_left_padding - IndexError: too many indices for tensor of dimension 2
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperStandaloneDecoderModelTest::test_flash_attn_2_inference_equivalence - AssertionError: assert False
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperStandaloneDecoderModelTest::test_flash_attn_2_inference_equivalence_right_padding - AssertionError: assert False
================================= 31 failed, 1064 passed, 1789 skipped, 62671 deselected, 166 warnings in 331.97s (0:05:31) =================================

Testing on MI250 for extra safety and good to merge.

edit: all good, can be merged

@fxmarty fxmarty merged commit e314395 into main Jul 11, 2024
26 checks passed
@fxmarty fxmarty deleted the backend-compatible branch July 11, 2024 12:37
This was referenced Jul 11, 2024
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jul 19, 2024
…31446)

* dumb commit

* nit

* update

* something like this

* unpack in modeling utils

* safe import

* oups

* update

* nits

* diff convert gemma

* update

* start propagating

* udpate other modeling code as well

* update for sliding window models

* nits

* more init cleanups

* styling

* fixup

* noice

* pass fixup

* typo typing_extension -> typing_extensions

* torch.nn.functionnal -> torch.nn.functional

* add to import structure

* unpack

* simplify a bit more for this first version

* nut

* update

* update

* nit

* ease the import of `Unpack`

* remove useless `use_sliding_window`

* no qua please

* protect import?

* style

* [run-slow]

* [run slow] llama,gemma,mistral,mixtral

* remove extra kwargs

* fix llama

* address review comments

* apply diff_model_converter to modeling_gemma.py

* remove cache_position 1

* remove cache_position 2

* some cleaning

* refactor gemma2 as well

* apply review comments

* rename file to modeling_flash_attention_utils.py

* siglip refactor

* remove dead code

* is the hub down?

* still down?

* fix siglip

* fix gemma2

* fatal: Could not read from remote repository.

* fix typo in softcap implem

* flacky

* Failed: Timeout >120.0s

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
MHRDYN7 pushed a commit to MHRDYN7/transformers that referenced this pull request Jul 23, 2024
…31446)

* dumb commit

* nit

* update

* something like this

* unpack in modeling utils

* safe import

* oups

* update

* nits

* diff convert gemma

* update

* start propagating

* udpate other modeling code as well

* update for sliding window models

* nits

* more init cleanups

* styling

* fixup

* noice

* pass fixup

* typo typing_extension -> typing_extensions

* torch.nn.functionnal -> torch.nn.functional

* add to import structure

* unpack

* simplify a bit more for this first version

* nut

* update

* update

* nit

* ease the import of `Unpack`

* remove useless `use_sliding_window`

* no qua please

* protect import?

* style

* [run-slow]

* [run slow] llama,gemma,mistral,mixtral

* remove extra kwargs

* fix llama

* address review comments

* apply diff_model_converter to modeling_gemma.py

* remove cache_position 1

* remove cache_position 2

* some cleaning

* refactor gemma2 as well

* apply review comments

* rename file to modeling_flash_attention_utils.py

* siglip refactor

* remove dead code

* is the hub down?

* still down?

* fix siglip

* fix gemma2

* fatal: Could not read from remote repository.

* fix typo in softcap implem

* flacky

* Failed: Timeout >120.0s

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jul 24, 2024
…31446)

* dumb commit

* nit

* update

* something like this

* unpack in modeling utils

* safe import

* oups

* update

* nits

* diff convert gemma

* update

* start propagating

* udpate other modeling code as well

* update for sliding window models

* nits

* more init cleanups

* styling

* fixup

* noice

* pass fixup

* typo typing_extension -> typing_extensions

* torch.nn.functionnal -> torch.nn.functional

* add to import structure

* unpack

* simplify a bit more for this first version

* nut

* update

* update

* nit

* ease the import of `Unpack`

* remove useless `use_sliding_window`

* no qua please

* protect import?

* style

* [run-slow]

* [run slow] llama,gemma,mistral,mixtral

* remove extra kwargs

* fix llama

* address review comments

* apply diff_model_converter to modeling_gemma.py

* remove cache_position 1

* remove cache_position 2

* some cleaning

* refactor gemma2 as well

* apply review comments

* rename file to modeling_flash_attention_utils.py

* siglip refactor

* remove dead code

* is the hub down?

* still down?

* fix siglip

* fix gemma2

* fatal: Could not read from remote repository.

* fix typo in softcap implem

* flacky

* Failed: Timeout >120.0s

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants