From da9a1b719ba7d10e209ff89efe28b074fb9a5f16 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 15 Dec 2021 15:27:08 +0100 Subject: [PATCH] Allow `--weights URL` (#5991) --- models/common.py | 4 ++-- utils/downloads.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/models/common.py b/models/common.py index c2edff4d3021..4f1afa13396c 100644 --- a/models/common.py +++ b/models/common.py @@ -296,7 +296,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False): check_suffix(w, suffixes) # check weights have acceptable suffix pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults - attempt_download(w) # download if not local + w = attempt_download(w) # download if not local if jit: # TorchScript LOGGER.info(f'Loading {w} for TorchScript inference...') @@ -306,7 +306,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False): d = json.loads(extra_files['config.txt']) # extra_files dict stride, names = int(d['stride']), d['names'] elif pt: # PyTorch - model = attempt_load(weights, map_location=device) + model = attempt_load(weights if isinstance(weights, list) else w, map_location=device) stride = int(model.stride.max()) # model stride names = model.module.names if hasattr(model, 'module') else model.names # get class names self.model = model # explicitly assign for to(), cpu(), cuda(), half() diff --git a/utils/downloads.py b/utils/downloads.py index 998a7a582a33..a8bacae4420f 100644 --- a/utils/downloads.py +++ b/utils/downloads.py @@ -49,9 +49,12 @@ def attempt_download(file, repo='ultralytics/yolov5'): # from utils.downloads i name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc. if str(file).startswith(('http:/', 'https:/')): # download url = str(file).replace(':/', '://') # Pathlib turns :// -> :/ - name = name.split('?')[0] # parse authentication https://url.com/file.txt?auth... - safe_download(file=name, url=url, min_bytes=1E5) - return name + file = name.split('?')[0] # parse authentication https://url.com/file.txt?auth... + if Path(file).is_file(): + print(f'Found {url} locally at {file}') # file already exists + else: + safe_download(file=file, url=url, min_bytes=1E5) + return file # GitHub assets file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)