Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rfc: graph: propose to support Grouped-query Attention #2018

Open
wants to merge 5 commits into
base: rfcs
Choose a base branch
from

Conversation

gyhintel
Copy link
Contributor

Description

This is to propose to support Grouped-query Attention in oneDNN Graph API.
Link to the rendered document.

@gyhintel gyhintel added the RFC A design document label Jul 31, 2024
@dzarukin
Copy link
Contributor

dzarukin commented Jul 31, 2024

@gyhintel, thanks for the RFC. Some questions:

  1. If the user specifies a graph with different dimensions for K x Q, H dim versus G dim as in the example, do you see any issues with expanding pattern matcher / underneath implementation to perform an extra check for H dividing G, and if yes, perform the logic as described in Option 2 in the document?
  2. If (1) can be done, would it mean that the user must modify their Graph, or dnnl::graph, to match our pattern for GQA?
  3. If (2) forces the user to do modification, will adding API from Option 1 in the document help to expand out pattern matcher so that there're no extra actions needed from the user side?

@gyhintel
Copy link
Contributor Author

gyhintel commented Aug 1, 2024

  1. If the user specifies a graph with different dimensions for K x Q, H dim versus G dim as in the example, do you see any issues with expanding pattern matcher / underneath implementation to perform an extra check for H dividing G, and if yes, perform the logic as described in Option 2 in the document?
  1. If there is no Reshape above the typical SDPA, Query(in shape (N, H, S, D)) and Key/Value (in shape (N, G, S, D)) have different head number dimension and cannot perform dot-product directly. This is the doc: "For example src can be broadcasted to wei, if the corresponding dimension in src is 1 (and vice versa). ".
  2. We can extend the MatMul broadcasting rules to support group broadcast. This is option 3. In this situation, we need to perform an extra check for H dividing G. There should be no issues.
  1. If (1) can be done, would it mean that the user must modify their Graph, or dnnl::graph, to match our pattern for GQA?

Yes, it means that the pattern cannot be used to optimize a framework graph directly. Users will have to map their GQA implementation graph to our pattern. This is the second cons of option 2.

  1. If (2) forces the user to do modification, will adding API from Option 1 in the document help to expand out pattern matcher so that there're no extra actions needed from the user side?

In the current Pytorch implementation, there are no extra actions from their side. But if the implementation in the community changes, still needs to handle the new implementation. This is the second cons of option1.

1. The pattern is less intuitive from GQA definition.
2. The pattern cannot be used to optimize a framework graph directly. Frameworks
will have to implement GQA fusion by themselves and leverage this option to
optimized the fused GQA.
Copy link
Contributor

@petercad petercad Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this turns out to be a serious con, it would be reasonable to add a pass to match the Option 1 subgraph and convert it to the Option 2 subgraph, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. If it is a serious con, we need to implement option 1 adding new ops and new patterns. It is a backend implementation that matches the Option 1 subgraph and converts it to the Option 2 subgraph. We can also implement it in other ways in the backend.
  2. If the pass can be done on the framework side, we only need to implement option 2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have to support and match the subgraph in Option 1 once the request pops up. With that, oneDNN will support and maintain several different patterns for the same GQA functionality. Maybe it's not an issue as even for now we choose to Option 1 as the initial step, the pattern may still change in the future as mentioned in the cons of Option 1.

(see broadcasting in
[ONNX](https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md) and
[NumPy](https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules).),
but actually it's added to the MatMul operation of cuDNN in order to support
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, thanks.

that the new broadcasting rule is only supported by the fused attention.
2. Same as option 2, still the pattern cannot be used to optimize a framework
graph directly. Frameworks will have to implement GQA fusion by themselves
and leverage this option to optimized the fused GQA.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another con here maybe that we rely on oneDNN matmul primitive kernels for reference implementation and testing in benchdnn which do not support the new broadcasting rule. Extending the broadcast semantics on graph side will also request additional effort for reference implementation and testing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, thanks.


## GQA in PyTorch

Unlike SDPA, PyTorch does not support GQA as a fused operations. In Huggingface
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI - the PyTorch PR just got merged this week: pytorch/pytorch#132689

@dzarukin
Copy link
Contributor

Any clarity at this point if users are fine with implementing Option 2 on their side, or Option 1 must be implemented instead?

@gyhintel gyhintel requested review from sanjivmshah and removed request for sanjivmshah August 23, 2024 08:08
@gyhintel
Copy link
Contributor Author

@chunyuan-w, @sanchitintel, Could you help take a look at this RFC? thanks!


| Matrix A | Matrix B | Matrix C = A x B |
| -- | -- | -- |
| B1 x 1 x B3 x M x K | B1 x B2 x 1 x M x K | B1 x B2 x B3 x M x N |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
| B1 x 1 x B3 x M x K | B1 x B2 x 1 x M x K | B1 x B2 x B3 x M x N |
| B1 x 1 x B3 x M x K | B1 x B2 x 1 x K x N | B1 x B2 x B3 x M x N |

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks.

@gyhintel gyhintel requested a review from a team as a code owner September 3, 2024 04:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
RFC A design document
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants