From 92e47b85d952274480c8c5efa5900e686241a96b Mon Sep 17 00:00:00 2001 From: daquexian Date: Wed, 20 Jul 2022 01:01:24 +0800 Subject: [PATCH] Upgrade onnxsim to v0.4.1 (#8632) * upgrade onnxsim to v0.4.1 Signed-off-by: daquexian * Update export.py * Update export.py * Update export.py * Update export.py * Update export.py Co-authored-by: Glenn Jocher --- export.py | 9 ++++----- requirements.txt | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/export.py b/export.py index 9868fcae95c3..3629915f028d 100644 --- a/export.py +++ b/export.py @@ -152,13 +152,12 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst # Simplify if simplify: try: - check_requirements(('onnx-simplifier',)) + cuda = torch.cuda.is_available() + check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1')) import onnxsim LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...') - model_onnx, check = onnxsim.simplify(model_onnx, - dynamic_input_shape=dynamic, - input_shapes={'images': list(im.shape)} if dynamic else None) + model_onnx, check = onnxsim.simplify(model_onnx) assert check, 'assert check failed' onnx.save(model_onnx, f) except Exception as e: @@ -493,7 +492,7 @@ def run( imgsz *= 2 if len(imgsz) == 1 else 1 # expand assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}' if optimize: - assert device.type != 'cuda', '--optimize not compatible with cuda devices, i.e. use --device cpu' + assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu' # Input gs = int(max(model.stride)) # grid size (max stride) diff --git a/requirements.txt b/requirements.txt index a3284d6529eb..8548f67b5a48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ seaborn>=0.11.0 # Export -------------------------------------- # coremltools>=4.1 # CoreML export # onnx>=1.9.0 # ONNX export -# onnx-simplifier>=0.3.6 # ONNX simplifier +# onnx-simplifier>=0.4.1 # ONNX simplifier # nvidia-pyindex # TensorRT export # nvidia-tensorrt # TensorRT export # scikit-learn==0.19.2 # CoreML quantization