Skip to content

Commit

Permalink
review suggestion - names to dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran committed Jul 18, 2023
1 parent 83b1412 commit 27fec25
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple, Type, Union

import numpy
Expand All @@ -30,10 +31,12 @@
__all__ = ["TextGenerationPipeline"]


PROMPT_PREFILL = "engine_prompt_prefill"
PROMPT_PREFILL_SINGLE = "engine_prompt_prefill_single"
TOKEN_GENERATION = "engine_token_generation"
TOKEN_GENERATION_SINGLE = "engine_token_generation_single"
@dataclass(frozen=True)
class _TextGenerationTimings:
PROMPT_PREFILL = "engine_prompt_prefill"
PROMPT_PREFILL_SINGLE = "engine_prompt_prefill_single"
TOKEN_GENERATION = "engine_token_generation"
TOKEN_GENERATION_SINGLE = "engine_token_generation_single"


class TextGenerationInput(BaseModel):
Expand Down Expand Up @@ -321,7 +324,7 @@ def engine_forward(

else:
# run the prompt through
with timer.time(PROMPT_PREFILL):
with timer.time(_TextGenerationTimings.PROMPT_PREFILL):
tokens, prompt_logits = self.prompt_inference(engine_inputs)

# create the generated output
Expand All @@ -334,17 +337,17 @@ def engine_forward(
generated_tokens = [tokens[-1]]
generated_logits = prompt_logits

timer.start(TOKEN_GENERATION)
timer.start(_TextGenerationTimings.TOKEN_GENERATION)
while len(generated_tokens) < max_tokens:
with timer.time(TOKEN_GENERATION_SINGLE):
with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE):
token, logits = self.autoregressive_inference(tokens)
tokens.append(token)
generated_tokens.append(token)
generated_logits.append(logits)

if token == self.tokenizer.eos_token_id and not self.force_max_tokens:
break
timer.stop(TOKEN_GENERATION)
timer.stop(_TextGenerationTimings.TOKEN_GENERATION)

return numpy.array([generated_tokens]), numpy.concatenate(
generated_logits, axis=1
Expand Down Expand Up @@ -400,7 +403,9 @@ def prompt_inference(

for token in tokens[num_tokens_processed:]:
run_tokens.append(token)
with self.timer_manager.current.time(PROMPT_PREFILL_SINGLE):
with self.timer_manager.current.time(
_TextGenerationTimings.PROMPT_PREFILL_SINGLE
):
new_token, new_logits = self.autoregressive_inference(
run_tokens, shift_positions_by_one=not bool(num_tokens_processed)
)
Expand Down

0 comments on commit 27fec25

Please sign in to comment.