Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TextGeneration] Update pipeline inputs to support GenerationConfig #1250

Merged
merged 19 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 120 additions & 72 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import datetime
import logging
import os
import pathlib
import warnings
from enum import Enum
from typing import (
Expand All @@ -33,17 +34,21 @@
import numpy
import onnx
from pydantic import BaseModel, Field
from transformers import GenerationConfig

from deepsparse import Pipeline
from deepsparse.pipeline import DEEPSPARSE_ENGINE
from deepsparse.transformers.engines import NLDecoderEngine
from deepsparse.transformers.pipelines import TransformersPipeline
from deepsparse.transformers.utils import DecoderKVCache
from deepsparse.transformers.utils.helpers import (
check_and_return_generation_config,
create_causal_mask,
initialize_kv_cache_state,
override_config,
pad_to_fixed_length,
prepends_bos_token,
process_generation_config,
repeat_inputs,
)
from deepsparse.transformers.utils.timings import TextGenerationTimings
Expand All @@ -56,6 +61,16 @@
__all__ = ["TextGenerationPipeline"]


class GenerationDefaults:
dsikka marked this conversation as resolved.
Show resolved Hide resolved
num_return_sequences = 1
max_length = 1024
max_new_tokens = None
output_scores = False
top_k = 0
top_p = 0.0
repetition_penalty = 0.0


class FinishReason(Enum):
STOP = "stop"
LENGTH = "length"
Expand All @@ -70,33 +85,14 @@ class Config:
sequences: Union[str, List[str]] = Field(
description="The input sequences to generate the text from.",
)
num_generated_predictions: int = Field(
dsikka marked this conversation as resolved.
Show resolved Hide resolved
default=1,
description="The number of text generations to create from a single prompt. If "
"the same sequence is given as an input multiple times, the number of generated"
"the number of generated predictins is equivalent to the number of times the "
"the sequence is repeated.",
)
max_tokens: int = Field(
default=1024,
description="Maximum number of tokens to generate per output sequence. If no "
"value is provided, will default to 1024.",
)
return_logits: bool = Field(
default=False,
description="A flag that indicates whether to return "
"the logits for the input text sequence and the "
"generated text sequence. ",
)
include_prompt_logits: bool = Field(
default=False,
description="A flag that indicates whether to return "
"the logits for the prompt. If set, prompt_logits are "
"`prepended` to the logits for the generated text sequence."
"Note: This flag is only applicable when return_logits "
"Note: This flag is only applicable when output_scores "
"is `True`.",
)

fixed_sequences_length: bool = Field(
default=False,
description="A flag that indicates whether to modify "
Expand Down Expand Up @@ -126,28 +122,27 @@ class Config:
" tokens is generated). Set to `None` to ignore this parameter."
" Default is `None`.",
)
top_p: Optional[float] = Field(
default=0.0,
description="Used for filtering generated tokens. Keep the"
" tokens where its cumulative probability is >= top_p"
" Default set to 0.0",
)
top_k: Optional[int] = Field(
default=0,
description="Used for filtering generated tokens. Keep"
" top_k generated tokens. Default set to 0",
)

presence_penalty: Optional[float] = Field(
default=0.0,
description="Penalty applied for generating new token. Any existing"
" token results in the subtraction of its corresponding logit value."
" Default set to 0.0",
)
frequency_penalty: Optional[float] = Field(
default=0.0,
description="Penalty applied for generating new token. Existing"
" token frequencies summed to subtraction the logit of its"
" corresponding logit value. Default set to 0.0.",

generation_config: Union[None, str, pathlib.Path, Dict, GenerationConfig] = Field(
dsikka marked this conversation as resolved.
Show resolved Hide resolved
default=None,
description="GenerationConfig file consisting of parameters used to control "
"sequences generated for each prompt. The current supported parameters are: "
"max_length, max_new_tokens, num_return_sequences, output_scores, top_p, "
"top_k, repetition_penalty.",
)

kwargs: Optional[Dict] = Field(
default=None,
description="Any arguments to override generation_config arguments. Refer to "
"the generation_config argument for a full list of supported variables. Only "
"valid when generation_config is not None.",
dsikka marked this conversation as resolved.
Show resolved Hide resolved
)


Expand Down Expand Up @@ -219,9 +214,10 @@ def __init__(
deterministic: bool = True,
sampling_temperature: float = 1.0,
dsikka marked this conversation as resolved.
Show resolved Hide resolved
prompt_sequence_length: int = 64,
sequence_length: int = 512,
sequence_length: int = 1024,
force_max_tokens: bool = False,
internal_kv_cache: bool = True,
generation_config: Union[str, pathlib.Path, Dict, GenerationConfig] = None,
**kwargs,
):
kwargs_engine_type = kwargs.get("engine_type", DEEPSPARSE_ENGINE)
Expand Down Expand Up @@ -271,6 +267,12 @@ def __init__(

# auxiliary flag for devs to enable debug mode for the pipeline
self._debug = False
self.generation_config = process_generation_config(generation_config)
if self.generation_config:
_LOGGER.info(
"Generation config provided for pipline. This will be used "
dsikka marked this conversation as resolved.
Show resolved Hide resolved
"for all inputs unless an input-specific config is provided. "
)

def initialize_engines(
self,
Expand Down Expand Up @@ -410,22 +412,29 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
:param inputs: the input schema for the pipeline
:return: the inputs for the engine
"""
if not self.cache_support_enabled and inputs.max_tokens > 1:
generation_config = check_and_return_generation_config(
self.generation_config, inputs.generation_config, GenerationDefaults()
)

generation_config = override_config(inputs.kwargs, generation_config)

self.streaming = inputs.streaming
if not self.cache_support_enabled and generation_config.max_length > 1:
raise ValueError(
"The model used for inference does not support kv cache. It is "
"assumed that it maps from the token sequence to predicted logits."
"Set `max_tokens` to 1 to support that scenario."
"Set `max_length` to 1 to support that scenario."
)
dsikka marked this conversation as resolved.
Show resolved Hide resolved

# If the num_generated_predictions > 1, repeat the prompt
# num_generated_predictions times. Also, update the engine so that deterministic
# If the num_return_sequences > 1, repeat the prompt
# num_return_sequences times. Also, update the engine so that deterministic
# is set to False.
original_inputs = inputs.sequences
if inputs.num_generated_predictions > 1:
if generation_config.num_return_sequences > 1:
if isinstance(inputs.sequences, str):
inputs.sequences = [inputs.sequences]
inputs.sequences = repeat_inputs(
inputs.sequences, inputs.num_generated_predictions
inputs.sequences, generation_config.num_return_sequences
)
if self.engine:
self.engine.deterministic = False
Expand Down Expand Up @@ -474,16 +483,14 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
context = dict(
prompts=original_inputs,
streaming=inputs.streaming,
num_generated_predictions=inputs.num_generated_predictions,
return_logits=inputs.return_logits,
generation_config=generation_config,
include_prompt_logits=inputs.include_prompt_logits,
callback=inputs.callback,
stop=inputs.stop,
top_p=inputs.top_p,
top_k=inputs.top_k,
top_p=generation_config.top_p,
top_k=generation_config.top_k,
presence_penalty=inputs.presence_penalty,
frequency_penalty=inputs.frequency_penalty,
max_tokens=inputs.max_tokens,
frequency_penalty=generation_config.repetition_penalty,
)

return engine_input, context
Expand Down Expand Up @@ -532,22 +539,48 @@ def process_engine_outputs(
:return: the output schema for the pipeline
"""

def _create_generated_text_output(
sequence: str,
finish_reason: FinishReason = None,
logits: Optional[numpy.array] = None,
):
if finish_reason:
return GeneratedText(
text=sequence,
score=logits,
finished=True,
finished_reason=finish_reason.value,
)
return GeneratedText(
text=sequence,
score=logits,
finished=False,
)

generation_config = kwargs.get("generation_config")
prompts = kwargs.get("prompts")
streaming = kwargs.get("streaming")

if streaming:
return self._stream_engine_outputs(engine_outputs, prompts, kwargs)

generated_tokens, generated_logits, finished_reason, *debug = list(
*engine_outputs
)
if self._debug:
(
generated_tokens,
generated_logits,
finished_reason,
kv_cache_state,
total_num_processed_tokens,
) = list(*engine_outputs)
else:
generated_tokens, generated_logits, finished_reason = list(*engine_outputs)
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)

logits = generated_logits if kwargs.get("return_logits") else None
logits = generated_logits if generation_config.output_scores else None

num_preds = kwargs.get("num_generated_predictions", 1)
num_preds = generation_config.num_return_sequences
finished_reason = [f[0] for f in finished_reason]

if logits is not None:
Expand All @@ -566,7 +599,7 @@ def process_engine_outputs(
)
)

# If the num_generated_predictions > 1, group the generations and return
# If the num_return_sequences > 1, group the generations and return
# them as a list of lists where each list consists of the generated
# predictions for a given prompt, and all the lists are in the order matching
# the order that the prompts were given as inputs.
Expand All @@ -581,8 +614,7 @@ def process_engine_outputs(
created=datetime.datetime.now(), prompts=prompts, generations=generations
)

if debug:
kv_cache_state, total_num_processed_tokens = debug
if self._debug:
debug_params = dict(
kv_cache_state=kv_cache_state,
total_num_processed_tokens=total_num_processed_tokens,
Expand Down Expand Up @@ -614,6 +646,7 @@ def engine_forward(
with self.timer_manager.new_timer_context(total_inference=False) as timer:
finished_reason = []
streaming = context.get("streaming")
generation_config = context.get("generation_config")

if not self.cache_support_enabled:
prompt_logits = self.multitoken_engine(engine_inputs)
Expand All @@ -640,10 +673,6 @@ def engine_forward(
)
token_generator.generate(prompt_logits[-1][0, -1, :])

# create the generated output
max_tokens = context.get("max_tokens", 0)
max_tokens = max_tokens if max_tokens > 0 else (100 * self.sequence_length)

# last prompt token is the first generated token
# add it to generated tokens, and the logits
generated_tokens = [token_generator.tokens[-1]]
Expand All @@ -655,6 +684,15 @@ def engine_forward(
callback = context.get("callback")
stop = context.get("stop")

max_new_tokens = generation_config.max_new_tokens
if max_new_tokens:
max_tokens = max_new_tokens + len(generated_tokens)
else:
max_tokens = generation_config.max_length
max_tokens = (
max_tokens if max_tokens > 0 else (100 * self.sequence_length)
)
dsikka marked this conversation as resolved.
Show resolved Hide resolved

with timer.time(TextGenerationTimings.TOKEN_GENERATION):
while len(generated_tokens) < max_tokens:
with timer.time(TextGenerationTimings.TOKEN_GENERATION_SINGLE):
Expand Down Expand Up @@ -702,14 +740,19 @@ def engine_forward(
)

if not streaming:
returns = (
numpy.array([generated_tokens]),
numpy.concatenate(generated_logits, axis=1),
finished_reason,
)

if self._debug is True:
yield *returns, session
if self._debug:
returns = (
numpy.array([generated_tokens]),
numpy.concatenate(generated_logits, axis=1),
finished_reason,
[session],
)
else:
returns = (
numpy.array([generated_tokens]),
numpy.concatenate(generated_logits, axis=1),
finished_reason,
)

yield returns

Expand Down Expand Up @@ -921,7 +964,12 @@ def join_engine_outputs(
yield outputs
else:
batch_outputs = [list(*b) for b in batch_outputs]
tokens, logits, finish_reason, *debug = zip(*batch_outputs)
if self._debug:
tokens, logits, finish_reason, debug = zip(*batch_outputs)
else:
tokens, logits, finish_reason = zip(*batch_outputs)
debug = None

if self.cache_support_enabled:
# if the model has kv cache, we need to account for
# the fact that the predicted outputs may have
Expand Down Expand Up @@ -969,8 +1017,8 @@ def join_engine_outputs(
kv_cache_state,
num_processed_tokens,
]

yield [tokens, logits, finish_reason]
else:
yield [tokens, logits, finish_reason]

@staticmethod
def causal_mask_input_present(model_path: str) -> bool:
Expand Down
Loading
Loading