Skip to content

Commit

Permalink
T5 Translation with torch.compile & TensorRT backend (#3223)
Browse files Browse the repository at this point in the history
* Added T5 TensorRT example with torch.compile

* Added T5 TensorRT example with torch.compile

* lint check

* Update examples/torch_tensorrt/torchcompile/T5/README.md

Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>

* Update T5_handler.py

review comments

---------

Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>
  • Loading branch information
agunapal and mreso committed Jul 9, 2024
1 parent 9a86e06 commit 9c587d2
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 1 deletion.
56 changes: 56 additions & 0 deletions examples/torch_tensorrt/torchcompile/T5/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# TorchServe inference with torch.compile with tensorrt backend

This example shows how to run TorchServe inference with T5 [Torch-TensorRT](https://github.com/pytorch/TensorRT) model



[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#inference) is an encode-decoder model used for a variety of text tasks out of the box by prepending a different text corresponding to each task. In this example, we use T5 for translation from English to German.

### Pre-requisites

- Verified to be working with `torch-tensorrt==2.3.0`
Installation instructions can be found in [pytorch/TensorRT](https://github.com/pytorch/TensorRT)

Change directory to examples directory `cd examples/torch_tensorrt/T5/torchcompile`

### torch.compile config

To use `tensorrt` backend with `torch.compile`, we specify the following config in `model-config.yaml`

```
pt2:
compile:
enable: True
backend: tensorrt
```

### Download the model

```
python ../../../large_models/Huggingface_accelerate/Download_model.py --model_name google-t5/t5-small
```

### Create the model archive
```
mkdir model_store
torch-model-archiver --model-name t5-translation --version 1.0 --handler T5_handler.py --config-file model-config.yaml -r requirements.txt --archive-format no-archive --export-path model_store -f
mv model model_store/t5-translation/.
```

### Start TorchServe

```
torchserve --start --ncs --ts-config config.properties --model-store model_store --models t5-translation --disable-token-auth
```

### Run Inference

```
curl -X POST http://127.0.0.1:8080/predictions/t5-translation -T sample_text.txt
```

results in

```
Das Haus ist wunderbar
```
142 changes: 142 additions & 0 deletions examples/torch_tensorrt/torchcompile/T5/T5_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import logging

import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class T5Handler(BaseHandler):
"""
Transformers handler class for sequence, token classification and question answering.
"""

def __init__(self):
super(T5Handler, self).__init__()
self.tokenizer = None
self.model = None
self.initialized = False

def initialize(self, ctx):
"""In this initialize function, the T5 model is loaded. It also has
the torch.compile calls for encoder and decoder.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
self.manifest = ctx.manifest
self.model_yaml_config = (
ctx.model_yaml_config
if ctx is not None and hasattr(ctx, "model_yaml_config")
else {}
)
properties = ctx.system_properties
model_dir = properties.get("model_dir")

self.device = torch.device(
"cuda:" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)

# read configs for the mode, model_name, etc. from the handler config
model_path = self.model_yaml_config.get("handler", {}).get("model_path", None)
if not model_path:
logger.error("Missing model path")

self.tokenizer = T5Tokenizer.from_pretrained(model_path)
self.model = T5ForConditionalGeneration.from_pretrained(model_path)
self.model.to(self.device)

self.model.eval()

pt2_value = self.model_yaml_config.get("pt2", {})
if "compile" in pt2_value:
compile_options = pt2_value["compile"]
if compile_options["enable"] == True:
del compile_options["enable"]

compile_options_str = ", ".join(
[f"{k} {v}" for k, v in compile_options.items()]
)
self.model.encoder = torch.compile(
self.model.encoder,
**compile_options,
)
self.model.decoder = torch.compile(
self.model.decoder,
**compile_options,
)
logger.info(f"Compiled model with {compile_options_str}")
logger.info("T5 model from path %s loaded successfully", model_dir)

self.initialized = True

def preprocess(self, requests):
"""
Basic text preprocessing, based on the user's choice of application mode.
Args:
requests (list): A list of dictionaries with a "data" or "body" field, each
containing the input text to be processed.
Returns:
inputs: A batched tensor of inputs: the batch of input ids and
attention masks.
"""

# Prefix for translation from English to German
task_prefix = "translate English to German: "
input_texts = [task_prefix + self.preprocess_requests(r) for r in requests]

logger.debug("Received texts: '%s'", input_texts)
inputs = self.tokenizer(
input_texts,
padding=True,
return_tensors="pt",
).to(self.device)

return inputs

def preprocess_requests(self, request):
"""
Preprocess request
Args:
request : Request to be decoded.
Returns:
str: Decoded input text
"""
input_text = request.get("data") or request.get("body")
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
return input_text

@torch.inference_mode()
def inference(self, input_batch):
"""
Generates the translated text for the given input
Args:
input_batch : A tensors: the batch of input ids and attention masks, as returned by the
preprocess function.
Returns:
list: A list of strings with the translated text for each input text in the batch.
"""
outputs = self.model.generate(
**input_batch,
)

inferences = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

logger.debug("Generated text: %s", inferences)
return inferences

def postprocess(self, inference_output):
"""Post Process Function converts the predicted response into Torchserve readable format.
Args:
inference_output (list): It contains the predicted response of the input text.
Returns:
(list): Returns a list of the Predictions.
"""
return inference_output
5 changes: 5 additions & 0 deletions examples/torch_tensorrt/torchcompile/T5/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
inference_address=http://127.0.0.1:8080
management_address=http://127.0.0.1:8081
metrics_address=http://127.0.0.1:8082
enable_envvars_config=true
install_py_dep_per_model=true
8 changes: 8 additions & 0 deletions examples/torch_tensorrt/torchcompile/T5/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
minWorkers: 1
maxWorkers: 1
handler:
model_path: model/models--google-t5--t5-small/snapshots/df1b051c49625cf57a3d0d8d3863ed4d13564fe4
pt2:
compile:
enable: True
backend: tensorrt
3 changes: 3 additions & 0 deletions examples/torch_tensorrt/torchcompile/T5/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
transformers>=4.41.2
sentencepiece>=0.2.0
torch-tensorrt>=2.3.0
1 change: 1 addition & 0 deletions examples/torch_tensorrt/torchcompile/T5/sample_text.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The house is wonderful
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This example shows how to run TorchServe inference with [Torch-TensorRT](https:/

- Verified to be working with `torch-tensorrt==2.3.0`

Change directory to examples directory `cd examples/torch_tensorrt/torchcompile`
Change directory to examples directory `cd examples/torch_tensorrt/resnet50/torchcompile`

### torch.compile config

Expand Down

0 comments on commit 9c587d2

Please sign in to comment.