Skip to content

Commit

Permalink
[Layers] Enable AMX FP16 of FlashAttn
Browse files Browse the repository at this point in the history
  • Loading branch information
abenmao committed Jun 28, 2024
1 parent 6656c54 commit 56629a6
Show file tree
Hide file tree
Showing 14 changed files with 134 additions and 50 deletions.
23 changes: 10 additions & 13 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "simple_mem_pool.h"
#include "transformer_ctx.h"
#include "transformer_util.h"
#include "type_selector.h"

/**
* WeiT: weight data type
Expand All @@ -50,9 +51,7 @@ class Attention {
: layerId(layerId), qkpo(ctx->attHeadSize, ctx->maxPosEmbed), norm(ctx) {

//todo(marvin): clear this code after all rotary_emb refactor
if constexpr (std::is_same<QKPO_CLS, LlamaRotaryEmbedding>::value) {
qkpo = LlamaRotaryEmbedding(ctx);
}
if constexpr (std::is_same<QKPO_CLS, LlamaRotaryEmbedding>::value) { qkpo = LlamaRotaryEmbedding(ctx); }

// Group attention or multi-head attention (multi-head attn is a special case of group attn)
if (ctx->attHeadNum % ctx->kvHeadNum == 0) {
Expand Down Expand Up @@ -1030,11 +1029,8 @@ class Attention {
void flashAttention(DecoderContext *ctx, xft::Matrix<ImT> &query, xft::Matrix<ImT> &key, xft::Matrix<ImT> &value,
xft::Matrix<ImT> &result, KVCacheTensor<KVCacheT> &presentKey, KVCacheTensor<KVCacheT> &presentValue,
const float *attnMask, int pastSeqLen) {
#if defined(AVX512_BF16_WEIGHT_ONLY_BF16)
using AttnT = bfloat16_t;
#else
using AttnT = float;
#endif
using AttnT = typename AttnTypeSelector<ImT>::type;

// How many heads this task should do
int batchSize = ctx->batchSize;
int respQHeads = this->endQHead - this->startQHead;
Expand Down Expand Up @@ -1068,6 +1064,8 @@ class Attention {
AttnT *dstPtr = kvBuf + seq * kvStride + i;
if constexpr (std::is_same_v<AttnT, bfloat16_t> && std::is_same_v<ImT, float>) {
bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float16_t> && std::is_same_v<ImT, float>) {
float16_t::cvt_float_to_float16(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, bfloat16_t>) {
bfloat16_t::cvt_bfloat16_to_float(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, float16_t>) {
Expand Down Expand Up @@ -1099,11 +1097,8 @@ class Attention {
void flashAttention(DecoderContext *ctx, xft::Matrix<ImT> &query, xft::Matrix<ImT> &key, xft::Matrix<ImT> &value,
xft::Matrix<ImT> &result, std::vector<KVCacheTensor<KVCacheT> *> &keyCaches,
std::vector<KVCacheTensor<KVCacheT> *> &valueCaches, std::vector<xft::SequenceMeta *> &seqs) {
#if defined(AVX512_BF16_WEIGHT_ONLY_BF16)
using AttnT = bfloat16_t;
#else
using AttnT = float;
#endif
using AttnT = typename AttnTypeSelector<ImT>::type;

// How many heads this task should do
int batchSize = seqs.size();
int respQHeads = this->endQHead - this->startQHead;
Expand Down Expand Up @@ -1137,6 +1132,8 @@ class Attention {
AttnT *dstPtr = kvBuf + seq * kvStride + i;
if constexpr (std::is_same_v<AttnT, bfloat16_t> && std::is_same_v<ImT, float>) {
bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float16_t> && std::is_same_v<ImT, float>) {
float16_t::cvt_float_to_float16(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, bfloat16_t>) {
bfloat16_t::cvt_bfloat16_to_float(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, float16_t>) {
Expand Down
10 changes: 10 additions & 0 deletions src/models/baichuan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ void Baichuan<WeiT, KVCacheT>::embeddingForward(int *ids, bfloat16_t *output, in
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void Baichuan<WeiT, KVCacheT>::embeddingForward(int *ids, float16_t *output, int tokenSize) {
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void Baichuan<WeiT, KVCacheT>::lastLayerNormForward(float *input, float *output, int rows) {
finalLN.forward(input, output, rows);
Expand All @@ -187,4 +192,9 @@ void Baichuan<WeiT, KVCacheT>::lastLayerNormForward(bfloat16_t *input, bfloat16_
finalLN.forward(input, output, rows);
}

template <typename WeiT, typename KVCacheT>
void Baichuan<WeiT, KVCacheT>::lastLayerNormForward(float16_t *input, float16_t *output, int rows) {
finalLN.forward(input, output, rows);
}

IMPLEMENT_MODEL(Baichuan, baichuan)
2 changes: 2 additions & 0 deletions src/models/baichuan.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ class Baichuan
void prepareAttnMask(int *ids, int step);
void embeddingForward(int *ids, float *output, int tokenSize);
void embeddingForward(int *ids, bfloat16_t *output, int tokenSize);
void embeddingForward(int *ids, float16_t *output, int tokenSize);
void lastLayerNormForward(float *input, float *output, int rows);
void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows);
void lastLayerNormForward(float16_t *input, float16_t *output, int rows);

private:
void setEmbeddingWeights(const std::string &modelPath);
Expand Down
12 changes: 11 additions & 1 deletion src/models/chatglm2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ void ChatGLM2<WeiT, KVCacheT>::embeddingForward(int *ids, bfloat16_t *output, in
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void ChatGLM2<WeiT, KVCacheT>::embeddingForward(int *ids, float16_t *output, int tokenSize) {
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void ChatGLM2<WeiT, KVCacheT>::lastLayerNormForward(float *input, float *output, int rows) {
finalLN.forward(input, output, rows);
Expand All @@ -125,6 +130,11 @@ void ChatGLM2<WeiT, KVCacheT>::lastLayerNormForward(bfloat16_t *input, bfloat16_
finalLN.forward(input, output, rows);
}

template <typename WeiT, typename KVCacheT>
void ChatGLM2<WeiT, KVCacheT>::lastLayerNormForward(float16_t *input, float16_t *output, int rows) {
finalLN.forward(input, output, rows);
}

// Return the position_ids + block_position_ids
// if position_ids is None:
// position_ids = self.get_position_ids(input_ids, device=input_ids.device)
Expand Down Expand Up @@ -175,4 +185,4 @@ int *ChatGLM2<WeiT, KVCacheT>::getPositionIds(int *ids, int batchSize, int seqLe
return positionIds;
}

IMPLEMENT_MODEL(ChatGLM2, chatglm2)
IMPLEMENT_MODEL(ChatGLM2, chatglm2)
4 changes: 3 additions & 1 deletion src/models/chatglm2.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ class ChatGLM2
virtual void prepareAttnMask(int *ids, int step);
virtual void embeddingForward(int *ids, float *output, int tokenSize);
virtual void embeddingForward(int *ids, bfloat16_t *output, int tokenSize);
virtual void embeddingForward(int *ids, float16_t *output, int tokenSize);
virtual void lastLayerNormForward(float *input, float *output, int rows);
virtual void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows);
virtual void lastLayerNormForward(float16_t *input, float16_t *output, int rows);
virtual int *getPositionIds(int *ids, int batchSize, int seqLen, int step) override;

private:
Expand All @@ -59,4 +61,4 @@ class ChatGLM2
int posBufSize;
};

REGISTER_MODEL(ChatGLM2, chatglm2)
REGISTER_MODEL(ChatGLM2, chatglm2)
10 changes: 0 additions & 10 deletions src/models/env_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,6 @@ int getFlashThresh() {
return envFlashThresh;
}

bool enableSkipMsk() {
// Skip masked attn in flash attention for better perf of long sequence, default disabled
static int skipMsk = -1;
if (skipMsk == -1) {
skipMsk = (getenv("ENABLE_SKIP_MASK") ? atoi(getenv("ENABLE_SKIP_MASK")) : 0);
if (skipMsk == 1) printf("ENABLE_SKIP_MASK is enabled for ignoring mask Q*K.\n");
}
return skipMsk == 1;
}

bool kvTrans() {
// Transpose KV Tensor to [batchSize, headNum, seqLen, headSize] for better perf of long sequence, default disabled
// TODO: add support for reorder and expand when beam_search>1
Expand Down
12 changes: 11 additions & 1 deletion src/models/gemma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ void GemmaLLM<WeiT, KVCacheT>::embeddingForward(int *ids, bfloat16_t *output, in
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void GemmaLLM<WeiT, KVCacheT>::embeddingForward(int *ids, float16_t *output, int tokenSize) {
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void GemmaLLM<WeiT, KVCacheT>::lastLayerNormForward(float *input, float *output, int rows) {
finalLN.forward(input, output, rows);
Expand All @@ -124,4 +129,9 @@ void GemmaLLM<WeiT, KVCacheT>::lastLayerNormForward(bfloat16_t *input, bfloat16_
finalLN.forward(input, output, rows);
}

IMPLEMENT_MODEL(GemmaLLM, gemma)
template <typename WeiT, typename KVCacheT>
void GemmaLLM<WeiT, KVCacheT>::lastLayerNormForward(float16_t *input, float16_t *output, int rows) {
finalLN.forward(input, output, rows);
}

IMPLEMENT_MODEL(GemmaLLM, gemma)
4 changes: 3 additions & 1 deletion src/models/gemma.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ class GemmaLLM

void embeddingForward(int *ids, float *output, int tokenSize);
void embeddingForward(int *ids, bfloat16_t *output, int tokenSize);
void embeddingForward(int *ids, float16_t *output, int tokenSize);

void lastLayerNormForward(float *input, float *output, int rows);
void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows);
void lastLayerNormForward(float16_t *input, float16_t *output, int rows);

private:
void setEmbeddingWeights(const std::string &modelPath);
Expand All @@ -49,4 +51,4 @@ class GemmaLLM
RmsNorm finalLN;
};

REGISTER_MODEL(GemmaLLM, gemma)
REGISTER_MODEL(GemmaLLM, gemma)
12 changes: 11 additions & 1 deletion src/models/qwen2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ void Qwen2LLM<WeiT, KVCacheT>::embeddingForward(int *ids, bfloat16_t *output, in
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void Qwen2LLM<WeiT, KVCacheT>::embeddingForward(int *ids, float16_t *output, int tokenSize) {
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void Qwen2LLM<WeiT, KVCacheT>::lastLayerNormForward(float *input, float *output, int rows) {
finalLN.forward(input, output, rows);
Expand All @@ -124,4 +129,9 @@ void Qwen2LLM<WeiT, KVCacheT>::lastLayerNormForward(bfloat16_t *input, bfloat16_
finalLN.forward(input, output, rows);
}

IMPLEMENT_MODEL(Qwen2LLM, qwen2)
template <typename WeiT, typename KVCacheT>
void Qwen2LLM<WeiT, KVCacheT>::lastLayerNormForward(float16_t *input, float16_t *output, int rows) {
finalLN.forward(input, output, rows);
}

IMPLEMENT_MODEL(Qwen2LLM, qwen2)
4 changes: 3 additions & 1 deletion src/models/qwen2.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ class Qwen2LLM

void embeddingForward(int *ids, float *output, int tokenSize);
void embeddingForward(int *ids, bfloat16_t *output, int tokenSize);
void embeddingForward(int *ids, float16_t *output, int tokenSize);

void lastLayerNormForward(float *input, float *output, int rows);
void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows);
void lastLayerNormForward(float16_t *input, float16_t *output, int rows);

private:
void setEmbeddingWeights(const std::string &modelPath);
Expand All @@ -49,4 +51,4 @@ class Qwen2LLM
RmsNorm finalLN;
};

REGISTER_MODEL(Qwen2LLM, qwen2)
REGISTER_MODEL(Qwen2LLM, qwen2)
12 changes: 11 additions & 1 deletion src/models/yarn_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ void YaRNLlama<WeiT, KVCacheT>::embeddingForward(int *ids, bfloat16_t *output, i
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void YaRNLlama<WeiT, KVCacheT>::embeddingForward(int *ids, float16_t *output, int tokenSize) {
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void YaRNLlama<WeiT, KVCacheT>::lastLayerNormForward(float *input, float *output, int rows) {
finalLN.forward(input, output, rows);
Expand All @@ -108,4 +113,9 @@ void YaRNLlama<WeiT, KVCacheT>::lastLayerNormForward(bfloat16_t *input, bfloat16
finalLN.forward(input, output, rows);
}

IMPLEMENT_MODEL(YaRNLlama, yarn_llama)
template <typename WeiT, typename KVCacheT>
void YaRNLlama<WeiT, KVCacheT>::lastLayerNormForward(float16_t *input, float16_t *output, int rows) {
finalLN.forward(input, output, rows);
}

IMPLEMENT_MODEL(YaRNLlama, yarn_llama)
4 changes: 3 additions & 1 deletion src/models/yarn_llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ class YaRNLlama

void embeddingForward(int *ids, float *output, int tokenSize);
void embeddingForward(int *ids, bfloat16_t *output, int tokenSize);
void embeddingForward(int *ids, float16_t *output, int tokenSize);

void lastLayerNormForward(float *input, float *output, int rows);
void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows);
void lastLayerNormForward(float16_t *input, float16_t *output, int rows);

private:
void setEmbeddingWeights(const std::string &modelPath);
Expand All @@ -51,4 +53,4 @@ class YaRNLlama
RmsNorm finalLN;
};

REGISTER_MODEL(YaRNLlama, yarn_llama)
REGISTER_MODEL(YaRNLlama, yarn_llama)
37 changes: 21 additions & 16 deletions src/utils/decoder_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

extern int getFlashThresh();
extern bool enableCATMLP();
extern bool enableSkipMsk();

class DecoderUtil {
public:
Expand Down Expand Up @@ -554,16 +553,31 @@ class DecoderUtil {

dnnl_sgemm(ta[0], tb[0], m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);

} else if (std::is_same_v<T, bfloat16_t>) {
} else {
CBLAS_TRANSPOSE ta, tb;
ta = transa ? CblasTrans : CblasNoTrans;
tb = transb ? CblasTrans : CblasNoTrans;

cblas_gemm_bf16bf16f32(CblasRowMajor, ta, tb, m, n, k, alpha, (const MKL_BF16 *)(A), lda,
(const MKL_BF16 *)(B), ldb, beta, C, ldc);
} else {
printf("Datatype Not supported yet\n");
exit(-1);
if (std::is_same_v<T, bfloat16_t>) {
cblas_gemm_bf16bf16f32(CblasRowMajor, ta, tb, m, n, k, alpha, (const MKL_BF16 *)(A), lda,
(const MKL_BF16 *)(B), ldb, beta, C, ldc);
} else if (std::is_same_v<T, float16_t>) {
static int mkl_enable_inst = -1;
if (mkl_enable_inst == -1) {
#ifdef AMX_FP16_WEIGHT_ONLY_FP16
// AMX FP16
mkl_enable_inst = mkl_enable_instructions(MKL_ENABLE_AVX512_E5);
#else
// AVX512_FP16, skip E4 avoiding illegal instruction error
mkl_enable_inst = mkl_enable_instructions(MKL_ENABLE_AVX512_E3);
#endif
}
cblas_gemm_f16f16f32(CblasRowMajor, ta, tb, m, n, k, alpha, (const MKL_F16 *)(A), lda,
(const MKL_F16 *)(B), ldb, beta, C, ldc);
} else {
printf("Datatype Not supported yet\n");
exit(-1);
}
}
}

Expand Down Expand Up @@ -806,13 +820,4 @@ class DecoderUtil {
sgemm((T *)AB, C, expABC, m, n, k, k, vStride, n, false, false);
updateOutTile(output, expABC, preSum, sum, preMax, max, m, n, stride);
}

static bool skipMskAttn(const float *attnMask, int m, int n, int stride) {
float lowest = std::numeric_limits<float>::lowest();
// left bottom is lowest
if (attnMask[(m - 1) * stride] == lowest)
return true;
else
return false;
}
};
Loading

0 comments on commit 56629a6

Please sign in to comment.