Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Text Generation] Causal Mask Feature Branch #1126

Merged
merged 12 commits into from
Jul 27, 2023
Merged
27 changes: 22 additions & 5 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ def onnx_input_names_no_cache(self) -> List[str]:
if not name.startswith(_CACHE_INPUT_NAME)
]

@property
def num_non_blank_cache_entries(self) -> int:
"""
:return a number of non-blank entries in the
kv cache
"""
return self.kv_cache.num_non_blank_entries

def __call__(
self,
inp: List[numpy.ndarray],
Expand Down Expand Up @@ -173,10 +181,16 @@ def transfer_cache_state(self, cache: DecoderKVCache):
information from another NLDecoderEngine. Call this method when
you want to transfer the kv cache state from one engine to another.

This method will also automatically set the kv cache capacity to
the appropriate value for the new engine.

:param cache: The `DecoderKVCache` object to transfer to the engine
from
"""
self.kv_cache = copy.deepcopy(cache)
cache_to_copy = copy.deepcopy(cache)
target_cache_capacity = self.sequence_length - self.input_ids_length
cache_to_copy.set_capacity(target_cache_capacity)
self.kv_cache = cache_to_copy

def overwrite_onnx_model_inputs(
self,
Expand Down Expand Up @@ -217,6 +231,11 @@ def overwrite_onnx_model_inputs(
external_input.type.tensor_type.shape.dim[2].dim_value = (
sequence_length - input_ids_length
)
elif external_input.name.startswith("causal_mask"):
external_input.type.tensor_type.shape.dim[
2
].dim_value = input_ids_length
external_input.type.tensor_type.shape.dim[3].dim_value = sequence_length
else:
raise ValueError(
f"Unexpected external input name: {external_input.name}"
Expand Down Expand Up @@ -283,10 +302,8 @@ def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]
self.reset_kv_cache()
kv_cache_state = self.kv_cache.cached_inputs

kv_cache_state["input_ids"] = inp[0]
kv_cache_state["attention_mask"] = inp[1]
if len(inp) == 3:
kv_cache_state["positions"] = inp[2]
for idx, input_name in enumerate(self.onnx_input_names_no_cache):
kv_cache_state[input_name] = inp[idx]

new_inp = [kv_cache_state[name] for name in self.engine.input_names]
return new_inp
Expand Down
144 changes: 115 additions & 29 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Type, Union
from typing import Generator, List, Optional, Tuple, Type, Union

import numpy
from pydantic import BaseModel, Field
Expand All @@ -26,7 +26,10 @@
from deepsparse.pipeline import DEEPSPARSE_ENGINE
from deepsparse.transformers.engines import NLDecoderEngine
from deepsparse.transformers.pipelines import TransformersPipeline
from deepsparse.transformers.utils.helpers import pad_to_fixed_length
from deepsparse.transformers.utils.helpers import (
create_causal_mask,
pad_to_fixed_length,
)


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -124,8 +127,7 @@ def __init__(
deterministic: bool = True,
sampling_temperature: float = 1.0,
max_generated_tokens: Optional[int] = 1024,
# TODO: Set this to 64 once we modify the OPT injection logic
prompt_processing_sequence_length: int = 128,
prompt_processing_sequence_length: int = 64,
force_max_tokens: bool = False,
use_deepsparse_cache: bool = True,
**kwargs,
Expand Down Expand Up @@ -153,12 +155,6 @@ def __init__(
)

if self.engine_type == DEEPSPARSE_ENGINE:
_LOGGER.warning(
"The support for deepsparse engine is limited "
f"for {self.__class__.__name__}. "
"The multi-token engine will not be "
"used for prompt processing."
)
if "WAND_OPT_FLAGS" not in os.environ:
os.environ["WAND_OPT_FLAGS"] = "default,~pyramids"

Expand Down Expand Up @@ -276,13 +272,16 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:

attention_mask = input_tokens["attention_mask"]

# TODO: Positions input is not required by BLOOM
# let's make it optional in the future
positions = attention_mask.cumsum(1) * attention_mask
positions -= 1 # assert that positions start at 0
positions_input = dict(positions=positions)

input_tokens = {**input_tokens, **positions_input}
causal_mask = create_causal_mask(
input_tokens["input_ids"], input_tokens["attention_mask"]
)

input_tokens = dict(
**input_tokens, positions=positions, causal_mask=causal_mask
)
onnx_input_names = self.multitoken_engine.onnx_input_names_no_cache
engine_input = self.tokens_to_engine_input(input_tokens, onnx_input_names)

Expand Down Expand Up @@ -391,20 +390,11 @@ def prompt_inference(
# to refrain from resetting if session id is being passed
self._reset_engines_cache()

# TODO: Multiple passes through the multitoken
# engine once the OPT injection is fixed
if (
len(tokens) > self.prompt_processing_sequence_length
and self.engine_type != DEEPSPARSE_ENGINE
):
# trim the input to the prompt size
engine_inputs = [
input[:, : self.prompt_processing_sequence_length]
for input in engine_inputs
]
new_token, new_logits = self.multitoken_engine(engine_inputs)
num_tokens_processed = self.prompt_processing_sequence_length
prompt_logits.append(new_logits)
if len(tokens) > self.prompt_processing_sequence_length:
for engine_inputs in self.engine_inputs_for_prefill(tokens):
new_token, new_logits = self.multitoken_engine(engine_inputs)
num_tokens_processed += self.prompt_processing_sequence_length
prompt_logits.append(new_logits)

if num_tokens_processed:
# transfer the cache state from the multi-token engine to the main engine
Expand Down Expand Up @@ -453,12 +443,108 @@ def autoregressive_inference(
if shift_positions_by_one:
positions -= 1
input_ids = numpy.array([[new_token]])
engine_inputs = [input_ids, attention_mask, positions]
causal_mask = create_causal_mask(input_ids, attention_mask)
engine_inputs = [input_ids, attention_mask, positions, causal_mask]

generated_token, generated_logits = self.engine(engine_inputs)

return generated_token, generated_logits

def engine_inputs_for_prefill(
self, tokens: List[int]
) -> Generator[List[numpy.ndarray], None, None]:
"""
Takes a list of tokens and creates a generator
of engine_inputs for the multitoken engine.

1. The input tokens first get batched into chunks of
size self.prompt_processing_sequence_length. This is to
ensure that they match the expected input size by the
multitoken engine. Any remaining tokens are discarded.

2. Every created engine_inputs batch is then created:

- input_ids: by taking a batch of tokens

- attention_mask: by creating an appropriate mask,
that will have the amount of unmasked entries equal to
the sum of:
a) the number of tokens in the batch
(self.prompt_processing_sequence_length)
b) the number of non-blank cache entries
(num_non_blank_cache_entries)
so that the attention_mask properly attends to the
current input tokens, as well as the previous cache
entries.

