Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Text Generation] Automatically benchmark in auto-regressive setting #1142

Merged
merged 17 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/deepsparse/benchmark/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@
from deepsparse.cpu import cpu_architecture
from deepsparse.log import set_logging_level
from deepsparse.utils import (
assert_model_sequence_length_one,
generate_random_inputs,
has_model_kv_cache,
model_to_path,
override_onnx_input_shapes,
parse_input_shapes,
Expand Down Expand Up @@ -357,6 +359,11 @@ def benchmark_model(

orig_model_path = model_path
model_path = model_to_path(model_path)

if has_model_kv_cache(model_path):
_LOGGER.info("Found model that contains KV cache support.")
model_path = assert_model_sequence_length_one(model_path)
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved

num_streams = parse_num_streams(num_streams, num_cores, scenario)

# Compile the ONNX into a runnable model
Expand Down
90 changes: 17 additions & 73 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@
from typing import Any, Dict, List, Optional, Tuple

import numpy
import onnx
from transformers import AutoTokenizer

from deepsparse.engine import Context
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
from deepsparse.transformers.utils.helpers import generate_session_id, softmax
from deepsparse.utils.onnx import translate_onnx_type_to_numpy
from sparsezoo.utils.onnx import save_onnx
from deepsparse.transformers.utils.helpers import generate_session_id
from deepsparse.transformers.utils.helpers import (
overwrite_onnx_model_inputs_for_kv_cache_models as overwrite_onnx_model_inputs,
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
)
from deepsparse.transformers.utils.helpers import softmax
from deepsparse.utils.onnx import CACHE_INPUT_NAME, CACHE_OUTPUT_NAME


_LOGGER = logging.getLogger(__name__)

__all__ = ["NLDecoderEngine"]

_CACHE_INPUT_NAME = "past_key_values"


class NLDecoderEngine:
"""
Expand Down Expand Up @@ -70,7 +70,11 @@ def __init__(
# flag to indicate if the model is quantized or not
self.kv_cache_data_type = None

onnx_file_path, output_indices_to_be_cached = self.overwrite_onnx_model_inputs(
(
onnx_file_path,
output_indices_to_be_cached,
kv_cache_data_type,
) = overwrite_onnx_model_inputs(
onnx_file_path=onnx_file_path,
batch_size=engine_args.get("batch_size", 1),
sequence_length=sequence_length,
Expand All @@ -79,6 +83,7 @@ def __init__(
kv_cache_enabled = False
if sum(output_indices_to_be_cached):
kv_cache_enabled = True
self.kv_cache_data_type = kv_cache_data_type
if use_deepsparse_cache and engine_type == DEEPSPARSE_ENGINE:
# inform the engine, that are using the kv cache
engine_args["cache_output_bools"] = output_indices_to_be_cached
Expand Down Expand Up @@ -123,7 +128,7 @@ def onnx_input_names_no_cache(self) -> List[str]:
return [
name
for name in self.engine.input_names
if not name.startswith(_CACHE_INPUT_NAME)
if not name.startswith(CACHE_INPUT_NAME)
]

def __call__(
Expand Down Expand Up @@ -176,67 +181,6 @@ def transfer_cache_state(self, cache: DecoderKVCache):
"""
self.kv_cache = copy.deepcopy(cache)

def overwrite_onnx_model_inputs(
self,
onnx_file_path: str,
sequence_length: int,
input_ids_length: int,
batch_size: int = 1,
) -> Tuple[str, List[int]]:
"""
Enforces the appropriate input shapes for the onnx model, as well as
checks whether kv cache is enabled or not.

:param onnx_file_path: The path to the onnx model file that will be
overwritten with the new input shapes
:param batch_size: The batch size to use for the input
:param sequence_length: The sequence length to use for the input
:param input_ids_length: The length of input_ids
:return: The path to the onnx model file that has been overwritten
with the new input shapes, as well as the indices of the inputs
that should be cached
"""
model = onnx.load(onnx_file_path, load_external_data=False)
initializer_input_names = set(node.name for node in model.graph.initializer)
external_inputs = [
inp for inp in model.graph.input if inp.name not in initializer_input_names
]
for external_input in external_inputs:
# overwrite the batch size for all the inputs
external_input.type.tensor_type.shape.dim[0].dim_value = batch_size

if external_input.name in ["input_ids", "positions"]:
external_input.type.tensor_type.shape.dim[
1
].dim_value = input_ids_length
elif external_input.name == "attention_mask":
external_input.type.tensor_type.shape.dim[1].dim_value = sequence_length
elif external_input.name.startswith(_CACHE_INPUT_NAME):
external_input.type.tensor_type.shape.dim[2].dim_value = (
sequence_length - input_ids_length
)
else:
raise ValueError(
f"Unexpected external input name: {external_input.name}"
)

_LOGGER.info(
"Overwriting in-place the input shapes "
f"of the transformer model at {onnx_file_path}"
)
save_onnx(model, onnx_file_path)

output_indices_to_be_cached = [
1 if inp.name.startswith("present") else 0 for inp in model.graph.output
]

kv_cache_elem_type = next(
inp for inp in model.graph.input if inp.name.startswith(_CACHE_INPUT_NAME)
).type.tensor_type.elem_type
self.kv_cache_data_type = translate_onnx_type_to_numpy(kv_cache_elem_type)

return onnx_file_path, output_indices_to_be_cached

def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
"""
Samples a token from the logits using the sampling temperature.
Expand Down Expand Up @@ -301,7 +245,7 @@ def update_kv_cache(
cache_onnx_names = [
name
for name in self.engine.input_names
if name.startswith(_CACHE_INPUT_NAME)
if name.startswith(CACHE_INPUT_NAME)
]
kv_cache_state = {
name: array for name, array in zip(cache_onnx_names, kv_cache_state)
Expand All @@ -319,7 +263,7 @@ def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]:
cache_engine_input_index = next(
i
for i, name in enumerate(self.engine.input_names)
if _CACHE_INPUT_NAME in name
if CACHE_INPUT_NAME in name
)
batch_size, num_attention_heads, _, hidden_dims = self.engine.input_shapes[
cache_engine_input_index
Expand All @@ -331,9 +275,9 @@ def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]:
)

cache_keys = [
output_name.replace("present", _CACHE_INPUT_NAME)
output_name.replace(CACHE_OUTPUT_NAME, CACHE_INPUT_NAME)
for output_name in self.engine.output_names
if output_name.startswith("present")
if output_name.startswith(CACHE_OUTPUT_NAME)
]
return {key: empty_kv_cache_tensor for key in cache_keys}

Expand Down
100 changes: 89 additions & 11 deletions src/deepsparse/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,100 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import uuid
from typing import List, Optional, Tuple

import numpy
import onnx

from deepsparse.utils.onnx import (
CACHE_INPUT_NAME,
default_cached_outputs,
translate_onnx_type_to_numpy,
)
from sparsezoo.utils import save_onnx


__all__ = [
"overwrite_onnx_model_inputs_for_kv_cache_models",
"generate_session_id",
"pad_to_fixed_length",
"softmax",
]

_LOGGER = logging.getLogger(__name__)


def overwrite_onnx_model_inputs_for_kv_cache_models(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
onnx_file_path: str,
sequence_length: int = 128,
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
input_ids_length: int = 1,
batch_size: int = 1,
) -> Tuple[str, List[int], Optional[numpy.dtype]]:
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
"""
Enforces the appropriate input shapes for the onnx model, as well as
checks whether kv cache is enabled or not.

:param onnx_file_path: The path to the onnx model file that will be
overwritten with the new input shapes
:param batch_size: The batch size to use for the input
:param sequence_length: The sequence length to use for the input
:param input_ids_length: The length of input_ids
:return: A tuple that contains:
- the path to the onnx model file that has been overwritten
with the new input shapes
- boolean list, where elements are set to True if the
corresponding model output should be cached or False
if not.
- the data type of the kv cache. If the model does not
use kv cache, then the data type is None
"""
model = onnx.load(onnx_file_path, load_external_data=False)
initializer_input_names = set(node.name for node in model.graph.initializer)
external_inputs = [
inp for inp in model.graph.input if inp.name not in initializer_input_names
]
for external_input in external_inputs:
# overwrite the batch size for all the inputs
external_input.type.tensor_type.shape.dim[0].dim_value = batch_size

if external_input.name in ["input_ids", "positions"]:
external_input.type.tensor_type.shape.dim[1].dim_value = input_ids_length
elif external_input.name == "attention_mask":
external_input.type.tensor_type.shape.dim[1].dim_value = sequence_length
elif external_input.name.startswith(CACHE_INPUT_NAME):
external_input.type.tensor_type.shape.dim[2].dim_value = (
sequence_length - input_ids_length
)
else:
raise ValueError(f"Unexpected external input name: {external_input.name}")

_LOGGER.info(
"Overwriting in-place the input shapes "
f"of the transformer model at {onnx_file_path}"
)
save_onnx(model, onnx_file_path)

output_indices_to_be_cached = default_cached_outputs(model)

__all__ = ["softmax", "generate_session_id", "pad_to_fixed_length"]
kv_cache_data_type = None
if sum(output_indices_to_be_cached):
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
kv_cache_elem_type = next(
inp for inp in model.graph.input if inp.name.startswith(CACHE_INPUT_NAME)
).type.tensor_type.elem_type
kv_cache_data_type = translate_onnx_type_to_numpy(kv_cache_elem_type)

return onnx_file_path, output_indices_to_be_cached, kv_cache_data_type


def generate_session_id() -> str:
"""
Generate uuid for session id. This is used to
identify the kv cache session for the user
"""
session_id = str(uuid.uuid4())
return session_id


def softmax(x: numpy.ndarray) -> numpy.ndarray:
Expand All @@ -36,15 +123,6 @@ def softmax(x: numpy.ndarray) -> numpy.ndarray:
return numerator / denominator


def generate_session_id() -> str:
"""
Generate uuid for session id. This is used to
identify the kv cache session for the user
"""
session_id = str(uuid.uuid4())
return session_id


def pad_to_fixed_length(
array: numpy.ndarray, max_len: int, axis: int = 0, value: int = 0
) -> numpy.ndarray:
Expand Down
52 changes: 42 additions & 10 deletions src/deepsparse/utils/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,17 @@
"truncate_onnx_model",
"truncate_onnx_embedding_model",
"default_cached_outputs",
"has_model_kv_cache",
"CACHE_INPUT_NAME",
"CACHE_OUTPUT_NAME",
"assert_model_sequence_length_one",
]

_LOGGER = logging.getLogger(__name__)

CACHE_INPUT_NAME = "past_key_values"
CACHE_OUTPUT_NAME = "present"


@contextlib.contextmanager
def save_onnx_to_temp_files(model: onnx.ModelProto, with_external_data=False) -> str:
Expand Down Expand Up @@ -475,20 +482,45 @@ def truncate_onnx_embedding_model(
return output_filepath, tmp_file


def default_cached_outputs(model_path: str) -> List[bool]:
def default_cached_outputs(model: Union[str, ModelProto]) -> List[bool]:
"""
Get a list of bools that indicate which outputs should be cached.
The elements that are set to True correspond to cached outputs,
the rest are set to False.

:param model_path: Path to a model
:return A list of bools that indicates caching of all outputs except the first one.
:return A list of bools that indicate which outputs should be cached.
"""

outputs = list(onnx.load(model_path).graph.output)
model = (
onnx.load(model, load_external_data=False) if isinstance(model, str) else model
)
outputs = model.graph.output
assert len(outputs) > 0

# Create a boolean list of every output of the
# model [logits, key0, value0, key1, value1, ..., keyN, valueN]
cached_outputs = [True for i in range(len(outputs))]
return [output.name.startswith(CACHE_OUTPUT_NAME) for output in outputs]


def has_model_kv_cache(model: Union[str, ModelProto]) -> bool:
"""
Check whether a model has a KV cache support.

# Assume first input is logits and logits ought not to be cached
cached_outputs[0] = False
:param model_path: Path to a model or a model proto.
:return True if the model has a KV cache support, False otherwise.
"""
return bool(sum(default_cached_outputs(model)))
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved


def assert_model_sequence_length_one(model_path: str) -> str:
"""
Takes a path to an onnx model and enforces that it has
static input dimensions.

:param model_path: Path to a model.
:return: Path to the model with static input dimensions.
"""
from deepsparse.transformers.utils.helpers import (
overwrite_onnx_model_inputs_for_kv_cache_models,
)

return cached_outputs
onnx_file_path, _, _ = overwrite_onnx_model_inputs_for_kv_cache_models(model_path)
return onnx_file_path
Loading