Skip to content

Commit

Permalink
[server] Refactor + OpenAI Chat Completion Support (#1288)
Browse files Browse the repository at this point in the history
* refactor server for different integrations; additional functionality for chat completion streaming and non streaming

* further refactor server

* add support such that openai can host multiple models

* update all tests

* fix output for n > 1

* add inline comment explaining ProxyPipeline

* [server] Update OpenAI Model Support (#1300)

* update server

* allow users to send requests with new models

* use v1; move around baseroutes

* add openai path

* PR comments

* clean-up output classes to be dataclasses, add docstrings, cleanup generation kwargs
  • Loading branch information
dsikka committed Oct 10, 2023
1 parent 6cffa85 commit d99a82c
Show file tree
Hide file tree
Showing 16 changed files with 1,247 additions and 476 deletions.
2 changes: 1 addition & 1 deletion examples/openai-server/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class UsageInfo(BaseModel):

class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[Dict[str, str]]]
messages: Union[str, List[str]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def _setup_entry_points() -> Dict:
"deepsparse.benchmark_pipeline=deepsparse.benchmark.benchmark_pipeline:main", # noqa E501
"deepsparse.benchmark_sweep=deepsparse.benchmark.benchmark_sweep:main",
"deepsparse.server=deepsparse.server.cli:main",
"deepsparse.openai=deepsparse.server.cli:openai",
"deepsparse.object_detection.annotate=deepsparse.yolo.annotate:main",
"deepsparse.yolov8.annotate=deepsparse.yolov8.annotate:main",
"deepsparse.yolov8.eval=deepsparse.yolov8.validation:main",
Expand Down
76 changes: 42 additions & 34 deletions src/deepsparse/server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@

from deepsparse.pipeline import SupportedTasks
from deepsparse.server.config import EndpointConfig, ServerConfig
from deepsparse.server.server import start_server
from deepsparse.server.deepsparse_server import DeepsparseServer
from deepsparse.server.openai_server import OpenAIServer
from deepsparse.server.sagemaker import SagemakerServer


HOST_OPTION = click.option(
Expand Down Expand Up @@ -109,7 +111,7 @@

INTEGRATION_OPTION = click.option(
"--integration",
type=click.Choice(["local", "sagemaker"], case_sensitive=False),
type=click.Choice(["local", "sagemaker", "openai"], case_sensitive=False),
default="local",
help=(
"Name of deployment integration that this server will be deployed to "
Expand Down Expand Up @@ -206,6 +208,20 @@ def main(
...
```
"""

def _fetch_server(integration: str, config_path: str):
if integration == "local":
return DeepsparseServer(server_config=config_path)
elif integration == "sagemaker":
return SagemakerServer(server_config=config_path)
elif integration == "openai":
return OpenAIServer(server_config=config_path)
else:
raise ValueError(
f"{integration} is not a supported integration. Must be "
"one of local, sagemkaer or openai."
)

if ctx.invoked_subcommand is not None:
return

Expand Down Expand Up @@ -234,14 +250,33 @@ def main(
config_path = os.path.join(tmp_dir, "server-config.yaml")
with open(config_path, "w") as fp:
yaml.dump(cfg.dict(), fp)
start_server(
config_path, host, port, log_level, hot_reload_config=hot_reload_config

server = _fetch_server(integration=integration, config_path=config_path)
server.start_server(
host, port, log_level, hot_reload_config=hot_reload_config
)

if config_file is not None:
start_server(
config_file, host, port, log_level, hot_reload_config=hot_reload_config
)
server = _fetch_server(integration=integration, config_path=config_file)
server.start_server(host, port, log_level, hot_reload_config=hot_reload_config)


@main.command(
context_settings=dict(
token_normalize_func=lambda x: x.replace("-", "_"), show_default=True
),
)
@click.argument("config-file", type=str)
@HOST_OPTION
@PORT_OPTION
@LOG_LEVEL_OPTION
@HOT_RELOAD_OPTION
def openai(
config_file: str, host: str, port: int, log_level: str, hot_reload_config: bool
):

server = OpenAIServer(server_config=config_file)
server.start_server(host, port, log_level, hot_reload_config=hot_reload_config)


@main.command(
Expand All @@ -264,9 +299,6 @@ def config(
"Use the `--config_file` argument instead.",
category=DeprecationWarning,
)
start_server(
config_path, host, port, log_level, hot_reload_config=hot_reload_config
)


@main.command(
Expand Down Expand Up @@ -311,30 +343,6 @@ def task(
category=DeprecationWarning,
)

cfg = ServerConfig(
num_cores=num_cores,
num_workers=num_workers,
integration=integration,
endpoints=[
EndpointConfig(
task=task,
name=f"{task}",
route="/predict",
model=model_path,
batch_size=batch_size,
)
],
loggers={},
)

with TemporaryDirectory() as tmp_dir:
config_path = os.path.join(tmp_dir, "server-config.yaml")
with open(config_path, "w") as fp:
yaml.dump(cfg.dict(), fp)
start_server(
config_path, host, port, log_level, hot_reload_config=hot_reload_config
)


if __name__ == "__main__":
main()
169 changes: 169 additions & 0 deletions src/deepsparse/server/deepsparse_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from functools import partial

from deepsparse import Pipeline
from deepsparse.server.config import EndpointConfig
from deepsparse.server.server import CheckReady, ModelMetaData, ProxyPipeline, Server
from fastapi import FastAPI


_LOGGER = logging.getLogger(__name__)


class DeepsparseServer(Server):
def _add_routes(self, app):
@app.get("/v2/health/ready", tags=["health"], response_model=CheckReady)
@app.get("/v2/health/live", tags=["health"], response_model=CheckReady)
def _check_health():
return CheckReady(status="OK")

@app.get("/v2", tags=["metadata", "server"], response_model=str)
def _get_server_info():
return "This is the deepsparse server. Hello!"

@app.post("/endpoints", tags=["endpoints"], response_model=bool)
def _add_endpoint_endpoint(cfg: EndpointConfig):
if cfg.name is None:
cfg.name = f"endpoint-{len(app.routes)}"
self._add_endpoint(
app,
cfg,
)
# force regeneration of the docs
app.openapi_schema = None
return True

@app.delete("/endpoints", tags=["endpoints"], response_model=bool)
def _delete_endpoint(cfg: EndpointConfig):
_LOGGER.info(f"Deleting endpoint for {cfg}")
matching = [r for r in app.routes if r.path == cfg.route]
assert len(matching) == 1
app.routes.remove(matching[0])
# force regeneration of the docs
app.openapi_schema = None
return True

for endpoint_config in self.server_config.endpoints:
self._add_endpoint(
app,
endpoint_config,
)

_LOGGER.info(f"Added endpoints: {[route.path for route in app.routes]}")
return app

def _add_endpoint(
self,
app: FastAPI,
endpoint_config: EndpointConfig,
):
pipeline_config = endpoint_config.to_pipeline_config()
pipeline_config.kwargs["executor"] = self.executor

_LOGGER.info(f"Initializing pipeline for '{endpoint_config.name}'")
pipeline = Pipeline.from_config(
pipeline_config, self.context, self.server_logger
)

_LOGGER.info(f"Adding endpoints for '{endpoint_config.name}'")
self._add_inference_endpoints(
app,
endpoint_config,
pipeline,
)
self._add_status_and_metadata_endpoints(app, endpoint_config, pipeline)

def _add_status_and_metadata_endpoints(
self,
app: FastAPI,
endpoint_config: EndpointConfig,
pipeline: Pipeline,
):
routes_and_fns = []
meta_and_fns = []

if endpoint_config.route:
endpoint_config.route = self.clean_up_route(endpoint_config.route)
route_ready = f"{endpoint_config.route}/ready"
route_meta = endpoint_config.route
else:
route_ready = f"/v2/models/{endpoint_config.name}/ready"
route_meta = f"/v2/models/{endpoint_config.name}"

routes_and_fns.append((route_ready, Server.pipeline_ready))
meta_and_fns.append(
(route_meta, partial(Server.model_metadata, ProxyPipeline(pipeline)))
)

self._update_routes(
app=app,
routes_and_fns=meta_and_fns,
response_model=ModelMetaData,
methods=["GET"],
tags=["model", "metadata"],
)
self._update_routes(
app=app,
routes_and_fns=routes_and_fns,
response_model=CheckReady,
methods=["GET"],
tags=["model", "health"],
)

def _add_inference_endpoints(
self,
app: FastAPI,
endpoint_config: EndpointConfig,
pipeline: Pipeline,
):
routes_and_fns = []
route = (
f"{endpoint_config.route}/infer"
if endpoint_config.route
else f"/v2/models/{endpoint_config.name}/infer"
)
route = self.clean_up_route(route)

routes_and_fns.append(
(
route,
partial(
Server.predict,
ProxyPipeline(pipeline),
self.server_config.system_logging,
),
)
)
if hasattr(pipeline.input_schema, "from_files"):
routes_and_fns.append(
(
route + "/from_files",
partial(
Server.predict_from_files,
ProxyPipeline(pipeline),
self.server_config.system_logging,
),
)
)

self._update_routes(
app=app,
routes_and_fns=routes_and_fns,
response_model=pipeline.output_schema,
methods=["POST"],
tags=["model", "inference"],
)
16 changes: 15 additions & 1 deletion src/deepsparse/server/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from http import HTTPStatus
from typing import Any, Dict, List, Optional

import numpy
Expand All @@ -25,9 +26,22 @@
)
from deepsparse.loggers.config import MetricFunctionConfig, SystemLoggingGroup
from deepsparse.server.config import EndpointConfig, ServerConfig
from deepsparse.server.protocol import ErrorResponse
from fastapi.responses import JSONResponse


__all__ = ["server_logger_from_config", "prep_outputs_for_serialization"]
__all__ = [
"create_error_response",
"server_logger_from_config",
"prep_outputs_for_serialization",
]


def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
return JSONResponse(
ErrorResponse(message=message, type="invalid_request_error").dict(),
status_code=status_code.value,
)


def server_logger_from_config(config: ServerConfig) -> BaseLogger:
Expand Down
Loading

0 comments on commit d99a82c

Please sign in to comment.