From 5446a5716d106d2c3d55c8bee5e30c5e25fdb54f Mon Sep 17 00:00:00 2001 From: Chen Meng Date: Thu, 13 Jun 2024 22:12:09 -0400 Subject: [PATCH] [Layers] Add qwenRope support for CB Qwen1.0 --- src/kernels/rotary_embedding_kernels.cpp | 86 ++++++++++++++++++++++-- src/kernels/rotary_embedding_kernels.h | 17 ++++- src/layers/attention.h | 10 +-- src/layers/rotary_embedding_qwen.cpp | 15 +++-- 4 files changed, 106 insertions(+), 22 deletions(-) diff --git a/src/kernels/rotary_embedding_kernels.cpp b/src/kernels/rotary_embedding_kernels.cpp index 9812bd35..c639c2f1 100644 --- a/src/kernels/rotary_embedding_kernels.cpp +++ b/src/kernels/rotary_embedding_kernels.cpp @@ -293,12 +293,11 @@ void chatglm2ApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStrid query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); } - // For ChatGLM2/3 continous batching template -static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride, int kStride, - int inv_freq_size, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { +static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride, + int kStride, int inv_freq_size, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { int dim = inv_freq_size * 2; const int head_num = qHeads + kHeads; const int half = inv_freq_size; @@ -341,14 +340,14 @@ static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, } } -void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride, int dim, - int totSeqLen, int qHeads, int kHeads, const int *positionIds) { +void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride, + int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { chatglm2ApplyRotaryPosEmbed( query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds); } void chatglm2ApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride, - int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { + int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { chatglm2ApplyRotaryPosEmbed( query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds); } @@ -450,6 +449,81 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i maxSupportedSeqLength, qkShape, positionIds); } +template +static inline void qwenApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride, int kStride, + int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads, + const int *positionIds) { + const int half = (dim + 1) / 2; + const int heads = std::max(qHeads, kHeads); + +#pragma omp parallel for collapse(2) + for (int head = 0; head < heads; ++head) { + for (int seq = 0; seq < totSeqLen; ++seq) { + int pos = positionIds[seq]; + + float *pcos = emb_cos + pos * dim; + float *psin = emb_sin + pos * dim; + + T *q = query + seq * qStride + head * dim; + T *k = key + seq * kStride + head * dim; + + __m512 pScale = _mm512_set1_ps(logn[pos]); + + // Process chunks of 16 elements at a time + for (int i = 0; i < half; i += 16) { + int remain = half - i; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + + __m512 pCosVec = _mm512_maskz_loadu_ps(mask, &pcos[i]); + __m512 pCosHalfVec = _mm512_maskz_loadu_ps(mask, &pcos[i + half]); + __m512 pSinVec = _mm512_maskz_loadu_ps(mask, &psin[i]); + __m512 pSinHalfVec = _mm512_maskz_loadu_ps(mask, &psin[i + half]); + + if (head < qHeads) { + __m512 qVec = xft::load_avx512(mask, &q[i]); + __m512 qHalfVec = xft::load_avx512(mask, &q[i + half]); + __m512 qNew + = _mm512_mul_ps(_mm512_fmsub_ps(qVec, pCosVec, _mm512_mul_ps(qHalfVec, pSinVec)), pScale); + __m512 qHalfNew = _mm512_mul_ps( + _mm512_fmadd_ps(qHalfVec, pCosHalfVec, _mm512_mul_ps(qVec, pSinHalfVec)), pScale); + xft::store_avx512(&q[i], mask, qNew); + xft::store_avx512(&q[i + half], mask, qHalfNew); + } + + if (head < kHeads) { + __m512 kVec = xft::load_avx512(mask, &k[i]); + __m512 kHalfVec = xft::load_avx512(mask, &k[i + half]); + __m512 kNew = _mm512_fmsub_ps(kVec, pCosVec, _mm512_mul_ps(kHalfVec, pSinVec)); + __m512 kHalfNew = _mm512_fmadd_ps(kHalfVec, pCosHalfVec, _mm512_mul_ps(kVec, pSinHalfVec)); + xft::store_avx512(&k[i], mask, kNew); + xft::store_avx512(&k[i + half], mask, kHalfNew); + } + } + } + } +} + +void qwenApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride, + int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads, + const int *positionIds) { + qwenApplyRotaryPosEmbed(query, key, emb_cos, emb_sin, qStride, kStride, dim, logn, maxSupportedSeqLength, + totSeqLen, qHeads, kHeads, positionIds); +} + +void qwenApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride, + int kStride, int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads, + const int *positionIds) { + qwenApplyRotaryPosEmbed(query, key, emb_cos, emb_sin, qStride, kStride, dim, logn, + maxSupportedSeqLength, totSeqLen, qHeads, kHeads, positionIds); +} + +void qwenApplyRotaryPosEmbed(float16_t *query, float16_t *key, float *emb_cos, float *emb_sin, int qStride, int kStride, + int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads, + const int *positionIds) { + qwenApplyRotaryPosEmbed(query, key, emb_cos, emb_sin, qStride, kStride, dim, logn, maxSupportedSeqLength, + totSeqLen, qHeads, kHeads, positionIds); +} + #ifdef XFT_GPU // For LLaMA template diff --git a/src/kernels/rotary_embedding_kernels.h b/src/kernels/rotary_embedding_kernels.h index 8e782bfe..90fae897 100644 --- a/src/kernels/rotary_embedding_kernels.h +++ b/src/kernels/rotary_embedding_kernels.h @@ -52,8 +52,9 @@ void chatglm2ApplyRotaryPosEmbeding(bfloat16_t *query, bfloat16_t *key, int qStr void chatglm2ApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, int kStride, float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds); -void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, - int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); +// For ChatGLM2 continous batching +void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride, + int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); void chatglm2ApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride, int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); @@ -74,6 +75,18 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i float *cur_emb_sin, int inv_freq_size, const float *logn, int maxSupportedSeqLength, const int *qkShape, const int *positionIds); +// For Qwen1.0 continous batching +void qwenApplyRotaryPosEmbed(float *query, float *key, float *embCos, float *embSin, int qStride, int kStride, int dim, + const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads, const int *positionIds); + +void qwenApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *embCos, float *embSin, int qStride, int kStride, + int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads, + const int *positionIds); + +void qwenApplyRotaryPosEmbed(float16_t *query, float16_t *key, float *embCos, float *embSin, int qStride, int kStride, + int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads, + const int *positionIds); + #ifdef XFT_GPU // For LLaMA void llamaApplyRotaryPosEmbeding(void *device, float *query, float *key, int qStride, int kStride, float *emb_cos, diff --git a/src/layers/attention.h b/src/layers/attention.h index 135d9bd4..3c8ecb96 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -521,19 +521,15 @@ class Attention { if (ctx->maxPosEmbed > 0) { int qheads = this->endQHead - this->startQHead; int kheads = this->endKVHead - this->startKVHead; - int totInputSeqLen = 0; - for (auto seq : seqs) { - totInputSeqLen += seq->getInputSeqLen(); - } // Use the default position ids - std::vector posIds(totInputSeqLen); + std::vector posIds(totInSeqLen); int loc = 0; for (auto seq : seqs) { std::iota(posIds.begin() + loc, posIds.begin() + loc + seq->getInputSeqLen(), seq->getPastSeqLen()); loc += seq->getInputSeqLen(); } - qkpo.forward(query.Data(), key.Data(), totInputSeqLen, query.Stride(), key.Stride(), qheads, kheads, - posIds.data()); + qkpo.forward( + query.Data(), key.Data(), totInSeqLen, query.Stride(), key.Stride(), qheads, kheads, posIds.data()); } t3.release(); diff --git a/src/layers/rotary_embedding_qwen.cpp b/src/layers/rotary_embedding_qwen.cpp index 849294b6..bd27b299 100644 --- a/src/layers/rotary_embedding_qwen.cpp +++ b/src/layers/rotary_embedding_qwen.cpp @@ -314,20 +314,21 @@ void QwenRotaryEmbedding::forward( maxSupportedSeqLength, qkShape, positionIds); } +// For continuous batching void QwenRotaryEmbedding::forward( float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { - printf("Unsupported QwenRotaryEmbedding in cb mode !\n"); - exit(1); + xft::qwenApplyRotaryPosEmbed(query, key, cur_emb_cos, cur_emb_sin, qStride, kStride, this->dim, logn, + maxSupportedSeqLength, totSeqLen, qHeads, kHeads, positionIds); } void QwenRotaryEmbedding::forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { - printf("Unsupported QwenRotaryEmbedding in cb mode !\n"); - exit(1); + xft::qwenApplyRotaryPosEmbed(query, key, cur_emb_cos, cur_emb_sin, qStride, kStride, this->dim, logn, + maxSupportedSeqLength, totSeqLen, qHeads, kHeads, positionIds); } void QwenRotaryEmbedding::forward(float16_t *query, float16_t *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { - printf("Unsupported QwenRotaryEmbedding in cb mode !\n"); - exit(1); -} \ No newline at end of file + xft::qwenApplyRotaryPosEmbed(query, key, cur_emb_cos, cur_emb_sin, qStride, kStride, this->dim, logn, + maxSupportedSeqLength, totSeqLen, qHeads, kHeads, positionIds); +}