diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 37d7ee4b35a73..4e3510472cd96 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -813,64 +813,6 @@ def _decide_input_format(model, args): return args -def _from_dynamic_axes_to_dynamic_shapes( - model, - dynamic_axes: Mapping[str, Mapping[int, str]] - | Mapping[str, Sequence[int]] - | None = None, - input_names: Sequence[str] | None = None, -) -> dict[str, Any] | None: - """ - - dynamic_axes examples: - (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}} - (2) dynamic_axes = {"x": [0], "y": [1]} - - these will be converted to dynamic_shapes respectively: - (1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}} - (2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}} # auto-generated dim names - - """ - if dynamic_axes is None: - return None - - if input_names is None: - input_names_set = set() - else: - input_names_set = set(input_names) - - dynamic_shapes: dict[str, Any | None] = {} - for input_name, axes in dynamic_axes.items(): - if input_name in input_names_set: - raise ValueError( - "Assinging new input names is not supported yet. Please use model forward signature " - "to specify input names in dynamix_axes." - ) - if isinstance(axes, dict): - dynamic_shapes[input_name] = { - k: torch.export.Dim(v) for k, v in axes.items() - } - elif isinstance(axes, list): - dynamic_shapes[input_name] = { - k: torch.export.Dim(f"{input_name}_dim_{k}") for k in axes - } - else: - raise TypeError( - f"dynamic_axes value must be either a dict or a list, but got {type(axes)}" - ) - # torch.export.export needs static dim to present in dynamic_shapes - # for all input tensors, so we need to add them with None - try: - sig = _signature(model) - except ValueError as e: - warnings.warn(f"{e}, skipping auto filling None on static axes...") - return dynamic_shapes - for input_name in sig.parameters.keys(): - if input_name not in dynamic_shapes: - dynamic_shapes[input_name] = None - return dynamic_shapes - - def _trace(func, args, operator_export_type, return_outs=False): # Special case for common case of passing a single Tensor if isinstance(args, torch.Tensor):