Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Text Generation] Automatically benchmark in auto-regressive setting #1142

Merged
merged 17 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/deepsparse/benchmark/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@
zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/base-none \
--input_shapes "[1,512],[1,512],[1,512]"

##########
Example on a CodeGen (model with KV cache support)
from SparseZoo with input_ids_length 10 and sequence length 256:
deepsparse.benchmark \
zoo:nlg/text_generation/codegen_mono-350m/pytorch/
huggingface/bigpython_bigquery_thepile/pruned50-none
--input_ids_length 10 --sequence_length 256

##########
Example on local ONNX model:
deepsparse.benchmark /PATH/TO/model.onnx
Expand Down Expand Up @@ -110,8 +118,10 @@
from deepsparse.log import set_logging_level
from deepsparse.utils import (
generate_random_inputs,
has_model_kv_cache,
model_to_path,
override_onnx_input_shapes,
overwrite_cache_model_inputs,
parse_input_shapes,
)

Expand Down Expand Up @@ -143,6 +153,26 @@ def parse_args():
default=1,
help="The batch size to run the analysis for. Must be greater than 0",
)

parser.add_argument(
"-seq_len",
"--sequence_length",
type=int,
default=2048,
help="The sequence length to run the "
"KV cache supported model benchmarks for. "
"Must be greater than 0, default is 2048",
)

