Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Text Generation][KVCacheStorage] TextGenerationPipeline refactor #1254

Merged
merged 12 commits into from
Sep 21, 2023
206 changes: 63 additions & 143 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import numpy
from transformers import AutoTokenizer

from deepsparse.engine import Context
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
from deepsparse.transformers.utils.helpers import generate_session_id
from deepsparse.transformers.utils.timings import TextGenerationTimings
from deepsparse.utils import TimerManager
from deepsparse.utils.onnx import (
CACHE_INPUT_PREFIX,
CACHE_OUTPUT_PREFIX,
overwrite_onnx_model_inputs_for_kv_cache_models,
)

Expand All @@ -37,20 +35,16 @@

class NLDecoderEngine:
"""
The NLDecoderEngine (NaturalLanguageDecoderEngine) handles the
The NLDecoderEngine (Natural Language Decoder Engine) handles the
logic around the inference for Natural Language pipeline,
including batching and kv cache logic.
including batching and kv cache manipulation logic.

:param onnx_file_path: The path to the onnx model file
:param engine_type: The type of engine to use for the inference
:param engine_args: The arguments to pass to the engine
:param sequence_length: The maximum sequence length to run the engine for
:param input_ids_length: The maximum input ids length to run the engine for
:param engine_context: The context to run the engine in
:param sampling_temperature: The temperature to use for sampling
:param deterministic: Whether to use deterministic sampling
:param tokenizer: The tokenizer to used for engine inputs
:param engine_context: The context to run the engine in
:param internal_kv_cache: Whether to use the deepsparse
kv cache in the DecoderKVCache object or not
"""
Expand All @@ -62,9 +56,6 @@ def __init__(
engine_args: Dict[str, Any],
sequence_length: int,
input_ids_length: int,
tokenizer: AutoTokenizer,
sampling_temperature: float = 1.0,
deterministic: bool = True,
engine_context: Optional[Context] = None,
Satrat marked this conversation as resolved.
Show resolved Hide resolved
internal_kv_cache=False,
timer_manager: TimerManager = None,
Expand Down Expand Up @@ -98,30 +89,11 @@ def __init__(
)
self.timer_manager = timer_manager or TimerManager()
self.sequence_length = sequence_length
self.sampling_temperature = sampling_temperature
self.deterministic = deterministic
self.input_ids_length = input_ids_length
self.cache_length = sequence_length - input_ids_length
self.kv_cache_enabled = kv_cache_enabled
self.kv_cache = DecoderKVCache(internal_kv_cache) if kv_cache_enabled else None
self._freeze_first_position = self._should_freeze_first_position(tokenizer)
self._session_id = generate_session_id()
self._engine_type = engine_type

@property
def session_id(self) -> str:
"""
:return: The session id for the kv_cache if enabled
"""
return self._session_id

@session_id.setter
def session_id(self, session_id: str):
"""
:param session_id: The session id to set for the kv_cache
"""
self._session_id = session_id

@property
def onnx_input_names_no_cache(self) -> List[str]:
"""
Expand All @@ -135,25 +107,43 @@ def onnx_input_names_no_cache(self) -> List[str]:
]

@property
def num_non_blank_cache_entries(self) -> int:
def onnx_input_names_cached(self) -> List[str]:
"""
:return: The cached input names for the onnx model
"""
return [
name
for name in self.engine.input_names
if name.startswith(CACHE_INPUT_PREFIX)
]

@property
def cache_shape(self) -> Tuple[int, int, int, int]:
"""
:return A number of non-blank entries in the
kv cache
:return: The shape of the kv cache inputs
for the onnx model. The shape is
(batch_size, num_heads, sequence_length, hidden_size)
"""
return self.kv_cache.num_non_blank_entries
cache_engine_input_index = next(
i
for i, name in enumerate(self.engine.input_names)
if CACHE_INPUT_PREFIX in name
)
return self.engine.input_shapes[cache_engine_input_index]

@property
def internal_cache_active(self) -> bool:
def output_names(self) -> List[str]:
"""
:return: Whether the internal kv cache is active
:return: The output names for the onnx model
"""
return self.kv_cache_enabled and self.kv_cache.engine_internal_cache is not None
return self.engine.output_names

def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]:
def run(
self, inputs: List[numpy.ndarray], val_inp: bool, kv_cache: DecoderKVCache
) -> List[numpy.ndarray]:
"""
Run the engine with the given inputs.

If the self.internal_cache_active=True, the internal
If the kv_cache.engine_internal_cache=True, the internal
deepsparse kv cache management is enabled. In this case
the LIB.kv_cache class object will be passed to the engine
call as well. In this scenario also the inputs will not be
Expand All @@ -163,25 +153,27 @@ def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]

