Skip to content

Commit

Permalink
Initial model ensemble capability #318
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jul 7, 2020
1 parent 121d90b commit e8cf24b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 20 deletions.
7 changes: 3 additions & 4 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch.backends.cudnn as cudnn

from utils import google_utils
from models.experimental import *
from utils.datasets import *
from utils.utils import *

Expand All @@ -20,8 +20,7 @@ def detect(save_img=False):
half = device.type != 'cpu' # half precision only supported on CUDA

# Load model
google_utils.attempt_download(weights)
model = torch.load(weights, map_location=device)['model'].float().eval() # load FP32 model
model = attempt_load(weights, map_location=device) # load FP32 model
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
if half:
model.half() # to FP16
Expand Down Expand Up @@ -137,7 +136,7 @@ def detect(save_img=False):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path')
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder, 0 for webcam
parser.add_argument('--output', type=str, default='inference/output', help='output folder') # output folder
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
Expand Down
17 changes: 17 additions & 0 deletions models/experimental.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file contains experimental modules

from models.common import *
from utils import google_utils


class CrossConv(nn.Module):
Expand Down Expand Up @@ -119,3 +120,19 @@ def forward(self, x, augment=False):
for module in self:
y.append(module(x, augment)[0])
return torch.cat(y, 1), None # ensembled inference output, train output


def attempt_load(weights, map_location=None):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
google_utils.attempt_download(w)
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model

if len(model) == 1:
return model[-1] # return model
else:
print('Ensemble created with %s\n' % weights)
for k in ['names', 'stride']:
setattr(model, k, getattr(model[-1], k))
return model # return ensemble
27 changes: 12 additions & 15 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import argparse
import json

from utils import google_utils
from models.experimental import *
from utils.datasets import *
from utils.utils import *


def test(data,
Expand All @@ -20,28 +19,26 @@ def test(data,
dataloader=None,
merge=False):
# Initialize/load model and set device
if model is None:
training = False
merge = opt.merge # use Merge NMS
training = model is not None
if training: # called by train.py
device = next(model.parameters()).device # get model device

else: # called directly
device = torch_utils.select_device(opt.device, batch_size=batch_size)
merge = opt.merge # use Merge NMS

# Remove previous
for f in glob.glob('test_batch*.jpg'):
os.remove(f)

# Load model
google_utils.attempt_download(weights)
model = torch.load(weights, map_location=device)['model'].float().fuse().to(device) # load to FP32
model = attempt_load(weights, map_location=device) # load FP32 model
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size

# Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99
# if device.type != 'cpu' and torch.cuda.device_count() > 1:
# model = nn.DataParallel(model)

else: # called by train.py
training = True
device = next(model.parameters()).device # get model device

# Half
half = device.type != 'cpu' and torch.cuda.device_count() == 1 # half precision only supported on single-GPU
if half:
Expand All @@ -56,11 +53,11 @@ def test(data,
niou = iouv.numel()

# Dataloader
if dataloader is None: # not training
if not training:
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
dataloader = create_dataloader(path, imgsz, batch_size, int(max(model.stride)), opt,
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt,
hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0]

seen = 0
Expand Down Expand Up @@ -193,7 +190,7 @@ def test(data,
if save_json and map50 and len(jdict):
imgIds = [int(Path(x).stem.split('_')[-1]) for x in dataloader.dataset.img_files]
f = 'detections_val2017_%s_results.json' % \
(weights.split(os.sep)[-1].replace('.pt', '') if weights else '') # filename
(weights.split(os.sep)[-1].replace('.pt', '') if isinstance(weights, str) else '') # filename
print('\nCOCO mAP with pycocotools... saving %s...' % f)
with open(f, 'w') as file:
json.dump(jdict, file)
Expand Down Expand Up @@ -226,7 +223,7 @@ def test(data,

if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='test.py')
parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path')
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
Expand Down
2 changes: 1 addition & 1 deletion utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
rect=rect, # rectangular training
cache_images=cache,
single_cls=opt.single_cls,
stride=stride,
stride=int(stride),
pad=pad)

batch_size = min(batch_size, len(dataset))
Expand Down

0 comments on commit e8cf24b

Please sign in to comment.