parser.add_argument(
"-input_ids_len",
"--input_ids_length",
type=int,
default=1,
help="The input ids length to run the "
"KV cache supported model benchmarks for. "
"Must be greater than 0, default is 1",
)
parser.add_argument(
"-i",
"-shapes",
Expand Down Expand Up @@ -265,6 +295,8 @@ def load_custom_engine(custom_engine_identifier: str):
def benchmark_model(
model_path: str,
batch_size: int = 1,
sequence_length: int = 2048,
input_ids_length: int = 1,
input_shapes: str = "",
num_cores: int = None,
scenario: str = "sync",
Expand All @@ -290,6 +322,28 @@ def benchmark_model(

orig_model_path = model_path
model_path = model_to_path(model_path)

if has_model_kv_cache(model_path):
if batch_size != 1:
raise ValueError(
"Unable to run models with KV cache support "
"for batch size different than one."
"Please set batch size to 1 and try again"
)

_LOGGER.info(
"Found model with KV cache support. "
"Benchmarking the autoregressive model with "
f"input_ids_length: {input_ids_length} and "
f"sequence length: {sequence_length}."
)

model_path, _, _ = overwrite_cache_model_inputs(
model_path=model_path,
input_ids_length=input_ids_length,
sequence_length=sequence_length,
)

num_streams = parse_num_streams(num_streams, num_cores, scenario)

# Compile the ONNX into a runnable model
Expand Down Expand Up @@ -351,6 +405,8 @@ def benchmark_model(
"orig_model_path": orig_model_path,
"model_path": model_path,
"batch_size": batch_size,
"sequence_length": sequence_length,
"input_ids_length": input_ids_length,
"input_shapes": input_shapes,
"num_cores": num_cores,
"scenario": scenario,
Expand All @@ -376,6 +432,8 @@ def main():

result = benchmark_model(
model_path=args.model_path,
sequence_length=args.sequence_length,
input_ids_length=args.input_ids_length,
batch_size=args.batch_size,
input_shapes=args.input_shapes,
num_cores=args.num_cores,
Expand All @@ -392,6 +450,10 @@ def main():
# Results summary
print("Original Model Path: {}".format(args.model_path))
print("Batch Size: {}".format(args.batch_size))
if args.sequence_length is not None:
print("Sequence Length: {}".format(args.sequence_length))
if args.input_ids_length is not None:
print("Input IDs Length: {}".format(args.input_ids_length))
print("Scenario: {}".format(args.scenario))
print(
"Throughput (items/sec): {:.4f}".format(
Expand Down
19 changes: 9 additions & 10 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,16 @@
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
from deepsparse.transformers.utils.helpers import (
generate_session_id,
overwrite_onnx_model_inputs,
overwrite_onnx_model_inputs_for_kv_cache_models,
)
from deepsparse.utils.data import numpy_softmax
from deepsparse.utils.onnx import CACHE_INPUT_PREFIX, CACHE_OUTPUT_PREFIX


_LOGGER = logging.getLogger(__name__)

__all__ = ["NLDecoderEngine"]

_CACHE_INPUT_NAME = "past_key_values"


class NLDecoderEngine:
"""
Expand Down Expand Up @@ -69,17 +68,17 @@ def __init__(
):
# flag to indicate if the model is quantized or not
self.kv_cache_data_type = None

(
onnx_file_path,
output_indices_to_be_cached,
kv_cache_data_type,
) = overwrite_onnx_model_inputs(
) = overwrite_onnx_model_inputs_for_kv_cache_models(
onnx_file_path=onnx_file_path,
batch_size=engine_args.get("batch_size", 1),
sequence_length=sequence_length,
input_ids_length=input_ids_length,
)

kv_cache_enabled = False
if sum(output_indices_to_be_cached):
kv_cache_enabled = True
Expand Down Expand Up @@ -129,7 +128,7 @@ def onnx_input_names_no_cache(self) -> List[str]:
return [
name
for name in self.engine.input_names
if not name.startswith(_CACHE_INPUT_NAME)
if not name.startswith(CACHE_INPUT_PREFIX)
]

@property
Expand Down Expand Up @@ -284,7 +283,7 @@ def update_kv_cache(
cache_onnx_names = [
name
for name in self.engine.input_names
if name.startswith(_CACHE_INPUT_NAME)
if name.startswith(CACHE_INPUT_PREFIX)
]
kv_cache_state = {
name: array for name, array in zip(cache_onnx_names, kv_cache_state)
Expand All @@ -302,7 +301,7 @@ def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]:
cache_engine_input_index = next(
i
for i, name in enumerate(self.engine.input_names)
if _CACHE_INPUT_NAME in name
if CACHE_INPUT_PREFIX in name
)
batch_size, num_attention_heads, _, hidden_dims = self.engine.input_shapes[
cache_engine_input_index
Expand All @@ -314,9 +313,9 @@ def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]:
)

cache_keys = [
output_name.replace("present", _CACHE_INPUT_NAME)
output_name.replace(CACHE_OUTPUT_PREFIX, CACHE_INPUT_PREFIX)
for output_name in self.engine.output_names
if output_name.startswith("present")
if output_name.startswith(CACHE_OUTPUT_PREFIX)
]
return {key: empty_kv_cache_tensor for key in cache_keys}

Expand Down
9 changes: 4 additions & 5 deletions src/deepsparse/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import uuid
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy
import onnx
Expand All @@ -24,21 +23,21 @@


__all__ = [
"overwrite_onnx_model_inputs_for_kv_cache_models",
"generate_session_id",
"pad_to_fixed_length",
"create_causal_mask",
"overwrite_onnx_model_inputs",
]

_LOGGER = logging.getLogger(__name__)


def overwrite_onnx_model_inputs(
def overwrite_onnx_model_inputs_for_kv_cache_models(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
onnx_file_path: str,
sequence_length: int,
input_ids_length: int,
batch_size: int = 1,
) -> Tuple[str, List[int]]:
) -> Tuple[str, List[int], Optional[int]]:
"""
Enforces the appropriate input shapes for the onnx model, as well as
checks whether kv cache is enabled or not.
Expand Down
55 changes: 55 additions & 0 deletions src/deepsparse/utils/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
"truncate_onnx_model",
"truncate_onnx_embedding_model",
"default_cached_outputs",
"has_model_kv_cache",
"overwrite_cache_model_inputs",
"CACHE_INPUT_PREFIX",
"CACHE_OUTPUT_PREFIX",
]
Expand Down Expand Up @@ -494,3 +496,56 @@ def default_cached_outputs(model_path: str) -> List[bool]:
assert len(output_names) > 0

return [name.startswith(CACHE_OUTPUT_PREFIX) for name in output_names]


def has_model_kv_cache(model: Union[str, ModelProto]) -> bool:
"""
Check whether a model has a KV cache support.

:param model_path: Path to a model or a model proto.
:return True if the model has a KV cache support, False otherwise.
"""
return bool(any(default_cached_outputs(model)))


def overwrite_cache_model_inputs(
model_path: str,
input_ids_length: int,
sequence_length: int,
) -> Tuple[str, List[int], Optional[int]]:
"""
Takes a path to an onnx model and enforces that it has
static input dimensions.

:param model_path: Path to a model.
:param input_ids_length: The input_ids length to overwrite the model with.
:param sequence_length: The sequence length to overwrite the model with.
:return: A tuple that contains:
- the path to the onnx model file that has been overwritten
with the new input shapes
- boolean list, where elements are set to True if the
corresponding model output should be cached or False
if not.
- the data type of the kv cache. If the model does not
use kv cache, then the data type is None
"""
from deepsparse.transformers.utils.helpers import (
overwrite_onnx_model_inputs_for_kv_cache_models,
)

assert input_ids_length < sequence_length, (
f"input_ids_length {input_ids_length} "
f"must be less than sequence_length {sequence_length}"
)

(
onnx_file_path,
output_indices_to_be_cached,
kv_cache_data_type,
) = overwrite_onnx_model_inputs_for_kv_cache_models(
onnx_file_path=model_path,
sequence_length=sequence_length,
input_ids_length=input_ids_length,
)

return onnx_file_path, output_indices_to_be_cached, kv_cache_data_type
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from deepsparse import Pipeline
from deepsparse.transformers.utils.helpers import (
create_causal_mask,
overwrite_onnx_model_inputs,
overwrite_onnx_model_inputs_for_kv_cache_models,
)
from deepsparse.utils.onnx import CACHE_INPUT_PREFIX
from sparsezoo import Model
Expand Down Expand Up @@ -216,7 +216,7 @@ def _get_cache_state_ort_kv_cache(model_onnx_path, sequence, model_name):

# setup model and session
# (run full sequence inference)
overwrite_onnx_model_inputs(
overwrite_onnx_model_inputs_for_kv_cache_models(
model_onnx_path, sequence_length=128, input_ids_length=128
)
sess = onnxruntime.InferenceSession(model_onnx_path)
Expand Down
Loading