From e8cf24b6c8ba9bf8c99251f2978a5570aea59778 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 7 Jul 2020 15:40:50 -0700 Subject: [PATCH] Initial model ensemble capability #318 --- detect.py | 7 +++---- models/experimental.py | 17 +++++++++++++++++ test.py | 27 ++++++++++++--------------- utils/datasets.py | 2 +- 4 files changed, 33 insertions(+), 20 deletions(-) diff --git a/detect.py b/detect.py index d02f0a922817..7b9d9b69b142 100644 --- a/detect.py +++ b/detect.py @@ -2,7 +2,7 @@ import torch.backends.cudnn as cudnn -from utils import google_utils +from models.experimental import * from utils.datasets import * from utils.utils import * @@ -20,8 +20,7 @@ def detect(save_img=False): half = device.type != 'cpu' # half precision only supported on CUDA # Load model - google_utils.attempt_download(weights) - model = torch.load(weights, map_location=device)['model'].float().eval() # load FP32 model + model = attempt_load(weights, map_location=device) # load FP32 model imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size if half: model.half() # to FP16 @@ -137,7 +136,7 @@ def detect(save_img=False): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path') + parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder, 0 for webcam parser.add_argument('--output', type=str, default='inference/output', help='output folder') # output folder parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') diff --git a/models/experimental.py b/models/experimental.py index 146a61b67a44..32a88f284648 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -1,6 +1,7 @@ # This file contains experimental modules from models.common import * +from utils import google_utils class CrossConv(nn.Module): @@ -119,3 +120,19 @@ def forward(self, x, augment=False): for module in self: y.append(module(x, augment)[0]) return torch.cat(y, 1), None # ensembled inference output, train output + + +def attempt_load(weights, map_location=None): + # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a + model = Ensemble() + for w in weights if isinstance(weights, list) else [weights]: + google_utils.attempt_download(w) + model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model + + if len(model) == 1: + return model[-1] # return model + else: + print('Ensemble created with %s\n' % weights) + for k in ['names', 'stride']: + setattr(model, k, getattr(model[-1], k)) + return model # return ensemble diff --git a/test.py b/test.py index 1cfae9591287..a39ac9e41276 100644 --- a/test.py +++ b/test.py @@ -1,9 +1,8 @@ import argparse import json -from utils import google_utils +from models.experimental import * from utils.datasets import * -from utils.utils import * def test(data, @@ -20,28 +19,26 @@ def test(data, dataloader=None, merge=False): # Initialize/load model and set device - if model is None: - training = False - merge = opt.merge # use Merge NMS + training = model is not None + if training: # called by train.py + device = next(model.parameters()).device # get model device + + else: # called directly device = torch_utils.select_device(opt.device, batch_size=batch_size) + merge = opt.merge # use Merge NMS # Remove previous for f in glob.glob('test_batch*.jpg'): os.remove(f) # Load model - google_utils.attempt_download(weights) - model = torch.load(weights, map_location=device)['model'].float().fuse().to(device) # load to FP32 + model = attempt_load(weights, map_location=device) # load FP32 model imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99 # if device.type != 'cpu' and torch.cuda.device_count() > 1: # model = nn.DataParallel(model) - else: # called by train.py - training = True - device = next(model.parameters()).device # get model device - # Half half = device.type != 'cpu' and torch.cuda.device_count() == 1 # half precision only supported on single-GPU if half: @@ -56,11 +53,11 @@ def test(data, niou = iouv.numel() # Dataloader - if dataloader is None: # not training + if not training: img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images - dataloader = create_dataloader(path, imgsz, batch_size, int(max(model.stride)), opt, + dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0] seen = 0 @@ -193,7 +190,7 @@ def test(data, if save_json and map50 and len(jdict): imgIds = [int(Path(x).stem.split('_')[-1]) for x in dataloader.dataset.img_files] f = 'detections_val2017_%s_results.json' % \ - (weights.split(os.sep)[-1].replace('.pt', '') if weights else '') # filename + (weights.split(os.sep)[-1].replace('.pt', '') if isinstance(weights, str) else '') # filename print('\nCOCO mAP with pycocotools... saving %s...' % f) with open(f, 'w') as file: json.dump(jdict, file) @@ -226,7 +223,7 @@ def test(data, if __name__ == '__main__': parser = argparse.ArgumentParser(prog='test.py') - parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path') + parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path') parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch') parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') diff --git a/utils/datasets.py b/utils/datasets.py index 1ebd709482fe..5c88f37e2e25 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -48,7 +48,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa rect=rect, # rectangular training cache_images=cache, single_cls=opt.single_cls, - stride=stride, + stride=int(stride), pad=pad) batch_size = min(batch_size, len(dataset))