Skip to content

Commit

Permalink
Update pipeline inputs to support GenerationConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Sep 18, 2023
1 parent 0243798 commit 506c687
Showing 1 changed file with 76 additions and 25 deletions.
101 changes: 76 additions & 25 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

import datetime
import json
import logging
import os
import pathlib.Path
import warnings
from enum import Enum
from typing import (
Expand All @@ -33,6 +35,7 @@
import numpy
import onnx
from pydantic import BaseModel, Field
from transformers import GenerationConfig

from deepsparse import Pipeline
from deepsparse.pipeline import DEEPSPARSE_ENGINE
Expand Down Expand Up @@ -67,24 +70,6 @@ class Config:
sequences: Union[str, List[str]] = Field(
description="The input sequences to generate the text from.",
)
num_generated_predictions: int = Field(
default=1,
description="The number of text generations to create from a single prompt. If "
"the same sequence is given as an input multiple times, the number of generated"
"the number of generated predictins is equivalent to the number of times the "
"the sequence is repeated.",
)
max_tokens: int = Field(
default=100,
description="Maximum number of tokens to generate per output sequence. If no "
"value is provided, will default to 1024.",
)
return_logits: bool = Field(
default=False,
description="A flag that indicates whether to return "
"the logits for the input text sequence and the "
"generated text sequence. ",
)
include_prompt_logits: bool = Field(
default=False,
description="A flag that indicates whether to return "
Expand All @@ -93,13 +78,15 @@ class Config:
"Note: This flag is only applicable when return_logits "
"is `True`.",
)
# TODO: Session ID vs session?
session_id: Optional[str] = Field(
default=None,
description="A user may set a string identifier "
"for the kv cache session. If None, "
"and the model is using kv cache, it "
"will be set to a random uuid.",
)
# TODO: how does the padding length work with our two options of max and largest?
fixed_sequences_length: bool = Field(
default=False,
description="A flag that indicates whether to modify "
Expand Down Expand Up @@ -151,6 +138,13 @@ class Config:
description="Penalty applied for generating new token. Existing"
" token frequencies summed to subtraction the logit of its"
" corresponding logit value. Default set to 0.0.",
# TODO: what about the other parameters? Such as special tokens, which are just
# taken from the tokenizer?
generation_config: Union[None, str, pathlib.Path, Dict, GenerationConfig] = Field(
default=None,
description="GenerationConfig file consisting of parameters used to control "
" sequences generated for each prompt. The current supported parameters are: "
" max_length, num_return_sequences, output_scores, ",
)


Expand Down Expand Up @@ -224,9 +218,10 @@ def __init__(
deterministic: bool = True,
sampling_temperature: float = 1.0,
prompt_sequence_length: int = 64,
sequence_length: int = 512,
sequence_length: int = 1024,
force_max_tokens: bool = False,
internal_kv_cache: bool = True,
generation_config: Union[str, pathlib.Path, Dict, GenerationConfig] = None,
**kwargs,
):
kwargs_engine_type = kwargs.get("engine_type", DEEPSPARSE_ENGINE)
Expand Down Expand Up @@ -275,6 +270,13 @@ def __init__(
self.engine, self.multitoken_engine = self.initialize_engines()
self.streaming = False

self.generation_config = self._process_generation_config(generation_config)
if self.generation_config:
_LOGGER.info(
"Generation config provided for pipline. This will be used "
"for all inputs unless and input-specific config is provided. "
)

def initialize_engines(
self,
) -> Tuple[Optional[NLDecoderEngine], Optional[NLDecoderEngine]]:
Expand Down Expand Up @@ -412,15 +414,64 @@ def output_schema(self) -> Type[BaseModel]:
"""
return TextGenerationOutput

def _process_generation_config(
self, generation_config: [None, str, pathlib.Path, Dict, GenerationConfig]
):
if isinstance(generation_config, GenerationConfig):
return generation_config

if not generation_config:
return None

if isinstance(generation_config, dict):
config_dir = os.getcwd()
config_name = "generation_config.json"
local_config_path = os.path.join(config_dir, config_name)
_LOGGER.info(
"Dictionary provided for the generation config. Creating temporary "
" generation_config.json"
)
with open(local_config_path, "w") as f:
json.dump(generation_config, f)

if isinstance(generation_config, (str, pathlib.Path)):
generation_config = pathlib.Path(generation_config)
config_dir = generation_config.parent.absolute()
config_name = generation_config.name

generation_config = GenerationConfig.from_pretrained(config_dir, config_name)
return generation_config

def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
"""
Convert the input schema for the pipeline to the inputs for the engine.
:param inputs: the input schema for the pipeline
:return: the inputs for the engine
"""
generation_config = self._process_generation_config(inputs.generation_config)
if generation_config is None:
if self.generation_config:
generation_config = self.generation_config
else:
_LOGGER.info(
"Input generation config detection. This will override any "
" config provided during pipeline creation for this input."
)

if generation_config:
num_generated_predictions = generation_config.num_return_sequences
max_tokens = generation_config.max_length
return_logits = generation_config.output_scores

else:
# TODO: maybe just define defaults at the top
num_generated_predictions = 1
max_tokens = 1024
return_logits = False

self.streaming = inputs.streaming
if not self.cache_support_enabled and inputs.max_tokens > 1:
if not self.cache_support_enabled and max_tokens > 1:
raise ValueError(
"The model used for inference does not support kv cache. It is "
"assumed that it maps from the token sequence to predicted logits."
Expand All @@ -431,11 +482,11 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
# num_generated_predictions times. Also, update the engine so that deterministic
# is set to False.
original_inputs = inputs.sequences
if inputs.num_generated_predictions > 1:
if num_generated_predictions > 1:
if isinstance(inputs.sequences, str):
inputs.sequences = [inputs.sequences]
inputs.sequences = repeat_inputs(
inputs.sequences, inputs.num_generated_predictions
inputs.sequences, num_generated_predictions
)
if self.engine:
self.engine.deterministic = False
Expand Down Expand Up @@ -488,16 +539,16 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:

context = dict(
prompts=original_inputs,
num_generated_predictions=inputs.num_generated_predictions,
return_logits=inputs.return_logits,
num_generated_predictions=num_generated_predictions,
return_logits=return_logits,
include_prompt_logits=inputs.include_prompt_logits,
callback=inputs.callback,
stop=inputs.stop,
top_p=inputs.top_p,
top_k=inputs.top_k,
presence_penalty=inputs.presence_penalty,
frequency_penalty=inputs.frequency_penalty,
max_tokens=inputs.max_tokens,
max_tokens=max_tokens,
)

return engine_input, context
Expand Down

0 comments on commit 506c687

Please sign in to comment.