From e5114707e50e2303081a75c0930a73880c257b60 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 15 Oct 2020 20:10:08 +0200 Subject: [PATCH] Simplified inference (#1153) --- detect.py | 4 +-- hubconf.py | 5 +--- models/common.py | 64 ++++++++++++++++++++++++++++++++++++++++---- models/yolo.py | 31 +++++++++++++-------- sotabench.py | 9 +++---- utils/datasets.py | 2 +- utils/torch_utils.py | 2 +- 7 files changed, 87 insertions(+), 30 deletions(-) diff --git a/detect.py b/detect.py index eee5f0208244..a98fe7394854 100644 --- a/detect.py +++ b/detect.py @@ -149,8 +149,8 @@ def detect(save_img=False): parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder, 0 for webcam parser.add_argument('--output', type=str, default='inference/output', help='output folder') # output folder parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') - parser.add_argument('--conf-thres', type=float, default=0.4, help='object confidence threshold') - parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS') + parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold') + parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') diff --git a/hubconf.py b/hubconf.py index 168b40502f25..94208e064784 100644 --- a/hubconf.py +++ b/hubconf.py @@ -10,7 +10,6 @@ import torch -from models.common import NMS from models.yolo import Model from utils.google_utils import attempt_download @@ -36,9 +35,7 @@ def create(name, pretrained, channels, classes): state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32 state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter model.load_state_dict(state_dict, strict=False) # load - - model.add_nms() # add NMS module - model.eval() + # model = model.autoshape() # cv2/PIL/np/torch inference: predictions = model(Image.open('image.jpg')) return model except Exception as e: diff --git a/models/common.py b/models/common.py index 314c31f91aac..022ad00ba43e 100644 --- a/models/common.py +++ b/models/common.py @@ -1,9 +1,12 @@ # This file contains modules common to various models -import math +import math +import numpy as np import torch import torch.nn as nn -from utils.general import non_max_suppression + +from utils.datasets import letterbox +from utils.general import non_max_suppression, make_divisible, scale_coords def autopad(k, p=None): # kernel, padding @@ -101,17 +104,68 @@ def forward(self, x): class NMS(nn.Module): # Non-Maximum Suppression (NMS) module - conf = 0.3 # confidence threshold - iou = 0.6 # IoU threshold + conf = 0.25 # confidence threshold + iou = 0.45 # IoU threshold classes = None # (optional list) filter by class - def __init__(self, dimension=1): + def __init__(self): super(NMS, self).__init__() def forward(self, x): return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) +class autoShape(nn.Module): + # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS + img_size = 640 # inference size (pixels) + conf = 0.25 # NMS confidence threshold + iou = 0.45 # NMS IoU threshold + classes = None # (optional list) filter by class + + def __init__(self, model): + super(autoShape, self).__init__() + self.model = model + + def forward(self, x, size=640, augment=False, profile=False): + # supports inference from various sources. For height=720, width=1280, RGB images example inputs are: + # opencv: x = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) + # PIL: x = Image.open('image.jpg') # HWC x(720,1280,3) + # numpy: x = np.zeros((720,1280,3)) # HWC + # torch: x = torch.zeros(16,3,720,1280) # BCHW + # multiple: x = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images + + p = next(self.model.parameters()) # for device and type + if isinstance(x, torch.Tensor): # torch + return self.model(x.to(p.device).type_as(p), augment, profile) # inference + + # Pre-process + if not isinstance(x, list): + x = [x] + shape0, shape1 = [], [] # image and inference shapes + batch = range(len(x)) # batch size + for i in batch: + x[i] = np.array(x[i])[:, :, :3] # up to 3 channels if png + s = x[i].shape[:2] # HWC + shape0.append(s) # image shape + g = (size / max(s)) # gain + shape1.append([y * g for y in s]) + shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape + x = [letterbox(x[i], new_shape=shape1, auto=False)[0] for i in batch] # pad + x = np.stack(x, 0) if batch[-1] else x[0][None] # stack + x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW + x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 + + # Inference + x = self.model(x, augment, profile) # forward + x = non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS + + # Post-process + for i in batch: + if x[i] is not None: + x[i][:, :4] = scale_coords(shape1, x[i][:, :4], shape0[i]) + return x + + class Flatten(nn.Module): # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions @staticmethod diff --git a/models/yolo.py b/models/yolo.py index a9dc539bf29f..0d46054ed21c 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -1,21 +1,22 @@ import argparse import logging -import math import sys from copy import deepcopy from pathlib import Path +import math + sys.path.append('./') # to run '$ python *.py' files in subdirectories logger = logging.getLogger(__name__) import torch import torch.nn as nn -from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS +from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape from models.experimental import MixConv2d, CrossConv, C3 from utils.general import check_anchor_order, make_divisible, check_file, set_logging -from utils.torch_utils import ( - time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, select_device) +from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \ + select_device, copy_attr class Detect(nn.Module): @@ -140,6 +141,7 @@ def forward_once(self, x, profile=False): return x def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency + # https://arxiv.org/abs/1708.02002 section 3.3 # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1. m = self.model[-1] # Detect() module for mi, s in zip(m.m, m.stride): # from @@ -170,15 +172,26 @@ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers self.info() return self - def add_nms(self): # fuse model Conv2d() + BatchNorm2d() layers - if type(self.model[-1]) is not NMS: # if missing NMS - print('Adding NMS module... ') + def nms(self, mode=True): # add or remove NMS module + present = type(self.model[-1]) is NMS # last layer is NMS + if mode and not present: + print('Adding NMS... ') m = NMS() # module m.f = -1 # from m.i = self.model[-1].i + 1 # index self.model.add_module(name='%s' % m.i, module=m) # add + self.eval() + elif not mode and present: + print('Removing NMS... ') + self.model = self.model[:-1] # remove return self + def autoshape(self): # add autoShape module + print('Adding autoShape... ') + m = autoShape(self) # wrap model + copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes + return m + def info(self, verbose=False): # print model information model_info(self, verbose) @@ -263,10 +276,6 @@ def parse_model(d, ch): # model_dict, input_channels(3) # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device) # y = model(img, profile=True) - # ONNX export - # model.model[-1].export = True - # torch.onnx.export(model, img, opt.cfg.replace('.yaml', '.onnx'), verbose=True, opset_version=11) - # Tensorboard # from torch.utils.tensorboard import SummaryWriter # tb_writer = SummaryWriter() diff --git a/sotabench.py b/sotabench.py index 96ea6bffcbb0..9507d0754e95 100644 --- a/sotabench.py +++ b/sotabench.py @@ -1,6 +1,5 @@ import argparse import glob -import json import os import shutil from pathlib import Path @@ -8,19 +7,17 @@ import numpy as np import torch import yaml +from sotabencheval.object_detection import COCOEvaluator +from sotabencheval.utils import is_server from tqdm import tqdm from models.experimental import attempt_load from utils.datasets import create_dataloader from utils.general import ( coco80_to_coco91_class, check_dataset, check_file, check_img_size, compute_loss, non_max_suppression, scale_coords, - xyxy2xywh, clip_coords, plot_images, xywh2xyxy, box_iou, output_to_target, ap_per_class, set_logging) + xyxy2xywh, clip_coords, set_logging) from utils.torch_utils import select_device, time_synchronized - -from sotabencheval.object_detection import COCOEvaluator -from sotabencheval.utils import is_server - DATA_ROOT = './.data/vision/coco' if is_server() else '../coco' # sotabench data dir diff --git a/utils/datasets.py b/utils/datasets.py index 29ee4b051e85..9192dec4b7d9 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -1,5 +1,4 @@ import glob -import math import os import random import shutil @@ -8,6 +7,7 @@ from threading import Thread import cv2 +import math import numpy as np import torch from PIL import Image, ExifTags diff --git a/utils/torch_utils.py b/utils/torch_utils.py index c587617b821c..f6818238452f 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -1,9 +1,9 @@ import logging -import math import os import time from copy import deepcopy +import math import torch import torch.backends.cudnn as cudnn import torch.nn as nn