From b64b27cdedcf1ed8b344e030321622e8be3428a1 Mon Sep 17 00:00:00 2001 From: pujiang Date: Sun, 19 May 2024 22:03:58 -0400 Subject: [PATCH 1/3] Revert "fix bug of incorrect input offset in CB" This reverts commit 314e67f1ca09f0e15a4a4d53720a11f878383efe. --- src/kernels/attention_kernels.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernels/attention_kernels.h b/src/kernels/attention_kernels.h index 1a9d96bf..183b4b8e 100644 --- a/src/kernels/attention_kernels.h +++ b/src/kernels/attention_kernels.h @@ -700,7 +700,7 @@ void crossAttnByHead(T *output, const T *query, const T *key, const T *value, in size_t scoreSizePerThr = 0; for (int i = 0; i < batchSize; ++i) { scoreSizePerThr = std::max(scoreSizePerThr, (size_t)inputSeqLens[i] * (inputSeqLens[i] + pastSeqLens[i])); - inputOffsets[i] = (i > 0 ? inputOffsets[i - 1] + inputSeqLens[i - 1] : 0); + inputOffsets[i] = (i > 0 ? inputOffsets[i - 1] + inputSeqLens[i] : 0); } scoreSizePerThr = ALIGNED_SIZE(scoreSizePerThr, 16); From d20d222bfecea4791828d7844d79fdf2531e62eb Mon Sep 17 00:00:00 2001 From: pujiang Date: Thu, 4 Jul 2024 23:47:26 -0400 Subject: [PATCH 2/3] Make Slim Attention prepared for AMX_FP16; more balanced split in crossAttnByHead --- src/kernels/attention_kernels.h | 216 +++++++++++++++++--------------- src/layers/attention.h | 28 +++-- 2 files changed, 133 insertions(+), 111 deletions(-) diff --git a/src/kernels/attention_kernels.h b/src/kernels/attention_kernels.h index 183b4b8e..3d9cea18 100644 --- a/src/kernels/attention_kernels.h +++ b/src/kernels/attention_kernels.h @@ -99,9 +99,23 @@ void gemmSV( } } +// T is bfloat16_t or float16_t +// ldb is the K value during packing +template +void small_amx_gemm_16bits_compute(int m, int n, int k, T *A, int lda, T *packedB, int ldb, T *C, int ldc) { + static_assert(std::is_same_v || std::is_same_v, "AMX gemm only supports BF16/FP16."); + + if (std::is_same_v) { + xdnn_small_amx_sgemm_bf16bf16bf16_compute( + m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc); + } else { + //xdnn_small_amx_sgemm_f16f16f16_compute(m, n, k, (XDNN_FP16 *)A, lda, (XDNN_FP16 *)packedB, ldb, (XDNN_FP16 *)C, ldc); + } +} + // Self attention while KV cache copy is separated -template -void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bfloat16_t *value, int qHeadNum, +template +void selfAttention_SeparateCopy(T *output, T *query, T *key, T *value, int qHeadNum, int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes, const float scale, const float *alibiSlopes, int threadNum, const Lambda1 &getKCache, const Lambda2 &getVCache) { @@ -126,8 +140,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ auto totalPackSize = fusedPack ? threadNum * (kPackSize + vPackSize) : (batchSize * kvHeadNum) * (kPackSize + vPackSize); - bfloat16_t *packBuf - = (bfloat16_t *)SimpleMemPool::instance().getBuffer("kv_packing", totalPackSize * sizeof(bfloat16_t)); + T *packBuf + = (T *)SimpleMemPool::instance().getBuffer("kv_packing", totalPackSize * sizeof(T)); // Copy key/value to cache and pack them // If packing is not fused into computing, then pack it here @@ -137,8 +151,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ for (int i = 0; i < kvHeadNum; ++i) { const int tokens = tokenSizes[b]; - bfloat16_t *packedB = packBuf + (b * kvHeadNum + i) * (kPackSize + vPackSize); - bfloat16_t *packedV = packedB + kPackSize; + T *packedB = packBuf + (b * kvHeadNum + i) * (kPackSize + vPackSize); + T *packedV = packedB + kPackSize; auto B = key + offsets[b] * kvStride + i * headSize; for (int s = 0; s < tokens; ++s) { @@ -181,8 +195,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ // Prepare score buffer auto maxScoreStride = (maxTokenSize + 31) / 32 * 32; - bfloat16_t *scores = (bfloat16_t *)SimpleMemPool::instance().getBuffer( - "qkscore", threadNum * mBlockSize * maxScoreStride * sizeof(bfloat16_t)); + T *scores = (T *)SimpleMemPool::instance().getBuffer( + "qkscore", threadNum * mBlockSize * maxScoreStride * sizeof(T)); auto totalBlocks = blkEndIndex[batchSize - 1]; std::pair packInfo[threadNum]; @@ -208,8 +222,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ int tid = omp_get_thread_num(); int kvHeadIdx = i / groupNum; int locationIdx = (fusedPack ? tid : b * kvHeadNum + kvHeadIdx); - bfloat16_t *packedB = packBuf + locationIdx * (kPackSize + vPackSize); - bfloat16_t *packedV = packedB + kPackSize; + T *packedB = packBuf + locationIdx * (kPackSize + vPackSize); + T *packedV = packedB + kPackSize; const int tokens = tokenSizes[b]; const int startSeq = mb * mBlockSize; @@ -234,8 +248,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ } // Causal mask (either with or without Alibi), use endSeq as N - xdnn_small_amx_sgemm_bf16bf16bf16_compute( - m, endSeq, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc); + small_amx_gemm_16bits_compute(m, endSeq, k, A, lda, packedB, headSize, C, ldc); #ifdef XFT_DEBUG if (b == 0 && i == 0) { @@ -257,7 +270,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ } else { DecoderUtil::alibiSoftmax(C + seq * ldc, scale, alibiSlopes[i], elements); } - memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(bfloat16_t)); + memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(T)); } #ifdef XFT_DEBUG @@ -274,7 +287,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ lda = ldc; ldc = oStride; A = C; - C = (bfloat16_t *)output + (offsets[b] + startSeq) * ldc + i * headSize; + C = (T *)output + (offsets[b] + startSeq) * ldc + i * headSize; if constexpr (fusedPack) { if (packInfo[tid].first != b || packInfo[tid].second != kvHeadIdx) { @@ -287,8 +300,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ } } - xdnn_small_amx_sgemm_bf16bf16bf16_compute( - m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedV, (XDNN_BF16 *)C, ldc); + small_amx_gemm_16bits_compute(m, n, k, A, lda, packedV, tokens, C, ldc); #ifdef XFT_DEBUG if (b == 0 && i == 0) { @@ -301,8 +313,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ }); } -template -void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bfloat16_t *value, int qHeadNum, +template +void selfAttention_FusedCopy(T *output, T *query, T *key, T *value, int qHeadNum, int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes, const float scale, const float *alibiSlopes, int threadNum, const Lambda1 &getKCache, const Lambda2 &getVCache) { @@ -331,11 +343,11 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * // Prepare buffers (packing buffer and score buffer) const int kPackSize = xdnn_small_amx_sgemm_bf16bf16bf16_packb_size(maxTokenSize, headSize, 32, 32); const int vPackSize = xdnn_small_amx_sgemm_bf16bf16bf16_packb_size(headSize, maxTokenSize, 32, 32); - bfloat16_t *packBuf = (bfloat16_t *)SimpleMemPool::instance().getBuffer( - "kv_packing", threadNum * (kPackSize + vPackSize) * sizeof(bfloat16_t)); + T *packBuf = (T *)SimpleMemPool::instance().getBuffer( + "kv_packing", threadNum * (kPackSize + vPackSize) * sizeof(T)); int maxScoreStride = (maxTokenSize + 31) / 32 * 32; - bfloat16_t *scores = (bfloat16_t *)SimpleMemPool::instance().getBuffer( - "qkscore", threadNum * mBlockSize * maxScoreStride * sizeof(bfloat16_t)); + T *scores = (T *)SimpleMemPool::instance().getBuffer( + "qkscore", threadNum * mBlockSize * maxScoreStride * sizeof(T)); #ifdef XFT_DEBUG printf("maxTokenSize=%d, tokenSizes[0]=%d, offsets[0]=%d, kvStride=%d\n", maxTokenSize, tokenSizes[0], offsets[0], @@ -349,8 +361,8 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * const int tokens = tokenSizes[b]; const int mBlockNum = (tokens + mBlockSize - 1) / mBlockSize; - bfloat16_t *packedB = packBuf + tid * (kPackSize + vPackSize); - bfloat16_t *packedV = packedB + kPackSize; + T *packedB = packBuf + tid * (kPackSize + vPackSize); + T *packedV = packedB + kPackSize; // Copy key/value to cache and pack them auto B = key + offsets[b] * kvStride + i * headSize; @@ -386,8 +398,8 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * auto A = query + (offsets[b] + startSeq) * qStride + i * headSize; auto C = scores + tid * mBlockSize * maxScoreStride; - xdnn_small_amx_sgemm_bf16bf16bf16_compute( - m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc); + small_amx_gemm_16bits_compute( + m, n, k, A, lda, packedB, headSize, C, ldc); #ifdef XFT_DEBUG if (b == 0 && i == 0) { @@ -408,7 +420,7 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * } else { DecoderUtil::alibiSoftmax(C + seq * ldc, scale, alibiSlopes[i], elements); } - memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(bfloat16_t)); + memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(T)); } #ifdef XFT_DEBUG @@ -425,10 +437,9 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * lda = ldc; ldc = oStride; A = C; - C = (bfloat16_t *)output + (offsets[b] + startSeq) * ldc + i * headSize; + C = (T *)output + (offsets[b] + startSeq) * ldc + i * headSize; - xdnn_small_amx_sgemm_bf16bf16bf16_compute( - m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedV, (XDNN_BF16 *)C, ldc); + small_amx_gemm_16bits_compute(m, n, k, A, lda, packedV, tokens, C, ldc); #ifdef XFT_DEBUG if (b == 0 && i == 0) { @@ -443,8 +454,8 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * } // end for b } -template -void selfAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bfloat16_t *value, int qHeadNum, +template +void selfAttention(T *output, T *query, T *key, T *value, int qHeadNum, int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes, const float scale, const float *alibiSlopes, int threadNum, const Lambda1 &getKCache, const Lambda2 &getVCache) { @@ -707,84 +718,87 @@ void crossAttnByHead(T *output, const T *query, const T *key, const T *value, in size_t scoreSize = scoreSizePerThr * threadNum; float *scoreBuf = (float *)SimpleMemPool::instance().getBuffer("scoreBuf", sizeof(float) * scoreSize); -#pragma omp parallel for collapse(2) - for (int b = 0; b < batchSize; ++b) { - for (int i = 0; i < responsibleHeads; ++i) { - // Copy current key to cached keys (if needed) - int kvHdx = i / groupNum; - auto keyMatInfo = getKHead(b, kvHdx); - auto valueMat = getVHead(b, kvHdx); - bool bCopyCache = (i % groupNum == 0); - - // Q * K - auto Q = query + inputOffsets[b] * qStride + i * headSize; - auto S = scoreBuf + omp_get_thread_num() * scoreSizePerThr; - - const int queryLen = inputSeqLens[b]; - const int keyLen = pastSeqLens[b] + inputSeqLens[b]; - - if (bCopyCache) { - int m = queryLen; - int n = keyLen; - int lda = qStride; - int ldc = keyLen; +#pragma omp parallel for collapse(3) + for (int kvh = 0; kvh < kvHeadNum; ++kvh) { + for (int b = 0; b < batchSize; ++b) { + for (int groupOff = 0; groupOff < groupNum; ++groupOff) { + int i = kvh * groupNum + groupOff; - // Copy to Key cache and compute Query * Key - auto src = key + inputOffsets[b] * kvStride + kvHdx * headSize; - storeKVCache(keyMatInfo, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride); + // Copy current key to cached keys (if needed) + int kvHdx = kvh; + auto keyMatInfo = getKHead(b, kvHdx); + auto valueMat = getVHead(b, kvHdx); + bool bCopyCache = (i % groupNum == 0); - gemmQK(Q, keyMatInfo, S, m, n, headSize, lda, ldc); - } else { - // Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization - int m = queryLen; - int n = pastSeqLens[b]; - int lda = qStride; - int ldc = keyLen; - gemmQK(Q, keyMatInfo, S, m, n, headSize, lda, ldc); + // Q * K + auto Q = query + inputOffsets[b] * qStride + i * headSize; + auto S = scoreBuf + omp_get_thread_num() * scoreSizePerThr; - int ldb = kvStride; - auto B = key + inputOffsets[b] * kvStride + kvHdx * headSize; - small_gemm_transb(Q, B, S + n, m, inputSeqLens[b], headSize, lda, ldb, ldc); - } + const int queryLen = inputSeqLens[b]; + const int keyLen = pastSeqLens[b] + inputSeqLens[b]; + + if (bCopyCache) { + int m = queryLen; + int n = keyLen; + int lda = qStride; + int ldc = keyLen; - // Softmax(Q * K) - for (int seq = 0; seq < queryLen; ++seq) { - int elements = pastSeqLens[b] + seq + 1; - if (alibiSlopes == nullptr) { - small_softmax_f32(S + seq * keyLen, scale, elements); + // Copy to Key cache and compute Query * Key + auto src = key + inputOffsets[b] * kvStride + kvHdx * headSize; + storeKVCache(keyMatInfo, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride); + + gemmQK(Q, keyMatInfo, S, m, n, headSize, lda, ldc); } else { - DecoderUtil::alibiSoftmax(S + seq * keyLen, scale, alibiSlopes[i], elements); + // Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization + int m = queryLen; + int n = pastSeqLens[b]; + int lda = qStride; + int ldc = keyLen; + gemmQK(Q, keyMatInfo, S, m, n, headSize, lda, ldc); + + int ldb = kvStride; + auto B = key + inputOffsets[b] * kvStride + kvHdx * headSize; + small_gemm_transb(Q, B, S + n, m, inputSeqLens[b], headSize, lda, ldb, ldc); } - if (keyLen > elements) { memset(S + seq * keyLen + elements, 0, (keyLen - elements) * sizeof(float)); } - } - // Softmax * V - if (bCopyCache) { - // Copy current value to cached values - auto src = value + inputOffsets[b] * kvStride + kvHdx * headSize; - storeKVCache(valueMat, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride); - - int m = queryLen; - auto result = output + inputOffsets[b] * oStride + i * headSize; - gemmSV(S, valueMat, result, m, headSize, keyLen, keyLen, oStride); - } else { - // Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization - int m = queryLen; - float f32Out[m * headSize]; // accumulate in FP32 - gemmSV(S, valueMat, f32Out, m, headSize, pastSeqLens[b], keyLen, headSize); - - auto B = value + inputOffsets[b] * kvStride + kvHdx * headSize; - small_gemm(S + pastSeqLens[b], B, f32Out, m, headSize, m, keyLen, kvStride, headSize, true); - - // f32Out -> output - auto result = output + inputOffsets[b] * oStride + i * headSize; - for (int t = 0; t < m; ++t) { - xft::copy(result + t * oStride, f32Out + t * headSize, headSize); + // Softmax(Q * K) + for (int seq = 0; seq < queryLen; ++seq) { + int elements = pastSeqLens[b] + seq + 1; + if (alibiSlopes == nullptr) { + small_softmax_f32(S + seq * keyLen, scale, elements); + } else { + DecoderUtil::alibiSoftmax(S + seq * keyLen, scale, alibiSlopes[i], elements); + } + if (keyLen > elements) { memset(S + seq * keyLen + elements, 0, (keyLen - elements) * sizeof(float)); } } - } - } // end for i - } // end for b + // Softmax * V + if (bCopyCache) { + // Copy current value to cached values + auto src = value + inputOffsets[b] * kvStride + kvHdx * headSize; + storeKVCache(valueMat, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride); + + int m = queryLen; + auto result = output + inputOffsets[b] * oStride + i * headSize; + gemmSV(S, valueMat, result, m, headSize, keyLen, keyLen, oStride); + } else { + // Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization + int m = queryLen; + float f32Out[m * headSize]; // accumulate in FP32 + gemmSV(S, valueMat, f32Out, m, headSize, pastSeqLens[b], keyLen, headSize); + + auto B = value + inputOffsets[b] * kvStride + kvHdx * headSize; + small_gemm(S + pastSeqLens[b], B, f32Out, m, headSize, m, keyLen, kvStride, headSize, true); + + // f32Out -> output + auto result = output + inputOffsets[b] * oStride + i * headSize; + for (int t = 0; t < m; ++t) { + xft::copy(result + t * oStride, f32Out + t * headSize, headSize); + } + } + } // end for groupOff + } // end for b + } // end for kvh } // scaled dot-product attention: bmm1 + softmax + bmm2 diff --git a/src/layers/attention.h b/src/layers/attention.h index 924270dc..1ece2a46 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -371,8 +371,12 @@ class Attention { if (pastSeqLen == 0) { if (ctx->inputSeqLen > getFlashThresh()) { flashAttention(ctx, query, key, value, attnSplit, presentKey, presentValue, attnMask, pastSeqLen); - } else if constexpr (std::is_same_v && std::is_same_v) { - selfAttentionBF16(ctx, query, key, value, attnSplit, presentKey, presentValue); + } else if constexpr ((std::is_same_v && std::is_same_v) +#if defined(AMX_FP16_WEIGHT_ONLY_FP16) + || (std::is_same_v && std::is_same_v) +#endif + ) { + selfAttention16bits(ctx, query, key, value, attnSplit, presentKey, presentValue); } else { fusedAttention(ctx, query, key, value, attnSplit, presentKey, presentValue, attnMask, pastSeqLen); } @@ -576,8 +580,12 @@ class Attention { if (seqs[0]->getStep() == 0) { // First token generation if (totInSeqLen > getFlashThresh() * seqs.size()) { flashAttention(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs); - } else if constexpr (std::is_same_v && std::is_same_v) { - selfAttentionBF16(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs); + } else if constexpr ((std::is_same_v && std::is_same_v) +#if defined(AMX_FP16_WEIGHT_ONLY_FP16) + || (std::is_same_v && std::is_same_v) +#endif + ) { + selfAttention16bits(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs); } else { fusedAttention(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs); } @@ -654,9 +662,9 @@ class Attention { } protected: - template - void selfAttentionBF16(DecoderContext *ctx, xft::Matrix &query, xft::Matrix &key, - xft::Matrix &value, xft::Matrix &result, KVCacheTensor &presentKey, + template + void selfAttention16bits(DecoderContext *ctx, xft::Matrix &query, xft::Matrix &key, + xft::Matrix &value, xft::Matrix &result, KVCacheTensor &presentKey, KVCacheTensor &presentValue) { int responsibleQHeads = this->endQHead - this->startQHead; int responsibleKVHeads = this->endKVHead - this->startKVHead; @@ -674,9 +682,9 @@ class Attention { [&](int b, int headIdx, int seqIdx) { return presentValue.getSequence(seqIdx, b, headIdx); }); } - template - void selfAttentionBF16(DecoderContext *ctx, xft::Matrix &query, xft::Matrix &key, - xft::Matrix &value, xft::Matrix &result, + template + void selfAttention16bits(DecoderContext *ctx, xft::Matrix &query, xft::Matrix &key, + xft::Matrix &value, xft::Matrix &result, std::vector *> &keyCaches, std::vector *> &valueCaches, std::vector &seqs) { int responsibleQHeads = this->endQHead - this->startQHead; From 5bdb3a68a54540701e9c3edd7e7653471001a009 Mon Sep 17 00:00:00 2001 From: pujiang Date: Mon, 8 Jul 2024 10:16:43 -0400 Subject: [PATCH 3/3] upgrade xdnn and make AMX_FP16 work --- cmake/xdnn.cmake | 6 ++--- src/kernels/attention_kernels.h | 25 +++++++++++++++---- src/utils/decoder_util.h | 43 +++++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 8 deletions(-) diff --git a/cmake/xdnn.cmake b/cmake/xdnn.cmake index 79cb2121..7c0e051a 100644 --- a/cmake/xdnn.cmake +++ b/cmake/xdnn.cmake @@ -26,9 +26,9 @@ include(ExternalProject) # cmake-format: off ExternalProject_Add(xdnn_lib - URL https://github.com/intel/xFasterTransformer/releases/download/IntrinsicGemm/xdnn_v1.5.1.tar.gz - URL_HASH MD5=9ac7a7031b542eca2d9ec80d4c0f8be2 - TIMEOUT 60 + URL https://github.com/intel/xFasterTransformer/releases/download/IntrinsicGemm/xdnn_v1.5.2.tar.gz + URL_HASH MD5=884f2e1e2c846ff19f33c889681f8dc2 + TIMEOUT 120 SOURCE_DIR ${CMAKE_SOURCE_DIR}/3rdparty/xdnn CONFIGURE_COMMAND "" BUILD_COMMAND "" diff --git a/src/kernels/attention_kernels.h b/src/kernels/attention_kernels.h index 3d9cea18..58322ebf 100644 --- a/src/kernels/attention_kernels.h +++ b/src/kernels/attention_kernels.h @@ -19,6 +19,7 @@ #include #include "aligned_type.h" #include "amx_sgemm_bf16bf16bf16.h" +#include "amx_sgemm_f16f16f16.h" #include "bfloat16.h" #include "compile_util.h" #include "copy_util.h" @@ -107,9 +108,23 @@ void small_amx_gemm_16bits_compute(int m, int n, int k, T *A, int lda, T *packed if (std::is_same_v) { xdnn_small_amx_sgemm_bf16bf16bf16_compute( - m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc); + m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, ldb, (XDNN_BF16 *)C, ldc); } else { - //xdnn_small_amx_sgemm_f16f16f16_compute(m, n, k, (XDNN_FP16 *)A, lda, (XDNN_FP16 *)packedB, ldb, (XDNN_FP16 *)C, ldc); + xdnn_small_amx_sgemm_f16f16f16_compute(m, n, k, (XDNN_FP16 *)A, lda, (XDNN_FP16 *)packedB, ldb, (XDNN_FP16 *)C, ldc); + } +} + +template +void small_softmax(T *data, float scale, int elements) { + static_assert(std::is_same_v || std::is_same_v || std::is_same_v, + "Unsupported data type for small_softmax"); + + if constexpr (std::is_same_v) { + small_softmax_f32(data, scale, elements); + } else if constexpr (std::is_same_v) { + small_softmax_bf16((XDNN_BF16 *)data, scale, elements); + } else if constexpr (std::is_same_v) { + DecoderUtil::computeSoftmax(data, scale, elements); } } @@ -266,7 +281,7 @@ void selfAttention_SeparateCopy(T *output, T *query, T *key, T *value, int qHead for (int seq = 0; seq < endSeq - startSeq; ++seq) { int elements = startSeq + seq + 1; if (alibiSlopes == nullptr) { - small_softmax_bf16((XDNN_BF16 *)(C + seq * ldc), scale, elements); + small_softmax(C + seq * ldc, scale, elements); } else { DecoderUtil::alibiSoftmax(C + seq * ldc, scale, alibiSlopes[i], elements); } @@ -416,7 +431,7 @@ void selfAttention_FusedCopy(T *output, T *query, T *key, T *value, int qHeadNum for (int seq = 0; seq < endSeq - startSeq; ++seq) { int elements = startSeq + seq + 1; if (alibiSlopes == nullptr) { - small_softmax_bf16((XDNN_BF16 *)(C + seq * ldc), scale, elements); + small_softmax(C + seq * ldc, scale, elements); } else { DecoderUtil::alibiSoftmax(C + seq * ldc, scale, alibiSlopes[i], elements); } @@ -765,7 +780,7 @@ void crossAttnByHead(T *output, const T *query, const T *key, const T *value, in for (int seq = 0; seq < queryLen; ++seq) { int elements = pastSeqLens[b] + seq + 1; if (alibiSlopes == nullptr) { - small_softmax_f32(S + seq * keyLen, scale, elements); + small_softmax(S + seq * keyLen, scale, elements); } else { DecoderUtil::alibiSoftmax(S + seq * keyLen, scale, alibiSlopes[i], elements); } diff --git a/src/utils/decoder_util.h b/src/utils/decoder_util.h index 2399a5e8..8a43420d 100644 --- a/src/utils/decoder_util.h +++ b/src/utils/decoder_util.h @@ -336,6 +336,49 @@ class DecoderUtil { } } + // General version + static void computeSoftmax(float16_t *data, float scale, int size) { + int vecs = (size + 15) / 16; // how many avx512 vectors + __mmask16 tailMask = (size % 16 == 0 ? 0xffff : (1 << (size % 16)) - 1); // mask of last vector + + __m512 vsum = _mm512_set1_ps(0); + + // maxVal is used to avoid exp(x) = inf + float maxVal = std::numeric_limits::lowest(); + __m512 vmax = _mm512_set1_ps(maxVal); + __m512 vfactor = _mm512_set1_ps(scale); + + int i = 0; + for (i = 0; i < vecs; ++i) { + __mmask16 k = (i == vecs - 1 ? tailMask : 0xffff); + __m512 vx = xft::load_avx512(k, data + i * 16); + vmax = _mm512_mask_max_ps(vmax, k, vmax, vx * vfactor); + } + + maxVal = _mm512_reduce_max_ps(vmax); + vmax = _mm512_set1_ps(maxVal); + + // Compute vexp(vx - vmax) and sum it + for (i = 0; i < vecs; ++i) { + __mmask16 k = (i == vecs - 1 ? tailMask : 0xffff); + __m512 vx = xft::load_avx512(k, data + i * 16); + vx = BertUtil::vexp(vx * vfactor - vmax); + xft::store_avx512(data + i * 16, k, vx); + vsum = _mm512_mask_add_ps(vsum, k, vsum, vx); + } + + float sum = _mm512_reduce_add_ps(vsum); + __m512 vrsum = _mm512_set1_ps(1.0f / sum); + + // Compute exp/sum(exp) and store + for (i = 0; i < vecs; ++i) { + __mmask16 k = (i == vecs - 1 ? tailMask : 0xffff); + __m512 vx = xft::load_avx512(k, data + i * 16); + vx = vx * vrsum; + xft::store_avx512(data + i * 16, k, vx); + } + } + // Softmax: skip the calculation when attention mask is the lowest value static void softmaxSkipMask(float *data, const float *attnMask, int size, float scale) { int vecs = (size + 15) / 16; // how many avx512 vectors