Skip to content

Commit

Permalink
Fixed infinite lock by adding request ids to the preprocess method
Browse files Browse the repository at this point in the history
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
  • Loading branch information
shrinath-suresh committed Aug 21, 2023
1 parent 8addde3 commit 002e221
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 27 deletions.
155 changes: 133 additions & 22 deletions cpp/src/examples/llm/llm_handler.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "src/examples/image_classifier/llm/llm_handler.hh"
nclude "src/examples/image_classifier/llm/llm_handler.hh"

#include <torch/script.h>
#include <torch/torch.h>
Expand All @@ -11,6 +11,29 @@

namespace llm {

void LlmHandler::initialize_context() {
// gpt_params params;
params.seed = 42;
params.n_threads = 4;
params.repeat_last_n = 64;

auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
lparams.n_gqa = params.n_gqa;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;

llama_ctx = llama_new_context_with_model(llamamodel, lparams);

if (llama_ctx == nullptr) {
std::cerr << "Failed to initialize llama context" << std::endl;
} else {
std::cout << "Context initialized successfully" << std::endl;
}
}

std::pair<std::shared_ptr<torch::jit::script::Module>,
std::shared_ptr<torch::Device>>
LlmHandler::LoadModel(
Expand All @@ -23,13 +46,24 @@ LlmHandler::LoadModel(
manifest_->GetModel().serialized_file),
*device));

// Load LLM
gpt_params params;
// TODO: Fetch the path from context
params.model = "/home/ubuntu/serve/cpp/llama-2-7b-chat.ggmlv3.q4_0.bin";
llama_backend_init(params.numa);
std::tie(llamamodel, llama_ctx) = llama_init_from_gpt_params(params);

auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
lparams.n_gqa = params.n_gqa;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
llamamodel = llama_load_model_from_file(params.model.c_str(), lparams);
// llama_ctx = llama_new_context_with_model(llamamodel, lparams);
// initialize_context();

// // Load LLM
// gpt_params params;
// // TODO: Fetch the path from context
// params.model = "/home/ubuntu/serve/cpp/llama-2-7b-chat.ggmlv3.q4_0.bin";
// llama_backend_init(params.numa);
// std::tie(llamamodel, llama_ctx) = llama_init_from_gpt_params(params);

