Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Layers] Add qwenRope support for Qwen1.0 in CB mode #449

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 80 additions & 6 deletions src/kernels/rotary_embedding_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,11 @@ void chatglm2ApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStrid
query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
}


// For ChatGLM2/3 continous batching

template <typename T>
static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int inv_freq_size, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int inv_freq_size, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
int dim = inv_freq_size * 2;
const int head_num = qHeads + kHeads;
const int half = inv_freq_size;
Expand Down Expand Up @@ -341,14 +340,14 @@ static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos,
}
}

void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride, int dim,
int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
chatglm2ApplyRotaryPosEmbed<float>(
query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds);
}

void chatglm2ApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
chatglm2ApplyRotaryPosEmbed<bfloat16_t>(
query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds);
}
Expand Down Expand Up @@ -450,6 +449,81 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i
maxSupportedSeqLength, qkShape, positionIds);
}

template <typename T>
static inline void qwenApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads,
const int *positionIds) {
const int half = (dim + 1) / 2;
const int heads = std::max(qHeads, kHeads);

#pragma omp parallel for collapse(2)
for (int head = 0; head < heads; ++head) {
for (int seq = 0; seq < totSeqLen; ++seq) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For next step, considering first token, should we swap the 2 loops to make each thread accessing contiguous memory? may deserve to test such implementation.
OK for current version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it~

int pos = positionIds[seq];

float *pcos = emb_cos + pos * dim;
float *psin = emb_sin + pos * dim;

T *q = query + seq * qStride + head * dim;
T *k = key + seq * kStride + head * dim;

__m512 pScale = _mm512_set1_ps(logn[pos]);

// Process chunks of 16 elements at a time
for (int i = 0; i < half; i += 16) {
int remain = half - i;
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);

__m512 pCosVec = _mm512_maskz_loadu_ps(mask, &pcos[i]);
__m512 pCosHalfVec = _mm512_maskz_loadu_ps(mask, &pcos[i + half]);
__m512 pSinVec = _mm512_maskz_loadu_ps(mask, &psin[i]);
__m512 pSinHalfVec = _mm512_maskz_loadu_ps(mask, &psin[i + half]);

if (head < qHeads) {
__m512 qVec = xft::load_avx512(mask, &q[i]);
__m512 qHalfVec = xft::load_avx512(mask, &q[i + half]);
__m512 qNew
= _mm512_mul_ps(_mm512_fmsub_ps(qVec, pCosVec, _mm512_mul_ps(qHalfVec, pSinVec)), pScale);
__m512 qHalfNew = _mm512_mul_ps(
_mm512_fmadd_ps(qHalfVec, pCosHalfVec, _mm512_mul_ps(qVec, pSinHalfVec)), pScale);
xft::store_avx512(&q[i], mask, qNew);
xft::store_avx512(&q[i + half], mask, qHalfNew);
}

if (head < kHeads) {
__m512 kVec = xft::load_avx512(mask, &k[i]);
__m512 kHalfVec = xft::load_avx512(mask, &k[i + half]);
__m512 kNew = _mm512_fmsub_ps(kVec, pCosVec, _mm512_mul_ps(kHalfVec, pSinVec));
__m512 kHalfNew = _mm512_fmadd_ps(kHalfVec, pCosHalfVec, _mm512_mul_ps(kVec, pSinHalfVec));
xft::store_avx512(&k[i], mask, kNew);
xft::store_avx512(&k[i + half], mask, kHalfNew);
}
}
}
}
}

void qwenApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads,
const int *positionIds) {
qwenApplyRotaryPosEmbed<float>(query, key, emb_cos, emb_sin, qStride, kStride, dim, logn, maxSupportedSeqLength,
totSeqLen, qHeads, kHeads, positionIds);
}

void qwenApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads,
const int *positionIds) {
qwenApplyRotaryPosEmbed<bfloat16_t>(query, key, emb_cos, emb_sin, qStride, kStride, dim, logn,
maxSupportedSeqLength, totSeqLen, qHeads, kHeads, positionIds);
}

void qwenApplyRotaryPosEmbed(float16_t *query, float16_t *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads,
const int *positionIds) {
qwenApplyRotaryPosEmbed<float16_t>(query, key, emb_cos, emb_sin, qStride, kStride, dim, logn, maxSupportedSeqLength,
totSeqLen, qHeads, kHeads, positionIds);
}