:param inputs: The inputs to run the engine with
:param val_inp: Whether the input is for validation or not
:param kv_cache: The kv cache object to use for the inference

:return: The output of the engine
"""

if self.internal_cache_active:
if bool(kv_cache.engine_internal_cache):
# conventionally, before dispatching
# inputs to the engine, we validate them
# if val_inp=True. However, in this case
# we want to pass the empty kv cache inputs
# (batch_size=0) to the engine. Therefore,
# we skip the validation
return self.engine._eng_net.execute_list_out(
inputs, self.kv_cache.engine_internal_cache
inputs, kv_cache.engine_internal_cache
)
# run the engine without the LIB.kv_cache object
return self.engine.run(inputs, val_inp)

def __call__(
self,
inp: List[numpy.ndarray],
kv_cache: Optional[DecoderKVCache] = None,
val_inp: bool = True,
) -> numpy.ndarray:
"""
Expand All @@ -190,23 +182,28 @@ def __call__(
:param inp: The input to run the engine with. We expect a
list of numpy arrays that contain the input ids,
attention mask, and position ids (optionally)
:param kv_cache: The DecoderKVCache object that contains
the kv cache state
:param val_inp: Whether the input is for validation or not

dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
:return: The generated token and corresponding logits
"""
timer = self.timer_manager.current
if self.kv_cache:
if self.kv_cache_enabled:
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
# if model has kv cache enabled, we need
# to add the kv cache state to the input
inp = self.add_kv_cache_to_input(inp)
inp = self.add_kv_cache_to_input(inp, kv_cache)

with timer.time(f"EXECUTE_ENGINE_SEQ_LEN_{self.sequence_length}"):
out = self.run(inp, val_inp)
out = self.run(inp, val_inp, kv_cache)

if self.kv_cache:
if self.kv_cache_enabled:
with timer.time(TextGenerationTimings.KV_CACHE_UPDATE):
logits, *kv_cache_state = out
self.update_kv_cache(
kv_cache_state=kv_cache_state, input_ids_len=self.input_ids_length
kv_cache_state=kv_cache_state,
input_ids_len=self.input_ids_length,
kv_cache=kv_cache,
)
Satrat marked this conversation as resolved.
Show resolved Hide resolved
else:
logits = out[0]
Expand All @@ -219,36 +216,11 @@ def __str__(self):
def __repr__(self):
return str(self)

def transfer_cache_state(self, cache: DecoderKVCache):
def add_kv_cache_to_input(
self, inp: List[numpy.ndarray], kv_cache: DecoderKVCache
) -> List[numpy.ndarray]:
"""
Transfers the kv cache state and the number of tokens processed
information from another NLDecoderEngine. Call this method when
you want to transfer the kv cache state from one engine to another.

This method will also automatically set the kv cache capacity to
the appropriate value for the new engine.

:param cache: The `DecoderKVCache` object to transfer to the engine
from
"""
cache.set_capacity(self.cache_length)
self.kv_cache = cache

def reset_kv_cache(self):
"""
Resets the kv cache state.
"""
kv_cache_state = self._initialize_kv_cache_state(self.cache_length)
self.kv_cache.setup(
session_id=self._session_id,
state=kv_cache_state,
num_processed_tokens=0,
freeze_first_position=self._freeze_first_position,
)

def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]:
"""
Takes the input and adds the past kv cache state to it.
Takes the input and adds the kv cache state to it.

If the internal kv cache is enabled, the kv cache state
will always be an empty array. This is just to make sure
Expand All @@ -262,17 +234,11 @@ def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]


