Skip to content

Commit

Permalink
[TextGeneration] Update pipeline inputs to support GenerationConfig (#…
Browse files Browse the repository at this point in the history
…1250)

* add streaming functionality

* set back default value

* update pipeline.py

* update tests

* fix tests

* update pipeline to use kwargs

* add TODO statements

* add streaming functionality

* Update pipeline inputs to support GenerationConfig

* add max_new_tokens

* remove todo

* update post local test runs

* remove todo missed from rebase

* refactor to use helpers, update reference to generation config variables

* update helper functions to include all generation config handling and overriding

* fix tests

* update to work with new session commit

* update to use config

* cleanup
  • Loading branch information
dsikka committed Sep 22, 2023
1 parent 73913e7 commit b309fa4
Show file tree
Hide file tree
Showing 3 changed files with 283 additions and 111 deletions.
192 changes: 120 additions & 72 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import datetime
import logging
import os
import pathlib
import warnings
from enum import Enum
from typing import (
Expand All @@ -33,17 +34,21 @@
import numpy
import onnx
from pydantic import BaseModel, Field
from transformers import GenerationConfig

from deepsparse import Pipeline
from deepsparse.pipeline import DEEPSPARSE_ENGINE
from deepsparse.transformers.engines import NLDecoderEngine
from deepsparse.transformers.pipelines import TransformersPipeline
from deepsparse.transformers.utils import DecoderKVCache
from deepsparse.transformers.utils.helpers import (
check_and_return_generation_config,
create_causal_mask,
initialize_kv_cache_state,
override_config,
pad_to_fixed_length,
prepends_bos_token,
process_generation_config,
repeat_inputs,
)
from deepsparse.transformers.utils.timings import TextGenerationTimings
Expand All @@ -56,6 +61,16 @@
__all__ = ["TextGenerationPipeline"]


class GenerationDefaults:
num_return_sequences = 1
max_length = 1024
max_new_tokens = None
output_scores = False
top_k = 0
top_p = 0.0
repetition_penalty = 0.0


class FinishReason(Enum):
STOP = "stop"
LENGTH = "length"
Expand All @@ -70,33 +85,14 @@ class Config:
sequences: Union[str, List[str]] = Field(
description="The input sequences to generate the text from.",
)
num_generated_predictions: int = Field(
default=1,
description="The number of text generations to create from a single prompt. If "
"the same sequence is given as an input multiple times, the number of generated"
"the number of generated predictins is equivalent to the number of times the "
"the sequence is repeated.",
)
max_tokens: int = Field(
default=1024,
description="Maximum number of tokens to generate per output sequence. If no "
"value is provided, will default to 1024.",
)
return_logits: bool = Field(
default=False,
description="A flag that indicates whether to return "
"the logits for the input text sequence and the "
"generated text sequence. ",
)
include_prompt_logits: bool = Field(
default=False,
description="A flag that indicates whether to return "
"the logits for the prompt. If set, prompt_logits are "
"`prepended` to the logits for the generated text sequence."
"Note: This flag is only applicable when return_logits "
"Note: This flag is only applicable when output_scores "
"is `True`.",
)

fixed_sequences_length: bool = Field(
default=False,
description="A flag that indicates whether to modify "
Expand Down Expand Up @@ -126,28 +122,27 @@ class Config:
" tokens is generated). Set to `None` to ignore this parameter."
" Default is `None`.",
)
top_p: Optional[float] = Field(
default=0.0,
description="Used for filtering generated tokens. Keep the"
" tokens where its cumulative probability is >= top_p"
" Default set to 0.0",
)
top_k: Optional[int] = Field(
default=0,
description="Used for filtering generated tokens. Keep"
" top_k generated tokens. Default set to 0",
)

presence_penalty: Optional[float] = Field(
default=0.0,
description="Penalty applied for generating new token. Any existing"
" token results in the subtraction of its corresponding logit value."
" Default set to 0.0",
)
frequency_penalty: Optional[float] = Field(
default=0.0,
description="Penalty applied for generating new token. Existing"
" token frequencies summed to subtraction the logit of its"
" corresponding logit value. Default set to 0.0.",

generation_config: Union[None, str, pathlib.Path, Dict, GenerationConfig] = Field(
default=None,
description="GenerationConfig file consisting of parameters used to control "
"sequences generated for each prompt. The current supported parameters are: "
"max_length, max_new_tokens, num_return_sequences, output_scores, top_p, "
"top_k, repetition_penalty.",
)

kwargs: Optional[Dict] = Field(
default=None,
description="Any arguments to override generation_config arguments. Refer to "
"the generation_config argument for a full list of supported variables. Only "
"valid when generation_config is not None.",
)


Expand Down Expand Up @@ -219,9 +214,10 @@ def __init__(
deterministic: bool = True,
sampling_temperature: float = 1.0,
prompt_sequence_length: int = 64,
sequence_length: int = 512,
sequence_length: int = 1024,
force_max_tokens: bool = False,
internal_kv_cache: bool = True,
generation_config: Union[str, pathlib.Path, Dict, GenerationConfig] = None,
**kwargs,
):
kwargs_engine_type = kwargs.get("engine_type", DEEPSPARSE_ENGINE)
Expand Down Expand Up @@ -271,6 +267,12 @@ def __init__(

# auxiliary flag for devs to enable debug mode for the pipeline
self._debug = False
self.generation_config = process_generation_config(generation_config)
if self.generation_config:
_LOGGER.info(
"Generation config provided for pipline. This will be used "
"for all inputs unless an input-specific config is provided. "
)

def initialize_engines(
self,
Expand Down Expand Up @@ -410,22 +412,29 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
:param inputs: the input schema for the pipeline
:return: the inputs for the engine
"""
if not self.cache_support_enabled and inputs.max_tokens > 1:
generation_config = check_and_return_generation_config(
self.generation_config, inputs.generation_config, GenerationDefaults()
)

generation_config = override_config(inputs.kwargs, generation_config)

self.streaming = inputs.streaming
if not self.cache_support_enabled and generation_config.max_length > 1:
raise ValueError(
"The model used for inference does not support kv cache. It is "
"assumed that it maps from the token sequence to predicted logits."
"Set `max_tokens` to 1 to support that scenario."
"Set `max_length` to 1 to support that scenario."
)

# If the num_generated_predictions > 1, repeat the prompt
# num_generated_predictions times. Also, update the engine so that deterministic
# If the num_return_sequences > 1, repeat the prompt
# num_return_sequences times. Also, update the engine so that deterministic
# is set to False.
original_inputs = inputs.sequences
if inputs.num_generated_predictions > 1:
if generation_config.num_return_sequences > 1:
if isinstance(inputs.sequences, str):
inputs.sequences = [inputs.sequences]
inputs.sequences = repeat_inputs(
inputs.sequences, inputs.num_generated_predictions
inputs.sequences, generation_config.num_return_sequences
)
if self.engine:
self.engine.deterministic = False
Expand Down Expand Up @@ -474,16 +483,14 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
context = dict(
prompts=original_inputs,
streaming=inputs.streaming,
num_generated_predictions=inputs.num_generated_predictions,
return_logits=inputs.return_logits,
generation_config=generation_config,
include_prompt_logits=inputs.include_prompt_logits,
callback=inputs.callback,
stop=inputs.stop,
top_p=inputs.top_p,
top_k=inputs.top_k,
top_p=generation_config.top_p,
top_k=generation_config.top_k,
presence_penalty=inputs.presence_penalty,
frequency_penalty=inputs.frequency_penalty,
max_tokens=inputs.max_tokens,
frequency_penalty=generation_config.repetition_penalty,
)

return engine_input, context
Expand Down Expand Up @@ -532,22 +539,48 @@ def process_engine_outputs(
:return: the output schema for the pipeline
"""

def _create_generated_text_output(
sequence: str,
finish_reason: FinishReason = None,
logits: Optional[numpy.array] = None,
):
if finish_reason:
return GeneratedText(
text=sequence,
score=logits,
finished=True,
finished_reason=finish_reason.value,
)
return GeneratedText(
text=sequence,
score=logits,
finished=False,
)

generation_config = kwargs.get("generation_config")
prompts = kwargs.get("prompts")
streaming = kwargs.get("streaming")

if streaming:
return self._stream_engine_outputs(engine_outputs, prompts, kwargs)

generated_tokens, generated_logits, finished_reason, *debug = list(
*engine_outputs
)
if self._debug:
(
generated_tokens,
generated_logits,
finished_reason,
kv_cache_state,
total_num_processed_tokens,
) = list(*engine_outputs)
else:
generated_tokens, generated_logits, finished_reason = list(*engine_outputs)
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)

logits = generated_logits if kwargs.get("return_logits") else None
logits = generated_logits if generation_config.output_scores else None

num_preds = kwargs.get("num_generated_predictions", 1)
num_preds = generation_config.num_return_sequences
finished_reason = [f[0] for f in finished_reason]

if logits is not None:
Expand All @@ -566,7 +599,7 @@ def process_engine_outputs(
)
)

# If the num_generated_predictions > 1, group the generations and return
# If the num_return_sequences > 1, group the generations and return
# them as a list of lists where each list consists of the generated
# predictions for a given prompt, and all the lists are in the order matching
# the order that the prompts were given as inputs.
Expand All @@ -581,8 +614,7 @@ def process_engine_outputs(
created=datetime.datetime.now(), prompts=prompts, generations=generations
)

if debug:
kv_cache_state, total_num_processed_tokens = debug
if self._debug:
debug_params = dict(
kv_cache_state=kv_cache_state,
total_num_processed_tokens=total_num_processed_tokens,
Expand Down Expand Up @@ -614,6 +646,7 @@ def engine_forward(
with self.timer_manager.new_timer_context(total_inference=False) as timer:
finished_reason = []
streaming = context.get("streaming")
generation_config = context.get("generation_config")

if not self.cache_support_enabled:
prompt_logits = self.multitoken_engine(engine_inputs)
Expand All @@ -640,10 +673,6 @@ def engine_forward(
)
token_generator.generate(prompt_logits[-1][0, -1, :])

# create the generated output
max_tokens = context.get("max_tokens", 0)
max_tokens = max_tokens if max_tokens > 0 else (100 * self.sequence_length)

# last prompt token is the first generated token
# add it to generated tokens, and the logits
generated_tokens = [token_generator.tokens[-1]]
Expand All @@ -655,6 +684,15 @@ def engine_forward(
callback = context.get("callback")
stop = context.get("stop")

max_new_tokens = generation_config.max_new_tokens
if max_new_tokens:
max_tokens = max_new_tokens + len(generated_tokens)
else:
max_tokens = generation_config.max_length
max_tokens = (
max_tokens if max_tokens > 0 else (100 * self.sequence_length)
)

with timer.time(TextGenerationTimings.TOKEN_GENERATION):
while len(generated_tokens) < max_tokens:
with timer.time(TextGenerationTimings.TOKEN_GENERATION_SINGLE):
Expand Down Expand Up @@ -702,14 +740,19 @@ def engine_forward(
)

if not streaming:
returns = (
numpy.array([generated_tokens]),
numpy.concatenate(generated_logits, axis=1),
finished_reason,
)

if self._debug is True:
yield *returns, session
if self._debug:
returns = (
numpy.array([generated_tokens]),
numpy.concatenate(generated_logits, axis=1),
finished_reason,
[session],
)
else:
returns = (
numpy.array([generated_tokens]),
numpy.concatenate(generated_logits, axis=1),
finished_reason,
)

yield returns

Expand Down Expand Up @@ -921,7 +964,12 @@ def join_engine_outputs(
yield outputs
else:
batch_outputs = [list(*b) for b in batch_outputs]
tokens, logits, finish_reason, *debug = zip(*batch_outputs)
if self._debug:
tokens, logits, finish_reason, debug = zip(*batch_outputs)
else:
tokens, logits, finish_reason = zip(*batch_outputs)
debug = None

if self.cache_support_enabled:
# if the model has kv cache, we need to account for
# the fact that the predicted outputs may have
Expand Down Expand Up @@ -969,8 +1017,8 @@ def join_engine_outputs(
kv_cache_state,
num_processed_tokens,
]

yield [tokens, logits, finish_reason]
else:
yield [tokens, logits, finish_reason]

@staticmethod
def causal_mask_input_present(model_path: str) -> bool:
Expand Down
Loading

0 comments on commit b309fa4

Please sign in to comment.