Skip to content

Commit

Permalink
TensorRT-LLM Engine integration (#3228)
Browse files Browse the repository at this point in the history
* TensorRT-LLM Engine integration

* TensorRT-LLM Engine integration

* review comments

* review comments

* review comments

* Update README.md

---------

Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>
  • Loading branch information
agunapal and mreso committed Jul 9, 2024
1 parent 9c587d2 commit a1c8eb2
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 3 deletions.
86 changes: 86 additions & 0 deletions examples/large_models/trt_llm/llama/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Llama TensorRT-LLM Engine integration with TorchServe

[TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) provides users with an option to build TensorRT engines for LLMs that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.

## Pre-requisites

TRT-LLM requires Python 3.10
This example is tested with CUDA 12.1
Once TorchServe is installed, install TensorRT-LLM using the following.
This will downgrade the versions of PyTorch & Triton but this doesn't cause any issue.

```
pip install tensorrt_llm==0.10.0 --extra-index-url https://pypi.nvidia.com
pip install tensorrt-cu12==10.1.0
python -c "import tensorrt_llm"
```
shows
```
[TensorRT-LLM] TensorRT-LLM version: 0.10.0
```

## Download model from HuggingFace
```
huggingface-cli login
# or using an environment variable
huggingface-cli login --token $HUGGINGFACE_TOKEN
```
```
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3-8B-Instruct
```

## Create TensorRT-LLM Engine
Clone TensorRT-LLM which will be used to create the TensorRT-LLM Engine

```
git clone -b v0.10.0 https://github.com/NVIDIA/TensorRT-LLM.git
```

Compile the model into a TensorRT engine with model weights and a model definition written in the TensorRT-LLM Python API.

```
python TensorRT-LLM/examples/llama/convert_checkpoint.py --model_dir model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/ --output_dir ./tllm_checkpoint_1gpu_bf16 --dtype bfloat16
```
```
trtllm-build --checkpoint_dir tllm_checkpoint_1gpu_bf16 --gemm_plugin bfloat16 --gpt_attention_plugin bfloat16 --output_dir ./llama-3-8b-engine
```

You can test if TensorRT-LLM Engine has been compiled correctly by running the following
```
python TensorRT-LLM/examples/run.py --engine_dir ./llama-3-8b-engine --max_output_len 100 --tokenizer_dir model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/ --input_text "How do I count to nine in French?"
```

You should see an output as follows
```
Input [Text 0]: "<|begin_of_text|>How do I count to nine in French?"
Output [Text 0 Beam 0]: " Counting to nine in French is easy and fun. Here's how you can do it:
One: Un
Two: Deux
Three: Trois
Four: Quatre
Five: Cinq
Six: Six
Seven: Sept
Eight: Huit
Nine: Neuf
That's it! You can now count to nine in French. Just remember that the numbers one to five are similar to their English counterparts, but the numbers six to nine have different pronunciations"
```

## Create model archive

```
mkdir model_store
torch-model-archiver --model-name llama3-8b --version 1.0 --handler trt_llm_handler.py --config-file model-config.yaml --archive-format no-archive --export-path model_store -f
mv model model_store/llama3-8b/.
mv llama-3-8b-engine model_store/llama3-8b/.
```

## Start TorchServe
```
torchserve --start --ncs --model-store model_store --models llama3-8b --disable-token-auth
```

## Run Inference
```
python ../../utils/test_llm_streaming_response.py -o 50 -t 2 -n 4 -m llama3-8b --prompt-text "@prompt.json" --prompt-json
```
12 changes: 12 additions & 0 deletions examples/large_models/trt_llm/llama/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# TorchServe frontend parameters
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 100
responseTimeout: 1200
deviceType: "gpu"
asyncCommunication: true

handler:
tokenizer_dir: "model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/"
trt_llm_engine_config:
engine_dir: "llama-3-8b-engine"
3 changes: 3 additions & 0 deletions examples/large_models/trt_llm/llama/prompt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"prompt": "How is the climate in San Francisco?",
"temperature":0.5,
"max_new_tokens": 200}
118 changes: 118 additions & 0 deletions examples/large_models/trt_llm/llama/trt_llm_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import json
import logging
import time

