Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Layers] Enable AMX FP16 of FlashAttn #459

Merged
merged 1 commit into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "simple_mem_pool.h"
#include "transformer_ctx.h"
#include "transformer_util.h"
#include "type_selector.h"

/**
* WeiT: weight data type
Expand All @@ -50,9 +51,7 @@ class Attention {
: layerId(layerId), qkpo(ctx->attHeadSize, ctx->maxPosEmbed), norm(ctx) {

//todo(marvin): clear this code after all rotary_emb refactor
if constexpr (std::is_same<QKPO_CLS, LlamaRotaryEmbedding>::value) {
qkpo = LlamaRotaryEmbedding(ctx);
}
if constexpr (std::is_same<QKPO_CLS, LlamaRotaryEmbedding>::value) { qkpo = LlamaRotaryEmbedding(ctx); }

// Group attention or multi-head attention (multi-head attn is a special case of group attn)
if (ctx->attHeadNum % ctx->kvHeadNum == 0) {
Expand Down Expand Up @@ -1030,11 +1029,8 @@ class Attention {
void flashAttention(DecoderContext *ctx, xft::Matrix<ImT> &query, xft::Matrix<ImT> &key, xft::Matrix<ImT> &value,
xft::Matrix<ImT> &result, KVCacheTensor<KVCacheT> &presentKey, KVCacheTensor<KVCacheT> &presentValue,
const float *attnMask, int pastSeqLen) {
#if defined(AVX512_BF16_WEIGHT_ONLY_BF16)
using AttnT = bfloat16_t;
#else
using AttnT = float;
#endif
using AttnT = typename AttnTypeSelector<ImT>::type;

// How many heads this task should do
int batchSize = ctx->batchSize;
int respQHeads = this->endQHead - this->startQHead;
Expand Down Expand Up @@ -1068,6 +1064,8 @@ class Attention {
AttnT *dstPtr = kvBuf + seq * kvStride + i;
if constexpr (std::is_same_v<AttnT, bfloat16_t> && std::is_same_v<ImT, float>) {
bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float16_t> && std::is_same_v<ImT, float>) {
float16_t::cvt_float_to_float16(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, bfloat16_t>) {
bfloat16_t::cvt_bfloat16_to_float(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, float16_t>) {
Expand Down Expand Up @@ -1099,11 +1097,8 @@ class Attention {
void flashAttention(DecoderContext *ctx, xft::Matrix<ImT> &query, xft::Matrix<ImT> &key, xft::Matrix<ImT> &value,
xft::Matrix<ImT> &result, std::vector<KVCacheTensor<KVCacheT> *> &keyCaches,
std::vector<KVCacheTensor<KVCacheT> *> &valueCaches, std::vector<xft::SequenceMeta *> &seqs) {
#if defined(AVX512_BF16_WEIGHT_ONLY_BF16)
using AttnT = bfloat16_t;
#else
using AttnT = float;
#endif
using AttnT = typename AttnTypeSelector<ImT>::type;

// How many heads this task should do
int batchSize = seqs.size();
int respQHeads = this->endQHead - this->startQHead;
Expand Down Expand Up @@ -1137,6 +1132,8 @@ class Attention {
AttnT *dstPtr = kvBuf + seq * kvStride + i;
if constexpr (std::is_same_v<AttnT, bfloat16_t> && std::is_same_v<ImT, float>) {
bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float16_t> && std::is_same_v<ImT, float>) {
float16_t::cvt_float_to_float16(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, bfloat16_t>) {
bfloat16_t::cvt_bfloat16_to_float(srcPtr, dstPtr, headSize);
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, float16_t>) {
Expand Down
10 changes: 10 additions & 0 deletions src/models/baichuan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ void Baichuan<WeiT, KVCacheT>::embeddingForward(int *ids, bfloat16_t *output, in
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void Baichuan<WeiT, KVCacheT>::embeddingForward(int *ids, float16_t *output, int tokenSize) {
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void Baichuan<WeiT, KVCacheT>::lastLayerNormForward(float *input, float *output, int rows) {
finalLN.forward(input, output, rows);
Expand All @@ -187,4 +192,9 @@ void Baichuan<WeiT, KVCacheT>::lastLayerNormForward(bfloat16_t *input, bfloat16_
finalLN.forward(input, output, rows);
}

template <typename WeiT, typename KVCacheT>
void Baichuan<WeiT, KVCacheT>::lastLayerNormForward(float16_t *input, float16_t *output, int rows) {
finalLN.forward(input, output, rows);
}

IMPLEMENT_MODEL(Baichuan, baichuan)
2 changes: 2 additions & 0 deletions src/models/baichuan.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ class Baichuan
void prepareAttnMask(int *ids, int step);
void embeddingForward(int *ids, float *output, int tokenSize);
void embeddingForward(int *ids, bfloat16_t *output, int tokenSize);
void embeddingForward(int *ids, float16_t *output, int tokenSize);
void lastLayerNormForward(float *input, float *output, int rows);
void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows);
void lastLayerNormForward(float16_t *input, float16_t *output, int rows);

private:
void setEmbeddingWeights(const std::string &modelPath);
Expand Down
12 changes: 11 additions & 1 deletion src/models/chatglm2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ void ChatGLM2<WeiT, KVCacheT>::embeddingForward(int *ids, bfloat16_t *output, in
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void ChatGLM2<WeiT, KVCacheT>::embeddingForward(int *ids, float16_t *output, int tokenSize) {
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void ChatGLM2<WeiT, KVCacheT>::lastLayerNormForward(float *input, float *output, int rows) {
finalLN.forward(input, output, rows);
Expand All @@ -125,6 +130,11 @@ void ChatGLM2<WeiT, KVCacheT>::lastLayerNormForward(bfloat16_t *input, bfloat16_
finalLN.forward(input, output, rows);
}

template <typename WeiT, typename KVCacheT>
void ChatGLM2<WeiT, KVCacheT>::lastLayerNormForward(float16_t *input, float16_t *output, int rows) {
finalLN.forward(input, output, rows);
}

// Return the position_ids + block_position_ids
// if position_ids is None:
// position_ids = self.get_position_ids(input_ids, device=input_ids.device)
Expand Down Expand Up @@ -175,4 +185,4 @@ int *ChatGLM2<WeiT, KVCacheT>::getPositionIds(int *ids, int batchSize, int seqLe
return positionIds;
}

IMPLEMENT_MODEL(ChatGLM2, chatglm2)
IMPLEMENT_MODEL(ChatGLM2, chatglm2)
4 changes: 3 additions & 1 deletion src/models/chatglm2.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ class ChatGLM2
virtual void prepareAttnMask(int *ids, int step);
virtual void embeddingForward(int *ids, float *output, int tokenSize);
virtual void embeddingForward(int *ids, bfloat16_t *output, int tokenSize);
virtual void embeddingForward(int *ids, float16_t *output, int tokenSize);
virtual void lastLayerNormForward(float *input, float *output, int rows);
virtual void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows);
virtual void lastLayerNormForward(float16_t *input, float16_t *output, int rows);
virtual int *getPositionIds(int *ids, int batchSize, int seqLen, int step) override;

private:
Expand All @@ -59,4 +61,4 @@ class ChatGLM2
int posBufSize;
};

REGISTER_MODEL(ChatGLM2, chatglm2)
REGISTER_MODEL(ChatGLM2, chatglm2)
10 changes: 0 additions & 10 deletions src/models/env_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,6 @@ int getFlashThresh() {
return envFlashThresh;
}

bool enableSkipMsk() {
// Skip masked attn in flash attention for better perf of long sequence, default disabled
static int skipMsk = -1;
if (skipMsk == -1) {
skipMsk = (getenv("ENABLE_SKIP_MASK") ? atoi(getenv("ENABLE_SKIP_MASK")) : 0);
if (skipMsk == 1) printf("ENABLE_SKIP_MASK is enabled for ignoring mask Q*K.\n");
}
return skipMsk == 1;
}

bool kvTrans() {
// Transpose KV Tensor to [batchSize, headNum, seqLen, headSize] for better perf of long sequence, default disabled
// TODO: add support for reorder and expand when beam_search>1
Expand Down
12 changes: 11 additions & 1 deletion src/models/gemma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ void GemmaLLM<WeiT, KVCacheT>::embeddingForward(int *ids, bfloat16_t *output, in
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void GemmaLLM<WeiT, KVCacheT>::embeddingForward(int *ids, float16_t *output, int tokenSize) {
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void GemmaLLM<WeiT, KVCacheT>::lastLayerNormForward(float *input, float *output, int rows) {
finalLN.forward(input, output, rows);
Expand All @@ -124,4 +129,9 @@ void GemmaLLM<WeiT, KVCacheT>::lastLayerNormForward(bfloat16_t *input, bfloat16_
finalLN.forward(input, output, rows);
}

IMPLEMENT_MODEL(GemmaLLM, gemma)
template <typename WeiT, typename KVCacheT>
void GemmaLLM<WeiT, KVCacheT>::lastLayerNormForward(float16_t *input, float16_t *output, int rows) {
finalLN.forward(input, output, rows);
}

IMPLEMENT_MODEL(GemmaLLM, gemma)
4 changes: 3 additions & 1 deletion src/models/gemma.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ class GemmaLLM

void embeddingForward(int *ids, float *output, int tokenSize);
void embeddingForward(int *ids, bfloat16_t *output, int tokenSize);
void embeddingForward(int *ids, float16_t *output, int tokenSize);

void lastLayerNormForward(float *input, float *output, int rows);
void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows);
void lastLayerNormForward(float16_t *input, float16_t *output, int rows);

private:
void setEmbeddingWeights(const std::string &modelPath);
Expand All @@ -49,4 +51,4 @@ class GemmaLLM
RmsNorm finalLN;
};

REGISTER_MODEL(GemmaLLM, gemma)
REGISTER_MODEL(GemmaLLM, gemma)
12 changes: 11 additions & 1 deletion src/models/qwen2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ void Qwen2LLM<WeiT, KVCacheT>::embeddingForward(int *ids, bfloat16_t *output, in
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void Qwen2LLM<WeiT, KVCacheT>::embeddingForward(int *ids, float16_t *output, int tokenSize) {
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void Qwen2LLM<WeiT, KVCacheT>::lastLayerNormForward(float *input, float *output, int rows) {
finalLN.forward(input, output, rows);
Expand All @@ -124,4 +129,9 @@ void Qwen2LLM<WeiT, KVCacheT>::lastLayerNormForward(bfloat16_t *input, bfloat16_
finalLN.forward(input, output, rows);
}

IMPLEMENT_MODEL(Qwen2LLM, qwen2)
template <typename WeiT, typename KVCacheT>
void Qwen2LLM<WeiT, KVCacheT>::lastLayerNormForward(float16_t *input, float16_t *output, int rows) {
finalLN.forward(input, output, rows);
}

IMPLEMENT_MODEL(Qwen2LLM, qwen2)
4 changes: 3 additions & 1 deletion src/models/qwen2.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ class Qwen2LLM

void embeddingForward(int *ids, float *output, int tokenSize);
void embeddingForward(int *ids, bfloat16_t *output, int tokenSize);
void embeddingForward(int *ids, float16_t *output, int tokenSize);

void lastLayerNormForward(float *input, float *output, int rows);
void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows);
void lastLayerNormForward(float16_t *input, float16_t *output, int rows);

private:
void setEmbeddingWeights(const std::string &modelPath);
Expand All @@ -49,4 +51,4 @@ class Qwen2LLM
RmsNorm finalLN;
};

REGISTER_MODEL(Qwen2LLM, qwen2)
REGISTER_MODEL(Qwen2LLM, qwen2)
12 changes: 11 additions & 1 deletion src/models/yarn_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ void YaRNLlama<WeiT, KVCacheT>::embeddingForward(int *ids, bfloat16_t *output, i
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void YaRNLlama<WeiT, KVCacheT>::embeddingForward(int *ids, float16_t *output, int tokenSize) {
embedding->forward(ids, output, tokenSize);
}

template <typename WeiT, typename KVCacheT>
void YaRNLlama<WeiT, KVCacheT>::lastLayerNormForward(float *input, float *output, int rows) {
finalLN.forward(input, output, rows);
Expand All @@ -108,4 +113,9 @@ void YaRNLlama<WeiT, KVCacheT>::lastLayerNormForward(bfloat16_t *input, bfloat16
finalLN.forward(input, output, rows);
}

IMPLEMENT_MODEL(YaRNLlama, yarn_llama)
template <typename WeiT, typename KVCacheT>
void YaRNLlama<WeiT, KVCacheT>::lastLayerNormForward(float16_t *input, float16_t *output, int rows) {
finalLN.forward(input, output, rows);
}

IMPLEMENT_MODEL(YaRNLlama, yarn_llama)
4 changes: 3 additions & 1 deletion src/models/yarn_llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ class YaRNLlama

void embeddingForward(int *ids, float *output, int tokenSize);
void embeddingForward(int *ids, bfloat16_t *output, int tokenSize);
void embeddingForward(int *ids, float16_t *output, int tokenSize);

void lastLayerNormForward(float *input, float *output, int rows);
void lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows);
void lastLayerNormForward(float16_t *input, float16_t *output, int rows);

private:
void setEmbeddingWeights(const std::string &modelPath);
Expand All @@ -51,4 +53,4 @@ class YaRNLlama
RmsNorm finalLN;
};

REGISTER_MODEL(YaRNLlama, yarn_llama)
REGISTER_MODEL(YaRNLlama, yarn_llama)
37 changes: 21 additions & 16 deletions src/utils/decoder_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

extern int getFlashThresh();
extern bool enableCATMLP();
extern bool enableSkipMsk();

class DecoderUtil {
public:
Expand Down Expand Up @@ -554,16 +553,31 @@ class DecoderUtil {

dnnl_sgemm(ta[0], tb[0], m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);

} else if (std::is_same_v<T, bfloat16_t>) {
} else {
CBLAS_TRANSPOSE ta, tb;
ta = transa ? CblasTrans : CblasNoTrans;
tb = transb ? CblasTrans : CblasNoTrans;

cblas_gemm_bf16bf16f32(CblasRowMajor, ta, tb, m, n, k, alpha, (const MKL_BF16 *)(A), lda,
(const MKL_BF16 *)(B), ldb, beta, C, ldc);
} else {
printf("Datatype Not supported yet\n");
exit(-1);
if (std::is_same_v<T, bfloat16_t>) {
cblas_gemm_bf16bf16f32(CblasRowMajor, ta, tb, m, n, k, alpha, (const MKL_BF16 *)(A), lda,
(const MKL_BF16 *)(B), ldb, beta, C, ldc);
} else if (std::is_same_v<T, float16_t>) {
static int mkl_enable_inst = -1;
if (mkl_enable_inst == -1) {
#ifdef AMX_FP16_WEIGHT_ONLY_FP16
// AMX FP16
mkl_enable_inst = mkl_enable_instructions(MKL_ENABLE_AVX512_E5);
#else
// AVX512_FP16, skip E4 avoiding illegal instruction error
mkl_enable_inst = mkl_enable_instructions(MKL_ENABLE_AVX512_E3);
#endif
}
cblas_gemm_f16f16f32(CblasRowMajor, ta, tb, m, n, k, alpha, (const MKL_F16 *)(A), lda,
(const MKL_F16 *)(B), ldb, beta, C, ldc);
} else {
printf("Datatype Not supported yet\n");
exit(-1);
}
}
}

Expand Down Expand Up @@ -806,13 +820,4 @@ class DecoderUtil {
sgemm((T *)AB, C, expABC, m, n, k, k, vStride, n, false, false);
updateOutTile(output, expABC, preSum, sum, preMax, max, m, n, stride);
}

static bool skipMskAttn(const float *attnMask, int m, int n, int stride) {
float lowest = std::numeric_limits<float>::lowest();
// left bottom is lowest
if (attnMask[(m - 1) * stride] == lowest)
return true;
else
return false;
}
};
Loading
Loading