Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Attention kernel provider option #21344

Merged
merged 13 commits into from
Jul 19, 2024
2 changes: 2 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ set(contrib_ops_excluded_files
"bert/attention_softmax.h"
"bert/attention_softmax.cu"
"bert/attention_prepare_qkv.cu"
"bert/attention_kernel_options.h"
"bert/attention_kernel_options.cc"
"bert/decoder_attention_impl.h"
"bert/decoder_attention_impl.cu"
"bert/decoder_masked_multihead_attention.h"
Expand Down
3 changes: 2 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,9 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $<TARGET_OBJECTS:onnxruntime_providers_cuda_obj>)
config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock)
add_dependencies(onnxruntime_providers_cuda_ut onnxruntime_test_utils onnxruntime_common)
target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_test_utils onnxruntime_common)
if (MSVC)
# Cutlass code has an issue with the following:
# warning C4100: 'magic': unreferenced formal parameter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@
int prefer_nhwc = 0; // make the CUDA EP NHWC preferred
int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not
int use_tf32 = 1; // use TF32
int sdpa_kernel = 0; // Scaled Dot Product Attention kernel option

Check warning on line 41 in include/onnxruntime/core/providers/cuda/cuda_provider_options.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: include/onnxruntime/core/providers/cuda/cuda_provider_options.h:41: Lines should be <= 120 characters long [whitespace/line_length] [2]
};
28 changes: 26 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,23 @@ constexpr const char* kDisableSparseAttentionV1 = "ORT_DISABLE_SPARSE_ATTENTION_
} // namespace sparse_attention

namespace attention {

enum class AttentionBackend : int {
FLASH_ATTENTION = 1,
EFFICIENT_ATTENTION = 2,
TRT_FUSED_ATTENTION = 4,
CUDNN_FLASH_ATTENTION = 8, // reserved for cuDNN flash attention.
MATH = 16, // unfused kernel cannot be disabled right now.

// The following kernels might be deprecated in the future.
TRT_FLASH_ATTENTION = 32,
TRT_CROSS_ATTENTION = 64,
TRT_CAUSAL_ATTENTION = 128,
};

// Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled).
constexpr const char* kEnableAttentionKernelDebugInfo = "ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO";

// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";

Expand All @@ -157,6 +174,9 @@ constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATT
// Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels.
constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION";

// Environment variable to enable or disable cuDNN flash attention.
constexpr const char* kEnableCudnnFlashAttention = "ORT_ENABLE_CUDNN_FLASH_ATTENTION";

// Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled).
constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION";

Expand All @@ -166,11 +186,15 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF
// Environment variable to enable or disable flash attention. Default is 0 (enabled).
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";

// Minimum sequence length to enable memory efficient attention in FP32.
constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256;
// Minimum sequence length to perfer memory efficient attention when data type is float32
constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32";

// Default value for minimum sequence length to enable memory efficient attention in FP32.
constexpr int kDefaultMinSeqLenForEfficientAttentionFp32 = 256;

// Minimum sequence length to prefer flash attention when input format is packed QKV for MultiHeadAttention
constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV";

// Default value for the above setting.
constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513;

Expand Down
53 changes: 23 additions & 30 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/attention.h"
#include "contrib_ops/cuda/bert/bert_padding.h"
Expand Down Expand Up @@ -40,36 +39,17 @@ REGISTER_KERNEL_TYPED(MLFloat16)

template <typename T>
Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) {
disable_fused_self_attention_ =
sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
kernel_options_ = this->GetAttentionKernelOptions();

enable_trt_flash_attention_ =
sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();

enable_fused_causal_attention_ =
sizeof(T) == 2 &&
ParseEnvironmentVariableWithDefault<bool>(attention::kEnableFusedCausalAttention, false);
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();

#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ =
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
#else
disable_memory_efficient_attention_ = true;
#endif
enable_fused_causal_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtCausalAttention();

#if USE_FLASH_ATTENTION
disable_flash_attention_ =
sizeof(T) != 2 ||
onnxruntime::ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
attention::kMinSeqLenForFlashAttentionPackedQKV,
attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);
#else
disable_flash_attention_ = true;
min_seq_len_for_flash_attention_packed_qkv_ = 0;
#endif
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();

disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
}

template <typename T>
Expand Down Expand Up @@ -134,7 +114,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.num_heads,
parameters.num_heads);
// When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512.
if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) {
if (use_flash_attention && parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) {
use_flash_attention = false;
}
// Allocate buffers
Expand Down Expand Up @@ -220,7 +200,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == past &&
nullptr == present &&
(nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
(sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) &&
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size);

if (use_memory_efficient_attention) {
Expand All @@ -231,6 +211,20 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
constexpr bool use_memory_efficient_attention = false;
#endif

if (kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_flash_attention = use_flash_attention;
debug_info.use_efficient_attention = use_memory_efficient_attention;
if (fused_runner != nullptr) {
debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length);
}

debug_info.Print("Attention",
this->Node().Name(),
std::is_same<T, MLFloat16>::value,
std::is_same<T, BFloat16>::value);
}

cublasHandle_t cublas = GetCublasHandle(context);

typedef typename ToCudaType<T>::MappedType CudaT;
Expand Down Expand Up @@ -268,7 +262,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
use_fused_cross_attention,
use_memory_efficient_attention);
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());
;

