Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Upgrade xDNN to v1.5.2 and make AMX_FP16 work #468

Merged
merged 4 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cmake/xdnn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ include(ExternalProject)

# cmake-format: off
ExternalProject_Add(xdnn_lib
URL https://github.com/intel/xFasterTransformer/releases/download/IntrinsicGemm/xdnn_v1.5.1.tar.gz
URL_HASH MD5=9ac7a7031b542eca2d9ec80d4c0f8be2
TIMEOUT 60
URL https://github.com/intel/xFasterTransformer/releases/download/IntrinsicGemm/xdnn_v1.5.2.tar.gz
URL_HASH MD5=884f2e1e2c846ff19f33c889681f8dc2
TIMEOUT 120
SOURCE_DIR ${CMAKE_SOURCE_DIR}/3rdparty/xdnn
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand Down
25 changes: 20 additions & 5 deletions src/kernels/attention_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <omp.h>
#include "aligned_type.h"
#include "amx_sgemm_bf16bf16bf16.h"
#include "amx_sgemm_f16f16f16.h"
#include "bfloat16.h"
#include "compile_util.h"
#include "copy_util.h"
Expand Down Expand Up @@ -107,9 +108,23 @@ void small_amx_gemm_16bits_compute(int m, int n, int k, T *A, int lda, T *packed

if (std::is_same_v<T, bfloat16_t>) {
xdnn_small_amx_sgemm_bf16bf16bf16_compute(
m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc);
m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, ldb, (XDNN_BF16 *)C, ldc);
} else {
//xdnn_small_amx_sgemm_f16f16f16_compute(m, n, k, (XDNN_FP16 *)A, lda, (XDNN_FP16 *)packedB, ldb, (XDNN_FP16 *)C, ldc);
xdnn_small_amx_sgemm_f16f16f16_compute(m, n, k, (XDNN_FP16 *)A, lda, (XDNN_FP16 *)packedB, ldb, (XDNN_FP16 *)C, ldc);
}
}

template <typename T>
void small_softmax(T *data, float scale, int elements) {
static_assert(std::is_same_v<T, float> || std::is_same_v<T, bfloat16_t> || std::is_same_v<T, float16_t>,
"Unsupported data type for small_softmax");

if constexpr (std::is_same_v<T, float>) {
small_softmax_f32(data, scale, elements);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
small_softmax_bf16((XDNN_BF16 *)data, scale, elements);
} else if constexpr (std::is_same_v<T, float16_t>) {
DecoderUtil::computeSoftmax(data, scale, elements);
}
}

Expand Down Expand Up @@ -266,7 +281,7 @@ void selfAttention_SeparateCopy(T *output, T *query, T *key, T *value, int qHead
for (int seq = 0; seq < endSeq - startSeq; ++seq) {
int elements = startSeq + seq + 1;
if (alibiSlopes == nullptr) {
small_softmax_bf16((XDNN_BF16 *)(C + seq * ldc), scale, elements);
small_softmax(C + seq * ldc, scale, elements);
} else {
DecoderUtil::alibiSoftmax(C + seq * ldc, scale, alibiSlopes[i], elements);
}
Expand Down Expand Up @@ -416,7 +431,7 @@ void selfAttention_FusedCopy(T *output, T *query, T *key, T *value, int qHeadNum
for (int seq = 0; seq < endSeq - startSeq; ++seq) {
int elements = startSeq + seq + 1;
if (alibiSlopes == nullptr) {
small_softmax_bf16((XDNN_BF16 *)(C + seq * ldc), scale, elements);
small_softmax(C + seq * ldc, scale, elements);
} else {
DecoderUtil::alibiSoftmax(C + seq * ldc, scale, alibiSlopes[i], elements);
}
Expand Down Expand Up @@ -765,7 +780,7 @@ void crossAttnByHead(T *output, const T *query, const T *key, const T *value, in
for (int seq = 0; seq < queryLen; ++seq) {
int elements = pastSeqLens[b] + seq + 1;
if (alibiSlopes == nullptr) {
small_softmax_f32(S + seq * keyLen, scale, elements);
small_softmax(S + seq * keyLen, scale, elements);
} else {
DecoderUtil::alibiSoftmax(S + seq * keyLen, scale, alibiSlopes[i], elements);
}
Expand Down
43 changes: 43 additions & 0 deletions src/utils/decoder_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,49 @@ class DecoderUtil {
}
}

// General version
static void computeSoftmax(float16_t *data, float scale, int size) {
int vecs = (size + 15) / 16; // how many avx512 vectors
__mmask16 tailMask = (size % 16 == 0 ? 0xffff : (1 << (size % 16)) - 1); // mask of last vector

__m512 vsum = _mm512_set1_ps(0);

// maxVal is used to avoid exp(x) = inf
float maxVal = std::numeric_limits<float>::lowest();
__m512 vmax = _mm512_set1_ps(maxVal);
__m512 vfactor = _mm512_set1_ps(scale);

int i = 0;
for (i = 0; i < vecs; ++i) {
__mmask16 k = (i == vecs - 1 ? tailMask : 0xffff);
__m512 vx = xft::load_avx512(k, data + i * 16);
vmax = _mm512_mask_max_ps(vmax, k, vmax, vx * vfactor);
}

maxVal = _mm512_reduce_max_ps(vmax);
vmax = _mm512_set1_ps(maxVal);

// Compute vexp(vx - vmax) and sum it
for (i = 0; i < vecs; ++i) {
__mmask16 k = (i == vecs - 1 ? tailMask : 0xffff);
__m512 vx = xft::load_avx512(k, data + i * 16);
vx = BertUtil::vexp(vx * vfactor - vmax);
xft::store_avx512(data + i * 16, k, vx);
vsum = _mm512_mask_add_ps(vsum, k, vsum, vx);
}

float sum = _mm512_reduce_add_ps(vsum);
__m512 vrsum = _mm512_set1_ps(1.0f / sum);

// Compute exp/sum(exp) and store
for (i = 0; i < vecs; ++i) {
__mmask16 k = (i == vecs - 1 ? tailMask : 0xffff);
__m512 vx = xft::load_avx512(k, data + i * 16);
vx = vx * vrsum;
xft::store_avx512(data + i * 16, k, vx);
}
}

// Softmax: skip the calculation when attention mask is the lowest value
static void softmaxSkipMask(float *data, const float *attnMask, int size, float scale) {
int vecs = (size + 15) / 16; // how many avx512 vectors
Expand Down
Loading