Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TextGeneration] Samling arguments for generation #1225

Merged
merged 15 commits into from
Sep 15, 2023
4 changes: 3 additions & 1 deletion src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]
def __call__(
self,
inp: List[numpy.ndarray],
token_generator: TokenGenerator,
val_inp: bool = True,
) -> Tuple[numpy.ndarray, numpy.ndarray]:
"""
Expand Down Expand Up @@ -203,7 +204,8 @@ def __call__(
logits = out[0]

# select batch idx 0, batch is always 1
token = self.generate_token(logits=logits[0, -1, :])
# token = self.generate_token(logits=logits[0, -1, :])
token = token_generator.generate(logits=logits[0, -1, :])

return token, logits

Expand Down
4 changes: 4 additions & 0 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,11 +482,15 @@ def engine_forward(
)
callback = context.get("callback")
stop = context.get("stop")

token_generator = TokenGenerator(**token_generator_kwargs)

with timer.time(_TextGenerationTimings.TOKEN_GENERATION):
while len(generated_tokens) < max_tokens:
with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE):

token, logits = self.autoregressive_inference(tokens)

tokens.append(token)
generated_tokens.append(token)
generated_logits.append(logits)
Expand Down
87 changes: 87 additions & 0 deletions src/deepsparse/transformers/token_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import numpy
from deepsparse.utils.data import numpy_softmax

class TokenGenerator:
def __init__(
self,
logits: numpy.ndarray,
deterministic: bool = True,
sampling_temperature: float = 1.0,
top_k: int=0,
top_p: float=0.0,
frequency_penalty: float=0.0,
presence_penalty: float=0.0,
):
self.token_frequencies = numpy.zeros(logits.shape)

self.deterministic = deterministic
self.sampling_termperature = sampling_temperature
self.top_k = top_k
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty


def update_frequences(self, token: numpy.ndarray):
for tk in token:
self.token_frequencies[0][tk] += 1


def generate(self, logits: numpy.ndarray) -> numpy.ndarray:
"""
Samples a token from the logits using the sampling temperature.

:param logits: the logits from the model with shape (vocab_size,)
:return: the sampled token
"""
if self.deterministic:
return numpy.argmax(logits)

if self.sampling_temperature != 1.0:
logits /= self.sampling_temperature

if self.top_k:
logits = self.apply_top_k(logits)
if self.top_p:
logits = self.apply_top_p(logits)

# penalties here
if self.frequency_penalty != 0.0:
horheynm marked this conversation as resolved.
Show resolved Hide resolved
logits = self.apply_frequency_penalty(logits)
if self.presence_penalty != 0.0:
logits = self.apply_presence_penalty(logits)

probs = self.numpy_softmax(logits)

token = numpy.random.choice(len(probs), p=probs)
self.update_frequencies(token)

return token


# from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
def apply_top_k(
self,
logits: numpy.ndarray, top_k: int, filter_value=-float("Inf")
):
indices_to_remove = (
logits < numpy.partition(logits, -top_k, axis=1)[:, -top_k][:, None]
)
logits[indices_to_remove] = filter_value
return logits

# from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
def apply_top_p(
self,
logits: numpy.ndarray, top_p: float, filter_value=-float("Inf")
):
sorted_indices = numpy.argsort(logits)
sorted_logits = logits[sorted_indices]
cumulative_probs = numpy_softmax(sorted_logits)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)

indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits = numpy.where(indices_to_remove, filter_value, logits)
return logits
Loading