typedef typename ToCudaType<T>::MappedType CudaT;
AttentionData<CudaT> data;
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/providers/cuda/cuda_kernel.h"
#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/attention_kernel_options.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -27,9 +28,10 @@ class Attention final : public CudaKernel, public AttentionBase {
bool enable_trt_flash_attention_;
bool enable_fused_causal_attention_;
bool disable_memory_efficient_attention_;
int min_seq_len_for_flash_attention_packed_qkv_;
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
mutable std::once_flag fused_fp16_runner_created_;

const AttentionKernelOptions* kernel_options_;
};

} // namespace cuda
Expand Down
166 changes: 166 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
// Licensed under the MIT License.

#include "contrib_ops/cuda/bert/attention_kernel_options.h"
#include <iomanip>
#include <iostream>
#include <sstream>
#include "contrib_ops/cpu/bert/attention_common.h"
#include "core/providers/shared_library/provider_api.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"

using namespace onnxruntime::contrib::attention;

Check warning on line 13 in onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc:13: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

namespace onnxruntime {
void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
if (value > 0) {
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
use_trt_fused_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FUSED_ATTENTION)) > 0;
use_cudnn_flash_attention_ = (value & static_cast<int>(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0;
use_unfused_ = (value & static_cast<int>(AttentionBackend::MATH)) > 0;
use_trt_flash_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FLASH_ATTENTION)) > 0;
use_trt_cross_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CROSS_ATTENTION)) > 0;
use_trt_causal_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0;
} else {
use_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFlashAttention, false);
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);
use_unfused_ = true;
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false);
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false);
use_trt_causal_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableFusedCausalAttention, false);
}

enable_kernel_debug_info_ = ParseEnvironmentVariableWithDefault<bool>(kEnableAttentionKernelDebugInfo, false);

// When value is positive, we use 0 as default minimum sequence lengths to align with common usage in testing.
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
kMinSeqLenForFlashAttentionPackedQKV,
value > 0 ? 0 : kDefaultMinSeqLenForFlashAttentionPackedQKV);

min_seq_len_for_efficient_attention_fp32_ = ParseEnvironmentVariableWithDefault<int>(
kMinSeqLenForEfficientAttentionFp32,
value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32);

if (use_build_flag) {
// Some kernels can be disabled at build time. If they are disabled, we should not use them.
#ifndef USE_FLASH_ATTENTION
use_flash_attention_ = false;
#endif

#ifndef USE_MEMORY_EFFICIENT_ATTENTION
use_efficient_attention_ = false;
#endif
}
}

void AttentionKernelOptions::InitializeOnce(
int sdpa_kernel, bool use_build_flag) {
std::call_once(this->initialize_once_flag_, [&]() {
this->Initialize(sdpa_kernel, use_build_flag);
if (this->enable_kernel_debug_info_) {
this->Print();
}
});
}

void AttentionKernelOptions::Print() const {
std::stringstream sstream;
sstream << "AttentionKernelOptions:";
sstream << " FLASH_ATTENTION=" << int(use_flash_attention_);
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_);
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_);
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_);
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention_);
sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention_);
sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention_);
sstream << " MATH=" << int(use_unfused_);

if (!use_unfused_) {
sstream << std::endl
<< "Warning: Unfused kernel cannot be disabled right now. MATH=0 is ignored.";
}

// Output text in Cyan color to make it easier to spot
std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl;
}

// Classify the kernel used in TRT fused runner.
void AttentionKernelDebugInfo::SetTrtFusedKernel(bool causal, bool enable_trt_flash_attention, int sequence_length) {
if (causal) {
use_trt_causal_attention = true;
} else if (enable_trt_flash_attention && sequence_length >= contrib::cuda::kMinSequenceLengthFlashAttention) {
use_trt_flash_attention = true;
} else {
use_trt_fused_attention = true;
}
}

void AttentionKernelDebugInfo::Print(const char* operator_name,
const std::string& node_name,
bool is_float16,
bool is_bfloat16) const {
std::stringstream sstream;
sstream << "Operator=" << operator_name;

if (node_name.length() > 0) {
sstream << " Node=" << node_name;
}

if (is_bfloat16) {
sstream << " DataType=bf16";
} else if (is_float16) {
sstream << " DataType=fp16";
} else {
sstream << " DataType=fp32";
}

if (use_flash_attention.has_value() && use_flash_attention.value()) {
sstream << " FLASH_ATTENTION=" << int(use_flash_attention.value());
}

if (use_efficient_attention.has_value() && use_efficient_attention.value()) {
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention.value());
}

if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) {
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention.value());
}

if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) {
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention.value());
}

if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) {
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention.value());
}

if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) {
sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention.value());
}

if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) {
sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention.value());
}

bool use_fused = (use_flash_attention.has_value() && use_flash_attention.value()) ||
(use_efficient_attention.has_value() && use_efficient_attention.value()) ||
(use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) ||
(use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) ||
(use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) ||
(use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) ||
(use_trt_causal_attention.has_value() && use_trt_causal_attention.value());

// Fall back to unfused when no fused kernel is enabled.
if (!use_fused) {
sstream << " MATH=1";
}

// Output text in Cyan color to make it easier to spot.
std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl;
}

} // namespace onnxruntime
Loading
Loading