Skip to content

Commit

Permalink
update pipeline to use kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Sep 21, 2023
1 parent 95a6377 commit 1dc2bd1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
6 changes: 4 additions & 2 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def __call__(self, *args, **kwargs) -> BaseModel:
)

# join together the batches of size `self._batch_size`
engine_outputs = self.join_engine_outputs(batch_outputs, orig_batch_size)
engine_outputs = self.join_engine_outputs(
batch_outputs, orig_batch_size, **context
)
timer.stop(InferenceStages.ENGINE_FORWARD)

self.log(
Expand Down Expand Up @@ -470,7 +472,7 @@ def to_config(self) -> "PipelineConfig":
)

def join_engine_outputs(
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int, **kwargs
) -> List[numpy.ndarray]:
"""
Joins list of engine outputs together into one list.
Expand Down
17 changes: 10 additions & 7 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def __init__(
self.tokenizer.pad_token = self.tokenizer.eos_token

self.engine, self.multitoken_engine = self.initialize_engines()
self.streaming = False

def initialize_engines(
self,
Expand Down Expand Up @@ -419,7 +418,6 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
:param inputs: the input schema for the pipeline
:return: the inputs for the engine
"""
self.streaming = inputs.streaming
if not self.cache_support_enabled and inputs.max_tokens > 1:
raise ValueError(
"The model used for inference does not support kv cache. It is "
Expand Down Expand Up @@ -488,6 +486,7 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:

context = dict(
prompts=original_inputs,
streaming=inputs.streaming,
num_generated_predictions=inputs.num_generated_predictions,
return_logits=inputs.return_logits,
include_prompt_logits=inputs.include_prompt_logits,
Expand Down Expand Up @@ -547,8 +546,9 @@ def process_engine_outputs(
"""

prompts = kwargs.get("prompts")
streaming = kwargs.get("streaming")

if self.streaming:
if streaming:
return self._stream_engine_outputs(engine_outputs, prompts, kwargs)

generated_tokens, generated_logits, finished_reason = list(*engine_outputs)
Expand Down Expand Up @@ -611,6 +611,7 @@ def engine_forward(

with self.timer_manager.new_timer_context(total_inference=False) as timer:
finished_reason = []
streaming = context.get("streaming")

if not self.cache_support_enabled:
prompt_logits = self.multitoken_engine(engine_inputs)
Expand Down Expand Up @@ -688,17 +689,17 @@ def engine_forward(
if len(generated_tokens) == max_tokens:
finished_reason.append(FinishReason.LENGTH)

if self.streaming:
if streaming:
yield (numpy.array([token]), numpy.array([logits]), [None])

if self.streaming:
if streaming:
yield (
numpy.array([token]),
numpy.array([logits]),
[finished_reason[-1]],
)

if not self.streaming:
if not streaming:
yield (
numpy.array([generated_tokens]),
numpy.concatenate(generated_logits, axis=1),
Expand Down Expand Up @@ -895,6 +896,7 @@ def join_engine_outputs(
self,
batch_outputs: List[List[Union[numpy.ndarray, FinishReason]]],
orig_batch_size: int,
**kwargs,
) -> List[Union[numpy.ndarray, FinishReason]]:
"""
Takes a list of outputs (batches) from the engine
Expand All @@ -906,7 +908,8 @@ def join_engine_outputs(
:param orig_batch_size: The original batch size
:return: A list of joined outputs
"""
if self.streaming:
streaming = kwargs.get("streaming")
if streaming:
for batch in batch_outputs:
for outputs in batch:
yield outputs
Expand Down

0 comments on commit 1dc2bd1

Please sign in to comment.