Skip to content

Commit

Permalink
[TextGeneration] Add GeneratedText and update TextGenerationOutput (
Browse files Browse the repository at this point in the history
#1240)

* add GeneratedText schema to text_generation pipeline

* update finish reason

* fix finishreason enum

* add TODO statements

* rebase and update FinishReason

* update text generation tests to comply with new output schema
  • Loading branch information
dsikka committed Sep 19, 2023
1 parent b24f7cf commit 404e8a4
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 42 deletions.
128 changes: 100 additions & 28 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import logging
import os
import warnings
from enum import Enum
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -52,6 +54,12 @@
__all__ = ["TextGenerationPipeline"]


class FinishReason(Enum):
STOP = "stop"
LENGTH = "length"
TIME = "time"


class TextGenerationInput(BaseModel):
class Config:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -146,15 +154,34 @@ class Config:
)


class GeneratedText(BaseModel):
text: str = Field(
description="The generated sequence for a given prompt. If "
"streaming is enabled, this will be the next generated token."
)
score: Optional[Any] = Field(
description="The score for the generated token or sequence. "
"The scores have the shape [sequence_length, vocab_size]"
)
finished: bool = Field(description="Whether generation has stopped.")
finished_reason: str = Field(
description="The reason for generation to stop. "
"Defined by FinishReason. One of stop, length, or time."
)


