Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pipeline Refactor][server][OpenAI] Enable OpenAI to use new text gen pipeline #1477

Merged
merged 11 commits into from
Dec 14, 2023
36 changes: 27 additions & 9 deletions src/deepsparse/server/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from http import HTTPStatus
from typing import AsyncGenerator, Dict, List, Optional

from deepsparse.legacy import Pipeline
from deepsparse.legacy.tasks import SupportedTasks
from deepsparse import Pipeline
from deepsparse.server.config import EndpointConfig
from deepsparse.server.helpers import create_error_response
from deepsparse.server.output import CompletionOutput, RequestOutput
Expand All @@ -43,6 +42,8 @@
random_uuid,
)
from deepsparse.server.server import Server
from deepsparse.tasks import SupportedTasks
from deepsparse.utils import InferenceState
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse

Expand Down Expand Up @@ -92,6 +93,12 @@ async def create_chat_completion(raw_request: Request):
request = ChatCompletionRequest(**await raw_request.json())
_LOGGER.debug("Received chat completion request %s" % request)

# disable streaming until enabled for v2.
if request.stream:
return create_error_response(
HTTPStatus.BAD_REQUEST, "Streaming is unavilable"
)

if isinstance(request.messages, str):
prompt = request.messages
else:
Expand Down Expand Up @@ -215,6 +222,12 @@ async def create_completion(raw_request: Request):
request = CompletionRequest(**await raw_request.json())
_LOGGER.debug("Received completion request: %s" % request)

# disable streaming until enabled for v2.
if request.stream:
return create_error_response(
HTTPStatus.BAD_REQUEST, "Streaming is unavilable"
)

model = request.model

pipeline = app.model_to_pipeline.get(model)
Expand Down Expand Up @@ -318,19 +331,20 @@ def _add_model(
:param endpoint_config: endpoint config for the specific model being added
"""
pipeline_config = endpoint_config.to_pipeline_config()
pipeline_config.kwargs["executor"] = self.executor

_LOGGER.debug("Initializing pipeline for %s" % endpoint_config.name)

if not SupportedTasks.is_text_generation(pipeline_config.task):
if not (
SupportedTasks.is_text_generation(pipeline_config.task)
or SupportedTasks.is_code_generation(pipeline_config.task)
):
raise ValueError(
"OpenAI API is only available for one of the following "
f"tasks: {SupportedTasks.text_generation._fields}"
f"tasks: {SupportedTasks.text_generation._fields}, "
f"{SupportedTasks.code_generation._fields}"
)

pipeline = Pipeline.from_config(
pipeline_config, self.context, self.server_logger
)
pipeline = Pipeline.from_config(pipeline_config, context=self.context)

if not self.model_to_pipeline.get(endpoint_config.model):
model_card = ModelCard(
Expand Down Expand Up @@ -369,7 +383,11 @@ def tokenize(text: str) -> List[int]:
presence_penalty = generation_kwargs.pop("presence_penalty")
stop = generation_kwargs.pop("stop")

output = pipeline(
inference_state = InferenceState()
inference_state.create_state({})

output = await pipeline.run_async(
inference_state=inference_state,
sequences=prompt,
generation_config=generation_kwargs,
streaming=stream,
Expand Down
2 changes: 1 addition & 1 deletion tests/server/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from deepsparse.legacy import Pipeline
from deepsparse import Pipeline
from deepsparse.server.config import EndpointConfig, ServerConfig
from deepsparse.server.openai_server import (
ChatCompletionRequest,
Expand Down
Loading