Skip to content

Commit

Permalink
Add support for hpu_backend and Resnet50 compile example (#3182)
Browse files Browse the repository at this point in the history
* Add support for hpu_backend and Resnet50 compile example

* Add skip_torch_install flag to install_dependencies.py

* Explaination of the PT_HPU_LAZY_MODE flag

* hpu perf RN50

* fix typo, update config desc

* clarify compile desc

---------

Co-authored-by: Ankith Gunapal <agunapal@ischool.Berkeley.edu>
Co-authored-by: Rafal Litka <rafal.litka@intel.com>
  • Loading branch information
3 people committed Jun 24, 2024
1 parent 5f3df71 commit 42e3fb1
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ TorchServe is a performant, flexible and easy to use tool for serving PyTorch ea
- [TorchServe Integrations](../examples/README.md#torchserve-integrations)
- [TorchServe UseCases](../examples/README.md#usecases)
* [Workflow Examples](https://github.com/pytorch/serve/tree/master/examples/Workflows) - Examples of how to compose models in a workflow with TorchServe
* [Resnet50 HPU compile](../examples/pt2/torch_compile_hpu/README.md) - An example of how to run the model in compile mode with the HPU device

## Advanced Features

Expand Down
142 changes: 142 additions & 0 deletions examples/pt2/torch_compile_hpu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@

# TorchServe Inference with torch.compile with HPU backend of Resnet50 model

This guide provides steps on how to optimize a ResNet50 model using `torch.compile` with [HPU backend](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Getting_Started_with_Inference.html), aiming to enhance inference performance when deployed through TorchServe. `torch.compile` allows for JIT compilation of Python code into optimized kernels with a simple API.

### Prerequisites and installation
First install `Intel® Gaudi® AI accelerator software for PyTorch` - Go to [Installation_Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html) which covers installation procedures, including software verification and subsequent steps for software installation and management.

Then install the dependencies with the `--skip_torch_install` flag so as not to overwrite habana torch, which you should already have installed. Then install torchserve, torch-model-archiver torch-workflow-archiver as in the example below.

```bash
python ./ts_scripts/install_dependencies.py --skip_torch_install

# Latest release
python install -c torchserve torch-model-archiver torch-workflow-archiver
```

## Workflow
1. Configure torch.compile.
2. Create model archive.
3. Start TorchServe.
4. Run Inference.
5. Stop TorchServe.

First, navigate to `examples/pt2/torch_compile_hpu`
```bash
cd examples/pt2/torch_compile_hpu
```

### 1. Configure torch.compile

`torch.compile` allows various configurations that can influence performance outcomes. Explore different options in the [official PyTorch documentation](https://pytorch.org/docs/stable/generated/torch.compile.html)


In this example, we use the following config that is provided in `model-config.yaml` file:

```bash
echo "minWorkers: 1
maxWorkers: 1
pt2:
compile:
enable: True
backend: hpu_backend" > model-config.yaml
```
Using this configuration will activate the compile mode. Eager mode can be enabled by setting `enable: False` or removing the whole `pt2:` section.

### 2. Create model archive

Download the pre-trained model and prepare the model archive:
```bash
wget https://download.pytorch.org/models/resnet50-11ad3fa6.pth
mkdir model_store
PT_HPU_LAZY_MODE=0 torch-model-archiver --model-name resnet-50 --version 1.0 --model-file model.py \
--serialized-file resnet50-11ad3fa6.pth --export-path model_store \
--extra-files ../../image_classifier/index_to_name.json --handler hpu_image_classifier.py \
--config-file model-config.yaml
```

`PT_HPU_LAZY_MODE=0` selects `eager+torch.compile` mode. Gaudi integration with PyTorch supports officially 2 modes of operation: `lazy` and `eager+torch.compile (beta state)`. Currently the first one is default, therefore it is necessary to use this flag for compile mode until `eager+torch.compile` mode is set as default. [More information](https://docs.habana.ai/en/latest/PyTorch/Reference/Runtime_Flags.html#pytorch-runtime-flags)

### 3. Start TorchServe

Start the TorchServe server using the following command:
```bash
PT_HPU_LAZY_MODE=0 torchserve --start --ncs --model-store model_store --models resnet-50.mar
```

### 4. Run Inference

**Note:** `torch.compile` requires a warm-up phase to reach optimal performance. Ensure you run at least as many inferences as the `maxWorkers` specified before measuring performance.

```bash
# Open a new terminal
cd examples/pt2/torch_compile_hpu
curl http://127.0.0.1:8080/predictions/resnet-50 -T ../../image_classifier/kitten.jpg
```

The expected output will be JSON-formatted classification probabilities, such as:

```json
{
"tabby": 0.2724992632865906,
"tiger_cat": 0.1374046504497528,
"Egyptian_cat": 0.046274710446596146,
"lynx": 0.003206699388101697,
"lens_cap": 0.002257900545373559
}
```

### 5. Stop the server
Stop TorchServe with the following command:

```bash
torchserve --stop
```

### 6. Performance improvement from using `torch.compile`

To measure the handler `preprocess`, `inference`, `postprocess` times, run the following

#### Measure inference time with PyTorch eager

```bash
echo "minWorkers: 1
maxWorkers: 1
handler:
profile: true" > model-config.yaml
```

Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above.
After a few iterations of warmup, we see the following

```bash
[INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:6.921529769897461|#ModelName:resnet-50,Level:Model|#type:GAUGE|###,1718265363,fe1dcea2-854d-4847-848e-a05e922d456c, pattern=[METRICS]
[INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:5.218982696533203|#ModelName:resnet-50,Level:Model|#type:GAUGE|###,1718265363,fe1dcea2-854d-4847-848e-a05e922d456c, pattern=[METRICS]
[INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:8.724212646484375|#ModelName:resnet-50,Level:Model|#type:GAUGE|###,1718265363,fe1dcea2-854d-4847-848e-a05e922d456c, pattern=[METRICS]
```

#### Measure inference time with `torch.compile`

```bash
echo "minWorkers: 1
maxWorkers: 1
pt2:
compile:
enable: True
backend: hpu_backend
handler:
profile: true" > model-config.yaml
```

Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above.
`torch.compile` needs a few inferences to warmup. Once warmed up, we see the following
```bash
[INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:6.833314895629883|#ModelName:resnet-50,Level:Model|#type:GAUGE|###,1718265582,53da9032-4ad3-49df-8cd4-2d499eea7691, pattern=[METRICS]
[INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:0.7846355438232422|#ModelName:resnet-50,Level:Model|#type:GAUGE|###,1718265582,53da9032-4ad3-49df-8cd4-2d499eea7691, pattern=[METRICS]
[INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:1.9681453704833984|#ModelName:resnet-50,Level:Model|#type:GAUGE|###,1718265582,53da9032-4ad3-49df-8cd4-2d499eea7691, pattern=[METRICS]
```

### Conclusion

`torch.compile` reduces the inference time from 5.22ms to 0.78ms
18 changes: 18 additions & 0 deletions examples/pt2/torch_compile_hpu/hpu_image_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import habana_frameworks.torch.core as htcore # nopycln: import
import torch

from ts.torch_handler.image_classifier import ImageClassifier


class HPUImageClassifier(ImageClassifier):
def set_hpu(self):
self.map_location = "hpu"
self.device = torch.device(self.map_location)

def _load_pickled_model(self, model_dir, model_file, model_pt_path):
"""
This override of this method allows us to set device to hpu and use the default base_handler without having to modify it.
"""
model = super()._load_pickled_model(model_dir, model_file, model_pt_path)
self.set_hpu()
return model
6 changes: 6 additions & 0 deletions examples/pt2/torch_compile_hpu/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
minWorkers: 1
maxWorkers: 1
pt2:
compile:
enable: True
backend: hpu_backend
6 changes: 6 additions & 0 deletions examples/pt2/torch_compile_hpu/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from torchvision.models.resnet import Bottleneck, ResNet


class ImageClassifier(ResNet):
def __init__(self):
super(ImageClassifier, self).__init__(Bottleneck, [3, 4, 6, 3])
1 change: 1 addition & 0 deletions ts/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class PT2Backend(str, enum.Enum):
IPEX = "ipex"
TORCHXLA_TRACE_ONCE = "torchxla_trace_once"
OPENVINO = "openvino"
HPU_BACKEND = "hpu_backend"


logger = logging.getLogger(__name__)
Expand Down
8 changes: 8 additions & 0 deletions ts_scripts/install_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def install_python_packages(self, cuda_version, requirements_file_path, nightly)
os.system(
f"pip3 install numpy --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/{pt_nightly}"
)
elif args.skip_torch_install:
print("Skipping Torch installation")
else:
self.install_torch_packages(cuda_version)

Expand Down Expand Up @@ -379,6 +381,12 @@ def get_brew_version():
help="Install nightly version of torch package",
)

parser.add_argument(
"--skip_torch_install",
action="store_true",
help="Skip Torch installation",
)

parser.add_argument(
"--force",
action="store_true",
Expand Down

0 comments on commit 42e3fb1

Please sign in to comment.