Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shashank/seq id flash attn #738

Merged

Conversation

ShashankMosaicML
Copy link
Contributor

@ShashankMosaicML ShashankMosaicML commented Nov 15, 2023

This PR does three flash attention-related things:

  1. Adds support for sequence id masking when using flash attention 2.1.2 or higher.
  2. Gets rid of kv tensor repetition for grouped query attention when using flash attention 2.0.0 or higher.
  3. Adds support for sliding window attention when using flash attention 2.3.0 or higher.

WandB link to the experiments: https://wandb.ai/mosaic-ml/seq_id_FA_final_tests

Loss and throughput curves for 125M model trained to chinchilla steps:

  1. control is the main llmfoundry branch
  2. treat is the Shashank/seq_id_flash_attn (the branch corresponding to this PR) with no config changes.
  3. treat-seq-id-masking is the Shashank/seq_id_flash_attn (the branch corresponding to this PR) with sequence id masking turned on.
  4. treat-sliding-window-100 is the Shashank/seq_id_flash_attn (the branch corresponding to this PR) with sliding window of size 100.
Screenshot 2023-12-01 at 9 37 53 AM Screenshot 2023-12-01 at 9 38 48 AM Screenshot 2023-12-01 at 9 43 01 AM Screenshot 2023-12-01 at 9 43 44 AM

@ShashankMosaicML ShashankMosaicML marked this pull request as ready for review November 16, 2023 17:26
Copy link
Contributor

@vchiley vchiley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high level looks ok

in the pr description can you include figure showing mfu diff with and without masking and also figure showing convergence diff with and without masking

llmfoundry/models/layers/attention.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/attention.py Outdated Show resolved Hide resolved
llmfoundry/models/mpt/configuration_mpt.py Show resolved Hide resolved
llmfoundry/models/mpt/configuration_mpt.py Outdated Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to train some models to show equivalence of seq id with flash and other attention implementations

llmfoundry/models/layers/attention.py Outdated Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Show resolved Hide resolved
tests/test_flash_triton_torch.py Outdated Show resolved Hide resolved
tests/test_model.py Outdated Show resolved Hide resolved
tests/test_flash_attn.py Outdated Show resolved Hide resolved
tests/test_flash_attn.py Outdated Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/attention.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/attention.py Show resolved Hide resolved
llmfoundry/models/mpt/configuration_mpt.py Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
tests/test_flash_attn.py Outdated Show resolved Hide resolved
tests/test_flash_attn.py Outdated Show resolved Hide resolved
tests/test_model.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/attention.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/attention.py Show resolved Hide resolved
llmfoundry/models/layers/attention.py Show resolved Hide resolved
llmfoundry/models/mpt/configuration_mpt.py Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, lets look into that slow test a bit before merging.

tests/test_flash_attn.py Outdated Show resolved Hide resolved
tests/test_flash_attn.py Outdated Show resolved Hide resolved
tests/test_flash_attn.py Outdated Show resolved Hide resolved
tests/test_model.py Outdated Show resolved Hide resolved
@ShashankMosaicML ShashankMosaicML merged commit 84b5d96 into mosaicml:main Dec 4, 2023
10 checks passed
@ShashankMosaicML ShashankMosaicML deleted the shashank/seq_id_flash_attn branch December 4, 2023 17:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants