Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable direct --weights URL definition #3373

Merged
merged 9 commits into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model
if pretrained:
attempt_download(fname) # download if not found locally
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu')) # load
msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
Expand Down
3 changes: 1 addition & 2 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ def attempt_load(weights, map_location=None, inplace=True):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
attempt_download(w)
ckpt = torch.load(w, map_location=map_location) # load
ckpt = torch.load(attempt_download(w), map_location=map_location) # load
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model

# Compatibility updates
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def train(hyp, opt, device, tb_writer=None):
pretrained = weights.endswith('.pt')
if pretrained:
with torch_distributed_zero_first(rank):
attempt_download(weights) # download if not found locally
weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
Expand Down
53 changes: 33 additions & 20 deletions utils/google_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,37 @@ def gsutil_getsize(url=''):
return eval(s.split(' ')[0]) if len(s) else 0 # bytes


def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
file = Path(file)
try: # GitHub
print(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, str(file))
assert file.exists() and file.stat().st_size > min_bytes # check
except Exception as e: # GCP
file.unlink(missing_ok=True) # remove partial downloads
print(f'Download error: {e}\nRe-attempting {url2 or url} to {file}...')
os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
finally:
if not file.exists() or file.stat().st_size < min_bytes: # check
file.unlink(missing_ok=True) # remove partial downloads
print(f'ERROR: Download failure: {error_msg or url}')
print('')


def attempt_download(file, repo='ultralytics/yolov5'):
# Attempt file download if does not exist
file = Path(str(file).strip().replace("'", ''))

if not file.exists():
# URL specified
name = file.name
if str(file).startswith(('http:/', 'https:/')): # download
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
safe_download(file=name, url=url, min_bytes=1E5)
return name

# GitHub assets
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
try:
response = requests.get(f'https://github.com/gitapi/repos/{repo}/releases/latest').json() # github api
Expand All @@ -34,27 +60,14 @@ def attempt_download(file, repo='ultralytics/yolov5'):
except:
tag = 'v5.0' # current release

name = file.name
if name in assets:
msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/'
redundant = False # second download option
try: # GitHub
url = f'https://github.com/{repo}/releases/download/{tag}/{name}'
print(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, file)
assert file.exists() and file.stat().st_size > 1E6 # check
except Exception as e: # GCP
print(f'Download error: {e}')
assert redundant, 'No secondary mirror'
url = f'https://storage.googleapis.com/{repo}/ckpt/{name}'
print(f'Downloading {url} to {file}...')
os.system(f"curl -L '{url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
finally:
if not file.exists() or file.stat().st_size < 1E6: # check
file.unlink(missing_ok=True) # remove partial downloads
print(f'ERROR: Download failure: {msg}')
print('')
return
safe_download(file,
url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
# url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}', # backup url (optional)
min_bytes=1E5,
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/')

return str(file)


def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
Expand Down