diff --git a/src/sparseml/onnx/utils/helpers.py b/src/sparseml/onnx/utils/helpers.py index 06a64c063a0..cb3b6e131dd 100644 --- a/src/sparseml/onnx/utils/helpers.py +++ b/src/sparseml/onnx/utils/helpers.py @@ -1216,16 +1216,19 @@ def get_tensor_shape(tensor: onnx.TensorProto) -> List[int]: return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim] -def get_tensor_dim_shape(tensor: onnx.TensorProto, dim: int) -> int: +def get_tensor_dim_shape(tensor: onnx.TensorProto, dim: Union[int, str]) -> int: """ :param tensor: ONNX tensor to get the shape of a dimension of :param dim: dimension index of the tensor to get the shape of :return: shape of the tensor at the given dimension """ - return tensor.type.tensor_type.shape.dim[dim].dim_value + return ( + tensor.type.tensor_type.shape.dim[dim].dim_value + or tensor.type.tensor_type.shape.dim[dim].dim_param + ) -def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: int): +def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: Union[int, str]): """ Sets the shape of the tensor at the given dimension to the given value @@ -1233,7 +1236,10 @@ def set_tensor_dim_shape(tensor: onnx.TensorProto, dim: int, value: int): :param dim: dimension index of the tensor to modify the shape of :param value: new shape for the given dimension """ - tensor.type.tensor_type.shape.dim[dim].dim_value = value + if isinstance(value, str): + tensor.type.tensor_type.shape.dim[dim].dim_param = value + else: + tensor.type.tensor_type.shape.dim[dim].dim_value = value def override_model_input_shape(model: Union[str, onnx.ModelProto], shape: List[int]):