Skip to content

Commit

Permalink
[1.7 Cherry Picks] (#1572)
Browse files Browse the repository at this point in the history
* [server] Disable the elastic scheduler when continuous batching is enabled (#1569)

* update server to disable the context/elastic scheduler when continuous batching is enabled

* clean up when context is created

* [TextGeneration] Fix initialization; don't try v1 init for text gen (#1571)

* only check capacity condition durin prefill; already have check in generation

* dont try v1 if running text gen; just raise error
  • Loading branch information
dsikka committed Feb 1, 2024
1 parent 73d0471 commit 16f9409
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 17 deletions.
6 changes: 5 additions & 1 deletion src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
SchedulerGroup,
)
from deepsparse.subgraph_execute import SubGraphExecutor
from deepsparse.tasks import SupportedTasks
from deepsparse.utils import InferenceState, PipelineState
from deepsparse.utils.subgraph import SubGraph
from deepsparse.utils.time import TIMER_KEY, InferenceStages, TimerManager
Expand Down Expand Up @@ -139,7 +140,10 @@ def create(cls, task: str, **kwargs) -> "Pipeline":
"Pipeline was not created for the given task. The "
"provided task should be registered using the OperatorRegistry"
)
except Exception:
except Exception as e:
if SupportedTasks.is_text_generation(task):
raise e

_LOGGER.warning(f"Could not create v2 '{task}' pipeline, trying legacy")
from deepsparse.legacy import Pipeline

Expand Down
27 changes: 23 additions & 4 deletions src/deepsparse/server/deepsparse_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
from concurrent.futures import ThreadPoolExecutor
from functools import partial

from deepsparse import Pipeline
Expand Down Expand Up @@ -73,12 +74,30 @@ def _add_endpoint(
endpoint_config: EndpointConfig,
):
pipeline_config = endpoint_config.to_pipeline_config()
pipeline_config.kwargs["executor"] = self.executor

_LOGGER.info(f"Initializing pipeline for '{endpoint_config.name}'")
pipeline = Pipeline.from_config(
pipeline_config, context=self.context, logger=self.server_logger
)
if pipeline_config.kwargs.get("continuous_batch_sizes"):
pipeline_config.kwargs["executor"] = ThreadPoolExecutor(
max_workers=self.server_config.num_workers
)
_LOGGER.info(
"for continuous batching, the single stream scheduler will be enabled."
)
pipeline_config.num_cores = self.server_config.num_cores
pipeline_config.scheduler = "single"

pipeline = Pipeline.from_config(
pipeline_config,
num_streams=self.server_config.num_workers,
logger=self.server_logger,
)
else:
pipeline_config.kwargs["executor"] = ThreadPoolExecutor(
max_workers=self.context.num_streams
)
pipeline = Pipeline.from_config(
pipeline_config, context=self.context, logger=self.server_logger
)

_LOGGER.info(f"Adding endpoints for '{endpoint_config.name}'")
self._add_inference_endpoints(
Expand Down
14 changes: 13 additions & 1 deletion src/deepsparse/server/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,19 @@ def _add_model(
f"{SupportedTasks.code_generation._fields}"
)

pipeline = Pipeline.from_config(pipeline_config, context=self.context)
if pipeline_config.kwargs.get("continuous_batch_sizes"):
_LOGGER.info(
"for continuous batching, the single stream scheduler will be enabled."
)
pipeline_config.num_cores = self.server_config.num_cores
pipeline_config.scheduler = "single"

pipeline = Pipeline.from_config(
pipeline_config,
num_streams=self.server_config.num_workers,
)
else:
pipeline = Pipeline.from_config(pipeline_config, context=self.context)

if not self.model_to_pipeline.get(endpoint_config.model):
model_card = ModelCard(
Expand Down
14 changes: 4 additions & 10 deletions src/deepsparse/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import os
from abc import abstractmethod
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from typing import AsyncGenerator, List, Optional, Union

Expand Down Expand Up @@ -76,10 +75,11 @@ def __init__(self, server_config: Union[str, ServerConfig]):
self.server_config = server_config

_LOGGER.info(f"Using config: {repr(self.server_config)}")

self.context = None
self.executor = None
self.server_logger = server_logger_from_config(self.server_config)
self.context = Context(
num_cores=self.server_config.num_cores,
num_streams=self.server_config.num_workers,
)

def start_server(
self,
Expand Down Expand Up @@ -109,12 +109,6 @@ def start_server(
self.config_path, f"http://{host}:{port}/endpoints", 0.5
)

self.context = Context(
num_cores=self.server_config.num_cores,
num_streams=self.server_config.num_workers,
)
self.executor = ThreadPoolExecutor(max_workers=self.context.num_streams)

app = self._build_app()

uvicorn.run(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def can_operate(self, inp: Any) -> bool:
if inp.get("in_generation"):
return True

if kv_cache.total_num_processed_tokens >= kv_cache.capacity:
if (
kv_cache.total_num_processed_tokens >= kv_cache.capacity
and inp.get("in_generation") is None
):
raise RuntimeError(
"Not enough kv_cache capacity to run generation. Please use a larger "
"sequence_length or a shorter prompt"
Expand Down

0 comments on commit 16f9409

Please sign in to comment.