Skip to content

Commit

Permalink
BERT with torch.compile (#3201)
Browse files Browse the repository at this point in the history
* Added support for torch.compile with BERT

* Added support for torch.compile with BERT

* Added support for torch.compile with BERT

* Added support for torch.compile with BERT

* Added support for torch.compile with BERT

* Added support for torch.compile with BERT

* Update examples/Huggingface_Transformers/README.md

Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>

* Updated based on review comments

---------

Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>
  • Loading branch information
agunapal and mreso committed Jun 27, 2024
1 parent fccb13a commit affdcdd
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 109 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
import os
import sys

import torch
import transformers
import yaml
from transformers import (
AutoConfig,
AutoModelForCausalLM,
Expand Down Expand Up @@ -151,9 +151,10 @@ def transformers_model_dowloader(
if len(sys.argv) > 1:
filename = os.path.join(dirname, sys.argv[1])
else:
filename = os.path.join(dirname, "setup_config.json")
filename = os.path.join(dirname, "model-config.yaml")
f = open(filename)
settings = json.load(f)
model_yaml_config = yaml.safe_load(f)
settings = model_yaml_config["handler"]
mode = settings["mode"]
model_name = settings["model_name"]
num_labels = int(settings["num_labels"])
Expand Down
167 changes: 90 additions & 77 deletions examples/Huggingface_Transformers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,37 @@ To get started [install Torchserve](https://github.com/pytorch/serve) and then

### **Getting Started with the Demo**

If you're finetuning an existing model then you need to save your model and tokenizer with `save_pretrained()` which will create a `pytorch_model.bin`, `vocab.txt` and `config.json` file. Make sure to create them then run
If you're finetuning an existing model then you need to save your model and tokenizer with `save_pretrained()` which will create a `model.safetensors`, `vocab.txt` and `config.json` file. Make sure to create them then run

```
mkdir Transformer_model
mv pytorch_model.bin vocab.txt config.json Transformer_model/
mv model.safetensors vocab.txt config.json Transformer_model/
```

If you'd like to download a pretrained model without fine tuning we've provided a simple helper script which will do the above for you. All you need to do is change [setup.config.json](https://github.com/pytorch/serve/blob/master/examples/Huggingface_Transformers/setup_config.json) to your liking and run
If you'd like to download a pretrained model without fine tuning we've provided a simple helper script which will do the above for you. All you need to do is change [model-config.yaml](https://github.com/pytorch/serve/blob/master/examples/Huggingface_Transformers/model-config.yaml) to your liking and run

`python Download_Transformer_models.py`

In this example, we are using `torch.compile` by default

This is enabled by the following config in `model-config.yaml` file

```
pt2:
compile:
enable: True
backend: inductor
mode: reduce-overhead
```
When batch_size is 1, for BERT models, the operations are memory bound. Hence, we make use of `reduce-overhead` mode to make use of CUDAGraph and get better performance.

To use PyTorch Eager or TorchScript, you can remove the above config.

For Torchscript support, check out [torchscript.md](torchscript.md)

#### Setting the setup_config.json
#### Setting the handler config in model-config.yaml

In the setup_config.json :
In `model-config.yaml` :

*model_name* : bert-base-uncased , roberta-base or other available pre-trained models.

Expand All @@ -55,7 +70,7 @@ In the setup_config.json :

*batch_size* : Input batch size when tracing the model for `neuron` or `neuronx` as target hardware.

Once, `setup_config.json` has been set properly, the next step is to run
Once, `model-config.yaml` has been set properly, the next step is to run

`python Download_Transformer_models.py`

Expand All @@ -78,17 +93,17 @@ For examples of how to configure a model for a use case and what the input forma

## Sequence Classification

### Create model archive eager mode
### Create model archive for eager mode or torch.compile

```
torch-model-archiver --model-name BERTSeqClassification --version 1.0 --serialized-file Transformer_model/pytorch_model.bin --handler ./Transformer_handler_generalized.py --extra-files "Transformer_model/config.json,./setup_config.json,./Seq_classification_artifacts/index_to_name.json"
torch-model-archiver --model-name BERTSeqClassification --version 1.0 --serialized-file Transformer_model/model.safetensors --handler ./Transformer_handler_generalized.py --config-file model-config.yaml --extra-files "Transformer_model/config.json,./Seq_classification_artifacts/index_to_name.json"
```

### Create model archive Torchscript mode

```
torch-model-archiver --model-name BERTSeqClassification --version 1.0 --serialized-file Transformer_model/traced_model.pt --handler ./Transformer_handler_generalized.py --extra-files "./setup_config.json,./Seq_classification_artifacts/index_to_name.json"
torch-model-archiver --model-name BERTSeqClassification --version 1.0 --serialized-file Transformer_model/traced_model.pt --handler ./Transformer_handler_generalized.py --config-file model-config.yaml --extra-files "./Seq_classification_artifacts/index_to_name.json"
```

Expand All @@ -99,7 +114,7 @@ To register the model on TorchServe using the above model archive file, we run t
```
mkdir model_store
mv BERTSeqClassification.mar model_store/
torchserve --start --model-store model_store --models my_tc=BERTSeqClassification.mar --ncs
torchserve --start --model-store model_store --models my_tc=BERTSeqClassification.mar --disable-token --ncs
```

Expand All @@ -113,43 +128,43 @@ To get an explanation: `curl -X POST http://127.0.0.1:8080/explanations/my_tc -T

## Token Classification

Change `setup_config.json` to
Change the `handler` section in `model-config.yaml` to

```
{
"model_name":"bert-base-uncased",
"mode":"token_classification",
"do_lower_case":true,
"num_labels":"9",
"save_mode":"pretrained",
"max_length":"150",
"captum_explanation":true,
"FasterTransformer":false,
"embedding_name": "bert"
}
handler:
model_name: bert-base-uncased
mode: token_classification
do_lower_case: true
num_labels: 9
save_mode: pretrained
max_length: 150
captum_explanation: true
embedding_name: bert
BetterTransformer: false
model_parallel: false
```

```
rm -r Transformer_model
python Download_Transformer_models.py
```

### Create model archive eager mode
### Create model archive for eager mode or torch.compile
```
torch-model-archiver --model-name BERTTokenClassification --version 1.0 --serialized-file Transformer_model/pytorch_model.bin --handler ./Transformer_handler_generalized.py --extra-files "Transformer_model/config.json,./setup_config.json,./Token_classification_artifacts/index_to_name.json"
torch-model-archiver --model-name BERTTokenClassification --version 1.0 --serialized-file Transformer_model/model.safetensors --handler ./Transformer_handler_generalized.py --config-file model-config.yaml --extra-files "Transformer_model/config.json,./Token_classification_artifacts/index_to_name.json"
```

### Create model archive Torchscript mode
```
torch-model-archiver --model-name BERTTokenClassification --version 1.0 --serialized-file Transformer_model/traced_model.pt --handler ./Transformer_handler_generalized.py --extra-files "./setup_config.json,./Token_classification_artifacts/index_to_name.json"
torch-model-archiver --model-name BERTTokenClassification --version 1.0 --serialized-file Transformer_model/traced_model.pt --handler ./Transformer_handler_generalized.py --config-file model-config.yaml --extra-files "./Token_classification_artifacts/index_to_name.json"
```

### Register the model

```
mkdir model_store
mv BERTTokenClassification.mar model_store
torchserve --start --model-store model_store --models my_tc=BERTTokenClassification.mar --ncs
torchserve --start --model-store model_store --models my_tc=BERTTokenClassification.mar --disable-token --ncs
```

### Run an inference
Expand All @@ -158,63 +173,62 @@ To get an explanation: `curl -X POST http://127.0.0.1:8080/explanations/my_tc -T

## Question Answering

Change `setup_config.json` to
Change the `handler` section in `model-config.yaml` to
```
{
"model_name":"distilbert-base-cased-distilled-squad",
"mode":"question_answering",
"do_lower_case":true,
"num_labels":"0",
"save_mode":"pretrained",
"max_length":"128",
"captum_explanation":true,
"FasterTransformer":false,
"embedding_name": "distilbert"
}
handler:
model_name: distilbert-base-cased-distilled-squad
mode: question_answering
do_lower_case: true
num_labels: 0
save_mode: pretrained
max_length: 150
captum_explanation: true
embedding_name: distilbert
BetterTransformer: false
model_parallel: false
```

```
rm -r Transformer_model
python Download_Transformer_models.py
```

### Create model archive eager mode
### Create model archive for eager mode or torch.compile
```
torch-model-archiver --model-name BERTQA --version 1.0 --serialized-file Transformer_model/pytorch_model.bin --handler ./Transformer_handler_generalized.py --extra-files "Transformer_model/config.json,./setup_config.json"
torch-model-archiver --model-name BERTQA --version 1.0 --serialized-file Transformer_model/model.safetensors --handler ./Transformer_handler_generalized.py --config-file model-config.yaml --extra-files "Transformer_model/config.json"
```

### Create model archive Torchscript mode
```
torch-model-archiver --model-name BERTQA --version 1.0 --serialized-file Transformer_model/traced_model.pt --handler ./Transformer_handler_generalized.py --extra-files "./setup_config.json"
torch-model-archiver --model-name BERTQA --version 1.0 --serialized-file Transformer_model/traced_model.pt --handler ./Transformer_handler_generalized.py --config-file model-config.yaml
```

### Register the model

```
mkdir model_store
mv BERTQA.mar model_store
torchserve --start --model-store model_store --models my_tc=BERTQA.mar --ncs
torchserve --start --model-store model_store --models my_tc=BERTQA.mar --disable-token --ncs
```
### Run an inference
To run an inference: `curl -X POST http://127.0.0.1:8080/predictions/my_tc -T QA_artifacts/sample_text_captum_input.txt`
To get an explanation: `curl -X POST http://127.0.0.1:8080/explanations/my_tc -T QA_artifacts/sample_text_captum_input.txt`

## Text Generation

Change `setup_config.json` to

Change the `handler` section in `model-config.yaml` to
```
{
"model_name":"gpt2",
"mode":"text_generation",
"do_lower_case":true,
"num_labels":"0",
"save_mode":"pretrained",
"max_length":"150",
"captum_explanation":true,
"FasterTransformer":false,
"embedding_name": "gpt2"
}
handler:
model_name: gpt2
mode: text_generation
do_lower_case: true
num_labels: 0
save_mode: pretrained
max_length: 150
captum_explanation: true
embedding_name: gpt2
BetterTransformer: false
model_parallel: false
```

```
Expand All @@ -225,13 +239,13 @@ python Download_Transformer_models.py
### Create model archive eager mode

```
torch-model-archiver --model-name Textgeneration --version 1.0 --serialized-file Transformer_model/pytorch_model.bin --handler ./Transformer_handler_generalized.py --extra-files "Transformer_model/config.json,./setup_config.json"
torch-model-archiver --model-name Textgeneration --version 1.0 --serialized-file Transformer_model/model.safetensors --handler ./Transformer_handler_generalized.py --config-file model-config.yaml --extra-files "Transformer_model/config.json"
```

### Create model archive Torchscript mode

```
torch-model-archiver --model-name Textgeneration --version 1.0 --serialized-file Transformer_model/traced_model.pt --handler ./Transformer_handler_generalized.py --extra-files "./setup_config.json"
torch-model-archiver --model-name Textgeneration --version 1.0 --serialized-file Transformer_model/traced_model.pt --handler ./Transformer_handler_generalized.py --config-file model-config.yaml
```

### Register the model
Expand All @@ -241,7 +255,7 @@ To register the model on TorchServe using the above model archive file, we run t
```
mkdir model_store
mv Textgeneration.mar model_store/
torchserve --start --model-store model_store --models my_tc=Textgeneration.mar --ncs
torchserve --start --model-store model_store --models my_tc=Textgeneration.mar --disable-token --ncs
```

### Run an inference
Expand All @@ -258,7 +272,7 @@ For batch inference the main difference is that you need set the batch size whil
```
mkdir model_store
mv BERTSeqClassification.mar model_store/
torchserve --start --model-store model_store --ncs
torchserve --start --model-store model_store --disable-token --ncs
curl -X POST "localhost:8081/models?model_name=BERTSeqClassification&url=BERTSeqClassification.mar&batch_size=4&max_batch_delay=5000&initial_workers=3&synchronous=true"
```
Expand Down Expand Up @@ -316,36 +330,35 @@ When a json file is passed as a request format to the curl, Torchserve unwraps t

## Speed up inference with Better Transformer (Flash Attentions/ Xformer Memory Efficient kernels)

In the setup_config.json, specify `"BetterTransformer":true,`.
In the `model-config.yaml`, specify `"BetterTransformer":true,`.


[Better Transformer(Accelerated Transformer)](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/) from PyTorch is integrated into [Huggingface Optimum](https://huggingface.co/docs/optimum/bettertransformer/overview) that bring major speedups for many of encoder models on different modalities (text, image, audio). It is a one liner API that we have also added in the `Transformer_handler_generalized.py` in this example as well. That as shown above you just need to set `"BetterTransformer":true,` in the setup_config.json.
[Better Transformer(Accelerated Transformer)](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/) from PyTorch is integrated into [Huggingface Optimum](https://huggingface.co/docs/optimum/bettertransformer/overview) that bring major speedups for many of encoder models on different modalities (text, image, audio). It is a one liner API that we have also added in the `Transformer_handler_generalized.py` in this example as well. That as shown above you just need to set `"BetterTransformer":true,` in the `model-config.yaml`.

Main speed ups in the Better Transformer comes from kernel fusion in the [TransformerEncoder] (https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html) and making use of sparsity with [nested tensors](https://pytorch.org/tutorials/prototype/nestedtensor.html) when input sequences are padded to avoid unnecessary computation on padded tensors. We have seen up to 4.5x speed up with distill_bert when used higher batch sizes with padding. Please read more about it in this [blog post](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2). You get some speedups even with Batch size = 1 and no padding however, major speed ups will show up when running inference with higher batch sizes (8.16,32) with padding.

The Accelerated Transformer integration with HuggingFace also added the support for decoder models, please read more about it [here](https://pytorch.org/blog/out-of-the-box-acceleration/). This adds the native support for Flash Attentions and Xformer Memory Efficient kernels in PyTorch and make it available on HuggingFace deocder models. This will brings significant speed up and memory savings with just one line of the code as before.
The Accelerated Transformer integration with HuggingFace also added the support for decoder models, please read more about it [here](https://pytorch.org/blog/out-of-the-box-acceleration/). This adds the native support for Flash Attentions and Xformer Memory Efficient kernels in PyTorch and make it available on HuggingFace decoder models. This will brings significant speed up and memory savings with just one line of the code as before.


## Model Parallelism

[Parallelize] (https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Model.parallelize) is a an experimental feature that HuggingFace recently added to support large model inference for some very large models, GPT2 and T5. GPT2 model choices based on their size are gpt2-medium, gpt2-large, gpt2-xl. This feature only supports LMHeadModel that could be used for text generation, other application such as sequence, token classification and question answering are not supported. We have added parallelize support for GPT2 model in the custom handler in this example that will enable you to perform model parallel inference for GPT2 models used for text generation. The same logic in the handler can be extended to T5 and the applications it supports. Make sure that you register your model with one worker using this feature. To run this example, a machine with #gpus > 1 is required. The number of required gpus depends on the size of the model. This feature only supports single node, one machine with multi-gpus.

Change `setup_config.json` to

Change the `handler` section in `model-config.yaml` to
```
{
"model_name":"gpt2",
"mode":"text_generation",
"do_lower_case":true,
"num_labels":"0",
"save_mode":"pretrained",
"max_length":"150",
"captum_explanation":true,
"embedding_name": "gpt2",
"FasterTransformer":false,
"model_parallel":true
}
handler:
model_name: gpt2
mode: text_generation
do_lower_case: true
num_labels: 0
save_mode: pretrained
max_length: 150
captum_explanation: true
embedding_name: gpt2
BetterTransformer: false
model_parallel: true
```

```
rm -r Transformer_model
python Download_Transformer_models.py
Expand All @@ -364,7 +377,7 @@ To register the model on TorchServe using the above model archive file, we run t
```
mkdir model_store
mv Textgeneration.mar model_store/
torchserve --start --model-store model_store
torchserve --start --model-store model_store --disable-token
curl -X POST "localhost:8081/models?model_name=Textgeneration&url=Textgeneration.mar&batch_size=1&max_batch_delay=5000&initial_workers=1&synchronous=true"
```

Expand Down
Loading

0 comments on commit affdcdd

Please sign in to comment.