Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Text Generation] Causal Mask Feature Branch #1126

Merged
merged 12 commits into from
Jul 27, 2023
Merged
27 changes: 22 additions & 5 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ def onnx_input_names_no_cache(self) -> List[str]:
if not name.startswith(_CACHE_INPUT_NAME)
]

@property
def num_non_blank_cache_entries(self) -> int:
"""
:return a number of non-blank entries in the
kv cache
"""
return self.kv_cache.num_non_blank_entries

def __call__(
self,
inp: List[numpy.ndarray],
Expand Down Expand Up @@ -173,10 +181,16 @@ def transfer_cache_state(self, cache: DecoderKVCache):
information from another NLDecoderEngine. Call this method when
you want to transfer the kv cache state from one engine to another.

This method will also automatically set the kv cache capacity to
the appropriate value for the new engine.

:param cache: The `DecoderKVCache` object to transfer to the engine
from
"""
self.kv_cache = copy.deepcopy(cache)
cache_to_copy = copy.deepcopy(cache)
target_cache_capacity = self.sequence_length - self.input_ids_length
cache_to_copy.set_capacity(target_cache_capacity)
self.kv_cache = cache_to_copy

def overwrite_onnx_model_inputs(
self,
Expand Down Expand Up @@ -217,6 +231,11 @@ def overwrite_onnx_model_inputs(
external_input.type.tensor_type.shape.dim[2].dim_value = (
sequence_length - input_ids_length
)
elif external_input.name.startswith("causal_mask"):
external_input.type.tensor_type.shape.dim[
2
].dim_value = input_ids_length
external_input.type.tensor_type.shape.dim[3].dim_value = sequence_length
else:
raise ValueError(
f"Unexpected external input name: {external_input.name}"
Expand Down Expand Up @@ -283,10 +302,8 @@ def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]
self.reset_kv_cache()
kv_cache_state = self.kv_cache.cached_inputs

kv_cache_state["input_ids"] = inp[0]
kv_cache_state["attention_mask"] = inp[1]
if len(inp) == 3:
kv_cache_state["positions"] = inp[2]
for idx, input_name in enumerate(self.onnx_input_names_no_cache):
kv_cache_state[input_name] = inp[idx]

new_inp = [kv_cache_state[name] for name in self.engine.input_names]
return new_inp
Expand Down
139 changes: 110 additions & 29 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Type, Union
from typing import Generator, List, Optional, Tuple, Type, Union

import numpy
from pydantic import BaseModel, Field
Expand All @@ -26,7 +26,10 @@
from deepsparse.pipeline import DEEPSPARSE_ENGINE
from deepsparse.transformers.engines import NLDecoderEngine
from deepsparse.transformers.pipelines import TransformersPipeline
from deepsparse.transformers.utils.helpers import pad_to_fixed_length
from deepsparse.transformers.utils.helpers import (
create_causal_mask,
pad_to_fixed_length,
)


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -124,8 +127,7 @@ def __init__(
deterministic: bool = True,
sampling_temperature: float = 1.0,
max_generated_tokens: Optional[int] = 1024,
# TODO: Set this to 64 once we modify the OPT injection logic
prompt_processing_sequence_length: int = 128,
prompt_processing_sequence_length: int = 64,
force_max_tokens: bool = False,
use_deepsparse_cache: bool = True,
**kwargs,
Expand Down Expand Up @@ -153,12 +155,6 @@ def __init__(
)

if self.engine_type == DEEPSPARSE_ENGINE:
_LOGGER.warning(
"The support for deepsparse engine is limited "
f"for {self.__class__.__name__}. "
"The multi-token engine will not be "
"used for prompt processing."
)
if "WAND_OPT_FLAGS" not in os.environ:
os.environ["WAND_OPT_FLAGS"] = "default,~pyramids"

Expand Down Expand Up @@ -276,13 +272,16 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:

attention_mask = input_tokens["attention_mask"]

# TODO: Positions input is not required by BLOOM
# let's make it optional in the future
positions = attention_mask.cumsum(1) * attention_mask
positions -= 1 # assert that positions start at 0
positions_input = dict(positions=positions)

input_tokens = {**input_tokens, **positions_input}
causal_mask = create_causal_mask(
input_tokens["input_ids"], input_tokens["attention_mask"]
)

input_tokens = dict(
**input_tokens, positions=positions, causal_mask=causal_mask
)
onnx_input_names = self.multitoken_engine.onnx_input_names_no_cache
engine_input = self.tokens_to_engine_input(input_tokens, onnx_input_names)

Expand Down Expand Up @@ -391,20 +390,11 @@ def prompt_inference(
# to refrain from resetting if session id is being passed
self._reset_engines_cache()

# TODO: Multiple passes through the multitoken
# engine once the OPT injection is fixed
if (
len(tokens) > self.prompt_processing_sequence_length
and self.engine_type != DEEPSPARSE_ENGINE
):
# trim the input to the prompt size
engine_inputs = [
input[:, : self.prompt_processing_sequence_length]
for input in engine_inputs
]
new_token, new_logits = self.multitoken_engine(engine_inputs)
num_tokens_processed = self.prompt_processing_sequence_length
prompt_logits.append(new_logits)
if len(tokens) > self.prompt_processing_sequence_length:
for engine_inputs in self.engine_inputs_for_prefill(tokens):
new_token, new_logits = self.multitoken_engine(engine_inputs)
num_tokens_processed = self.prompt_processing_sequence_length
prompt_logits.append(new_logits)

if num_tokens_processed:
# transfer the cache state from the multi-token engine to the main engine
Expand Down Expand Up @@ -453,12 +443,103 @@ def autoregressive_inference(
if shift_positions_by_one:
positions -= 1
input_ids = numpy.array([[new_token]])
engine_inputs = [input_ids, attention_mask, positions]
causal_mask = create_causal_mask(input_ids, attention_mask)
engine_inputs = [input_ids, attention_mask, positions, causal_mask]

generated_token, generated_logits = self.engine(engine_inputs)

return generated_token, generated_logits

def engine_inputs_for_prefill(
self, tokens: List[int]
) -> Generator[List[numpy.ndarray], None, None]:
"""
Takes a list of tokens and creates a generator
of engine_inputs for the multitoken engine.

1. The input tokens first get batched into chunks of
size self.prompt_processing_sequence_length. This is to
ensure that they match the expected input size by the
multitoken engine. Any remaining tokens are discarded.

2. Every created engine_inputs batch is then created:

- input_ids: by taking a batch of tokens

- attention_mask: by creating an appropriate mask,
that will have the amount of unmasked entries equal to
the sum of:
a) the number of tokens in the batch
(self.prompt_processing_sequence_length)
b) the number of non-blank cache entries
(num_non_blank_cache_entries)
so that the attention_mask properly attends to the
current input tokens, as well as the previous cache
entries.

