Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Sep 19, 2023
1 parent d0d80fd commit 627b849
Showing 1 changed file with 8 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,9 @@ def test_run_same_prompt_multiple_times(self, setup):
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
)
output_1 = next(output_1)
output_2 = next(output_2)

assert output_1.generations[0].text == output_2.generations[0].text
assert numpy.allclose(
output_1.generations[0].score,
Expand All @@ -397,6 +400,8 @@ def test_run_multiple_prompts_in_parallel(self, setup):
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
)
output = next(output)

logits_0 = output.generations[0].score
sequence_0 = output.generations[0].text

Expand All @@ -414,12 +419,14 @@ def test_num_generated_predictions(self, setup):
output_sequences = pipeline(
sequences=[self.prompt], num_generated_predictions=2
)
output_sequences = next(output_sequences)
assert len(output_sequences.generations) == 1
assert len(output_sequences.generations[0]) == 2

output_sequences = pipeline(
sequences=[self.prompt, self.prompt], num_generated_predictions=2
)
output_sequences = next(output_sequences)
assert len(output_sequences.generations) == 2

for generation in output_sequences.generations:
Expand All @@ -435,6 +442,7 @@ def _test_output(
):
# extract numpy arrays from cached_inputs
kv_cache_array = list(cache_session.cached_inputs.values())
output = next(output)

(
generated_logits,
Expand Down

0 comments on commit 627b849

Please sign in to comment.