From a75f2abe02331307645da21cb330e7eda3e24659 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 5 Nov 2021 18:43:03 +0100 Subject: [PATCH] Fix detect.py URL inference (#5525) * Fix detect.py URL inference Allows detect.py to run inference on remote URL sources, i.e.: ```python !python detect.py --weights yolov5s.pt --source https://ultralytics.com/assets/zidane.jpg # image URL !python detect.py --weights yolov5s.pt --source https://ultralytics.com/assets/decelera_landscape.mov # video URL ``` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- detect.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/detect.py b/detect.py index 46141ed4da3c..61044914e16b 100644 --- a/detect.py +++ b/detect.py @@ -24,10 +24,10 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative from models.experimental import attempt_load -from utils.datasets import LoadImages, LoadStreams -from utils.general import (LOGGER, apply_classifier, check_img_size, check_imshow, check_requirements, check_suffix, - colorstr, increment_path, non_max_suppression, print_args, save_one_box, scale_coords, - strip_optimizer, xyxy2xywh) +from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams +from utils.general import (LOGGER, apply_classifier, check_file, check_img_size, check_imshow, check_requirements, + check_suffix, colorstr, increment_path, non_max_suppression, print_args, save_one_box, + scale_coords, strip_optimizer, xyxy2xywh) from utils.plots import Annotator, colors from utils.torch_utils import load_classifier, select_device, time_sync @@ -61,8 +61,11 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) ): source = str(source) save_img = not nosave and not source.endswith('.txt') # save inference images - webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( - ('rtsp://', 'rtmp://', 'http://', 'https://')) + is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) + is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) + webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file) + if is_url and is_file: + source = check_file(source) # download # Directories save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run