return std::make_pair(module, device);
} catch (const c10::Error& e) {
Expand All @@ -50,13 +84,48 @@ std::vector<torch::jit::IValue> LlmHandler::Preprocess(
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
std::shared_ptr<torchserve::InferenceRequestBatch>& request_batch,
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) {
std::cout << "Initializing llama context" << std::endl;

initialize_context();

std::cout << "Llama context initialized" << std::endl;

std::vector<torch::jit::IValue> batch_ivalue;
std::vector<torch::Tensor> batch_tensors;

uint8_t idx = 0;
for (auto& request : *request_batch) {
try {
std::vector new_data = request.parameters["data"];
std::string msg = torchserve::Converter::VectorToStr(new_data);
(*response_batch)[request.request_id] =
std::make_shared<torchserve::InferenceResponse>(request.request_id);
idx_to_req_id.first += idx_to_req_id.first.empty()
? request.request_id
: "," + request.request_id;

auto data_it = request.parameters.find(
torchserve::PayloadType::kPARAMETER_NAME_DATA);
auto dtype_it =
request.headers.find(torchserve::PayloadType::kHEADER_NAME_DATA_TYPE);
if (data_it == request.parameters.end()) {
data_it = request.parameters.find(
torchserve::PayloadType::kPARAMETER_NAME_BODY);
dtype_it = request.headers.find(
torchserve::PayloadType::kHEADER_NAME_BODY_TYPE);
}

if (data_it == request.parameters.end() ||
dtype_it == request.headers.end()) {
TS_LOGF(ERROR, "Empty payload for request id: {}", request.request_id);
(*response_batch)[request.request_id]->SetResponse(
500, "data_type", torchserve::PayloadType::kCONTENT_TYPE_TEXT,
"Empty payload");
continue;
}

std::cout << "Received Input: " << data_it->second << std::endl;

// std::vector new_data = request.parameters["data"];
// std::string msg = torchserve::Converter::VectorToStr(new_data);
std::string msg = torchserve::Converter::VectorToStr(data_it->second);

// tokenization

Expand All @@ -82,6 +151,7 @@ std::vector<torch::jit::IValue> LlmHandler::Preprocess(

torch::Tensor stacked_tensor = torch::stack(tensor_vector);
batch_ivalue.push_back(stacked_tensor);
idx_to_req_id.second[idx++] = request.request_id;

} catch (const std::runtime_error& e) {
TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}",
Expand Down Expand Up @@ -128,13 +198,11 @@ torch::Tensor LlmHandler::Inference(
tokens_list.push_back(id);
}

std::vector<std::string> generated_tokens;
gpt_params params;
// gpt_params params;

const int max_context_size = 64;

while (llama_get_kv_cache_token_count(llama_ctx) < max_context_size) {

if (llama_eval(llama_ctx, tokens_list.data(), int(tokens_list.size()),
llama_get_kv_cache_token_count(llama_ctx),
params.n_threads)) {
Expand Down Expand Up @@ -164,23 +232,66 @@ torch::Tensor LlmHandler::Inference(
break;
}

generated_tokens.push_back(llama_token_to_str(llama_ctx, new_token_id));

// Print the new token :
std::cout << llama_token_to_str(llama_ctx, new_token_id) << std::endl;
std::cout << "New Token: " << llama_token_to_str(llama_ctx, new_token_id);

// Push this new token for next evaluation :
tokens_list.push_back(new_token_id);
}

} // wend of main loop
std::vector<torch::Tensor> tensor_vector;
for (auto id : tokens_list) {
torch::Tensor tensor = torch::tensor(id, torch::kLong);
tensor_vector.push_back(tensor);
}

torch::Tensor inference_result =
torch::from_blob(tokens_list.data(),
{static_cast<long>(tokens_list.size())}, torch::kInt32);
torch::Tensor stacked_tensor = torch::stack(tensor_vector);

return inference_result;
llama_free(llama_ctx);
return stacked_tensor;
}

void LlmHandler::Postprocess(
const torch::Tensor& data,
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) {
for (const auto& kv : idx_to_req_id.second) {
try {
int64_t num_elements = data.numel();

// Convert the tensor to a vector of long values
std::stringstream generated_text_stream;

auto data_ptr = data.data_ptr<int64_t>();
for (int64_t i = 0; i < num_elements; ++i) {
generated_text_stream << llama_token_to_str(llama_ctx, data_ptr[i]);
}

std::string generated_text_str = generated_text_stream.str();
std::cout << "Generated Text Str: " << generated_text_str << std::endl;

auto response = (*response_batch)[kv.second];

response->SetResponse(200, "data_type",
torchserve::PayloadType::kDATA_TYPE_STRING,
generated_text_str);
} catch (const std::runtime_error& e) {
TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}",
kv.second, e.what());
auto response = (*response_batch)[kv.second];
response->SetResponse(500, "data_type",
torchserve::PayloadType::kDATA_TYPE_STRING,
"runtime_error, failed to postprocess tensor");
} catch (const c10::Error& e) {
TS_LOGF(ERROR,
"Failed to postprocess tensor for request id: {}, error: {}",
kv.second, e.msg());
auto response = (*response_batch)[kv.second];
response->SetResponse(500, "data_type",
torchserve::PayloadType::kDATA_TYPE_STRING,
"c10 error, failed to postprocess tensor");
}
}
}

} // namespace llm

Expand Down
13 changes: 8 additions & 5 deletions cpp/src/examples/llm/llm_handler.hh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
namespace llm {
class LlmHandler : public torchserve::torchscripted::BaseHandler {
private:
gpt_params params;
llama_model* llamamodel;
llama_context* llama_ctx;

Expand All @@ -18,6 +19,8 @@ class LlmHandler : public torchserve::torchscripted::BaseHandler {
// NOLINTEND(bugprone-exception-escape)
~LlmHandler() override = default;

void initialize_context();

virtual std::pair<std::shared_ptr<torch::jit::script::Module>,
std::shared_ptr<torch::Device>>
LoadModel(std::shared_ptr<torchserve::LoadModelRequest>& load_model_request);
Expand All @@ -37,11 +40,11 @@ class LlmHandler : public torchserve::torchscripted::BaseHandler {
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch)
override;

// void Postprocess(
// const torch::Tensor& data,
// std::pair<std::string&, std::map<uint8_t, std::string>&>&
// idx_to_req_id, std::shared_ptr<torchserve::InferenceResponseBatch>&
// response_batch) override;
void Postprocess(
const torch::Tensor& data,
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch)
override;
};
} // namespace llm
#endif // LLM_HANDLER_HH_

0 comments on commit 002e221

Please sign in to comment.