Skip to content

Commit

Permalink
Integrate vllm with example Lora and Mistral (#3077)
Browse files Browse the repository at this point in the history
* init

* init

* fix typo

* update client to load json promt

* fix vllm parameter parse

* fix stop criteria bugs

* update readme and add test

* update  test

* fix lint

* update mistral example

* update mistral example

* update model config

* fix output text

* update pytest

---------

Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>
  • Loading branch information
lxning and mreso authored May 3, 2024
1 parent a09de83 commit f2c26f3
Show file tree
Hide file tree
Showing 16 changed files with 453 additions and 132 deletions.
34 changes: 26 additions & 8 deletions examples/large_models/utils/test_llm_streaming_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,40 @@ def _predict(self):
if self.args.demo_streaming:
print(data["text"], end="", flush=True)
else:
combined_text += data["text"]
combined_text += data.get("text", "")
if not self.args.demo_streaming:
self.queue.put_nowait(f"payload={payload}\n, output={combined_text}\n")

def _get_url(self):
return f"http://localhost:8080/predictions/{self.args.model}"

def _format_payload(self):
prompt = _load_curl_like_data(self.args.prompt_text)
prompt_list = prompt.split(" ")
prompt_input = _load_curl_like_data(self.args.prompt_text)
if self.args.prompt_json:
prompt_input = orjson.loads(prompt_input)
prompt = prompt_input.get("prompt", None)
assert prompt is not None
prompt_list = prompt.split(" ")
rt = int(prompt_input.get("max_new_tokens", self.args.max_tokens))
else:
prompt_list = prompt_input.split(" ")
rt = self.args.max_tokens
rp = len(prompt_list)
rt = self.args.max_tokens
if self.args.prompt_randomize:
rp = random.randint(0, max_prompt_random_tokens)
rt = rp + self.args.max_tokens
for _ in range(rp):
prompt_list.insert(0, chr(ord("a") + random.randint(0, 25)))
cur_prompt = " ".join(prompt_list)
return {
"prompt": cur_prompt,
"max_new_tokens": rt,
}
if self.args.prompt_json:
prompt_input["prompt"] = cur_prompt
prompt_input["max_new_tokens"] = rt
return prompt_input
else:
return {
"prompt": cur_prompt,
"max_new_tokens": rt,
}


def _load_curl_like_data(text):
Expand Down Expand Up @@ -112,6 +124,12 @@ def parse_args():
default=1,
help="Execute the number of prediction in each thread",
)
parser.add_argument(
"--prompt-json",
action=argparse.BooleanOptionalAction,
default=False,
help="Flag the imput prompt is a json format with prompt parameters",
)
parser.add_argument(
"--demo-streaming",
action=argparse.BooleanOptionalAction,
Expand Down
15 changes: 15 additions & 0 deletions examples/large_models/vllm/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Example showing inference with vLLM

This folder contains multiple demonstrations showcasing the integration of [vLLM Engine](https://github.com/vllm-project/vllm) with TorchServe, running inference with continuous batching.
vLLM achieves high throughput using PagedAttention. More details can be found [here](https://vllm.ai/)

- demo1: [Mistral](mistral)
- demo2: [lora](lora)

### Supported vLLM Configuration
* LLMEngine configuration:
vLLM [EngineArgs](https://github.com/vllm-project/vllm/blob/258a2c58d08fc7a242556120877a89404861fbce/vllm/engine/arg_utils.py#L15) is defined in the section of `handler/vllm_engine_config` of model-config.yaml.


* Sampling parameters for text generation:
vLLM [SamplingParams](https://github.com/vllm-project/vllm/blob/258a2c58d08fc7a242556120877a89404861fbce/vllm/sampling_params.py#L27) is defined in the JSON format, for example, [prompt.json](lora/prompt.json).
132 changes: 132 additions & 0 deletions examples/large_models/vllm/base_vllm_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import logging
import pathlib

from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest

from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class BaseVLLMHandler(BaseHandler):
def __init__(self):
super().__init__()

self.vllm_engine = None
self.model = None
self.model_dir = None
self.lora_ids = {}
self.adapters = None
self.initialized = False

def initialize(self, ctx):
ctx.cache = {}

self.model_dir = ctx.system_properties.get("model_dir")
vllm_engine_config = self._get_vllm_engine_config(
ctx.model_yaml_config.get("handler", {})
)
self.adapters = ctx.model_yaml_config.get("handler", {}).get("adapters", {})
self.vllm_engine = LLMEngine.from_engine_args(vllm_engine_config)
self.initialized = True

def preprocess(self, requests):
for req_id, req_data in zip(self.context.request_ids.values(), requests):
if req_id not in self.context.cache:
data = req_data.get("data") or req_data.get("body")
if isinstance(data, (bytes, bytearray)):
data = data.decode("utf-8")

prompt = data.get("prompt")
sampling_params = self._get_sampling_params(req_data)
lora_request = self._get_lora_request(req_data)
self.context.cache[req_id] = {
"text_len": 0,
"stopping_criteria": self._create_stopping_criteria(req_id),
}
self.vllm_engine.add_request(
req_id, prompt, sampling_params, lora_request=lora_request
)

return requests

def inference(self, input_batch):
inference_outputs = self.vllm_engine.step()
results = {}

for output in inference_outputs:
req_id = output.request_id
results[req_id] = {
"text": output.outputs[0].text[
self.context.cache[req_id]["text_len"] :
],
"tokens": output.outputs[0].token_ids[-1],
"finished": output.finished,
}
self.context.cache[req_id]["text_len"] = len(output.outputs[0].text)

return [results[i] for i in self.context.request_ids.values()]

def postprocess(self, inference_outputs):
self.context.stopping_criteria = [
self.context.cache[req_id]["stopping_criteria"]
for req_id in self.context.request_ids.values()
]

return inference_outputs

def _get_vllm_engine_config(self, handler_config: dict):
vllm_engine_params = handler_config.get("vllm_engine_config", {})
model = vllm_engine_params.get("model", {})
if len(model) == 0:
model_path = handler_config.get("model_path", {})
assert (
len(model_path) > 0
), "please define model in vllm_engine_config or model_path in handler"
model = str(pathlib.Path(self.model_dir).joinpath(model_path))
logger.info(f"EngineArgs model={model}")
vllm_engine_config = EngineArgs(model=model)
self._set_attr_value(vllm_engine_config, vllm_engine_params)
return vllm_engine_config

def _get_sampling_params(self, req_data: dict):
sampling_params = SamplingParams()
self._set_attr_value(sampling_params, req_data)

return sampling_params

def _get_lora_request(self, req_data: dict):
adapter_name = req_data.get("lora_adapter", "")

if len(adapter_name) > 0:
adapter_path = self.adapters.get(adapter_name, "")
assert len(adapter_path) > 0, f"{adapter_name} misses adapter path"
lora_id = self.lora_ids.setdefault(adapter_name, len(self.lora_ids) + 1)
adapter_path = str(pathlib.Path(self.model_dir).joinpath(adapter_path))
logger.info(f"adapter_path=${adapter_path}")
return LoRARequest(adapter_name, lora_id, adapter_path)

return None

def _clean_up(self, req_id):
del self.context.cache[req_id]

def _create_stopping_criteria(self, req_id):
class StoppingCriteria(object):
def __init__(self, outer, req_id):
self.req_id = req_id
self.outer = outer

def __call__(self, res):
if res["finished"]:
self.outer._clean_up(self.req_id)
return res["finished"]

return StoppingCriteria(outer=self, req_id=req_id)

def _set_attr_value(self, obj, config: dict):
items = vars(obj)
for k, v in config.items():
if k in items:
setattr(obj, k, v)
48 changes: 48 additions & 0 deletions examples/large_models/vllm/lora/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Example showing inference with vLLM on LoRA model

This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `Llama-2-7b-hf` + LoRA model `llama-2-7b-sql-lora-test` with continuous batching.

### Step 1: Download Model from HuggingFace

Login with a HuggingFace account
```
huggingface-cli login
# or using an environment variable
huggingface-cli login --token $HUGGINGFACE_TOKEN
```

```bash
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Llama-2-7b-chat-hf --use_auth_token True
mkdir adapters && cd adapters
python ../../../utils/Download_model.py --model_path model --model_name yard1/llama-2-7b-sql-lora-test --use_auth_token True
cd ..
```

### Step 2: Generate model artifacts

Add the downloaded path to "model_path:" and "adapter_1:" in `model-config.yaml` and run the following.

```bash
torch-model-archiver --model-name llama-7b-lora --version 1.0 --handler ../base_vllm_handler.py --config-file model-config.yaml -r ../requirements.txt --archive-format no-archive
mv model llama-7b-lora
mv adapters llama-7b-lora
```

### Step 3: Add the model artifacts to model store

```bash
mkdir model_store
mv llama-7b-lora model_store
```

### Step 4: Start torchserve

```bash
torchserve --start --ncs --ts-config ../config.properties --model-store model_store --models llama-7b-lora
```

### Step 5: Run inference

```bash
python ../../utils/test_llm_streaming_response.py -m lora -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json
```
20 changes: 20 additions & 0 deletions examples/large_models/vllm/lora/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# TorchServe frontend parameters
minWorkers: 1
maxWorkers: 1
batchSize: 16
maxBatchDelay: 100
responseTimeout: 1200
deviceType: "gpu"
continuousBatching: true

handler:
model_path: "model/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9"
vllm_engine_config:
enable_lora: true
max_loras: 4
max_cpu_loras: 4
max_num_seqs: 16
max_model_len: 250

adapters:
adapter_1: "adapters/model/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/"
9 changes: 9 additions & 0 deletions examples/large_models/vllm/lora/prompt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"prompt": "A robot may not injure a human being",
"max_new_tokens": 50,
"temperature": 0.8,
"logprobs": 1,
"prompt_logprobs": 1,
"max_tokens": 128,
"adapter": "adapter_1"
}
39 changes: 14 additions & 25 deletions examples/large_models/vllm/mistral/Readme.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Example showing inference with vLLM with mistralai/Mistral-7B-v0.1 model
# Example showing inference with vLLM on Mistral model

This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on `mistralai/Mistral-7B-v0.1` model.
vLLM achieves high throughput using PagedAttention. More details can be found [here](https://vllm.ai/)
This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `mistralai/Mistral-7B-v0.1` with continuous batching.

### Step 1: Login to HuggingFace
### Step 1: Download Model from HuggingFace

Login with a HuggingFace account
```
Expand All @@ -13,43 +12,33 @@ huggingface-cli login --token $HUGGINGFACE_TOKEN
```

```bash
python ../../Huggingface_accelerate/Download_model.py --model_path model --model_name mistralai/Mistral-7B-v0.1
python ../../utils/Download_model.py --model_path model --model_name mistralai/Mistral-7B-v0.1 --use_auth_token True
```
Model will be saved in the following path, `mistralai/Mistral-7B-v0.1`.

### Step 2: Generate MAR file
### Step 2: Generate model artifacts

Add the downloaded path to " model_path:" in `model-config.yaml` and run the following.
Add the downloaded path to "model_path:" in `model-config.yaml` and run the following.

```bash
torch-model-archiver --model-name mistral7b --version 1.0 --handler custom_handler.py --config-file model-config.yaml -r requirements.txt --archive-format tgz
torch-model-archiver --model-name mistral --version 1.0 --handler ../base_vllm_handler.py --config-file model-config.yaml -r ../requirements.txt --archive-format no-archive
mv model mistral
```

### Step 3: Add the mar file to model store
### Step 3: Add the model artifacts to model store

```bash
mkdir model_store
mv mistral7b.tar.gz model_store
mv mistral model_store
```

### Step 3: Start torchserve

### Step 4: Start torchserve

```bash
torchserve --start --ncs --ts-config config.properties --model-store model_store --models mistral7b.tar.gz
torchserve --start --ncs --ts-config ../config.properties --model-store model_store --models mistral
```

### Step 4: Run inference
### Step 5: Run inference

```bash
curl -v "http://localhost:8080/predictions/mistral7b" -T sample_text.txt
```

results in the following output
```
Mayonnaise is made of eggs, oil, vinegar, salt and pepper. Using an electric blender, combine all the ingredients and beat at high speed for 4 to 5 minutes.
Try it with some mustard and paprika mixed in, and a bit of sweetener if you like. But use real mayonnaise or it isn’t the same. Marlou
What in the world is mayonnaise?
python ../../utils/test_llm_streaming_response.py -m mistral -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json
```
Loading

0 comments on commit f2c26f3

Please sign in to comment.