Skip to content

Commit

Permalink
attention cuda kernel options
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jul 13, 2024
1 parent 42b7ced commit 8802e40
Show file tree
Hide file tree
Showing 23 changed files with 333 additions and 78 deletions.
3 changes: 2 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,9 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $<TARGET_OBJECTS:onnxruntime_providers_cuda_obj>)
config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock)
add_dependencies(onnxruntime_providers_cuda_ut onnxruntime_providers_shared onnxruntime_test_utils)
target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} onnxruntime_test_utils)
if (MSVC)
# Cutlass code has an issue with the following:
# warning C4100: 'magic': unreferenced formal parameter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ struct OrtCUDAProviderOptionsV2 {
int prefer_nhwc = 0; // make the CUDA EP NHWC preferred
int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not
int use_tf32 = 1; // use TF32
int sdpa_kernel = 0; // Scaled Dot Product Attention kernel option
};
8 changes: 6 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,15 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF
// Environment variable to enable or disable flash attention. Default is 0 (enabled).
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";

// Minimum sequence length to enable memory efficient attention in FP32.
constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256;
// Minimum sequence length to perfer memory efficient attention when data type is float32
constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32";

// Default value for minimum sequence length to enable memory efficient attention in FP32.
constexpr int kDefaultMinSeqLenForEfficientAttentionFp32 = 256;

// Minimum sequence length to prefer flash attention when input format is packed QKV for MultiHeadAttention
constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV";

// Default value for the above setting.
constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513;

Expand Down
30 changes: 9 additions & 21 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/attention.h"
#include "contrib_ops/cuda/bert/bert_padding.h"
Expand Down Expand Up @@ -40,35 +39,24 @@ REGISTER_KERNEL_TYPED(MLFloat16)

template <typename T>
Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) {
disable_fused_self_attention_ =
sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
kernel_options_ = AttentionKernelOptions::GetInstance(this->SdpaKernel(), false);

enable_trt_flash_attention_ =
sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();

enable_fused_causal_attention_ =
sizeof(T) == 2 &&
ParseEnvironmentVariableWithDefault<bool>(attention::kEnableFusedCausalAttention, false);
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();

enable_fused_causal_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtCausalAttention();

#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ =
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();
#else
disable_memory_efficient_attention_ = true;
#endif

#if USE_FLASH_ATTENTION
disable_flash_attention_ =
sizeof(T) != 2 ||
onnxruntime::ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
attention::kMinSeqLenForFlashAttentionPackedQKV,
attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
#else
disable_flash_attention_ = true;
min_seq_len_for_flash_attention_packed_qkv_ = 0;
#endif
}

Expand Down Expand Up @@ -134,7 +122,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.num_heads,
parameters.num_heads);
// When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512.
if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) {
if (use_flash_attention && parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) {
use_flash_attention = false;
}
// Allocate buffers
Expand Down Expand Up @@ -220,7 +208,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == past &&
nullptr == present &&
(nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
(sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) &&
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size);

if (use_memory_efficient_attention) {
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/providers/cuda/cuda_kernel.h"
#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/attention_kernel_options.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -27,9 +28,9 @@ class Attention final : public CudaKernel, public AttentionBase {
bool enable_trt_flash_attention_;
bool enable_fused_causal_attention_;
bool disable_memory_efficient_attention_;
int min_seq_len_for_flash_attention_packed_qkv_;
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
mutable std::once_flag fused_fp16_runner_created_;
const AttentionKernelOptions* kernel_options_;
};

} // namespace cuda
Expand Down
53 changes: 53 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@

#include "contrib_ops/cuda/bert/attention_kernel_options.h"
#include "contrib_ops/cpu/bert/attention_common.h"
#include "core/platform/env_var_utils.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {

// Initialize the singleton instance
AttentionKernelOptions AttentionKernelOptions::instance;

void AttentionKernelOptions::Initialize(int value) {
if (value > 0) {
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
use_trt_fused_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FUSED_ATTENTION)) > 0;
use_unfused_ = (value & static_cast<int>(AttentionBackend::MATH)) > 0;
use_trt_flash_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FLASH_ATTENTION)) > 0;
use_trt_cross_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CROSS_ATTENTION)) > 0;
use_trt_causal_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0;
} else {
use_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
use_unfused_ = true;
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedCrossAttention, false);
use_trt_causal_attention_ = ParseEnvironmentVariableWithDefault<bool>(attention::kEnableFusedCausalAttention, false);
}

min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
attention::kMinSeqLenForFlashAttentionPackedQKV,
attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);

min_seq_len_for_efficient_attention_fp32_ = ParseEnvironmentVariableWithDefault<int>(
attention::kMinSeqLenForEfficientAttentionFp32,
attention::kDefaultMinSeqLenForEfficientAttentionFp32);

initialized_ = true;
}

const AttentionKernelOptions* AttentionKernelOptions::GetInstance(int sdpa_kernel, bool force_init) {
if (force_init || !instance.initialized_) {
instance.Initialize(sdpa_kernel);
}

return &instance;
}

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
59 changes: 59 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@

#pragma once