import torch
from tensorrt_llm.runtime import ModelRunner
from transformers import AutoTokenizer

from ts.handler_utils.utils import send_intermediate_predict_response
from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class TRTLLMHandler(BaseHandler):
def __init__(self):
super().__init__()

self.trt_llm_engine = None
self.tokenizer = None
self.model = None
self.model_dir = None
self.lora_ids = {}
self.adapters = None
self.initialized = False

def initialize(self, ctx):
self.model_dir = ctx.system_properties.get("model_dir")

trt_llm_engine_config = ctx.model_yaml_config.get("handler").get(
"trt_llm_engine_config"
)

tokenizer_dir = ctx.model_yaml_config.get("handler").get("tokenizer_dir")
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir,
legacy=False,
padding_side="left",
truncation_side="left",
trust_remote_code=True,
use_fast=True,
)

if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

self.trt_llm_engine = ModelRunner.from_dir(**trt_llm_engine_config)
self.initialized = True

async def handle(self, data, context):
start_time = time.time()

metrics = context.metrics

data_preprocess = await self.preprocess(data)
output, input_batch = await self.inference(data_preprocess, context)
output = await self.postprocess(output, input_batch, context)

stop_time = time.time()
metrics.add_time(
"HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms"
)
return output

async def preprocess(self, requests):
input_batch = []
assert len(requests) == 1, "Expecting batch_size = 1"
for req_data in requests:
data = req_data.get("data") or req_data.get("body")
if isinstance(data, (bytes, bytearray)):
data = data.decode("utf-8")

prompt = data.get("prompt")
temperature = data.get("temperature", 1.0)
max_new_tokens = data.get("max_new_tokens", 50)
input_ids = self.tokenizer.encode(
prompt, add_special_tokens=True, truncation=True
)
input_batch.append(input_ids)

input_batch = [torch.tensor(x, dtype=torch.int32) for x in input_batch]

return (input_batch, temperature, max_new_tokens)

async def inference(self, input_batch, context):
input_ids_batch, temperature, max_new_tokens = input_batch

with torch.no_grad():
outputs = self.trt_llm_engine.generate(
batch_input_ids=input_ids_batch,
temperature=temperature,
max_new_tokens=max_new_tokens,
end_id=self.tokenizer.eos_token_id,
pad_id=self.tokenizer.pad_token_id,
output_sequence_lengths=True,
streaming=True,
return_dict=True,
)
return outputs, input_ids_batch

async def postprocess(self, inference_outputs, input_batch, context):
for inference_output in inference_outputs:
output_ids = inference_output["output_ids"]
sequence_lengths = inference_output["sequence_lengths"]

batch_size, _, _ = output_ids.size()
for batch_idx in range(batch_size):
output_end = sequence_lengths[batch_idx][0]
outputs = output_ids[batch_idx][0][output_end - 1 : output_end].tolist()
output_text = self.tokenizer.decode(outputs)
send_intermediate_predict_response(
[json.dumps({"text": output_text})],
context.request_ids,
"Result",
200,
context,
)
return [""] * len(input_batch)
6 changes: 3 additions & 3 deletions examples/large_models/utils/test_llm_streaming_response.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import argparse
import json
import random
import threading
from queue import Queue

import orjson
import requests

max_prompt_random_tokens = 20
Expand All @@ -27,7 +27,7 @@ def _predict(self):
combined_text = ""
for chunk in response.iter_content(chunk_size=None):
if chunk:
data = orjson.loads(chunk)
data = json.loads(chunk)
if self.args.demo_streaming:
print(data["text"], end="", flush=True)
else:
Expand All @@ -41,7 +41,7 @@ def _get_url(self):
def _format_payload(self):
prompt_input = _load_curl_like_data(self.args.prompt_text)
if self.args.prompt_json:
prompt_input = orjson.loads(prompt_input)
prompt_input = json.loads(prompt_input)
prompt = prompt_input.get("prompt", None)
assert prompt is not None
prompt_list = prompt.split(" ")
Expand Down
1 change: 1 addition & 0 deletions ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,7 @@ parallelLevel
parallelType
parallelization
pptp
TRT
torchcompile
HPU
hpu
Expand Down

0 comments on commit a1c8eb2

Please sign in to comment.