Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TextGeneration][Timer] text gen specific timings + improved timing tooling #1121

Merged
merged 5 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 52 additions & 32 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import os
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Type, Union

import numpy
Expand All @@ -33,6 +34,14 @@
__all__ = ["TextGenerationPipeline"]


@dataclass(frozen=True)
class _TextGenerationTimings:
PROMPT_PREFILL: str = "engine_prompt_prefill"
PROMPT_PREFILL_SINGLE: str = "engine_prompt_prefill_single"
TOKEN_GENERATION: str = "engine_token_generation"
TOKEN_GENERATION_SINGLE: str = "engine_token_generation_single"


class TextGenerationInput(BaseModel):
sequences: Union[str, List[str]] = Field(
description="The input sequences to generate the text from.",
Expand Down Expand Up @@ -314,35 +323,43 @@ def engine_forward(
sequence of generated tokens and a sequence
of logits for each generated token
"""
if not self.multitoken_engine.kv_cache_enabled:
tokens, prompt_logits = self.multitoken_engine(engine_inputs)
return numpy.array([tokens]), prompt_logits

else:
# run the prompt through
tokens, prompt_logits = self.prompt_inference(engine_inputs)

# create the generated output
max_tokens = (
self.max_generated_tokens
if self.max_generated_tokens and self.max_generated_tokens > 0
else 100 * self.sequence_length
) # set safety for absolute max generation

generated_tokens = [tokens[-1]]
generated_logits = prompt_logits

while len(generated_tokens) < max_tokens:
(
token,
logits,
) = self.autoregressive_inference(tokens)
tokens.append(token)
generated_tokens.append(token)
generated_logits.append(logits)

if token == self.tokenizer.eos_token_id and not self.force_max_tokens:
break
# engine_forward is always called in a threadpool due to batch splitting
# as such, a new context needs to be created since we are no longer in the
# main thread. That is why `engine_` is prepended to each of the timer phase
# names in this context
with self.timer_manager.new_timer_context(total_inference=False) as timer:
if not self.multitoken_engine.kv_cache_enabled:
tokens, prompt_logits = self.multitoken_engine(engine_inputs)
return numpy.array([tokens]), prompt_logits

else:
# run the prompt through
with timer.time(_TextGenerationTimings.PROMPT_PREFILL):
tokens, prompt_logits = self.prompt_inference(engine_inputs)

# create the generated output
max_tokens = (
self.max_generated_tokens
if self.max_generated_tokens and self.max_generated_tokens > 0
else 100 * self.sequence_length
) # set safety for absolute max generation

generated_tokens = [tokens[-1]]
generated_logits = prompt_logits

with timer.time(_TextGenerationTimings.TOKEN_GENERATION):
while len(generated_tokens) < max_tokens:
with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE):
token, logits = self.autoregressive_inference(tokens)
tokens.append(token)
generated_tokens.append(token)
generated_logits.append(logits)

if (
token == self.tokenizer.eos_token_id
and not self.force_max_tokens
):
break

return numpy.array([generated_tokens]), numpy.concatenate(
generated_logits, axis=1
Expand Down Expand Up @@ -398,9 +415,12 @@ def prompt_inference(

for token in tokens[num_tokens_processed:]:
run_tokens.append(token)
new_token, new_logits = self.autoregressive_inference(
run_tokens, shift_positions_by_one=not bool(num_tokens_processed)
)
with self.timer_manager.current.time(
_TextGenerationTimings.PROMPT_PREFILL_SINGLE
):
new_token, new_logits = self.autoregressive_inference(
run_tokens, shift_positions_by_one=not bool(num_tokens_processed)
)
prompt_logits.append(new_logits)

tokens.append(new_token)
Expand Down
35 changes: 28 additions & 7 deletions src/deepsparse/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,23 @@ def has_stage(self, stage: str) -> bool:
"""
return stage in self.stages

@contextmanager
def time(self, stage: str):
"""
Context Manager to record the time for a stage in the given context

example:
```
with timer.time(STAGE_NAME):
# do something...
```

:param stage: the name of the stage to time
"""
self.start(stage)
yield
self.stop(stage)

def start(self, stage: str):
"""
Start the timer for a specific stage. If the stage doesn't exist,
Expand Down Expand Up @@ -322,23 +339,27 @@ def all_times(self) -> Dict[str, List[float]]:
return all_times

@contextmanager
def new_timer_context(self) -> StagedTimer:
def new_timer_context(self, total_inference: bool = True) -> StagedTimer:
"""
Create a new StagedTimer object and set it as the current context.

:param total_inference: if True, measures the entire context as total inference
automatically and assumes this is the main inference thread. if False,
assumes this is not the main inference thread and will not overwrite
any other timers in non-multi/benchmark mode. Default True
:return: the new StagedTimer object.
"""
timer = StagedTimer(enabled=self.enabled)
timer.start(InferenceStages.TOTAL_INFERENCE)

if self.multi:
if total_inference:
timer.start(InferenceStages.TOTAL_INFERENCE)

if self.multi or not total_inference:
self._timers.append(timer)
else:
self._timers = [timer]

timer_context.set(timer)

try:
yield timer
finally:
yield timer
if total_inference:
timer.stop(InferenceStages.TOTAL_INFERENCE)
Loading