diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 8aefbb86e3..a6b07318ea 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import datetime import logging import os @@ -578,10 +579,24 @@ def _stream_engine_outputs( self, engine_outputs, prompts, generation_config, **kwargs ): for output in engine_outputs: - generated_tokens, generated_logits, finished_reason = output + ( + generated_tokens, + generated_logits, + finished_reason, + past_tokens_queue, + ) = output logits = generated_logits if generation_config.output_scores else None + from transformers import LlamaTokenizer, LlamaTokenizerFast + + if isinstance(self.tokenizer, (LlamaTokenizer, LlamaTokenizerFast)): + # temporary fix for LLama2/Mistral/... models + generated_string = self._generate_streamed_text_from_past_tokens( + generated_tokens, past_tokens_queue + ) + else: + generated_string = self.tokenizer.batch_decode(generated_tokens)[0] generation = self._create_generated_text_output( - self.tokenizer.batch_decode(generated_tokens)[0], + generated_string, finished_reason[0], logits, ) @@ -599,6 +614,33 @@ def _stream_engine_outputs( **schema_kwargs, ) + def _generate_streamed_text_from_past_tokens( + self, generated_tokens: numpy.ndarray, past_tokens_queue: List[int] + ) -> str: + """ + An auxiliary method that helps to properly generate the streamed text. + Some models like llama2 and mistral are using LlamaTokenizer which is + based on SentencePiece tokenizer. This specific tokenizer doesn't seem + to output appropriate prefix spaces when decoding token by token. + One can make it work if the previously generated tokens are included. + This allows the tokenizer to figure out that the appropriate spaces + from last n consecutive tokens. + + :param generated_tokens: the generated tokens from the engine + :param past_tokens_queue: the queue of last n tokens (n is the + original prompt length in tokens) + :return: the generated string + """ + string_from_n_tokens = self.tokenizer.decode( + past_tokens_queue, skip_special_tokens=True + ) + past_tokens_queue.append(generated_tokens[0]) + string_from_n_plus_1_tokens = self.tokenizer.decode( + past_tokens_queue, skip_special_tokens=True + ) + past_tokens_queue.pop(0) + return string_from_n_plus_1_tokens[len(string_from_n_tokens) :] + def process_engine_outputs( self, engine_outputs: List[Union[numpy.ndarray, FinishReason]], **kwargs ) -> TextGenerationOutput: @@ -733,6 +775,9 @@ def engine_forward( prompt_logits, session = self.prompt_inference(engine_inputs) tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist() + # copy the tokens so that we can use them for streaming + past_tokens_queue = copy.copy(tokens) + token_generator = TokenGenerator( logits_shape=prompt_logits[-1].shape[-1], tokens=tokens, @@ -771,6 +816,7 @@ def engine_forward( numpy.array([generated_tokens[-1]]), numpy.array([generated_logits[-1]]), [None], + past_tokens_queue, ) while len(generated_tokens) < max_tokens: @@ -811,7 +857,12 @@ def engine_forward( break if streaming: - yield (numpy.array([token]), numpy.array([logits]), [None]) + yield ( + numpy.array([token]), + numpy.array([logits]), + [None], + past_tokens_queue, + ) # Run the autoregressive inference only to put the # kv cache entry for the last generated token into the @@ -826,12 +877,14 @@ def engine_forward( numpy.array([generated_tokens]), numpy.concatenate(generated_logits, axis=1), [FinishReason.LENGTH], + past_tokens_queue, ) else: yield ( numpy.array([token]), numpy.array([logits]), [finished_reason[-1]], + past_tokens_queue, ) if not streaming: diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index fb25a33883..ad9526d54a 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -130,3 +130,29 @@ def test_streaming_mode_returns_generator(pipeline, prompt): isinstance(response, pipeline.output_schema) for response in response_generator ), "Pipeline should return a generator of output_schema \ objects in streaming mode" + + +def test_streaming_with_several_prompts(pipeline, prompt): + additional_prompt = "Never gonna run around and desert you" + prompts = [prompt, additional_prompt] + + generations_first_prompt_only = list(pipeline(prompt=prompts[0], streaming=True)) + generations_second_prompt_only = list(pipeline(prompt=prompts[1], streaming=True)) + + bag_of_words_first_prompt = [ + g.generations[0].text for g in generations_first_prompt_only + ] + bag_of_words_second_prompt = [ + g.generations[0].text for g in generations_second_prompt_only + ] + + generations = pipeline(prompt=prompts, streaming=True) + bag_of_words_shared = [] + for r in generations: + for gen in r.generations: + text = gen.text + bag_of_words_shared.append(text) + + assert sorted(bag_of_words_first_prompt + bag_of_words_second_prompt) == sorted( + bag_of_words_shared + )