:param inp: The input to the model
:param kv_cache: The kv cache object

dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
:return The input with the kv cache state added to it
"""
if self.internal_cache_active:
kv_cache_state = self._initialize_kv_cache_state(
self.cache_length, empty=True
)
else:
kv_cache_state = self.kv_cache.cached_inputs
if kv_cache_state is None:
self.reset_kv_cache()
kv_cache_state = self.kv_cache.cached_inputs
kv_cache_state = copy.copy(kv_cache.cached_inputs)

for idx, input_name in enumerate(self.onnx_input_names_no_cache):
kv_cache_state[input_name] = inp[idx]
Expand All @@ -284,75 +250,29 @@ def update_kv_cache(
self,
kv_cache_state: List[numpy.ndarray],
input_ids_len: int,
kv_cache: DecoderKVCache,
):
"""
Updates the state of the kv cache
Updates the kv cache using the new kv cache state.

If the internal kv cache is enabled, we refrain from
updating the kv cache state as it is being tracked internally
inside the engine. We only update the number of tokens processed.

:param kv_cache_state: The state of the kv cache storage
:param kv_cache_state: The new state of the kv cache storage
:param input_ids_len: The length of input_ids
:param kv_cache: The kv cache object to update
"""
if self.internal_cache_active:
self.kv_cache.total_num_processed_tokens += input_ids_len
if bool(kv_cache.engine_internal_cache):
kv_cache.total_num_processed_tokens += input_ids_len
return

cache_onnx_names = [
name
for name in self.engine.input_names
if name.startswith(CACHE_INPUT_PREFIX)
]
kv_cache_state = {
name: array for name, array in zip(cache_onnx_names, kv_cache_state)
name: array
for name, array in zip(self.onnx_input_names_cached, kv_cache_state)
}

self.kv_cache.update(
kv_cache.update(
state=kv_cache_state,
input_ids_len=input_ids_len,
)

def _initialize_kv_cache_state(
self, length: int, empty: bool = False
) -> Dict[str, numpy.ndarray]:
# initialize empty kv cache of size
# (batch_size, num_attention_heads, length, hidden_dims)
# if empty is True, we initialize empty kv_cache
# and set the batch_size to 0

cache_engine_input_index = next(
i
for i, name in enumerate(self.engine.input_names)
if CACHE_INPUT_PREFIX in name
)
batch_size, num_attention_heads, _, hidden_dims = self.engine.input_shapes[
cache_engine_input_index
]

empty_kv_cache_tensor = numpy.zeros(
(
batch_size if not empty else 0,
num_attention_heads,
length,
hidden_dims,
),
dtype=self.kv_cache_data_type,
)

cache_keys = [
output_name.replace(CACHE_OUTPUT_PREFIX, CACHE_INPUT_PREFIX)
for output_name in self.engine.output_names
if output_name.startswith(CACHE_OUTPUT_PREFIX)
]
return {key: empty_kv_cache_tensor for key in cache_keys}

@staticmethod
def _should_freeze_first_position(tokenizer) -> bool:
# use tokenizer to find out whether we should freeze the first position
# (True if tokenizer has a prefix for a BOS token)
if tokenizer is None:
return False
if hasattr(tokenizer, "add_bos_token"):
return True
return False
1 change: 1 addition & 0 deletions src/deepsparse/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@
from .question_answering import *
from .text_classification import *
from .token_classification import *
from .text_generation import *
from .zero_shot_text_classification import *
from .embedding_extraction import *
Loading
Loading