Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chat pipeline] session context manager #1276

Merged
merged 1 commit into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions src/deepsparse/transformers/pipelines/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextvars
import logging
from typing import Any, Dict, List, Tuple, Type, Union
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Tuple, Type, Union

import numpy
from pydantic import Field, validator
Expand All @@ -37,6 +39,7 @@


_LOGGER = logging.getLogger(__name__)
_SESSION_IDS_CONTEXT = contextvars.ContextVar("_SESSION_ID", default=None)

__all__ = ["ChatPipeline"]

Expand Down Expand Up @@ -117,6 +120,41 @@ def output_schema(self) -> Type[ChatOutput]:
"""
return ChatOutput

@contextmanager
def session(
self,
session_ids: Union[None, List[str], str] = None,
inference_batch_size: int = 1,
) -> Callable[[Any, Any], Any]:
"""
Context manager that sets and keeps a default session id(s) within
the context

example:
In the following - both responses in the context will share the same
session id
```
with chat_pipeline.session():
first_response = chat_pipeline("first prompt")
second_response = chat_pipeline("second prompt")
```

:param session_ids: actual value to set session ids to in context
must match the inference batch size. If not supplied, will
create default values. Default None
:param inference_batch_size: if generating default session ids, number
of session ids to create. default 1
"""

if session_ids is None:
session_ids = [generate_session_id() for _ in range(inference_batch_size)]

# set session_ids contextvar
token = _SESSION_IDS_CONTEXT.set(session_ids)
yield
# reset session_ids contextvar
_SESSION_IDS_CONTEXT.reset(token)

def process_inputs(
self, inputs: ChatInput
) -> Tuple[List[numpy.ndarray], Dict[str, Any]]:
Expand Down Expand Up @@ -234,7 +272,11 @@ def add_session_ids_to_engine_input(
:return: the engine input with the session ids
"""
session_ids = inputs.session_ids
if session_ids is None:
if session_ids is None and _SESSION_IDS_CONTEXT.get() is not None:
# respect directly setting session IDs first, then try to pull
# from context
session_ids = _SESSION_IDS_CONTEXT.get()
elif session_ids is None:
# session_ids is None, so we need to generate
# a session id for each input sequence
# TODO: Talk to Dipika whether this aligns with the
Expand Down
5 changes: 4 additions & 1 deletion src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,14 +416,17 @@ def parse_inputs(self, *args, **kwargs) -> TextGenerationInput:
these kwargs will be used to instantiate one
:return: parsed TextGenerationInput object
"""
if "sequences" in kwargs and "prompt" not in kwargs:
# support prompt and sequences interchangeably
kwargs["prompt"] = kwargs["sequences"]
if (
args
and not isinstance(args[0], TextGenerationInput)
and "prompt" not in kwargs
and "sequences" not in kwargs
):
# assume first argument is "sequences" (prompt) by default
kwargs["sequences"] = args[0]
kwargs["prompt"] = args[0]
args = args[1:]

return super().parse_inputs(*args, **kwargs)
Expand Down
45 changes: 45 additions & 0 deletions tests/deepsparse/transformers/pipelines/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from deepsparse import Pipeline


@pytest.mark.parametrize(
"pipeline_kwargs",
[
dict(
model_path="zoo:nlg/text_generation/codegen_mono-350m/pytorch/"
"huggingface/bigpython_bigquery_thepile/base-none",
engine_type="onnxruntime",
),
],
)
@pytest.mark.skip(reason="too heavy for now to run in gha")
def test_chat_pipeline_session_manager(pipeline_kwargs):
chat_pipeline = Pipeline.create(task="chat", **pipeline_kwargs)

with chat_pipeline.session():
output_1 = chat_pipeline(
prompt="first", generation_config=dict(max_new_tokens=1)
)
output_2 = chat_pipeline(
prompt="second", generation_config=dict(max_new_tokens=1)
)
# assert inferences in the same context share a session id
assert output_1.session_ids == output_2.session_ids

# test that follow-up inference has a different session id
output_3 = chat_pipeline(prompt="third", generation_config=dict(max_new_tokens=1))
assert output_3.session_ids != output_1.session_ids
Loading