Skip to content

Commit

Permalink
Asynchronous worker communication and vllm integration (#3146)
Browse files Browse the repository at this point in the history
* Added dummy async comm worker thread

* First version of async worker in frontend running

* [WIP]Running async worker but requests get corrupted if parallel

* First version running with thread feeding + async predict

* shorten vllm test time

* Added AsyncVLLMEngine

* Extend vllm test with multiple possible prompts

* Batch size =1 and remove stream in test

* Switched vllm examples to async comm and added llama3 example

* Fix typo

* Corrected java file formatting

* Cleanup and silent chatty debug message

* Added multi-gpu support to vllm examples

* fix java format

* Remove debugging messages

* Fix async comm worker test

* Added cl_socket to fixture

* Added multi worker note to vllm example readme

* Disable tests

* Enable async worker comm test

* Debug CI

* Fix python version <= 3.9 issue in async worker

* Renamed async worker test

* Update frontend/server/src/main/java/org/pytorch/serve/wlm/AsyncBatchAggregator.java

Remove job from jobs_in_backend on error

Co-authored-by: Naman Nandan <namankt55@gmail.com>

* Unskip vllm example test

* Clean up async worker code

* Safely remove jobs from jobs_in_backend

* Let worker die if one of the threads in async service dies

* Add description of parallelLevel and parallelType=custom to docs/large_model_inference.md

* Added description of parallelLevel to model-archiver readme.md

* fix typo + added words

* Fix skip condition for vllm example test

---------

Co-authored-by: Naman Nandan <namankt55@gmail.com>
  • Loading branch information
mreso and namannandan authored Jun 22, 2024
1 parent 4c96e6f commit 5f3df71
Show file tree
Hide file tree
Showing 28 changed files with 1,267 additions and 186 deletions.
30 changes: 27 additions & 3 deletions docs/large_model_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This document explain how Torchserve supports large model serving, here large model refers to the models that are not able to fit into one gpu so they need be split in multiple partitions over multiple gpus.
This page is split into the following sections:
- [How it works](#how-it-works)
- [Large Model Inference with vLLM](#pippy-pytorch-native-solution-for-large-model-inference)
- [Large Model Inference with PiPPy](#pippy-pytorch-native-solution-for-large-model-inference)
- [Large Model Inference with Deep Speed](#deepspeed)
- [Deep Speed MII](#deepspeed-mii)
Expand All @@ -11,13 +12,36 @@ This page is split into the following sections:

## How it works?

During deployment a worker of a large model, TorchServe utilizes [torchrun](https://pytorch.org/docs/stable/elastic/run.html) to set up the distributed environment for model parallel processing. TorchServe has the capability to support multiple workers for a large model. By default, TorchServe uses a round-robin algorithm to assign GPUs to a worker on a host. In case of large models inference GPUs assigned to each worker is automatically calculated based on number of GPUs specified in the model_config.yaml. CUDA_VISIBLE_DEVICES is set based this number.
For GPU inference of smaller models TorchServe executes a single process per worker which gets assigned a single GPU.
For large model inference the model needs to be split over multiple GPUs.
There are different modes to achieve this split which usually include pipeline parallel (PP), tensor parallel or a combination of these.
Which mode is selected and how the split is implemented depends on the implementation in the utilized framework.
TorchServe allows users to utilize any framework for their model deployment and tries to accommodate the needs of the frameworks through flexible configurations.
Some frameworks require to execute a separate process for each of the GPUs (PiPPy, Deep Speed) while others require a single process which get assigned all GPUs (vLLM).
In case multiple processes are required TorchServe utilizes [torchrun](https://pytorch.org/docs/stable/elastic/run.html) to set up the distributed environment for the worker.
During the setup `torchrun` will start a new process for each GPU assigned to the worker.
If torchrun is utilized or not depends on the parameter parallelType which can be set in the `model-config.yaml` to one of the following options:

For instance, suppose there are eight GPUs on a node and one worker needs 4 GPUs (ie, nproc-per-node=4) on a node. In this case, TorchServe would assign CUDA_VISIBLE_DEVICES="0,1,2,3" to worker1 and CUDA_VISIBLE_DEVICES="4,5,6,7" to worker2.
* `pp` - for pipeline parallel
* `tp` - for tensor parallel
* `pptp` - for pipeline + tensor parallel
* `custom`

In addition to this default behavior, TorchServe provides the flexibility for users to specify GPUs for a worker. For instance, if the user sets "deviceIds: [2,3,4,5]" in the [model config YAML file](https://github.com/pytorch/serve/blob/5ee02e4f050c9b349025d87405b246e970ee710b/model-archiver/README.md?plain=1#L164), and nproc-per-node is set to 2, then TorchServe would assign CUDA_VISIBLE_DEVICES="2,3" to worker1 and CUDA_VISIBLE_DEVICES="4,5" to worker2.
The first three options setup the environment using torchrun while the "custom" option leaves the way of parallelization to the user and assigned the GPUs assigned to a worker to a single process.
The number of assigned GPUs is determined either by the number of processes started by torchrun i.e. configured through nproc-per-node OR the parameter parallelLevel.
Meaning that the parameter parallelLevel should NOT be set if nproc-per-node is set and vice versa.

By default, TorchServe uses a round-robin algorithm to assign GPUs to a worker on a host.
In case of large models inference GPUs assigned to each worker is automatically calculated based on the number of GPUs specified in the model_config.yaml.
CUDA_VISIBLE_DEVICES is set based this number.

For instance, suppose there are eight GPUs on a node and one worker needs 4 GPUs (ie, nproc-per-node=4 OR parallelLevel=4) on a node.
In this case, TorchServe would assign CUDA_VISIBLE_DEVICES="0,1,2,3" to worker1 and CUDA_VISIBLE_DEVICES="4,5,6,7" to worker2.

In addition to this default behavior, TorchServe provides the flexibility for users to specify GPUs for a worker. For instance, if the user sets "deviceIds: [2,3,4,5]" in the [model config YAML file](https://github.com/pytorch/serve/blob/5ee02e4f050c9b349025d87405b246e970ee710b/model-archiver/README.md?plain=1#L164), and nproc-per-node (OR parallelLevel) is set to 2, then TorchServe would assign CUDA_VISIBLE_DEVICES="2,3" to worker1 and CUDA_VISIBLE_DEVICES="4,5" to worker2.

Using Pippy integration as an example, the image below illustrates the internals of the TorchServe large model inference.
For an example using vLLM see [this example](../examples/large_models/vllm/).

![ts-lmi-internal](https://github.com/raw/pytorch/serve/master/docs/images/ts-lmi-internal.png)

Expand Down
34 changes: 31 additions & 3 deletions examples/large_models/vllm/Readme.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# Example showing inference with vLLM

This folder contains multiple demonstrations showcasing the integration of [vLLM Engine](https://github.com/vllm-project/vllm) with TorchServe, running inference with continuous batching.
vLLM achieves high throughput using PagedAttention. More details can be found [here](https://vllm.ai/)
vLLM achieves high throughput using PagedAttention. More details can be found [here](https://vllm.ai/).
The vLLM integration uses our new asynchronous worker communication mode which decoupled communication between frontend and backend from running the actual inference.
By using this new feature TorchServe is capable to feed incoming requests into the vLLM engine while asynchronously running the engine in the backend.
As long as a single request is inside the engine it will continue to run and asynchronously stream out the results until the request is finished.
New requests are added to the engine in a continuous fashion similar to the continuous batching mode shown in other examples.
For all examples distributed inference can be enabled by following the instruction [here](./Readme.md#distributed-inference)

- demo1: [Mistral](mistral)
- demo2: [lora](lora)
- demo1: [Meta-Llama3](llama3)
- demo2: [Mistral](mistral)
- demo3: [lora](lora)

### Supported vLLM Configuration
* LLMEngine configuration:
Expand All @@ -13,3 +19,25 @@ vLLM [EngineArgs](https://github.com/vllm-project/vllm/blob/258a2c58d08fc7a24255

* Sampling parameters for text generation:
vLLM [SamplingParams](https://github.com/vllm-project/vllm/blob/258a2c58d08fc7a242556120877a89404861fbce/vllm/sampling_params.py#L27) is defined in the JSON format, for example, [prompt.json](lora/prompt.json).

### Distributed Inference
All examples can be easily distributed over multiple GPUs by enabling tensor parallelism in vLLM.
To enable distributed inference the following additions need to made to the model-config.yaml of the examples where 4 is the number of desired GPUs to use for the inference:

```yaml
# TorchServe frontend parameters
...
parallelType: "custom"
parallelLevel: 4

handler:
...
vllm_engine_config:
...
tensor_parallel_size: 4
```
### Multi-worker Note:
While this example in theory works with multiple workers it would distribute the incoming requests in a round robin fashion which might lead to non optimal worker/hardware utilization.
It is therefore advised to only use a single worker per engine and utilize tensor parallelism to distribute the model over multiple GPUs as described in the previous section.
This will result in better hardware utilization and inference performance.
111 changes: 51 additions & 60 deletions examples/large_models/vllm/base_vllm_handler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
import logging
import pathlib
import time

from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
from vllm.lora.request import LoRARequest

from ts.handler_utils.utils import send_intermediate_predict_response
from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)
Expand All @@ -21,59 +24,64 @@ def __init__(self):
self.initialized = False

def initialize(self, ctx):
ctx.cache = {}

self.model_dir = ctx.system_properties.get("model_dir")
vllm_engine_config = self._get_vllm_engine_config(
ctx.model_yaml_config.get("handler", {})
)
self.adapters = ctx.model_yaml_config.get("handler", {}).get("adapters", {})
self.vllm_engine = LLMEngine.from_engine_args(vllm_engine_config)

self.vllm_engine = AsyncLLMEngine.from_engine_args(vllm_engine_config)
self.initialized = True

def preprocess(self, requests):
for req_id, req_data in zip(self.context.request_ids.values(), requests):
if req_id not in self.context.cache:
data = req_data.get("data") or req_data.get("body")
if isinstance(data, (bytes, bytearray)):
data = data.decode("utf-8")

prompt = data.get("prompt")
sampling_params = self._get_sampling_params(req_data)
lora_request = self._get_lora_request(req_data)
self.context.cache[req_id] = {
"text_len": 0,
"stopping_criteria": self._create_stopping_criteria(req_id),
}
self.vllm_engine.add_request(
req_id, prompt, sampling_params, lora_request=lora_request
)
async def handle(self, data, context):
start_time = time.time()

return requests
metrics = context.metrics

def inference(self, input_batch):
inference_outputs = self.vllm_engine.step()
results = {}
data_preprocess = await self.preprocess(data)
output = await self.inference(data_preprocess, context)
output = await self.postprocess(output)

for output in inference_outputs:
req_id = output.request_id
results[req_id] = {
"text": output.outputs[0].text[
self.context.cache[req_id]["text_len"] :
],
stop_time = time.time()
metrics.add_time(
"HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms"
)
return output

async def preprocess(self, requests):
input_batch = []
assert len(requests) == 1, "Expecting batch_size = 1"
for req_data in requests:
data = req_data.get("data") or req_data.get("body")
if isinstance(data, (bytes, bytearray)):
data = data.decode("utf-8")

prompt = data.get("prompt")
sampling_params = self._get_sampling_params(data)
lora_request = self._get_lora_request(data)
input_batch += [(prompt, sampling_params, lora_request)]
return input_batch

async def inference(self, input_batch, context):
logger.debug(f"Inputs: {input_batch[0]}")
prompt, params, lora = input_batch[0]
generator = self.vllm_engine.generate(
prompt, params, context.request_ids[0], lora
)
text_len = 0
async for output in generator:
result = {
"text": output.outputs[0].text[text_len:],
"tokens": output.outputs[0].token_ids[-1],
"finished": output.finished,
}
self.context.cache[req_id]["text_len"] = len(output.outputs[0].text)

return [results[i] for i in self.context.request_ids.values()]

def postprocess(self, inference_outputs):
self.context.stopping_criteria = [
self.context.cache[req_id]["stopping_criteria"]
for req_id in self.context.request_ids.values()
]
text_len = len(output.outputs[0].text)
if not output.finished:
send_intermediate_predict_response(
[json.dumps(result)], context.request_ids, "Result", 200, context
)
return [json.dumps(result)]

async def postprocess(self, inference_outputs):
return inference_outputs

def _get_vllm_engine_config(self, handler_config: dict):
Expand All @@ -85,8 +93,8 @@ def _get_vllm_engine_config(self, handler_config: dict):
len(model_path) > 0
), "please define model in vllm_engine_config or model_path in handler"
model = str(pathlib.Path(self.model_dir).joinpath(model_path))
logger.info(f"EngineArgs model={model}")
vllm_engine_config = EngineArgs(model=model)
logger.debug(f"EngineArgs model: {model}")
vllm_engine_config = AsyncEngineArgs(model=model)
self._set_attr_value(vllm_engine_config, vllm_engine_params)
return vllm_engine_config

Expand All @@ -104,27 +112,10 @@ def _get_lora_request(self, req_data: dict):
assert len(adapter_path) > 0, f"{adapter_name} misses adapter path"
lora_id = self.lora_ids.setdefault(adapter_name, len(self.lora_ids) + 1)
adapter_path = str(pathlib.Path(self.model_dir).joinpath(adapter_path))
logger.info(f"adapter_path=${adapter_path}")
return LoRARequest(adapter_name, lora_id, adapter_path)

return None

def _clean_up(self, req_id):
del self.context.cache[req_id]

def _create_stopping_criteria(self, req_id):
class StoppingCriteria(object):
def __init__(self, outer, req_id):
self.req_id = req_id
self.outer = outer

def __call__(self, res):
if res["finished"]:
self.outer._clean_up(self.req_id)
return res["finished"]

return StoppingCriteria(outer=self, req_id=req_id)

def _set_attr_value(self, obj, config: dict):
items = vars(obj)
for k, v in config.items():
Expand Down
45 changes: 45 additions & 0 deletions examples/large_models/vllm/llama3/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Example showing inference with vLLM on LoRA model

This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `meta-llama/Meta-Llama-3-8B-Instruct` with continuous batching.
This examples supports distributed inference by following [this instruction](../Readme.md#distributed-inference)

### Step 1: Download Model from HuggingFace

Login with a HuggingFace account
```
huggingface-cli login
# or using an environment variable
huggingface-cli login --token $HUGGINGFACE_TOKEN
```

```bash
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3-8B-Instruct --use_auth_token True
```

### Step 2: Generate model artifacts

Add the downloaded path to "model_path:" in `model-config.yaml` and run the following.

```bash
torch-model-archiver --model-name llama3-8b --version 1.0 --handler ../base_vllm_handler.py --config-file model-config.yaml -r ../requirements.txt --archive-format no-archive
mv model llama3-8b
```

### Step 3: Add the model artifacts to model store

```bash
mkdir model_store
mv llama3-8b model_store
```

### Step 4: Start torchserve

```bash
torchserve --start --ncs --ts-config ../config.properties --model-store model_store --models llama3-8b
```

### Step 5: Run inference

```bash
python ../../utils/test_llm_streaming_response.py -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json
```
13 changes: 13 additions & 0 deletions examples/large_models/vllm/llama3/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# TorchServe frontend parameters
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 100
responseTimeout: 1200
deviceType: "gpu"
asyncCommunication: true

handler:
model_path: "model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/"
vllm_engine_config:
max_num_seqs: 16
max_model_len: 250
9 changes: 9 additions & 0 deletions examples/large_models/vllm/llama3/prompt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"prompt": "A robot may not injure a human being",
"max_new_tokens": 50,
"temperature": 0.8,
"logprobs": 1,
"prompt_logprobs": 1,
"max_tokens": 128,
"adapter": "adapter_1"
}
1 change: 1 addition & 0 deletions examples/large_models/vllm/lora/Readme.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Example showing inference with vLLM on LoRA model

This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `Llama-2-7b-hf` + LoRA model `llama-2-7b-sql-lora-test` with continuous batching.
This examples supports distributed inference by following [this instruction](../Readme.md#distributed-inference)

### Step 1: Download Model from HuggingFace

Expand Down
5 changes: 2 additions & 3 deletions examples/large_models/vllm/lora/model-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# TorchServe frontend parameters
minWorkers: 1
maxWorkers: 1
batchSize: 16
maxBatchDelay: 100
responseTimeout: 1200
deviceType: "gpu"
continuousBatching: true
asyncCommunication: true

handler:
model_path: "model/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9"
model_path: "model/models--meta-llama--Llama-2-7b-chat-hf/snapshots/f5db02db724555f92da89c216ac04704f23d4590/"
vllm_engine_config:
enable_lora: true
max_loras: 4
Expand Down
1 change: 1 addition & 0 deletions examples/large_models/vllm/mistral/Readme.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Example showing inference with vLLM on Mistral model

This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `mistralai/Mistral-7B-v0.1` with continuous batching.
This examples supports distributed inference by following [this instruction](../Readme.md#distributed-inference)

### Step 1: Download Model from HuggingFace

Expand Down
3 changes: 1 addition & 2 deletions examples/large_models/vllm/mistral/model-config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# TorchServe frontend parameters
minWorkers: 1
maxWorkers: 1
batchSize: 16
maxBatchDelay: 100
responseTimeout: 1200
deviceType: "gpu"
continuousBatching: true
asyncCommunication: true

handler:
model_path: "model/models--mistralai--Mistral-7B-v0.1/snapshots/26bca36bde8333b5d7f72e9ed20ccda6a618af24"
Expand Down
Loading

0 comments on commit 5f3df71

Please sign in to comment.