Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchserve support for Intel GPUs #3132

Merged
merged 37 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
292d6ad
Merge changes from intel-sandbox/serve
krish-navulla May 1, 2024
0f81bdf
Merge branch 'pytorch:master' into xpu-enabled
krish-navulla May 2, 2024
88bea0d
ipex_gpu_enable - New config in config.properties
krish-navulla May 2, 2024
ee159f1
Instructions for IPEX GPU support
krish-navulla May 3, 2024
bbbb626
Final Commits 1
krish-navulla May 3, 2024
d7f0c8f
Style: Refactor code formatting
krish-navulla May 3, 2024
881572b
Readme Updated
krish-navulla May 3, 2024
e5f3e6a
Code Refactoring
krish-navulla May 3, 2024
e91db65
Code Refactoring
May 3, 2024
57a1ff6
Merge branch 'pytorch:master' into xpu-enabled
krish-navulla May 3, 2024
f308eef
Merge branch 'master' into xpu-enabled
krish-navulla May 3, 2024
78bd30c
Merge branch 'pytorch:master' into xpu-enabled
krish-navulla May 3, 2024
45f971f
Final Commit
krish-navulla May 3, 2024
a603fe1
Merge branch 'master' into xpu-enabled
krish-navulla May 4, 2024
c5179a2
Merge branch 'master' into xpu-enabled
krish-navulla May 6, 2024
f8e539e
Merge branch 'pytorch:master' into xpu-enabled
krish-navulla May 10, 2024
c5b2dbf
Merge branch 'pytorch:master' into xpu-enabled
krish-navulla May 10, 2024
ae03184
self.device mapping to XPU
krish-navulla May 14, 2024
52289dd
Merge branch 'pytorch:master' into xpu-enabled
krish-navulla May 14, 2024
d64f314
Code Refactoring
krish-navulla May 14, 2024
a4564da
Mulitple GPU device engagement enabled
krish-navulla May 15, 2024
27d705c
Merge branch 'master' into xpu-enabled
krish-navulla May 21, 2024
9188deb
Remove unused changes
Kanya-Mo May 29, 2024
e3be79c
Revert "Remove unused changes"
Kanya-Mo Jun 6, 2024
5fe7645
Add performance gain info for GPU
anupren Jun 14, 2024
e029268
Update README.md
anupren Jun 14, 2024
a8bdb47
Add units to table
anupren Jun 14, 2024
ee5187d
Update metric reading configuration.
Kanya-Mo Jun 19, 2024
578dfab
Update system metrics script path.
Kanya-Mo Jun 20, 2024
9f77a01
Merge branch 'master' into xpu-enabled
Kanya-Mo Jun 21, 2024
e48f2c1
Update ConfigManager.java
Kanya-Mo Jun 21, 2024
bca4012
Reformat ConfigManager.java
Kanya-Mo Jun 22, 2024
77f4495
Merge branch 'master' into xpu-enabled
Kanya-Mo Jun 22, 2024
7c4a69d
Merge branch 'master' into xpu-enabled
agunapal Jun 22, 2024
86f7ced
Fix spelling issues
Kanya-Mo Jun 24, 2024
927483d
Fix lint changed file.
Kanya-Mo Jun 24, 2024
fd20a57
Merge branch 'master' into xpu-enabled
Kanya-Mo Jun 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions examples/intel_extension_for_pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Here we show how to use TorchServe with Intel® Extension for PyTorch*.
* [Install Intel® Extension for PyTorch*](https://github.com/pytorch/serve/blob/master/examples/intel_extension_for_pytorch/README.md#install-intel-extension-for-pytorch)
* [Serving model with Intel® Extension for PyTorch*](https://github.com/pytorch/serve/blob/master/examples/intel_extension_for_pytorch/README.md#serving-model-with-intel-extension-for-pytorch)
* [TorchServe with Launcher](#torchserve-with-launcher)
* [TorchServe with Intel® Extension for PyTorch* and Intel GPUs](#torchserve-with-intel®-extension-for-pytorch-and-intel-gpus)
* [Performance Gain with Intel® Extension for PyTorch* and Intel GPU](https://github.com/pytorch/serve/blob/master/examples/intel_extension_for_pytorch/README.md#performance-gain-with-intel-extension-for-pytorch-and-intel-gpu)
* [Creating and Exporting INT8 model for Intel® Extension for PyTorch*](https://github.com/pytorch/serve/blob/master/examples/intel_extension_for_pytorch/README.md#creating-and-exporting-int8-model-for-intel-extension-for-pytorch)
* [Benchmarking with Launcher](#benchmarking-with-launcher)
* [Performance Boost with Intel® Extension for PyTorch* and Launcher](https://github.com/pytorch/serve/blob/master/examples/intel_extension_for_pytorch/README.md#performance-boost-with-intel-extension-for-pytorch-and-launcher)
Expand Down Expand Up @@ -73,6 +75,7 @@ CPU usage is shown below. 4 main worker threads were launched, each launching 14
![26](https://user-images.githubusercontent.com/93151422/170373651-fd8a0363-febf-4528-bbae-e1ddef119358.gif)



#### Scaling workers
Additionally when dynamically [scaling the number of workers](https://pytorch.org/serve/management_api.html#scale-workers), cores that were pinned to killed workers by the launcher could be left unutilized. To address this problem, launcher internally restarts the workers to re-distribute cores that were pinned to killed workers to the remaining, alive workers. This is taken care internally, so users do not have to worry about this.

Expand All @@ -90,6 +93,93 @@ Add the following lines in `config.properties` to use launcher with its default
cpu_launcher_enable=true
```

## TorchServe with Intel® Extension for PyTorch* and Intel GPUs
agunapal marked this conversation as resolved.
Show resolved Hide resolved

TorchServe can also leverage Intel GPU for acceleration, providing additional performance benefits. To use TorchServe with Intel GPU, the machine must have the latest oneAPI Base Kit installed, activated, and ipex GPU installed.


### Installation and Setup for Intel GPU Support
**Install Intel oneAPI Base Kit:**
Follow the installation instructions for your operating system from the [Intel oneAPI Basekit Installation](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.htm).

**Install the ipex GPU package to enable TorchServe to utilize Intel GPU for acceleration:**
Follow the installation instructions for your operating system from the [ Intel® Extension for PyTorch* XPU/GPU Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu).

**Activate the Intel oneAPI Base Kit:**
Activate the Intel oneAPI Base Kit using the following command:
```bash
source /path/to/oneapi/setvars.sh
```

**Install xpu-smi:**
Install xpu-smi to let torchserve detect the number of Intel GPU devices present. xpu-smi provides information about the Intel GPU, including temperature, utilization, and other metrics.[xpu-smi Installation Guide](https://dgpu-docs.intel.com/driver/installation.html#ubuntu-package-repository)

**Enable Intel GPU Support in TorchServe:**
To enable TorchServe to use Intel GPUs, set the following configuration in `config.properties`:
```
ipex_enable=true
ipex_gpu_enable=true
```
To enable metric reading, additionally set:
```
system_metrics_cmd=${PATH to examples/intel_extension_for_pytorch/intel_gpu_metric_collector.py} --gpu ${Number of GPUs}
```

## Performance Gain with Intel® Extension for PyTorch* and Intel GPU

To understand the performance gain using Intel GPU, Torchserve recommended [apache benchmark](https://github.com/pytorch/serve/tree/master/benchmarks#benchmarking-with-apache-bench) is executed on FastRCNN FP32 model.

A `model_config.json` file is created, and the following parameters are added:

```
{
"url": "https://torchserve.pytorch.org/mar_files/fastrcnn.mar",
"requests": "10000",
"concurrency": "100",
"workers": "1",
"batch_delay": "100",
"batch_size": "1",
"input": "../examples/image_classifier/kitten.jpg",
"backend_profiling": "FALSE",
"exec_env": "local"
}
```

Batch size can be changed according to the requirement.

Following lines are added to the `config.properties` to utilize IPEX and Intel GPU:

```
ipex_enable=true
ipex_gpu_enable=true
```

To reproduce the test, use the following command:

```
python benchmark-ab.py --config model_config.json --config_properties config.properties
```

This test is performed on a server containing Intel(R) Core (TM) i5-9600K CPU + Intel(R) Arc(TM) A770 Graphics and is compared with a Intel(R) Xeon(R) Gold 6438Y CPU server.
It is recommended to use only 1 worker per GPU, more than 1 worker per GPU is not validated and may cause performance degradation due to context switching.


| Model | Batch size | CPU Throughput(img/sec) | GPU Throughput(img/sec) | CPU TS Latency mean(ms) | GPU TS Latency mean(ms) | Throughput speedup ratio | Latency speedup ratio |
|:-----:|:----------:|:--------------:|:--------------:|:-------------------:|:-------------------:|:-------------------------:|:----------------------:|
| FastRCNN_FP32 | 1 | 15.74 | 2.89 | 6352.388 | 34636.68 | 5.45 | 5.45 |
| | 2 | 17.69 | 2.67 | 5651.999 | 37520.781 | 6.63 | 6.64 |
| | 4 | 18.57 | 2.39 | 5385.389 | 41886.902 | 7.77 | 7.78 |
| | 8 | 18.68 | 2.32 | 5354.58 | 43146.797 | 8.05 | 8.06 |
| | 16 | 19.26 | 2.39 | 5193.307 | 41903.752 | 8.06 | 8.07 |
| | 32 | 19.06 | 2.49 | 5245.912 | 40172.39 | 7.65 | 7.66 |

<p align="center">
<img src="https://github.com/pytorch/serve/assets/113945574/c30aeacc-9825-42b1-bde8-2d9dca09bb8a" />
</p>
Above graph shows the speedup ratio of throughput and latency while using Intel GPU. The speedup ratio is increasing steadily reaching almost 8x till batch size 8 and gives diminishing returns after. Further increasing the batch size to 64 results in `RuntimeError: Native API failed. Native API returns: -5 (PI_ERROR_OUT_OF_RESOURCES)` error as GPU is overloaded.

Note: The optimal configuration will vary depending on the hardware used.

## Creating and Exporting INT8 model for Intel® Extension for PyTorch*
Intel® Extension for PyTorch* supports both eager and torchscript mode. In this section, we show how to deploy INT8 model for Intel® Extension for PyTorch*. Refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/features/int8_overview.md) for more details on Intel® Extension for PyTorch* optimizations for quantization.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public final class ConfigManager {
private static final String TS_IPEX_ENABLE = "ipex_enable";
private static final String TS_CPU_LAUNCHER_ENABLE = "cpu_launcher_enable";
private static final String TS_CPU_LAUNCHER_ARGS = "cpu_launcher_args";
private static final String TS_IPEX_GPU_ENABLE = "ipex_gpu_enable";

private static final String TS_ASYNC_LOGGING = "async_logging";
private static final String TS_CORS_ALLOWED_ORIGIN = "cors_allowed_origin";
Expand Down Expand Up @@ -161,7 +162,7 @@ public final class ConfigManager {
private static Pattern pattern = Pattern.compile("\\$\\$([^$]+[^$])\\$\\$");

private Pattern blacklistPattern;
private Properties prop;
private Properties prop;

private boolean snapshotDisabled;

Expand Down Expand Up @@ -272,6 +273,7 @@ private ConfigManager(Arguments args) throws IOException {
getAvailableGpu(),
getIntProperty(TS_NUMBER_OF_GPU, Integer.MAX_VALUE))));


String pythonExecutable = args.getPythonExecutable();
if (pythonExecutable != null) {
prop.setProperty(PYTHON_EXECUTABLE, pythonExecutable);
Expand Down Expand Up @@ -473,6 +475,10 @@ public String getCPULauncherArgs() {
return getProperty(TS_CPU_LAUNCHER_ARGS, null);
}

public boolean isIPEXGpuEnabled() {
return Boolean.parseBoolean(getProperty(TS_IPEX_GPU_ENABLE, "false"));
}

public boolean getDisableTokenAuthorization() {
return Boolean.parseBoolean(getProperty(TS_DISABLE_TOKEN_AUTHORIZATION, "false"));
}
Expand All @@ -490,8 +496,9 @@ public int getJobQueueSize() {
}

public int getNumberOfGpu() {
// return 1;
agunapal marked this conversation as resolved.
Show resolved Hide resolved
return getIntProperty(TS_NUMBER_OF_GPU, 0);
}
}

public boolean getModelControlMode() {
return Boolean.parseBoolean(getProperty(MODEL_CONTROL_MODE, "false"));
Expand Down Expand Up @@ -647,7 +654,7 @@ public String getCertificateFile() {
public String getSystemMetricsCmd() {
return prop.getProperty(SYSTEM_METRICS_CMD, "");
}

public SslContext getSslContext() throws IOException, GeneralSecurityException {
List<String> supportedCiphers =
Arrays.asList(
Expand Down Expand Up @@ -902,6 +909,7 @@ public HashMap<String, String> getBackendConfiguration() {
// Append properties used by backend worker here
config.put("TS_DECODE_INPUT_REQUEST", prop.getProperty(TS_DECODE_INPUT_REQUEST, "true"));
config.put("TS_IPEX_ENABLE", prop.getProperty(TS_IPEX_ENABLE, "false"));
config.put("TS_IPEX_GPU_ENABLE", prop.getProperty(TS_IPEX_GPU_ENABLE, "false"));
return config;
}

Expand All @@ -922,6 +930,7 @@ private static String getCanonicalPath(String path) {

private static int getAvailableGpu() {
try {

List<Integer> gpuIds = new ArrayList<>();
String visibleCuda = System.getenv("CUDA_VISIBLE_DEVICES");
if (visibleCuda != null && !visibleCuda.isEmpty()) {
Expand Down Expand Up @@ -953,7 +962,10 @@ private static int getAvailableGpu() {
// No MPS devices detected
return 0;
} else {
Process process =


try {
Process process =
Runtime.getRuntime().exec("nvidia-smi --query-gpu=index --format=csv");
int ret = process.waitFor();
if (ret != 0) {
Expand All @@ -967,13 +979,35 @@ private static int getAvailableGpu() {
for (int i = 1; i < list.size(); i++) {
gpuIds.add(Integer.parseInt(list.get(i)));
}
}
}catch (IOException | InterruptedException e) {
System.out.println("nvidia-smi not available or failed: " + e.getMessage());
}
try {
Process process = Runtime.getRuntime().exec("xpu-smi discovery --dump 1");
int ret = process.waitFor();
if (ret != 0) {
return 0;
}
List<String> list =
IOUtils.readLines(process.getInputStream(), StandardCharsets.UTF_8);
if (list.isEmpty() || !list.get(0).contains("Device ID")) {
throw new AssertionError("Unexpected xpu-smi response.");
}
for (int i = 1; i < list.size(); i++) {
gpuIds.add(Integer.parseInt(list.get(i)));
}
}catch (IOException | InterruptedException e) {
System.out.println("xpu-smi not available or failed: " + e.getMessage());
agunapal marked this conversation as resolved.
Show resolved Hide resolved
}



}
return gpuIds.size();
} catch (IOException | InterruptedException e) {
return 0;
}
}
}

public List<String> getAllowedUrls() {
String allowedURL = prop.getProperty(TS_ALLOWED_URLS, DEFAULT_TS_ALLOWED_URLS);
Expand Down
18 changes: 12 additions & 6 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
load_label_mapping,
)


if packaging.version.parse(torch.__version__) >= packaging.version.parse("1.8.1"):
from torch.profiler import ProfilerActivity, profile, record_function

Expand Down Expand Up @@ -69,7 +70,6 @@
if os.environ.get("TS_IPEX_ENABLE", "false") == "true":
try:
import intel_extension_for_pytorch as ipex

IPEX_AVAILABLE = True
except ImportError as error:
logger.warning(
Expand All @@ -79,7 +79,7 @@
else:
IPEX_AVAILABLE = False


try:
import onnxruntime as ort
import psutil
Expand Down Expand Up @@ -147,17 +147,22 @@ def initialize(self, context):
RuntimeError: Raises the Runtime error when the model.py is missing

"""

if context is not None and hasattr(context, "model_yaml_config"):
self.model_yaml_config = context.model_yaml_config

properties = context.system_properties

if torch.cuda.is_available() and properties.get("gpu_id") is not None:
self.map_location = "cuda"
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
)
elif torch.xpu.is_available() and properties.get("gpu_id") is not None and os.environ.get("TS_IPEX_GPU_ENABLE", "false") == "true":
self.map_location = "xpu"
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
)
torch.xpu.device(self.device)
elif torch.backends.mps.is_available() and properties.get("gpu_id") is not None:
self.map_location = "mps"
self.device = torch.device("mps")
Expand Down Expand Up @@ -273,6 +278,7 @@ def initialize(self, context):

elif IPEX_AVAILABLE:
self.model = self.model.to(memory_format=torch.channels_last)
self.model = self.model.to(self.device)
self.model = ipex.optimize(self.model)
logger.info(f"Compiled model with ipex")

Expand Down Expand Up @@ -560,4 +566,4 @@ def get_device(self):
Returns:
string : self device
"""
return self.device
return self.device
Loading