namespace onnxruntime {
namespace contrib {
namespace cuda {

enum class AttentionBackend : int {
FLASH_ATTENTION = 1,
EFFICIENT_ATTENTION = 2,
TRT_FUSED_ATTENTION = 4,
MATH = 8, // unfused

// The following kernels might be deprected in the future.
TRT_FLASH_ATTENTION = 16,
TRT_CROSS_ATTENTION = 32,
TRT_CAUSAL_ATTENTION = 64,
};

class AttentionKernelOptions {
public:
static const AttentionKernelOptions* GetInstance(int sdpa_kernel, bool force_init);

bool UseFlashAttention() const { return use_flash_attention_; }
bool UseEfficientAttention() const { return use_efficient_attention_; }
bool UseTrtFusedAttention() const { return use_trt_fused_attention_; }
bool UseUnfusedAttention() const { return use_unfused_; }
bool UseTrtFlashAttention() const { return use_trt_flash_attention_; }
bool UseTrtCrossAttention() const { return use_trt_cross_attention_; }
bool UseTrtCausalAttention() const { return use_trt_causal_attention_; }

int MinSeqLenForFlashAttentionPackedQkv() const { return min_seq_len_for_flash_attention_packed_qkv_; }
int MinSeqLenForEfficientAttentionFp32() const { return min_seq_len_for_efficient_attention_fp32_; }

protected:
void Initialize(int value);

private:
bool use_flash_attention_{true};
bool use_efficient_attention_{true};
bool use_trt_fused_attention_{true};
bool use_unfused_{true};
bool use_trt_flash_attention_{true};
bool use_trt_cross_attention_{true};

// Causal attention is disabled by default in #14732.
bool use_trt_causal_attention_{false};

int min_seq_len_for_flash_attention_packed_qkv_{0};

int min_seq_len_for_efficient_attention_fp32_{0};

bool initialized_{false};
static AttentionKernelOptions instance;
};

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
9 changes: 4 additions & 5 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,16 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

kernel_options_ = AttentionKernelOptions::GetInstance(this->SdpaKernel(), false);
#if USE_FLASH_ATTENTION
disable_flash_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
#else
disable_flash_attention_ = true;
#endif

#if USE_MEMORY_EFFICIENT_ATTENTION
// Memory efficient attention only supports float and float16, not bfloat16.
disable_memory_efficient_attention_ = std::is_same<T, BFloat16>::value ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
disable_memory_efficient_attention_ = std::is_same<T, BFloat16>::value || !kernel_options_->UseEfficientAttention();
#else
disable_memory_efficient_attention_ = true;
#endif
Expand Down Expand Up @@ -161,7 +160,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
(sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) &&
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.head_size);
if (!use_flash_attention && !use_memory_efficient_attention && local_window_size_ != -1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <memory>
#include "core/providers/cuda/cuda_kernel.h"
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cuda/bert/attention_kernel_options.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -32,6 +33,7 @@ class GroupQueryAttention final : public CudaKernel {
bool disable_memory_efficient_attention_;
static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256)
IAllocatorUniquePtr<int> zeros_;
const AttentionKernelOptions* kernel_options_;
};

} // namespace cuda
Expand Down
28 changes: 10 additions & 18 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

#include "core/providers/cuda/cuda_common.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/multihead_attention.h"
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
Expand Down Expand Up @@ -47,31 +46,23 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead.");

disable_fused_self_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);

enable_trt_flash_attention_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
kernel_options_ = AttentionKernelOptions::GetInstance(this->SdpaKernel(), false);
disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();

#if USE_FLASH_ATTENTION
disable_flash_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
attention::kMinSeqLenForFlashAttentionPackedQKV,
attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
#else
disable_flash_attention_ = true;
min_seq_len_for_flash_attention_packed_qkv_ = 0;
#endif

#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();
#else
disable_memory_efficient_attention_ = true;
#endif

disable_fused_cross_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedCrossAttention, false);
disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention();

// Allocate cache buffers
constexpr size_t cache_bytes = sizeof(int32_t) * (static_cast<size_t>(kCumulatedSequenceLengthCacheMaxBatchSize) + 1);
Expand Down Expand Up @@ -155,7 +146,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.num_heads);
// When input is packed QKV format, TensorRT kernel might be faster than flash attention when sequence length <= 512.
if (use_flash_attention && key == nullptr && value == nullptr &&
parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) {
parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) {
use_flash_attention = false;
}
// Allocate buffers
Expand Down Expand Up @@ -229,9 +220,10 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
}

#if USE_MEMORY_EFFICIENT_ATTENTION
int length_threshold = this->kernel_options_->MinSeqLenForEfficientAttentionFp32();
bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16
parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32 ||
parameters.kv_sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32;
parameters.sequence_length >= length_threshold ||
parameters.kv_sequence_length >= length_threshold;

bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/attention_kernel_options.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -31,12 +32,12 @@ class MultiHeadAttention final : public CudaKernel {
bool disable_fused_cross_attention_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
int min_seq_len_for_flash_attention_packed_qkv_;
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
mutable std::once_flag fused_fp16_runner_created_;
mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_;
mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_;
mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_;
const AttentionKernelOptions* kernel_options_;
};

} // namespace cuda
Expand Down
Loading

0 comments on commit 8802e40

Please sign in to comment.