Skip to content

Commit

Permalink
address Luka comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Aug 23, 2023
1 parent b3cf419 commit 8ab1c87
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/deepsparse/benchmark/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def benchmark_model(
f"sequence length: {sequence_length}."
)

model_path = overwrite_cache_model_inputs(
model_path, _, _ = overwrite_cache_model_inputs(
model_path=model_path,
input_ids_length=input_ids_length,
sequence_length=sequence_length,
Expand Down
6 changes: 3 additions & 3 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from deepsparse.engine import Context
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
from deepsparse.transformers.utils.helpers import generate_session_id
from deepsparse.transformers.utils.helpers import (
overwrite_onnx_model_inputs_for_kv_cache_models as overwrite_onnx_model_inputs,
generate_session_id,
overwrite_onnx_model_inputs_for_kv_cache_models,
)
from deepsparse.utils.data import numpy_softmax
from deepsparse.utils.onnx import CACHE_INPUT_PREFIX, CACHE_OUTPUT_PREFIX
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
onnx_file_path,
output_indices_to_be_cached,
kv_cache_data_type,
) = overwrite_onnx_model_inputs(
) = overwrite_onnx_model_inputs_for_kv_cache_models(
onnx_file_path=onnx_file_path,
batch_size=engine_args.get("batch_size", 1),
sequence_length=sequence_length,
Expand Down
4 changes: 2 additions & 2 deletions src/deepsparse/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import logging
import uuid
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy
import onnx
Expand All @@ -37,7 +37,7 @@ def overwrite_onnx_model_inputs_for_kv_cache_models(
sequence_length: int,
input_ids_length: int,
batch_size: int = 1,
) -> Tuple[str, List[int]]:
) -> Tuple[str, List[int], Optional[int]]:
"""
Enforces the appropriate input shapes for the onnx model, as well as
checks whether kv cache is enabled or not.
Expand Down
19 changes: 15 additions & 4 deletions src/deepsparse/utils/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,15 +512,22 @@ def overwrite_cache_model_inputs(
model_path: str,
input_ids_length: int,
sequence_length: int,
) -> Tuple[str, int]:
) -> Tuple[str, List[int], Optional[int]]:
"""
Takes a path to an onnx model and enforces that it has
static input dimensions.
:param model_path: Path to a model.
:param input_ids_length: The input_ids length to overwrite the model with.
:param sequence_length: The sequence length to overwrite the model with.
:return: Path to the model with static input dimensions.
:return: A tuple that contains:
- the path to the onnx model file that has been overwritten
with the new input shapes
- boolean list, where elements are set to True if the
corresponding model output should be cached or False
if not.
- the data type of the kv cache. If the model does not
use kv cache, then the data type is None
"""
from deepsparse.transformers.utils.helpers import (
overwrite_onnx_model_inputs_for_kv_cache_models,
Expand All @@ -531,10 +538,14 @@ def overwrite_cache_model_inputs(
f"must be less than sequence_length {sequence_length}"
)

onnx_file_path, _, _ = overwrite_onnx_model_inputs_for_kv_cache_models(
(
onnx_file_path,
output_indices_to_be_cached,
kv_cache_data_type,
) = overwrite_onnx_model_inputs_for_kv_cache_models(
onnx_file_path=model_path,
sequence_length=sequence_length,
input_ids_length=input_ids_length,
)

return onnx_file_path
return onnx_file_path, output_indices_to_be_cached, kv_cache_data_type
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from deepsparse import Pipeline
from deepsparse.transformers.utils.helpers import (
create_causal_mask,
overwrite_onnx_model_inputs,
overwrite_onnx_model_inputs_for_kv_cache_models,
)
from deepsparse.utils.onnx import CACHE_INPUT_PREFIX
from sparsezoo import Model
Expand Down Expand Up @@ -216,7 +216,7 @@ def _get_cache_state_ort_kv_cache(model_onnx_path, sequence, model_name):

# setup model and session
# (run full sequence inference)
overwrite_onnx_model_inputs(
overwrite_onnx_model_inputs_for_kv_cache_models(
model_onnx_path, sequence_length=128, input_ids_length=128
)
sess = onnxruntime.InferenceSession(model_onnx_path)
Expand Down

0 comments on commit 8ab1c87

Please sign in to comment.