Skip to content

Commit

Permalink
[Kernel] Upgrade xDNN to v1.5.2 and make AMX_FP16 work (#468)
Browse files Browse the repository at this point in the history
* Revert "fix bug of incorrect input offset in CB"

This reverts commit 314e67f1ca09f0e15a4a4d53720a11f878383efe.

* Make Slim Attention prepared for AMX_FP16; more balanced split in crossAttnByHead

* upgrade xdnn and make AMX_FP16 work
  • Loading branch information
pujiang2018 committed Jul 9, 2024
1 parent 69b91cf commit fcea26f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 8 deletions.
6 changes: 3 additions & 3 deletions cmake/xdnn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ include(ExternalProject)

# cmake-format: off
ExternalProject_Add(xdnn_lib
URL https://github.com/intel/xFasterTransformer/releases/download/IntrinsicGemm/xdnn_v1.5.1.tar.gz
URL_HASH MD5=9ac7a7031b542eca2d9ec80d4c0f8be2
TIMEOUT 60
URL https://github.com/intel/xFasterTransformer/releases/download/IntrinsicGemm/xdnn_v1.5.2.tar.gz
URL_HASH MD5=884f2e1e2c846ff19f33c889681f8dc2
TIMEOUT 120
SOURCE_DIR ${CMAKE_SOURCE_DIR}/3rdparty/xdnn
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand Down
25 changes: 20 additions & 5 deletions src/kernels/attention_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <omp.h>
#include "aligned_type.h"
#include "amx_sgemm_bf16bf16bf16.h"
#include "amx_sgemm_f16f16f16.h"
#include "bfloat16.h"
#include "compile_util.h"
#include "copy_util.h"
Expand Down Expand Up @@ -107,9 +108,23 @@ void small_amx_gemm_16bits_compute(int m, int n, int k, T *A, int lda, T *packed

if (std::is_same_v<T, bfloat16_t>) {
xdnn_small_amx_sgemm_bf16bf16bf16_compute(
m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc);
m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, ldb, (XDNN_BF16 *)C, ldc);
} else {
//xdnn_small_amx_sgemm_f16f16f16_compute(m, n, k, (XDNN_FP16 *)A, lda, (XDNN_FP16 *)packedB, ldb, (XDNN_FP16 *)C, ldc);
xdnn_small_amx_sgemm_f16f16f16_compute(m, n, k, (XDNN_FP16 *)A, lda, (XDNN_FP16 *)packedB, ldb, (XDNN_FP16 *)C, ldc);
}
}

template <typename T>
void small_softmax(T *data, float scale, int elements) {
static_assert(std::is_same_v<T, float> || std::is_same_v<T, bfloat16_t> || std::is_same_v<T, float16_t>,
"Unsupported data type for small_softmax");

if constexpr (std::is_same_v<T, float>) {
small_softmax_f32(data, scale, elements);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
small_softmax_bf16((XDNN_BF16 *)data, scale, elements);
} else if constexpr (std::is_same_v<T, float16_t>) {
DecoderUtil::computeSoftmax(data, scale, elements);
}
}

Expand Down Expand Up @@ -266,7 +281,7 @@ void selfAttention_SeparateCopy(T *output, T *query, T *key, T *value, int qHead
for (int seq = 0; seq < endSeq - startSeq; ++seq) {
int elements = startSeq + seq + 1;
if (alibiSlopes == nullptr) {
small_softmax_bf16((XDNN_BF16 *)(C + seq * ldc), scale, elements);
small_softmax(C + seq * ldc, scale, elements);
} else {
DecoderUtil::alibiSoftmax(C + seq * ldc, scale, alibiSlopes[i], elements);
}
Expand Down Expand Up @@ -416,7 +431,7 @@ void selfAttention_FusedCopy(T *output, T *query, T *key, T *value, int qHeadNum
for (int seq = 0; seq < endSeq - startSeq; ++seq) {
int elements = startSeq + seq + 1;
if (alibiSlopes == nullptr) {
small_softmax_bf16((XDNN_BF16 *)(C + seq * ldc), scale, elements);
small_softmax(C + seq * ldc, scale, elements);
} else {
DecoderUtil::alibiSoftmax(C + seq * ldc, scale, alibiSlopes[i], elements);
}
Expand Down Expand Up @@ -765,7 +780,7 @@ void crossAttnByHead(T *output, const T *query, const T *key, const T *value, in
for (int seq = 0; seq < queryLen; ++seq) {
int elements = pastSeqLens[b] + seq + 1;
if (alibiSlopes == nullptr) {
small_softmax_f32(S + seq * keyLen, scale, elements);
small_softmax(S + seq * keyLen, scale, elements);
} else {
DecoderUtil::alibiSoftmax(S + seq * keyLen, scale, alibiSlopes[i], elements);
}
Expand Down
43 changes: 43 additions & 0 deletions src/utils/decoder_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,49 @@ class DecoderUtil {
}
}

// General version
static void computeSoftmax(float16_t *data, float scale, int size) {
int vecs = (size + 15) / 16; // how many avx512 vectors
__mmask16 tailMask = (size % 16 == 0 ? 0xffff : (1 << (size % 16)) - 1); // mask of last vector

__m512 vsum = _mm512_set1_ps(0);

// maxVal is used to avoid exp(x) = inf
float maxVal = std::numeric_limits<float>::lowest();
__m512 vmax = _mm512_set1_ps(maxVal);
__m512 vfactor = _mm512_set1_ps(scale);

int i = 0;
for (i = 0; i < vecs; ++i) {
__mmask16 k = (i == vecs - 1 ? tailMask : 0xffff);
__m512 vx = xft::load_avx512(k, data + i * 16);
vmax = _mm512_mask_max_ps(vmax, k, vmax, vx * vfactor);
}

maxVal = _mm512_reduce_max_ps(vmax);
vmax = _mm512_set1_ps(maxVal);

// Compute vexp(vx - vmax) and sum it
for (i = 0; i < vecs; ++i) {
__mmask16 k = (i == vecs - 1 ? tailMask : 0xffff);
__m512 vx = xft::load_avx512(k, data + i * 16);
vx = BertUtil::vexp(vx * vfactor - vmax);
xft::store_avx512(data + i * 16, k, vx);
vsum = _mm512_mask_add_ps(vsum, k, vsum, vx);
}

float sum = _mm512_reduce_add_ps(vsum);
__m512 vrsum = _mm512_set1_ps(1.0f / sum);

// Compute exp/sum(exp) and store
for (i = 0; i < vecs; ++i) {
__mmask16 k = (i == vecs - 1 ? tailMask : 0xffff);
__m512 vx = xft::load_avx512(k, data + i * 16);
vx = vx * vrsum;
xft::store_avx512(data + i * 16, k, vx);
}
}

// Softmax: skip the calculation when attention mask is the lowest value
static void softmaxSkipMask(float *data, const float *attnMask, int size, float scale) {
int vecs = (size + 15) / 16; // how many avx512 vectors
Expand Down

0 comments on commit fcea26f

Please sign in to comment.