-
Notifications
You must be signed in to change notification settings - Fork 168
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement OpenAI-compatible server (#1171)
* Implement OpenAI-compatible server * Add client example and README * Style * Update and rebase * Format * Format
- Loading branch information
Showing
7 changed files
with
1,538 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import openai | ||
|
||
|
||
# Modify OpenAI's API values to use the DeepSparse API server. | ||
openai.api_key = "EMPTY" | ||
openai.api_base = "http://localhost:8000/v1" | ||
|
||
# List models API | ||
models = openai.Model.list() | ||
print("Models:", models) | ||
|
||
model = models["data"][0]["id"] | ||
|
||
# Completion API | ||
stream = True | ||
completion = openai.Completion.create( | ||
model=model, prompt="def fib():", stream=stream, max_tokens=16 | ||
) | ||
|
||
print("Completion results:") | ||
if stream: | ||
for c in completion: | ||
print(c) | ||
else: | ||
print(completion) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Dict, List, Optional | ||
|
||
|
||
class CompletionOutput: | ||
"""The output data of one completion output of a request. | ||
Args: | ||
index: The index of the output in the request. | ||
text: The generated output text. | ||
token_ids: The token IDs of the generated output text. | ||
cumulative_logprob: The cumulative log probability of the generated | ||
output text. | ||
logprobs: The log probabilities of the top probability words at each | ||
position if the logprobs are requested. | ||
finish_reason: The reason why the sequence is finished. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
index: int, | ||
text: str, | ||
token_ids: List[int], | ||
cumulative_logprob: float = 0.0, | ||
logprobs: Optional[List[Dict[int, float]]] = None, | ||
finish_reason: Optional[str] = None, | ||
) -> None: | ||
self.index = index | ||
self.text = text | ||
self.token_ids = token_ids | ||
self.cumulative_logprob = cumulative_logprob | ||
self.logprobs = logprobs | ||
self.finish_reason = finish_reason | ||
|
||
def finished(self) -> bool: | ||
return self.finish_reason is not None | ||
|
||
def __repr__(self) -> str: | ||
return ( | ||
f"CompletionOutput(index={self.index}, " | ||
f"text={self.text!r}, " | ||
f"token_ids={self.token_ids}, " | ||
f"cumulative_logprob={self.cumulative_logprob}, " | ||
f"logprobs={self.logprobs}, " | ||
f"finish_reason={self.finish_reason})" | ||
) | ||
|
||
|
||
class RequestOutput: | ||
"""The output data of a request to the LLM. | ||
Args: | ||
request_id: The unique ID of the request. | ||
prompt: The prompt string of the request. | ||
prompt_token_ids: The token IDs of the prompt. | ||
outputs: The output sequences of the request. | ||
finished: Whether the whole request is finished. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
request_id: str, | ||
prompt: str, | ||
prompt_token_ids: List[int], | ||
outputs: List[CompletionOutput], | ||
finished: bool, | ||
) -> None: | ||
self.request_id = request_id | ||
self.prompt = prompt | ||
self.prompt_token_ids = prompt_token_ids | ||
self.outputs = outputs | ||
self.finished = finished | ||
|
||
def __repr__(self) -> str: | ||
return ( | ||
f"RequestOutput(request_id={self.request_id}, " | ||
f"prompt={self.prompt!r}, " | ||
f"prompt_token_ids={self.prompt_token_ids}, " | ||
f"outputs={self.outputs}, " | ||
f"finished={self.finished})" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# Adapted from | ||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py | ||
|
||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import time | ||
import uuid | ||
from typing import Dict, List, Literal, Optional, Union | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
def random_uuid() -> str: | ||
return str(uuid.uuid4().hex) | ||
|
||
|
||
class ErrorResponse(BaseModel): | ||
object: str = "error" | ||
message: str | ||
type: str | ||
param: Optional[str] = None | ||
code: Optional[str] = None | ||
|
||
|
||
class ModelPermission(BaseModel): | ||
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") | ||
object: str = "model_permission" | ||
created: int = Field(default_factory=lambda: int(time.time())) | ||
allow_create_engine: bool = False | ||
allow_sampling: bool = True | ||
allow_logprobs: bool = True | ||
allow_search_indices: bool = False | ||
allow_view: bool = True | ||
allow_fine_tuning: bool = False | ||
organization: str = "*" | ||
group: Optional[str] = None | ||
is_blocking: str = False | ||
|
||
|
||
class ModelCard(BaseModel): | ||
id: str | ||
object: str = "model" | ||
created: int = Field(default_factory=lambda: int(time.time())) | ||
owned_by: str = "neuralmagic" | ||
root: Optional[str] = None | ||
parent: Optional[str] = None | ||
permission: List[ModelPermission] = Field(default_factory=list) | ||
|
||
|
||
class ModelList(BaseModel): | ||
object: str = "list" | ||
data: List[ModelCard] = Field(default_factory=list) | ||
|
||
|
||
class UsageInfo(BaseModel): | ||
prompt_tokens: int = 0 | ||
total_tokens: int = 0 | ||
completion_tokens: Optional[int] = 0 | ||
|
||
|
||
class ChatCompletionRequest(BaseModel): | ||
model: str | ||
messages: Union[str, List[Dict[str, str]]] | ||
temperature: Optional[float] = 0.7 | ||
top_p: Optional[float] = 1.0 | ||
n: Optional[int] = 1 | ||
max_tokens: Optional[int] = 16 | ||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) | ||
stream: Optional[bool] = False | ||
presence_penalty: Optional[float] = 0.0 | ||
frequency_penalty: Optional[float] = 0.0 | ||
logit_bias: Optional[Dict[str, float]] = None | ||
user: Optional[str] = None | ||
# Additional parameters | ||
best_of: Optional[int] = None | ||
top_k: Optional[int] = -1 | ||
ignore_eos: Optional[bool] = False | ||
use_beam_search: Optional[bool] = False | ||
|
||
|
||
class CompletionRequest(BaseModel): | ||
model: str | ||
prompt: Union[str, List[str]] | ||
suffix: Optional[str] = None | ||
max_tokens: Optional[int] = 16 | ||
temperature: Optional[float] = 1.0 | ||
top_p: Optional[float] = 1.0 | ||
n: Optional[int] = 1 | ||
stream: Optional[bool] = False | ||
logprobs: Optional[int] = None | ||
echo: Optional[bool] = False | ||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) | ||
presence_penalty: Optional[float] = 0.0 | ||
frequency_penalty: Optional[float] = 0.0 | ||
best_of: Optional[int] = None | ||
logit_bias: Optional[Dict[str, float]] = None | ||
user: Optional[str] = None | ||
# Additional parameters | ||
top_k: Optional[int] = -1 | ||
ignore_eos: Optional[bool] = False | ||
use_beam_search: Optional[bool] = False | ||
|
||
|
||
class LogProbs(BaseModel): | ||
text_offset: List[int] = Field(default_factory=list) | ||
token_logprobs: List[Optional[float]] = Field(default_factory=list) | ||
tokens: List[str] = Field(default_factory=list) | ||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) | ||
|
||
|
||
class CompletionResponseChoice(BaseModel): | ||
index: int | ||
text: str | ||
logprobs: Optional[LogProbs] = None | ||
finish_reason: Optional[Literal["stop", "length"]] = None | ||
|
||
|
||
class CompletionResponse(BaseModel): | ||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") | ||
object: str = "text_completion" | ||
created: int = Field(default_factory=lambda: int(time.time())) | ||
model: str | ||
choices: List[CompletionResponseChoice] | ||
usage: UsageInfo | ||
|
||
|
||
class CompletionResponseStreamChoice(BaseModel): | ||
index: int | ||
text: str | ||
logprobs: Optional[LogProbs] = None | ||
finish_reason: Optional[Literal["stop", "length"]] = None | ||
|
||
|
||
class CompletionStreamResponse(BaseModel): | ||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") | ||
object: str = "text_completion" | ||
created: int = Field(default_factory=lambda: int(time.time())) | ||
model: str | ||
choices: List[CompletionResponseStreamChoice] | ||
|
||
|
||
class ChatMessage(BaseModel): | ||
role: str | ||
content: str | ||
|
||
|
||
class ChatCompletionResponseChoice(BaseModel): | ||
index: int | ||
message: ChatMessage | ||
finish_reason: Optional[Literal["stop", "length"]] = None | ||
|
||
|
||
class ChatCompletionResponse(BaseModel): | ||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") | ||
object: str = "chat.completion" | ||
created: int = Field(default_factory=lambda: int(time.time())) | ||
model: str | ||
choices: List[ChatCompletionResponseChoice] | ||
usage: UsageInfo | ||
|
||
|
||
class DeltaMessage(BaseModel): | ||
role: Optional[str] = None | ||
content: Optional[str] = None | ||
|
||
|
||
class ChatCompletionResponseStreamChoice(BaseModel): | ||
index: int | ||
delta: DeltaMessage | ||
finish_reason: Optional[Literal["stop", "length"]] = None | ||
|
||
|
||
class ChatCompletionStreamResponse(BaseModel): | ||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") | ||
object: str = "chat.completion.chunk" | ||
created: int = Field(default_factory=lambda: int(time.time())) | ||
model: str | ||
choices: List[ChatCompletionResponseStreamChoice] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
deepsparse-nightly[server,transformers] | ||
openai |
Oops, something went wrong.