From c37bb492da88005a330eab27b72ba944a82ade8d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 10 Oct 2024 11:28:56 -0700 Subject: [PATCH] [ONNX] Create an `optimize` method in ONNXProgram (#137667) Move optimization from the export call to the `optimize()` method in ONNXProgram. Users can call `optimize()` before calling `save()` to save the model. Right now if users set `optimize=True` in `torch.onnx.export` it will have the same effect as calling `optimize()`, but in the future we can evolve the method to be more flexible (e.g. target aware etc.) Example ```python onnx_program = torch.onnx.export(..., dynamo=True) onnx_program.optimize() onnx_program.save("model.onnx") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/137667 Approved by: https://github.com/titaiwangms ghstack dependencies: #137666 --- torch/onnx/_internal/exporter/_compat.py | 2 +- torch/onnx/_internal/exporter/_onnx_program.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index 0cbf65402c695..444c73190557d 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -205,7 +205,7 @@ def export_compat( onnx_program.model, opset_version ) if optimize: - onnx_program.model = onnxscript_apis.optimize(onnx_program.model) + onnx_program.optimize() if f is not None: onnx_program.save( diff --git a/torch/onnx/_internal/exporter/_onnx_program.py b/torch/onnx/_internal/exporter/_onnx_program.py index 3f423f787a723..4a0fad5506aaf 100644 --- a/torch/onnx/_internal/exporter/_onnx_program.py +++ b/torch/onnx/_internal/exporter/_onnx_program.py @@ -114,6 +114,14 @@ def model_proto(self) -> onnx.ModelProto: """Return the ONNX ``ModelProto`` object.""" return ir.serde.serialize_model(self.model) + def optimize(self) -> None: + """Optimize the ONNX model. + + This method optimizes the ONNX model by performing constant folding and + eliminating redundancies in the graph. The optimization is done in-place. + """ + self.model = onnxscript_apis.optimize(self.model) + def save( self, destination: str | os.PathLike,