Skip to content

Commit

Permalink
update to use config
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Sep 21, 2023
1 parent b95b107 commit f8f5448
Showing 1 changed file with 43 additions and 37 deletions.
80 changes: 43 additions & 37 deletions tests/deepsparse/transformers/pipelines/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import List, Optional, Tuple

import numpy
from transformers import GenerationConfig

import pytest
from deepsparse import Pipeline
Expand Down Expand Up @@ -175,11 +176,13 @@ def test_ort_single_token_prefill(self, setup):
engine_type="onnxruntime",
)
pipeline._debug = True

config = GenerationConfig(
output_scores=True, max_length=self.num_tokens_generate
)

output = pipeline(
sequences=self.prompt,
return_logits=True,
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
sequences=self.prompt, include_prompt_logits=True, generation_config=config
)
assert output.total_num_processed_tokens[0] < self.sequence_length
self._test_output(
Expand Down Expand Up @@ -207,11 +210,11 @@ def test_ort_multi_token_prefill(self, setup):
engine_type="onnxruntime",
)
pipeline._debug = True
config = GenerationConfig(
output_scores=True, max_length=self.num_tokens_generate
)
output = pipeline(
sequences=self.prompt,
return_logits=True,
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
sequences=self.prompt, include_prompt_logits=True, generation_config=config
)

assert output.total_num_processed_tokens[0] < self.sequence_length
Expand Down Expand Up @@ -241,11 +244,12 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup):
engine_type="onnxruntime",
)
pipeline._debug = True

config = GenerationConfig(
output_scores=True, max_length=self.num_tokens_generate
)
output = pipeline(
sequences=self.prompt,
return_logits=True,
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
sequences=self.prompt, include_prompt_logits=True, generation_config=config
)

assert output.total_num_processed_tokens[0] > self.sequence_length_short, (
Expand Down Expand Up @@ -276,11 +280,11 @@ def test_deepsparse_single_token_prefill(self, setup):
internal_kv_cache=self.internal_kv_cache,
)
pipeline._debug = True
config = GenerationConfig(
output_scores=True, max_length=self.num_tokens_generate
)
output = pipeline(
sequences=self.prompt,
return_logits=True,
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
sequences=self.prompt, include_prompt_logits=True, generation_config=config
)

assert output.total_num_processed_tokens[0] < self.sequence_length
Expand All @@ -307,11 +311,11 @@ def test_deepsparse_multi_token_prefill(self, setup):
)
pipeline._debug = True

config = GenerationConfig(
output_scores=True, max_length=self.num_tokens_generate
)
output = pipeline(
sequences=self.prompt,
return_logits=True,
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
sequences=self.prompt, include_prompt_logits=True, generation_config=config
)

assert output.total_num_processed_tokens[0] < self.sequence_length
Expand All @@ -337,11 +341,11 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup):
internal_kv_cache=self.internal_kv_cache,
)
pipeline._debug = True
config = GenerationConfig(
output_scores=True, max_length=self.num_tokens_generate
)
output = pipeline(
sequences=self.prompt,
return_logits=True,
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
sequences=self.prompt, include_prompt_logits=True, generation_config=config
)

assert output.total_num_processed_tokens[0] > self.sequence_length_short, (
Expand All @@ -361,18 +365,16 @@ def test_run_same_prompt_multiple_times(self, setup):
# Test the scenario, where the same prompt is run multiple times
# Every run should produce the same output
pipeline = self.get_pipeline()
config = GenerationConfig(
output_scores=True, max_length=self.num_tokens_generate
)

output_1 = pipeline(
sequences=self.prompt,
return_logits=True,
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
sequences=self.prompt, include_prompt_logits=True, generation_config=config
)

output_2 = pipeline(
sequences=self.prompt,
return_logits=True,
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
sequences=self.prompt, include_prompt_logits=True, generation_config=config
)

assert output_1.generations[0].text == output_2.generations[0].text
Expand All @@ -387,11 +389,13 @@ def test_run_multiple_prompts_in_parallel(self, setup):
# Same two prompts should produce the same output
pipeline = self.get_pipeline()

config = GenerationConfig(
output_scores=True, max_length=self.num_tokens_generate
)
output = pipeline(
sequences=[self.prompt, self.prompt],
return_logits=True,
generation_config=config,
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
)

logits_0 = output.generations[0].score
Expand All @@ -408,14 +412,16 @@ def test_num_generated_predictions(self, setup):
# from the same prompt
pipeline = self.get_pipeline()

output_sequences = pipeline(
sequences=[self.prompt], num_generated_predictions=2
config = GenerationConfig(
num_return_sequences=2, max_length=self.num_tokens_generate
)

output_sequences = pipeline(sequences=[self.prompt], generation_config=config)
assert len(output_sequences.generations) == 1
assert len(output_sequences.generations[0]) == 2

output_sequences = pipeline(
sequences=[self.prompt, self.prompt], num_generated_predictions=2
sequences=[self.prompt, self.prompt], generation_config=config
)
assert len(output_sequences.generations) == 2

Expand Down

0 comments on commit f8f5448

Please sign in to comment.