# TODO: Pydantic aliases allow assignment but not reference. Still need to update.
class TextGenerationOutput(BaseModel):
sequences: Union[str, List[str], List[List[str]]] = Field(
description="The generated text sequences.",
created: datetime.datetime = Field(description="Time of inference creation.")
prompts: Union[str, List[str]] = Field(
description="Prompts used for the sequence generation. For multiple input "
"prompts, a list of prompts is returned"
)
logits: Optional[Any] = Field( # numpy array, set to Any for FastAPI compatibility
default=None,
description="The logits for the generated text sequence."
"The logits have dimensions "
"[batch_size, sequence_length, vocab_size]",
generations: Union[List[GeneratedText], List[List[GeneratedText]]] = Field(
description="For a single prompt, a single list of GeneratedText is returned. "
"If multiple prompts are given, a list of GeneratedText is returned for each "
"prompt provided. If streamng is enabled, the next generated token is returned."
"Otherwise, the full generated sequence is returned."
)
session_id: Optional[str] = Field(
default=None, description="A string identifier for the kv cache session."
Expand Down Expand Up @@ -401,6 +428,7 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
# If the num_generated_predictions > 1, repeat the prompt
# num_generated_predictions times. Also, update the engine so that deterministic
# is set to False.
original_inputs = inputs.sequences
if inputs.num_generated_predictions > 1:
if isinstance(inputs.sequences, str):
inputs.sequences = [inputs.sequences]
Expand Down Expand Up @@ -457,6 +485,7 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
self.multitoken_engine.session_id = inputs.session_id

context = dict(
prompts=original_inputs,
num_generated_predictions=inputs.num_generated_predictions,
return_logits=inputs.return_logits,
streamer=inputs.streamer,
Expand All @@ -473,39 +502,71 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
return engine_input, context

def process_engine_outputs(
self, engine_outputs: List[numpy.ndarray], **kwargs
self, engine_outputs: List[Union[numpy.ndarray, FinishReason]], **kwargs
) -> TextGenerationOutput:
"""
Convert the engine outputs to the output schema for the pipeline.
:param engine_outputs: the outputs from the engine
:return: the output schema for the pipeline
"""
generated_tokens, generated_logits = engine_outputs
generated_tokens, generated_logits, finished_reason = engine_outputs
finished_reason = [f[0] for f in finished_reason]

sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)
num_preds = kwargs.get("num_generated_predictions", 1)
# If the num_generated_predictions > 1, group the generated sequences and return
# the sequences as a list of lists where each list consists of the generated
prompts = kwargs.get("prompts")

def _create_generated_text_output(
sequence: str,
finish_reason: FinishReason,
logits: Optional[numpy.array] = None,
):
return GeneratedText(
text=sequence,
score=logits,
finished=True,
finished_reason=finish_reason.value,
)

logits = generated_logits if kwargs.get("return_logits") else None

if logits is not None:
generations = list(
self.executor.map(
_create_generated_text_output,
sequences,
finished_reason,
logits,
)
)
else:
generations = list(
self.executor.map(
_create_generated_text_output, sequences, finished_reason
)
)

# If the num_generated_predictions > 1, group the generations and return
# them as a list of lists where each list consists of the generated
# predictions for a given prompt, and all the lists are in the order matching
# the order that the prompts were given as inputs.
if num_preds > 1:
grouped_seq = [
sequences[n : n + num_preds]
for n in range(0, len(sequences), num_preds)
grouped_generations = [
generations[n : n + num_preds]
for n in range(0, len(generations), num_preds)
]
sequences = grouped_seq
generations = grouped_generations

logits = generated_logits if kwargs.get("return_logits") else None

return TextGenerationOutput(sequences=sequences, logits=logits)
return TextGenerationOutput(
created=datetime.datetime.now(), prompts=prompts, generations=generations
)

def engine_forward(
self,
engine_inputs: List[numpy.ndarray],
context: Dict,
) -> Tuple[numpy.ndarray, numpy.ndarray]:
self, engine_inputs: List[numpy.ndarray], context: Dict
) -> Tuple[numpy.ndarray, numpy.ndarray, List[FinishReason]]:
"""
Run the forward pass on the engine.
Expand All @@ -522,6 +583,7 @@ def engine_forward(

with self.timer_manager.new_timer_context(total_inference=False) as timer:
streamer = context.get("streamer")
finished_reason = []

if not self.cache_support_enabled:
prompt_logits = self.multitoken_engine(engine_inputs)
Expand Down Expand Up @@ -583,27 +645,35 @@ def engine_forward(
token == self.tokenizer.eos_token_id
and not self.force_max_tokens
):
finished_reason.append(FinishReason.STOP)
break

if self._stop_token_generated(token, stop_tokens=stop):
_LOGGER.debug(
"Stop token %s generated. Stopping generation."
% self.tokenizer.decode(token)
)
finished_reason.append(FinishReason.STOP)
break

# TODO: Add any generic callback reason?
if callback is not None and callback(token) is False:
_LOGGER.debug(
"callback %s returned False, stopping generation."
% callback.__qualname__
)
break

if len(generated_tokens) == max_tokens:
finished_reason.append(FinishReason.LENGTH)

if streamer is not None:
streamer.end()

return numpy.array([generated_tokens]), numpy.concatenate(
generated_logits, axis=1
return (
numpy.array([generated_tokens]),
numpy.concatenate(generated_logits, axis=1),
finished_reason,
)

def prompt_inference(
Expand Down Expand Up @@ -793,8 +863,10 @@ def is_cache_support_enabled(self) -> bool:
return any(default_cached_outputs(self.onnx_file_path))

def join_engine_outputs(
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int
) -> List[numpy.ndarray]:
self,
batch_outputs: List[List[Union[numpy.ndarray, FinishReason]]],
orig_batch_size: int,
) -> List[Union[numpy.ndarray, FinishReason]]:
"""
Takes a list of outputs (batches) from the engine
and joins them into a single output. Asserts that
Expand All @@ -805,7 +877,7 @@ def join_engine_outputs(
:param orig_batch_size: The original batch size
:return: A list of joined outputs
"""
tokens, logits = zip(*batch_outputs)
tokens, logits, finish_reason = zip(*batch_outputs)
if self.cache_support_enabled:
# if the model has kv cache, we need to account for
# the fact that the predicted outputs may have
Expand Down Expand Up @@ -837,7 +909,7 @@ def join_engine_outputs(
tokens = numpy.concatenate(tokens, axis=0)
logits = numpy.concatenate(logits, axis=0)

return [tokens, logits]
return [tokens, logits, finish_reason]

@staticmethod
def causal_mask_input_present(model_path: str) -> bool:
Expand Down
38 changes: 24 additions & 14 deletions tests/deepsparse/transformers/pipelines/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,12 @@ def test_run_same_prompt_multiple_times(self, setup):
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
)
assert output_1.sequences[0] == output_2.sequences[0]
assert numpy.allclose(output_1.logits, output_2.logits, atol=_PRECISION)
assert output_1.generations[0].text == output_2.generations[0].text
assert numpy.allclose(
output_1.generations[0].score,
output_2.generations[0].score,
atol=_PRECISION,
)

def test_run_multiple_prompts_in_parallel(self, setup):
# Test the scenario, where multiple prompts are run in parallel
Expand All @@ -393,9 +397,14 @@ def test_run_multiple_prompts_in_parallel(self, setup):
include_prompt_logits=True,
max_tokens=self.num_tokens_generate,
)
logits_0 = output.generations[0].score
sequence_0 = output.generations[0].text

logits_1 = output.generations[1].score
sequence_1 = output.generations[1].text

assert numpy.allclose(output.logits[0], output.logits[1], atol=_PRECISION)
assert output.sequences[0] == output.sequences[1]
assert numpy.allclose(logits_0, logits_1, atol=_PRECISION)
assert sequence_0 == sequence_1

def test_num_generated_predictions(self, setup):
# Test the scenario, where multiple predictions are generated
Expand All @@ -405,14 +414,16 @@ def test_num_generated_predictions(self, setup):
output_sequences = pipeline(
sequences=[self.prompt], num_generated_predictions=2
)
assert len(output_sequences.sequences[0]) == 2
assert len(output_sequences.generations) == 1
assert len(output_sequences.generations[0]) == 2

output_sequences = pipeline(
sequences=[self.prompt, self.prompt], num_generated_predictions=2
)
assert len(output_sequences.sequences) == 2
for sequences in output_sequences.sequences:
assert len(sequences) == 2
assert len(output_sequences.generations) == 2

for generation in output_sequences.generations:
assert len(generation) == 2

def _test_output(
self,
Expand All @@ -434,6 +445,7 @@ def _test_output(

# concatenate target prompt_logits and generated_logits and check
target_logits = numpy.concatenate([prompt_logits, generated_logits], axis=1)
score = output.generations[0].score

if max_logits_difference_threshold:
# if comparing the output from the model where
Expand All @@ -442,18 +454,16 @@ def _test_output(
# to be less than the threshold
# (the threshold is established by running the
# ONNX model in ONNXRuntime)
assert (
abs(output.logits - target_logits).max()
< max_logits_difference_threshold
)
assert abs(score - target_logits[0]).max() < max_logits_difference_threshold
else:
# otherwise, we expect the logits to be exactly the same
# as the target logits; the generated sequence should
# also be the same as the target sequence, and finally
# (if applicable) the kv cache should be the same as the
# target kv cache
assert numpy.allclose(output.logits, target_logits, atol=_PRECISION)
assert self.prompt + output.sequences[0] == generated_text

assert numpy.allclose(score, target_logits[0], atol=_PRECISION)
assert self.prompt + output.generations[0].text == generated_text

if run_cache_validation:
self._test_kv_cache_state(
Expand Down

0 comments on commit 404e8a4

Please sign in to comment.