From f387ffc33c6e85c591048ece91d7bc9011c40365 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 22 Apr 2022 14:31:05 -0700 Subject: [PATCH] Update check_requirements() with `cmds=()` argument (#7543) --- export.py | 10 ++-------- utils/general.py | 6 +++--- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/export.py b/export.py index 93d98c801d02..d7aff0d4b4e1 100644 --- a/export.py +++ b/export.py @@ -218,14 +218,8 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt try: assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`' - try: - import tensorrt as trt - except Exception: - s = f"\n{prefix} tensorrt not found and is required by YOLOv5" - LOGGER.info(f"{s}, attempting auto-update...") - r = '-U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com' - LOGGER.info(subprocess.check_output(f"pip install {r}", shell=True).decode()) - import tensorrt as trt + check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',)) + import tensorrt as trt if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012 grid = model.model[-1].anchor_grid diff --git a/utils/general.py b/utils/general.py index 92e3560de8c0..31abd9420134 100755 --- a/utils/general.py +++ b/utils/general.py @@ -321,7 +321,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals @try_except -def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True): +def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()): # Check installed dependencies meet requirements (pass *.txt file or list of packages) prefix = colorstr('red', 'bold', 'requirements:') check_python() # check python version @@ -334,7 +334,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta requirements = [x for x in requirements if x not in exclude] n = 0 # number of packages updates - for r in requirements: + for i, r in enumerate(requirements): try: pkg.require(r) except Exception: # DistributionNotFound or VersionConflict if requirements not met @@ -343,7 +343,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta LOGGER.info(f"{s}, attempting auto-update...") try: assert check_online(), f"'pip install {r}' skipped (offline)" - LOGGER.info(check_output(f"pip install '{r}'", shell=True).decode()) + LOGGER.info(check_output(f"pip install '{r}' {cmds[i] if cmds else ''}", shell=True).decode()) n += 1 except Exception as e: LOGGER.warning(f'{prefix} {e}')