diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 1fb09cca0f..7da8de5ac2 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -21,7 +21,7 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union import numpy from pydantic import BaseModel, Field @@ -259,7 +259,9 @@ def __call__(self, *args, **kwargs) -> BaseModel: ) # join together the batches of size `self._batch_size` - engine_outputs = self.join_engine_outputs(batch_outputs, orig_batch_size) + engine_outputs = self.join_engine_outputs( + batch_outputs, orig_batch_size, **context + ) timer.stop(InferenceStages.ENGINE_FORWARD) self.log( @@ -280,7 +282,10 @@ def __call__(self, *args, **kwargs) -> BaseModel: # ------ POSTPROCESSING ------ timer.start(InferenceStages.POST_PROCESS) pipeline_outputs = self.process_engine_outputs(engine_outputs, **context) - if not isinstance(pipeline_outputs, self.output_schema): + if not ( + isinstance(pipeline_outputs, (self.output_schema, Generator)) + or isinstance(pipeline_outputs, Generator) + ): raise ValueError( f"Outputs of {self.__class__} must be instances of " f"{self.output_schema} found output of type " @@ -467,7 +472,7 @@ def to_config(self) -> "PipelineConfig": ) def join_engine_outputs( - self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int + self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int, **kwargs ) -> List[numpy.ndarray]: """ Joins list of engine outputs together into one list. diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 20e0d546ad..e27e5af9a5 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -33,7 +33,6 @@ import numpy import onnx from pydantic import BaseModel, Field -from transformers import TextStreamer from deepsparse import Pipeline from deepsparse.pipeline import DEEPSPARSE_ENGINE @@ -61,6 +60,7 @@ class FinishReason(Enum): STOP = "stop" LENGTH = "length" TIME = "time" + CALLBACK = "callback" class TextGenerationInput(BaseModel): @@ -106,12 +106,12 @@ class Config: "to have consistent length so one " "can compute metric in a batched fashion. ", ) - streamer: Optional[TextStreamer] = Field( - default=None, - description="Streamer object that will be used to stream the " - "generated sequences. Generated tokens are passed through " - "`streamer.put(token_ids)` and the streamer is responsible " - "for any further processing.", + streaming: bool = Field( + default=False, + description="Whether to stream the results back as they are generated. If " + "True, then the results are returned as a generator object which yields " + "the results as they are generated. If False, then the results are returned " + "as a list after it has completed.", ) callback: Optional[Callable[[Any], Union[bool, Any]]] = Field( default=None, @@ -161,7 +161,7 @@ class GeneratedText(BaseModel): "The scores have the shape [sequence_length, vocab_size]" ) finished: bool = Field(description="Whether generation has stopped.") - finished_reason: str = Field( + finished_reason: Optional[str] = Field( description="The reason for generation to stop. " "Defined by FinishReason. One of stop, length, or time." ) @@ -473,9 +473,9 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]: context = dict( prompts=original_inputs, + streaming=inputs.streaming, num_generated_predictions=inputs.num_generated_predictions, return_logits=inputs.return_logits, - streamer=inputs.streamer, include_prompt_logits=inputs.include_prompt_logits, callback=inputs.callback, stop=inputs.stop, @@ -488,6 +488,40 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]: return engine_input, context + def _create_generated_text_output( + self, + sequence: str, + finish_reason: Optional[FinishReason] = None, + logits: Optional[numpy.array] = None, + ): + if finish_reason: + return GeneratedText( + text=sequence, + score=logits, + finished=True, + finished_reason=finish_reason.value, + ) + return GeneratedText( + text=sequence, + score=logits, + finished=False, + ) + + def _stream_engine_outputs(self, engine_outputs, prompts, kwargs): + for output in engine_outputs: + generated_tokens, generated_logits, finished_reason = output + logits = generated_logits if kwargs.get("return_logits") else None + generation = self._create_generated_text_output( + self.tokenizer.batch_decode(generated_tokens)[0], + finished_reason[0], + logits, + ) + yield TextGenerationOutput( + created=datetime.datetime.now(), + prompts=prompts, + generations=[generation], + ) + def process_engine_outputs( self, engine_outputs: List[Union[numpy.ndarray, FinishReason]], **kwargs ) -> TextGenerationOutput: @@ -497,33 +531,29 @@ def process_engine_outputs( :param engine_outputs: the outputs from the engine :return: the output schema for the pipeline """ - generated_tokens, generated_logits, finished_reason, *debug = engine_outputs - finished_reason = [f[0] for f in finished_reason] + prompts = kwargs.get("prompts") + streaming = kwargs.get("streaming") + + if streaming: + return self._stream_engine_outputs(engine_outputs, prompts, kwargs) + + generated_tokens, generated_logits, finished_reason, *debug = list( + *engine_outputs + ) sequences = self.tokenizer.batch_decode( generated_tokens, skip_special_tokens=True ) - num_preds = kwargs.get("num_generated_predictions", 1) - prompts = kwargs.get("prompts") - - def _create_generated_text_output( - sequence: str, - finish_reason: FinishReason, - logits: Optional[numpy.array] = None, - ): - return GeneratedText( - text=sequence, - score=logits, - finished=True, - finished_reason=finish_reason.value, - ) logits = generated_logits if kwargs.get("return_logits") else None + num_preds = kwargs.get("num_generated_predictions", 1) + finished_reason = [f[0] for f in finished_reason] + if logits is not None: generations = list( self.executor.map( - _create_generated_text_output, + self._create_generated_text_output, sequences, finished_reason, logits, @@ -532,7 +562,7 @@ def _create_generated_text_output( else: generations = list( self.executor.map( - _create_generated_text_output, sequences, finished_reason + self._create_generated_text_output, sequences, finished_reason ) ) @@ -582,8 +612,8 @@ def engine_forward( # names in this context with self.timer_manager.new_timer_context(total_inference=False) as timer: - streamer = context.get("streamer") finished_reason = [] + streaming = context.get("streaming") if not self.cache_support_enabled: prompt_logits = self.multitoken_engine(engine_inputs) @@ -610,9 +640,6 @@ def engine_forward( ) token_generator.generate(prompt_logits[-1][0, -1, :]) - if streamer is not None: - streamer.put(numpy.array(token_generator.tokens)) - # create the generated output max_tokens = context.get("max_tokens", 0) max_tokens = max_tokens if max_tokens > 0 else (100 * self.sequence_length) @@ -638,9 +665,6 @@ def engine_forward( generated_tokens.append(token) generated_logits.append(logits) - if streamer is not None: - streamer.put(numpy.array([token])) - if ( token == self.tokenizer.eos_token_id and not self.force_max_tokens @@ -656,30 +680,38 @@ def engine_forward( finished_reason.append(FinishReason.STOP) break - # TODO: Add any generic callback reason? if callback is not None and callback(token) is False: _LOGGER.debug( "callback %s returned False, stopping generation." % callback.__qualname__ ) + finished_reason.append(FinishReason.CALLBACK) break if len(generated_tokens) == max_tokens: finished_reason.append(FinishReason.LENGTH) - if streamer is not None: - streamer.end() + if streaming: + yield (numpy.array([token]), numpy.array([logits]), [None]) - returns = ( - numpy.array([generated_tokens]), - numpy.concatenate(generated_logits, axis=1), - finished_reason, - ) + if streaming: + yield ( + numpy.array([token]), + numpy.array([logits]), + [finished_reason[-1]], + ) + + if not streaming: + returns = ( + numpy.array([generated_tokens]), + numpy.concatenate(generated_logits, axis=1), + finished_reason, + ) - if self._debug is True: - return *returns, session + if self._debug is True: + yield *returns, session - return returns + yield returns def prompt_inference( self, @@ -870,6 +902,7 @@ def join_engine_outputs( self, batch_outputs: List[List[Union[numpy.ndarray, FinishReason]]], orig_batch_size: int, + **kwargs, ) -> List[Union[numpy.ndarray, FinishReason]]: """ Takes a list of outputs (batches) from the engine @@ -881,48 +914,63 @@ def join_engine_outputs( :param orig_batch_size: The original batch size :return: A list of joined outputs """ - tokens, logits, finish_reason, *debug = zip(*batch_outputs) - if self.cache_support_enabled: - # if the model has kv cache, we need to account for - # the fact that the predicted outputs may have - # different lengths - - # find the longest sequence in the batch of tokens - max_len = max([token.shape[1] for token in tokens]) - - # pad all tokens to the same length - tokens = [ - pad_to_fixed_length( - array=prediction, - max_len=max_len, - value=self.tokenizer.pad_token_id, - axis=1, - ) - for prediction in tokens - ] + streaming = kwargs.get("streaming") + if streaming: + for batch in batch_outputs: + for outputs in batch: + yield outputs + else: + batch_outputs = [list(*b) for b in batch_outputs] + tokens, logits, finish_reason, *debug = zip(*batch_outputs) + if self.cache_support_enabled: + # if the model has kv cache, we need to account for + # the fact that the predicted outputs may have + # different lengths + + # find the longest sequence in the batch of tokens + max_len = max([token.shape[1] for token in tokens]) + + # pad all tokens to the same length + tokens = [ + pad_to_fixed_length( + array=prediction, + max_len=max_len, + value=self.tokenizer.pad_token_id, + axis=1, + ) + for prediction in tokens + ] - # find the longest sequence in the batch of logits - max_len = max([logits.shape[1] for logits in logits]) + # find the longest sequence in the batch of logits + max_len = max([logits.shape[1] for logits in logits]) - # pad all logits to the same length - logits = [ - pad_to_fixed_length(array=single_logits, max_len=max_len, axis=1) - for single_logits in logits - ] + # pad all logits to the same length + logits = [ + pad_to_fixed_length(array=single_logits, max_len=max_len, axis=1) + for single_logits in logits + ] - tokens = numpy.concatenate(tokens, axis=0) - logits = numpy.concatenate(logits, axis=0) + tokens = numpy.concatenate(tokens, axis=0) + logits = numpy.concatenate(logits, axis=0) - if debug: - sessions = debug[0] - kv_cache_state = numpy.stack(session.cached_inputs for session in sessions) - num_processed_tokens = numpy.stack( - session.total_num_processed_tokens for session in sessions - ) + if debug: + sessions = debug[0] + kv_cache_state = numpy.stack( + session.cached_inputs for session in sessions + ) + num_processed_tokens = numpy.stack( + session.total_num_processed_tokens for session in sessions + ) - return [tokens, logits, finish_reason, kv_cache_state, num_processed_tokens] + yield [ + tokens, + logits, + finish_reason, + kv_cache_state, + num_processed_tokens, + ] - return [tokens, logits, finish_reason] + yield [tokens, logits, finish_reason] @staticmethod def causal_mask_input_present(model_path: str) -> bool: diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index d8c4fde2a1..2f96d422b2 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -374,6 +374,7 @@ def test_run_same_prompt_multiple_times(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) + assert output_1.generations[0].text == output_2.generations[0].text assert numpy.allclose( output_1.generations[0].score, @@ -392,6 +393,7 @@ def test_run_multiple_prompts_in_parallel(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) + logits_0 = output.generations[0].score sequence_0 = output.generations[0].text @@ -409,12 +411,14 @@ def test_num_generated_predictions(self, setup): output_sequences = pipeline( sequences=[self.prompt], num_generated_predictions=2 ) + assert len(output_sequences.generations) == 1 assert len(output_sequences.generations[0]) == 2 output_sequences = pipeline( sequences=[self.prompt, self.prompt], num_generated_predictions=2 ) + assert len(output_sequences.generations) == 2 for generation in output_sequences.generations: @@ -427,7 +431,7 @@ def _test_output( max_logits_difference_threshold: Optional[float] = None, run_cache_validation: bool = True, ): - + # extract numpy arrays from cached_inputs ( generated_logits, prompt_logits,