Skip to content

Commit

Permalink
[Pipeline Refactor][server][OpenAI] Enable OpenAI to use new text gen…
Browse files Browse the repository at this point in the history
… pipeline (#1477)

* add pipeline.from_config

* update pipeline config; update deepsparses server to use new pipeline

* update tests

* fix imports in server config.py

* fix state update with parse_inputs being added

* update openai server to use new v2 text gen pipeline; disable streaming while unavilable in v2
  • Loading branch information
dsikka committed Dec 14, 2023
1 parent 7c183bb commit 543199e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
36 changes: 27 additions & 9 deletions src/deepsparse/server/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from http import HTTPStatus
from typing import AsyncGenerator, Dict, List, Optional

from deepsparse.legacy import Pipeline
from deepsparse.legacy.tasks import SupportedTasks
from deepsparse import Pipeline
from deepsparse.server.config import EndpointConfig
from deepsparse.server.helpers import create_error_response
from deepsparse.server.output import CompletionOutput, RequestOutput
Expand All @@ -43,6 +42,8 @@
random_uuid,
)
from deepsparse.server.server import Server
from deepsparse.tasks import SupportedTasks
from deepsparse.utils import InferenceState
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse

Expand Down Expand Up @@ -92,6 +93,12 @@ async def create_chat_completion(raw_request: Request):
request = ChatCompletionRequest(**await raw_request.json())
_LOGGER.debug("Received chat completion request %s" % request)

# disable streaming until enabled for v2.
if request.stream:
return create_error_response(
HTTPStatus.BAD_REQUEST, "Streaming is unavilable"
)

if isinstance(request.messages, str):
prompt = request.messages
else:
Expand Down Expand Up @@ -215,6 +222,12 @@ async def create_completion(raw_request: Request):
request = CompletionRequest(**await raw_request.json())
_LOGGER.debug("Received completion request: %s" % request)

# disable streaming until enabled for v2.
if request.stream:
return create_error_response(
HTTPStatus.BAD_REQUEST, "Streaming is unavilable"
)

model = request.model

pipeline = app.model_to_pipeline.get(model)
Expand Down Expand Up @@ -318,19 +331,20 @@ def _add_model(
:param endpoint_config: endpoint config for the specific model being added
"""
pipeline_config = endpoint_config.to_pipeline_config()
pipeline_config.kwargs["executor"] = self.executor

_LOGGER.debug("Initializing pipeline for %s" % endpoint_config.name)

if not SupportedTasks.is_text_generation(pipeline_config.task):
if not (
SupportedTasks.is_text_generation(pipeline_config.task)
or SupportedTasks.is_code_generation(pipeline_config.task)
):
raise ValueError(
"OpenAI API is only available for one of the following "
f"tasks: {SupportedTasks.text_generation._fields}"
f"tasks: {SupportedTasks.text_generation._fields}, "
f"{SupportedTasks.code_generation._fields}"
)

pipeline = Pipeline.from_config(
pipeline_config, self.context, self.server_logger
)
pipeline = Pipeline.from_config(pipeline_config, context=self.context)

if not self.model_to_pipeline.get(endpoint_config.model):
model_card = ModelCard(
Expand Down Expand Up @@ -369,7 +383,11 @@ def tokenize(text: str) -> List[int]:
presence_penalty = generation_kwargs.pop("presence_penalty")
stop = generation_kwargs.pop("stop")

output = pipeline(
inference_state = InferenceState()
inference_state.create_state({})

output = await pipeline.run_async(
inference_state=inference_state,
sequences=prompt,
generation_config=generation_kwargs,
streaming=stream,
Expand Down
2 changes: 1 addition & 1 deletion tests/server/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from deepsparse.legacy import Pipeline
from deepsparse import Pipeline
from deepsparse.server.config import EndpointConfig, ServerConfig
from deepsparse.server.openai_server import (
ChatCompletionRequest,
Expand Down

0 comments on commit 543199e

Please sign in to comment.