-
Notifications
You must be signed in to change notification settings - Fork 858
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding test script for finding tokens per second llama-7b-chat and gg…
…ml version Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
- Loading branch information
1 parent
002e221
commit f351d1d
Showing
2 changed files
with
67 additions
and
0 deletions.
There are no files selected for viewing
25 changes: 25 additions & 0 deletions
25
cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-ggml-hf.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import requests | ||
import json | ||
import time | ||
|
||
def send_text_file(url, file_path): | ||
with open(file_path, 'rb') as fp: | ||
file_bytes = fp.read() | ||
|
||
start_time = time.time() | ||
response = requests.post(url, data=file_bytes) | ||
time_taken = time.time() - start_time | ||
generated_answer = response.text | ||
print("Generated Anser: ", generated_answer) | ||
number_of_tokens = len(generated_answer.split(' ')) | ||
print("Number of tokens: ", number_of_tokens) | ||
print("Time taken: ", time_taken) | ||
print("Tokens per second:", number_of_tokens / int(time_taken)) | ||
|
||
|
||
if __name__ == "__main__": | ||
url = "http://localhost:8080/predictions/llm" | ||
file_path = "llm_handler/prompt.txt" | ||
|
||
send_text_file(url, file_path) | ||
|
42 changes: 42 additions & 0 deletions
42
cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-hf.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from transformers import AutoTokenizer | ||
import transformers | ||
import torch | ||
import time | ||
|
||
model = "meta-llama/Llama-2-7b-chat-hf" | ||
hf_api_key = "<INSERT-YOUR-HF-KEY-HERE>" | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model, use_auth_token=hf_api_key) | ||
pipeline = transformers.pipeline( | ||
"text-generation", | ||
model=model, | ||
torch_dtype=torch.float16, | ||
device_map="auto", | ||
use_auth_token=hf_api_key | ||
) | ||
|
||
start_time = time.time() | ||
sequences = pipeline( | ||
'Hello my name is\n', | ||
do_sample=True, | ||
top_k=10, | ||
num_return_sequences=1, | ||
eos_token_id=tokenizer.eos_token_id, | ||
max_length=512, | ||
) | ||
result = "" | ||
for seq in sequences: | ||
result += seq['generated_text'] | ||
print(f"Result: {seq['generated_text']}") | ||
time_taken = time.time() - start_time | ||
|
||
print("Generated String:", result) | ||
print("Total time taken:", time_taken) | ||
|
||
num_words = len(result.split(' ')) | ||
|
||
print("Total words generated: ", num_words) | ||
|
||
tokens_per_second = num_words / int(time_taken) | ||
|
||
print("Tokens per second: ", tokens_per_second) |