Skip to content

Commit

Permalink
Adopted llama.cpp api changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mreso committed Jan 12, 2024
1 parent eda470f commit 28c9b02
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ repos:
- id: black
additional_dependencies: ['click==8.0.4']
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black"]
6 changes: 4 additions & 2 deletions cpp/src/examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ set(MNIST_SOURCE_FILES "")
list(APPEND MNIST_SOURCE_FILES ${MNIST_SRC_DIR}/mnist_handler.cc)
add_library(mnist_handler SHARED ${MNIST_SOURCE_FILES})
target_include_directories(mnist_handler PUBLIC ${MNIST_SRC_DIR})
target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES})
target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES})

set(LLM_SRC_DIR "${torchserve_cpp_SOURCE_DIR}/src/examples/llamacpp")
set(LLAMACPP_SRC_DIR "/home/ubuntu/llama.cpp")
Expand All @@ -20,10 +20,12 @@ set(MY_OBJECT_FILES
${LLAMACPP_SRC_DIR}/ggml.o
${LLAMACPP_SRC_DIR}/llama.o
${LLAMACPP_SRC_DIR}/common.o
${LLAMACPP_SRC_DIR}/k_quants.o
${LLAMACPP_SRC_DIR}/ggml-quants.o
${LLAMACPP_SRC_DIR}/ggml-alloc.o
${LLAMACPP_SRC_DIR}/grammar-parser.o
${LLAMACPP_SRC_DIR}/console.o
${LLAMACPP_SRC_DIR}/build-info.o
${LLAMACPP_SRC_DIR}/ggml-backend.o

)

Expand Down
11 changes: 5 additions & 6 deletions cpp/src/examples/llamacpp/llamacpp_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ LlamacppHandler::LoadModel(

llama_backend_init(params.numa);
ctx_params = llama_context_default_params();
llamamodel = llama_load_model_from_file(params.model.c_str(), ctx_params);
model_params = llama_model_default_params();
llamamodel = llama_load_model_from_file(params.model.c_str(), model_params);

return std::make_pair(module, device);
} catch (const c10::Error& e) {
Expand All @@ -74,7 +75,6 @@ std::vector<torch::jit::IValue> LlamacppHandler::Preprocess(
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
std::shared_ptr<torchserve::InferenceRequestBatch>& request_batch,
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) {

initialize_context();

std::vector<torch::jit::IValue> batch_ivalue;
Expand Down Expand Up @@ -181,8 +181,7 @@ torch::Tensor LlamacppHandler::Inference(
// evaluate the transformer

if (llama_eval(llama_ctx, tokens_list.data(), int(tokens_list.size()),
llama_get_kv_cache_token_count(llama_ctx),
params.n_threads)) {
llama_get_kv_cache_token_count(llama_ctx))) {
std::cout << "Failed to eval\n" << __func__ << std::endl;
break;
}
Expand All @@ -194,7 +193,7 @@ torch::Tensor LlamacppHandler::Inference(
llama_token new_token_id = 0;

auto logits = llama_get_logits(llama_ctx);
auto n_vocab = llama_n_vocab(llama_ctx);
auto n_vocab = llama_n_vocab(llamamodel);

std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
Expand All @@ -210,7 +209,7 @@ torch::Tensor LlamacppHandler::Inference(
new_token_id = llama_sample_token_greedy(llama_ctx, &candidates_p);

// is it an end of stream ?
if (new_token_id == llama_token_eos(llama_ctx)) {
if (new_token_id == llama_token_eos(llamamodel)) {
std::cout << "Reached [end of text]\n";
break;
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/examples/llamacpp/llamacpp_handler.hh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace llm {
class LlamacppHandler : public torchserve::torchscripted::BaseHandler {
private:
gpt_params params;
llama_model_params model_params;
llama_model* llamamodel;
llama_context_params ctx_params;
llama_context* llama_ctx;
Expand Down Expand Up @@ -52,4 +53,4 @@ class LlamacppHandler : public torchserve::torchscripted::BaseHandler {
override;
};
} // namespace llm
#endif // LLAMACPP_HANDLER_HH_
#endif // LLAMACPP_HANDLER_HH_
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ TEST_F(TorchScriptedBackendTest, TestLoadPredictLlmHandler) {
"test/resources/torchscript_model/llamacpp/llamacpp_handler", "llm",
-1, "", "", 1, false),
"test/resources/torchscript_model/llamacpp/llamacpp_handler",
"test/resources/torchscript_model/llamacpp/prompt.txt", "llm_ts",
200);
"test/resources/torchscript_model/llamacpp/prompt.txt", "llm_ts", 200);
}

TEST_F(TorchScriptedBackendTest, TestBackendInitWrongModelDir) {
Expand Down

0 comments on commit 28c9b02

Please sign in to comment.