Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[server] Update server routes to be compliant with MLServer #1237

Merged
merged 16 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/openai-server/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class UsageInfo(BaseModel):

class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[Dict[str, str]]]
messages: Union[str, List[str]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def _setup_entry_points() -> Dict:
"deepsparse.benchmark_pipeline=deepsparse.benchmark.benchmark_pipeline:main", # noqa E501
"deepsparse.benchmark_sweep=deepsparse.benchmark.benchmark_sweep:main",
"deepsparse.server=deepsparse.server.cli:main",
"deepsparse.openai=deepsparse.server.cli:openai",
"deepsparse.object_detection.annotate=deepsparse.yolo.annotate:main",
"deepsparse.yolov8.annotate=deepsparse.yolov8.annotate:main",
"deepsparse.yolov8.eval=deepsparse.yolov8.validation:main",
Expand Down
76 changes: 42 additions & 34 deletions src/deepsparse/server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@

from deepsparse.pipeline import SupportedTasks
from deepsparse.server.config import EndpointConfig, ServerConfig
from deepsparse.server.server import start_server
from deepsparse.server.deepsparse_server import DeepsparseServer
from deepsparse.server.openai_server import OpenAIServer
from deepsparse.server.sagemaker import SagemakerServer


HOST_OPTION = click.option(
Expand Down Expand Up @@ -109,7 +111,7 @@

INTEGRATION_OPTION = click.option(
"--integration",
type=click.Choice(["local", "sagemaker"], case_sensitive=False),
type=click.Choice(["local", "sagemaker", "openai"], case_sensitive=False),
default="local",
help=(
"Name of deployment integration that this server will be deployed to "
Expand Down Expand Up @@ -206,6 +208,20 @@ def main(
...
```
"""

def _fetch_server(integration: str, config_path: str):
if integration == "local":
return DeepsparseServer(server_config=config_path)
elif integration == "sagemaker":
return SagemakerServer(server_config=config_path)
elif integration == "openai":
return OpenAIServer(server_config=config_path)
else:
raise ValueError(
f"{integration} is not a supported integration. Must be "
"one of local, sagemkaer or openai."
dsikka marked this conversation as resolved.
Show resolved Hide resolved
)

if ctx.invoked_subcommand is not None:
return

Expand Down Expand Up @@ -234,14 +250,33 @@ def main(
config_path = os.path.join(tmp_dir, "server-config.yaml")
with open(config_path, "w") as fp:
yaml.dump(cfg.dict(), fp)
start_server(
config_path, host, port, log_level, hot_reload_config=hot_reload_config

server = _fetch_server(integration=integration, config_path=config_path)
server.start_server(
host, port, log_level, hot_reload_config=hot_reload_config
)

if config_file is not None:
start_server(
config_file, host, port, log_level, hot_reload_config=hot_reload_config
)
server = _fetch_server(integration=integration, config_path=config_file)
server.start_server(host, port, log_level, hot_reload_config=hot_reload_config)


@main.command(
context_settings=dict(
token_normalize_func=lambda x: x.replace("-", "_"), show_default=True
),
)
@click.argument("config-file", type=str)
@HOST_OPTION
@PORT_OPTION
@LOG_LEVEL_OPTION
@HOT_RELOAD_OPTION
def openai(
config_file: str, host: str, port: int, log_level: str, hot_reload_config: bool
):

server = OpenAIServer(server_config=config_file)
server.start_server(host, port, log_level, hot_reload_config=hot_reload_config)


@main.command(
Expand All @@ -264,9 +299,6 @@ def config(
"Use the `--config_file` argument instead.",
category=DeprecationWarning,
)
start_server(
config_path, host, port, log_level, hot_reload_config=hot_reload_config
)


@main.command(
Expand Down Expand Up @@ -311,30 +343,6 @@ def task(
category=DeprecationWarning,
)

cfg = ServerConfig(
num_cores=num_cores,
num_workers=num_workers,
integration=integration,
endpoints=[
EndpointConfig(
task=task,
name=f"{task}",
route="/predict",
model=model_path,
batch_size=batch_size,
)
],
loggers={},
)

with TemporaryDirectory() as tmp_dir:
config_path = os.path.join(tmp_dir, "server-config.yaml")
with open(config_path, "w") as fp:
yaml.dump(cfg.dict(), fp)
start_server(
config_path, host, port, log_level, hot_reload_config=hot_reload_config
)


if __name__ == "__main__":
main()
169 changes: 169 additions & 0 deletions src/deepsparse/server/deepsparse_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from functools import partial

from deepsparse import Pipeline
from deepsparse.server.config import EndpointConfig
from deepsparse.server.server import CheckReady, ModelMetaData, ProxyPipeline, Server
from fastapi import FastAPI


_LOGGER = logging.getLogger(__name__)


class DeepsparseServer(Server):
def _add_routes(self, app):
@app.get("/v2/health/ready", tags=["health"], response_model=CheckReady)
@app.get("/v2/health/live", tags=["health"], response_model=CheckReady)
def _check_health():
return CheckReady(status="OK")

@app.get("/v2", tags=["metadata", "server"], response_model=str)
def _get_server_info():
return "This is the deepsparse server. Hello!"

@app.post("/endpoints", tags=["endpoints"], response_model=bool)
def _add_endpoint_endpoint(cfg: EndpointConfig):
if cfg.name is None:
cfg.name = f"endpoint-{len(app.routes)}"
self._add_endpoint(
app,
cfg,
)
# force regeneration of the docs
app.openapi_schema = None
return True

@app.delete("/endpoints", tags=["endpoints"], response_model=bool)
def _delete_endpoint(cfg: EndpointConfig):
_LOGGER.info(f"Deleting endpoint for {cfg}")
matching = [r for r in app.routes if r.path == cfg.route]
assert len(matching) == 1
app.routes.remove(matching[0])
# force regeneration of the docs
app.openapi_schema = None
return True

for endpoint_config in self.server_config.endpoints:
self._add_endpoint(
app,
endpoint_config,
)

_LOGGER.info(f"Added endpoints: {[route.path for route in app.routes]}")
return app

def _add_endpoint(
self,
app: FastAPI,
endpoint_config: EndpointConfig,
):
pipeline_config = endpoint_config.to_pipeline_config()
pipeline_config.kwargs["executor"] = self.executor

_LOGGER.info(f"Initializing pipeline for '{endpoint_config.name}'")
pipeline = Pipeline.from_config(
pipeline_config, self.context, self.server_logger
)

_LOGGER.info(f"Adding endpoints for '{endpoint_config.name}'")
self._add_inference_endpoints(
app,
endpoint_config,
pipeline,
)
self._add_status_and_metadata_endpoints(app, endpoint_config, pipeline)

def _add_status_and_metadata_endpoints(
self,
app: FastAPI,
endpoint_config: EndpointConfig,
pipeline: Pipeline,
):
routes_and_fns = []
meta_and_fns = []

if endpoint_config.route:
endpoint_config.route = self.clean_up_route(endpoint_config.route)
route_ready = f"{endpoint_config.route}/ready"
route_meta = endpoint_config.route
else:
route_ready = f"/v2/models/{endpoint_config.name}/ready"
route_meta = f"/v2/models/{endpoint_config.name}"

routes_and_fns.append((route_ready, Server.pipeline_ready))
meta_and_fns.append(
(route_meta, partial(Server.model_metadata, ProxyPipeline(pipeline)))
)

self._update_routes(
app=app,
routes_and_fns=meta_and_fns,
response_model=ModelMetaData,
methods=["GET"],
tags=["model", "metadata"],
)
self._update_routes(
app=app,
routes_and_fns=routes_and_fns,
response_model=CheckReady,
methods=["GET"],
tags=["model", "health"],
)

def _add_inference_endpoints(
self,
app: FastAPI,
endpoint_config: EndpointConfig,
pipeline: Pipeline,
):
routes_and_fns = []
route = (
f"{endpoint_config.route}/infer"
if endpoint_config.route
else f"/v2/models/{endpoint_config.name}/infer"
)
route = self.clean_up_route(route)

routes_and_fns.append(
(
route,
partial(
Server.predict,
ProxyPipeline(pipeline),
self.server_config.system_logging,
),
)
)
if hasattr(pipeline.input_schema, "from_files"):
routes_and_fns.append(
(
route + "/from_files",
partial(
Server.predict_from_files,
ProxyPipeline(pipeline),
self.server_config.system_logging,
),
)
)

self._update_routes(
app=app,
routes_and_fns=routes_and_fns,
response_model=pipeline.output_schema,
methods=["POST"],
tags=["model", "inference"],
)
16 changes: 15 additions & 1 deletion src/deepsparse/server/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from http import HTTPStatus
from typing import Any, Dict, List, Optional

import numpy
Expand All @@ -25,9 +26,22 @@
)
from deepsparse.loggers.config import MetricFunctionConfig, SystemLoggingGroup
from deepsparse.server.config import EndpointConfig, ServerConfig
from deepsparse.server.protocol import ErrorResponse
from fastapi.responses import JSONResponse


__all__ = ["server_logger_from_config", "prep_outputs_for_serialization"]
__all__ = [
"create_error_response",
"server_logger_from_config",
"prep_outputs_for_serialization",
]


def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
return JSONResponse(
ErrorResponse(message=message, type="invalid_request_error").dict(),
status_code=status_code.value,
)


def server_logger_from_config(config: ServerConfig) -> BaseLogger:
Expand Down
Loading
Loading