Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --include torchscript onnx coreml argument #3137

Merged
merged 5 commits into from
May 12, 2021
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 61 additions & 54 deletions models/export.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Exports a YOLOv5 *.pt model to ONNX and TorchScript formats
"""Exports a YOLOv5 *.pt model to TorchScript, ONNX, CoreML formats

Usage:
$ export PYTHONPATH="$PWD" && python models/export.py --weights yolov5s.pt --img 640 --batch 1
$ python path/to/models/export.py --weights yolov5s.pt --img 640 --batch 1
"""

import argparse
Expand Down Expand Up @@ -33,8 +33,10 @@
parser.add_argument('--optimize', action='store_true', help='optimize TorchScript for mobile') # TorchScript-only
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
parser.add_argument('--skip-format', action='append', dest="skipped_formats", help='exclude output format')
opt = parser.parse_args()
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
opt.skipped_formats = [i.lower() for i in opt.skipped_formats] if opt.skipped_formats else []
print(opt)
set_logging()
t = time.time()
Expand Down Expand Up @@ -74,62 +76,67 @@
print(f"\n{colorstr('PyTorch:')} starting from {opt.weights} ({file_size(opt.weights):.1f} MB)")

# TorchScript export -----------------------------------------------------------------------------------------------
prefix = colorstr('TorchScript:')
try:
print(f'\n{prefix} starting export with torch {torch.__version__}...')
f = opt.weights.replace('.pt', '.torchscript.pt') # filename
ts = torch.jit.trace(model, img, strict=False)
(optimize_for_mobile(ts) if opt.optimize else ts).save(f)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')
if "torchscript" not in opt.skipped_formats:
prefix = colorstr('TorchScript:')
try:
print(f'\n{prefix} starting export with torch {torch.__version__}...')
f = opt.weights.replace('.pt', '.torchscript.pt') # filename
ts = torch.jit.trace(model, img, strict=False)
(optimize_for_mobile(ts) if opt.optimize else ts).save(f)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')

# ONNX export ------------------------------------------------------------------------------------------------------
prefix = colorstr('ONNX:')
try:
import onnx

print(f'{prefix} starting export with onnx {onnx.__version__}...')
f = opt.weights.replace('.pt', '.onnx') # filename
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)

# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# print(onnx.helper.printable_graph(model_onnx.graph)) # print

# Simplify
if opt.simplify:
try:
check_requirements(['onnx-simplifier'])
import onnxsim

print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx,
dynamic_input_shape=opt.dynamic,
input_shapes={'images': list(img.shape)} if opt.dynamic else None)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{prefix} simplifier failure: {e}')
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')
if "onnx" not in opt.skipped_formats:
prefix = colorstr('ONNX:')
try:
import onnx

print(f'{prefix} starting export with onnx {onnx.__version__}...')
f = opt.weights.replace('.pt', '.onnx') # filename
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)

# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# print(onnx.helper.printable_graph(model_onnx.graph)) # print

# Simplify
if opt.simplify:
try:
check_requirements(['onnx-simplifier'])
import onnxsim

print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(
model_onnx,
dynamic_input_shape=opt.dynamic,
input_shapes={'images': list(img.shape)} if opt.dynamic else None)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{prefix} simplifier failure: {e}')
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')

# CoreML export ----------------------------------------------------------------------------------------------------
prefix = colorstr('CoreML:')
try:
import coremltools as ct

print(f'{prefix} starting export with coremltools {ct.__version__}...')
model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
f = opt.weights.replace('.pt', '.mlmodel') # filename
model.save(f)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')
if "coreml" not in opt.skipped_formats:
prefix = colorstr('CoreML:')
try:
import coremltools as ct

print(f'{prefix} starting export with coremltools {ct.__version__}...')
model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0,
bias=[0, 0, 0])])
f = opt.weights.replace('.pt', '.mlmodel') # filename
model.save(f)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')

# Finish
print(f'\nExport complete ({time.time() - t:.2f}s). Visualize with https://github.com/lutzroeder/netron.')