-
Notifications
You must be signed in to change notification settings - Fork 858
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrate vllm with example Lora and Mistral (#3077)
* init * init * fix typo * update client to load json promt * fix vllm parameter parse * fix stop criteria bugs * update readme and add test * update test * fix lint * update mistral example * update mistral example * update model config * fix output text * update pytest --------- Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>
- Loading branch information
Showing
16 changed files
with
453 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Example showing inference with vLLM | ||
|
||
This folder contains multiple demonstrations showcasing the integration of [vLLM Engine](https://github.com/vllm-project/vllm) with TorchServe, running inference with continuous batching. | ||
vLLM achieves high throughput using PagedAttention. More details can be found [here](https://vllm.ai/) | ||
|
||
- demo1: [Mistral](mistral) | ||
- demo2: [lora](lora) | ||
|
||
### Supported vLLM Configuration | ||
* LLMEngine configuration: | ||
vLLM [EngineArgs](https://github.com/vllm-project/vllm/blob/258a2c58d08fc7a242556120877a89404861fbce/vllm/engine/arg_utils.py#L15) is defined in the section of `handler/vllm_engine_config` of model-config.yaml. | ||
|
||
|
||
* Sampling parameters for text generation: | ||
vLLM [SamplingParams](https://github.com/vllm-project/vllm/blob/258a2c58d08fc7a242556120877a89404861fbce/vllm/sampling_params.py#L27) is defined in the JSON format, for example, [prompt.json](lora/prompt.json). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import logging | ||
import pathlib | ||
|
||
from vllm import EngineArgs, LLMEngine, SamplingParams | ||
from vllm.lora.request import LoRARequest | ||
|
||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class BaseVLLMHandler(BaseHandler): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.vllm_engine = None | ||
self.model = None | ||
self.model_dir = None | ||
self.lora_ids = {} | ||
self.adapters = None | ||
self.initialized = False | ||
|
||
def initialize(self, ctx): | ||
ctx.cache = {} | ||
|
||
self.model_dir = ctx.system_properties.get("model_dir") | ||
vllm_engine_config = self._get_vllm_engine_config( | ||
ctx.model_yaml_config.get("handler", {}) | ||
) | ||
self.adapters = ctx.model_yaml_config.get("handler", {}).get("adapters", {}) | ||
self.vllm_engine = LLMEngine.from_engine_args(vllm_engine_config) | ||
self.initialized = True | ||
|
||
def preprocess(self, requests): | ||
for req_id, req_data in zip(self.context.request_ids.values(), requests): | ||
if req_id not in self.context.cache: | ||
data = req_data.get("data") or req_data.get("body") | ||
if isinstance(data, (bytes, bytearray)): | ||
data = data.decode("utf-8") | ||
|
||
prompt = data.get("prompt") | ||
sampling_params = self._get_sampling_params(req_data) | ||
lora_request = self._get_lora_request(req_data) | ||
self.context.cache[req_id] = { | ||
"text_len": 0, | ||
"stopping_criteria": self._create_stopping_criteria(req_id), | ||
} | ||
self.vllm_engine.add_request( | ||
req_id, prompt, sampling_params, lora_request=lora_request | ||
) | ||
|
||
return requests | ||
|
||
def inference(self, input_batch): | ||
inference_outputs = self.vllm_engine.step() | ||
results = {} | ||
|
||
for output in inference_outputs: | ||
req_id = output.request_id | ||
results[req_id] = { | ||
"text": output.outputs[0].text[ | ||
self.context.cache[req_id]["text_len"] : | ||
], | ||
"tokens": output.outputs[0].token_ids[-1], | ||
"finished": output.finished, | ||
} | ||
self.context.cache[req_id]["text_len"] = len(output.outputs[0].text) | ||
|
||
return [results[i] for i in self.context.request_ids.values()] | ||
|
||
def postprocess(self, inference_outputs): | ||
self.context.stopping_criteria = [ | ||
self.context.cache[req_id]["stopping_criteria"] | ||
for req_id in self.context.request_ids.values() | ||
] | ||
|
||
return inference_outputs | ||
|
||
def _get_vllm_engine_config(self, handler_config: dict): | ||
vllm_engine_params = handler_config.get("vllm_engine_config", {}) | ||
model = vllm_engine_params.get("model", {}) | ||
if len(model) == 0: | ||
model_path = handler_config.get("model_path", {}) | ||
assert ( | ||
len(model_path) > 0 | ||
), "please define model in vllm_engine_config or model_path in handler" | ||
model = str(pathlib.Path(self.model_dir).joinpath(model_path)) | ||
logger.info(f"EngineArgs model={model}") | ||
vllm_engine_config = EngineArgs(model=model) | ||
self._set_attr_value(vllm_engine_config, vllm_engine_params) | ||
return vllm_engine_config | ||
|
||
def _get_sampling_params(self, req_data: dict): | ||
sampling_params = SamplingParams() | ||
self._set_attr_value(sampling_params, req_data) | ||
|
||
return sampling_params | ||
|
||
def _get_lora_request(self, req_data: dict): | ||
adapter_name = req_data.get("lora_adapter", "") | ||
|
||
if len(adapter_name) > 0: | ||
adapter_path = self.adapters.get(adapter_name, "") | ||
assert len(adapter_path) > 0, f"{adapter_name} misses adapter path" | ||
lora_id = self.lora_ids.setdefault(adapter_name, len(self.lora_ids) + 1) | ||
adapter_path = str(pathlib.Path(self.model_dir).joinpath(adapter_path)) | ||
logger.info(f"adapter_path=${adapter_path}") | ||
return LoRARequest(adapter_name, lora_id, adapter_path) | ||
|
||
return None | ||
|
||
def _clean_up(self, req_id): | ||
del self.context.cache[req_id] | ||
|
||
def _create_stopping_criteria(self, req_id): | ||
class StoppingCriteria(object): | ||
def __init__(self, outer, req_id): | ||
self.req_id = req_id | ||
self.outer = outer | ||
|
||
def __call__(self, res): | ||
if res["finished"]: | ||
self.outer._clean_up(self.req_id) | ||
return res["finished"] | ||
|
||
return StoppingCriteria(outer=self, req_id=req_id) | ||
|
||
def _set_attr_value(self, obj, config: dict): | ||
items = vars(obj) | ||
for k, v in config.items(): | ||
if k in items: | ||
setattr(obj, k, v) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Example showing inference with vLLM on LoRA model | ||
|
||
This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `Llama-2-7b-hf` + LoRA model `llama-2-7b-sql-lora-test` with continuous batching. | ||
|
||
### Step 1: Download Model from HuggingFace | ||
|
||
Login with a HuggingFace account | ||
``` | ||
huggingface-cli login | ||
# or using an environment variable | ||
huggingface-cli login --token $HUGGINGFACE_TOKEN | ||
``` | ||
|
||
```bash | ||
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Llama-2-7b-chat-hf --use_auth_token True | ||
mkdir adapters && cd adapters | ||
python ../../../utils/Download_model.py --model_path model --model_name yard1/llama-2-7b-sql-lora-test --use_auth_token True | ||
cd .. | ||
``` | ||
|
||
### Step 2: Generate model artifacts | ||
|
||
Add the downloaded path to "model_path:" and "adapter_1:" in `model-config.yaml` and run the following. | ||
|
||
```bash | ||
torch-model-archiver --model-name llama-7b-lora --version 1.0 --handler ../base_vllm_handler.py --config-file model-config.yaml -r ../requirements.txt --archive-format no-archive | ||
mv model llama-7b-lora | ||
mv adapters llama-7b-lora | ||
``` | ||
|
||
### Step 3: Add the model artifacts to model store | ||
|
||
```bash | ||
mkdir model_store | ||
mv llama-7b-lora model_store | ||
``` | ||
|
||
### Step 4: Start torchserve | ||
|
||
```bash | ||
torchserve --start --ncs --ts-config ../config.properties --model-store model_store --models llama-7b-lora | ||
``` | ||
|
||
### Step 5: Run inference | ||
|
||
```bash | ||
python ../../utils/test_llm_streaming_response.py -m lora -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# TorchServe frontend parameters | ||
minWorkers: 1 | ||
maxWorkers: 1 | ||
batchSize: 16 | ||
maxBatchDelay: 100 | ||
responseTimeout: 1200 | ||
deviceType: "gpu" | ||
continuousBatching: true | ||
|
||
handler: | ||
model_path: "model/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9" | ||
vllm_engine_config: | ||
enable_lora: true | ||
max_loras: 4 | ||
max_cpu_loras: 4 | ||
max_num_seqs: 16 | ||
max_model_len: 250 | ||
|
||
adapters: | ||
adapter_1: "adapters/model/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
{ | ||
"prompt": "A robot may not injure a human being", | ||
"max_new_tokens": 50, | ||
"temperature": 0.8, | ||
"logprobs": 1, | ||
"prompt_logprobs": 1, | ||
"max_tokens": 128, | ||
"adapter": "adapter_1" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.