Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement streamer for text-generation and add context arg to Pipeline.engine_forward #1140

Merged
merged 4 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -228,9 +229,9 @@ def __call__(self, *args, **kwargs) -> BaseModel:
# batch size of the inputs may be `> self._batch_size` at this point
engine_inputs: List[numpy.ndarray] = self.process_inputs(pipeline_inputs)
if isinstance(engine_inputs, tuple):
engine_inputs, postprocess_kwargs = engine_inputs
engine_inputs, context = engine_inputs
else:
postprocess_kwargs = {}
context = {}

timer.stop(InferenceStages.PRE_PROCESS)
self.log(
Expand All @@ -247,7 +248,10 @@ def __call__(self, *args, **kwargs) -> BaseModel:
)

# submit split batches to engine threadpool
batch_outputs = list(self.executor.map(self.engine_forward, batches))
engine_forward_with_context = partial(self.engine_forward, context=context)
batch_outputs = list(
self.executor.map(engine_forward_with_context, batches)
)

# join together the batches of size `self._batch_size`
engine_outputs = self.join_engine_outputs(batch_outputs, orig_batch_size)
Expand All @@ -270,9 +274,7 @@ def __call__(self, *args, **kwargs) -> BaseModel:

# ------ POSTPROCESSING ------
timer.start(InferenceStages.POST_PROCESS)
pipeline_outputs = self.process_engine_outputs(
engine_outputs, **postprocess_kwargs
)
pipeline_outputs = self.process_engine_outputs(engine_outputs, **context)
if not isinstance(pipeline_outputs, self.output_schema):
raise ValueError(
f"Outputs of {self.__class__} must be instances of "
Expand Down Expand Up @@ -486,10 +488,13 @@ def split_engine_inputs(
"""
return split_engine_inputs(items, batch_size)

def engine_forward(self, engine_inputs: List[numpy.ndarray]) -> List[numpy.ndarray]:
def engine_forward(
self, engine_inputs: List[numpy.ndarray], context: Dict = {}
) -> List[numpy.ndarray]:
"""
:param engine_inputs: list of numpy inputs to Pipeline engine forward
pass
:param context: optional dictionary to be used during engine execution
:return: result of forward pass to Pipeline engine
"""
return self.engine(engine_inputs)
Expand Down
30 changes: 27 additions & 3 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import os
import warnings
from dataclasses import dataclass
from typing import Generator, List, Optional, Tuple, Type, Union
from typing import Dict, Generator, List, Optional, Tuple, Type, Union

import numpy
from pydantic import BaseModel, Field
from transformers import TextStreamer

from deepsparse import Pipeline
from deepsparse.cpu import cpu_avx512_compatible
Expand All @@ -46,6 +47,9 @@ class _TextGenerationTimings:


class TextGenerationInput(BaseModel):
class Config:
arbitrary_types_allowed = True

sequences: Union[str, List[str]] = Field(
description="The input sequences to generate the text from.",
)
Expand All @@ -71,6 +75,13 @@ class TextGenerationInput(BaseModel):
"to have consistent length so one "
"can compute metric in a batched fashion. ",
)
streamer: Optional[TextStreamer] = Field(
default=None,
description="Streamer object that will be used to stream the "
"generated sequences. Generated tokens are passed through "
"`streamer.put(token_ids)` and the streamer is responsible "
"for any further processing.",
)


class TextGenerationOutput(BaseModel):
Expand Down Expand Up @@ -290,7 +301,9 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
self.engine.session_id = inputs.session_id
self.multitoken_engine.session_id = inputs.session_id

postprocessing_kwargs = dict(return_logits=inputs.return_logits)
postprocessing_kwargs = dict(
return_logits=inputs.return_logits, streamer=inputs.streamer
)
return engine_input, postprocessing_kwargs

def process_engine_outputs(
Expand All @@ -311,7 +324,7 @@ def process_engine_outputs(
return TextGenerationOutput(sequences=sequences, logits=logits)

def engine_forward(
self, engine_inputs: List[numpy.ndarray], **kwargs
self, engine_inputs: List[numpy.ndarray], context: Dict
) -> Tuple[numpy.ndarray, numpy.ndarray]:
"""
Run the forward pass on the engine.
Expand All @@ -327,6 +340,8 @@ def engine_forward(
# main thread. That is why `engine_` is prepended to each of the timer phase
# names in this context
with self.timer_manager.new_timer_context(total_inference=False) as timer:
streamer = context.get("streamer")

if not self.multitoken_engine.kv_cache_enabled:
tokens, prompt_logits = self.multitoken_engine(engine_inputs)
return numpy.array([tokens]), prompt_logits
Expand All @@ -336,6 +351,9 @@ def engine_forward(
with timer.time(_TextGenerationTimings.PROMPT_PREFILL):
tokens, prompt_logits = self.prompt_inference(engine_inputs)

if streamer is not None:
streamer.put(numpy.array(tokens))

# create the generated output
max_tokens = (
self.max_generated_tokens
Expand All @@ -354,12 +372,18 @@ def engine_forward(
generated_tokens.append(token)
generated_logits.append(logits)

if streamer is not None:
streamer.put(numpy.array([token]))

if (
token == self.tokenizer.eos_token_id
and not self.force_max_tokens
):
break

if streamer is not None:
streamer.end()

return numpy.array([generated_tokens]), numpy.concatenate(
generated_logits, axis=1
)
Expand Down
2 changes: 1 addition & 1 deletion tests/deepsparse/pipelines/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_pipeline_call_is_async(engine_mock):
executor = ThreadPoolExecutor(max_workers=1)
pipeline = Pipeline.create("token_classification", batch_size=1, executor=executor)

def sleep_then_engine_forward(xs):
def sleep_then_engine_forward(xs, context):
# each call to engine_forward also sleeps
time.sleep(20 / 1000)
return pipeline.engine(xs)
Expand Down
Loading