Skip to content

Commit

Permalink
[Text Generation] Automatically benchmark in auto-regressive setting (#…
Browse files Browse the repository at this point in the history
…1142)

* initial commit

* improve logging docstring

* more verbose logging

* add sequence_length as variable

* fixed type annotations and avoided overwriting inputs when no sequence_length is passed

* fix bad merge

* tested

* update defaults

* address Luka comments

---------

Co-authored-by: Luka Govedic <luka.govedic@gmail.com>
  • Loading branch information
dbogunowicz and ProExpertProg committed Aug 24, 2023
1 parent 8ccfc6f commit 703b47f
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 17 deletions.
62 changes: 62 additions & 0 deletions src/deepsparse/benchmark/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@
zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/base-none \
--input_shapes "[1,512],[1,512],[1,512]"
##########
Example on a CodeGen (model with KV cache support)
from SparseZoo with input_ids_length 10 and sequence length 256:
deepsparse.benchmark \
zoo:nlg/text_generation/codegen_mono-350m/pytorch/
huggingface/bigpython_bigquery_thepile/pruned50-none
--input_ids_length 10 --sequence_length 256
##########
Example on local ONNX model:
deepsparse.benchmark /PATH/TO/model.onnx
Expand Down Expand Up @@ -110,8 +118,10 @@
from deepsparse.log import set_logging_level
from deepsparse.utils import (
generate_random_inputs,
has_model_kv_cache,
model_to_path,
override_onnx_input_shapes,
overwrite_cache_model_inputs,
parse_input_shapes,
)

Expand Down Expand Up @@ -143,6 +153,26 @@ def parse_args():
default=1,
help="The batch size to run the analysis for. Must be greater than 0",
)

parser.add_argument(
"-seq_len",
"--sequence_length",
type=int,
default=2048,
help="The sequence length to run the "
"KV cache supported model benchmarks for. "
"Must be greater than 0, default is 2048",
)

