Skip to content

Commit

Permalink
Replace auto with appropriate data type
Browse files Browse the repository at this point in the history
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
  • Loading branch information
shrinath-suresh committed Sep 4, 2023
1 parent 48f522c commit 49a3015
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions cpp/src/examples/babyllama/baby_llama_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ LlmHandler::LoadModel(
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well,
// but slower
int steps = 256; // number of steps to run for
unsigned long long rng_seed = 0;
unsigned long long rng_seed;
// build the Sampler
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp,
rng_seed);
Expand Down Expand Up @@ -110,11 +110,12 @@ std::vector<torch::jit::IValue> LlmHandler::Preprocess(
int num_prompt_tokens = 0;
int* prompt_tokens = (int*)malloc(
(strlen(msgCStr) + 3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS

encode(&tokenizer, msgCStr, 1, 0, prompt_tokens, &num_prompt_tokens);

std::vector<torch::Tensor> tensor_vector;
for (int64_t i = 0; i < num_prompt_tokens; ++i) {
auto token = prompt_tokens[i];
int token = prompt_tokens[i];
torch::Tensor tensor = torch::tensor(token, torch::kInt64);
tensor_vector.push_back(tensor);
}
Expand Down Expand Up @@ -153,7 +154,7 @@ torch::Tensor LlmHandler::Inference(
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) {
std::vector<torch::Tensor> tensor_vector;
auto tokens_list_tensor = inputs[0].toTensor();
torch::Tensor tokens_list_tensor = inputs[0].toTensor();

int64_t num_elements = tokens_list_tensor.numel();

Expand All @@ -162,7 +163,7 @@ torch::Tensor LlmHandler::Inference(
std::vector<long> long_vector;
long_vector.reserve(num_elements);

auto data_ptr = tokens_list_tensor.data_ptr<int64_t>();
int64_t* data_ptr = tokens_list_tensor.data_ptr<int64_t>();
for (int64_t i = 0; i < num_elements; ++i) {
long_vector.push_back(data_ptr[i]);
}
Expand Down Expand Up @@ -213,7 +214,7 @@ torch::Tensor LlmHandler::Inference(
// iteration)
if (pos > 1) {
long end = time_in_ms();
auto token_per_sec = (pos - 1) / (double)(end - start) * 1000;
double token_per_sec = (pos - 1) / (double)(end - start) * 1000;
std::cout << "Achieved tok per sec: " << token_per_sec << std::endl;
}

Expand All @@ -230,7 +231,7 @@ void LlmHandler::Postprocess(
for (const auto& kv : idx_to_req_id.second) {
try {
int64_t num_elements = data.numel();
auto data_ptr = data.data_ptr<int64_t>();
int64_t* data_ptr = data.data_ptr<int64_t>();
int64_t token = 1;
std::string concatenated_string;
for (int64_t i = 0; i < num_elements; ++i) {
Expand Down Expand Up @@ -264,7 +265,6 @@ void LlmHandler::Postprocess(
"c10 error, failed to postprocess tensor");
}
}

}

} // namespace llm
Expand Down

0 comments on commit 49a3015

Please sign in to comment.