Skip to content

Commit

Permalink
[Feature] Add callback for Text Generation Pipelines (#1204)
Browse files Browse the repository at this point in the history
* Add callback argument to text-generation pipelines

* Add passing test

* Style

* Address comments from @bfineran

* update log level to debug
  • Loading branch information
rahul-tuli committed Aug 25, 2023
1 parent ec0cab0 commit 6a70b96
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
18 changes: 16 additions & 2 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import warnings
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union

import numpy
import onnx
Expand Down Expand Up @@ -91,6 +91,12 @@ class Config:
"`streamer.put(token_ids)` and the streamer is responsible "
"for any further processing.",
)
callback: Optional[Callable[[Any], Union[bool, Any]]] = Field(
default=None,
description="Callable that will be invoked "
"on each generated token. If the callable returns "
"`False`, the generation will stop. Default is `None`.",
)


class TextGenerationOutput(BaseModel):
Expand Down Expand Up @@ -377,6 +383,7 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
return_logits=inputs.return_logits,
streamer=inputs.streamer,
include_prompt_logits=inputs.include_prompt_logits,
callback=inputs.callback,
)
return engine_input, postprocessing_kwargs

Expand Down Expand Up @@ -443,7 +450,7 @@ def engine_forward(
if context.get("include_prompt_logits")
else [prompt_logits[-1]]
)

callback = context.get("callback")
with timer.time(_TextGenerationTimings.TOKEN_GENERATION):
while len(generated_tokens) < max_tokens:
with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE):
Expand All @@ -461,6 +468,13 @@ def engine_forward(
):
break

if callback is not None and callback(token) is False:
_LOGGER.debug(
"callback %s returned False, stopping generation."
% callback.__qualname__
)
break

if streamer is not None:
streamer.end()

Expand Down
20 changes: 20 additions & 0 deletions tests/deepsparse/transformers/pipelines/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def _initialize_kv_cache_state(model, length=0):
return kv_cache


START = 0 # global variable for dummy_callback


@pytest.mark.parametrize(
"use_deepsparse_cache",
[True, False],
Expand Down Expand Up @@ -143,6 +146,23 @@ def test_model_output_cache(self, setup):
self._test_cache_state(short_prompt, pipeline, model_name)
self._test_cache_state(long_prompt, pipeline, model_name)

def test_callback(self, setup):
pipeline, *_ = setup

def dummy_callback(token):
global START
START += 1
return START < 3

inputs = {
"sequences": "def fib(a, b, accumulator=0)",
"callback": dummy_callback,
"return_logits": True,
}

outs = pipeline(**inputs)
assert outs.logits.shape[1] == 3

def _test_cache_state(self, prompt, pipeline, model_name):
# make sure that the cache state after running a prompt
# is correct
Expand Down

0 comments on commit 6a70b96

Please sign in to comment.