Skip to content

Commit

Permalink
Adding test script for finding tokens per second llama-7b-chat and gg…
Browse files Browse the repository at this point in the history
…ml version

Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
  • Loading branch information
shrinath-suresh committed Aug 21, 2023
1 parent 002e221 commit f351d1d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import requests
import json
import time

def send_text_file(url, file_path):
with open(file_path, 'rb') as fp:
file_bytes = fp.read()

start_time = time.time()
response = requests.post(url, data=file_bytes)
time_taken = time.time() - start_time
generated_answer = response.text
print("Generated Anser: ", generated_answer)
number_of_tokens = len(generated_answer.split(' '))
print("Number of tokens: ", number_of_tokens)
print("Time taken: ", time_taken)
print("Tokens per second:", number_of_tokens / int(time_taken))


if __name__ == "__main__":
url = "http://localhost:8080/predictions/llm"
file_path = "llm_handler/prompt.txt"

send_text_file(url, file_path)

42 changes: 42 additions & 0 deletions cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from transformers import AutoTokenizer
import transformers
import torch
import time

model = "meta-llama/Llama-2-7b-chat-hf"
hf_api_key = "<INSERT-YOUR-HF-KEY-HERE>"

tokenizer = AutoTokenizer.from_pretrained(model, use_auth_token=hf_api_key)
pipeline = transformers.pipeline(
"text-generation",
model=model,
torch_dtype=torch.float16,
device_map="auto",
use_auth_token=hf_api_key
)

start_time = time.time()
sequences = pipeline(
'Hello my name is\n',
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_length=512,
)
result = ""
for seq in sequences:
result += seq['generated_text']
print(f"Result: {seq['generated_text']}")
time_taken = time.time() - start_time

print("Generated String:", result)
print("Total time taken:", time_taken)

num_words = len(result.split(' '))

print("Total words generated: ", num_words)

tokens_per_second = num_words / int(time_taken)

print("Tokens per second: ", tokens_per_second)

0 comments on commit f351d1d

Please sign in to comment.