Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize attention fusion #3403

Merged
merged 28 commits into from
Sep 19, 2024
Merged

Generalize attention fusion #3403

merged 28 commits into from
Sep 19, 2024

Conversation

shivadbhavsar
Copy link
Contributor

@shivadbhavsar shivadbhavsar commented Aug 26, 2024

Comeplete solution for #2812

Changes (applicable when mlir attention is enabled):

  • No longer use gemm_softmax_gemm matcher in prefuse_ops pass
  • Add matcher for softmax base ops
  • Look for dot -> fused_reduce -> dot (-> pointwise) pattern in fuse_mlir pass
    • for valid attention fusion, fused_reduce should end with softmax base ops
    • create a single fused module consisting of these ops

Verified that the attention fusion works as before on various transformer models in our nas (bert, gpt, etc.)

@shivadbhavsar shivadbhavsar added the enhancement New feature or request label Aug 26, 2024
@shivadbhavsar shivadbhavsar self-assigned this Aug 26, 2024
Copy link

codecov bot commented Aug 26, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.04%. Comparing base (fde041e) to head (df7d9f0).
Report is 1 commits behind head on develop.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #3403   +/-   ##
========================================
  Coverage    92.04%   92.04%           
========================================
  Files          506      506           
  Lines        20856    20864    +8     
========================================
+ Hits         19196    19204    +8     
  Misses        1660     1660           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@shivadbhavsar shivadbhavsar linked an issue Aug 27, 2024 that may be closed by this pull request
@shivadbhavsar shivadbhavsar marked this pull request as ready for review August 28, 2024 16:32
src/module.cpp Show resolved Hide resolved
src/include/migraphx/match/softmax.hpp Show resolved Hide resolved
src/targets/gpu/prefuse_ops.cpp Show resolved Hide resolved
test/gpu/fuse_mlir.cpp Show resolved Hide resolved
test/module_test.cpp Show resolved Hide resolved
Copy link
Collaborator

@CharlieL7 CharlieL7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@shivadbhavsar
Copy link
Contributor Author

@pfultz2 @causten
I added a commit here to accept graphs with the extra reshape here, it should fix perf regressions related to that issue. This can probably be generalized to accept some number of shape ops, I think that can be a topic for a separate PR.

@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
df7d9f
Rate old
7c2fdf
Diff Compare
torchvision-resnet50 64 3,238.65 3,249.19 -0.32%
torchvision-resnet50_fp16 64 6,973.40 6,993.27 -0.28%
torchvision-densenet121 32 2,379.41 2,434.31 -2.26%
torchvision-densenet121_fp16 32 3,962.26 4,095.02 -3.24% 🔴
torchvision-inceptionv3 32 1,452.37 1,635.79 -11.21% 🔴
torchvision-inceptionv3_fp16 32 2,716.08 2,740.83 -0.90%
cadene-inceptionv4 16 772.19 776.76 -0.59%
cadene-resnext64x4 16 808.92 808.72 0.02%
slim-mobilenet 64 7,449.34 7,455.28 -0.08%
slim-nasnetalarge 64 208.21 208.38 -0.08%
slim-resnet50v2 64 3,444.10 3,435.08 0.26%
bert-mrpc-onnx 8 720.77 1,150.34 -37.34% 🔴
bert-mrpc-tf 1 310.99 314.36 -1.07%
pytorch-examples-wlang-gru 1 344.16 418.46 -17.75% 🔴
pytorch-examples-wlang-lstm 1 303.87 499.68 -39.19% 🔴
torchvision-resnet50_1 1 799.94 772.72 3.52% 🔆
cadene-dpn92_1 1 436.34 397.74 9.70% 🔆
cadene-resnext101_1 1 382.64 383.61 -0.25%
onnx-taau-downsample 1 344.22 344.76 -0.16%
dlrm-criteoterabyte 1 35.08 35.10 -0.06%
dlrm-criteoterabyte_fp16 1 58.11 58.12 -0.02%
agentmodel 1 8,195.44 7,932.67 3.31% 🔆
unet_fp16 2 58.57 57.85 1.23%
resnet50v1_fp16 1 954.63 935.68 2.02%
resnet50v1_int8 1 957.23 949.99 0.76%
bert_base_cased_fp16 64 1,153.35 1,153.06 0.03%
bert_large_uncased_fp16 32 355.69 355.77 -0.02%
bert_large_fp16 1 212.28 210.32 0.93%
distilgpt2_fp16 16 2,158.71 2,161.65 -0.14%
yolov5s 1 540.81 534.27 1.22%
tinyllama 1 43.39 43.40 -0.01%
vicuna-fastchat 1 179.92 170.43 5.57% 🔆
whisper-tiny-encoder 1 417.97 418.17 -0.05%
whisper-tiny-decoder 1 433.56 426.09 1.75%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

@causten causten merged commit 2769c73 into develop Sep 19, 2024
47 of 48 checks passed
@causten causten deleted the generalized_attn branch September 19, 2024 02:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fuse where into MLIR attention
7 participants