Skip to content

Commit

Permalink
Implement OpenAI-compatible server
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Aug 7, 2023
1 parent ffeb98f commit 2d8a4cd
Show file tree
Hide file tree
Showing 3 changed files with 1,015 additions and 0 deletions.
94 changes: 94 additions & 0 deletions examples/openai-server/outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional


class CompletionOutput:
"""The output data of one completion output of a request.
Args:
index: The index of the output in the request.
text: The generated output text.
token_ids: The token IDs of the generated output text.
cumulative_logprob: The cumulative log probability of the generated
output text.
logprobs: The log probabilities of the top probability words at each
position if the logprobs are requested.
finish_reason: The reason why the sequence is finished.
"""

def __init__(
self,
index: int,
text: str,
token_ids: List[int],
cumulative_logprob: float = 0.0,
logprobs: Optional[List[Dict[int, float]]] = None,
finish_reason: Optional[str] = None,
) -> None:
self.index = index
self.text = text
self.token_ids = token_ids
self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs
self.finish_reason = finish_reason

def finished(self) -> bool:
return self.finish_reason is not None

def __repr__(self) -> str:
return (
f"CompletionOutput(index={self.index}, "
f"text={self.text!r}, "
f"token_ids={self.token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"logprobs={self.logprobs}, "
f"finish_reason={self.finish_reason})"
)


class RequestOutput:
"""The output data of a request to the LLM.
Args:
request_id: The unique ID of the request.
prompt: The prompt string of the request.
prompt_token_ids: The token IDs of the prompt.
outputs: The output sequences of the request.
finished: Whether the whole request is finished.
"""

def __init__(
self,
request_id: str,
prompt: str,
prompt_token_ids: List[int],
outputs: List[CompletionOutput],
finished: bool,
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.outputs = outputs
self.finished = finished

def __repr__(self) -> str:
return (
f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"outputs={self.outputs}, "
f"finished={self.finished})"
)
190 changes: 190 additions & 0 deletions examples/openai-server/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import uuid
from typing import Dict, List, Literal, Optional, Union

from pydantic import BaseModel, Field


def random_uuid() -> str:
return str(uuid.uuid4().hex)


class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: Optional[str] = None


class ModelPermission(BaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
allow_create_engine: bool = False
allow_sampling: bool = True
allow_logprobs: bool = True
allow_search_indices: bool = False
allow_view: bool = True
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
is_blocking: str = False


class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "neuralmagic"
root: Optional[str] = None
parent: Optional[str] = None
permission: List[ModelPermission] = Field(default_factory=list)


class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)


class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0


class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
max_tokens: Optional[int] = 16
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Additional parameters
best_of: Optional[int] = None
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False


class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[str]]
suffix: Optional[str] = None
max_tokens: Optional[int] = 16
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stream: Optional[bool] = False
logprobs: Optional[int] = None
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Additional parameters
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False


class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)


class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


class CompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo


class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


class CompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]


class ChatMessage(BaseModel):
role: str
content: str


class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[Literal["stop", "length"]] = None


class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo


class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None


class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None


class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
Loading

0 comments on commit 2d8a4cd

Please sign in to comment.