-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUDA] Attention kernel provider option #21344
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
8802e40
attention cuda kernel options
tianleiwu 065a87f
fix test
tianleiwu a001b73
update ut dependency
tianleiwu 968f931
fix windows build
tianleiwu 21571d6
add unit test cases
tianleiwu 09e9c4d
Merge branch 'main' into tlwu/attention_kernel_cuda_option
tianleiwu a1c1eec
add back onnxruntime_common dependency
tianleiwu 8a758a0
move option object to cuda provider
tianleiwu 81ee535
reserve a flag for cudnn flash attention; print debug info
tianleiwu 04d460c
format
tianleiwu 6ad0764
exclude hipify
tianleiwu 0e843e8
refactoring
tianleiwu 2862eb2
fix typo; warn that MATH=0 is ignored
tianleiwu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
166 changes: 166 additions & 0 deletions
166
onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "contrib_ops/cuda/bert/attention_kernel_options.h" | ||
#include <iomanip> | ||
#include <iostream> | ||
#include <sstream> | ||
#include "contrib_ops/cpu/bert/attention_common.h" | ||
#include "core/providers/shared_library/provider_api.h" | ||
#include "core/platform/env_var_utils.h" | ||
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" | ||
|
||
using namespace onnxruntime::contrib::attention; | ||
Check warning on line 13 in onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc GitHub Actions / Lint C++
|
||
|
||
namespace onnxruntime { | ||
void AttentionKernelOptions::Initialize(int value, bool use_build_flag) { | ||
if (value > 0) { | ||
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0; | ||
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0; | ||
use_trt_fused_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FUSED_ATTENTION)) > 0; | ||
use_cudnn_flash_attention_ = (value & static_cast<int>(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0; | ||
use_unfused_ = (value & static_cast<int>(AttentionBackend::MATH)) > 0; | ||
use_trt_flash_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FLASH_ATTENTION)) > 0; | ||
use_trt_cross_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CROSS_ATTENTION)) > 0; | ||
use_trt_causal_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0; | ||
} else { | ||
use_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFlashAttention, false); | ||
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false); | ||
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false); | ||
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false); | ||
use_unfused_ = true; | ||
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false); | ||
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false); | ||
use_trt_causal_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableFusedCausalAttention, false); | ||
} | ||
|
||
enable_kernel_debug_info_ = ParseEnvironmentVariableWithDefault<bool>(kEnableAttentionKernelDebugInfo, false); | ||
|
||
// When value is positive, we use 0 as default minimum sequence lengths to align with common usage in testing. | ||
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>( | ||
kMinSeqLenForFlashAttentionPackedQKV, | ||
value > 0 ? 0 : kDefaultMinSeqLenForFlashAttentionPackedQKV); | ||
|
||
min_seq_len_for_efficient_attention_fp32_ = ParseEnvironmentVariableWithDefault<int>( | ||
kMinSeqLenForEfficientAttentionFp32, | ||
value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32); | ||
|
||
if (use_build_flag) { | ||
// Some kernels can be disabled at build time. If they are disabled, we should not use them. | ||
#ifndef USE_FLASH_ATTENTION | ||
use_flash_attention_ = false; | ||
#endif | ||
|
||
#ifndef USE_MEMORY_EFFICIENT_ATTENTION | ||
use_efficient_attention_ = false; | ||
#endif | ||
} | ||
} | ||
|
||
void AttentionKernelOptions::InitializeOnce( | ||
int sdpa_kernel, bool use_build_flag) { | ||
std::call_once(this->initialize_once_flag_, [&]() { | ||
this->Initialize(sdpa_kernel, use_build_flag); | ||
if (this->enable_kernel_debug_info_) { | ||
this->Print(); | ||
} | ||
}); | ||
} | ||
|
||
void AttentionKernelOptions::Print() const { | ||
std::stringstream sstream; | ||
sstream << "AttentionKernelOptions:"; | ||
sstream << " FLASH_ATTENTION=" << int(use_flash_attention_); | ||
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_); | ||
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_); | ||
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_); | ||
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention_); | ||
sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention_); | ||
sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention_); | ||
sstream << " MATH=" << int(use_unfused_); | ||
|
||
if (!use_unfused_) { | ||
sstream << std::endl | ||
<< "Warning: Unfused kernel cannot be disabled right now. MATH=0 is ignored."; | ||
} | ||
|
||
// Output text in Cyan color to make it easier to spot | ||
std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl; | ||
} | ||
|
||
// Classify the kernel used in TRT fused runner. | ||
void AttentionKernelDebugInfo::SetTrtFusedKernel(bool causal, bool enable_trt_flash_attention, int sequence_length) { | ||
if (causal) { | ||
use_trt_causal_attention = true; | ||
} else if (enable_trt_flash_attention && sequence_length >= contrib::cuda::kMinSequenceLengthFlashAttention) { | ||
use_trt_flash_attention = true; | ||
} else { | ||
use_trt_fused_attention = true; | ||
} | ||
} | ||
|
||
void AttentionKernelDebugInfo::Print(const char* operator_name, | ||
const std::string& node_name, | ||
bool is_float16, | ||
bool is_bfloat16) const { | ||
std::stringstream sstream; | ||
sstream << "Operator=" << operator_name; | ||
|
||
if (node_name.length() > 0) { | ||
sstream << " Node=" << node_name; | ||
} | ||
|
||
if (is_bfloat16) { | ||
sstream << " DataType=bf16"; | ||
} else if (is_float16) { | ||
sstream << " DataType=fp16"; | ||
} else { | ||
sstream << " DataType=fp32"; | ||
} | ||
|
||
if (use_flash_attention.has_value() && use_flash_attention.value()) { | ||
sstream << " FLASH_ATTENTION=" << int(use_flash_attention.value()); | ||
} | ||
|
||
if (use_efficient_attention.has_value() && use_efficient_attention.value()) { | ||
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention.value()); | ||
} | ||
|
||
if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) { | ||
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention.value()); | ||
} | ||
|
||
if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) { | ||
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention.value()); | ||
} | ||
|
||
if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) { | ||
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention.value()); | ||
} | ||
|
||
if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) { | ||
sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention.value()); | ||
} | ||
|
||
if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) { | ||
sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention.value()); | ||
} | ||
|
||
bool use_fused = (use_flash_attention.has_value() && use_flash_attention.value()) || | ||
(use_efficient_attention.has_value() && use_efficient_attention.value()) || | ||
(use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) || | ||
(use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) || | ||
(use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) || | ||
(use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) || | ||
(use_trt_causal_attention.has_value() && use_trt_causal_attention.value()); | ||
|
||
// Fall back to unfused when no fused kernel is enabled. | ||
if (!use_fused) { | ||
sstream << " MATH=1"; | ||
} | ||
|
||
// Output text in Cyan color to make it easier to spot. | ||
std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl; | ||
} | ||
|
||
} // namespace onnxruntime |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning