From 3fa8cc4049f3c3d35c7a5612567e787dac536a3f Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Fri, 24 Mar 2023 14:44:46 -0400 Subject: [PATCH] [ONNX] support for setting/getting string onnx tensor shapes (#1478) (#1480) * [ONNX] support for setting/getting string onnx tensor shapes * Apply suggestions from code review --- src/sparseml/onnx/utils/helpers.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/sparseml/onnx/utils/helpers.py b/src/sparseml/onnx/utils/helpers.py index 06a64c063a0..cb3b6e131dd 100644 --- a/src/sparseml/onnx/utils/helpers.py +++ b/src/sparseml/onnx/utils/helpers.py @@ -1216,16 +1216,19 @@ def get_tensor_shape(tensor: onnx.TensorProto) -> List[int]: return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim] -def get_tensor_dim_shape(tensor: onnx.TensorProto, dim: int) -> int: +def get_tensor_dim_shape(tensor: onnx.TensorProto, dim: Union[int, str]) -> int: """ :param tensor: ONNX tensor to get the shape of a dimension of :param dim: dimension index of the tensor to get the shape of :return: shape of the tensor at the given dimension """ - return tensor.type.tensor_type.shape.dim[dim].dim_value + return ( + tensor.type.tensor_type.shape.dim[dim].dim_value + or tensor.type.tensor_type.shape.dim[dim].dim_param + ) -def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: int): +def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: Union[int, str]): """ Sets the shape of the tensor at the given dimension to the given value @@ -1233,7 +1236,10 @@ def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: int): :param dim: dimension index of the tensor to modify the shape of :param value: new shape for the given dimension """ - tensor.type.tensor_type.shape.dim[dim].dim_value = value + if isinstance(value, str): + tensor.type.tensor_type.shape.dim[dim].dim_param = value + else: + tensor.type.tensor_type.shape.dim[dim].dim_value = value def override_model_input_shape(model: Union[str, onnx.ModelProto], shape: List[int]):