Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX Simplifier #2815

Merged
merged 2 commits into from
Apr 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions models/export.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Exports a YOLOv5 *.pt model to ONNX and TorchScript formats

Usage:
$ export PYTHONPATH="$PWD" && python models/export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
$ export PYTHONPATH="$PWD" && python models/export.py --weights yolov5s.pt --img 640 --batch 1
"""

import argparse
Expand All @@ -16,7 +16,7 @@
import models
from models.experimental import attempt_load
from utils.activations import Hardswish, SiLU
from utils.general import set_logging, check_img_size
from utils.general import colorstr, check_img_size, check_requirements, set_logging
from utils.torch_utils import select_device

if __name__ == '__main__':
Expand Down Expand Up @@ -59,46 +59,61 @@
y = model(img) # dry run

# TorchScript export
prefix = colorstr('TorchScript:')
try:
print('\nStarting TorchScript export with torch %s...' % torch.__version__)
print(f'\n{prefix} starting export with torch {torch.__version__}...')
f = opt.weights.replace('.pt', '.torchscript.pt') # filename
ts = torch.jit.trace(model, img, strict=False)
ts.save(f)
print('TorchScript export success, saved as %s' % f)
print(f'{prefix} export success, saved as {f}')
except Exception as e:
print('TorchScript export failure: %s' % e)
print(f'{prefix} export failure: {e}')

# ONNX export
prefix = colorstr('ONNX:')
try:
import onnx

print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
print(f'{prefix} starting export with onnx {onnx.__version__}...')
f = opt.weights.replace('.pt', '.onnx') # filename
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
output_names=['classes', 'boxes'] if y is None else ['output'],
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)

# Checks
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model
# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
print('ONNX export success, saved as %s' % f)
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# print(onnx.helper.printable_graph(model_onnx.graph)) # print

# Simplify
try:
check_requirements(['onnx-simplifier'])
import onnxsim

print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{prefix} simplifier failure: {e}')
print(f'{prefix} export success, saved as {f}')
except Exception as e:
print('ONNX export failure: %s' % e)
print(f'{prefix} export failure: {e}')

# CoreML export
prefix = colorstr('CoreML:')
try:
import coremltools as ct

print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
print(f'{prefix} starting export with coremltools {onnx.__version__}...')
# convert model from torchscript and apply pixel scaling as per detect.py
model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
f = opt.weights.replace('.pt', '.mlmodel') # filename
model.save(f)
print('CoreML export success, saved as %s' % f)
print(f'{prefix} export success, saved as {f}')
except Exception as e:
print('CoreML export failure: %s' % e)
print(f'{prefix} export failure: {e}')

# Finish
print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t))
print(f'\nExport complete ({time.time() - t:.2f}s). Visualize with https://github.com/lutzroeder/netron.')
2 changes: 1 addition & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def check_requirements(requirements='requirements.txt', exclude=()):
except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
n += 1
print(f"{prefix} {e.req} not found and is required by YOLOv5, attempting auto-update...")
print(subprocess.check_output(f"pip install '{e.req}'", shell=True).decode())
print(subprocess.check_output(f"pip install {e.req}", shell=True).decode())

if n: # if packages updated
source = file.resolve() if 'file' in locals() else requirements
Expand Down