Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TextGeneration] Samling arguments for generation #1225

Merged
merged 15 commits into from
Sep 15, 2023
22 changes: 1 addition & 21 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
generate_session_id,
overwrite_onnx_model_inputs_for_kv_cache_models,
)
from deepsparse.utils.data import numpy_softmax
from deepsparse.utils.onnx import CACHE_INPUT_PREFIX, CACHE_OUTPUT_PREFIX


Expand Down Expand Up @@ -202,10 +201,7 @@ def __call__(
else:
logits = out[0]

# select batch idx 0, batch is always 1
token = self.generate_token(logits=logits[0, -1, :])

return token, logits
return logits
horheynm marked this conversation as resolved.
Show resolved Hide resolved

def __str__(self):
return f"{self.__class__.__name__}: {self.engine}"
Expand All @@ -228,22 +224,6 @@ def transfer_cache_state(self, cache: DecoderKVCache):
cache.set_capacity(self.cache_length)
self.kv_cache = cache

def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
"""
Samples a token from the logits using the sampling temperature.

:param logits: the logits from the model with shape (vocab_size,)
:return: the sampled token
"""
if self.deterministic:
return numpy.argmax(logits)

logits /= self.sampling_temperature

probs = numpy_softmax(logits)

return numpy.random.choice(len(probs), p=probs)

def reset_kv_cache(self):
"""
Resets the kv cache state.
Expand Down
59 changes: 42 additions & 17 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
create_causal_mask,
pad_to_fixed_length,
)
from deepsparse.transformers.utils.token_generator import TokenGenerator
from deepsparse.utils.onnx import default_cached_outputs


Expand Down Expand Up @@ -115,6 +116,21 @@ class Config:
" tokens is generated). Set to `None` to ignore this parameter."
" Default is `None`.",
)
top_p: Optional[float] = Field(
horheynm marked this conversation as resolved.
Show resolved Hide resolved
default=0,
description="Select the tokens with cumulative probability sum"
" higher than the given top_p",
)
top_k: Optional[int] = Field(
default=0.0,
description="Select the tokens with top_k values",
)
presence_penalty: Optional[float] = Field(
default=0.0,
)
frquency_peanlty: Optional[float] = Field(
horheynm marked this conversation as resolved.
Show resolved Hide resolved
default=0.0,
)


class TextGenerationOutput(BaseModel):
Expand Down Expand Up @@ -290,7 +306,6 @@ def initialize_engines(
if (
self.cache_support_enabled and self.enable_multitoken_prefill
) or not self.cache_support_enabled:

multitoken_engine = NLDecoderEngine(
onnx_file_path=self.onnx_file_path,
engine_type=self.engine_type,
Expand Down Expand Up @@ -450,20 +465,29 @@ def engine_forward(
# as such, a new context needs to be created since we are no longer in the
# main thread. That is why `engine_` is prepended to each of the timer phase
# names in this context

with self.timer_manager.new_timer_context(total_inference=False) as timer:
streamer = context.get("streamer")

if not self.cache_support_enabled:
tokens, prompt_logits = self.multitoken_engine(engine_inputs)
return numpy.array([tokens]), prompt_logits
prompt_logits = self.multitoken_engine(engine_inputs)
token_generator = TokenGenerator(prompt_logits[0])
horheynm marked this conversation as resolved.
Show resolved Hide resolved
for prompt_logit in prompt_logits:
token = token_generator.generate(prompt_logit)
return numpy.array([self.tokens]), prompt_logits

else:
# run the prompt through
with timer.time(_TextGenerationTimings.PROMPT_PREFILL):
tokens, prompt_logits = self.prompt_inference(engine_inputs)
prompt_logits = self.prompt_inference(engine_inputs)

tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()
horheynm marked this conversation as resolved.
Show resolved Hide resolved
token_generator = TokenGenerator(logits=prompt_logits[-1], tokens=tokens)
token_generator.generate(prompt_logits[-1])

tokens = []
if streamer is not None:
streamer.put(numpy.array(tokens))
streamer.put(numpy.array(token_generator.tokens))

# create the generated output
max_tokens = (
Expand All @@ -474,7 +498,7 @@ def engine_forward(

# last prompt token is the first generated token
# add it to generated tokens, and the logits
generated_tokens = [tokens[-1]]
generated_tokens = [token_generator.tokens[-1]]
generated_logits = (
prompt_logits
if context.get("include_prompt_logits")
Expand All @@ -486,8 +510,11 @@ def engine_forward(
with timer.time(_TextGenerationTimings.TOKEN_GENERATION):
while len(generated_tokens) < max_tokens:
with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE):
token, logits = self.autoregressive_inference(tokens)
tokens.append(token)
logits = self.autoregressive_inference(
tokens=token_generator.tokens
)
token = token_generator.generate(logits=logits[0, -1, :])

generated_tokens.append(token)
generated_logits.append(logits)

Expand Down Expand Up @@ -522,7 +549,8 @@ def engine_forward(
)

def prompt_inference(
self, engine_inputs: List[numpy.ndarray]
self,
engine_inputs: List[numpy.ndarray],
) -> Tuple[List[int], List[numpy.ndarray]]:
"""
An inference run that processes the prompt through the
Expand All @@ -539,7 +567,6 @@ def prompt_inference(
tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()

prompt_logits = []
new_token = None
num_tokens_processed = 0

if (
Expand All @@ -548,7 +575,7 @@ def prompt_inference(
):
self.multitoken_engine.reset_kv_cache()
for engine_inputs in self.engine_inputs_for_prefill(tokens):
new_token, new_logits = self.multitoken_engine(engine_inputs)
new_logits = self.multitoken_engine(engine_inputs)
num_tokens_processed += self.prompt_processing_sequence_length
bfineran marked this conversation as resolved.
Show resolved Hide resolved
prompt_logits.append(new_logits)

Expand All @@ -565,13 +592,11 @@ def prompt_inference(
with self.timer_manager.current.time(
_TextGenerationTimings.PROMPT_PREFILL_SINGLE
):
new_token, new_logits = self.autoregressive_inference(run_tokens)
new_logits = self.autoregressive_inference(run_tokens)

prompt_logits.append(new_logits)

tokens.append(new_token)

return tokens, prompt_logits
return prompt_logits

def autoregressive_inference(
self,
Expand Down Expand Up @@ -608,9 +633,9 @@ def autoregressive_inference(
engine_inputs_map[name] for name in self.engine.onnx_input_names_no_cache
]

generated_token, generated_logits = self.engine(engine_inputs)
generated_logits = self.engine(engine_inputs)

return generated_token, generated_logits
return generated_logits

def engine_inputs_for_prefill(
self, tokens: List[int]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,6 @@ def _get_tag(self, entity_name: str) -> Tuple[str, str]:
return bi, tag

def _group_entities(self, entities: List[dict]) -> List[dict]:

horheynm marked this conversation as resolved.
Show resolved Hide resolved
entity_groups = []
entity_group_disagg = []

Expand Down
120 changes: 120 additions & 0 deletions src/deepsparse/transformers/utils/token_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

import numpy

from deepsparse.utils.data import numpy_softmax


class TokenGenerator:
horheynm marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
logits: numpy.ndarray,
tokens: List[int] = [],
deterministic: bool = True,
sampling_temperature: float = 1.0,
top_k: int = 0,
top_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
**kwargs,
):
self.token_frequencies = numpy.zeros(logits.shape[-1])

self.deterministic = deterministic
self.sampling_termperature = sampling_temperature
self.top_k = top_k
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.tokens = []
horheynm marked this conversation as resolved.
Show resolved Hide resolved
for token in tokens:
self.update_frequencies(token)
horheynm marked this conversation as resolved.
Show resolved Hide resolved
horheynm marked this conversation as resolved.
Show resolved Hide resolved

def update_frequencies(self, token: numpy.ndarray):
horheynm marked this conversation as resolved.
Show resolved Hide resolved
self.tokens.append(token)
self.token_frequencies[token] += 1

def generate(self, logits: numpy.ndarray) -> numpy.ndarray:
"""
Samples a token from the logits using the sampling temperature.
horheynm marked this conversation as resolved.
Show resolved Hide resolved

:param logits: the logits from the model with shape (vocab_size,)
:return: the sampled token
"""
if self.deterministic:
token = numpy.argmax(logits)
self.tokens.append(token)
return token

if self.sampling_temperature != 1.0:
logits /= self.sampling_temperature

if self.top_k:
logits = self.apply_top_k(logits)
if self.top_p:
logits = self.apply_top_p(logits)

# penalties here
if self.frequency_penalty != 0.0:
logits = self.apply_frequency_penalty(logits)
if self.presence_penalty != 0.0:
logits = self.apply_presence_penalty(logits)

probs = self.numpy_softmax(logits)
token = numpy.random.choice(len(probs), p=probs)
self.update_frequencies(token)

return token

def apply_frequency_penalty(self, logits: numpy.ndarray):
horheynm marked this conversation as resolved.
Show resolved Hide resolved
logits -= self.frequency_penalty * self.token_frequencies
return logits

def apply_presence_penalty(self, logits: numpy.ndarray):
logits -= self.presence_penalty * (self.token_frequencies > 0)
return logits

# from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
def apply_top_k(self, logits: numpy.ndarray, filter_value=-float("Inf")):
logits_shape = logits.shape
logits = logits.reshape(logits.shape[-1])
top_k_indices = numpy.argpartition(logits, -self.top_k)[-self.top_k :]
logits[~numpy.isin(numpy.arange(len(logits)), top_k_indices)] = filter_value

return logits.reshape(logits_shape)

# from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
def apply_top_p(
self, logits: numpy.ndarray, filter_value=-float("Inf"), min_tokens_to_keep=1
):
logits_shape = logits.shape
logits = logits.reshape(logits.shape[-1])

sorted_indices = numpy.argsort(logits)
sorted_logits = logits[sorted_indices]
logit_cumulative_probs = numpy.cumsum(numpy_softmax(sorted_logits))

# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = logit_cumulative_probs > self.top_p
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0

# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value

return logits.reshape(logits_shape)
Loading