parser.add_argument(
"-input_ids_len",
"--input_ids_length",
type=int,
default=1,
help="The input ids length to run the "
"KV cache supported model benchmarks for. "
"Must be greater than 0, default is 1",
)
parser.add_argument(
"-i",
"-shapes",
Expand Down Expand Up @@ -265,6 +295,8 @@ def load_custom_engine(custom_engine_identifier: str):
def benchmark_model(
model_path: str,
batch_size: int = 1,
sequence_length: int = 2048,
input_ids_length: int = 1,
input_shapes: str = "",
num_cores: int = None,
scenario: str = "sync",
Expand All @@ -290,6 +322,28 @@ def benchmark_model(

orig_model_path = model_path
model_path = model_to_path(model_path)

if has_model_kv_cache(model_path):
if batch_size != 1:
raise ValueError(
"Unable to run models with KV cache support "
"for batch size different than one."
"Please set batch size to 1 and try again"
)

_LOGGER.info(
"Found model with KV cache support. "
"Benchmarking the autoregressive model with "
f"input_ids_length: {input_ids_length} and "
f"sequence length: {sequence_length}."
)

model_path, _, _ = overwrite_cache_model_inputs(
model_path=model_path,
input_ids_length=input_ids_length,
sequence_length=sequence_length,
)

num_streams = parse_num_streams(num_streams, num_cores, scenario)

# Compile the ONNX into a runnable model
Expand Down Expand Up @@ -351,6 +405,8 @@ def benchmark_model(
"orig_model_path": orig_model_path,
"model_path": model_path,
"batch_size": batch_size,
"sequence_length": sequence_length,
"input_ids_length": input_ids_length,
"input_shapes": input_shapes,
"num_cores": num_cores,
"scenario": scenario,
Expand All @@ -376,6 +432,8 @@ def main():

result = benchmark_model(
model_path=args.model_path,
sequence_length=args.sequence_length,
input_ids_length=args.input_ids_length,
batch_size=args.batch_size,
input_shapes=args.input_shapes,
num_cores=args.num_cores,
Expand All @@ -392,6 +450,10 @@ def main():
# Results summary
print("Original Model Path: {}".format(args.model_path))
print("Batch Size: {}".format(args.batch_size))
if args.sequence_length is not None:
print("Sequence Length: {}".format(args.sequence_length))
if args.input_ids_length is not None:
print("Input IDs Length: {}".format(args.input_ids_length))
print("Scenario: {}".format(args.scenario))
print(
"Throughput (items/sec): {:.4f}".format(
Expand Down
19 changes: 9 additions & 10 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,16 @@
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
from deepsparse.transformers.utils.helpers import (
generate_session_id,
overwrite_onnx_model_inputs,
overwrite_onnx_model_inputs_for_kv_cache_models,
)
from deepsparse.utils.data import numpy_softmax
from deepsparse.utils.onnx import CACHE_INPUT_PREFIX, CACHE_OUTPUT_PREFIX


_LOGGER = logging.getLogger(__name__)

__all__ = ["NLDecoderEngine"]

_CACHE_INPUT_NAME = "past_key_values"


class NLDecoderEngine:
"""
Expand Down Expand Up @@ -69,17 +68,17 @@ def __init__(
):
# flag to indicate if the model is quantized or not
self.kv_cache_data_type = None

(
onnx_file_path,
output_indices_to_be_cached,
kv_cache_data_type,
) = overwrite_onnx_model_inputs(
) = overwrite_onnx_model_inputs_for_kv_cache_models(
onnx_file_path=onnx_file_path,
batch_size=engine_args.get("batch_size", 1),
sequence_length=sequence_length,
input_ids_length=input_ids_length,
)

kv_cache_enabled = False
if sum(output_indices_to_be_cached):
kv_cache_enabled = True
Expand Down Expand Up @@ -129,7 +128,7 @@ def onnx_input_names_no_cache(self) -> List[str]:
return [
name
for name in self.engine.input_names
if not name.startswith(_CACHE_INPUT_NAME)
if not name.startswith(CACHE_INPUT_PREFIX)
]

@property
Expand Down Expand Up @@ -284,7 +283,7 @@ def update_kv_cache(
cache_onnx_names = [
name
for name in self.engine.input_names
if name.startswith(_CACHE_INPUT_NAME)
if name.startswith(CACHE_INPUT_PREFIX)
]
kv_cache_state = {
name: array for name, array in zip(cache_onnx_names, kv_cache_state)
Expand All @@ -302,7 +301,7 @@ def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]:
cache_engine_input_index = next(
i
for i, name in enumerate(self.engine.input_names)
if _CACHE_INPUT_NAME in name
if CACHE_INPUT_PREFIX in name
)
batch_size, num_attention_heads, _, hidden_dims = self.engine.input_shapes[
cache_engine_input_index
Expand All @@ -314,9 +313,9 @@ def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]:
)

cache_keys = [
output_name.replace("present", _CACHE_INPUT_NAME)
output_name.replace(CACHE_OUTPUT_PREFIX, CACHE_INPUT_PREFIX)
for output_name in self.engine.output_names
if output_name.startswith("present")
if output_name.startswith(CACHE_OUTPUT_PREFIX)
]
return {key: empty_kv_cache_tensor for key in cache_keys}

Expand Down
9 changes: 4 additions & 5 deletions src/deepsparse/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import uuid
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy
import onnx
Expand All @@ -24,21 +23,21 @@


__all__ = [
"overwrite_onnx_model_inputs_for_kv_cache_models",
"generate_session_id",
"pad_to_fixed_length",
"create_causal_mask",
"overwrite_onnx_model_inputs",
]

_LOGGER = logging.getLogger(__name__)


def overwrite_onnx_model_inputs(
def overwrite_onnx_model_inputs_for_kv_cache_models(
onnx_file_path: str,
sequence_length: int,
input_ids_length: int,
batch_size: int = 1,
) -> Tuple[str, List[int]]:
) -> Tuple[str, List[int], Optional[int]]:
"""
Enforces the appropriate input shapes for the onnx model, as well as
checks whether kv cache is enabled or not.
Expand Down
55 changes: 55 additions & 0 deletions src/deepsparse/utils/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
"truncate_onnx_model",
"truncate_onnx_embedding_model",
"default_cached_outputs",
"has_model_kv_cache",
"overwrite_cache_model_inputs",
"CACHE_INPUT_PREFIX",
"CACHE_OUTPUT_PREFIX",
]
Expand Down Expand Up @@ -494,3 +496,56 @@ def default_cached_outputs(model_path: str) -> List[bool]:
assert len(output_names) > 0

return [name.startswith(CACHE_OUTPUT_PREFIX) for name in output_names]


def has_model_kv_cache(model: Union[str, ModelProto]) -> bool:
"""
Check whether a model has a KV cache support.
:param model_path: Path to a model or a model proto.
:return True if the model has a KV cache support, False otherwise.
"""
return bool(any(default_cached_outputs(model)))


def overwrite_cache_model_inputs(
model_path: str,
input_ids_length: int,
sequence_length: int,
) -> Tuple[str, List[int], Optional[int]]:
"""
Takes a path to an onnx model and enforces that it has
static input dimensions.
:param model_path: Path to a model.
:param input_ids_length: The input_ids length to overwrite the model with.
:param sequence_length: The sequence length to overwrite the model with.
:return: A tuple that contains:
- the path to the onnx model file that has been overwritten
with the new input shapes
- boolean list, where elements are set to True if the
corresponding model output should be cached or False
if not.
- the data type of the kv cache. If the model does not
use kv cache, then the data type is None
"""
from deepsparse.transformers.utils.helpers import (
overwrite_onnx_model_inputs_for_kv_cache_models,
)

assert input_ids_length < sequence_length, (
f"input_ids_length {input_ids_length} "
f"must be less than sequence_length {sequence_length}"
)

(
onnx_file_path,
output_indices_to_be_cached,
kv_cache_data_type,
) = overwrite_onnx_model_inputs_for_kv_cache_models(
onnx_file_path=model_path,
sequence_length=sequence_length,
input_ids_length=input_ids_length,
)

return onnx_file_path, output_indices_to_be_cached, kv_cache_data_type
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from deepsparse import Pipeline
from deepsparse.transformers.utils.helpers import (
create_causal_mask,
overwrite_onnx_model_inputs,
overwrite_onnx_model_inputs_for_kv_cache_models,
)
from deepsparse.utils.onnx import CACHE_INPUT_PREFIX
from sparsezoo import Model
Expand Down Expand Up @@ -216,7 +216,7 @@ def _get_cache_state_ort_kv_cache(model_onnx_path, sequence, model_name):

# setup model and session
# (run full sequence inference)
overwrite_onnx_model_inputs(
overwrite_onnx_model_inputs_for_kv_cache_models(
model_onnx_path, sequence_length=128, input_ids_length=128
)
sess = onnxruntime.InferenceSession(model_onnx_path)
Expand Down

0 comments on commit 703b47f

Please sign in to comment.