Skip to content

Commit

Permalink
[cherry-pick-1.4.4][ONNX] patch previous commit to accept path-like o…
Browse files Browse the repository at this point in the history
…bjects (#1475) (#1476)

* [ONNX] override_model_input_shape helper function (#1471)

* [ONNX] patch previous commit to accept path-like objects (#1475)

* bump version to 1.4.4
  • Loading branch information
bfineran committed Mar 23, 2023
1 parent cfbfedf commit faebb8b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
23 changes: 23 additions & 0 deletions src/sparseml/onnx/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"get_tensor_shape",
"get_tensor_dim_shape",
"set_tensor_dim_shape",
"override_model_input_shape",
]


Expand Down Expand Up @@ -1233,3 +1234,25 @@ def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: int):
:param value: new shape for the given dimension
"""
tensor.type.tensor_type.shape.dim[dim].dim_value = value


def override_model_input_shape(model: Union[str, onnx.ModelProto], shape: List[int]):
"""
Set the shape of the first input of the given model to the given shape.
If given a file, the file will be overwritten
:param model: ONNX model or model path to overrwrite
:param shape: shape as list of integers to override with. must match
existing dimensions
"""
if not isinstance(model, onnx.ModelProto):
model_path = model
model = onnx.load(model)
else:
model_path = None

for dim, dim_size in enumerate(shape):
set_tensor_dim_shape(model.graph.input[0], dim, dim_size)

if model_path:
onnx.save(model, model_path)
2 changes: 1 addition & 1 deletion src/sparseml/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datetime import date


version_base = "1.4.3"
version_base = "1.4.4"
is_release = False # change to True to set the generated version as a release version


Expand Down

0 comments on commit faebb8b

Please sign in to comment.