Skip to content

Commit

Permalink
[TextGeneration] Samling arguments for generation (#1225)
Browse files Browse the repository at this point in the history
* draft

* draft

* draft

* impleentation

* delete commented line

* tests, update sampling calculation

* comments and bug fixes

* commnets

* remove generted tokent est form nldecoder engine

* update prompt seq len name

Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>

* readd missing code

---------

Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>
Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
  • Loading branch information
3 people committed Sep 15, 2023
1 parent 544e372 commit 55307ad
Show file tree
Hide file tree
Showing 7 changed files with 410 additions and 58 deletions.
6 changes: 4 additions & 2 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def __call__(self, *args, **kwargs) -> BaseModel:
f"Inputs parsed to {type(pipeline_inputs)}"
)
# batch size of the inputs may be `> self._batch_size` at this point
engine_inputs: List[numpy.ndarray] = self.process_inputs(pipeline_inputs)
engine_inputs = self.process_inputs(pipeline_inputs)
if isinstance(engine_inputs, tuple):
engine_inputs, context = engine_inputs
else:
Expand Down Expand Up @@ -494,7 +494,9 @@ def split_engine_inputs(
return split_engine_inputs(items, batch_size)

def engine_forward(
self, engine_inputs: List[numpy.ndarray], context: Dict = {}
self,
engine_inputs: List[numpy.ndarray],
context: Dict = {},
) -> List[numpy.ndarray]:
"""
:param engine_inputs: list of numpy inputs to Pipeline engine forward
Expand Down
26 changes: 3 additions & 23 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional

import numpy
from transformers import AutoTokenizer
Expand All @@ -23,7 +23,6 @@
from deepsparse.transformers.utils.helpers import generate_session_id
from deepsparse.transformers.utils.timings import TextGenerationTimings
from deepsparse.utils import TimerManager
from deepsparse.utils.data import numpy_softmax
from deepsparse.utils.onnx import (
CACHE_INPUT_PREFIX,
CACHE_OUTPUT_PREFIX,
Expand Down Expand Up @@ -184,7 +183,7 @@ def __call__(
self,
inp: List[numpy.ndarray],
val_inp: bool = True,
) -> Tuple[numpy.ndarray, numpy.ndarray]:
) -> numpy.ndarray:
"""
The main entry point for running the engine.
Expand Down Expand Up @@ -212,10 +211,7 @@ def __call__(
else:
logits = out[0]

# select batch idx 0, batch is always 1
token = self.generate_token(logits=logits[0, -1, :])

return token, logits
return logits

def __str__(self):
return f"{self.__class__.__name__}: {self.engine}"
Expand All @@ -238,22 +234,6 @@ def transfer_cache_state(self, cache: DecoderKVCache):
cache.set_capacity(self.cache_length)
self.kv_cache = cache

def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
"""
Samples a token from the logits using the sampling temperature.
:param logits: the logits from the model with shape (vocab_size,)
:return: the sampled token
"""
if self.deterministic:
return numpy.argmax(logits)

logits /= self.sampling_temperature

probs = numpy_softmax(logits)

return numpy.random.choice(len(probs), p=probs)

def reset_kv_cache(self):
"""
Resets the kv cache state.
Expand Down
82 changes: 65 additions & 17 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
repeat_inputs,
)
from deepsparse.transformers.utils.timings import TextGenerationTimings
from deepsparse.transformers.utils.token_generator import TokenGenerator
from deepsparse.utils.onnx import default_cached_outputs


Expand Down Expand Up @@ -120,6 +121,29 @@ class Config:
" tokens is generated). Set to `None` to ignore this parameter."
" Default is `None`.",
)
top_p: Optional[float] = Field(
default=0.0,
description="Used for filtering generated tokens. Keep the"
" tokens where its cumulative probability is >= top_p"
" Default set to 0.0",
)
top_k: Optional[int] = Field(
default=0,
description="Used for filtering generated tokens. Keep"
" top_k generated tokens. Default set to 0",
)
presence_penalty: Optional[float] = Field(
default=0.0,
description="Penalty applied for generating new token. Any existing"
" token results in the subtraction of its corresponding logit value."
" Default set to 0.0",
)
frequency_penalty: Optional[float] = Field(
default=0.0,
description="Penalty applied for generating new token. Existing"
" token frequencies summed to subtraction the logit of its"
" corresponding logit value. Default set to 0.0.",
)


class TextGenerationOutput(BaseModel):
Expand Down Expand Up @@ -439,8 +463,13 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
include_prompt_logits=inputs.include_prompt_logits,
callback=inputs.callback,
stop=inputs.stop,
top_p=inputs.top_p,
top_k=inputs.top_k,
presence_penalty=inputs.presence_penalty,
frequency_penalty=inputs.frequency_penalty,
max_tokens=inputs.max_tokens,
)

return engine_input, context

def process_engine_outputs(
Expand Down Expand Up @@ -473,7 +502,9 @@ def process_engine_outputs(
return TextGenerationOutput(sequences=sequences, logits=logits)

def engine_forward(
self, engine_inputs: List[numpy.ndarray], context: Dict
self,
engine_inputs: List[numpy.ndarray],
context: Dict,
) -> Tuple[numpy.ndarray, numpy.ndarray]:
"""
Run the forward pass on the engine.
Expand All @@ -488,28 +519,45 @@ def engine_forward(
# as such, a new context needs to be created since we are no longer in the
# main thread. That is why `engine_` is prepended to each of the timer phase
# names in this context

with self.timer_manager.new_timer_context(total_inference=False) as timer:
streamer = context.get("streamer")

if not self.cache_support_enabled:
tokens, prompt_logits = self.multitoken_engine(engine_inputs)
return numpy.array([tokens]), prompt_logits
prompt_logits = self.multitoken_engine(engine_inputs)
token_generator = TokenGenerator(
logits_shape=prompt_logits[-1].shape[-1],
deterministic=self.deterministic,
**context,
)
for prompt_logit in prompt_logits:
token_generator.generate(prompt_logit)
return numpy.array([self.tokens]), prompt_logits

else:
# run the prompt through
with timer.time(TextGenerationTimings.PROMPT_PREFILL):
tokens, prompt_logits = self.prompt_inference(engine_inputs)
prompt_logits = self.prompt_inference(engine_inputs)

tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()
token_generator = TokenGenerator(
logits_shape=prompt_logits[-1].shape[-1],
tokens=tokens,
deterministic=self.deterministic,
**context,
)
token_generator.generate(prompt_logits[-1][0, -1, :])

if streamer is not None:
streamer.put(numpy.array(tokens))
streamer.put(numpy.array(token_generator.tokens))

# create the generated output
max_tokens = context.get("max_tokens", 0)
max_tokens = max_tokens if max_tokens > 0 else (100 * self.sequence_length)

# last prompt token is the first generated token
# add it to generated tokens, and the logits
generated_tokens = [tokens[-1]]
generated_tokens = [token_generator.tokens[-1]]
generated_logits = (
prompt_logits
if context.get("include_prompt_logits")
Expand All @@ -521,8 +569,10 @@ def engine_forward(
with timer.time(TextGenerationTimings.TOKEN_GENERATION):
while len(generated_tokens) < max_tokens:
with timer.time(TextGenerationTimings.TOKEN_GENERATION_SINGLE):
token, logits = self.autoregressive_inference(tokens)
tokens.append(token)
logits = self.autoregressive_inference(
tokens=token_generator.tokens
)
token = token_generator.generate(logits=logits[0, -1, :])
generated_tokens.append(token)
generated_logits.append(logits)

Expand Down Expand Up @@ -557,7 +607,8 @@ def engine_forward(
)

def prompt_inference(
self, engine_inputs: List[numpy.ndarray]
self,
engine_inputs: List[numpy.ndarray],
) -> Tuple[List[int], List[numpy.ndarray]]:
"""
An inference run that processes the prompt through the
Expand All @@ -574,13 +625,12 @@ def prompt_inference(
tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()

prompt_logits = []
new_token = None
num_tokens_processed = 0

if len(tokens) > self.prompt_sequence_length and self.enable_multitoken_prefill:
self.multitoken_engine.reset_kv_cache()
for engine_inputs in self.engine_inputs_for_prefill(tokens):
new_token, new_logits = self.multitoken_engine(engine_inputs)
new_logits = self.multitoken_engine(engine_inputs)
num_tokens_processed += self.prompt_sequence_length
prompt_logits.append(new_logits)

Expand All @@ -598,13 +648,11 @@ def prompt_inference(
with self.timer_manager.current.time(
TextGenerationTimings.PROMPT_PREFILL_SINGLE
):
new_token, new_logits = self.autoregressive_inference(run_tokens)
new_logits = self.autoregressive_inference(run_tokens)

prompt_logits.append(new_logits)

tokens.append(new_token)

return tokens, prompt_logits
return prompt_logits

def autoregressive_inference(
self,
Expand Down Expand Up @@ -641,9 +689,9 @@ def autoregressive_inference(
engine_inputs_map[name] for name in self.engine.onnx_input_names_no_cache
]

generated_token, generated_logits = self.engine(engine_inputs)
generated_logits = self.engine(engine_inputs)

return generated_token, generated_logits
return generated_logits

def engine_inputs_for_prefill(
self, tokens: List[int]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,6 @@ def _get_tag(self, entity_name: str) -> Tuple[str, str]:
return bi, tag

def _group_entities(self, entities: List[dict]) -> List[dict]:

entity_groups = []
entity_group_disagg = []

Expand Down
Loading

0 comments on commit 55307ad

Please sign in to comment.