Skip to content

Commit

Permalink
[NeuralChat] Support multi cards streaming inference on Gaudi (#867)
Browse files Browse the repository at this point in the history
* Support multi cards streaming inference on Gaudi

Signed-off-by: lvliang-intel <liang1.lv@intel.com>
  • Loading branch information
lvliang-intel authored Dec 6, 2023
1 parent 8e3ffa1 commit 9ad75c2
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ model_name_or_path: "Phind/Phind-CodeLlama-34B-v2"
device: "hpu"
use_deepspeed: true
world_size: 8
master_port: 29500 # default value in deepspeed is 29500, users can change this value to avoid conflict

# task choices = ['textchat', 'voicechat', 'retrieval', 'text2image', 'finetune']
tasks_list: ['textchat']
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,14 @@ def predict(self, query, origin_query="", config=None):

if not query_include_prompt and not is_plugin_enabled("retrieval"):
query = self.prepare_prompt(query, self.model_name, config.task)

# Phind/Phind-CodeLlama-34B-v2 model accpects Alpaca/Vicuna instruction format.
if "phind" in self.model_name.lower():
conv_template = PromptTemplate(name="phind")
conv_template.append_message(conv_template.roles[0], query)
conv_template.append_message(conv_template.roles[1], None)
query = conv_template.get_prompt()

# LLM inference
response = predict(**construct_parameters(query, self.model_name, self.device, config))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,30 +282,13 @@ async def chat_completion_endpoint(request: ChatCompletionRequest):
if attr == "stream":
continue
setattr(gen_config, attr, value)
buffered_texts = ""
if request.stream:
generator, _ = chatbot.predict_stream(query=request.prompt, config=gen_config)
if not isinstance(generator, types.GeneratorType):
generator = (generator,)
def stream_generator():
nonlocal buffered_texts
for output in generator:
if isinstance(output, str):
chunks = output.split()
for chunk in chunks:
ret = {
"text": chunk,
"error_code": 0,
}
buffered_texts += chunk + ' '
yield json.dumps(ret).encode() + b"\0"
else:
ret = {
"text": output,
"error_code": 0,
}
buffered_texts += output + ' '
yield json.dumps(ret).encode() + b"\0"
yield output + "\0"
yield f"data: [DONE]\n\n"
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def init(self, config):
port = config.get("port", "80")
use_deepspeed = config.get("use_deepspeed", False)
world_size = config.get("world_size", 1)
master_port = config.get("master_port", 29500)
model_name_or_path = config.get("model_name_or_path", "meta-llama/Llama-2-7b-hf")
tokenizer_name_or_path = config.get("tokenizer_name_or_path", model_name_or_path)
peft_model_path = config.get("peft_model_path", "")
Expand Down Expand Up @@ -194,7 +195,7 @@ def init(self, config):
multi_hpu_server_file = os.path.abspath(
os.path.join(os.path.dirname(__file__), './multi_hpu_server.py'))
launch_str = f"deepspeed --num_nodes 1 --num_gpus {world_size} --no_local_rank \
{multi_hpu_server_file}"
--master_port {master_port} {multi_hpu_server_file}"
command_list = f"{launch_str} --habana --use_hpu_graphs --use_kv_cache --task chat \
--base_model_path {model_name_or_path} --host {host} --port {port} --api_list {api_str}"
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ async def show_available_models():
models.append(router.get_chatbot().model_name)
return {"models": models}

# router /v1/code_generation only supports non-streaming mode.
@router.post("/v1/code_generation")
async def code_generation_endpoint(chat_request: ChatCompletionRequest):
if router.use_deepspeed:
Expand All @@ -167,7 +168,9 @@ def send_request(port):
url = f'http://{router.host}:{port}/v1/code_generation'
response = requests.post(url, json=chat_request.dict())
response.raise_for_status()
responses.append(response.content)
json_response = json.loads(response.content)
chat_completion_response = ChatCompletionResponse(response=json_response['response'])
responses.append(chat_completion_response)
except requests.exceptions.RequestException as e:
print(f"Error sending/receiving on port {port}: {e}")

Expand All @@ -181,3 +184,52 @@ def send_request(port):
if ret is not None:
raise RuntimeError("Invalid parameter.")
return router.handle_chat_completion_request(chat_request)

# router /v1/code_chat supports both non-streaming and streaming mode.
@router.post("/v1/code_chat")
async def code_chat_endpoint(chat_request: ChatCompletionRequest):
if router.use_deepspeed:
if chat_request.stream:
responses = []
def generate_stream(port):
url = f'http://{router.host}:{port}/v1/code_generation'
response = requests.post(url, json=chat_request.dict(), stream=True, timeout=1000)
responses.append(response)
with futures.ThreadPoolExecutor(max_workers=router.world_size) as executor:
worker_ports = [router.port + i + 1 for i in range(router.world_size)]
executor.map(generate_stream, worker_ports)

while not responses:
pass
def generate():
if responses[0]:
for chunk in responses[0].iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield f"data: {chunk}\n\n"
yield f"data: [DONE]\n\n"

return StreamingResponse(generate(), media_type="text/event-stream")
else:
responses = []

def send_request(port):
try:
url = f'http://{router.host}:{port}/v1/code_generation'
response = requests.post(url, json=chat_request.dict())
response.raise_for_status()
json_response = json.loads(response.content)
chat_completion_response = ChatCompletionResponse(response=json_response['response'])
responses.append(chat_completion_response)
except requests.exceptions.RequestException as e:
print(f"Error sending/receiving on port {port}: {e}")

with futures.ThreadPoolExecutor(max_workers=router.world_size) as executor:
worker_ports = [router.port + i + 1 for i in range(router.world_size)]
executor.map(send_request, worker_ports)
if responses:
return responses[0]
else:
ret = check_completion_request(chat_request)
if ret is not None:
raise RuntimeError("Invalid parameter.")
return router.handle_chat_completion_request(chat_request)

0 comments on commit 9ad75c2

Please sign in to comment.