From b7f333f925373e0acc388524009c71cfcb339033 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 21 Sep 2023 17:37:53 -0400 Subject: [PATCH] update to use config --- .../pipelines/test_text_generation.py | 80 ++++++++++--------- 1 file changed, 43 insertions(+), 37 deletions(-) diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index ad1e29ead3..ad64dae3f3 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -15,6 +15,7 @@ from typing import List, Optional, Tuple import numpy +from transformers import GenerationConfig import pytest from deepsparse import Pipeline @@ -175,11 +176,13 @@ def test_ort_single_token_prefill(self, setup): engine_type="onnxruntime", ) pipeline._debug = True + + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) + output = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( @@ -207,11 +210,11 @@ def test_ort_multi_token_prefill(self, setup): engine_type="onnxruntime", ) pipeline._debug = True + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) output = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) assert output.total_num_processed_tokens[0] < self.sequence_length @@ -241,11 +244,12 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): engine_type="onnxruntime", ) pipeline._debug = True + + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) output = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( @@ -276,11 +280,11 @@ def test_deepsparse_single_token_prefill(self, setup): internal_kv_cache=self.internal_kv_cache, ) pipeline._debug = True + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) output = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) assert output.total_num_processed_tokens[0] < self.sequence_length @@ -307,11 +311,11 @@ def test_deepsparse_multi_token_prefill(self, setup): ) pipeline._debug = True + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) output = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) assert output.total_num_processed_tokens[0] < self.sequence_length @@ -337,11 +341,11 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): internal_kv_cache=self.internal_kv_cache, ) pipeline._debug = True + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) output = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( @@ -361,18 +365,16 @@ def test_run_same_prompt_multiple_times(self, setup): # Test the scenario, where the same prompt is run multiple times # Every run should produce the same output pipeline = self.get_pipeline() + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) output_1 = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) + output_2 = pipeline( - sequences=self.prompt, - return_logits=True, - include_prompt_logits=True, - max_tokens=self.num_tokens_generate, + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) assert output_1.generations[0].text == output_2.generations[0].text @@ -387,11 +389,13 @@ def test_run_multiple_prompts_in_parallel(self, setup): # Same two prompts should produce the same output pipeline = self.get_pipeline() + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate + ) output = pipeline( sequences=[self.prompt, self.prompt], - return_logits=True, + generation_config=config, include_prompt_logits=True, - max_tokens=self.num_tokens_generate, ) logits_0 = output.generations[0].score @@ -408,14 +412,16 @@ def test_num_generated_predictions(self, setup): # from the same prompt pipeline = self.get_pipeline() - output_sequences = pipeline( - sequences=[self.prompt], num_generated_predictions=2 + config = GenerationConfig( + num_return_sequences=2, max_length=self.num_tokens_generate ) + + output_sequences = pipeline(sequences=[self.prompt], generation_config=config) assert len(output_sequences.generations) == 1 assert len(output_sequences.generations[0]) == 2 output_sequences = pipeline( - sequences=[self.prompt, self.prompt], num_generated_predictions=2 + sequences=[self.prompt, self.prompt], generation_config=config ) assert len(output_sequences.generations) == 2