Skip to content

Commit

Permalink
[server] Add model argument to server cli (#1584)
Browse files Browse the repository at this point in the history
* update model path to be an argument; remove unused openai command pathway

* add model path arg and option
  • Loading branch information
dsikka committed Feb 7, 2024
1 parent 06e2233 commit 6fe858a
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions src/deepsparse/server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
),
)

MODEL_ARG = click.argument("model", type=str, default=None, required=False)
MODEL_OPTION = click.option(
"--model_path",
type=str,
Expand Down Expand Up @@ -152,6 +153,7 @@
@PORT_OPTION
@LOG_LEVEL_OPTION
@HOT_RELOAD_OPTION
@MODEL_ARG
@MODEL_OPTION
@BATCH_OPTION
@CORES_OPTION
Expand All @@ -167,6 +169,7 @@ def main(
log_level: str,
hot_reload_config: bool,
model_path: str,
model: str,
batch_size: int,
num_cores: int,
num_workers: int,
Expand Down Expand Up @@ -216,6 +219,17 @@ def main(
...
```
"""
# the server cli can take a model argument or --model_path option
# if the --model_path option is provided, use that
# otherwise if the argument is given and --model_path is not used, use the
# argument instead
if model and model_path == "default":
model_path = model

if integration == INTEGRATION_OPENAI:
if task is None or task != "text_generation":
task = "text_generation"

if ctx.invoked_subcommand is not None:
return

Expand Down Expand Up @@ -254,24 +268,6 @@ def main(
server.start_server(host, port, log_level, hot_reload_config=hot_reload_config)


@main.command(
context_settings=dict(
token_normalize_func=lambda x: x.replace("-", "_"), show_default=True
),
)
@click.argument("config-file", type=str)
@HOST_OPTION
@PORT_OPTION
@LOG_LEVEL_OPTION
@HOT_RELOAD_OPTION
def openai(
config_file: str, host: str, port: int, log_level: str, hot_reload_config: bool
):

server = OpenAIServer(server_config=config_file)
server.start_server(host, port, log_level, hot_reload_config=hot_reload_config)


@main.command(
context_settings=dict(
token_normalize_func=lambda x: x.replace("-", "_"), show_default=True
Expand Down

0 comments on commit 6fe858a

Please sign in to comment.