Skip to content

Commit

Permalink
PyTorch Hub load directly when possible (ultralytics#2986)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored and danny-schwartz7 committed May 22, 2021
1 parent 49239ba commit 7ed06ff
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from models.yolo import Model
from models.yolo import Model, attempt_load
from utils.general import check_requirements, set_logging
from utils.google_utils import attempt_download
from utils.torch_utils import select_device
Expand All @@ -26,33 +26,37 @@ def create(name, pretrained, channels, classes, autoshape, verbose):
pretrained (bool): load pretrained weights into the model
channels (int): number of input channels
classes (int): number of model classes
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
verbose (bool): print all information to screen
Returns:
pytorch model
YOLOv5 pytorch model
"""
set_logging(verbose=verbose)
fname = f'{name}.pt' # checkpoint filename
try:
set_logging(verbose=verbose)

cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes)
if pretrained:
fname = f'{name}.pt' # checkpoint filename
attempt_download(fname) # download if not found locally
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
model.load_state_dict(csd, strict=False) # load
if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute
if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
if pretrained and channels == 3 and classes == 80:
model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model
else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model
if pretrained:
attempt_download(fname) # download if not found locally
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
model.load_state_dict(csd, strict=False) # load
if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute
if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
return model.to(device)

except Exception as e:
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url
s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url
raise Exception(s) from e


Expand Down

0 comments on commit 7ed06ff

Please sign in to comment.