Skip to content

Commit

Permalink
[ONNX] support for setting/getting string onnx tensor shapes (#1478) (#…
Browse files Browse the repository at this point in the history
…1480)

* [ONNX] support for setting/getting string onnx tensor shapes

* Apply suggestions from code review
  • Loading branch information
bfineran committed Mar 24, 2023
1 parent faebb8b commit 3fa8cc4
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/sparseml/onnx/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,24 +1216,30 @@ def get_tensor_shape(tensor: onnx.TensorProto) -> List[int]:
return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim]


def get_tensor_dim_shape(tensor: onnx.TensorProto, dim: int) -> int:
def get_tensor_dim_shape(tensor: onnx.TensorProto, dim: Union[int, str]) -> int:
"""
:param tensor: ONNX tensor to get the shape of a dimension of
:param dim: dimension index of the tensor to get the shape of
:return: shape of the tensor at the given dimension
"""
return tensor.type.tensor_type.shape.dim[dim].dim_value
return (
tensor.type.tensor_type.shape.dim[dim].dim_value
or tensor.type.tensor_type.shape.dim[dim].dim_param
)


def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: int):
def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: Union[int, str]):
"""
Sets the shape of the tensor at the given dimension to the given value
:param tensor: ONNX tensor to modify the shape of
:param dim: dimension index of the tensor to modify the shape of
:param value: new shape for the given dimension
"""
tensor.type.tensor_type.shape.dim[dim].dim_value = value
if isinstance(value, str):
tensor.type.tensor_type.shape.dim[dim].dim_param = value
else:
tensor.type.tensor_type.shape.dim[dim].dim_value = value


def override_model_input_shape(model: Union[str, onnx.ModelProto], shape: List[int]):
Expand Down

0 comments on commit 3fa8cc4

Please sign in to comment.