#ifdef XFT_GPU
// For LLaMA
template <typename T>
Expand Down
17 changes: 15 additions & 2 deletions src/kernels/rotary_embedding_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ void chatglm2ApplyRotaryPosEmbeding(bfloat16_t *query, bfloat16_t *key, int qStr
void chatglm2ApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, int kStride, float *emb_cos,
float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds);

void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);
// For ChatGLM2 continous batching
void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);

void chatglm2ApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);
Expand All @@ -74,6 +75,18 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i
float *cur_emb_sin, int inv_freq_size, const float *logn, int maxSupportedSeqLength, const int *qkShape,
const int *positionIds);

// For Qwen1.0 continous batching
void qwenApplyRotaryPosEmbed(float *query, float *key, float *embCos, float *embSin, int qStride, int kStride, int dim,
const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads, const int *positionIds);

void qwenApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *embCos, float *embSin, int qStride, int kStride,
int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads,
const int *positionIds);

void qwenApplyRotaryPosEmbed(float16_t *query, float16_t *key, float *embCos, float *embSin, int qStride, int kStride,
int dim, const float *logn, int maxSupportedSeqLength, int totSeqLen, int qHeads, int kHeads,
const int *positionIds);

#ifdef XFT_GPU
// For LLaMA
void llamaApplyRotaryPosEmbeding(void *device, float *query, float *key, int qStride, int kStride, float *emb_cos,
Expand Down
10 changes: 3 additions & 7 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,19 +521,15 @@ class Attention {
if (ctx->maxPosEmbed > 0) {
int qheads = this->endQHead - this->startQHead;
int kheads = this->endKVHead - this->startKVHead;
int totInputSeqLen = 0;
for (auto seq : seqs) {
totInputSeqLen += seq->getInputSeqLen();
}
// Use the default position ids
std::vector<int> posIds(totInputSeqLen);
std::vector<int> posIds(totInSeqLen);
int loc = 0;
for (auto seq : seqs) {
std::iota(posIds.begin() + loc, posIds.begin() + loc + seq->getInputSeqLen(), seq->getPastSeqLen());
loc += seq->getInputSeqLen();
}
qkpo.forward(query.Data(), key.Data(), totInputSeqLen, query.Stride(), key.Stride(), qheads, kheads,
posIds.data());
qkpo.forward(
query.Data(), key.Data(), totInSeqLen, query.Stride(), key.Stride(), qheads, kheads, posIds.data());
}
t3.release();

Expand Down
15 changes: 8 additions & 7 deletions src/layers/rotary_embedding_qwen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,20 +314,21 @@ void QwenRotaryEmbedding::forward(
maxSupportedSeqLength, qkShape, positionIds);
}

// For continuous batching
void QwenRotaryEmbedding::forward(
float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) {
printf("Unsupported QwenRotaryEmbedding in cb mode !\n");
exit(1);
xft::qwenApplyRotaryPosEmbed(query, key, cur_emb_cos, cur_emb_sin, qStride, kStride, this->dim, logn,
maxSupportedSeqLength, totSeqLen, qHeads, kHeads, positionIds);
}

void QwenRotaryEmbedding::forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, int kStride,
int qHeads, int kHeads, int *positionIds) {
printf("Unsupported QwenRotaryEmbedding in cb mode !\n");
exit(1);
xft::qwenApplyRotaryPosEmbed(query, key, cur_emb_cos, cur_emb_sin, qStride, kStride, this->dim, logn,
maxSupportedSeqLength, totSeqLen, qHeads, kHeads, positionIds);
}

void QwenRotaryEmbedding::forward(float16_t *query, float16_t *key, int totSeqLen, int qStride, int kStride, int qHeads,
int kHeads, int *positionIds) {
printf("Unsupported QwenRotaryEmbedding in cb mode !\n");
exit(1);
}
xft::qwenApplyRotaryPosEmbed(query, key, cur_emb_cos, cur_emb_sin, qStride, kStride, this->dim, logn,
maxSupportedSeqLength, totSeqLen, qHeads, kHeads, positionIds);
}