Skip to content

Commit

Permalink
fixed type annotations and avoided overwriting inputs when no sequenc…
Browse files Browse the repository at this point in the history
…e_length is passed
  • Loading branch information
ProExpertProg committed Aug 8, 2023
1 parent 360db72 commit 709853d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/deepsparse/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def overwrite_onnx_model_inputs_for_kv_cache_models(
sequence_length: int = 128,
input_ids_length: int = 1,
batch_size: int = 1,
) -> Tuple[str, List[int], Optional[numpy.dtype]]:
) -> Tuple[str, List[bool], Optional[numpy.dtype]]:
"""
Enforces the appropriate input shapes for the onnx model, as well as
checks whether kv cache is enabled or not.
Expand Down
12 changes: 8 additions & 4 deletions src/deepsparse/utils/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def has_model_kv_cache(model: Union[str, ModelProto]) -> bool:

def overwrite_sequence_length(
model_path: str, sequence_length: Optional[int] = None
) -> str:
) -> Tuple[str, int]:
"""
Takes a path to an onnx model and enforces that it has
static input dimensions.
Expand All @@ -525,9 +525,13 @@ def overwrite_sequence_length(
overwrite_onnx_model_inputs_for_kv_cache_models,
)

onnx_file_path, _, _ = overwrite_onnx_model_inputs_for_kv_cache_models(
onnx_file_path=model_path, sequence_length=sequence_length
)
if sequence_length is not None:
onnx_file_path, _, _ = overwrite_onnx_model_inputs_for_kv_cache_models(
onnx_file_path=model_path, sequence_length=sequence_length
)
else:
onnx_file_path = model_path

attention_input_info = [
input
for input in onnx.load(onnx_file_path, load_external_data=False).graph.input
Expand Down

0 comments on commit 709853d

Please sign in to comment.