- positions: derived directly from the input_ids

- causal_mask: derived from the input_ids and attention_mask

:param tokens: the list of tokens to process
:return: a generator of engine inputs
"""

num_batches = len(tokens) // self.prompt_processing_sequence_length

token_batches = [
tokens[i : i + self.prompt_processing_sequence_length]
for i in range(num_batches)
]

for idx, token_batch in enumerate(token_batches):
engine_inputs = []

for name in self.multitoken_engine.onnx_input_names_no_cache:
if name == "input_ids":
engine_input = numpy.array([token_batch])

elif name == "attention_mask":
num_cached_entries = (
self.multitoken_engine.num_non_blank_cache_entries
)

# create an empty attention mask
engine_input = numpy.zeros(
(1, self.sequence_length), dtype=numpy.int64
)
# fill it out with 1s (from the right), so that the number
# of unmaksed entries is equal to the sum of:
engine_input[
:,
-(
# ...the number of current input tokens...
self.prompt_processing_sequence_length
# ...and the number of the previous cache entries
+ num_cached_entries
) :,
] = 1
elif name == "causal_mask":
# delay creation of the causal mask
continue
elif name == "positions":
engine_input = (
numpy.arange(self.prompt_processing_sequence_length)
.reshape(1, -1)
.astype(numpy.int64)
)

engine_inputs.append(engine_input)

# create the causal mask once we have the input_ids and attention_mask
if "causal_mask" in self.multitoken_engine.onnx_input_names_no_cache:
causal_mask = create_causal_mask(
input_ids=engine_inputs[0], attention_mask=engine_inputs[1]
)
engine_inputs.append(causal_mask)

yield engine_inputs

@property
def has_cache(self) -> bool:
"""
Expand Down
88 changes: 79 additions & 9 deletions src/deepsparse/transformers/utils/decoder_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ def __init__(self, use_deepsparse_cache: bool = False):
This object is used to handle the manipulation of the
key/value buffers on the DeepSparse engine side.
"""
self.total_num_processed_tokens = None

# assuming that kv cache arrays are of shape
# [batch_size, num_heads, sequence_length, hidden_size]
self._sequence_len_axis = SEQUENCE_LENGTH_AXIS
self._use_deepsparse_cache = use_deepsparse_cache
self._session_id = None
self._freeze_first_position = None
self._state = None
self._total_num_processed_tokens = None
self._kv_cache = None

def setup_session(
Expand Down Expand Up @@ -74,7 +75,7 @@ def setup_session(
self._session_id = session_id
self._state = state
self._freeze_first_position = freeze_first_position
self._total_num_processed_tokens = num_processed_tokens
self.total_num_processed_tokens = num_processed_tokens

if self._use_deepsparse_cache:
raise NotImplementedError("DeepSparse cache is not supported yet.")
Expand All @@ -97,20 +98,16 @@ def update_session(
input batch: (batch_size, length).
Corresponds to `input_ids.shape[1]`
"""
self._total_num_processed_tokens += input_ids_len
self.total_num_processed_tokens += input_ids_len
total_cache_capacity = state[list(state.keys())[0]].shape[
self._sequence_len_axis
]
# total_capacity = num_tokens (num of non-blank tokens) +
# + num_padded_entries (num of blank tokens)
num_padded_entries = max(
0, total_cache_capacity - self._total_num_processed_tokens
0, total_cache_capacity - self.total_num_processed_tokens
)
# we want to remove input_ids_len entries from the cache
# because len_input_ids + inp_cache_len = out_cache_len
# TODO: Make it more general once
# multitoken regression is supported
num_entries_to_delete = 1 # input_ids_len
num_entries_to_delete = input_ids_len

