diff --git a/hubconf.py b/hubconf.py index 168b40502f25..e439530b7660 100644 --- a/hubconf.py +++ b/hubconf.py @@ -36,9 +36,7 @@ def create(name, pretrained, channels, classes): state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32 state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter model.load_state_dict(state_dict, strict=False) # load - - model.add_nms() # add NMS module - model.eval() + # model = model.autoshape() # cv2/PIL/np inference: predictions = model(cv2.imread('img.jpg')) return model except Exception as e: diff --git a/models/common.py b/models/common.py index 314c31f91aac..a54050337d14 100644 --- a/models/common.py +++ b/models/common.py @@ -1,9 +1,12 @@ # This file contains modules common to various models import math +import numpy as np import torch import torch.nn as nn -from utils.general import non_max_suppression + +from utils.datasets import letterbox +from utils.general import non_max_suppression, make_divisible, scale_coords def autopad(k, p=None): # kernel, padding @@ -101,17 +104,37 @@ def forward(self, x): class NMS(nn.Module): # Non-Maximum Suppression (NMS) module - conf = 0.3 # confidence threshold - iou = 0.6 # IoU threshold + conf = 0.25 # confidence threshold + iou = 0.45 # IoU threshold classes = None # (optional list) filter by class - def __init__(self, dimension=1): + def __init__(self): super(NMS, self).__init__() def forward(self, x): return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) +class autoShape(nn.Module): + # auto-reshape image size model wrapper + img_size = 640 # inference size (pixels) + + def __init__(self, model): + super(autoShape, self).__init__() + self.model = model + + def forward(self, x, shape=640, augment=False, profile=False): # x = cv2.imread('img.jpg') + x0shape = x.shape[:2] + p = next(self.model.parameters()) + x, ratio, (dw, dh) = letterbox(x, new_shape=make_divisible(shape or max(x0shape), int(self.stride.max()))) + x1shape = x.shape[:2] + x = np.ascontiguousarray(x[:, :, ::-1].transpose(2, 0, 1)) # BGR to RGB, to 3x640x640 + x = torch.from_numpy(x).to(p.device).type_as(p).unsqueeze(0) / 255. # uint8 to fp16/32 + x = self.model(x, augment, profile) # forward + x[0][:, :4] = scale_coords(x1shape, x[0][:, :4], x0shape) + return x + + class Flatten(nn.Module): # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions @staticmethod diff --git a/models/yolo.py b/models/yolo.py index 46f1375e50aa..e7f9c8e99134 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -11,11 +11,11 @@ import torch import torch.nn as nn -from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS +from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape from models.experimental import MixConv2d, CrossConv, C3 from utils.general import check_anchor_order, make_divisible, check_file, set_logging -from utils.torch_utils import ( - time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, select_device) +from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \ + select_device, copy_attr class Detect(nn.Module): @@ -140,6 +140,7 @@ def forward_once(self, x, profile=False): return x def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency + # https://arxiv.org/abs/1708.02002 section 3.3 # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1. m = self.model[-1] # Detect() module for mi, s in zip(m.m, m.stride): # from @@ -170,15 +171,27 @@ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers self.info() return self - def add_nms(self): # fuse model Conv2d() + BatchNorm2d() layers - if type(self.model[-1]) is not NMS: # if missing NMS - print('Adding NMS module... ') + def nms(self, mode=True): # add or remove NMS module + present = type(self.model[-1]) is NMS # last layer is NMS + if mode and not present: + print('Adding NMS... ') m = NMS() # module m.f = -1 # from m.i = self.model[-1].i + 1 # index self.model.add_module(name='%s' % m.i, module=m) # add + self.eval() + elif not mode and present: + print('Removing NMS... ') + self.model = self.model[:-1] # remove return self + def autoshape(self): # add autoShape module + print('Adding autoShape... ') + self.nms() # add NMS + m = autoShape(self) # wrap model + copy_attr(m, self, include=('names', 'stride', 'nc', 'autoshape'), exclude=()) # copy attributes + return m + def info(self, verbose=False): # print model information model_info(self, verbose) @@ -263,10 +276,6 @@ def parse_model(d, ch): # model_dict, input_channels(3) # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device) # y = model(img, profile=True) - # ONNX export - # model.model[-1].export = True - # torch.onnx.export(model, img, opt.cfg.replace('.yaml', '.onnx'), verbose=True, opset_version=11) - # Tensorboard # from torch.utils.tensorboard import SummaryWriter # tb_writer = SummaryWriter()