From 011595202aea03958cc7f003874084d1d6f37ada Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 30 Apr 2021 14:52:47 +0200 Subject: [PATCH] PyTorch Hub load directly when possible --- hubconf.py | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/hubconf.py b/hubconf.py index 747f7d41bcec..ee5a7d87224d 100644 --- a/hubconf.py +++ b/hubconf.py @@ -9,7 +9,7 @@ import torch -from models.yolo import Model +from models.yolo import Model, attempt_load from utils.general import check_requirements, set_logging from utils.google_utils import attempt_download from utils.torch_utils import select_device @@ -26,33 +26,37 @@ def create(name, pretrained, channels, classes, autoshape, verbose): pretrained (bool): load pretrained weights into the model channels (int): number of input channels classes (int): number of model classes + autoshape (bool): apply YOLOv5 .autoshape() wrapper to model + verbose (bool): print all information to screen Returns: - pytorch model + YOLOv5 pytorch model """ + set_logging(verbose=verbose) + fname = f'{name}.pt' # checkpoint filename try: - set_logging(verbose=verbose) - - cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path - model = Model(cfg, channels, classes) - if pretrained: - fname = f'{name}.pt' # checkpoint filename - attempt_download(fname) # download if not found locally - ckpt = torch.load(fname, map_location=torch.device('cpu')) # load - msd = model.state_dict() # model state_dict - csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 - csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter - model.load_state_dict(csd, strict=False) # load - if len(ckpt['model'].names) == classes: - model.names = ckpt['model'].names # set class names attribute - if autoshape: - model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS + if pretrained and channels == 3 and classes == 80: + model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model + else: + cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path + model = Model(cfg, channels, classes) # create model + if pretrained: + attempt_download(fname) # download if not found locally + ckpt = torch.load(fname, map_location=torch.device('cpu')) # load + msd = model.state_dict() # model state_dict + csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 + csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter + model.load_state_dict(csd, strict=False) # load + if len(ckpt['model'].names) == classes: + model.names = ckpt['model'].names # set class names attribute + if autoshape: + model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available return model.to(device) except Exception as e: help_url = 'https://github.com/ultralytics/yolov5/issues/36' - s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url + s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url raise Exception(s) from e