Skip to content

Commit

Permalink
Fix a bug in Chat pipeline for streaming;
Browse files Browse the repository at this point in the history
+ add streaming support to chatbot
  • Loading branch information
rahul-tuli committed Sep 25, 2023
1 parent 17a2aa4 commit 785c1a4
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 14 deletions.
57 changes: 43 additions & 14 deletions examples/chatbot-llm/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
no_show_tokens_per_sec]
--history / --no_history Whether to include history during prompt
generation or not [default: history]
--stream / --no_stream Whether to stream output as generated or not
[default: no_stream]
--help Show this message and exit.
Expand All @@ -62,6 +64,10 @@
4) Disable history
python chatbot.py models/llama/deployment \
--no_history
5) Stream output
python chatbot.py models/llama/deployment \
--stream
"""
import click

Expand Down Expand Up @@ -108,13 +114,20 @@
default=True,
help="Whether to include history during prompt generation or not",
)
@click.option(
"--stream/--no_stream",
is_flag=True,
default=False,
help="Whether to stream output as generated or not",
)
def main(
model_path: str,
sequence_length: int,
sampling_temperature: float,
prompt_sequence_length: int,
show_tokens_per_sec: bool,
history: bool,
stream: bool,
):
"""
Command Line utility to interact with a text genration LLM in a chatbot style
Expand All @@ -125,31 +138,47 @@ def main(
"""
# chat pipeline, automatically adds history
task = "chat" if history else "text-generation"

pipeline = Pipeline.create(
task=task,
model_path=model_path,
sequence_length=sequence_length,
sampling_temperature=sampling_temperature,
prompt_sequence_length=prompt_sequence_length,
)

# continue prompts until a keyboard interrupt
while True:
input_text = input("User: ")
response = pipeline(**{"sequences": [input_text]})
print("Bot: ", response.generations[0].text)
response = pipeline(**{"sequences": [input_text]}, streaming=stream)
_display_bot_response(stream, response)

if show_tokens_per_sec:
times = pipeline.timer_manager.times
prefill_speed = (
1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"]
)
generation_speed = 1.0 / times["engine_token_generation_single"]
print(
f"[prefill: {prefill_speed:.2f} tokens/sec]",
f"[decode: {generation_speed:.2f} tokens/sec]",
sep="\n",
)
_display_generation_speed(prompt_sequence_length, pipeline)


def _display_generation_speed(prompt_sequence_length, pipeline):
# display prefill and generation speed(s) in tokens/sec
times = pipeline.timer_manager.times
prefill_speed = 1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"]
generation_speed = 1.0 / times["engine_token_generation_single"]
print(
f"[prefill: {prefill_speed:.2f} tokens/sec]",
f"[decode: {generation_speed:.2f} tokens/sec]",
sep="\n",
)


def _display_bot_response(stream: bool, response):
# print response from pipeline, streaming or not

print("Bot:", end=" ")
if stream:
for generation in response:
print(generation.generations[0].text, end=" ")
print()
else:
print(response.generations[0].text)


if "__main__" == __name__:
Expand Down
7 changes: 7 additions & 0 deletions src/deepsparse/transformers/pipelines/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import logging
from typing import Any, Dict, List, Tuple, Type, Union

Expand Down Expand Up @@ -146,6 +147,12 @@ def process_engine_outputs(
engine_outputs, **kwargs
)
# create the ChatOutput from the data provided
if inspect.isgenerator(text_generation_output):
return (
ChatOutput(**output.dict(), session_ids=session_ids)
for output in text_generation_output
)

return ChatOutput(**text_generation_output.dict(), session_ids=session_ids)

def engine_forward(
Expand Down

0 comments on commit 785c1a4

Please sign in to comment.