if num_padded_entries:
"""
Expand Down Expand Up @@ -175,6 +172,42 @@ def update_session(

self._state = state

def set_capacity(self, capacity: int):
"""
Enforce a new total capacity for the state
of cached inputs.

This means popping the old entries if the new
total capacity should lesser than the current one

or

Padding the state blank entries if the new
total capacity should be greater than the current one

:param capacity: The new length of the
self._state in the
`self._sequence_length_axis` dimension
"""
capacity_difference = self.capacity - capacity
state = self.cached_inputs

if capacity_difference > 0:
raise NotImplementedError(
"The scenario when capacity"
"needs to be expanded is not yet"
"supported."
)

elif capacity_difference < 0:
indices = [0] * abs(capacity_difference)
state = self._add_entries(state, indices=indices)

else:
pass

self._state = state

def _delete_entries(
self, state: Dict[str, Any], indices: List[int]
) -> Dict[str, Any]:
Expand All @@ -183,12 +216,49 @@ def _delete_entries(
state[key] = numpy.ascontiguousarray(state[key])
return state

def _add_entries(
self, state: Dict[str, Any], indices: List[int], padding_value: int = 0
) -> Dict[str, Any]:
for key, value in state.items():
# required to make sure that both
# quantized and non quantized caches
# are supported
state_dtype = value.dtype
# change padding_value dtype to match the state dtype
padding_value = numpy.array(padding_value, dtype=state_dtype)

state[key] = numpy.insert(
value, indices, padding_value, axis=self._sequence_len_axis
)
return state

@property
def session_id(self):
if self._session_id is None:
raise ValueError("Attempted to access session_id before setting up session")
return self._session_id

@property
def num_non_blank_entries(self):
"""
:return: the number of non-blank entries in the kv cache
"""
return min(self.capacity, self.total_num_processed_tokens)

@property
def capacity(self) -> int:
"""
Return the maximum number of kv cache entries
that the decoder can hold, until the old entries
start to get erased to make place for new entries

:return: the maximum number of kv cache entries
that the decoder can hold
"""
return self.cached_inputs[list(self.cached_inputs.keys())[0]].shape[
self._sequence_len_axis
]

@session_id.setter
def session_id(self, session_id: str):
self._session_id = session_id
Expand Down
Loading
Loading