From 7c58654d150f80ba5596fbef7f9f904397636d94 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 1 May 2021 17:35:02 +0200 Subject: [PATCH] Update hubconf.py for unified loading (#3005) --- hubconf.py | 34 +++++++--------------------------- 1 file changed, 7 insertions(+), 27 deletions(-) diff --git a/hubconf.py b/hubconf.py index ee5a7d87224d..7f897d15c314 100644 --- a/hubconf.py +++ b/hubconf.py @@ -18,7 +18,7 @@ check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('tensorboard', 'pycocotools', 'thop')) -def create(name, pretrained, channels, classes, autoshape, verbose): +def create(name, pretrained, channels=3, classes=80, autoshape=True, verbose=True): """Creates a specified YOLOv5 model Arguments: @@ -33,7 +33,7 @@ def create(name, pretrained, channels, classes, autoshape, verbose): YOLOv5 pytorch model """ set_logging(verbose=verbose) - fname = f'{name}.pt' # checkpoint filename + fname = Path(name).with_suffix('.pt') # checkpoint filename try: if pretrained and channels == 3 and classes == 80: model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model @@ -60,30 +60,9 @@ def create(name, pretrained, channels, classes, autoshape, verbose): raise Exception(s) from e -def custom(path_or_model='path/to/model.pt', autoshape=True, verbose=True): - """YOLOv5-custom model https://github.com/ultralytics/yolov5 - - Arguments (3 options): - path_or_model (str): 'path/to/model.pt' - path_or_model (dict): torch.load('path/to/model.pt') - path_or_model (nn.Module): torch.load('path/to/model.pt')['model'] - - Returns: - pytorch model - """ - set_logging(verbose=verbose) - - model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint - if isinstance(model, dict): - model = model['ema' if model.get('ema') else 'model'] # load model - - hub_model = Model(model.yaml).to(next(model.parameters()).device) # create - hub_model.load_state_dict(model.float().state_dict()) # load state_dict - hub_model.names = model.names # class names - if autoshape: - hub_model = hub_model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS - device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available - return hub_model.to(device) +def custom(path='path/to/model.pt', autoshape=True, verbose=True): + # YOLOv5 custom or local model + return create(path, autoshape, verbose) def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True): @@ -127,7 +106,8 @@ def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=Tr if __name__ == '__main__': - model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True, verbose=True) # pretrained + model = create(name='weights/yolov5s.pt', pretrained=True, channels=3, classes=80, autoshape=True, + verbose=True) # pretrained # model = custom(path_or_model='path/to/model.pt') # custom # Verify inference