From cc1d7df03c7c3c37367e76b237ac4b087ea040d4 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 22 Apr 2022 12:31:33 -0700 Subject: [PATCH] Autoinstall TensorRT if missing (#7537) * Autoinstall TensorRT if missing May resolve https://github.com/ultralytics/yolov5/issues/7464 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update export.py * Update export.py * Update export.py * Update export.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- export.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/export.py b/export.py index 2a5eff23c1a6..93d98c801d02 100644 --- a/export.py +++ b/export.py @@ -217,7 +217,15 @@ def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')): def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt try: - import tensorrt as trt # pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com + assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`' + try: + import tensorrt as trt + except Exception: + s = f"\n{prefix} tensorrt not found and is required by YOLOv5" + LOGGER.info(f"{s}, attempting auto-update...") + r = '-U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com' + LOGGER.info(subprocess.check_output(f"pip install {r}", shell=True).decode()) + import tensorrt as trt if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012 grid = model.model[-1].anchor_grid @@ -230,7 +238,6 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F onnx = file.with_suffix('.onnx') LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') - assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`' assert onnx.exists(), f'failed to export ONNX file: {onnx}' f = file.with_suffix('.engine') # TensorRT engine file logger = trt.Logger(trt.Logger.INFO)