Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add onnxslim support #37

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from models.experimental import attempt_load
from models.yolo import ClassificationModel, Detect, DetectionModel, SegmentationModel
from models.yolo import ClassificationModel, Detect, DetectionModel, SegmentationModel, DualDDetect
from utils.dataloaders import LoadImages
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
Expand Down Expand Up @@ -84,7 +84,7 @@ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:'


@try_export
def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):
def export_onnx(model, im, file, opset, dynamic, simplify, slim, prefix=colorstr('ONNX:')):
# YOLO ONNX export
check_requirements('onnx')
import onnx
Expand Down Expand Up @@ -136,6 +136,18 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX
onnx.save(model_onnx, f)
except Exception as e:
LOGGER.info(f'{prefix} simplifier failure: {e}')

if slim:
try:
check_requirements('onnxslim')
from onnxslim import slim

LOGGER.info(f'{prefix} slimming with onnxslim...')
model_onnx = slim(model_onnx)
onnx.save(model_onnx, f)
except Exception as e:
LOGGER.info(f'{prefix} slim failure: {e}')

return f, model_onnx


Expand Down Expand Up @@ -193,7 +205,7 @@ def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):


@try_export
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
def export_engine(model, im, file, half, dynamic, simplify, slim, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
# YOLO TensorRT export https://developer.nvidia.com/tensorrt
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
try:
Expand All @@ -206,11 +218,11 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
grid = model.model[-1].anchor_grid
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
export_onnx(model, im, file, 12, dynamic, simplify, slim) # opset 12
model.model[-1].anchor_grid = grid
else: # TensorRT >= 8
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
export_onnx(model, im, file, 12, dynamic, simplify, slim) # opset 12
onnx = file.with_suffix('.onnx')

LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
Expand Down Expand Up @@ -456,6 +468,7 @@ def run(
int8=False, # CoreML/TF INT8 quantization
dynamic=False, # ONNX/TF/TensorRT: dynamic axes
simplify=False, # ONNX: simplify model
slim=False, # ONNX: slim model
opset=12, # ONNX: opset version
verbose=False, # TensorRT: verbose log
workspace=4, # TensorRT: workspace size (GB)
Expand Down Expand Up @@ -494,7 +507,7 @@ def run(
# Update model
model.eval()
for k, m in model.named_modules():
if isinstance(m, (Detect, V6Detect)):
if isinstance(m, (Detect, DualDDetect)):
m.inplace = inplace
m.dynamic = dynamic
m.export = True
Expand All @@ -503,7 +516,7 @@ def run(
y = model(im) # dry runs
if half and not coreml:
im, model = im.half(), model.half() # to FP16
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
shape = tuple((y[0] if isinstance(y, (tuple, list)) else y).shape) # model output shape
metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")

Expand All @@ -513,9 +526,9 @@ def run(
if jit: # TorchScript
f[0], _ = export_torchscript(model, im, file, optimize)
if engine: # TensorRT required before ONNX
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, slim, workspace, verbose)
if onnx or xml: # OpenVINO requires ONNX
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify, slim)
if xml: # OpenVINO
f[3], _ = export_openvino(file, metadata, half)
if coreml: # CoreML
Expand Down Expand Up @@ -577,6 +590,7 @@ def parse_opt():
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--slim', action='store_true', help='ONNX: slim model')
parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ seaborn>=0.11.0
# tensorflow>=2.4.1
# tensorflowjs>=3.9.0
# openvino-dev
# onnxslim

# Deploy ----------------------------------------------------------------------
# tritonclient[all]~=2.24.0
Expand Down