Skip to content

Commit

Permalink
[Layers] Add qwenRope support for CB Qwen1.0 (#449)
Browse files Browse the repository at this point in the history
  • Loading branch information
abenmao committed Jun 17, 2024
1 parent 2331613 commit 8bd8d68
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 22 deletions.
86 changes: 80 additions & 6 deletions src/kernels/rotary_embedding_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,11 @@ void chatglm2ApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStrid
query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
}


// For ChatGLM2/3 continous batching

template <typename T>
static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int inv_freq_size, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int inv_freq_size, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
int dim = inv_freq_size * 2;
const int head_num = qHeads + kHeads;
const int half = inv_freq_size;
Expand Down Expand Up @@ -341,14 +340,14 @@ static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos,
}
}

void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride, int dim,
int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
chatglm2ApplyRotaryPosEmbed<float>(
query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds);
}

void chatglm2ApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
chatglm2ApplyRotaryPosEmbed<bfloat16_t>(
query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds);
}
Expand Down Expand Up @@ -450,6 +449,81 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i
maxSupportedSeqLength, qkShape, positionIds);
}

template <typename T>
static inline void qwenApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads,
const int *positionIds) {
const int half = (dim + 1) / 2;
const int heads = std::max(qHeads, kHeads);

#pragma omp parallel for collapse(2)
for (int head = 0; head < heads; ++head) {
for (int seq = 0; seq < totSeqLen; ++seq) {
int pos = positionIds[seq];

float *pcos = emb_cos + pos * dim;
float *psin = emb_sin + pos * dim;

T *q = query + seq * qStride + head * dim;
T *k = key + seq * kStride + head * dim;

__m512 pScale = _mm512_set1_ps(logn[pos]);

// Process chunks of 16 elements at a time
for (int i = 0; i < half; i += 16) {
int remain = half - i;
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);

__m512 pCosVec = _mm512_maskz_loadu_ps(mask, &pcos[i]);
__m512 pCosHalfVec = _mm512_maskz_loadu_ps(mask, &pcos[i + half]);
__m512 pSinVec = _mm512_maskz_loadu_ps(mask, &psin[i]);
__m512 pSinHalfVec = _mm512_maskz_loadu_ps(mask, &psin[i + half]);

if (head < qHeads) {
__m512 qVec = xft::load_avx512(mask, &q[i]);
__m512 qHalfVec = xft::load_avx512(mask, &q[i + half]);
__m512 qNew
= _mm512_mul_ps(_mm512_fmsub_ps(qVec, pCosVec, _mm512_mul_ps(qHalfVec, pSinVec)), pScale);
__m512 qHalfNew = _mm512_mul_ps(
_mm512_fmadd_ps(qHalfVec, pCosHalfVec, _mm512_mul_ps(qVec, pSinHalfVec)), pScale);
xft::store_avx512(&q[i], mask, qNew);
xft::store_avx512(&q[i + half], mask, qHalfNew);
}

if (head < kHeads) {
__m512 kVec = xft::load_avx512(mask, &k[i]);
__m512 kHalfVec = xft::load_avx512(mask, &k[i + half]);
__m512 kNew = _mm512_fmsub_ps(kVec, pCosVec, _mm512_mul_ps(kHalfVec, pSinVec));
__m512 kHalfNew = _mm512_fmadd_ps(kHalfVec, pCosHalfVec, _mm512_mul_ps(kVec, pSinHalfVec));
xft::store_avx512(&k[i], mask, kNew);
xft::store_avx512(&k[i + half], mask, kHalfNew);
}
}
}
}
}

void qwenApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads,
const int *positionIds) {
qwenApplyRotaryPosEmbed<float>(query, key, emb_cos, emb_sin, qStride, kStride, dim, logn, maxSupportedSeqLength,
totSeqLen, qHeads, kHeads, positionIds);
}

void qwenApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads,
const int *positionIds) {
qwenApplyRotaryPosEmbed<bfloat16_t>(query, key, emb_cos, emb_sin, qStride, kStride, dim, logn,
maxSupportedSeqLength, totSeqLen, qHeads, kHeads, positionIds);
}

