Skip to content

Commit

Permalink
Run successfully.
Browse files Browse the repository at this point in the history
  • Loading branch information
changqi1 committed Jun 15, 2024
1 parent c03cc38 commit ea4e80f
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 45 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
[tool.black]
line-length = 120
target-version = ["py38", "py39", "py310", "py311"]
target-version = ["py38", "py39", "py310", "py311"]
[build-system]
requires = ["setuptools", "cmake"]
80 changes: 71 additions & 9 deletions src/kernels/rotary_embedding_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ void llamaSetCosSinCache(
// return q_embed, k_embed
//

// For LLaMA
template <typename T>
static inline void llamaApplyRotaryPosEmbeding(T *query, T *key, int qStride, int kStride, float *emb_cos,
float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) {
Expand Down Expand Up @@ -129,6 +130,7 @@ void llamaApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride,
query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
}

// For LLaMA continous batching
template <typename T>
static inline void llamaApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
Expand Down Expand Up @@ -293,12 +295,11 @@ void chatglm2ApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStrid
query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
}


// For ChatGLM2/3 continous batching

template <typename T>
static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int inv_freq_size, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int inv_freq_size, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
int dim = inv_freq_size * 2;
const int head_num = qHeads + kHeads;
const int half = inv_freq_size;
Expand Down Expand Up @@ -341,14 +342,14 @@ static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos,
}
}

void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride, int dim,
int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
chatglm2ApplyRotaryPosEmbed<float>(
query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds);
}

