Skip to content

Commit

Permalink
Switch to transformers streamer
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Jul 23, 2023
1 parent 2d2a959 commit 0a56709
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 165 deletions.
5 changes: 2 additions & 3 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,7 @@ def __call__(self, *args, **kwargs) -> BaseModel:
)

# submit split batches to engine threadpool
batch_outputs = list(
self.executor.map(self.engine_forward, batches, **postprocess_kwargs)
)
batch_outputs = list(self.executor.map(self.engine_forward, batches))

# join together the batches of size `self._batch_size`
engine_outputs = join_engine_outputs(batch_outputs, orig_batch_size)
Expand Down Expand Up @@ -683,6 +681,7 @@ def create_engine(
model=onnx_file_path,
**engine_args,
)
engine_args.pop("cache_output_bools", None)
return Engine(onnx_file_path, **engine_args)

if engine_type == ORT_ENGINE:
Expand Down
28 changes: 15 additions & 13 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

import numpy
from pydantic import BaseModel, Field
from transformers import TextStreamer

from deepsparse import Pipeline
from deepsparse.cpu import cpu_avx512_compatible
from deepsparse.pipeline import DEEPSPARSE_ENGINE
from deepsparse.transformers.engines import NLDecoderEngine
from deepsparse.transformers.pipelines import TransformersPipeline
from deepsparse.transformers.utils.helpers import pad_to_fixed_length
from deepsparse.transformers.utils.streamers import BaseStreamer


_LOGGER = logging.getLogger(__name__)
Expand All @@ -35,6 +35,9 @@


class TextGenerationInput(BaseModel):
class Config:
arbitrary_types_allowed = True

sequences: Union[str, List[str]] = Field(
description="The input sequences to generate the text from.",
)
Expand All @@ -60,7 +63,7 @@ class TextGenerationInput(BaseModel):
"to have consistent length so one "
"can compute metric in a batched fashion. ",
)
streamer: Optional[BaseStreamer] = Field(
streamer: Optional[TextStreamer] = Field(
default=None,
description="Streamer object that will be used to stream the "
"generated sequences. Generated tokens are passed through "
Expand Down Expand Up @@ -166,6 +169,7 @@ def __init__(
self.max_generated_tokens = max_generated_tokens
self.prompt_processing_sequence_length = prompt_processing_sequence_length
self.force_max_tokens = force_max_tokens
self.streamer = None

# override tokenizer to pad to left
self.tokenizer.padding_side = "left"
Expand Down Expand Up @@ -290,9 +294,9 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
self.engine.session_id = inputs.session_id
self.multitoken_engine.session_id = inputs.session_id

postprocessing_kwargs = dict(
return_logits=inputs.return_logits, streamer=inputs.streamer
)
self.streamer = inputs.streamer

postprocessing_kwargs = dict(return_logits=inputs.return_logits)
return engine_input, postprocessing_kwargs

def process_engine_outputs(
Expand Down Expand Up @@ -324,8 +328,6 @@ def engine_forward(
sequence of generated tokens and a sequence
of logits for each generated token
"""
streamer = kwargs.get("streamer")

if not self.multitoken_engine.kv_cache_enabled:
tokens, prompt_logits = self.multitoken_engine(engine_inputs)
return numpy.array([tokens]), prompt_logits
Expand All @@ -334,8 +336,8 @@ def engine_forward(
# run the prompt through
tokens, prompt_logits = self.prompt_inference(engine_inputs)

if streamer is not None:
streamer.put(tokens)
if self.streamer is not None:
self.streamer.put(numpy.array(tokens))

# create the generated output
max_tokens = (
Expand All @@ -356,14 +358,14 @@ def engine_forward(
generated_tokens.append(token)
generated_logits.append(logits)

if streamer is not None:
streamer.put(token)
if self.streamer is not None:
self.streamer.put(numpy.array([token]))

if token == self.tokenizer.eos_token_id and not self.force_max_tokens:
break

if streamer is not None:
streamer.end()
if self.streamer is not None:
self.streamer.end()

return numpy.array([generated_tokens]), numpy.concatenate(
generated_logits, axis=1
Expand Down
149 changes: 0 additions & 149 deletions src/deepsparse/transformers/utils/streamers.py

This file was deleted.

0 comments on commit 0a56709

Please sign in to comment.