Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TextGeneration] Add Streaming Functionality #1246

Merged
merged 13 commits into from
Sep 21, 2023
13 changes: 9 additions & 4 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union

import numpy
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -259,7 +259,9 @@ def __call__(self, *args, **kwargs) -> BaseModel:
)

# join together the batches of size `self._batch_size`
engine_outputs = self.join_engine_outputs(batch_outputs, orig_batch_size)
engine_outputs = self.join_engine_outputs(
batch_outputs, orig_batch_size, **context
)
timer.stop(InferenceStages.ENGINE_FORWARD)

self.log(
Expand All @@ -280,7 +282,10 @@ def __call__(self, *args, **kwargs) -> BaseModel:
# ------ POSTPROCESSING ------
timer.start(InferenceStages.POST_PROCESS)
pipeline_outputs = self.process_engine_outputs(engine_outputs, **context)
if not isinstance(pipeline_outputs, self.output_schema):
if not (
isinstance(pipeline_outputs, self.output_schema)
dsikka marked this conversation as resolved.
Show resolved Hide resolved
or isinstance(pipeline_outputs, Generator)
):
raise ValueError(
f"Outputs of {self.__class__} must be instances of "
f"{self.output_schema} found output of type "
Expand Down Expand Up @@ -467,7 +472,7 @@ def to_config(self) -> "PipelineConfig":
)

def join_engine_outputs(
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int, **kwargs
dsikka marked this conversation as resolved.
Show resolved Hide resolved
) -> List[numpy.ndarray]:
"""
Joins list of engine outputs together into one list.
Expand Down
210 changes: 129 additions & 81 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import numpy
import onnx
from pydantic import BaseModel, Field
from transformers import TextStreamer

from deepsparse import Pipeline
from deepsparse.pipeline import DEEPSPARSE_ENGINE
Expand Down Expand Up @@ -61,6 +60,7 @@ class FinishReason(Enum):
STOP = "stop"
LENGTH = "length"
TIME = "time"
CALLBACK = "callback"


class TextGenerationInput(BaseModel):
Expand Down Expand Up @@ -106,12 +106,12 @@ class Config:
"to have consistent length so one "
"can compute metric in a batched fashion. ",
)
streamer: Optional[TextStreamer] = Field(
default=None,
description="Streamer object that will be used to stream the "
"generated sequences. Generated tokens are passed through "
"`streamer.put(token_ids)` and the streamer is responsible "
"for any further processing.",
streaming: bool = Field(
default=False,
description="Whether to stream the results back as they are generated. If "
"True, then the results are returned as a generator object which yields "
"the results as they are generated. If False, then the results are returned "
"as a list after it has completed.",
dsikka marked this conversation as resolved.
Show resolved Hide resolved
)
callback: Optional[Callable[[Any], Union[bool, Any]]] = Field(
default=None,
Expand Down Expand Up @@ -161,7 +161,7 @@ class GeneratedText(BaseModel):
"The scores have the shape [sequence_length, vocab_size]"
)
finished: bool = Field(description="Whether generation has stopped.")
finished_reason: str = Field(
finished_reason: Optional[str] = Field(
description="The reason for generation to stop. "
"Defined by FinishReason. One of stop, length, or time."
)
Expand Down Expand Up @@ -473,9 +473,9 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:

context = dict(
prompts=original_inputs,
streaming=inputs.streaming,
num_generated_predictions=inputs.num_generated_predictions,
return_logits=inputs.return_logits,
streamer=inputs.streamer,
include_prompt_logits=inputs.include_prompt_logits,
callback=inputs.callback,
stop=inputs.stop,
Expand All @@ -488,6 +488,40 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:

return engine_input, context

def _create_generated_text_output(
self,
sequence: str,
finish_reason: FinishReason = None,
dsikka marked this conversation as resolved.
Show resolved Hide resolved
logits: Optional[numpy.array] = None,
):
if finish_reason:
return GeneratedText(
text=sequence,
score=logits,
finished=True,
finished_reason=finish_reason.value,
)
return GeneratedText(
text=sequence,
score=logits,
finished=False,
)

def _stream_engine_outputs(self, engine_outputs, prompts, kwargs):
for output in engine_outputs:
generated_tokens, generated_logits, finished_reason = output
logits = generated_logits if kwargs.get("return_logits") else None
generation = self._create_generated_text_output(
self.tokenizer.batch_decode(generated_tokens)[0],
finished_reason[0],
logits,
)
yield TextGenerationOutput(
created=datetime.datetime.now(),
prompts=prompts,
generations=[generation],
)

def process_engine_outputs(
self, engine_outputs: List[Union[numpy.ndarray, FinishReason]], **kwargs
) -> TextGenerationOutput:
Expand All @@ -497,33 +531,29 @@ def process_engine_outputs(
:param engine_outputs: the outputs from the engine
:return: the output schema for the pipeline
"""
generated_tokens, generated_logits, finished_reason, *debug = engine_outputs
finished_reason = [f[0] for f in finished_reason]

prompts = kwargs.get("prompts")
streaming = kwargs.get("streaming")

if streaming:
return self._stream_engine_outputs(engine_outputs, prompts, kwargs)

generated_tokens, generated_logits, finished_reason, *debug = list(
*engine_outputs
)
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)
num_preds = kwargs.get("num_generated_predictions", 1)
prompts = kwargs.get("prompts")

