From ea4e80fc32a516e30c74391fba6abf22e26e4ca4 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Sat, 15 Jun 2024 17:15:54 +0000 Subject: [PATCH] Run successfully. --- pyproject.toml | 4 +- src/kernels/rotary_embedding_kernels.cpp | 80 ++++++++++++++++++--- src/kernels/rotary_embedding_kernels.h | 14 +++- src/layers/attention.h | 37 +++++++--- src/layers/decoder_block.h | 2 +- src/layers/rotary_embedding.cpp | 92 ++++++++++++++++++------ src/models/common_decoder.h | 24 ++++++- 7 files changed, 208 insertions(+), 45 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ccf31b72..d37a451f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,5 @@ [tool.black] line-length = 120 -target-version = ["py38", "py39", "py310", "py311"] \ No newline at end of file +target-version = ["py38", "py39", "py310", "py311"] +[build-system] +requires = ["setuptools", "cmake"] diff --git a/src/kernels/rotary_embedding_kernels.cpp b/src/kernels/rotary_embedding_kernels.cpp index 9812bd35..3341ba5c 100644 --- a/src/kernels/rotary_embedding_kernels.cpp +++ b/src/kernels/rotary_embedding_kernels.cpp @@ -53,6 +53,7 @@ void llamaSetCosSinCache( // return q_embed, k_embed // +// For LLaMA template static inline void llamaApplyRotaryPosEmbeding(T *query, T *key, int qStride, int kStride, float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) { @@ -129,6 +130,7 @@ void llamaApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); } +// For LLaMA continous batching template static inline void llamaApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride, int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { @@ -293,12 +295,11 @@ void chatglm2ApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStrid query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); } - // For ChatGLM2/3 continous batching template -static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride, int kStride, - int inv_freq_size, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { +static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, float *emb_sin, int qStride, + int kStride, int inv_freq_size, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { int dim = inv_freq_size * 2; const int head_num = qHeads + kHeads; const int half = inv_freq_size; @@ -341,14 +342,14 @@ static inline void chatglm2ApplyRotaryPosEmbed(T *query, T *key, float *emb_cos, } } -void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride, int dim, - int totSeqLen, int qHeads, int kHeads, const int *positionIds) { +void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride, + int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { chatglm2ApplyRotaryPosEmbed( query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds); } void chatglm2ApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride, - int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { + int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { chatglm2ApplyRotaryPosEmbed( query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds); } @@ -467,7 +468,6 @@ static inline void llamaApplyRotaryPosEmbeding(void *device, T *query, T *key, i const int half_head_size = (head_size + 1) / 2; using namespace sycl; - // Reorder input sycl::queue *gpu_queue = static_cast(device); sycl::buffer positionIdsBuf(positionIds, sycl::range<1>(seqLen)); gpu_queue->submit([&](sycl::handler &cgh) { @@ -484,8 +484,8 @@ static inline void llamaApplyRotaryPosEmbeding(void *device, T *query, T *key, i const sycl::half cos = (sycl::half)emb_cos[pos * half_head_size + idx_half_head_dim]; const sycl::half sin = (sycl::half)emb_sin[pos * half_head_size + idx_half_head_dim]; - sycl::half *q = (sycl::half *)query + idx_bs_seq * qStride + idx_head_num * head_size + idx_half_head_dim; - sycl::half *k = (sycl::half *)key + idx_bs_seq * kStride + idx_head_num * head_size + idx_half_head_dim; + sycl::half *q = (sycl::half *)(query + idx_bs_seq * qStride + idx_head_num * head_size + idx_half_head_dim); + sycl::half *k = (sycl::half *)(key + idx_bs_seq * kStride + idx_head_num * head_size + idx_half_head_dim); if (idx_head_num < qHeads) { auto q1 = q[0]; @@ -519,6 +519,68 @@ void llamaApplyRotaryPosEmbeding(void *device, float16_t *query, float16_t *key, llamaApplyRotaryPosEmbeding(device, (sycl::half *)query, (sycl::half *)key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); } + +// For LLaMA continous batching +template +static inline void llamaApplyRotaryPosEmbed(void *device, T *query, T *key, float *emb_cos, float *emb_sin, int qStride, int kStride, + int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { + const int half = (dim + 1) / 2; + const int heads = std::max(qHeads, kHeads); + using namespace sycl; + + sycl::queue *gpu_queue = static_cast(device); + sycl::buffer positionIdsBuf(positionIds, sycl::range<1>(totSeqLen)); + gpu_queue->submit([&](sycl::handler &cgh) { + sycl::accessor position(positionIdsBuf, cgh, sycl::read_only); + sycl::range<2> globalSize(heads, totSeqLen); + sycl::range<2> workGroupSize(1, 1); + + cgh.parallel_for(sycl::nd_range(globalSize, workGroupSize), [=](sycl::nd_item<2> item) { + size_t idx_seq = item.get_global_id(0); + size_t idx_head = item.get_global_id(1); + size_t pos = position[idx_seq]; + + sycl::half *q = (sycl::half *)(query + idx_seq * qStride + idx_head * dim); + sycl::half *k = (sycl::half *)(key + idx_seq * kStride + idx_head * dim); + + for (int i = 0; i < half; i += 16) { + + sycl::half cos = (sycl::half)emb_cos[pos * half + i]; + sycl::half sin = (sycl::half)emb_sin[pos * half + i]; + + if (idx_head < qHeads) { + auto q1 = q[0]; + q[0] = q1 * cos - q[half] * sin; + q[half] = q[half] * cos + q1 * sin; + } + if (idx_head < kHeads) { + auto k1 = k[0]; + k[0] = k1 * cos - k[half] * sin; + k[half] = k[half] * cos + k1 * sin; + } + } + }); + }); + gpu_queue->wait(); +} + +void llamaApplyRotaryPosEmbed(void *device, float *query, float *key, float *emb_cos, float *emb_sin, int qStride, + int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { + llamaApplyRotaryPosEmbed( + device, query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds); +} + +void llamaApplyRotaryPosEmbed(void *device, bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, + int qStride, int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { + llamaApplyRotaryPosEmbed( + device, query, key, emb_cos, emb_sin, qStride, kStride, dim, totSeqLen, qHeads, kHeads, positionIds); +} + +void llamaApplyRotaryPosEmbed(void *device, float16_t *query, float16_t *key, float *emb_cos, float *emb_sin, + int qStride, int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds) { + llamaApplyRotaryPosEmbed(device, (sycl::half *)query, (sycl::half *)key, emb_cos, emb_sin, qStride, + kStride, dim, totSeqLen, qHeads, kHeads, positionIds); +} #endif } // namespace xft diff --git a/src/kernels/rotary_embedding_kernels.h b/src/kernels/rotary_embedding_kernels.h index 8e782bfe..a2fe580d 100644 --- a/src/kernels/rotary_embedding_kernels.h +++ b/src/kernels/rotary_embedding_kernels.h @@ -52,8 +52,8 @@ void chatglm2ApplyRotaryPosEmbeding(bfloat16_t *query, bfloat16_t *key, int qStr void chatglm2ApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, int kStride, float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds); -void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, - int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); +void chatglm2ApplyRotaryPosEmbed(float *query, float *key, float *emb_cos, float *emb_sin, int qStride, int kStride, + int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); void chatglm2ApplyRotaryPosEmbed(bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, int qStride, int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); @@ -84,6 +84,16 @@ void llamaApplyRotaryPosEmbeding(void *device, bfloat16_t *query, bfloat16_t *ke void llamaApplyRotaryPosEmbeding(void *device, float16_t *query, float16_t *key, int qStride, int kStride, float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds); + +// For LLaMA continous batching +void llamaApplyRotaryPosEmbed(void *device, float *query, float *key, float *embCos, float *embSin, int qStride, + int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); + +void llamaApplyRotaryPosEmbed(void *device, bfloat16_t *query, bfloat16_t *key, float *emb_cos, float *emb_sin, + int qStride, int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); + +void llamaApplyRotaryPosEmbed(void *device, float16_t *query, float16_t *key, float *emb_cos, float *emb_sin, + int qStride, int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); #endif } // namespace xft diff --git a/src/layers/attention.h b/src/layers/attention.h index 135d9bd4..da8a6164 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -336,15 +336,6 @@ class Attention { dbg.dumpMatrix(key); #endif -#ifdef XFT_GPU - int64_t qkvSize = qkvRows * qkvStride * sizeof(ImT); - ImT *qkvTmp = (ImT *)xft::alloc(qkvSize); - xft::memcopy(qkvTmp, qkvGroupMatMul.Data(), qkvSize, ctx->device); // error: need CPU ptr and GPU ptr - query.Assign(qkvTmp, inputBuffer.Rows(), qCols, qkvCols); - key.Assign(qkvTmp + qCols, inputBuffer.Rows(), kvCols, qkvCols); - value.Assign(qkvTmp + qCols + kvCols, inputBuffer.Rows(), kvCols, qkvCols); -#endif - // Revise attnFactor before softmax (for some models, attnFactor may be not the default value) // We initially introduced the code for ChatGLM, but eventually found it has no difference and was unnecessary. // However, we have chosen to keep it in the codebase in case it becomes useful for future models. @@ -362,6 +353,13 @@ class Attention { xft::Matrix attnSplit(imBuffer.Data(), imBuffer.Rows(), qCols, qCols); #ifdef XFT_GPU + int64_t qkvSize = qkvRows * qkvStride * sizeof(ImT); + ImT *qkvTmp = (ImT *)xft::alloc(qkvSize); + xft::memcopy(qkvTmp, qkvGroupMatMul.Data(), qkvSize, ctx->device); + query.Assign(qkvTmp, inputBuffer.Rows(), qCols, qkvCols); + key.Assign(qkvTmp + qCols, inputBuffer.Rows(), kvCols, qkvCols); + value.Assign(qkvTmp + qCols + kvCols, inputBuffer.Rows(), kvCols, qkvCols); + int64_t attnSplitSize = imBuffer.Rows() * qCols * sizeof(ImT); ImT *attnSplitTmp = (ImT *)xft::alloc(attnSplitSize); attnSplit.Assign(attnSplitTmp, imBuffer.Rows(), qCols, qCols); @@ -384,6 +382,7 @@ class Attention { xft::memcopy(imBuffer.Data(), attnSplit.Data(), attnSplitSize, ctx->device); attnSplit.Assign(imBuffer.Data(), imBuffer.Rows(), qCols, qCols); xft::dealloc(qkvTmp); + xft::dealloc(attnSplitTmp); #endif #ifdef XFT_DEBUG @@ -560,6 +559,19 @@ class Attention { // For multiple nodes inference, not the whole result buffer xft::Matrix attnSplit(imBuffer.Data(), imBuffer.Rows(), qCols, qCols); +#ifdef XFT_GPU + int64_t qkvSize = qkvRows * qkvStride * sizeof(ImT); + ImT *qkvTmp = (ImT *)xft::alloc(qkvSize); + xft::memcopy(qkvTmp, qkvGroupMatMul.Data(), qkvSize, ctx->device); + query.Assign(qkvTmp, inputBuffer.Rows(), qCols, qkvCols); + key.Assign(qkvTmp + qCols, inputBuffer.Rows(), kvCols, qkvCols); + value.Assign(qkvTmp + qCols + kvCols, inputBuffer.Rows(), kvCols, qkvCols); + + int64_t attnSplitSize = imBuffer.Rows() * qCols * sizeof(ImT); + ImT *attnSplitTmp = (ImT *)xft::alloc(attnSplitSize); + attnSplit.Assign(attnSplitTmp, imBuffer.Rows(), qCols, qCols); +#endif + if (seqs[0]->getStep() == 0) { // First token generation if (totInSeqLen > getFlashThresh() * seqs.size()) { flashAttention(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs); @@ -573,6 +585,13 @@ class Attention { } t4.release(); +#ifdef XFT_GPU + xft::memcopy(imBuffer.Data(), attnSplit.Data(), attnSplitSize, ctx->device); + attnSplit.Assign(imBuffer.Data(), imBuffer.Rows(), qCols, qCols); + xft::dealloc(qkvTmp); + xft::dealloc(attnSplitTmp); +#endif + #ifdef XFT_DEBUG dbg.debugPrint(">>> attention_%d (softmax * value): [%d, %d] (%d)\n", ctx->splitIdx, attnSplit.Rows(), attnSplit.Cols(), attnSplit.Stride()); diff --git a/src/layers/decoder_block.h b/src/layers/decoder_block.h index ba352066..c7348a92 100644 --- a/src/layers/decoder_block.h +++ b/src/layers/decoder_block.h @@ -121,7 +121,7 @@ class DecoderBlock { // Copy final result to the output buffer if (inputBuf != outputBuf && layersOnDuty % 2 == 0) { - std::memcpy(outputBuf, inputBuf, totInSeqLen * ctx->hiddenSize * sizeof(T)); + xft::memcopy(outputBuf, inputBuf, totInSeqLen * ctx->hiddenSize * sizeof(T), ctx->device); } } diff --git a/src/layers/rotary_embedding.cpp b/src/layers/rotary_embedding.cpp index 5efc6020..0e1b5ed6 100644 --- a/src/layers/rotary_embedding.cpp +++ b/src/layers/rotary_embedding.cpp @@ -16,6 +16,7 @@ #include "allocator.h" #include "compile_util.h" +#include "timeline.h" LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { const std::string inv_freq_str = "inv_freq"; @@ -105,30 +106,20 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(const int dim, const int max_position // |_____| |_____| // head_size/2 head_size/2 -#ifdef XFT_GPU - void LlamaRotaryEmbedding::forward( float *query, float *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { - xft::llamaApplyRotaryPosEmbeding(this->device, - query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); -} - -void LlamaRotaryEmbedding::forward( - bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { - xft::llamaApplyRotaryPosEmbeding(this->device, - query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); -} - -void LlamaRotaryEmbedding::forward( - float16_t *query, float16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { - xft::llamaApplyRotaryPosEmbeding(this->device, - query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); -} + TimeLine t("LlamaRotaryEmbedding.forward"); + if (device != nullptr) { +#ifdef XFT_GPU + xft::llamaApplyRotaryPosEmbeding(this->device, + query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); + return; #else + printf("[Warning] %s:%d: Defined GPU device, but did not use it.\n", __FILE__, __LINE__); +#endif + } -void LlamaRotaryEmbedding::forward( - float *query, float *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { int dim = inv_freq_size * 2; REQUIRES(dim == qkShape[3], "Incorrect shape, this dimention is not the head size."); @@ -175,33 +166,92 @@ void LlamaRotaryEmbedding::forward( void LlamaRotaryEmbedding::forward( bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { + TimeLine t("LlamaRotaryEmbedding.forward"); + + if (device != nullptr) { +#ifdef XFT_GPU + xft::llamaApplyRotaryPosEmbeding(this->device, + query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); + return; +#else + printf("[Warning] %s:%d: Defined GPU device, but did not use it.\n", __FILE__, __LINE__); +#endif + } + xft::llamaApplyRotaryPosEmbeding( query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); } void LlamaRotaryEmbedding::forward( float16_t *query, float16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { + TimeLine t("LlamaRotaryEmbedding.forward"); + + if (device != nullptr) { +#ifdef XFT_GPU + xft::llamaApplyRotaryPosEmbeding(this->device, + query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); + return; +#else + printf("[Warning] %s:%d: Defined GPU device, but did not use it.\n", __FILE__, __LINE__); +#endif + } + xft::llamaApplyRotaryPosEmbeding( query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); } -#endif // GPU - // For continuous batching void LlamaRotaryEmbedding::forward( float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { + TimeLine t("LlamaRotaryEmbedding.forward"); + + if (device != nullptr) { +#ifdef XFT_GPU + xft::llamaApplyRotaryPosEmbed(this->device, + query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds); + return; +#else + printf("[Warning] %s:%d: Defined GPU device, but did not use it.\n", __FILE__, __LINE__); +#endif + } + xft::llamaApplyRotaryPosEmbed( query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds); } void LlamaRotaryEmbedding::forward(bfloat16_t *query, bfloat16_t *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { + TimeLine t("LlamaRotaryEmbedding.forward"); + + if (device != nullptr) { +#ifdef XFT_GPU + xft::llamaApplyRotaryPosEmbed(this->device, + query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds); + return; +#else + printf("[Warning] %s:%d: Defined GPU device, but did not use it.\n", __FILE__, __LINE__); +#endif + } + xft::llamaApplyRotaryPosEmbed( query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds); } void LlamaRotaryEmbedding::forward(float16_t *query, float16_t *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { + TimeLine t("LlamaRotaryEmbedding.forward"); + + if (device != nullptr) { +#ifdef XFT_GPU + xft::llamaApplyRotaryPosEmbed(this->device, + query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds); + return; + +#else + printf("[Warning] %s:%d: Defined GPU device, but did not use it.\n", __FILE__, __LINE__); +#endif + } + xft::llamaApplyRotaryPosEmbed( query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds); } \ No newline at end of file diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index bcbfc05a..68e692fa 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -571,6 +571,16 @@ class CommonDecoder : public AbstractDecoder { // Embedding this->embeddingForward(allInputIds.data(), embBuf, totInputSeqLen); +#ifdef XFT_GPU + size_t embBufSize = totInputSeqLen * hiddenSize * sizeof(AttnInT); + AttnInT *embBufTmp = (AttnInT *)xft::alloc(embBufSize, ctx->device); + AttnInT *outBufTmp = (AttnInT *)xft::alloc( + actBuffers->Rows() * actBuffers->Cols() * sizeof(float) - embBufSize, ctx->device); + xft::memcopy(embBufTmp, embBuf, embBufSize, ctx->device); + embBuf = embBufTmp; + outBuf = outBufTmp; +#endif + // Decoder block (all layers) decoderBlock->forward(ctx, seqs, embBuf, embBuf); @@ -581,7 +591,7 @@ class CommonDecoder : public AbstractDecoder { int offset = -1; for (int b = 0; b < batchSize; ++b) { offset += seqs[b]->getInputSeqLen(); - memcpy(lnIn + b * hiddenSize, embBuf + offset * hiddenSize, hiddenSize * sizeof(MlpOutT)); + xft::memcopy(lnIn + b * hiddenSize, embBuf + offset * hiddenSize, hiddenSize * sizeof(MlpOutT), ctx->device); } } @@ -604,14 +614,24 @@ class CommonDecoder : public AbstractDecoder { // Predictor float *finalOut = (float *)outBuf; + auto splitSize = this->predictor->getSplitSize(); this->predictor->forward(ctx, lnOut, finalOut, logitRows); #ifdef XFT_DEBUG - auto splitSize = this->predictor->getSplitSize(); dbg.debugPrint("finalOut:\n"); dbg.dumpMatrix(finalOut, logitRows, splitSize, splitSize); #endif +#ifdef XFT_GPU + xft::dealloc(embBuf, ctx->device); + embBuf = (AttnInT *)actBuffers->Data(); + + float *finalOutTmp = (float *)(embBuf + totInputSeqLen * hiddenSize); + xft::memcopy(finalOutTmp, finalOut, logitRows * splitSize * sizeof(float), ctx->device); + xft::dealloc(outBuf, ctx->device); + finalOut = finalOutTmp; +#endif + return std::tuple( finalOut, this->predictor->getSplitOffset(), this->predictor->getSplitSize()); }