Skip to content

Commit

Permalink
Add types to _TextGenerationTimings attributes
Browse files Browse the repository at this point in the history
Revert to using timer.time for `TOKEN_GENERATION`
Remove finally clause from `contextmanagers`
Address review comments from @rahul-tuli
  • Loading branch information
rahul-tuli committed Jul 21, 2023
1 parent 385e6d4 commit f7572f9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 26 deletions.
32 changes: 17 additions & 15 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@

@dataclass(frozen=True)
class _TextGenerationTimings:
PROMPT_PREFILL = "engine_prompt_prefill"
PROMPT_PREFILL_SINGLE = "engine_prompt_prefill_single"
TOKEN_GENERATION = "engine_token_generation"
TOKEN_GENERATION_SINGLE = "engine_token_generation_single"
PROMPT_PREFILL: str = "engine_prompt_prefill"
PROMPT_PREFILL_SINGLE: str = "engine_prompt_prefill_single"
TOKEN_GENERATION: str = "engine_token_generation"
TOKEN_GENERATION_SINGLE: str = "engine_token_generation_single"


class TextGenerationInput(BaseModel):
Expand Down Expand Up @@ -337,17 +337,19 @@ def engine_forward(
generated_tokens = [tokens[-1]]
generated_logits = prompt_logits

timer.start(_TextGenerationTimings.TOKEN_GENERATION)
while len(generated_tokens) < max_tokens:
with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE):
token, logits = self.autoregressive_inference(tokens)
tokens.append(token)
generated_tokens.append(token)
generated_logits.append(logits)

if token == self.tokenizer.eos_token_id and not self.force_max_tokens:
break
timer.stop(_TextGenerationTimings.TOKEN_GENERATION)
with timer.time(_TextGenerationTimings.TOKEN_GENERATION):
while len(generated_tokens) < max_tokens:
with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE):
token, logits = self.autoregressive_inference(tokens)
tokens.append(token)
generated_tokens.append(token)
generated_logits.append(logits)

if (
token == self.tokenizer.eos_token_id
and not self.force_max_tokens
):
break

return numpy.array([generated_tokens]), numpy.concatenate(
generated_logits, axis=1
Expand Down
16 changes: 5 additions & 11 deletions src/deepsparse/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,8 @@ def time(self, stage: str):
:param stage: the name of the stage to time
"""
self.start(stage)

try:
yield
finally:
self.stop(stage)
yield
self.stop(stage)

def start(self, stage: str):
"""
Expand Down Expand Up @@ -363,9 +360,6 @@ def new_timer_context(self, total_inference: bool = True) -> StagedTimer:
self._timers = [timer]

timer_context.set(timer)

try:
yield timer
finally:
if total_inference:
timer.stop(InferenceStages.TOTAL_INFERENCE)
yield timer
if total_inference:
timer.stop(InferenceStages.TOTAL_INFERENCE)

0 comments on commit f7572f9

Please sign in to comment.