From f65ec280ea7eb989c7c273b3946dc9d4ffb024ca Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 28 May 2021 13:30:24 +0200 Subject: [PATCH 1/9] Enable direct `--weights URL` definition @KalenMike this PR will enable direct --weights URL definition. Example use case: ``` python train.py --weights https://storage.googleapis.com/bucket/dir/model.pt ``` --- utils/google_utils.py | 51 ++++++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/utils/google_utils.py b/utils/google_utils.py index 63d3e5b212f3..56366b22788a 100644 --- a/utils/google_utils.py +++ b/utils/google_utils.py @@ -16,12 +16,38 @@ def gsutil_getsize(url=''): return eval(s.split(' ')[0]) if len(s) else 0 # bytes +def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''): + # Attempts to download file from url or url2, checks and removes incomplete downloads < min_size bytes + file = Path(file) + try: # GitHub + print(f'Downloading {url} to {file}...') + torch.hub.download_url_to_file(url, str(file)) + assert file.exists() and file.stat().st_size > min_bytes # check + except Exception as e: # GCP + file.unlink(missing_ok=True) # remove partial downloads + print(f'Download error: {e}\nRe-attempting {url2 or url} to {file}...') + os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail + finally: + if not file.exists() or file.stat().st_size < min_bytes: # check + file.unlink(missing_ok=True) # remove partial downloads + print(f'ERROR: Download failure: {error_msg or url}') + print('') + return + + def attempt_download(file, repo='ultralytics/yolov5'): # Attempt file download if does not exist file = Path(str(file).strip().replace("'", '')) if not file.exists(): file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required) + name = file.name + + # URL specified + if str(file).startswith(('http://', 'https://')): # download + safe_download(file=name, url=str(file), min_bytes=1E5) + + # GitHub assets try: response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...] @@ -34,27 +60,12 @@ def attempt_download(file, repo='ultralytics/yolov5'): except: tag = 'v5.0' # current release - name = file.name if name in assets: - msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/' - redundant = False # second download option - try: # GitHub - url = f'https://github.com/{repo}/releases/download/{tag}/{name}' - print(f'Downloading {url} to {file}...') - torch.hub.download_url_to_file(url, file) - assert file.exists() and file.stat().st_size > 1E6 # check - except Exception as e: # GCP - print(f'Download error: {e}') - assert redundant, 'No secondary mirror' - url = f'https://storage.googleapis.com/{repo}/ckpt/{name}' - print(f'Downloading {url} to {file}...') - os.system(f"curl -L '{url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail - finally: - if not file.exists() or file.stat().st_size < 1E6: # check - file.unlink(missing_ok=True) # remove partial downloads - print(f'ERROR: Download failure: {msg}') - print('') - return + safe_download(file, + url=f'https://github.com/{repo}/releases/download/{tag}/{name}', + url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}', # backup + min_bytes=1E5, + error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/') def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'): From 1a475fbda59e9ea436a81634594b962eb2d901a6 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 28 May 2021 13:35:06 +0200 Subject: [PATCH 2/9] cleanup --- utils/google_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/google_utils.py b/utils/google_utils.py index 56366b22788a..cc1d496cffd0 100644 --- a/utils/google_utils.py +++ b/utils/google_utils.py @@ -17,7 +17,7 @@ def gsutil_getsize(url=''): def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''): - # Attempts to download file from url or url2, checks and removes incomplete downloads < min_size bytes + # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes file = Path(file) try: # GitHub print(f'Downloading {url} to {file}...') From e383f5201522b3b1c55a4cb1c8afa77277fe96d9 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 28 May 2021 13:54:06 +0200 Subject: [PATCH 3/9] bug fixes --- utils/google_utils.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/utils/google_utils.py b/utils/google_utils.py index cc1d496cffd0..1662fb6ef6bc 100644 --- a/utils/google_utils.py +++ b/utils/google_utils.py @@ -17,7 +17,7 @@ def gsutil_getsize(url=''): def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''): - # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes + # Attempts to download file from url or url2, checks and removes incomplete downloads < min_size bytes file = Path(file) try: # GitHub print(f'Downloading {url} to {file}...') @@ -40,14 +40,15 @@ def attempt_download(file, repo='ultralytics/yolov5'): file = Path(str(file).strip().replace("'", '')) if not file.exists(): - file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required) - name = file.name - # URL specified - if str(file).startswith(('http://', 'https://')): # download - safe_download(file=name, url=str(file), min_bytes=1E5) + name = file.name + if str(file).startswith(('http:/', 'https:/')): # download + url = str(file).replace(':/', '://') # Pathlib turns :// -> :/ + safe_download(file=name, url=url, min_bytes=1E5) + return name # GitHub assets + file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required) try: response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...] @@ -67,6 +68,8 @@ def attempt_download(file, repo='ultralytics/yolov5'): min_bytes=1E5, error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/') + return str(file) + def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'): # Downloads a file from Google Drive. from yolov5.utils.google_utils import *; gdrive_download() From 0b13b94c4ff195e5640399d0f827de30b3c41e54 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 28 May 2021 13:55:22 +0200 Subject: [PATCH 4/9] weights = attempt_download(weights) --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index c8d617fc228f..a344f53c6b5c 100644 --- a/train.py +++ b/train.py @@ -83,7 +83,7 @@ def train(hyp, opt, device, tb_writer=None): pretrained = weights.endswith('.pt') if pretrained: with torch_distributed_zero_first(rank): - attempt_download(weights) # download if not found locally + weights = attempt_download(weights) # download if not found locally ckpt = torch.load(weights, map_location=device) # load checkpoint model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys From d588968d6b8e510cfc3d4199ca7abbc442074e2b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 28 May 2021 14:24:50 +0200 Subject: [PATCH 5/9] Update experimental.py --- models/experimental.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/models/experimental.py b/models/experimental.py index afa787907104..d316b18373c3 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -116,8 +116,7 @@ def attempt_load(weights, map_location=None, inplace=True): # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a model = Ensemble() for w in weights if isinstance(weights, list) else [weights]: - attempt_download(w) - ckpt = torch.load(w, map_location=map_location) # load + ckpt = torch.load(attempt_download(w), map_location=map_location) # load model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model # Compatibility updates From eca773db5fbb9d8d8d70dccb86db09967e70fa5c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 28 May 2021 14:25:57 +0200 Subject: [PATCH 6/9] Update hubconf.py --- hubconf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/hubconf.py b/hubconf.py index f74e70c85a65..40bbb1ed0826 100644 --- a/hubconf.py +++ b/hubconf.py @@ -41,8 +41,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path model = Model(cfg, channels, classes) # create model if pretrained: - attempt_download(fname) # download if not found locally - ckpt = torch.load(fname, map_location=torch.device('cpu')) # load + ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu')) # load msd = model.state_dict() # model state_dict csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter From d9e5c9a15bbea1f4bb970aa61b1eebdd605e4856 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 28 May 2021 14:34:45 +0200 Subject: [PATCH 7/9] return bug fix --- utils/google_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/utils/google_utils.py b/utils/google_utils.py index 1662fb6ef6bc..d6798d69b88c 100644 --- a/utils/google_utils.py +++ b/utils/google_utils.py @@ -32,7 +32,6 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''): file.unlink(missing_ok=True) # remove partial downloads print(f'ERROR: Download failure: {error_msg or url}') print('') - return def attempt_download(file, repo='ultralytics/yolov5'): @@ -68,7 +67,7 @@ def attempt_download(file, repo='ultralytics/yolov5'): min_bytes=1E5, error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/') - return str(file) + return str(file) def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'): From 0cd14c4ddbb32d3e8df1b74a1134d2d155841452 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 28 May 2021 14:50:25 +0200 Subject: [PATCH 8/9] comment mirror --- utils/google_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/google_utils.py b/utils/google_utils.py index d6798d69b88c..dc3ab7e1fd6a 100644 --- a/utils/google_utils.py +++ b/utils/google_utils.py @@ -63,7 +63,7 @@ def attempt_download(file, repo='ultralytics/yolov5'): if name in assets: safe_download(file, url=f'https://github.com/{repo}/releases/download/{tag}/{name}', - url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}', # backup + # url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}', # backup url (optional) min_bytes=1E5, error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/') From eecf1c7322e0892eeb4de0252ff613caa4605bd5 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 28 May 2021 14:52:56 +0200 Subject: [PATCH 9/9] min_bytes --- utils/google_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/google_utils.py b/utils/google_utils.py index dc3ab7e1fd6a..ac5c54dba97f 100644 --- a/utils/google_utils.py +++ b/utils/google_utils.py @@ -17,7 +17,7 @@ def gsutil_getsize(url=''): def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''): - # Attempts to download file from url or url2, checks and removes incomplete downloads < min_size bytes + # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes file = Path(file) try: # GitHub print(f'Downloading {url} to {file}...')