Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improve] Use PyTorch official scaled_dot_product_attention to accelerate MultiheadAttention. #1434

Merged
merged 3 commits into from
Mar 29, 2023

Conversation

mzr1996
Copy link
Member

@mzr1996 mzr1996 commented Mar 23, 2023

Motivation

Pytorch 2.0 has official scaled_dot_product_attention implementation, which will automatically select the
attention implementation like FlashAttention and Memory-Efficient Attention.

Modification

If the PyTorch version is higher than 2.0, use the official scaled_dot_product_attention in the
MultiheadAttention.

Use cases

Here is the speed comparsion on test (NVIDIA A100, FP16):

Original:

03/23 12:19:48 - mmengine - INFO - Epoch(test) [10/98]    eta: 0:02:17  time: 1.5650  data_time: 0.5420  memory: 4155  
03/23 12:19:52 - mmengine - INFO - Epoch(test) [20/98]    eta: 0:01:14  time: 0.3493  data_time: 0.0018  memory: 4155  
03/23 12:19:55 - mmengine - INFO - Epoch(test) [30/98]    eta: 0:00:51  time: 0.3411  data_time: 0.0014  memory: 4155  
03/23 12:19:58 - mmengine - INFO - Epoch(test) [40/98]    eta: 0:00:36  time: 0.2472  data_time: 0.0016  memory: 4155  
03/23 12:20:00 - mmengine - INFO - Epoch(test) [50/98]    eta: 0:00:26  time: 0.2463  data_time: 0.0014  memory: 4155  
03/23 12:20:03 - mmengine - INFO - Epoch(test) [60/98]    eta: 0:00:18  time: 0.2463  data_time: 0.0011  memory: 4155  
03/23 12:20:05 - mmengine - INFO - Epoch(test) [70/98]    eta: 0:00:12  time: 0.2462  data_time: 0.0010  memory: 4155  
03/23 12:20:08 - mmengine - INFO - Epoch(test) [80/98]    eta: 0:00:07  time: 0.2463  data_time: 0.0010  memory: 4155  
03/23 12:20:10 - mmengine - INFO - Epoch(test) [90/98]    eta: 0:00:03  time: 0.2471  data_time: 0.0010  memory: 4155  
03/23 12:20:13 - mmengine - INFO - Epoch(test) [98/98]  accuracy/top1: 82.3800  accuracy/top5: 96.1480

New (Use flash-attention):

03/23 12:18:33 - mmengine - INFO - Epoch(test) [10/98]    eta: 0:02:07  time: 1.4518  data_time: 0.6494  memory: 3538  
03/23 12:18:36 - mmengine - INFO - Epoch(test) [20/98]    eta: 0:01:07  time: 0.2795  data_time: 0.0016  memory: 3538  
03/23 12:18:38 - mmengine - INFO - Epoch(test) [30/98]    eta: 0:00:43  time: 0.1712  data_time: 0.0014  memory: 3538  
03/23 12:18:40 - mmengine - INFO - Epoch(test) [40/98]    eta: 0:00:30  time: 0.1957  data_time: 0.0012  memory: 3538  
03/23 12:18:41 - mmengine - INFO - Epoch(test) [50/98]    eta: 0:00:21  time: 0.1711  data_time: 0.0011  memory: 3538  
03/23 12:18:43 - mmengine - INFO - Epoch(test) [60/98]    eta: 0:00:15  time: 0.1705  data_time: 0.0010  memory: 3538  
03/23 12:18:45 - mmengine - INFO - Epoch(test) [70/98]    eta: 0:00:10  time: 0.1708  data_time: 0.0010  memory: 3538  
03/23 12:18:46 - mmengine - INFO - Epoch(test) [80/98]    eta: 0:00:06  time: 0.1706  data_time: 0.0009  memory: 3538  
03/23 12:18:48 - mmengine - INFO - Epoch(test) [90/98]    eta: 0:00:02  time: 0.1747  data_time: 0.0009  memory: 3538  
03/23 12:18:50 - mmengine - INFO - Epoch(test) [98/98]  accuracy/top1: 82.3820  accuracy/top5: 96.1500

Checklist

Before PR:

  • Pre-commit or other linting tools are used to fix the potential lint issues.
  • Bug fixes are fully covered by unit tests, the case that causes the bug should be added in the unit tests.
  • The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  • The documentation has been modified accordingly, like docstring or example tutorials.

After PR:

  • If the modification has potential influence on downstream or other related projects, this PR should be tested with those projects, like MMDet or MMSeg.
  • CLA has been signed and all committers have signed the CLA in this PR.

tools/test.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Mar 27, 2023

Codecov Report

❗ No coverage uploaded for pull request base (pretrain@c4ccae4). Click here to learn what that means.
Patch has no changes to coverable lines.

❗ Current head 0500319 differs from pull request most recent head e845df5. Consider uploading reports for the commit e845df5 to get more accurate results

Additional details and impacted files
@@             Coverage Diff             @@
##             pretrain    #1434   +/-   ##
===========================================
  Coverage            ?   85.17%           
===========================================
  Files               ?      228           
  Lines               ?    17095           
  Branches            ?     2680           
===========================================
  Hits                ?    14560           
  Misses              ?     2044           
  Partials            ?      491           
Flag Coverage Δ
unittests 85.17% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@fangyixiao18 fangyixiao18 merged commit b017670 into open-mmlab:pretrain Mar 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants