Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GQA Rotary and Packed QKV with Flash #18906

Merged
merged 58 commits into from
Jan 24, 2024
Merged

Conversation

aciddelgado
Copy link
Contributor

Description

These changes add rotary embedding and packed qkv input to gqa. As of now, the changes are only supported with Flash-Attention (SM >= 80) but should soon be supported with Memory Efficient Attention as well.

Motivation and Context

With the fusion of rotary embedding into this Attention op, we hope to observe some perf gain. The packed QKV should also provide some perf gain in the context of certain models, like Llama2, that would benefit from running ops on the fused QKV matrix, rather than the separate Q, K, and V.

tianleiwu
tianleiwu previously approved these changes Jan 22, 2024
tianleiwu
tianleiwu previously approved these changes Jan 22, 2024
@aciddelgado aciddelgado merged commit cbb29d8 into main Jan 24, 2024
95 of 98 checks passed
@aciddelgado aciddelgado deleted the aciddelgado/gqa_rotary_packed branch January 24, 2024 00:34
YUNQIUGUO pushed a commit that referenced this pull request Jan 30, 2024
### Description
These changes add rotary embedding and packed qkv input to gqa. As of
now, the changes are only supported with Flash-Attention (SM >= 80) but
should soon be supported with Memory Efficient Attention as well.



### Motivation and Context
With the fusion of rotary embedding into this Attention op, we hope to
observe some perf gain. The packed QKV should also provide some perf
gain in the context of certain models, like Llama2, that would benefit
from running ops on the fused QKV matrix, rather than the separate Q, K,
and V.

---------

Co-authored-by: Yufeng Li <liyufeng1987@gmail.com>
kunal-vaishnavi added a commit that referenced this pull request Mar 13, 2024
### Description
This PR updates the replacement of MultiHeadAttention (MHA) with
GroupQueryAttention (GQA). It is related to the changes in [this
PR](#18906).

### Motivation and Context
The updated replacement of MHA with GQA includes the following fusion
changes.
- Apply sliding window within GQA
- Fuse the rotary embeddings within GQA
- Fuse the 3 MatMuls into 1 packed MatMul if possible
- Fuse the 3 Adds into 1 packed Add if possible
YUNQIUGUO pushed a commit that referenced this pull request Mar 21, 2024
### Description
This PR updates the replacement of MultiHeadAttention (MHA) with
GroupQueryAttention (GQA). It is related to the changes in [this
PR](#18906).

### Motivation and Context
The updated replacement of MHA with GQA includes the following fusion
changes.
- Apply sliding window within GQA
- Fuse the rotary embeddings within GQA
- Fuse the 3 MatMuls into 1 packed MatMul if possible
- Fuse the 3 Adds into 1 packed Add if possible
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants