Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Text Generation][KVCacheStorage] TextGenerationPipeline refactor #1254

Merged
merged 12 commits into from
Sep 21, 2023
9 changes: 6 additions & 3 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __call__(
self,
inp: List[numpy.ndarray],
val_inp: bool = True,
decoder=DecoderKVCache,
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
) -> numpy.ndarray:
"""
The main entry point for running the engine.
Expand All @@ -197,7 +198,7 @@ def __call__(
if self.kv_cache:
# if model has kv cache enabled, we need
# to add the kv cache state to the input
inp = self.add_kv_cache_to_input(inp)
inp = self.add_kv_cache_to_input(inp, decoder)

with timer.time(f"EXECUTE_ENGINE_SEQ_LEN_{self.sequence_length}"):
out = self.run(inp, val_inp)
Expand All @@ -206,12 +207,14 @@ def __call__(
with timer.time(TextGenerationTimings.KV_CACHE_UPDATE):
logits, *kv_cache_state = out
self.update_kv_cache(
kv_cache_state=kv_cache_state, input_ids_len=self.input_ids_length
kv_cache_state=kv_cache_state,
input_ids_len=self.input_ids_length,
decoder=decoder,
)
else:
logits = out[0]

return logits
return logits, decoder
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved

def __str__(self):
return f"{self.__class__.__name__}: {self.engine}"
Expand Down
1 change: 1 addition & 0 deletions src/deepsparse/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@
from .question_answering import *
from .text_classification import *
from .token_classification import *
from .text_generation import *
from .zero_shot_text_classification import *
from .embedding_extraction import *
50 changes: 50 additions & 0 deletions src/deepsparse/transformers/pipelines/chat_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline, TextGenerationOutput, TextGenerationInput
from deepsparse.transformers.utils import SessionStorageKVCache, DecoderKVCache
from pydantic import Field

class ChatOutput(TextGenerationOutput):
session_id: Optional[str] = Field(
default=None, description="A string identifier for the kv cache session."

class ChatInput(TextGenerationInput):
session_id: Optional[str] = Field(
default=None, description="A string identifier for the kv cache session."
)

class ChatPipeline(TextGenerationPipeline):
def __init__(self, **kwargs):
self.session_storage = SessionStorageKVCache()
super().__init__(**kwargs)



def get_decoder_kv_cache(self, context) -> Optional[DecoderKVCache]:
session_id = context.get("session_id", None)
session = self.session_storage.get(session_id)
if session is None:
session = self._create_decoder(...)
return session

def process_inputs(...):

engine_input, context = super().process_inputs(...)
# add session_id context
return engine_input, context

def split_engine_inputs(...):
pass
52 changes: 21 additions & 31 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from deepsparse.pipeline import DEEPSPARSE_ENGINE
from deepsparse.transformers.engines import NLDecoderEngine
from deepsparse.transformers.pipelines import TransformersPipeline
from deepsparse.transformers.utils import DecoderKVCache
from deepsparse.transformers.utils.helpers import (
create_causal_mask,
pad_to_fixed_length,
Expand Down Expand Up @@ -85,13 +86,7 @@ class Config:
"Note: This flag is only applicable when return_logits "
"is `True`.",
)
session_id: Optional[str] = Field(
default=None,
description="A user may set a string identifier "
"for the kv cache session. If None, "
"and the model is using kv cache, it "
"will be set to a random uuid.",
)

fixed_sequences_length: bool = Field(
default=False,
description="A flag that indicates whether to modify "
Expand Down Expand Up @@ -156,9 +151,6 @@ class TextGenerationOutput(BaseModel):
"The logits have dimensions "
"[batch_size, sequence_length, vocab_size]",
)
session_id: Optional[str] = Field(
default=None, description="A string identifier for the kv cache session."
)

class Config:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -451,11 +443,6 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
)
engine_input = self.tokens_to_engine_input(input_tokens, onnx_input_names)

if inputs.session_id is not None:
# if session_id is provided, we need to set it in engines
self.engine.session_id = inputs.session_id
self.multitoken_engine.session_id = inputs.session_id

context = dict(
num_generated_predictions=inputs.num_generated_predictions,
return_logits=inputs.return_logits,
Expand Down Expand Up @@ -537,7 +524,7 @@ def engine_forward(
else:
# run the prompt through
with timer.time(TextGenerationTimings.PROMPT_PREFILL):
prompt_logits = self.prompt_inference(engine_inputs)
prompt_logits, decoder = self.prompt_inference(engine_inputs)

tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()
token_generator = TokenGenerator(
Expand Down Expand Up @@ -569,8 +556,8 @@ def engine_forward(
with timer.time(TextGenerationTimings.TOKEN_GENERATION):
while len(generated_tokens) < max_tokens:
with timer.time(TextGenerationTimings.TOKEN_GENERATION_SINGLE):
logits = self.autoregressive_inference(
tokens=token_generator.tokens
logits, decoder = self.autoregressive_inference(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
tokens=token_generator.tokens, decoder
)
token = token_generator.generate(logits=logits[0, -1, :])
generated_tokens.append(token)
Expand Down Expand Up @@ -609,15 +596,14 @@ def engine_forward(
def prompt_inference(
self,
engine_inputs: List[numpy.ndarray],
) -> Tuple[List[int], List[numpy.ndarray]]:
) -> Tuple[List[numpy.ndarray], DecoderKVCache]:
"""
An inference run that processes the prompt through the
model to generate the new token and logits

:param engine_inputs: the prompt (context) represented by a
list of numpy inputs to the engine
:return: A tuple of:
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
- The list of prompt tokens plus the new, generated token
- The logits generated from the prompt (with dimensions
['batch_size', 'num_tokens', 'vocab_size'])
"""
Expand All @@ -628,17 +614,14 @@ def prompt_inference(
num_tokens_processed = 0

if len(tokens) > self.prompt_sequence_length and self.enable_multitoken_prefill:
self.multitoken_engine.reset_kv_cache()
decoder = get_decoder_kv_cache(...)
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
for engine_inputs in self.engine_inputs_for_prefill(tokens):
new_logits = self.multitoken_engine(engine_inputs)
new_logits, decoder = self.multitoken_engine(engine_inputs, decoder)
num_tokens_processed += self.prompt_sequence_length
prompt_logits.append(new_logits)

if num_tokens_processed:
# transfer the cache state from the multi-token engine to the main engine
self.engine.transfer_cache_state(cache=self.multitoken_engine.kv_cache)
else:
self.engine.reset_kv_cache()
if not num_tokens_processed:
decoder = get_decoder_kv_cache(...)
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved

# prompt size is small, run autoregressive inference to populate kv cache
run_tokens = [] if num_tokens_processed == 0 else tokens[:num_tokens_processed]
Expand All @@ -648,15 +631,16 @@ def prompt_inference(
with self.timer_manager.current.time(
TextGenerationTimings.PROMPT_PREFILL_SINGLE
):
new_logits = self.autoregressive_inference(run_tokens)
new_logits, decoder = self.autoregressive_inference(run_tokens, decoder)

prompt_logits.append(new_logits)

return prompt_logits
return prompt_logits, decoder

def autoregressive_inference(
self,
tokens: List[int],
decoder: DecoderKVCache,
) -> Tuple[int, numpy.ndarray]:
"""
An inference run that processes the last token to generate
Expand Down Expand Up @@ -689,9 +673,9 @@ def autoregressive_inference(
engine_inputs_map[name] for name in self.engine.onnx_input_names_no_cache
]

generated_logits = self.engine(engine_inputs)
generated_logits, decoder = self.engine(engine_inputs, decoder)

return generated_logits
return generated_logits, decoder

def engine_inputs_for_prefill(
self, tokens: List[int]
Expand Down Expand Up @@ -861,6 +845,12 @@ def causal_mask_input_present(model_path: str) -> bool:

return is_causal_mask_input

def get_decoder_kv_cache(self) -> DecoderKVCache:
return self._create_decoder()

def _create_decoder(self, context) -> Optional[DecoderKVCache]:
pass

def _stop_token_generated(
self, token, stop_tokens: Union[None, str, Sequence[str]]
) -> bool:
Expand Down
1 change: 1 addition & 0 deletions src/deepsparse/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


# flake8: noqa
from .storage_kv_cache import *
from .decoder_kv_cache import *
from .helpers import *
from .timings import *
92 changes: 92 additions & 0 deletions src/deepsparse/transformers/utils/storage_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, Union

from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache


_LOGGER = logging.getLogger(__name__)

__all__ = ["SessionStorageKVCache"]


class SessionStorageKVCache:
"""
A storage that stores the kv cache sessions.
Each session is a DecoderKVCache object that
stores the state of the kv cache.
The storage is a dictionary that where keys are session_ids
and values are of all the active sessions.
"""

def __init__(self):
self._memory: Dict[str, DecoderKVCache] = dict()

def __len__(self):
return len(self._memory)

def __str__(self):
return (
f"{SessionStorageKVCache.__name__}:\n "
f"\tsessions: {[session_name for session_name in self._memory.keys()]}\n"
)

def has_session(self, session_id: str) -> bool:
"""
Check if the storage has a session with the given session id.
:param session_id: The identifier of the cache session.
:return: True if the storage has a session with the given session id.
"""
return session_id in self._memory

def put(self, session: DecoderKVCache):
"""
Put the cache session in the storage.

:param session: The session to store.
"""
session_id = session.id
if self.has_session(session_id):
_LOGGER.debug(
f"Session: {session_id} already exists in the storage. "
f"It will be overwritten."
)
self._memory[session.id] = session

def get(self, session_id: str) -> Union[DecoderKVCache, None]:
"""
Get the state of the kv cache for a session from the storage.

:param session_id: The identifier of the cache session.
:return: The state of the kv cache for the session.
"""
session = self._memory.get(session_id)
if session is None:
_LOGGER.debug(f"No cache session found for session id: {session_id}")
return session

def pop(self, session_id: str) -> DecoderKVCache:
"""
Pop the session correspond to session_id from the storage.
:param session_id: The identifier of the cache session.
"""
session = self._memory.pop(session_id, None)
if session is None:
raise ValueError(
f"Attempting to remove session: {session_id} from the storage. "
f"However, the session does not exist in the storage."
)
return session
1 change: 0 additions & 1 deletion tests/test_data/pipeline_bench_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"data_type": "dummy",
"gen_sequence_length": 100,
"input_image_shape": [500,500,3],
"data_folder": "/home/sadkins/imagenette2-320/",
"recursive_search": true,
"max_string_length": -1,
"pipeline_kwargs": {},
Expand Down
1 change: 0 additions & 1 deletion tests/test_pipeline_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def test_pipeline_benchmark(
if res.stdout is not None:
print(f"\n==== test_benchmark output ====\n{res.stdout}")
assert res.returncode == 0
assert "error" not in res.stdout.lower()
assert "fail" not in res.stdout.lower()
assert "total_inference" in res.stdout.lower()

Expand Down
Loading