Skip to content

Commit

Permalink
[TextGeneration][Timer] text gen specific timings + improved timing t…
Browse files Browse the repository at this point in the history
…ooling (#1121)

* [TextGeneration][Timer] text gen specific timings + improved timing tooling

* review suggestion - names to dataclass

* Add types to `_TextGenerationTimings` attributes
Revert to using timer.time for `TOKEN_GENERATION`
Remove finally clause from `contextmanagers`
Address review comments from @rahul-tuli

---------

Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
  • Loading branch information
3 people committed Jul 24, 2023
1 parent 2524935 commit 29a8f68
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 39 deletions.
84 changes: 52 additions & 32 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import os
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Type, Union

import numpy
Expand All @@ -33,6 +34,14 @@
__all__ = ["TextGenerationPipeline"]


@dataclass(frozen=True)
class _TextGenerationTimings:
PROMPT_PREFILL: str = "engine_prompt_prefill"
PROMPT_PREFILL_SINGLE: str = "engine_prompt_prefill_single"
TOKEN_GENERATION: str = "engine_token_generation"
TOKEN_GENERATION_SINGLE: str = "engine_token_generation_single"


class TextGenerationInput(BaseModel):
sequences: Union[str, List[str]] = Field(
description="The input sequences to generate the text from.",
Expand Down Expand Up @@ -314,35 +323,43 @@ def engine_forward(
sequence of generated tokens and a sequence
of logits for each generated token
"""
if not self.multitoken_engine.kv_cache_enabled:
tokens, prompt_logits = self.multitoken_engine(engine_inputs)
return numpy.array([tokens]), prompt_logits

else:
# run the prompt through
tokens, prompt_logits = self.prompt_inference(engine_inputs)

# create the generated output
max_tokens = (
self.max_generated_tokens
if self.max_generated_tokens and self.max_generated_tokens > 0
else 100 * self.sequence_length
) # set safety for absolute max generation

generated_tokens = [tokens[-1]]
generated_logits = prompt_logits

while len(generated_tokens) < max_tokens:
(
token,
logits,
) = self.autoregressive_inference(tokens)
tokens.append(token)
generated_tokens.append(token)
generated_logits.append(logits)

if token == self.tokenizer.eos_token_id and not self.force_max_tokens:
break
# engine_forward is always called in a threadpool due to batch splitting
# as such, a new context needs to be created since we are no longer in the
# main thread. That is why `engine_` is prepended to each of the timer phase
# names in this context
with self.timer_manager.new_timer_context(total_inference=False) as timer:
if not self.multitoken_engine.kv_cache_enabled:
tokens, prompt_logits = self.multitoken_engine(engine_inputs)
return numpy.array([tokens]), prompt_logits

else:
# run the prompt through
with timer.time(_TextGenerationTimings.PROMPT_PREFILL):
tokens, prompt_logits = self.prompt_inference(engine_inputs)

# create the generated output
max_tokens = (
self.max_generated_tokens
if self.max_generated_tokens and self.max_generated_tokens > 0
else 100 * self.sequence_length
) # set safety for absolute max generation

generated_tokens = [tokens[-1]]
generated_logits = prompt_logits

with timer.time(_TextGenerationTimings.TOKEN_GENERATION):
while len(generated_tokens) < max_tokens:
with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE):
token, logits = self.autoregressive_inference(tokens)
tokens.append(token)
generated_tokens.append(token)
generated_logits.append(logits)

if (
token == self.tokenizer.eos_token_id
and not self.force_max_tokens
):
break

return numpy.array([generated_tokens]), numpy.concatenate(
generated_logits, axis=1
Expand Down Expand Up @@ -398,9 +415,12 @@ def prompt_inference(

for token in tokens[num_tokens_processed:]:
run_tokens.append(token)
new_token, new_logits = self.autoregressive_inference(
run_tokens, shift_positions_by_one=not bool(num_tokens_processed)
)
with self.timer_manager.current.time(
_TextGenerationTimings.PROMPT_PREFILL_SINGLE
):
new_token, new_logits = self.autoregressive_inference(
run_tokens, shift_positions_by_one=not bool(num_tokens_processed)
)
prompt_logits.append(new_logits)

tokens.append(new_token)
Expand Down
35 changes: 28 additions & 7 deletions src/deepsparse/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,23 @@ def has_stage(self, stage: str) -> bool:
"""
return stage in self.stages

@contextmanager
def time(self, stage: str):
"""
Context Manager to record the time for a stage in the given context
example:
```
with timer.time(STAGE_NAME):
# do something...
```
:param stage: the name of the stage to time
"""
self.start(stage)
yield
self.stop(stage)

def start(self, stage: str):
"""
Start the timer for a specific stage. If the stage doesn't exist,
Expand Down Expand Up @@ -322,23 +339,27 @@ def all_times(self) -> Dict[str, List[float]]:
return all_times

@contextmanager
def new_timer_context(self) -> StagedTimer:
def new_timer_context(self, total_inference: bool = True) -> StagedTimer:
"""
Create a new StagedTimer object and set it as the current context.
:param total_inference: if True, measures the entire context as total inference
automatically and assumes this is the main inference thread. if False,
assumes this is not the main inference thread and will not overwrite
any other timers in non-multi/benchmark mode. Default True
:return: the new StagedTimer object.
"""
timer = StagedTimer(enabled=self.enabled)
timer.start(InferenceStages.TOTAL_INFERENCE)

if self.multi:
if total_inference:
timer.start(InferenceStages.TOTAL_INFERENCE)

if self.multi or not total_inference:
self._timers.append(timer)
else:
self._timers = [timer]

timer_context.set(timer)

try:
yield timer
finally:
yield timer
if total_inference:
timer.stop(InferenceStages.TOTAL_INFERENCE)

0 comments on commit 29a8f68

Please sign in to comment.