Skip to content

Commit

Permalink
[ONNX] Create an optimize method in ONNXProgram (pytorch#137667)
Browse files Browse the repository at this point in the history
Move optimization from the export call to the `optimize()` method in ONNXProgram.

Users can call `optimize()` before calling `save()` to save the model. Right now if users set `optimize=True` in `torch.onnx.export` it will have the same effect as calling `optimize()`, but in the future we can evolve the method to be more flexible (e.g. target aware etc.)

Example

```python
onnx_program = torch.onnx.export(..., dynamo=True)
onnx_program.optimize()
onnx_program.save("model.onnx")
```

Pull Request resolved: pytorch#137667
Approved by: https://github.com/titaiwangms
ghstack dependencies: pytorch#137666
  • Loading branch information
justinchuby authored and pytorchmergebot committed Oct 10, 2024
1 parent e75984c commit c37bb49
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion torch/onnx/_internal/exporter/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def export_compat(
onnx_program.model, opset_version
)
if optimize:
onnx_program.model = onnxscript_apis.optimize(onnx_program.model)
onnx_program.optimize()

if f is not None:
onnx_program.save(
Expand Down
8 changes: 8 additions & 0 deletions torch/onnx/_internal/exporter/_onnx_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ def model_proto(self) -> onnx.ModelProto:
"""Return the ONNX ``ModelProto`` object."""
return ir.serde.serialize_model(self.model)

def optimize(self) -> None:
"""Optimize the ONNX model.
This method optimizes the ONNX model by performing constant folding and
eliminating redundancies in the graph. The optimization is done in-place.
"""
self.model = onnxscript_apis.optimize(self.model)

def save(
self,
destination: str | os.PathLike,
Expand Down

0 comments on commit c37bb49

Please sign in to comment.