From 2fb5aa4f3e59f5005efed42cf1f98997716e8fde Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Mon, 17 Jun 2024 16:12:24 +0800 Subject: [PATCH] [Model] Support hybrid model in continuous batching. --- include/abstract_decoder.h | 4 ++-- include/models.h | 2 +- src/models/hybrid_model.h | 10 ++++++---- src/models/models.cpp | 8 ++++---- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/include/abstract_decoder.h b/include/abstract_decoder.h index ca8132db..2944ebfa 100644 --- a/include/abstract_decoder.h +++ b/include/abstract_decoder.h @@ -36,9 +36,9 @@ class AbstractDecoder { // | | | ||||||||||||||| | | // v |_____________|_____________|||||||||||||||_____________|__________| // |<----------------------- vocabSize ----------------------------->| - virtual std::tuple forward(int *ids, int64_t *dims, int step, bool logits_all = false) = 0; + virtual std::tuple forward(int *ids, int64_t *dims, int step, bool logitsAll = false) = 0; - virtual std::tuple forward(std::vector &seq, bool logits_all = false) = 0; + virtual std::tuple forward(std::vector &seq, bool logitsAll = false) = 0; // Reorder cached keys and values, size=batchSize*beamSize virtual void reorderCache(int *idx, int size) = 0; diff --git a/include/models.h b/include/models.h index bb1fdac6..41518b9b 100644 --- a/include/models.h +++ b/include/models.h @@ -74,7 +74,7 @@ class Model { bool isDone(); - std::tuple forward(bool logits_all = true); + std::tuple forward(bool logitsAll = true); std::vector generate(); diff --git a/src/models/hybrid_model.h b/src/models/hybrid_model.h index 03dfa7a9..da514282 100644 --- a/src/models/hybrid_model.h +++ b/src/models/hybrid_model.h @@ -72,10 +72,12 @@ class HybridModel : public AbstractDecoder { } } - // TODO - std::tuple forward(std::vector &seq, bool logits_all = false) { - throw std::logic_error("Method not implemented"); - return std::make_tuple(nullptr, 0, 0); + std::tuple forward(std::vector &seq, bool logitsAll = false) { + if (seq[0]->getStep() == 0) { + return firstModel->forward(seq, logitsAll); + } else { + return nextModel->forward(seq, logitsAll); + } } void reorderCache(int *idx, int size) { return firstModel->reorderCache(idx, size); } diff --git a/src/models/models.cpp b/src/models/models.cpp index 165b510a..2356d8ec 100644 --- a/src/models/models.cpp +++ b/src/models/models.cpp @@ -742,14 +742,14 @@ std::vector Model::finalize() { } } -std::tuple Model::forward(bool logits_all) { +std::tuple Model::forward(bool logitsAll) { // This forward will sync and gather all logits. // Return is a tuple of (logits, totalSeqSize, VocabSize) // TODO: Deprecate the following Path // Old path reture is (logits, offset, size) if (searcher != nullptr) { int64_t dims[3] = {batchSize, 1, seqLen}; - return decoder->forward(inputIds.data(), dims, 0, logits_all); + return decoder->forward(inputIds.data(), dims, 0, logitsAll); } // TODO: checking waiting queue if (workingGroup.empty()) { @@ -768,10 +768,10 @@ std::tuple Model::forward(bool logits_all) { } } - std::tuple result = decoder->forward(workingSeqs, logits_all); + std::tuple result = decoder->forward(workingSeqs, logitsAll); int totalSeqSize = workingSeqs.size(); - if (logits_all && workingSeqs[0]->getStep() == 0) { + if (logitsAll && workingSeqs[0]->getStep() == 0) { totalSeqSize = 0; for (auto x : workingSeqs) { totalSeqSize += x->getInputSeqLen();