- positions: derived directly from the input_ids

- causal_mask: derived from the input_ids and attention_mask

:param tokens: the list of tokens to process
:return: a generator of engine inputs
"""

num_batches = len(tokens) // self.prompt_processing_sequence_length

token_batches = [
tokens[i : i + self.prompt_processing_sequence_length]
for i in range(num_batches)
]

for idx, token_batch in enumerate(token_batches):
engine_inputs = []

for name in self.multitoken_engine.onnx_input_names_no_cache:
if name == "input_ids":
engine_input = numpy.array([token_batch])

elif name == "attention_mask":
num_cached_entries = (
self.multitoken_engine.num_non_blank_cache_entries
)

# create an empty attention mask
engine_input = numpy.zeros(
(1, self.sequence_length), dtype=numpy.int64
)
# fill it out with 1s (from the right), so that the number
# of unmaksed entries is equal to the sum of:
engine_input[
:,
-(
# ...the number of current input tokens...
self.prompt_processing_sequence_length
# ...and the number of the previous cache entries
+ num_cached_entries
) :,
] = 1
elif name == "causal_mask":
# delay creation of the causal mask
continue
elif name == "positions":
if self.prompt_processing_sequence_length == 1:
# we need to treat `positions` as if we were in
# the autoregressive mode
engine_input = numpy.array([[idx]], dtype=numpy.int64)
else:
engine_input = (
numpy.arange(self.prompt_processing_sequence_length)
.reshape(1, -1)
.astype(numpy.int64)
)

engine_inputs.append(engine_input)

# create the causal mask once we have the input_ids and attention_mask
if "causal_mask" in self.multitoken_engine.onnx_input_names_no_cache:
causal_mask = create_causal_mask(
input_ids=engine_inputs[0], attention_mask=engine_inputs[1]
)
engine_inputs.append(causal_mask)

yield engine_inputs

@property
def has_cache(self) -> bool:
"""
Expand Down
Loading
Loading