From 03f2ca8eff8918b98169256d055353a1f15b8e32 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 16 Sep 2022 12:31:43 +0200 Subject: [PATCH] Fix TensorRT exports to ONNX opset 12 (#9441) * Fix TensorRT exports to ONNX opset 12 Signed-off-by: Glenn Jocher * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Glenn Jocher Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- export.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/export.py b/export.py index 1b25f3f8221b..cc4386ae4916 100644 --- a/export.py +++ b/export.py @@ -251,7 +251,7 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose model.model[-1].anchor_grid = grid else: # TensorRT >= 8 check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0 - export_onnx(model, im, file, 13, False, dynamic, simplify) # opset 13 + export_onnx(model, im, file, 12, False, dynamic, simplify) # opset 12 onnx = file.with_suffix('.onnx') LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') @@ -274,11 +274,10 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose inputs = [network.get_input(i) for i in range(network.num_inputs)] outputs = [network.get_output(i) for i in range(network.num_outputs)] - LOGGER.info(f'{prefix} Network Description:') for inp in inputs: - LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}') + LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}') for out in outputs: - LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}') + LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}') if dynamic: if im.shape[0] <= 1: @@ -288,7 +287,7 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape) config.add_optimization_profile(profile) - LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}') + LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}') if builder.platform_has_fast_fp16 and half: config.set_flag(trt.BuilderFlag.FP16) with builder.build_engine(network, config) as engine, open(f, 'wb') as t: