Skip to content

Commit

Permalink
Merge branch 'main' into sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran committed Sep 13, 2023
2 parents 0a6a315 + a49ab47 commit 816c5ea
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 26 deletions.
74 changes: 51 additions & 23 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from deepsparse.transformers.utils.helpers import (
create_causal_mask,
pad_to_fixed_length,
repeat_inputs,
)
from deepsparse.transformers.utils.timings import TextGenerationTimings
from deepsparse.transformers.utils.token_generator import TokenGenerator
Expand All @@ -58,6 +59,18 @@ class Config:
sequences: Union[str, List[str]] = Field(
description="The input sequences to generate the text from.",
)
num_generated_predictions: int = Field(
default=1,
description="The number of text generations to create from a single prompt. If "
"the same sequence is given as an input multiple times, the number of generated"
"the number of generated predictins is equivalent to the number of times the "
"the sequence is repeated.",
)
max_tokens: int = Field(
default=1024,
description="Maximum number of tokens to generate per output sequence. If no "
"value is provided, will default to 1024.",
)
return_logits: bool = Field(
default=False,
description="A flag that indicates whether to return "
Expand Down Expand Up @@ -134,7 +147,7 @@ class Config:


class TextGenerationOutput(BaseModel):
sequences: Union[str, List[str]] = Field(
sequences: Union[str, List[str], List[List[str]]] = Field(
description="The generated text sequences.",
)
logits: Optional[Any] = Field( # numpy array, set to Any for FastAPI compatibility
Expand Down Expand Up @@ -167,11 +180,6 @@ class TextGenerationPipeline(TransformersPipeline):
from the probability distribution computed from the logits.
Higher values will result in more random samples. Should
be greater than 0.0.
:param max_generated_tokens: the maximum number of tokens to generate
given the input sequence. If None, the model will generate
tokens until the end of the sequence is reached.
Otherwise, it will generate up to the maximum number of tokens or end of
sequence is reached.
:param sequence_length: sequence length to compile model and tokenizer for.
This controls the maximum context length of the pipeline. Default is 512
:param prompt_sequence_length: For large prompts, the prompt is
Expand All @@ -188,7 +196,6 @@ def __init__(
self,
deterministic: bool = True,
sampling_temperature: float = 1.0,
max_generated_tokens: Optional[int] = 1024,
prompt_sequence_length: int = 64,
sequence_length: int = 512,
force_max_tokens: bool = False,
Expand Down Expand Up @@ -227,16 +234,8 @@ def __init__(
if "WAND_OPT_FLAGS" not in os.environ:
os.environ["WAND_OPT_FLAGS"] = "default,~pyramids"

if not self.cache_support_enabled and max_generated_tokens > 1:
raise ValueError(
"The model used for inference does not support kv cache. It is "
"assumed that it maps from the token sequence to predicted logits."
"Set `max_generated_tokens` to 1 to support that scenario."
)

self.deterministic = deterministic
self.sampling_temperature = sampling_temperature
self.max_generated_tokens = max_generated_tokens
self.prompt_sequence_length = prompt_sequence_length
self.force_max_tokens = force_max_tokens
self.internal_kv_cache = internal_kv_cache
Expand Down Expand Up @@ -393,6 +392,26 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
:param inputs: the input schema for the pipeline
:return: the inputs for the engine
"""
if not self.cache_support_enabled and inputs.max_tokens > 1:
raise ValueError(
"The model used for inference does not support kv cache. It is "
"assumed that it maps from the token sequence to predicted logits."
"Set `max_tokens` to 1 to support that scenario."
)

# If the num_generated_predictions > 1, repeat the prompt
# num_generated_predictions times. Also, update the engine so that deterministic
# is set to False.
if inputs.num_generated_predictions > 1:
if isinstance(inputs.sequences, str):
inputs.sequences = [inputs.sequences]
inputs.sequences = repeat_inputs(
inputs.sequences, inputs.num_generated_predictions
)
if self.engine:
self.engine.deterministic = False
if self.multitoken_engine:
self.multitoken_engine.deterministic = False

if inputs.fixed_sequences_length or not self.cache_support_enabled:
# to enforce a fixed sequence length, we need to
Expand Down Expand Up @@ -438,7 +457,8 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
self.engine.session_id = inputs.session_id
self.multitoken_engine.session_id = inputs.session_id

postprocessing_kwargs = dict(
context = dict(
num_generated_predictions=inputs.num_generated_predictions,
return_logits=inputs.return_logits,
streamer=inputs.streamer,
include_prompt_logits=inputs.include_prompt_logits,
Expand All @@ -448,10 +468,9 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
top_k=inputs.top_k,
presence_penalty=inputs.presence_penalty,
frequency_penalty=inputs.frequency_penalty,
max_tokens=inputs.max_tokens,
)

return engine_input, postprocessing_kwargs

def process_engine_outputs(
self, engine_outputs: List[numpy.ndarray], **kwargs
) -> TextGenerationOutput:
Expand All @@ -465,6 +484,18 @@ def process_engine_outputs(
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)
num_preds = kwargs.get("num_generated_predictions", 1)
# If the num_generated_predictions > 1, group the generated sequences and return
# the sequences as a list of lists where each list consists of the generated
# predictions for a given prompt, and all the lists are in the order matching
# the order that the prompts were given as inputs.
if num_preds > 1:
grouped_seq = [
sequences[n : n + num_preds]
for n in range(0, len(sequences), num_preds)
]
sequences = grouped_seq

logits = generated_logits if kwargs.get("return_logits") else None

return TextGenerationOutput(sequences=sequences, logits=logits)
Expand Down Expand Up @@ -520,11 +551,8 @@ def engine_forward(
streamer.put(numpy.array(token_generator.tokens))

# create the generated output
max_tokens = (
self.max_generated_tokens
if self.max_generated_tokens and self.max_generated_tokens > 0
else 100 * self.sequence_length
) # set safety for absolute max generation
max_tokens = context.get("max_tokens", 0)
max_tokens = max_tokens if max_tokens > 0 else (100 * self.sequence_length)

# last prompt token is the first generated token
# add it to generated tokens, and the logits
Expand Down
20 changes: 20 additions & 0 deletions src/deepsparse/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"generate_session_id",
"pad_to_fixed_length",
"create_causal_mask",
"repeat_inputs",
]

_LOGGER = logging.getLogger(__name__)
Expand All @@ -36,6 +37,25 @@ def generate_session_id() -> str:
return session_id


def repeat_inputs(
input_sequences: List[str], num_generated_predictions: int
) -> List[str]:
"""
:param input_sequences: List of input sequences to repeat
:param num_generated_predictions: number of times to repeat each sequence
:return: a list of input sequences, where sequences have been repeated
num_generated_predictions times if the sequence appears in input_sequences just
once. If the sequence appears multiple times in input_sequences, the
num_generated_predictions for the sequence is ignored.
"""
repeated_seq = []

for seq in input_sequences:
repeated_seq.extend(numpy.repeat([seq], num_generated_predictions))
return repeated_seq


def pad_to_fixed_length(
array: numpy.ndarray, max_len: int, axis: int = 0, value: int = 0
) -> numpy.ndarray:
Expand Down
25 changes: 22 additions & 3 deletions tests/deepsparse/transformers/pipelines/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def setup(self, model_stub, model_name, uses_bos_token, internal_kv_cache):
model_path=model_stub,
sequence_length=32,
prompt_sequence_length=4,
max_generated_tokens=self.max_generated_tokens,
internal_kv_cache=self.internal_kv_cache,
)
short_prompt = "this"
Expand Down Expand Up @@ -126,7 +125,9 @@ def test_model_output_sequences(self, setup):
# test model output against sources of truth
pipeline, model_name, _, short_prompt, long_prompt = setup

output_sequences = pipeline(sequences=[short_prompt, long_prompt])
output_sequences = pipeline(
sequences=[short_prompt, long_prompt], max_tokens=self.max_generated_tokens
)

# test against huggingface model
output_hugging_face = self._get_output_huggingface(
Expand All @@ -135,6 +136,23 @@ def test_model_output_sequences(self, setup):
assert short_prompt + output_sequences.sequences[0] == output_hugging_face[0]
assert long_prompt + output_sequences.sequences[1] == output_hugging_face[1]

def test_num_generated_predictions(self, setup):
pipeline = setup[0]
short_prompt = setup[3]

output_sequences = pipeline(
sequences=[short_prompt], num_generated_predictions=2
)

assert len(output_sequences.sequences[0]) == 2

output_sequences = pipeline(
sequences=[short_prompt, short_prompt], num_generated_predictions=2
)
assert len(output_sequences.sequences) == 2
for sequences in output_sequences.sequences:
assert len(sequences) == 2

def test_model_output_cache(self, setup):
pipeline, model_name, _, short_prompt, long_prompt = setup
if self.internal_kv_cache:
Expand All @@ -158,6 +176,7 @@ def dummy_callback(token):
"sequences": "def fib(a, b, accumulator=0)",
"callback": dummy_callback,
"return_logits": True,
"max_tokens": self.max_generated_tokens,
}

outs = pipeline(**inputs)
Expand All @@ -167,7 +186,7 @@ def _test_cache_state(self, prompt, pipeline, model_name):
# make sure that the cache state after running a prompt
# is correct

pipeline(sequences=prompt)
pipeline(sequences=prompt, max_tokens=self.max_generated_tokens)
cache_state_dict = pipeline.engine.kv_cache.cached_inputs
cache_state_list = [cache_state_dict[key] for key in cache_state_dict.keys()]

Expand Down

0 comments on commit 816c5ea

Please sign in to comment.