Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[server] Update OpenAI endpoints #1445

Merged
merged 15 commits into from
Dec 13, 2023
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def _parse_requirements_file(file_path):
"transformers<4.35",
"datasets<=2.14.6",
"scikit-learn",
"fschat==0.2.33",
"accelerate==0.24.1",
"seqeval",
]
_sentence_transformers_integration_deps = ["optimum-deepsparse"] + _torch_deps
Expand Down
216 changes: 206 additions & 10 deletions src/deepsparse/server/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
DeltaMessage,
ModelCard,
ModelList,
Expand Down Expand Up @@ -79,22 +84,54 @@ async def show_available_models():
response_model=ChatCompletionResponse,
)
async def create_chat_completion(raw_request: Request):
"""Completion API similar to OpenAI's API.
# Completion API similar to OpenAI's API.

# See https://platform.openai.com/docs/api-reference/chat/create
# for the API specification. This API mimics the OpenAI ChatCompletion API.

See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
"""
request = ChatCompletionRequest(**await raw_request.json())
_LOGGER.info(f"Received chat completion request: {request}")

prompt = request.messages
if isinstance(request.messages, str):
prompt = request.messages
else:
# else case assums a FastChat-compliant dictionary
# Fetch a model-specific template from FastChat
_LOGGER.warning(
"A dictionary message was found. This dictionary must "
"be fastchat compliant."
)
try:
from fastchat.model.model_adapter import get_conversation_template
except ImportError as e:
return create_error_response(HTTPStatus.FAILED_DEPENDENCY, str(e))

conv = get_conversation_template(request.model)
message = request.messages
# add the model to the Conversation template, based on the given role
msg_role = message["role"]
if msg_role == "system":
conv.system_message = message["content"]
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
else:
return create_error_response(
HTTPStatus.BAD_REQUEST, "Message role not recognized"
)

# blank message to start generation
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
model = request.model

pipeline = app.model_to_pipeline.get(model)
if not pipeline:
create_error_response(
return create_error_response(
HTTPStatus.BAD_REQUEST, f"{model} is not available"
)

Expand Down Expand Up @@ -126,7 +163,7 @@ async def abort_request() -> None:
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(
completion_stream_generator(
chat_completion_stream_generator(
request,
result_generator,
request_id=request_id,
Expand All @@ -143,7 +180,7 @@ async def abort_request() -> None:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
return OpenAIServer.create_error_response(
return create_error_response(
HTTPStatus.BAD_REQUEST, "Client disconnected"
)
final_res = res
Expand Down Expand Up @@ -175,6 +212,107 @@ async def abort_request() -> None:
)
return response

@app.post(
"/v1/completions",
tags=["model", "inference"],
response_model=CompletionResponse,
)
async def create_completion(raw_request: Request):
request = CompletionRequest(**await raw_request.json())
_LOGGER.info(f"Received completion request: {request}")
dsikka marked this conversation as resolved.
Show resolved Hide resolved
dsikka marked this conversation as resolved.
Show resolved Hide resolved

model = request.model

pipeline = app.model_to_pipeline.get(model)
if not pipeline:
return create_error_response(
HTTPStatus.BAD_REQUEST, f"{model} is not available"
)

request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
prompt = request.prompt

try:
sampling_params = dict(
dsikka marked this conversation as resolved.
Show resolved Hide resolved
num_return_sequences=request.n,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
max_tokens=request.max_tokens,
stream=request.stream,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))

result_generator = OpenAIServer.generate(
prompt=prompt,
request_id=request_id,
generation_kwargs=sampling_params,
pipeline=pipeline,
)

async def abort_request() -> None:
await pipeline.abort(request_id)
dsikka marked this conversation as resolved.
Show resolved Hide resolved

# Streaming response
if request.stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(
completion_stream_generator(
request=request,
request_id=request_id,
result_generator=result_generator,
pipeline=pipeline,
created_time=created_time,
),
media_type="text/event-stream",
background=background_tasks,
)

# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
return create_error_response(
HTTPStatus.BAD_REQUEST, "Client disconnected"
)
final_res = res
assert final_res is not None
choices = []
for output in final_res.outputs:
choice_data = CompletionResponseChoice(
index=output.index,
text=output.text,
finish_reason=output.finish_reason,
)
choices.append(choice_data)

num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs
)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = CompletionResponse(
id=request_id,
created=created_time,
model=model,
choices=choices,
usage=usage,
)
return response

return app

def _add_model(
Expand Down Expand Up @@ -210,7 +348,7 @@ def _add_model(
@staticmethod
async def generate(
prompt: str, request_id: str, generation_kwargs: dict, pipeline: Pipeline
) -> AsyncGenerator[RequestOutput, None]:
):
dsikka marked this conversation as resolved.
Show resolved Hide resolved
def tokenize(text: str) -> List[int]:
return pipeline.tokenizer(text)

Expand Down Expand Up @@ -244,6 +382,7 @@ def tokenize(text: str) -> List[int]:
finish_reason=prompt_generation.finished_reason,
)
generated_outputs.append(completion)

yield RequestOutput(
request_id=request_id,
prompt=prompt,
Expand Down Expand Up @@ -292,7 +431,7 @@ def map_generation_schema(generation_kwargs: Dict) -> Dict:
if generation_kwargs["num_return_sequences"] > 1:
generation_kwargs["do_sample"] = True

return generation_kwargs
return dict(generation_kwargs)
dsikka marked this conversation as resolved.
Show resolved Hide resolved


def create_stream_response_json(
Expand All @@ -319,8 +458,65 @@ def create_stream_response_json(
return response_json


def create_completion_stream_response_json(
index: int,
text: str,
request_id: str,
created_time: int,
pipeline: Pipeline,
finish_reason: Optional[str] = None,
) -> str:
choice_data = CompletionResponseStreamChoice(
dsikka marked this conversation as resolved.
Show resolved Hide resolved
index=index,
text=text,
finish_reason=finish_reason,
)
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=pipeline.model_path,
choices=[choice_data],
)
response_json = response.json(ensure_ascii=False)

return response_json


async def completion_stream_generator(
request, result_generator, request_id, created_time, pipeline
) -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
async for res in result_generator:
res: RequestOutput
dsikka marked this conversation as resolved.
Show resolved Hide resolved
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]) :]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
response_json = create_completion_stream_response_json(
index=i,
text=delta_text,
request_id=request_id,
created_time=created_time,
pipeline=pipeline,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
response_json = create_completion_stream_response_json(
index=i,
text="",
request_id=request_id,
created_time=created_time,
pipeline=pipeline,
finish_reason=output.finish_reason,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"


async def chat_completion_stream_generator(
request, result_generator, request_id, created_time, pipeline
) -> AsyncGenerator[str, None]:
# First chunk with role
for i in range(request.n):
Expand Down
Loading
Loading