Skip to content

Commit

Permalink
[Layers] Enable AMX FP16 of FlashAttn
Browse files Browse the repository at this point in the history
  • Loading branch information
abenmao committed Jun 27, 2024
1 parent 6656c54 commit 55aa3fa
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 42 deletions.
23 changes: 10 additions & 13 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "simple_mem_pool.h"
#include "transformer_ctx.h"
#include "transformer_util.h"
#include "type_selector.h"

/**
* WeiT: weight data type
Expand All @@ -50,9 +51,7 @@ class Attention {
: layerId(layerId), qkpo(ctx->attHeadSize, ctx->maxPosEmbed), norm(ctx) {

//todo(marvin): clear this code after all rotary_emb refactor
if constexpr (std::is_same<QKPO_CLS, LlamaRotaryEmbedding>::value) {
qkpo = LlamaRotaryEmbedding(ctx);
}
if constexpr (std::is_same<QKPO_CLS, LlamaRotaryEmbedding>::value) { qkpo = LlamaRotaryEmbedding(ctx); }

// Group attention or multi-head attention (multi-head attn is a special case of group attn)
if (ctx->attHeadNum % ctx->kvHeadNum == 0) {
Expand Down Expand Up @@ -1030,11 +1029,8 @@ class Attention {
void flashAttention(DecoderContext *ctx, xft::Matrix<ImT> &query, xft::Matrix<ImT> &key, xft::Matrix<ImT> &value,
xft::Matrix<ImT> &result, KVCacheTensor<KVCacheT> &presentKey, KVCacheTensor<KVCacheT> &presentValue,
const float *attnMask, int pastSeqLen) {
#if defined(AVX512_BF16_WEIGHT_ONLY_BF16)
using AttnT = bfloat16_t;
#else
using AttnT = float;
#endif
using AttnT = typename AttnTypeSelector<ImT>::type;

// How many heads this task should do
int batchSize = ctx->batchSize;
int respQHeads = this->endQHead - this->startQHead;
Expand Down Expand Up @@ -1068,6 +1064,8 @@ class Attention {
AttnT *dstPtr = kvBuf + seq * kvStride + i;
if constexpr (std::is_same_v<AttnT, bfloat16_t> && std::is_same_v<ImT, float>) {
bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float16_t> && std::is_same_v<ImT, float>) {
float16_t::cvt_float_to_float16(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, bfloat16_t>) {
bfloat16_t::cvt_bfloat16_to_float(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, float16_t>) {
Expand Down Expand Up @@ -1099,11 +1097,8 @@ class Attention {
void flashAttention(DecoderContext *ctx, xft::Matrix<ImT> &query, xft::Matrix<ImT> &key, xft::Matrix<ImT> &value,
xft::Matrix<ImT> &result, std::vector<KVCacheTensor<KVCacheT> *> &keyCaches,
std::vector<KVCacheTensor<KVCacheT> *> &valueCaches, std::vector<xft::SequenceMeta *> &seqs) {
#if defined(AVX512_BF16_WEIGHT_ONLY_BF16)
using AttnT = bfloat16_t;
#else
using AttnT = float;
#endif
using AttnT = typename AttnTypeSelector<ImT>::type;

// How many heads this task should do
int batchSize = seqs.size();
int respQHeads = this->endQHead - this->startQHead;
Expand Down Expand Up @@ -1137,6 +1132,8 @@ class Attention {
AttnT *dstPtr = kvBuf + seq * kvStride + i;
if constexpr (std::is_same_v<AttnT, bfloat16_t> && std::is_same_v<ImT, float>) {
bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float16_t> && std::is_same_v<ImT, float>) {
float16_t::cvt_float_to_float16(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, bfloat16_t>) {
bfloat16_t::cvt_bfloat16_to_float(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, float16_t>) {
Expand Down
10 changes: 0 additions & 10 deletions src/models/env_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,6 @@ int getFlashThresh() {
return envFlashThresh;
}

bool enableSkipMsk() {
// Skip masked attn in flash attention for better perf of long sequence, default disabled
static int skipMsk = -1;
if (skipMsk == -1) {
skipMsk = (getenv("ENABLE_SKIP_MASK") ? atoi(getenv("ENABLE_SKIP_MASK")) : 0);
if (skipMsk == 1) printf("ENABLE_SKIP_MASK is enabled for ignoring mask Q*K.\n");
}
return skipMsk == 1;
}

bool kvTrans() {
// Transpose KV Tensor to [batchSize, headNum, seqLen, headSize] for better perf of long sequence, default disabled
// TODO: add support for reorder and expand when beam_search>1
Expand Down
37 changes: 21 additions & 16 deletions src/utils/decoder_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

extern int getFlashThresh();
extern bool enableCATMLP();
extern bool enableSkipMsk();

class DecoderUtil {
public:
Expand Down Expand Up @@ -554,16 +553,31 @@ class DecoderUtil {

dnnl_sgemm(ta[0], tb[0], m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);

} else if (std::is_same_v<T, bfloat16_t>) {
} else {
CBLAS_TRANSPOSE ta, tb;
ta = transa ? CblasTrans : CblasNoTrans;
tb = transb ? CblasTrans : CblasNoTrans;

cblas_gemm_bf16bf16f32(CblasRowMajor, ta, tb, m, n, k, alpha, (const MKL_BF16 *)(A), lda,
(const MKL_BF16 *)(B), ldb, beta, C, ldc);
} else {
printf("Datatype Not supported yet\n");
exit(-1);
if (std::is_same_v<T, bfloat16_t>) {
cblas_gemm_bf16bf16f32(CblasRowMajor, ta, tb, m, n, k, alpha, (const MKL_BF16 *)(A), lda,
(const MKL_BF16 *)(B), ldb, beta, C, ldc);
} else if (std::is_same_v<T, float16_t>) {
static int mkl_enable_inst = -1;
if (mkl_enable_inst == -1) {
#ifdef AMX_FP16_WEIGHT_ONLY_FP16
// AMX FP16
mkl_enable_inst = mkl_enable_instructions(MKL_ENABLE_AVX512_E5);
#else
// AVX512_FP16, skip E4 avoiding illegal instruction error
mkl_enable_inst = mkl_enable_instructions(MKL_ENABLE_AVX512_E3);
#endif
}
cblas_gemm_f16f16f32(CblasRowMajor, ta, tb, m, n, k, alpha, (const MKL_F16 *)(A), lda,
(const MKL_F16 *)(B), ldb, beta, C, ldc);
} else {
printf("Datatype Not supported yet\n");
exit(-1);
}
}
}

Expand Down Expand Up @@ -806,13 +820,4 @@ class DecoderUtil {
sgemm((T *)AB, C, expABC, m, n, k, k, vStride, n, false, false);
updateOutTile(output, expABC, preSum, sum, preMax, max, m, n, stride);
}

static bool skipMskAttn(const float *attnMask, int m, int n, int stride) {
float lowest = std::numeric_limits<float>::lowest();
// left bottom is lowest
if (attnMask[(m - 1) * stride] == lowest)
return true;
else
return false;
}
};
38 changes: 35 additions & 3 deletions src/utils/type_selector.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct TypeSelector<bfloat16_t> {
using OutType = bfloat16_t;
};

#ifdef XFT_GPU
#ifdef XFT_GPU
template <>
struct TypeSelector<float16_t> {
using InType = float16_t;
Expand All @@ -40,11 +40,43 @@ struct TypeSelector<float16_t> {
};
#endif

#ifdef AMX_FP16_WEIGHT_ONLY_FP16
#ifdef AVX512_FP16_WEIGHT_ONLY_FP16
template <>
struct TypeSelector<float16_t> {
using InType = float16_t;
using ImType = float16_t;
using OutType = float16_t;
};
#endif
#endif

template <typename T>
struct AttnTypeSelector;

template <>
struct AttnTypeSelector<float> {
#if defined(AVX512_BF16_WEIGHT_ONLY_BF16)
using type = bfloat16_t;
#elif defined(AVX512_FP16_WEIGHT_ONLY_FP16)
using type = float16_t;
#else
using type = float;
#endif
};

template <>
struct AttnTypeSelector<bfloat16_t> {
#if defined(AVX512_BF16_WEIGHT_ONLY_BF16)
using type = bfloat16_t;
#else
using type = float;
#endif
};

template <>
struct AttnTypeSelector<float16_t> {
#if defined(AVX512_FP16_WEIGHT_ONLY_FP16)
using type = float16_t;
#else
using type = float;
#endif
};

0 comments on commit 55aa3fa

Please sign in to comment.