Skip to content

Commit

Permalink
Address comments from @bfineran
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Aug 25, 2023
1 parent 9bfe1c6 commit 4d8c42a
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,10 @@ class Config:
"`streamer.put(token_ids)` and the streamer is responsible "
"for any further processing.",
)
callback: Optional[Callable[[Any], bool]] = Field(
callback: Optional[Callable[[Any], Union[bool, Any]]] = Field(
default=None,
description="Callable that will be invoked "
"on each generated token. The callable must return a "
"Boolean value. If invocation returns `True`, the "
"generation will continue. If the callable returns "
"on each generated token. If the callable returns "
"`False`, the generation will stop. Default is `None`.",
)

Expand Down Expand Up @@ -470,7 +468,7 @@ def engine_forward(
):
break

if callback is not None and not callback(token):
if callback is not None and callback(token) is False:
_LOGGER.info(
"callback %s returned False, stopping generation."
% callback.__qualname__
Expand Down

0 comments on commit 4d8c42a

Please sign in to comment.