void chatglm2ApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
chatglm2ApplyRotaryPosEmbed<bfloat16_t>(
query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds);
}
Expand Down Expand Up @@ -467,7 +468,6 @@ static inline void llamaApplyRotaryPosEmbeding(void *device, T *query, T *key, i
const int half_head_size = (head_size + 1) / 2;
using namespace sycl;

// Reorder input
sycl::queue *gpu_queue = static_cast<sycl::queue *>(device);
sycl::buffer<int, 1> positionIdsBuf(positionIds, sycl::range<1>(seqLen));
gpu_queue->submit([&](sycl::handler &cgh) {
Expand All @@ -484,8 +484,8 @@ static inline void llamaApplyRotaryPosEmbeding(void *device, T *query, T *key, i
const sycl::half cos = (sycl::half)emb_cos[pos * half_head_size + idx_half_head_dim];
const sycl::half sin = (sycl::half)emb_sin[pos * half_head_size + idx_half_head_dim];

sycl::half *q = (sycl::half *)query + idx_bs_seq * qStride + idx_head_num * head_size + idx_half_head_dim;
sycl::half *k = (sycl::half *)key + idx_bs_seq * kStride + idx_head_num * head_size + idx_half_head_dim;
sycl::half *q = (sycl::half *)(query + idx_bs_seq * qStride + idx_head_num * head_size + idx_half_head_dim);
sycl::half *k = (sycl::half *)(key + idx_bs_seq * kStride + idx_head_num * head_size + idx_half_head_dim);

if (idx_head_num < qHeads) {
auto q1 = q[0];
Expand Down Expand Up @@ -519,6 +519,68 @@ void llamaApplyRotaryPosEmbeding(void *device, float16_t *query, float16_t *key,
llamaApplyRotaryPosEmbeding<sycl::half>(device, (sycl::half *)query, (sycl::half *)key, qStride, kStride, emb_cos,
emb_sin, inv_freq_size, qkShape, positionIds);
}

// For LLaMA continous batching
template <typename T>
static inline void llamaApplyRotaryPosEmbed(void *device, T *query, T *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
const int half = (dim + 1) / 2;
const int heads = std::max(qHeads, kHeads);
using namespace sycl;

sycl::queue *gpu_queue = static_cast<sycl::queue *>(device);
sycl::buffer<int, 1> positionIdsBuf(positionIds, sycl::range<1>(totSeqLen));
gpu_queue->submit([&](sycl::handler &cgh) {
sycl::accessor position(positionIdsBuf, cgh, sycl::read_only);
sycl::range<2> globalSize(heads, totSeqLen);
sycl::range<2> workGroupSize(1, 1);

cgh.parallel_for(sycl::nd_range(globalSize, workGroupSize), [=](sycl::nd_item<2> item) {
size_t idx_seq = item.get_global_id(0);
size_t idx_head = item.get_global_id(1);
size_t pos = position[idx_seq];

sycl::half *q = (sycl::half *)(query + idx_seq * qStride + idx_head * dim);
sycl::half *k = (sycl::half *)(key + idx_seq * kStride + idx_head * dim);

for (int i = 0; i < half; i += 16) {

sycl::half cos = (sycl::half)emb_cos[pos * half + i];
sycl::half sin = (sycl::half)emb_sin[pos * half + i];

if (idx_head < qHeads) {
auto q1 = q[0];
q[0] = q1 * cos - q[half] * sin;
q[half] = q[half] * cos + q1 * sin;
}
if (idx_head < kHeads) {
auto k1 = k[0];
k[0] = k1 * cos - k[half] * sin;
k[half] = k[half] * cos + k1 * sin;
}
}
});
});
gpu_queue->wait();
}

void llamaApplyRotaryPosEmbed(void *device, float *query, float *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
llamaApplyRotaryPosEmbed<float>(
device, query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds);
}

void llamaApplyRotaryPosEmbed(void *device, bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin,
int qStride, int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
llamaApplyRotaryPosEmbed<bfloat16_t>(
device, query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds);
}

void llamaApplyRotaryPosEmbed(void *device, float16_t *query, float16_t *key, float *emb_cos, float *emb_sin,
int qStride, int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) {
llamaApplyRotaryPosEmbed<sycl::half>(device, (sycl::half *)query, (sycl::half *)key, emb_cos, emb_sin, qStride,
kStride, dim, totSeqLen, qHeads, kHeads, positionIds);
}
#endif

} // namespace xft
14 changes: 12 additions & 2 deletions src/kernels/rotary_embedding_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ void chatglm2ApplyRotaryPosEmbeding(bfloat16_t *query, bfloat16_t *key, int qStr
void chatglm2ApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, int kStride, float *emb_cos,
float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds);

void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);
void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride,
int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);

void chatglm2ApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);
Expand Down Expand Up @@ -84,6 +84,16 @@ void llamaApplyRotaryPosEmbeding(void *device, bfloat16_t *query, bfloat16_t *ke

void llamaApplyRotaryPosEmbeding(void *device, float16_t *query, float16_t *key, int qStride, int kStride,
float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds);

// For LLaMA continous batching
void llamaApplyRotaryPosEmbed(void *device, float *query, float *key, float *embCos, float *embSin, int qStride,
int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);

void llamaApplyRotaryPosEmbed(void *device, bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin,
int qStride, int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);

void llamaApplyRotaryPosEmbed(void *device, float16_t *query, float16_t *key, float *emb_cos, float *emb_sin,
int qStride, int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds);
#endif

} // namespace xft
37 changes: 28 additions & 9 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,6 @@ class Attention {
dbg.dumpMatrix(key);
#endif

#ifdef XFT_GPU
int64_t qkvSize = qkvRows * qkvStride * sizeof(ImT);
ImT *qkvTmp = (ImT *)xft::alloc(qkvSize);
xft::memcopy(qkvTmp, qkvGroupMatMul.Data(), qkvSize, ctx->device); // error: need CPU ptr and GPU ptr
query.Assign(qkvTmp, inputBuffer.Rows(), qCols, qkvCols);
key.Assign(qkvTmp + qCols, inputBuffer.Rows(), kvCols, qkvCols);
value.Assign(qkvTmp + qCols + kvCols, inputBuffer.Rows(), kvCols, qkvCols);
#endif

// Revise attnFactor before softmax (for some models, attnFactor may be not the default value)
// We initially introduced the code for ChatGLM, but eventually found it has no difference and was unnecessary.
// However, we have chosen to keep it in the codebase in case it becomes useful for future models.
Expand All @@ -362,6 +353,13 @@ class Attention {
xft::Matrix<ImT> attnSplit(imBuffer.Data(), imBuffer.Rows(), qCols, qCols);

#ifdef XFT_GPU
int64_t qkvSize = qkvRows * qkvStride * sizeof(ImT);
ImT *qkvTmp = (ImT *)xft::alloc(qkvSize);
xft::memcopy(qkvTmp, qkvGroupMatMul.Data(), qkvSize, ctx->device);
query.Assign(qkvTmp, inputBuffer.Rows(), qCols, qkvCols);
key.Assign(qkvTmp + qCols, inputBuffer.Rows(), kvCols, qkvCols);
value.Assign(qkvTmp + qCols + kvCols, inputBuffer.Rows(), kvCols, qkvCols);

int64_t attnSplitSize = imBuffer.Rows() * qCols * sizeof(ImT);
ImT *attnSplitTmp = (ImT *)xft::alloc(attnSplitSize);
attnSplit.Assign(attnSplitTmp, imBuffer.Rows(), qCols, qCols);
Expand All @@ -384,6 +382,7 @@ class Attention {
xft::memcopy(imBuffer.Data(), attnSplit.Data(), attnSplitSize, ctx->device);
attnSplit.Assign(imBuffer.Data(), imBuffer.Rows(), qCols, qCols);
xft::dealloc(qkvTmp);
xft::dealloc(attnSplitTmp);
#endif

#ifdef XFT_DEBUG
Expand Down Expand Up @@ -560,6 +559,19 @@ class Attention {
// For multiple nodes inference, not the whole result buffer
xft::Matrix<ImT> attnSplit(imBuffer.Data(), imBuffer.Rows(), qCols, qCols);

#ifdef XFT_GPU
int64_t qkvSize = qkvRows * qkvStride * sizeof(ImT);
ImT *qkvTmp = (ImT *)xft::alloc(qkvSize);
xft::memcopy(qkvTmp, qkvGroupMatMul.Data(), qkvSize, ctx->device);
query.Assign(qkvTmp, inputBuffer.Rows(), qCols, qkvCols);
key.Assign(qkvTmp + qCols, inputBuffer.Rows(), kvCols, qkvCols);
value.Assign(qkvTmp + qCols + kvCols, inputBuffer.Rows(), kvCols, qkvCols);

int64_t attnSplitSize = imBuffer.Rows() * qCols * sizeof(ImT);
ImT *attnSplitTmp = (ImT *)xft::alloc(attnSplitSize);
attnSplit.Assign(attnSplitTmp, imBuffer.Rows(), qCols, qCols);
#endif

if (seqs[0]->getStep() == 0) { // First token generation
if (totInSeqLen > getFlashThresh() * seqs.size()) {
flashAttention(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs);
Expand All @@ -573,6 +585,13 @@ class Attention {
}
t4.release();

#ifdef XFT_GPU
xft::memcopy(imBuffer.Data(), attnSplit.Data(), attnSplitSize, ctx->device);
attnSplit.Assign(imBuffer.Data(), imBuffer.Rows(), qCols, qCols);
xft::dealloc(qkvTmp);
xft::dealloc(attnSplitTmp);
#endif

#ifdef XFT_DEBUG
dbg.debugPrint(">>> attention_%d (softmax * value): [%d, %d] (%d)\n", ctx->splitIdx, attnSplit.Rows(),
attnSplit.Cols(), attnSplit.Stride());
Expand Down
2 changes: 1 addition & 1 deletion src/layers/decoder_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class DecoderBlock {

// Copy final result to the output buffer
if (inputBuf != outputBuf && layersOnDuty % 2 == 0) {
std::memcpy(outputBuf, inputBuf, totInSeqLen * ctx->hiddenSize * sizeof(T));
xft::memcopy(outputBuf, inputBuf, totInSeqLen * ctx->hiddenSize * sizeof(T), ctx->device);
}
}

Expand Down
92 changes: 71 additions & 21 deletions src/layers/rotary_embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "allocator.h"
#include "compile_util.h"
#include "timeline.h"

LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) {
const std::string inv_freq_str = "inv_freq";
Expand Down Expand Up @@ -105,30 +106,20 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(const int dim, const int max_position
// |_____| |_____|
// head_size/2 head_size/2

#ifdef XFT_GPU

void LlamaRotaryEmbedding::forward(
float *query, float *key, int qStride, int kStride, const int *qkShape, const int *positionIds) {
xft::llamaApplyRotaryPosEmbeding(this->device,
query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
}

void LlamaRotaryEmbedding::forward(
bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds) {
xft::llamaApplyRotaryPosEmbeding(this->device,
query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
}

void LlamaRotaryEmbedding::forward(
float16_t *query, float16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds) {
xft::llamaApplyRotaryPosEmbeding(this->device,
query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
}
TimeLine t("LlamaRotaryEmbedding.forward");

if (device != nullptr) {
#ifdef XFT_GPU
xft::llamaApplyRotaryPosEmbeding(this->device,
query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
return;
#else
printf("[Warning] %s:%d: Defined GPU device, but did not use it.\n", __FILE__, __LINE__);
#endif
}

void LlamaRotaryEmbedding::forward(
float *query, float *key, int qStride, int kStride, const int *qkShape, const int *positionIds) {
int dim = inv_freq_size * 2;
REQUIRES(dim == qkShape[3], "Incorrect shape, this dimention is not the head size.");

Expand Down Expand Up @@ -175,33 +166,92 @@ void LlamaRotaryEmbedding::forward(

void LlamaRotaryEmbedding::forward(
bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds) {
TimeLine t("LlamaRotaryEmbedding.forward");

if (device != nullptr) {
#ifdef XFT_GPU
xft::llamaApplyRotaryPosEmbeding(this->device,
query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
return;
#else
printf("[Warning] %s:%d: Defined GPU device, but did not use it.\n", __FILE__, __LINE__);
#endif
}

xft::llamaApplyRotaryPosEmbeding(
query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
}

void LlamaRotaryEmbedding::forward(
float16_t *query, float16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds) {
TimeLine t("LlamaRotaryEmbedding.forward");

if (device != nullptr) {
#ifdef XFT_GPU
xft::llamaApplyRotaryPosEmbeding(this->device,
query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
return;
#else
printf("[Warning] %s:%d: Defined GPU device, but did not use it.\n", __FILE__, __LINE__);
#endif
}

xft::llamaApplyRotaryPosEmbeding(
query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
}

#endif // GPU

// For continuous batching
void LlamaRotaryEmbedding::forward(
float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) {
TimeLine t("LlamaRotaryEmbedding.forward");

if (device != nullptr) {
#ifdef XFT_GPU
xft::llamaApplyRotaryPosEmbed(this->device,
query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds);
return;
#else
printf("[Warning] %s:%d: Defined GPU device, but did not use it.\n", __FILE__, __LINE__);
#endif
}

xft::llamaApplyRotaryPosEmbed(
query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds);
}

void LlamaRotaryEmbedding::forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, int kStride,
int qHeads, int kHeads, int *positionIds) {
TimeLine t("LlamaRotaryEmbedding.forward");

if (device != nullptr) {
#ifdef XFT_GPU
xft::llamaApplyRotaryPosEmbed(this->device,
query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds);
return;
#else
printf("[Warning] %s:%d: Defined GPU device, but did not use it.\n", __FILE__, __LINE__);
#endif
}

xft::llamaApplyRotaryPosEmbed(
query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds);
}

void LlamaRotaryEmbedding::forward(float16_t *query, float16_t *key, int totSeqLen, int qStride, int kStride,
int qHeads, int kHeads, int *positionIds) {
TimeLine t("LlamaRotaryEmbedding.forward");

if (device != nullptr) {
#ifdef XFT_GPU
xft::llamaApplyRotaryPosEmbed(this->device,
query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds);
return;

#else
printf("[Warning] %s:%d: Defined GPU device, but did not use it.\n", __FILE__, __LINE__);
#endif
}

xft::llamaApplyRotaryPosEmbed(
query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds);
}
Loading

0 comments on commit ea4e80f

Please sign in to comment.