void qwenApplyRotaryPosEmbed(float16_t *query, float16_t *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads,
const int *positionIds) {
qwenApplyRotaryPosEmbed<float16_t>(query, key, emb_cos, emb_sin, qStride, kStride, dim, logn, maxSupportedSeqLength,
totSeqLen, qHeads, kHeads, positionIds);
}

#ifdef XFT_GPU
// For LLaMA
template <typename T>
Expand Down
17 changes: 15 additions & 2 deletions src/kernels/rotary_embedding_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ void chatglm2ApplyRotaryPosEmbeding(bfloat16_t *query, bfloat16_t *key, int qStr
void chatglm2ApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, int kStride, float *emb_cos,
float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds);

void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);
// For ChatGLM2 continous batching
void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);

void chatglm2ApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);
Expand All @@ -74,6 +75,18 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i
float *cur_emb_sin, int inv_freq_size, const float *logn, int maxSupportedSeqLength, const int *qkShape,
const int *positionIds);

// For Qwen1.0 continous batching
void qwenApplyRotaryPosEmbed(float *query, float *key, float *embCos, float *embSin, int qStride, int kStride, int dim,
const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads, const int *positionIds);

void qwenApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *embCos, float *embSin, int qStride, int kStride,
int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads,
const int *positionIds);

void qwenApplyRotaryPosEmbed(float16_t *query, float16_t *key, float *embCos, float *embSin, int qStride, int kStride,
int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads,
const int *positionIds);

#ifdef XFT_GPU
// For LLaMA
void llamaApplyRotaryPosEmbeding(void *device, float *query, float *key, int qStride, int kStride, float *emb_cos,
Expand Down
10 changes: 3 additions & 7 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,19 +521,15 @@ class Attention {
if (ctx->maxPosEmbed > 0) {
int qheads = this->endQHead - this->startQHead;
int kheads = this->endKVHead - this->startKVHead;
int totInputSeqLen = 0;
for (auto seq : seqs) {
totInputSeqLen += seq->getInputSeqLen();
}
// Use the default position ids
std::vector<int> posIds(totInputSeqLen);
std::vector<int> posIds(totInSeqLen);
int loc = 0;
for (auto seq : seqs) {
std::iota(posIds.begin() + loc, posIds.begin() + loc + seq->getInputSeqLen(), seq->getPastSeqLen());
loc += seq->getInputSeqLen();
}
qkpo.forward(query.Data(), key.Data(), totInputSeqLen, query.Stride(), key.Stride(), qheads, kheads,
posIds.data());
qkpo.forward(
query.Data(), key.Data(), totInSeqLen, query.Stride(), key.Stride(), qheads, kheads, posIds.data());
}
t3.release();

Expand Down
15 changes: 8 additions & 7 deletions src/layers/rotary_embedding_qwen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,20 +314,21 @@ void QwenRotaryEmbedding::forward(
maxSupportedSeqLength, qkShape, positionIds);
}

// For continuous batching
void QwenRotaryEmbedding::forward(
float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) {
printf("Unsupported QwenRotaryEmbedding in cb mode !\n");
exit(1);
xft::qwenApplyRotaryPosEmbed(query, key, cur_emb_cos, cur_emb_sin, qStride, kStride, this->dim, logn,
maxSupportedSeqLength, totSeqLen, qHeads, kHeads, positionIds);
}

void QwenRotaryEmbedding::forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, int kStride,
int qHeads, int kHeads, int *positionIds) {
printf("Unsupported QwenRotaryEmbedding in cb mode !\n");
exit(1);
xft::qwenApplyRotaryPosEmbed(query, key, cur_emb_cos, cur_emb_sin, qStride, kStride, this->dim, logn,
maxSupportedSeqLength, totSeqLen, qHeads, kHeads, positionIds);
}

void QwenRotaryEmbedding::forward(float16_t *query, float16_t *key, int totSeqLen, int qStride, int kStride, int qHeads,
int kHeads, int *positionIds) {
printf("Unsupported QwenRotaryEmbedding in cb mode !\n");
exit(1);
}
xft::qwenApplyRotaryPosEmbed(query, key, cur_emb_cos, cur_emb_sin, qStride, kStride, this->dim, logn,
maxSupportedSeqLength, totSeqLen, qHeads, kHeads, positionIds);
}

0 comments on commit 8bd8d68

Please sign in to comment.