def _create_generated_text_output(
sequence: str,
finish_reason: FinishReason,
logits: Optional[numpy.array] = None,
):
return GeneratedText(
text=sequence,
score=logits,
finished=True,
finished_reason=finish_reason.value,
)

logits = generated_logits if kwargs.get("return_logits") else None

num_preds = kwargs.get("num_generated_predictions", 1)
finished_reason = [f[0] for f in finished_reason]
dsikka marked this conversation as resolved.
Show resolved Hide resolved

if logits is not None:
generations = list(
self.executor.map(
_create_generated_text_output,
self._create_generated_text_output,
sequences,
finished_reason,
logits,
Expand All @@ -532,7 +562,7 @@ def _create_generated_text_output(
else:
generations = list(
self.executor.map(
_create_generated_text_output, sequences, finished_reason
self._create_generated_text_output, sequences, finished_reason
)
)

Expand Down Expand Up @@ -582,8 +612,8 @@ def engine_forward(
# names in this context

with self.timer_manager.new_timer_context(total_inference=False) as timer:
streamer = context.get("streamer")
finished_reason = []
streaming = context.get("streaming")

if not self.cache_support_enabled:
prompt_logits = self.multitoken_engine(engine_inputs)
Expand All @@ -610,9 +640,6 @@ def engine_forward(
)
token_generator.generate(prompt_logits[-1][0, -1, :])

if streamer is not None:
streamer.put(numpy.array(token_generator.tokens))

# create the generated output
max_tokens = context.get("max_tokens", 0)
max_tokens = max_tokens if max_tokens > 0 else (100 * self.sequence_length)
Expand All @@ -638,9 +665,6 @@ def engine_forward(
generated_tokens.append(token)
generated_logits.append(logits)

if streamer is not None:
streamer.put(numpy.array([token]))

if (
token == self.tokenizer.eos_token_id
and not self.force_max_tokens
Expand All @@ -656,30 +680,38 @@ def engine_forward(
finished_reason.append(FinishReason.STOP)
break

# TODO: Add any generic callback reason?
if callback is not None and callback(token) is False:
_LOGGER.debug(
"callback %s returned False, stopping generation."
% callback.__qualname__
)
finished_reason.append(FinishReason.CALLBACK)
break

if len(generated_tokens) == max_tokens:
finished_reason.append(FinishReason.LENGTH)

if streamer is not None:
streamer.end()
if streaming:
yield (numpy.array([token]), numpy.array([logits]), [None])

returns = (
numpy.array([generated_tokens]),
numpy.concatenate(generated_logits, axis=1),
finished_reason,
)
if streaming:
yield (
numpy.array([token]),
numpy.array([logits]),
[finished_reason[-1]],
dsikka marked this conversation as resolved.
Show resolved Hide resolved
)

if not streaming:
returns = (
numpy.array([generated_tokens]),
numpy.concatenate(generated_logits, axis=1),
finished_reason,
)

if self._debug is True:
return *returns, session
if self._debug is True:
yield *returns, session

return returns
yield returns

def prompt_inference(
self,
Expand Down Expand Up @@ -870,6 +902,7 @@ def join_engine_outputs(
self,
batch_outputs: List[List[Union[numpy.ndarray, FinishReason]]],
orig_batch_size: int,
**kwargs,
) -> List[Union[numpy.ndarray, FinishReason]]:
"""
Takes a list of outputs (batches) from the engine
Expand All @@ -881,48 +914,63 @@ def join_engine_outputs(
:param orig_batch_size: The original batch size
:return: A list of joined outputs
"""
tokens, logits, finish_reason, *debug = zip(*batch_outputs)
if self.cache_support_enabled:
# if the model has kv cache, we need to account for
# the fact that the predicted outputs may have
# different lengths

# find the longest sequence in the batch of tokens
max_len = max([token.shape[1] for token in tokens])

# pad all tokens to the same length
tokens = [
pad_to_fixed_length(
array=prediction,
max_len=max_len,
value=self.tokenizer.pad_token_id,
axis=1,
)
for prediction in tokens
]
streaming = kwargs.get("streaming")
if streaming:
for batch in batch_outputs:
for outputs in batch:
yield outputs
else:
batch_outputs = [list(*b) for b in batch_outputs]
tokens, logits, finish_reason, *debug = zip(*batch_outputs)
dsikka marked this conversation as resolved.
Show resolved Hide resolved
if self.cache_support_enabled:
# if the model has kv cache, we need to account for
# the fact that the predicted outputs may have
# different lengths

# find the longest sequence in the batch of tokens
max_len = max([token.shape[1] for token in tokens])
dsikka marked this conversation as resolved.
Show resolved Hide resolved

# pad all tokens to the same length
tokens = [
pad_to_fixed_length(
array=prediction,
max_len=max_len,
value=self.tokenizer.pad_token_id,
axis=1,
)
for prediction in tokens
]

# find the longest sequence in the batch of logits
max_len = max([logits.shape[1] for logits in logits])
# find the longest sequence in the batch of logits
max_len = max([logits.shape[1] for logits in logits])
dsikka marked this conversation as resolved.
Show resolved Hide resolved

# pad all logits to the same length
logits = [
pad_to_fixed_length(array=single_logits, max_len=max_len, axis=1)
for single_logits in logits
]
# pad all logits to the same length
logits = [
pad_to_fixed_length(array=single_logits, max_len=max_len, axis=1)
for single_logits in logits
]

tokens = numpy.concatenate(tokens, axis=0)
logits = numpy.concatenate(logits, axis=0)
tokens = numpy.concatenate(tokens, axis=0)
logits = numpy.concatenate(logits, axis=0)

if debug:
sessions = debug[0]
kv_cache_state = numpy.stack(session.cached_inputs for session in sessions)
num_processed_tokens = numpy.stack(
session.total_num_processed_tokens for session in sessions
)
if debug:
sessions = debug[0]
kv_cache_state = numpy.stack(
session.cached_inputs for session in sessions
)
num_processed_tokens = numpy.stack(
session.total_num_processed_tokens for session in sessions
)

return [tokens, logits, finish_reason, kv_cache_state, num_processed_tokens]
yield [
tokens,
logits,
finish_reason,
kv_cache_state,
num_processed_tokens,
]

return [tokens, logits, finish_reason]
yield [tokens, logits, finish_reason]

@staticmethod
def causal_mask_input_present(model_path: str) -> bool:
Expand Down
Loading
Loading