diff --git a/hubconf.py b/hubconf.py index f74e70c85a65..40bbb1ed0826 100644 --- a/hubconf.py +++ b/hubconf.py @@ -41,8 +41,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path model = Model(cfg, channels, classes) # create model if pretrained: - attempt_download(fname) # download if not found locally - ckpt = torch.load(fname, map_location=torch.device('cpu')) # load + ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu')) # load msd = model.state_dict() # model state_dict csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter diff --git a/models/experimental.py b/models/experimental.py index afa787907104..d316b18373c3 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -116,8 +116,7 @@ def attempt_load(weights, map_location=None, inplace=True): # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a model = Ensemble() for w in weights if isinstance(weights, list) else [weights]: - attempt_download(w) - ckpt = torch.load(w, map_location=map_location) # load + ckpt = torch.load(attempt_download(w), map_location=map_location) # load model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model # Compatibility updates diff --git a/train.py b/train.py index c8d617fc228f..a344f53c6b5c 100644 --- a/train.py +++ b/train.py @@ -83,7 +83,7 @@ def train(hyp, opt, device, tb_writer=None): pretrained = weights.endswith('.pt') if pretrained: with torch_distributed_zero_first(rank): - attempt_download(weights) # download if not found locally + weights = attempt_download(weights) # download if not found locally ckpt = torch.load(weights, map_location=device) # load checkpoint model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys diff --git a/utils/google_utils.py b/utils/google_utils.py index 63d3e5b212f3..ac5c54dba97f 100644 --- a/utils/google_utils.py +++ b/utils/google_utils.py @@ -16,11 +16,37 @@ def gsutil_getsize(url=''): return eval(s.split(' ')[0]) if len(s) else 0 # bytes +def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''): + # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes + file = Path(file) + try: # GitHub + print(f'Downloading {url} to {file}...') + torch.hub.download_url_to_file(url, str(file)) + assert file.exists() and file.stat().st_size > min_bytes # check + except Exception as e: # GCP + file.unlink(missing_ok=True) # remove partial downloads + print(f'Download error: {e}\nRe-attempting {url2 or url} to {file}...') + os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail + finally: + if not file.exists() or file.stat().st_size < min_bytes: # check + file.unlink(missing_ok=True) # remove partial downloads + print(f'ERROR: Download failure: {error_msg or url}') + print('') + + def attempt_download(file, repo='ultralytics/yolov5'): # Attempt file download if does not exist file = Path(str(file).strip().replace("'", '')) if not file.exists(): + # URL specified + name = file.name + if str(file).startswith(('http:/', 'https:/')): # download + url = str(file).replace(':/', '://') # Pathlib turns :// -> :/ + safe_download(file=name, url=url, min_bytes=1E5) + return name + + # GitHub assets file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required) try: response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api @@ -34,27 +60,14 @@ def attempt_download(file, repo='ultralytics/yolov5'): except: tag = 'v5.0' # current release - name = file.name if name in assets: - msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/' - redundant = False # second download option - try: # GitHub - url = f'https://github.com/{repo}/releases/download/{tag}/{name}' - print(f'Downloading {url} to {file}...') - torch.hub.download_url_to_file(url, file) - assert file.exists() and file.stat().st_size > 1E6 # check - except Exception as e: # GCP - print(f'Download error: {e}') - assert redundant, 'No secondary mirror' - url = f'https://storage.googleapis.com/{repo}/ckpt/{name}' - print(f'Downloading {url} to {file}...') - os.system(f"curl -L '{url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail - finally: - if not file.exists() or file.stat().st_size < 1E6: # check - file.unlink(missing_ok=True) # remove partial downloads - print(f'ERROR: Download failure: {msg}') - print('') - return + safe_download(file, + url=f'https://github.com/{repo}/releases/download/{tag}/{name}', + # url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}', # backup url (optional) + min_bytes=1E5, + error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/') + + return str(file) def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):