From 7a1a5587adaf68d4250dcc939240de735f5a4ca3 Mon Sep 17 00:00:00 2001 From: wjx052333 Date: Wed, 26 Apr 2023 14:21:27 +0800 Subject: [PATCH] init from yolov5_quant_sample --- data/hyps/hyp.qat.yaml | 33 ++ docker/build.sh | 2 + docker/launch.sh | 12 + export_qat.py | 66 ++++ models/common.py | 13 +- models/experimental.py | 6 +- models/yolo.py | 5 +- train.py | 29 +- trt/Processor.py | 367 +++++++++++++++++ trt/Visualizer.py | 57 +++ trt/__init__.py | 0 trt/calibrator.py | 113 ++++++ trt/demo.py | 41 ++ trt/eval_yolo_trt.py | 122 ++++++ trt/onnx_to_trt.py | 199 ++++++++++ trt/onnx_to_trt_partialquant.py | 251 ++++++++++++ trt/trt_dynamic/Processor_dynamic.py | 375 ++++++++++++++++++ trt/trt_dynamic/__init__.py | 0 trt/trt_dynamic/eval_yolo_trt_dynamic.py | 144 +++++++ utils_quant/__init__.py | 0 utils_quant/calib_test.sh | 25 ++ utils_quant/check_params.py | 83 ++++ utils_quant/onnxrt_demo.py | 356 +++++++++++++++++ utils_quant/print_model_structure.py | 43 ++ val.py | 3 +- yolo_quant_flow.py | 475 +++++++++++++++++++++++ 26 files changed, 2802 insertions(+), 18 deletions(-) create mode 100644 data/hyps/hyp.qat.yaml create mode 100644 docker/build.sh create mode 100644 docker/launch.sh create mode 100644 export_qat.py create mode 100644 trt/Processor.py create mode 100644 trt/Visualizer.py create mode 100644 trt/__init__.py create mode 100644 trt/calibrator.py create mode 100644 trt/demo.py create mode 100644 trt/eval_yolo_trt.py create mode 100644 trt/onnx_to_trt.py create mode 100644 trt/onnx_to_trt_partialquant.py create mode 100644 trt/trt_dynamic/Processor_dynamic.py create mode 100644 trt/trt_dynamic/__init__.py create mode 100644 trt/trt_dynamic/eval_yolo_trt_dynamic.py create mode 100644 utils_quant/__init__.py create mode 100644 utils_quant/calib_test.sh create mode 100644 utils_quant/check_params.py create mode 100644 utils_quant/onnxrt_demo.py create mode 100644 utils_quant/print_model_structure.py create mode 100644 yolo_quant_flow.py diff --git a/data/hyps/hyp.qat.yaml b/data/hyps/hyp.qat.yaml new file mode 100644 index 000000000000..e69977f4492b --- /dev/null +++ b/data/hyps/hyp.qat.yaml @@ -0,0 +1,33 @@ +# Hyperparameters for COCO training from scratch +# python train.py --batch 40 --cfg yolov5m.yaml --weights '' --data coco.yaml --img 640 --epochs 300 +# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials + + +lr0: 0.0001 # initial learning rate (SGD=1E-2, Adam=1E-3) +lrf: 0.2 # final OneCycleLR learning rate (lr0 * lrf) +momentum: 0.937 # SGD momentum/Adam beta1 +weight_decay: 0.0005 # optimizer weight decay 5e-4 +warmup_epochs: 0 # warmup epochs (fractions ok) +warmup_momentum: 0.8 # warmup initial momentum +warmup_bias_lr: 0.1 # warmup initial bias lr +box: 0.05 # box loss gain +cls: 0.5 # cls loss gain +cls_pw: 1.0 # cls BCELoss positive_weight +obj: 1.0 # obj loss gain (scale with pixels) +obj_pw: 1.0 # obj BCELoss positive_weight +iou_t: 0.20 # IoU training threshold +anchor_t: 4.0 # anchor-multiple threshold +# anchors: 3 # anchors per output layer (0 to ignore) +fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) +hsv_h: 0.015 # image HSV-Hue augmentation (fraction) +hsv_s: 0.7 # image HSV-Saturation augmentation (fraction) +hsv_v: 0.4 # image HSV-Value augmentation (fraction) +degrees: 0.0 # image rotation (+/- deg) +translate: 0.1 # image translation (+/- fraction) +scale: 0.5 # image scale (+/- gain) +shear: 0.0 # image shear (+/- deg) +perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 +flipud: 0.0 # image flip up-down (probability) +fliplr: 0.5 # image flip left-right (probability) +mosaic: 1.0 # image mosaic (probability) +mixup: 0.0 # image mixup (probability) diff --git a/docker/build.sh b/docker/build.sh new file mode 100644 index 000000000000..5f96736981f3 --- /dev/null +++ b/docker/build.sh @@ -0,0 +1,2 @@ +#!/bin/bash +docker build --network=host . --rm --pull --no-cache -t yolov5_quant \ No newline at end of file diff --git a/docker/launch.sh b/docker/launch.sh new file mode 100644 index 000000000000..bce1e98917b6 --- /dev/null +++ b/docker/launch.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +CMD=${1:-/bin/bash} +NV_VISIBLE_DEVICES=${2:-"0"} +DOCKER_BRIDGE=${3:-"host"} + +docker run -it --rm --name yolov5_quant -p 80:8888 \ + --gpus device=$NV_VISIBLE_DEVICES \ + --net=$DOCKER_BRIDGE \ + --shm-size=16g \ + -v $(dirname $(pwd)):/root/space/projects \ + yolov5_quant $CMD \ No newline at end of file diff --git a/export_qat.py b/export_qat.py new file mode 100644 index 000000000000..acf38aa6dee9 --- /dev/null +++ b/export_qat.py @@ -0,0 +1,66 @@ +"""Exports a YOLOv5 *.pt model to ONNX and TorchScript formats + +Usage: + $ export PYTHONPATH="$PWD" && python models/export.py --weights ./weights/yolov5s.pt --img 640 --batch 1 +""" + +import argparse +import sys +import time +import warnings + +sys.path.append('./') # to run '$ python *.py' files in subdirectories + +import torch +import torch.nn as nn + +import models +from models.experimental import attempt_load +from utils.activations import Hardswish, SiLU +from utils.general import set_logging +from utils.torch_utils import select_device + +# To use Pytorch's own fake quantization functions +from pytorch_quantization import nn as quant_nn + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path') # from yolov5/models/ + parser.add_argument('--img-size', type=int, default=640, help='image size') # height, width + parser.add_argument('--batch-size', type=int, default=1, help='batch size') + parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') + parser.add_argument('--grid', action='store_true', help='export Detect() layer grid') + parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + opt = parser.parse_args() + print(opt) + set_logging() + t = time.time() + + # Load PyTorch model + device = select_device(opt.device) + model = attempt_load(opt.weights, map_location=device) # load FP32 model + model.eval() + quant_nn.TensorQuantizer.use_fb_fake_quant = True + model.model[-1].export = not opt.grid # set Detect() layer grid export + + + dummy_input = torch.rand(opt.batch_size, 3, opt.img_size, opt.img_size, device='cuda') + + # ONNX export + try: + import onnx + + print('\nStarting ONNX export with onnx %s...' % onnx.__version__) + f = opt.weights.replace('.pt', '.onnx') # filename + torch.onnx.export(model, dummy_input, f, verbose=False, opset_version=13, input_names=['images'], + output_names= ['output_0', 'output_1', 'output_2'], + dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}} if opt.dynamic else None) + + # Checks + onnx_model = onnx.load(f) # load onnx model + onnx.checker.check_model(onnx_model) # check onnx model + # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model + print('ONNX export success, saved as %s' % f) + except Exception as e: + print('ONNX export failure: %s' % e) diff --git a/models/common.py b/models/common.py index b1c24ad378dc..5440bff18c04 100644 --- a/models/common.py +++ b/models/common.py @@ -32,6 +32,13 @@ from utils.plots import Annotator, colors, save_one_box from utils.torch_utils import copy_attr, smart_inference_mode +try: + from pytorch_quantization import nn as quant_nn +except ImportError: + raise ImportError( + "pytorch-quantization is not installed. Install from " + "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." + ) def autopad(k, p=None, d=1): # kernel, padding, dilation # Pad to 'same' shape outputs @@ -48,7 +55,8 @@ class Conv(nn.Module): def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): super().__init__() - self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) + #wjx self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) + self.conv = quant_nn.QuantConv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) self.bn = nn.BatchNorm2d(c2) self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() @@ -206,7 +214,8 @@ def __init__(self, c1, c2, k=(5, 9, 13)): c_ = c1 // 2 # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) - self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) + #wjx self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) + self.m = nn.ModuleList([quant_nn.QuantMaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) def forward(self, x): x = self.cv1(x) diff --git a/models/experimental.py b/models/experimental.py index d60d1808da11..cfffaf4eb589 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -85,7 +85,11 @@ def attempt_load(weights, device=None, inplace=True, fuse=True): if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)): ckpt.names = dict(enumerate(ckpt.names)) # convert to dict - model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode + #model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode + # Modified by maggie. + # 1. Since we benchmark the speed using TensorRT backend, so it is not necesary to fuse. + # 2. If fuse, the fuse_conv_and_bn function will be called, then the quant_nn.QuantConv2d will be replace by noraml Conv2d wjx + model.append(ckpt.eval()) # Module compatibility updates for m in model.modules(): diff --git a/models/yolo.py b/models/yolo.py index 18d2542bfb48..7227ed61c7ad 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -28,7 +28,7 @@ from utils.plots import feature_visualization from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device, time_sync) - +from pytorch_quantization import nn as quant_nn try: import thop # for FLOPs computation except ImportError: @@ -50,7 +50,8 @@ def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer self.grid = [torch.empty(0) for _ in range(self.nl)] # init grid self.anchor_grid = [torch.empty(0) for _ in range(self.nl)] # init anchor grid self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2) - self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv + #self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv + self.m = nn.ModuleList(quant_nn.QuantConv2d(x, self.no * self.na, 1) for x in ch) # output conv wjx self.inplace = inplace # use inplace ops (e.g. slice assignment) def forward(self, x): diff --git a/train.py b/train.py index 216da6399028..42aa1bfa36a8 100644 --- a/train.py +++ b/train.py @@ -118,18 +118,21 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # Model check_suffix(weights, '.pt') # check weights pretrained = weights.endswith('.pt') - if pretrained: - with torch_distributed_zero_first(LOCAL_RANK): - weights = attempt_download(weights) # download if not found locally - ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak - model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create - exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys - csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 - csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect - model.load_state_dict(csd, strict=False) # load - LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report + if model is None: # wjx + if pretrained: + with torch_distributed_zero_first(LOCAL_RANK): + weights = attempt_download(weights) # download if not found locally + ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak + model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create + exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys + csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 + csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect + model.load_state_dict(csd, strict=False) # load + LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report + else: + model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create else: - model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create + pretrained = False # For QAT finetuning wjx amp = check_amp(model) # check AMP # Freeze @@ -415,7 +418,9 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio data_dict, batch_size=batch_size // WORLD_SIZE * 2, imgsz=imgsz, - model=attempt_load(f, device).half(), + #model=attempt_load(f, device).half(), + # During QAT finetuning, close the half recision wjx + model=attempt_load(f, device), iou_thres=0.65 if is_coco else 0.60, # best pycocotools at iou 0.65 single_cls=single_cls, dataloader=val_loader, diff --git a/trt/Processor.py b/trt/Processor.py new file mode 100644 index 000000000000..42508d0cd28b --- /dev/null +++ b/trt/Processor.py @@ -0,0 +1,367 @@ +import cv2 +import sys +import tensorrt as trt +import pycuda.driver as cuda +import pycuda.autoinit # Necessary, to enable pycuda +import numpy as np +import math +import time + +sys.path.append('./') # to run '$ python *.py' files in subdirectories +import torch +import torchvision + + +class HostDeviceMem(object): + """Simple helper data class that's a little nicer to use than a 2-tuple.""" + def __init__(self, host_mem, device_mem): + self.host = host_mem + self.device = device_mem + + def __str__(self): + return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) + + def __repr__(self): + return self.__str__() + + +def get_input_shape(engine): + """Get input shape of the TensorRT YOLO engine.""" + binding = engine[0] + assert engine.binding_is_input(binding) + binding_dims = engine.get_binding_shape(binding) + if len(binding_dims) == 4: + return tuple(binding_dims[2:]) + elif len(binding_dims) == 3: + return tuple(binding_dims[1:]) + else: + raise ValueError('bad dims of binding %s: %s' % (binding, str(binding_dims))) + +class Processor(): + def __init__(self, model, category_num=80, letter_box=False): + # load tensorrt engine + TRT_LOGGER = trt.Logger(trt.Logger.INFO) + print('TRT model path: ', model) + with open(model, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime: + engine = runtime.deserialize_cuda_engine(f.read()) + self.input_shape = get_input_shape(engine) + self.context = engine.create_execution_context() + + # Allocates all host/device in/out buffers required for an engine. + inputs = [] + outputs = [] + bindings = [] + stream = cuda.Stream() + for binding in engine: + size = trt.volume(engine.get_binding_shape(binding)) + dtype = trt.nptype(engine.get_binding_dtype(binding)) + host_mem = cuda.pagelocked_empty(size, dtype) + device_mem = cuda.mem_alloc(host_mem.nbytes) + # Append the device buffer to device bindings. + bindings.append(int(device_mem)) + # Append to the appropriate list. + if engine.binding_is_input(binding): + inputs.append(HostDeviceMem(host_mem, device_mem)) + else: + outputs.append(HostDeviceMem(host_mem, device_mem)) + + # save to class + self.inputs = inputs + self.outputs = outputs + self.bindings = bindings + self.stream = stream + self.letter_box = letter_box + # post processing config + filters = (category_num + 5) * 3 + self.output_shapes = [ + (1, 3, 80, 80, 85), + (1, 3, 40, 40, 85), + (1, 3, 20, 20, 85) + ] + self.strides = np.array([8., 16., 32.]) + anchors = np.array([ + [[10,13], [16,30], [33,23]], + [[30,61], [62,45], [59,119]], + [[116,90], [156,198], [373,326]], + ]) + self.nl = len(anchors) + self.nc = category_num # classes + self.no = self.nc + 5 # outputs per anchor + self.na = len(anchors[0]) + a = anchors.copy().astype(np.float32) + a = a.reshape(self.nl, -1, 2) + self.anchors = a.copy() + self.anchor_grid = a.copy().reshape(self.nl, 1, -1, 1, 1, 2) + + + def detect(self, img, letter_box=None): + """Detect objects in the input image.""" + letter_box = self.letter_box if letter_box is None else letter_box + resized = self.pre_process(img, self.input_shape, letter_box) + + outputs = self.inference(resized) + # reshape from flat to (1, 3, x, y, 85) + reshaped = [] + for output, shape in zip(outputs, self.output_shapes): + reshaped.append(output.reshape(shape)) + return reshaped + + + def pre_process(self, img, input_shape, letter_box=False): + """Preprocess an image before TRT YOLO inferencing. + """ + if letter_box: + img_h, img_w, _ = img.shape + new_h, new_w = input_shape[0], input_shape[1] + offset_h, offset_w = 0, 0 + if (new_w / img_w) <= (new_h / img_h): + new_h = int(img_h * new_w / img_w) + offset_h = (input_shape[0] - new_h) // 2 + else: + new_w = int(img_w * new_h / img_h) + offset_w = (input_shape[1] - new_w) // 2 + resized = cv2.resize(img, (new_w, new_h)) + img = np.full((input_shape[0], input_shape[1], 3), 114, dtype=np.uint8) + img[offset_h:(offset_h + new_h), offset_w:(offset_w + new_w), :] = resized + else: + img = cv2.resize(img, (input_shape[1], input_shape[0])) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = img.transpose((2, 0, 1)).astype(np.float32) + img /= 255.0 + return img + + def inference(self, img): + img = img[np.newaxis, :, :, :] + + # Set host input to the image. The do_inference() function + # will copy the input to the GPU before executing. + self.inputs[0].host = np.ascontiguousarray(img) + + # Transfer input data to the GPU. + [cuda.memcpy_htod_async(inp.device, inp.host, self.stream) for inp in self.inputs] + # Run inference. + self.context.execute_async_v2(bindings=self.bindings, stream_handle=self.stream.handle) + # Transfer predictions back from the GPU. + [cuda.memcpy_dtoh_async(out.host, out.device, self.stream) for out in self.outputs] + # Synchronize the stream + self.stream.synchronize() + # Return only the host outputs. + return [out.host for out in self.outputs] + + def post_process(self, outputs, img_shape, conf_thres=0.5, iou_thres=0.6): + """ + Transforms raw output into boxes, confs, classes + Applies NMS thresholding on bounding boxes and confs + Parameters: + output: raw output tensor + Returns: + boxes: x1,y1,x2,y2 tensor (dets, 4) + confs: class * obj prob tensor (dets, 1) + classes: class type tensor (dets, 1) + """ + scaled = [] + grids = [] + for out in outputs: + # print('out.shape: ', out.shape) + # print('out: ', out) + out = self.sigmoid_v(out) + # print('sigmoid_v out.shape: ', out.shape) + # print('sigmoid_v out: ', out) + _, _, width, height, _ = out.shape + grid = self.make_grid(width, height) + grids.append(grid) + scaled.append(out) + z = [] + for out, grid, stride, anchor in zip(scaled, grids, self.strides, self.anchor_grid): + _, _, width, height, _ = out.shape + out[..., 0:2] = (out[..., 0:2] * 2. - 0.5 + grid) * stride + out[..., 2:4] = (out[..., 2:4] * 2) ** 2 * anchor + + out = out.reshape((1, 3 * width * height, 85)) + z.append(out) + + pred = np.concatenate(z, 1) + + # Use Pytorch to do the post-process + det_t = self.non_max_suppression(torch.from_numpy(pred), conf_thres=conf_thres, + iou_thres=iou_thres, multi_label=True)[0] + + self.scale_coords(self.input_shape, det_t[:, :4], img_shape) + return det_t + + + def make_grid(self, nx, ny): + """ + Create scaling tensor based on box location + Source: https://github.com/ultralytics/yolov5/blob/master/models/yolo.py + Arguments + nx: x-axis num boxes + ny: y-axis num boxes + Returns + grid: tensor of shape (1, 1, nx, ny, 80) + """ + nx_vec = np.arange(nx) + ny_vec = np.arange(ny) + yv, xv = np.meshgrid(ny_vec, nx_vec) + grid = np.stack((yv, xv), axis=2) + grid = grid.reshape(1, 1, ny, nx, 2) + return grid + + + def sigmoid(self, x): + return 1 / (1 + math.exp(-x)) + + + def sigmoid_v(self, array): + return np.reciprocal(np.exp(-array) + 1.0) + + + def xywh2xyxy(self, x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + + def clip_coords(self, boxes, img_shape): + # Clip bounding xyxy bounding boxes to image shape (height, width) + boxes[:, 0].clamp_(0, img_shape[1]) # x1 + boxes[:, 1].clamp_(0, img_shape[0]) # y1 + boxes[:, 2].clamp_(0, img_shape[1]) # x2 + boxes[:, 3].clamp_(0, img_shape[0]) # y2 + + + def scale_coords(self, img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2]] -= pad[0] # x padding + coords[:, [1, 3]] -= pad[1] # y padding + coords[:, :4] /= gain + self.clip_coords(coords, img0_shape) + return coords + + def non_max_suppression(self, prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, + multi_label=False, + labels=()): + """Runs Non-Maximum Suppression (NMS) on inference results + + Returns: + list of detections, on (n,6) tensor per image [xyxy, conf, cls] + """ + + nc = prediction.shape[2] - 5 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height + max_det = 300 # maximum number of detections per image + max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + l = labels[xi] + v = torch.zeros((len(l), nc + 5), device=x.device) + v[:, :4] = l[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = self.xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, cls) + if multi_label: + i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 5:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # Apply finite constraint + # if not torch.isfinite(x).all(): + # x = x[torch.isfinite(x).all(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + elif n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + if i.shape[0] > max_det: # limit detections + i = i[:max_det] + if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = self.box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + print(f'WARNING: NMS time limit {time_limit}s exceeded') + break # time limit exceeded + + return output + + def box_iou(self, box1, box2): + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + + def box_area(box): + # box = 4xn + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) + diff --git a/trt/Visualizer.py b/trt/Visualizer.py new file mode 100644 index 000000000000..8048cf4b8e9d --- /dev/null +++ b/trt/Visualizer.py @@ -0,0 +1,57 @@ +import cv2 +import random +import colorsys +import numpy as np +import matplotlib.pyplot as plt + +coco = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', + 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush'] + +class Visualizer(): + def __init__(self): + self.color_list = self.gen_colors(coco) + + def gen_colors(self, classes): + """ + generate unique hues for each class and convert to bgr + classes -- list -- class names (80 for coco dataset) + -> list + """ + hsvs = [] + for x in range(len(classes)): + hsvs.append([float(x) / len(classes), 1., 0.7]) + random.seed(1234) + random.shuffle(hsvs) + rgbs = [] + for hsv in hsvs: + h, s, v = hsv + rgb = colorsys.hsv_to_rgb(h, s, v) + rgbs.append(rgb) + + bgrs = [] + for rgb in rgbs: + bgr = (int(rgb[2] * 255), int(rgb[1] * 255), int(rgb[0] * 255)) + bgrs.append(bgr) + return bgrs + + + def draw_results(self, img, boxes, confs, classes): + overlay = img.copy() + final = img.copy() + for box, conf, cls in zip(boxes, confs, classes): + # draw rectangle + x1, y1, x2, y2 = box + color = self.color_list[int(cls)] + cv2.rectangle(overlay, (int(x1), int(y1)), (int(x2), int(y2)), color, -1) + + cv2.addWeighted(overlay, 0.5, final, 1 - 0.5, 0, final) + + cv2.imwrite('./box_grid.jpg', final) + return final diff --git a/trt/__init__.py b/trt/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/trt/calibrator.py b/trt/calibrator.py new file mode 100644 index 000000000000..32ee26ead71b --- /dev/null +++ b/trt/calibrator.py @@ -0,0 +1,113 @@ +import os +import tensorrt as trt +import pycuda.driver as cuda +import pycuda.autoinit +import numpy as np +import cv2 +import glob + +import ctypes +import logging +logger = logging.getLogger(__name__) +ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_char_p +ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p] + + +""" +There are 4 types calibrator in TensorRT. +trt.IInt8LegacyCalibrator +trt.IInt8EntropyCalibrator +trt.IInt8EntropyCalibrator2 +trt.IInt8MinMaxCalibrator +""" + +class Calibrator(trt.IInt8MinMaxCalibrator): + def __init__(self, stream, cache_file=""): + trt.IInt8MinMaxCalibrator.__init__(self) + self.stream = stream + self.d_input = cuda.mem_alloc(self.stream.calibration_data.nbytes) + self.cache_file = cache_file + stream.reset() + + def get_batch_size(self): + return self.stream.batch_size + + def get_batch(self, names): + print("######################") + print(names) + print("######################") + batch = self.stream.next_batch() + if not batch.size: + return None + + cuda.memcpy_htod(self.d_input, batch) + return [int(self.d_input)] + + def read_calibration_cache(self): + # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None. + if os.path.exists(self.cache_file): + with open(self.cache_file, "rb") as f: + logger.info("Using calibration cache to save time: {:}".format(self.cache_file)) + return f.read() + + def write_calibration_cache(self, cache): + with open(self.cache_file, "wb") as f: + logger.info("Caching calibration data for future use: {:}".format(self.cache_file)) + f.write(cache) + +def preprocess(img, input_shape, letter_box=False): + """Preprocess an image before TRT YOLO inferencing. + """ + if letter_box: + img_h, img_w, _ = img.shape + new_h, new_w = input_shape[0], input_shape[1] + offset_h, offset_w = 0, 0 + if (new_w / img_w) <= (new_h / img_h): + new_h = int(img_h * new_w / img_w) + offset_h = (input_shape[0] - new_h) // 2 + else: + new_w = int(img_w * new_h / img_h) + offset_w = (input_shape[1] - new_w) // 2 + resized = cv2.resize(img, (new_w, new_h)) + img = np.full((input_shape[0], input_shape[1], 3), 114, dtype=np.uint8) + img[offset_h:(offset_h + new_h), offset_w:(offset_w + new_w), :] = resized + else: + img = cv2.resize(img, (input_shape[1], input_shape[0])) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = img.transpose((2, 0, 1)).astype(np.float32) + img /= 255.0 + return img + +class DataLoader: + def __init__(self, batch_size, batch_num, calib_img_dir, input_w, input_h): + self.index = 0 + self.length = batch_num + self.batch_size = batch_size + self.input_h = input_h + self.input_w = input_w + # self.img_list = [i.strip() for i in open('calib.txt').readlines()] + self.img_list = glob.glob(os.path.join(calib_img_dir, "*.jpg")) + assert len(self.img_list) > self.batch_size * self.length, \ + '{} must contains more than '.format(calib_img_dir) + str(self.batch_size * self.length) + ' images to calib' + print('found all {} images to calib.'.format(len(self.img_list))) + self.calibration_data = np.zeros((self.batch_size, 3, input_h, input_w), dtype=np.float32) + + def reset(self): + self.index = 0 + + def next_batch(self): + if self.index < self.length: + for i in range(self.batch_size): + assert os.path.exists(self.img_list[i + self.index * self.batch_size]), 'not found!!' + img = cv2.imread(self.img_list[i + self.index * self.batch_size]) + img = preprocess(img, (self.input_h, self.input_w, 3), letter_box=True) + self.calibration_data[i] = img + + self.index += 1 + return np.ascontiguousarray(self.calibration_data, dtype=np.float32) + else: + return np.array([]) + + def __len__(self): + return self.length \ No newline at end of file diff --git a/trt/demo.py b/trt/demo.py new file mode 100644 index 000000000000..a83deb406abe --- /dev/null +++ b/trt/demo.py @@ -0,0 +1,41 @@ +import cv2 +import sys +import argparse + +from Processor import Processor +from Visualizer import Visualizer + +def cli(): + desc = 'Run TensorRT yolov5 visualizer' + parser = argparse.ArgumentParser(description=desc) + parser.add_argument('-m', '--model', default='./weights/yolov5s-simple.trt', help='trt engine file path', required=False) + parser.add_argument('-i', '--image', default='./data/images/bus.jpg', help='image file path', required=False) + args = parser.parse_args() + return args + +def main(): + # parse arguments + args = cli() + + # setup processor and visualizer + processor = Processor(model=args.model, letter_box=True) + visualizer = Visualizer() + + img = cv2.imread(args.image) + + # inference + output = processor.detect(img) + + # final results + pred = processor.post_process(output, img.shape, conf_thres=0.5) + + print('Detection result: ') + for item in pred.tolist(): + print(item) + + visualizer.draw_results(img, pred[:, :4], pred[:, 4], pred[:, 5]) + + + +if __name__ == '__main__': + main() diff --git a/trt/eval_yolo_trt.py b/trt/eval_yolo_trt.py new file mode 100644 index 000000000000..85aa818c3ad2 --- /dev/null +++ b/trt/eval_yolo_trt.py @@ -0,0 +1,122 @@ +"""eval_yolo.py + +This script is for evaluating mAP (accuracy) of YOLO models. +""" +import os +import sys +import json +import argparse + +sys.path.append('./') # to run '$ python *.py' files in subdirectories + +import cv2 +import torch + +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +from Processor import Processor +from utils.general import coco80_to_coco91_class + +# converts 80-index (val2014) to 91-index (paper) +coco91class = coco80_to_coco91_class() + +VAL_IMGS_DIR = '../coco/images/val2017' +VAL_ANNOTATIONS = '../coco/annotations/instances_val2017.json' + +def parse_args(): + """Parse input arguments.""" + desc = 'Evaluate mAP of YOLO TRT model' + parser = argparse.ArgumentParser(description=desc) + parser.add_argument( + '--imgs-dir', type=str, default=VAL_IMGS_DIR, + help='directory of validation images [%s]' % VAL_IMGS_DIR) + parser.add_argument( + '--annotations', type=str, default=VAL_ANNOTATIONS, + help='groundtruth annotations [%s]' % VAL_ANNOTATIONS) + parser.add_argument( + '-c', '--category-num', type=int, default=80, + help='number of object categories [80]') + parser.add_argument( + '--img-size', nargs='+', type=int, default=[640, 640], help='image size') + parser.add_argument( + '-m', '--model', type=str, default='./weights/yolov5s-simple.trt', + help=('trt model path')) + parser.add_argument( + '-l', '--letter_box', action='store_true', + help='inference with letterboxed image [False]') + parser.add_argument( + '--conf-thres', type=float, default=0.001, + help='object confidence threshold') + parser.add_argument( + '--iou-thres', type=float, default=0.6, + help='IOU threshold for NMS') + args = parser.parse_args() + return args + + +def check_args(args): + """Check and make sure command-line arguments are valid.""" + if not os.path.isdir(args.imgs_dir): + sys.exit('%s is not a valid directory' % args.imgs_dir) + if not os.path.isfile(args.annotations): + sys.exit('%s is not a valid file' % args.annotations) + + +def generate_results(processor, imgs_dir, jpgs, results_file, conf_thres, iou_thres, non_coco): + """Run detection on each jpg and write results to file.""" + results = [] + + i = 0 + for jpg in jpgs: + i+=1 + if(i%100 == 0): + print('Processing {} images'.format(i)) + img = cv2.imread(os.path.join(imgs_dir, jpg)) + image_id = int(jpg.split('.')[0].split('_')[-1]) + output = processor.detect(img) + + pred = processor.post_process(output, img.shape, conf_thres=conf_thres, + iou_thres=iou_thres) + for p in pred.tolist(): + x = float(p[0]) + y = float(p[1]) + w = float(p[2] - p[0]) + h = float(p[3] - p[1]) + results.append({'image_id': image_id, + 'category_id': coco91class[int(p[5])] if not non_coco else int(p[5]), + 'bbox': [round(x, 3) for x in [x, y, w, h]], + 'score': round(p[4], 5)}) + + with open(results_file, 'w') as f: + f.write(json.dumps(results, indent=4)) + + +def main(): + args = parse_args() + check_args(args) + + model_prefix = args.model.replace('.trt', '').split('/')[-1] + results_file = 'weights/results_{}.json'.format(model_prefix) + + # setup processor + processor = Processor(model=args.model, letter_box=True) + + jpgs = [j for j in os.listdir(args.imgs_dir) if j.endswith('.jpg')] + generate_results(processor, args.imgs_dir, jpgs, results_file, args.conf_thres, args.iou_thres, + non_coco=False) + + # Run COCO mAP evaluation + # Reference: https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb + cocoGt = COCO(args.annotations) + cocoDt = cocoGt.loadRes(results_file) + imgIds = sorted(cocoGt.getImgIds()) + cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') + cocoEval.params.imgIds = imgIds + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + + +if __name__ == '__main__': + main() diff --git a/trt/onnx_to_trt.py b/trt/onnx_to_trt.py new file mode 100644 index 000000000000..7a5bd75c8072 --- /dev/null +++ b/trt/onnx_to_trt.py @@ -0,0 +1,199 @@ +# onnx_to_tensorrt.py +# +# Copyright 1993-2019 NVIDIA Corporation. All rights reserved. +# +# NOTICE TO LICENSEE: +# +# This source code and/or documentation ("Licensed Deliverables") are +# subject to NVIDIA intellectual property rights under U.S. and +# international Copyright laws. +# +# These Licensed Deliverables contained herein is PROPRIETARY and +# CONFIDENTIAL to NVIDIA and is being provided under the terms and +# conditions of a form of NVIDIA software license agreement by and +# between NVIDIA and Licensee ("License Agreement") or electronically +# accepted by Licensee. Notwithstanding any terms or conditions to +# the contrary in the License Agreement, reproduction or disclosure +# of the Licensed Deliverables to any third party without the express +# written consent of NVIDIA is prohibited. +# +# NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE +# LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE +# SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS +# PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. +# NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED +# DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, +# NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. +# NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE +# LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY +# SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY +# DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, +# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS +# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE +# OF THESE LICENSED DELIVERABLES. +# +# U.S. Government End Users. These Licensed Deliverables are a +# "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT +# 1995), consisting of "commercial computer software" and "commercial +# computer software documentation" as such terms are used in 48 +# C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government +# only as a commercial end item. Consistent with 48 C.F.R.12.212 and +# 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all +# U.S. Government End Users acquire the Licensed Deliverables with +# only those rights set forth herein. +# +# Any use of the Licensed Deliverables in individual and commercial +# software must include, in the user documentation and internal +# comments to the code, the above Disclaimer and U.S. Government End +# Users Notice. +# + + +from __future__ import print_function + +import argparse +import traceback +import sys +import tensorrt as trt + +sys.path.append('./') # to run '$ python *.py' files in subdirectories +from trt.calibrator import DataLoader, Calibrator + +MAX_BATCH_SIZE = 1 + +def build_engine_from_onnx(model_name, + dtype, + verbose=False, + int8_calib=False, + calib_loader=None, + calib_cache=None, + dynamic_shape=False, + fp32_layer_names=[], + fp16_layer_names=[], + ): + """Initialization routine.""" + if dtype == "int8": + t_dtype = trt.DataType.INT8 + elif dtype == "fp16": + t_dtype = trt.DataType.HALF + elif dtype == "fp32": + t_dtype = trt.DataType.FLOAT + else: + raise ValueError("Unsupported data type: %s" % dtype) + + if trt.__version__[0] < '8': + print('Exit, trt.version should be >=8. Now your trt version is ', trt.__version__[0]) + + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + if dtype == "int8" and calib_loader is None: + print('QAT enabled!') + network_flags = network_flags | (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION)) + + """Build a TensorRT engine from ONNX""" + TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger() + with trt.Builder(TRT_LOGGER) as builder, builder.create_network(flags=network_flags) as network, \ + trt.OnnxParser(network, TRT_LOGGER) as parser: + with open(model_name, 'rb') as model: + if not parser.parse(model.read()): + print('ERROR: ONNX Parse Failed') + for error in range(parser.num_errors): + print(parser.get_error(error)) + return None + + print('Building an engine. This would take a while...') + print('(Use "--verbose" or "-v" to enable verbose logging.)') + config = builder.create_builder_config() + config.max_workspace_size = 2 << 30 + if t_dtype == trt.DataType.HALF: + config.flags |= 1 << int(trt.BuilderFlag.FP16) + + if t_dtype == trt.DataType.INT8: + print('trt.DataType.INT8') + config.flags |= 1 << int(trt.BuilderFlag.INT8) + config.flags |= 1 << int(trt.BuilderFlag.FP16) + + if int8_calib: + config.int8_calibrator = Calibrator(calib_loader, calib_cache) + print('Int8 calibation is enabled.') + + if dynamic_shape: + # You can adjust the shape setting according to the actual situation + profile = builder.create_optimization_profile() + profile.set_shape("images", (1, 3, 640, 640), (8, 3, 640, 640), (16, 3, 640, 640)) + config.add_optimization_profile(profile) + + engine = builder.build_engine(network, config) + + try: + assert engine + except AssertionError: + _, _, tb = sys.exc_info() + traceback.print_tb(tb) # Fixed format + tb_info = traceback.extract_tb(tb) + _, line, _, text = tb_info[-1] + raise AssertionError( + "Parsing failed on line {} in statement {}".format(line, text) + ) + + return engine + + +def main(): + """Create a TensorRT engine for ONNX-based YOLO.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '-v', '--verbose', action='store_true', + help='enable verbose output (for debugging)') + parser.add_argument( + '-m', '--model', type=str, required=True, + help=('onnx model path')) + parser.add_argument( + '-d', '--dtype', type=str, required=True, + help='one type of int8, fp16, fp32') + parser.add_argument('--dynamic-shape', action='store_true', + help='Dynamic shape, defer specifying some or all tensor dimensions until runtime') + parser.add_argument( + '--qat', action='store_true', + help='whether the onnx model is qat; if it is, the int8 calibrator is not needed') + # If enable int8(not post-QAT model), then set the following + parser.add_argument('--img-size', type=int, + default=640, help='image size of model input') + parser.add_argument('--batch-size', type=int, + default=128, help='batch size for training: default 64') + parser.add_argument('--num-calib-batch', default=6, type=int, + help='Number of batches for calibration') + parser.add_argument('--calib-img-dir', default='../coco/images/train2017', type=str, + help='Number of batches for calibration') + parser.add_argument('--calib-cache', default='./trt/yolov5s_calibration.cache', type=str, + help='Path of calibration cache') + parser.add_argument('--calib-method', default='minmax', type=str, + help='Calibration method') + + args = parser.parse_args() + + + if args.dtype == "int8" and not args.qat: + calib_loader = DataLoader(args.batch_size, args.num_calib_batch, args.calib_img_dir, + args.img_size, args.img_size) + engine = build_engine_from_onnx(args.model, args.dtype, args.verbose, + int8_calib=True, calib_loader=calib_loader, calib_cache=args.calib_cache, + dynamic_shape=args.dynamic_shape) + else: + engine = build_engine_from_onnx(args.model, args.dtype, args.verbose, + dynamic_shape=args.dynamic_shape) + + if engine is None: + raise SystemExit('ERROR: failed to build the TensorRT engine!') + + engine_path = args.model.replace('.onnx', '.trt') + if args.dtype == "int8" and not args.qat: + engine_path = args.model.replace('.onnx', '-int8-{}-{}-{}.trt'.format(args.batch_size, args.num_calib_batch, + args.calib_method)) + + with open(engine_path, 'wb') as f: + f.write(engine.serialize()) + print('Serialized the TensorRT engine to file: %s' % engine_path) + + +if __name__ == '__main__': + main() diff --git a/trt/onnx_to_trt_partialquant.py b/trt/onnx_to_trt_partialquant.py new file mode 100644 index 000000000000..9d2120cc6fc6 --- /dev/null +++ b/trt/onnx_to_trt_partialquant.py @@ -0,0 +1,251 @@ +# onnx_to_tensorrt.py +# +# Copyright 1993-2019 NVIDIA Corporation. All rights reserved. +# +# NOTICE TO LICENSEE: +# +# This source code and/or documentation ("Licensed Deliverables") are +# subject to NVIDIA intellectual property rights under U.S. and +# international Copyright laws. +# +# These Licensed Deliverables contained herein is PROPRIETARY and +# CONFIDENTIAL to NVIDIA and is being provided under the terms and +# conditions of a form of NVIDIA software license agreement by and +# between NVIDIA and Licensee ("License Agreement") or electronically +# accepted by Licensee. Notwithstanding any terms or conditions to +# the contrary in the License Agreement, reproduction or disclosure +# of the Licensed Deliverables to any third party without the express +# written consent of NVIDIA is prohibited. +# +# NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE +# LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE +# SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS +# PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. +# NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED +# DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, +# NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. +# NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE +# LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY +# SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY +# DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, +# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS +# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE +# OF THESE LICENSED DELIVERABLES. +# +# U.S. Government End Users. These Licensed Deliverables are a +# "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT +# 1995), consisting of "commercial computer software" and "commercial +# computer software documentation" as such terms are used in 48 +# C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government +# only as a commercial end item. Consistent with 48 C.F.R.12.212 and +# 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all +# U.S. Government End Users acquire the Licensed Deliverables with +# only those rights set forth herein. +# +# Any use of the Licensed Deliverables in individual and commercial +# software must include, in the user documentation and internal +# comments to the code, the above Disclaimer and U.S. Government End +# Users Notice. +# + + +from __future__ import print_function + +import argparse +import traceback +import sys +import tensorrt as trt + +sys.path.append('./') # to run '$ python *.py' files in subdirectories +from trt.calibrator import DataLoader, Calibrator + +MAX_BATCH_SIZE = 1 + +def _set_excluded_layer_id_precision(network, fp32_layer_ids, fp16_layer_ids): + """ + Step2: setting the sensitive layer to FP32/FP16 + When generating an INT8 model, it sets excluded layers' precision as fp32 or fp16. + + In detail, this function is only used when generating INT8 TensorRT models. It accepts + two lists of layer ids: (1). for the layers in fp32_layer_ids, their precision will + be set as fp32; (2). for those in fp16_layer_ids, their precision will be set as fp16. + + Args: + network: TensorRT network object. + fp32_layer_ids (list): List of layer ids. These layers use fp32. + fp16_layer_ids (list): List of layer ids. These layers use fp16. + """ + is_mixed_precision = False + use_fp16_mode = False + + for layer_idx in range(network.num_layers): + layer = network.get_layer(layer_idx) + if layer_idx in fp32_layer_ids: + is_mixed_precision = True + layer.precision = trt.float32 + layer.set_output_type(0, trt.float32) + elif layer_idx in fp16_layer_ids: + is_mixed_precision = True + use_fp16_mode = True + layer.precision = trt.float16 + layer.set_output_type(0, trt.float16) + else: + layer.precision = trt.int8 + layer.set_output_type(0, trt.int8) + + return network, is_mixed_precision, use_fp16_mode + +def build_engine_from_onnx(model_name, + dtype, + verbose=False, + int8_calib=False, + calib_loader=None, + calib_cache=None, + fp32_layer_ids=[], + fp16_layer_ids=[], + ): + """Initialization routine.""" + if dtype == "int8": + t_dtype = trt.DataType.INT8 + elif dtype == "fp16": + t_dtype = trt.DataType.HALF + elif dtype == "fp32": + t_dtype = trt.DataType.FLOAT + else: + raise ValueError("Unsupported data type: %s" % dtype) + + if trt.__version__[0] < '8': + print('Exit, trt.version should be >=8. Now your trt version is ', trt.__version__[0]) + + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + if dtype == "int8" and calib_loader is None: + print('QAT enabled!') + network_flags = network_flags | (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION)) + + """Build a TensorRT engine from ONNX""" + TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger() + with trt.Builder(TRT_LOGGER) as builder, builder.create_network(flags=network_flags) as network, \ + trt.OnnxParser(network, TRT_LOGGER) as parser: + with open(model_name, 'rb') as model: + if not parser.parse(model.read()): + print('ERROR: ONNX Parse Failed') + for error in range(parser.num_errors): + print(parser.get_error(error)) + return None + + print('Building an engine. This would take a while...') + print('(Use "--verbose" or "-v" to enable verbose logging.)') + config = builder.create_builder_config() + config.max_workspace_size = 2 << 30 + # config.set_flag(trt.BuilderFlag.GPU_FALLBACK) + if t_dtype == trt.DataType.HALF: + config.flags |= 1 << int(trt.BuilderFlag.FP16) + + if t_dtype == trt.DataType.INT8: + print('trt.DataType.INT8') + config.flags |= 1 << int(trt.BuilderFlag.INT8) + + if int8_calib: + config.int8_calibrator = Calibrator(calib_loader, calib_cache) + print('Int8 calibation is enabled.') + + ## Step1: Print layer name and id, for partial quantization + # layer_names = [] + # for layer_idx in range(network.num_layers): + # layer = network.get_layer(layer_idx) + # layer_names.append(layer.name) + # for index, layer_name in enumerate(layer_names): + # print(index, ' layer_name: ', layer_name) + + # When use mixed precision, for TensorRT builder: + # strict_type_constraints needs to be True; + # fp16_mode needs to be True if any layer uses fp16 precision. + network, strict_type_constraints, fp16_mode = _set_excluded_layer_id_precision( + network=network, + fp32_layer_ids=fp32_layer_ids, + fp16_layer_ids=fp16_layer_ids, + ) + + if strict_type_constraints: + print('Set STRICT_TYPES') + config.flags |= 1 << int(trt.BuilderFlag.STRICT_TYPES) + + if fp16_mode: + print('Set fp16_mode') + config.flags |= 1 << int(trt.BuilderFlag.FP16) + + engine = builder.build_engine(network, config) + + try: + assert engine + except AssertionError: + _, _, tb = sys.exc_info() + traceback.print_tb(tb) # Fixed format + tb_info = traceback.extract_tb(tb) + _, line, _, text = tb_info[-1] + raise AssertionError( + "Parsing failed on line {} in statement {}".format(line, text) + ) + + return engine + + +def main(): + """Create a TensorRT engine for ONNX-based YOLO.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '-v', '--verbose', action='store_true', + help='enable verbose output (for debugging)') + parser.add_argument( + '-m', '--model', type=str, required=True, + help=('onnx model path')) + parser.add_argument( + '-d', '--dtype', type=str, required=True, + help='one type of int8, fp16, fp32') + parser.add_argument( + '--qat', action='store_true', + help='whether the onnx model is qat; if it is, the int8 calibrator is not needed') + # If enable int8(not post-QAT model), then set the following + parser.add_argument('--img-size', type=int, + default=640, help='image size of model input') + parser.add_argument('--batch-size', type=int, + default=128, help='batch size for training: default 64') + parser.add_argument('--num-calib-batch', default=6, type=int, + help='Number of batches for calibration') + parser.add_argument('--calib-img-dir', default='../coco/images/train2017', type=str, + help='Number of batches for calibration') + parser.add_argument('--calib-cache', default='./trt/yolov5s_calibration.cache', type=str, + help='Path of calibration cache') + + args = parser.parse_args() + + + if args.dtype == "int8" and not args.qat: + calib_loader = DataLoader(args.batch_size, args.num_calib_batch, args.calib_img_dir, + args.img_size, args.img_size) + + # For yolov5s-SiLU + fp16_lay_ids = list(range(208, 220)) # Detect layer and the layer close to detect layer + fp16_lay_ids.extend([168, 169, 170, 188, 189, 190]) + fp16_lay_ids.extend(list(range(0, 29))) # The slice layer and first two conv layer + + engine = build_engine_from_onnx(args.model, args.dtype, args.verbose, + int8_calib=True, calib_loader=calib_loader, calib_cache=args.calib_cache, + fp32_layer_ids=[], fp16_layer_ids=fp16_lay_ids) + else: + engine = build_engine_from_onnx(args.model, args.dtype, args.verbose) + + if engine is None: + raise SystemExit('ERROR: failed to build the TensorRT engine!') + + engine_path = args.model.replace('.onnx', '.trt') + if args.dtype == "int8" and not args.qat: + engine_path = args.model.replace('.onnx', '-int8-{}-{}-minmax.trt'.format(args.batch_size, args.num_calib_batch)) + + with open(engine_path, 'wb') as f: + f.write(engine.serialize()) + print('Serialized the TensorRT engine to file: %s' % engine_path) + + +if __name__ == '__main__': + main() diff --git a/trt/trt_dynamic/Processor_dynamic.py b/trt/trt_dynamic/Processor_dynamic.py new file mode 100644 index 000000000000..d5ec09b270b3 --- /dev/null +++ b/trt/trt_dynamic/Processor_dynamic.py @@ -0,0 +1,375 @@ +import cv2 +import sys +import tensorrt as trt +import pycuda.driver as cuda +import pycuda.autoinit # Necessary, to enable pycuda +import numpy as np +import math +import time + +sys.path.append('./') # to run '$ python *.py' files in subdirectories +import torch +import torchvision + + +class HostDeviceMem(object): + """Simple helper data class that's a little nicer to use than a 2-tuple.""" + def __init__(self, host_mem, device_mem): + self.host = host_mem + self.device = device_mem + + def __str__(self): + return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) + + def __repr__(self): + return self.__str__() + + +class Processor(): + def __init__(self, model, category_num=80, letter_box=False, infer_shape=None): + # load tensorrt engine + TRT_LOGGER = trt.Logger(trt.Logger.INFO) + print('TRT model path: ', model) + with open(model, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime: + engine = runtime.deserialize_cuda_engine(f.read()) + if infer_shape is None: + raise ValueError('Using dynamic shape engine, you must specify the correct inference shape.') + self.input_shape = tuple(infer_shape[2:]) + self.context = engine.create_execution_context() + + # Allocates all host/device in/out buffers required for an engine. + inputs = [] + outputs = [] + bindings = [] + stream = cuda.Stream() + for binding in engine: + size = -1 * trt.volume(engine.get_binding_shape(binding)) * infer_shape[0] + dtype = trt.nptype(engine.get_binding_dtype(binding)) + print('----------- Dynamic info: ', binding, engine.get_binding_shape(binding), infer_shape[0], size, dtype) + + host_mem = cuda.pagelocked_empty(size, dtype) + device_mem = cuda.mem_alloc(host_mem.nbytes) + # Append the device buffer to device bindings. + bindings.append(int(device_mem)) + # Append to the appropriate list. + if engine.binding_is_input(binding): + inputs.append(HostDeviceMem(host_mem, device_mem)) + else: + outputs.append(HostDeviceMem(host_mem, device_mem)) + + # set_binding_shape for the python API. + # setBindingDimensions in the C++ API + self.context.set_binding_shape(0, tuple(infer_shape)) + + # save to class + self.inputs = inputs + self.outputs = outputs + self.bindings = bindings + self.stream = stream + self.letter_box = letter_box + # post processing config + self.output_shapes = [ + (infer_shape[0], 3, 80, 80, 85), + (infer_shape[0], 3, 40, 40, 85), + (infer_shape[0], 3, 20, 20, 85) + ] + self.strides = np.array([8., 16., 32.]) + anchors = np.array([ + [[10,13], [16,30], [33,23]], + [[30,61], [62,45], [59,119]], + [[116,90], [156,198], [373,326]], + ]) + self.nl = len(anchors) + self.nc = category_num # classes + self.no = self.nc + 5 # outputs per anchor + self.na = len(anchors[0]) + a = anchors.copy().astype(np.float32) + a = a.reshape(self.nl, -1, 2) + self.anchors = a.copy() + self.anchor_grid = a.copy().reshape(self.nl, 1, -1, 1, 1, 2) + self.infer_shape = infer_shape + + + def detect(self, img_list, letter_box=None): + """Detect objects in the input image.""" + # It needs to be optimized for efficiency + resized_infer = np.empty(self.infer_shape, dtype='float32') + for index, img in enumerate(img_list): + letter_box = self.letter_box if letter_box is None else letter_box + resized = self.pre_process(img, self.input_shape, letter_box) + resized_infer[index, :, :, :] = resized + + outputs = self.inference(resized_infer) + # reshape from flat to (batch_size, 3, x, y, 85) + reshaped = [] + for output, shape in zip(outputs, self.output_shapes): + reshaped.append(output.reshape(shape)) + return reshaped + + + def pre_process(self, img, input_shape, letter_box=False): + """Preprocess an image before TRT YOLO inferencing. + """ + if letter_box: + img_h, img_w, _ = img.shape + new_h, new_w = input_shape[0], input_shape[1] + offset_h, offset_w = 0, 0 + if (new_w / img_w) <= (new_h / img_h): + new_h = int(img_h * new_w / img_w) + offset_h = (input_shape[0] - new_h) // 2 + else: + new_w = int(img_w * new_h / img_h) + offset_w = (input_shape[1] - new_w) // 2 + resized = cv2.resize(img, (new_w, new_h)) + img = np.full((input_shape[0], input_shape[1], 3), 114, dtype=np.uint8) + img[offset_h:(offset_h + new_h), offset_w:(offset_w + new_w), :] = resized + else: + img = cv2.resize(img, (input_shape[1], input_shape[0])) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = img.transpose((2, 0, 1)).astype(np.float32) + img /= 255.0 + return img + + def inference(self, img): + # No need to add axis for batch inference + # img = img[np.newaxis, :, :, :] + + # Set host input to the image. The do_inference() function + # will copy the input to the GPU before executing. + self.inputs[0].host = np.ascontiguousarray(img) + + # Transfer input data to the GPU. + [cuda.memcpy_htod_async(inp.device, inp.host, self.stream) for inp in self.inputs] + # Run inference. + self.context.execute_async_v2(bindings=self.bindings, stream_handle=self.stream.handle) + # Transfer predictions back from the GPU. + [cuda.memcpy_dtoh_async(out.host, out.device, self.stream) for out in self.outputs] + # Synchronize the stream + self.stream.synchronize() + # Return only the host outputs. + return [out.host for out in self.outputs] + + def post_process(self, outputs, img_shape_list, conf_thres=0.5, iou_thres=0.6): + """ + Transforms raw output into boxes, confs, classes + Applies NMS thresholding on bounding boxes and confs + Parameters: + output: raw output tensor + Returns: + boxes: x1,y1,x2,y2 tensor (dets, 4) + confs: class * obj prob tensor (dets, 1) + classes: class type tensor (dets, 1) + """ + scaled = [] + grids = [] + for out in outputs: + # print('out.shape: ', out.shape) + # print('out: ', out) + out = self.sigmoid_v(out) + # print('sigmoid_v out.shape: ', out.shape) + # print('sigmoid_v out: ', out) + _, _, width, height, _ = out.shape + grid = self.make_grid(width, height) + grids.append(grid) + scaled.append(out) + z = [] + for out, grid, stride, anchor in zip(scaled, grids, self.strides, self.anchor_grid): + _, _, width, height, _ = out.shape + out[..., 0:2] = (out[..., 0:2] * 2. - 0.5 + grid) * stride + out[..., 2:4] = (out[..., 2:4] * 2) ** 2 * anchor + + out = out.reshape((self.infer_shape[0], 3 * width * height, 85)) + z.append(out) + + det_t_list = [] + for i in range(self.infer_shape[0]): + single_image = [z[0][i:i+1, :, :], z[1][i:i+1, :, :], z[2][i:i+1, :, :]] + + # Process single image + pred = np.concatenate(single_image, 1) + + # Use Pytorch to do the post-process + det_t = self.non_max_suppression(torch.from_numpy(pred), conf_thres=conf_thres, + iou_thres=iou_thres, multi_label=True)[0] + + self.scale_coords(self.input_shape, det_t[:, :4], img_shape_list[i]) + det_t_list.append(det_t) + + return det_t_list + + + def make_grid(self, nx, ny): + """ + Create scaling tensor based on box location + Source: https://github.com/ultralytics/yolov5/blob/master/models/yolo.py + Arguments + nx: x-axis num boxes + ny: y-axis num boxes + Returns + grid: tensor of shape (1, 1, nx, ny, 80) + """ + nx_vec = np.arange(nx) + ny_vec = np.arange(ny) + yv, xv = np.meshgrid(ny_vec, nx_vec) + grid = np.stack((yv, xv), axis=2) + grid = grid.reshape(1, 1, ny, nx, 2) + return grid + + + def sigmoid(self, x): + return 1 / (1 + math.exp(-x)) + + + def sigmoid_v(self, array): + return np.reciprocal(np.exp(-array) + 1.0) + + + def xywh2xyxy(self, x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + + def clip_coords(self, boxes, img_shape): + # Clip bounding xyxy bounding boxes to image shape (height, width) + boxes[:, 0].clamp_(0, img_shape[1]) # x1 + boxes[:, 1].clamp_(0, img_shape[0]) # y1 + boxes[:, 2].clamp_(0, img_shape[1]) # x2 + boxes[:, 3].clamp_(0, img_shape[0]) # y2 + + + def scale_coords(self, img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2]] -= pad[0] # x padding + coords[:, [1, 3]] -= pad[1] # y padding + coords[:, :4] /= gain + self.clip_coords(coords, img0_shape) + return coords + + def non_max_suppression(self, prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, + multi_label=False, + labels=()): + """Runs Non-Maximum Suppression (NMS) on inference results + + Returns: + list of detections, on (n,6) tensor per image [xyxy, conf, cls] + """ + + nc = prediction.shape[2] - 5 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height + max_det = 300 # maximum number of detections per image + max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + l = labels[xi] + v = torch.zeros((len(l), nc + 5), device=x.device) + v[:, :4] = l[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = self.xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, cls) + if multi_label: + i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 5:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # Apply finite constraint + # if not torch.isfinite(x).all(): + # x = x[torch.isfinite(x).all(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + elif n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + if i.shape[0] > max_det: # limit detections + i = i[:max_det] + if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = self.box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + print(f'WARNING: NMS time limit {time_limit}s exceeded') + break # time limit exceeded + + return output + + def box_iou(self, box1, box2): + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + + def box_area(box): + # box = 4xn + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) + diff --git a/trt/trt_dynamic/__init__.py b/trt/trt_dynamic/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/trt/trt_dynamic/eval_yolo_trt_dynamic.py b/trt/trt_dynamic/eval_yolo_trt_dynamic.py new file mode 100644 index 000000000000..689dacb4874b --- /dev/null +++ b/trt/trt_dynamic/eval_yolo_trt_dynamic.py @@ -0,0 +1,144 @@ +"""eval_yolo.py + +This script is for evaluating mAP (accuracy) of YOLO models. +""" +import os +import sys +import json +import argparse + +sys.path.append('./') # to run '$ python *.py' files in subdirectories + +import cv2 +import torch + +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +from Processor_dynamic import Processor +from utils.general import coco80_to_coco91_class + +# converts 80-index (val2014) to 91-index (paper) +coco91class = coco80_to_coco91_class() + +VAL_IMGS_DIR = '../coco/images/val2017' +VAL_ANNOTATIONS = '../coco/annotations/instances_val2017.json' + +def parse_args(): + """Parse input arguments.""" + desc = 'Evaluate mAP of YOLO TRT model' + parser = argparse.ArgumentParser(description=desc) + parser.add_argument( + '--imgs-dir', type=str, default=VAL_IMGS_DIR, + help='directory of validation images [%s]' % VAL_IMGS_DIR) + parser.add_argument( + '--annotations', type=str, default=VAL_ANNOTATIONS, + help='groundtruth annotations [%s]' % VAL_ANNOTATIONS) + parser.add_argument( + '-c', '--category-num', type=int, default=80, + help='number of object categories [80]') + parser.add_argument('--infer-shape', nargs='+', type=int, default=[8, 3, 640, 640], + help='if build engine with dynamic shape, you must specify the dimension during inference') + parser.add_argument( + '-m', '--model', type=str, default='./weights/yolov5s-simple.trt', + help=('trt model path')) + parser.add_argument( + '-l', '--letter_box', action='store_true', + help='inference with letterboxed image [False]') + parser.add_argument( + '--conf-thres', type=float, default=0.001, + help='object confidence threshold') + parser.add_argument( + '--iou-thres', type=float, default=0.6, + help='IOU threshold for NMS') + args = parser.parse_args() + return args + + +def check_args(args): + """Check and make sure command-line arguments are valid.""" + if not os.path.isdir(args.imgs_dir): + sys.exit('%s is not a valid directory' % args.imgs_dir) + if not os.path.isfile(args.annotations): + sys.exit('%s is not a valid file' % args.annotations) + + +def generate_results(processor, imgs_dir, jpgs, results_file, conf_thres, iou_thres, non_coco, infer_shape): + """Run detection on each jpg and write results to file.""" + results = [] + + i = 0 + + # Batch param + batch_id = 0 + img_list = [] + image_id_list = [] + image_shape_list = [] + + for jpg in jpgs: + i+=1 + if(i%100 == 0): + print('Processing {} images'.format(i)) + + img = cv2.imread(os.path.join(imgs_dir, jpg)) + image_id = int(jpg.split('.')[0].split('_')[-1]) + + # Accumulated into a batch + if batch_id < infer_shape[0]: + img_list.append(img) + image_id_list.append(image_id) + image_shape_list.append(img.shape) + + batch_id += 1 + if batch_id == infer_shape[0]: + output = processor.detect(img_list) + pred = processor.post_process(output, image_shape_list, conf_thres=conf_thres, + iou_thres=iou_thres) + for index, item in enumerate(pred): + for p in item.tolist(): + x = float(p[0]) + y = float(p[1]) + w = float(p[2] - p[0]) + h = float(p[3] - p[1]) + results.append({'image_id': image_id_list[index], + 'category_id': coco91class[int(p[5])] if not non_coco else int(p[5]), + 'bbox': [round(x, 3) for x in [x, y, w, h]], + 'score': round(p[4], 5)}) + + batch_id = 0 + img_list = [] + image_id_list = [] + image_shape_list = [] + + with open(results_file, 'w') as f: + f.write(json.dumps(results, indent=4)) + + +def main(): + args = parse_args() + check_args(args) + + model_prefix = args.model.replace('.trt', '').split('/')[-1] + results_file = 'weights/results_{}.json'.format(model_prefix) + + # setup processor + processor = Processor(model=args.model, letter_box=True, infer_shape=args.infer_shape) + + jpgs = [j for j in os.listdir(args.imgs_dir) if j.endswith('.jpg')] + generate_results(processor, args.imgs_dir, jpgs, results_file, args.conf_thres, args.iou_thres, + non_coco=False, infer_shape=args.infer_shape) + + # Run COCO mAP evaluation + # Reference: https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb + cocoGt = COCO(args.annotations) + cocoDt = cocoGt.loadRes(results_file) + imgIds = sorted(cocoGt.getImgIds()) + cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') + cocoEval.params.imgIds = imgIds + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + + +if __name__ == '__main__': + main() diff --git a/utils_quant/__init__.py b/utils_quant/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/utils_quant/calib_test.sh b/utils_quant/calib_test.sh new file mode 100644 index 000000000000..a22ee8e8eb6a --- /dev/null +++ b/utils_quant/calib_test.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +echo "Running first-time script 64, 8." +python yolo_quant_flow.py --data data/coco.yaml --cfg models/yolov5s.yaml --ckpt-path weights/yolov5s.pt \ +--hyp data/hyp.qat.yaml --calib-batch-size 64 --num-calib-batch 8 --device=1 + + +echo "Running first-time script 64, 16." +python yolo_quant_flow.py --data data/coco.yaml --cfg models/yolov5s.yaml --ckpt-path weights/yolov5s.pt \ +--hyp data/hyp.qat.yaml --calib-batch-size 64 --num-calib-batch 16 --device=1 + + +echo "Running first-time script 32, 16." +python yolo_quant_flow.py --data data/coco.yaml --cfg models/yolov5s.yaml --ckpt-path weights/yolov5s.pt \ +--hyp data/hyp.qat.yaml --calib-batch-size 32 --num-calib-batch 16 --device=1 + + +echo "Running first-time script 128, 8." +python yolo_quant_flow.py --data data/coco.yaml --cfg models/yolov5s.yaml --ckpt-path weights/yolov5s.pt \ +--hyp data/hyp.qat.yaml --calib-batch-size 128 --num-calib-batch 8 --device=1 + + +echo "Running first-time script 64, 24." +python yolo_quant_flow.py --data data/coco.yaml --cfg models/yolov5s.yaml --ckpt-path weights/yolov5s.pt \ +--hyp data/hyp.qat.yaml --calib-batch-size 64 --num-calib-batch 24 --device=1 \ No newline at end of file diff --git a/utils_quant/check_params.py b/utils_quant/check_params.py new file mode 100644 index 000000000000..97f21eb659ea --- /dev/null +++ b/utils_quant/check_params.py @@ -0,0 +1,83 @@ +import os +from pathlib import Path +import logging + +import torch +import torch.utils.data +from torch.utils.tensorboard import SummaryWriter + +import numpy as np +import yaml + +logger = logging.getLogger(__name__) + +try: + from train import train + import test # import test.py to get mAP after each epoch + from utils.general import check_file, colorstr + from utils.torch_utils import select_device +except Exception as e: + print(repr(e)) + + +def check_and_set_params(opt): + """ + i. Check the validity of parameters + ii. Convert for compatibility with yolov5 parameters, and supplement the necessasy parameters + """ + torch.manual_seed(opt.seed) + np.random.seed(opt.seed) + + # Conversion and supplement for compatibility with yolov5 parameters + opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 + opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1 + opt.epochs = opt.num_finetune_epochs + opt.batch_size = opt.batch_size_train + opt.weights = opt.ckpt_path + opt.total_batch_size = opt.batch_size_train + opt.project = opt.out_dir # output folder + opt.evolve = False + opt.resume = False + opt.single_cls = False + opt.adam = False + opt.linear_lr = False + opt.sync_bn = False + opt.cache_images = False + opt.image_weights = False + opt.rect = False + opt.workers = 8 # maximum number of dataloader workers + opt.quad = False # quad dataloader + opt.noautoanchor = False # disable autoanchor check + opt.label_smoothing = 0.0 # default=0.0, Label smoothing epsilon + opt.multi_scale = False # vary img-size +/- 50% + opt.notest = False # only test final epoch + opt.name = 'exp' # save to project/name + opt.bucket = '' # gsutil bucket + opt.nosave = False # only save final checkpoint + opt.conf_thres = 0.001 # default=0.001, help='object confidence threshold' + opt.iou_thres = 0.6 # default=0.6, help='IOU threshold for NMS' + opt.exist_ok = False # action='store_true', help='existing project/name ok, do not increment' + opt.save_txt = False # action='store_true', help='save results to *.txt') + opt.save_hybrid = False # action='store_true', help='save label+prediction hybrid results to *.txt') + opt.save_conf = False # action='store_true', help='save confidences in --save-txt labels' + opt.save_json = True # save a cocoapi-compatible JSON results file + + opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files + assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' + opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) + opt.save_dir = Path(opt.out_dir) + + device = select_device(opt.device, batch_size=opt.batch_size) + + # Hyperparameters + with open(opt.hyp) as f: + hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps + + logger.info(opt) + tb_writer = None # init loggers + if opt.global_rank in [-1, 0]: + prefix = colorstr('tensorboard: ') + logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.out_dir}', view at http://localhost:6006/") + tb_writer = SummaryWriter(opt.save_dir) # Tensorboard + + return hyp, opt, device, tb_writer \ No newline at end of file diff --git a/utils_quant/onnxrt_demo.py b/utils_quant/onnxrt_demo.py new file mode 100644 index 000000000000..8dce1a8d8356 --- /dev/null +++ b/utils_quant/onnxrt_demo.py @@ -0,0 +1,356 @@ +import onnxruntime +import argparse + +import cv2 +import sys +import numpy as np +import math +import time + +sys.path.append('./') # to run '$ python *.py' files in subdirectories +import torch +import torchvision + +""" + Run the inference using onnx runtime, useful debug tool +""" + +def get_input_shape(binding_dims): + if len(binding_dims) == 4: + return tuple(binding_dims[2:]) + elif len(binding_dims) == 3: + return tuple(binding_dims[1:]) + else: + raise ValueError('bad dims of binding %s' % (str(binding_dims))) + + +class Processor(): + def __init__(self, model, category_num=80, letter_box=False): + # load onnx engine + self.ort_session = onnxruntime.InferenceSession(model) + + # get output name + self.input_name = self.ort_session.get_inputs()[0].name + self.output_names = [] + for i in range(len(self.ort_session.get_outputs())): + output_name = self.ort_session.get_outputs()[i].name + print("output name {}: ".format(i), output_name) + output_shape = self.ort_session.get_outputs()[i].shape + print("output shape {}: ".format(i), output_shape) + self.output_names.append(output_name) + + self.input_shape = get_input_shape(self.ort_session.get_inputs()[0].shape) + print('---self.input_shape: ', self.input_shape) + + + self.letter_box = letter_box + # post processing config + filters = (category_num + 5) * 3 + self.output_shapes = [ + (1, 3, 80, 80, 85), + (1, 3, 40, 40, 85), + (1, 3, 20, 20, 85) + ] + self.strides = np.array([8., 16., 32.]) + anchors = np.array([ + [[10, 13], [16, 30], [33, 23]], + [[30, 61], [62, 45], [59, 119]], + [[116, 90], [156, 198], [373, 326]], + ]) + self.nl = len(anchors) + self.nc = category_num # classes + self.no = self.nc + 5 # outputs per anchor + self.na = len(anchors[0]) + a = anchors.copy().astype(np.float32) + a = a.reshape(self.nl, -1, 2) + self.anchors = a.copy() + self.anchor_grid = a.copy().reshape(self.nl, 1, -1, 1, 1, 2) + + def detect(self, img, letter_box=None): + """Detect objects in the input image.""" + letter_box = self.letter_box if letter_box is None else letter_box + resized = self.pre_process(img, self.input_shape, letter_box) + + outputs = self.inference(resized) + # reshape from flat to (1, 3, x, y, 85) + reshaped = [] + for output, shape in zip(outputs, self.output_shapes): + reshaped.append(output.reshape(shape)) + return reshaped + + def pre_process(self, img, input_shape, letter_box=False): + """Preprocess an image before TRT YOLO inferencing. + """ + if letter_box: + img_h, img_w, _ = img.shape + new_h, new_w = input_shape[0], input_shape[1] + offset_h, offset_w = 0, 0 + if (new_w / img_w) <= (new_h / img_h): + new_h = int(img_h * new_w / img_w) + offset_h = (input_shape[0] - new_h) // 2 + else: + new_w = int(img_w * new_h / img_h) + offset_w = (input_shape[1] - new_w) // 2 + resized = cv2.resize(img, (new_w, new_h)) + img = np.full((input_shape[0], input_shape[1], 3), 114, dtype=np.uint8) + img[offset_h:(offset_h + new_h), offset_w:(offset_w + new_w), :] = resized + else: + img = cv2.resize(img, (input_shape[1], input_shape[0])) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = img.transpose((2, 0, 1)).astype(np.float32) + img /= 255.0 + return img + + def inference(self, img): + img = img[np.newaxis, :, :, :] + + # forward model + res = self.ort_session.run(self.output_names, {self.input_name: img}) + + # Return only the host outputs. + return [item for item in res] + + def post_process(self, outputs, img_shape, conf_thres=0.5, iou_thres=0.6): + """ + Transforms raw output into boxes, confs, classes + Applies NMS thresholding on bounding boxes and confs + Parameters: + output: raw output tensor + Returns: + boxes: x1,y1,x2,y2 tensor (dets, 4) + confs: class * obj prob tensor (dets, 1) + classes: class type tensor (dets, 1) + """ + scaled = [] + grids = [] + for out in outputs: + print('out.shape: ', out.shape) + print('out: ', out) + out = self.sigmoid_v(out) + + print('sigmoid_v out.shape: ', out.shape) + print('sigmoid_v out: ', out) + _, _, width, height, _ = out.shape + grid = self.make_grid(width, height) + grids.append(grid) + scaled.append(out) + z = [] + for out, grid, stride, anchor in zip(scaled, grids, self.strides, self.anchor_grid): + _, _, width, height, _ = out.shape + out[..., 0:2] = (out[..., 0:2] * 2. - 0.5 + grid) * stride + out[..., 2:4] = (out[..., 2:4] * 2) ** 2 * anchor + + out = out.reshape((1, 3 * width * height, 85)) + z.append(out) + + pred = np.concatenate(z, 1) + + # Use Pytorch to do the post-process + det_t = self.non_max_suppression(torch.from_numpy(pred), conf_thres=conf_thres, + iou_thres=iou_thres, multi_label=True)[0] + + self.scale_coords(self.input_shape, det_t[:, :4], img_shape) + return det_t + + def make_grid(self, nx, ny): + """ + Create scaling tensor based on box location + Source: https://github.com/ultralytics/yolov5/blob/master/models/yolo.py + Arguments + nx: x-axis num boxes + ny: y-axis num boxes + Returns + grid: tensor of shape (1, 1, nx, ny, 80) + """ + nx_vec = np.arange(nx) + ny_vec = np.arange(ny) + yv, xv = np.meshgrid(ny_vec, nx_vec) + grid = np.stack((yv, xv), axis=2) + grid = grid.reshape(1, 1, ny, nx, 2) + return grid + + def sigmoid(self, x): + return 1 / (1 + math.exp(-x)) + + def sigmoid_v(self, array): + return np.reciprocal(np.exp(-array) + 1.0) + + def xywh2xyxy(self, x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + def clip_coords(self, boxes, img_shape): + # Clip bounding xyxy bounding boxes to image shape (height, width) + boxes[:, 0].clamp_(0, img_shape[1]) # x1 + boxes[:, 1].clamp_(0, img_shape[0]) # y1 + boxes[:, 2].clamp_(0, img_shape[1]) # x2 + boxes[:, 3].clamp_(0, img_shape[0]) # y2 + + def scale_coords(self, img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2]] -= pad[0] # x padding + coords[:, [1, 3]] -= pad[1] # y padding + coords[:, :4] /= gain + self.clip_coords(coords, img0_shape) + return coords + + def non_max_suppression(self, prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, + multi_label=False, + labels=()): + """Runs Non-Maximum Suppression (NMS) on inference results + + Returns: + list of detections, on (n,6) tensor per image [xyxy, conf, cls] + """ + + nc = prediction.shape[2] - 5 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height + max_det = 300 # maximum number of detections per image + max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + l = labels[xi] + v = torch.zeros((len(l), nc + 5), device=x.device) + v[:, :4] = l[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = self.xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, cls) + if multi_label: + i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 5:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # Apply finite constraint + # if not torch.isfinite(x).all(): + # x = x[torch.isfinite(x).all(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + elif n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + if i.shape[0] > max_det: # limit detections + i = i[:max_det] + if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = self.box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + print(f'WARNING: NMS time limit {time_limit}s exceeded') + break # time limit exceeded + + return output + + def box_iou(self, box1, box2): + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + + def box_area(box): + # box = 4xn + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) + + +def cli(): + desc = 'Run TensorRT yolov5 visualizer' + parser = argparse.ArgumentParser(description=desc) + parser.add_argument('-m', '--model', default='./weights/yolov5s-ReLU-max-512.onnx', help='onnx runtime file path', required=False) + parser.add_argument('-i', '--image', default='./data/images/bus.jpg', help='image file path', required=False) + args = parser.parse_args() + return args + + +def main(): + # parse arguments + args = cli() + + # setup processor and visualizer + processor = Processor(model=args.model, letter_box=True) + + img = cv2.imread(args.image) + + # inference + output = processor.detect(img) + + # final results + pred = processor.post_process(output, img.shape, conf_thres=0.5) + + print('Detection result: ') + for item in pred.tolist(): + print(item) + + +if __name__ == '__main__': + main() diff --git a/utils_quant/print_model_structure.py b/utils_quant/print_model_structure.py new file mode 100644 index 000000000000..b6247ae2116f --- /dev/null +++ b/utils_quant/print_model_structure.py @@ -0,0 +1,43 @@ +"""Exports a YOLOv5 *.pt model to ONNX and TorchScript formats + +Usage: + $ export PYTHONPATH="$PWD" && python models/export.py --weights ./weights/yolov5s.pt --img 640 --batch 1 +""" + +import argparse +import sys + +sys.path.append('./') # to run '$ python *.py' files in subdirectories + +import torch.nn as nn + +import models +from models.experimental import attempt_load +from utils.activations import Hardswish +from utils.torch_utils import select_device + +from pytorch_quantization import nn as quant_nn + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--weights', type=str, default='./runs/finetune/yolov5s-max-512.pth', help='weights path') + parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + opt = parser.parse_args() + + # To use Pytorch's own fake quantization functions + quant_nn.TensorQuantizer.use_fb_fake_quant = True + + # Load PyTorch model + device = select_device(opt.device) + model = attempt_load(opt.weights, map_location=device) # load FP32 model + + # Print model name and params + for k, m in model.named_modules(): + m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + if isinstance(m, models.common.Conv): # assign export-friendly activations + if isinstance(m.act, nn.Hardswish): + m.act = Hardswish() + print(m) + + diff --git a/val.py b/val.py index 3d01f1a5996d..c09b9ba5622b 100644 --- a/val.py +++ b/val.py @@ -117,7 +117,8 @@ def run( project=ROOT / 'runs/val', # save to project/name name='exp', # save to project/name exist_ok=False, # existing project/name ok, do not increment - half=True, # use FP16 half-precision inference + #half=True, # use FP16 half-precision inference + half=False, # For QAT: quantization aware training, disable half_precision test,wjx dnn=False, # use OpenCV DNN for ONNX inference model=None, dataloader=None, diff --git a/yolo_quant_flow.py b/yolo_quant_flow.py new file mode 100644 index 000000000000..9f5f79aea9f7 --- /dev/null +++ b/yolo_quant_flow.py @@ -0,0 +1,475 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import argparse +import yaml +import warnings +import collections + +import torch +import torch.utils.data +import numpy as np + +from tqdm import tqdm +from copy import deepcopy +from prettytable import PrettyTable + +import logging +logging.basicConfig(level=logging.ERROR) + +try: + from pytorch_quantization import nn as quant_nn + from pytorch_quantization import calib + from pytorch_quantization.tensor_quant import QuantDescriptor + from pytorch_quantization import quant_modules +except ImportError: + raise ImportError( + "pytorch-quantization is not installed. Install from " + "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." + ) + +from train import train +import test +from utils_quant.check_params import check_and_set_params +from utils.datasets import create_dataloader +from utils.general import check_img_size, colorstr +from utils.torch_utils import intersect_dicts +from models.yolo import Model + + +def get_parser(): + """ + Creates an argument parser. + """ + parser = argparse.ArgumentParser(description='Object detection: Yolov5 quantization flow script') + + parser.add_argument('--data', type=str, default='data/coco.yaml', help='data.yaml path', required=True) + parser.add_argument('--out-dir', '-o', default='./runs/finetune', help='output folder: default ./runs/finetune') + parser.add_argument('--print-freq', '-pf', type=int, default=20, help='evaluation print frequency: default 20') + parser.add_argument('--threshold', '-t', type=float, default=-1.0, help='top1 accuracy threshold (less than 0.0 means no comparison): default -1.0') + + # setting for yolov5 + parser.add_argument('--model-name', '-m', default='yolov5s', help='model name: default yolov5s') + parser.add_argument('--cfg', type=str, default='', help='model.yaml path') + parser.add_argument('--ckpt-path', default='', type=str, help='path to latest checkpoint (default: none)') + parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes') + parser.add_argument('--batch-size-train', type=int, default=64, help='batch size for training: default 64') + parser.add_argument('--batch-size-test', type=int, default=64, help='batch size for testing: default 64') + parser.add_argument('--batch-size-onnx', type=int, default=1, help='batch size for onnx: default 1') + parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--seed', type=int, default=12345, help='random seed: default 12345') + parser.add_argument('--skip-eval-accuracy', action='store_true', help='Skip the accuracy evaluation after the QDQ insert/Calibration/QAT-Fintuning') + + # setting for calibration + parser.add_argument('--hyp', type=str, default='data/hyp.qat.yaml', help='hyperparameters path') + parser.add_argument('--calib-batch-size', type=int, + default=32, help='calib batch size: default 64') + parser.add_argument('--num-calib-batch', default=16, type=int, + help='Number of batches for calibration. 0 will disable calibration. (default: 4)') + parser.add_argument('--num-finetune-epochs', default=15, type=int, + help='Number of epochs to fine tune. 0 will disable fine tune. (default: 0)') + parser.add_argument('--calibrator', type=str, choices=["max", "histogram"], default="max") + parser.add_argument('--percentile', nargs='+', type=float, default=[99.9, 99.99, 99.999, 99.9999]) + parser.add_argument('--sensitivity', action="store_true", help="Build sensitivity profile") + parser.add_argument('--evaluate-onnx', action="store_true", help="Evaluate exported ONNX") + parser.add_argument("--accu-tolerance", type=float, default=0.925, help="used by test, for coco 0.367+0.558") + parser.add_argument('--skip-layers', action="store_true", help='Skip some sensitivity layers') + parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') + + return parser + + +def prepare_model(calibrator, hyp, opt, device): + """ + Prepare the model for the quantization, including quant modules, settings and dataloaders. + """ + # Use 'spawn' to avoid CUDA reinitialization with forked subprocess + torch.multiprocessing.set_start_method('spawn') + + ## Initialize quantization, model and data loaders + quant_desc_input = QuantDescriptor(calib_method=calibrator) + quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) + quant_nn.QuantMaxPool2d.set_default_quant_desc_input(quant_desc_input) + quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input) + + # Model + with open(opt.data) as f: + data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict + nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes + + # # Dynamic module replacement using monkey patching. + # # Monkey patching, take Conv2d for example, replace the Conv2d operator with quant_nn.QuantConv2d to enable FakeQuant + # quant_modules.initialize() + + pretrained = opt.weights.endswith('.pt') + if pretrained: + ckpt = torch.load(opt.weights, map_location=device) # load checkpoint + model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create + exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys + state_dict = ckpt['model'].float().state_dict() # to FP32 + state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect + model.load_state_dict(state_dict, strict=False) # load + else: + model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create + + # # Disable the monkey patching. + # quant_modules.deactivate() + + model.eval() + model.cuda() + + train_path = data_dict['train'] + test_path = data_dict['val'] + + gs = max(int(model.stride.max()), 32) # grid size (max stride) + imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples + + # Train dataloader + trainloader, dataset = create_dataloader(train_path, imgsz, opt.batch_size, gs, opt, + hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=-1, + world_size=opt.world_size, workers=opt.workers, + image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) + mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class + assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) + + # Test dataloader + testloader = create_dataloader(test_path, imgsz_test, opt.batch_size*2, gs, opt, # testloader + hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, + world_size=opt.world_size, workers=opt.workers, + pad=0.5, prefix=colorstr('val: '))[0] + + # Calib dataloader + calibloader = create_dataloader(train_path, imgsz, opt.calib_batch_size, gs, opt, + hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=-1, + world_size=opt.world_size, workers=opt.workers, + image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))[0] + + return model, trainloader, testloader, calibloader, dataset + + +def evaluate_accuracy(model, opt, testloader): + opt.task = 'val' + results, _, _ = test.test(opt.data, + weights=opt.weights, + batch_size=opt.batch_size_test, + model=model, + dataloader=testloader, + conf_thres=opt.conf_thres, + iou_thres=opt.iou_thres, + save_json=opt.save_json, + opt=opt) + + map50 = list(results)[3] + map = list(results)[2] + return map50, map + +def print_module_status(model): + """ + print the setting of quant module for debugging + """ + for name, module in model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer): + if module._calibrator is not None: + print('With _calibrator: ', F'{name:40}: {module}', module._learn_amax) + else: + print('No _calibrator: ', F'{name:40}: {module}') + + +def main(cmdline_args): + parser = get_parser() + opt = parser.parse_args(cmdline_args) + print(parser.description) + print(opt) + + # Check the validity of parameters + hyp, opt, device, tb_writer = check_and_set_params(opt) + + + # Prepare the pretrained model and data loaders + model, trainloader, testloader, calibloader, dataset = prepare_model(opt.calibrator, + hyp, opt, device) + + # Initial accuracy evaluation + if not opt.skip_eval_accuracy: + map50_initial, map_initial = evaluate_accuracy(model, opt, testloader) + print('Initial evaluation: ', "{:.3f}, {:.3f}".format(map50_initial, map_initial)) + + # Calibrate the model + with torch.no_grad(): + calibrate_model( + model=model, + model_name=opt.model_name, + data_loader=calibloader, + num_calib_batch=opt.num_calib_batch, + calibrator=opt.calibrator, + hist_percentile=opt.percentile, + out_dir=opt.out_dir, + device=device) + + # Evaluate after calibration + if opt.num_calib_batch > 0 and (not opt.skip_eval_accuracy): + map50_calibrated, map_calibrated = evaluate_accuracy(model, opt, testloader) + print('Calibration evaluation:', "{:.3f}, {:.3f}".format(map50_calibrated, map_calibrated)) + else: + map50_calibrated, map_calibrated = -1.0, -1.0 + + # Build sensitivy profile + if opt.sensitivity: + build_sensitivity_profile(model, opt, testloader) + + # Skip the sensitive layer + if opt.skip_layers: + skip_sensitive_layers(model, opt, testloader) + + if opt.num_finetune_epochs > 0: + # Finetune the model + train(hyp, opt, device, tb_writer, model=model, dataloader=trainloader, + testloader=testloader, dataset=dataset) + + # Evaluate after finetuning + if not opt.skip_eval_accuracy: + map50_finetuned, map_finetuned = evaluate_accuracy(model, opt, testloader) + print('Finetune evaluation: ', "{:.3f}, {:.3f}".format(map50_finetuned, map_finetuned)) + else: + map50_finetuned, map_finetuned = -1.0, -1.0 + + # Export to ONNX + onnx_filename = opt.ckpt_path.replace('.pt', '.onnx') + export_onnx(model, onnx_filename, opt.batch_size_onnx, opt.dynamic) + + # Print summary + if not opt.skip_eval_accuracy: + print("Accuracy summary:") + table = PrettyTable(['Stage','Top1']) + table.align['Stage'] = "l" + table.add_row( [ 'Initial', "{:.3f}, {:.3f}".format(map50_initial, map_initial) ] ) + table.add_row( [ 'Calibrated', "{:.3f}, {:.3f}".format(map50_calibrated, map_calibrated) ] ) + table.add_row( [ 'Finetuned', "{:.3f}, {:.3f}".format(map50_finetuned, map_finetuned) ] ) + print(table) + + return 0 + + +def export_onnx(model, onnx_filename, batch_onnx, dynamic_shape): + model.model[-1].export = True # Do not export Detect() layer grid + model.eval() + + # We have to shift to pytorch's fake quant ops before exporting the model to ONNX + quant_nn.TensorQuantizer.use_fb_fake_quant = True + + # Export ONNX for multiple batch sizes + print("Creating ONNX file: " + onnx_filename) + dummy_input = torch.randn(batch_onnx, 3, 640, 640, device='cuda') #TODO: switch input dims by model + + try: + import onnx + torch.onnx.export(model, dummy_input, onnx_filename, verbose=False, opset_version=13, input_names=['images'], + output_names= ['output_0', 'output_1', 'output_2'], + dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}} if dynamic_shape else None, + do_constant_folding=True) + + # Checks enable_onnx_checker=False, + onnx_model = onnx.load(onnx_filename) # load onnx model + onnx.checker.check_model(onnx_model) # check onnx model + print('ONNX export success, saved as %s' % onnx_filename) + except ValueError: + warnings.warn( + UserWarning("Per-channel quantization is not yet supported in Pytorch/ONNX RT (requires ONNX opset 13)")) + print("Failed to export to ONNX") + return False + + # Restore the PSX/TensorRT's fake quant mechanism + quant_nn.TensorQuantizer.use_fb_fake_quant = False + # Restore the model to train/test mode, use Detect() layer grid + model.model[-1].export = False + + return True + + +def calibrate_model(model, model_name, data_loader, num_calib_batch, calibrator, hist_percentile, out_dir, device): + """ + Feed data to the network and calibrate. + Arguments: + model: classification model + model_name: name to use when creating state files + data_loader: calibration data set + num_calib_batch: amount of calibration passes to perform + calibrator: type of calibration to use (max/histogram) + hist_percentile: percentiles to be used for historgram calibration + out_dir: dir to save state files in + """ + + if num_calib_batch > 0: + print("Calibrating model") + with torch.no_grad(): + collect_stats(model, data_loader, num_calib_batch, device) + + if not calibrator == "histogram": + compute_amax(model, method="max") + calib_output = os.path.join( + out_dir, + F"{model_name}-max-{num_calib_batch*data_loader.batch_size}.pth") + + ckpt = {'model': deepcopy(model)} + torch.save(ckpt, calib_output) + else: + for percentile in hist_percentile: + print(F"{percentile} percentile calibration") + compute_amax(model, method="percentile") + calib_output = os.path.join( + out_dir, + F"{model_name}-percentile-{percentile}-{num_calib_batch*data_loader.batch_size}.pth") + + ckpt = {'model': deepcopy(model)} + torch.save(ckpt, calib_output) + + for method in ["mse", "entropy"]: + print(F"{method} calibration") + compute_amax(model, method=method) + + calib_output = os.path.join( + out_dir, + F"{model_name}-{method}-{num_calib_batch*data_loader.batch_size}.pth") + + ckpt = {'model': deepcopy(model)} + torch.save(ckpt, calib_output) + +def collect_stats(model, data_loader, num_batches, device): + """Feed data to the network and collect statistics""" + # Enable calibrators + for name, module in model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer): + if module._calibrator is not None: + module.disable_quant() + module.enable_calib() + else: + module.disable() + + # Feed data to the network for collecting stats + for i, (img, _, _, _) in tqdm(enumerate(data_loader), total=num_batches): + img = img.to(device, non_blocking=True) + img = img.float() # uint8 to fp16/32 + img /= 255.0 # 0 - 255 to 0.0 - 1.0 + model(img) + if i >= num_batches: + break + + # Disable calibrators + for name, module in model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer): + if module._calibrator is not None: + module.enable_quant() + module.disable_calib() + else: + module.enable() + +def compute_amax(model, **kwargs): + # Load calib result + for name, module in model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer): + if module._calibrator is not None: + if isinstance(module._calibrator, calib.MaxCalibrator): + module.load_calib_amax() + else: + module.load_calib_amax(**kwargs) + # print(F"{name:40}: {module}") + model.cuda() + + +def build_sensitivity_profile(model, opt, testloader): + quant_layer_names = [] + for name, module in model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer): + module.disable() + layer_name = name.replace("._input_quantizer", "").replace("._weight_quantizer", "") + if layer_name not in quant_layer_names: + quant_layer_names.append(layer_name) + print(F"{len(quant_layer_names)} quantized layers found.") + + # Build sensitivity profile + quant_layer_sensitivity = {} + for i, quant_layer in enumerate(quant_layer_names): + print(F"Enable {quant_layer}") + for name, module in model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer) and quant_layer in name: + module.enable() + print(F"{name:40}: {module}") + + # Eval the model + map50, map50_95 = evaluate_accuracy(model, opt, testloader) + print(F"mAP@IoU=0.50: {map50}, mAP@IoU=0.50:0.95: {map50_95}") + quant_layer_sensitivity[quant_layer] = opt.accu_tolerance - (map50 + map50_95) + + for name, module in model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer) and quant_layer in name: + module.disable() + print(F"{name:40}: {module}") + + # Skip most sensitive layers until accuracy target is met + for name, module in model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer): + module.enable() + quant_layer_sensitivity = collections.OrderedDict(sorted(quant_layer_sensitivity.items(), key=lambda x: x[1])) + print(quant_layer_sensitivity) + + skipped_layers = [] + for quant_layer, _ in quant_layer_sensitivity.items(): + for name, module in model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer): + if quant_layer in name: + print(F"Disable {name}") + if not quant_layer in skipped_layers: + skipped_layers.append(quant_layer) + module.disable() + map50, map50_95 = evaluate_accuracy(model, opt, testloader) + if (map50 + map50_95) >= opt.accu_tolerance - 0.05: + print(F"Accuracy tolerance {opt.accu_tolerance} is met by skipping {len(skipped_layers)} sensitive layers.") + print(skipped_layers) + onnx_filename = opt.ckpt_path.replace('.pt', F'_skip{len(skipped_layers)}.onnx') + export_onnx(model, onnx_filename, opt.batch_size_onnx, opt.dynamic) + return + raise ValueError(f"Accuracy tolerance {opt.accu_tolerance} can not be met with any layer quantized!") + + +def skip_sensitive_layers(model, opt, testloader): + print('Skip the sensitive layers.') + # Sensitivity layers for yolov5s + skipped_layers = ['model.1.conv', # the first conv + 'model.2.cv1.conv', # the second conv + 'model.24.m.2', # detect layer + 'model.24.m.1', # detect layer + ] + + for name, module in model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer): + layer_name = name.replace("._input_quantizer", "").replace("._weight_quantizer", "") + if layer_name in skipped_layers: + print(F"Disable {name}") + module.disable() + + map50, map50_95 = evaluate_accuracy(model, opt, testloader) + print(F"mAP@IoU=0.50: {map50}, mAP@IoU=0.50:0.95: {map50_95}") + + onnx_filename = opt.ckpt_path.replace('.pt', F'_skip{len(skipped_layers)}.onnx') + export_onnx(model, onnx_filename, opt.batch_size_onnx, opt.dynamic) + return + + +if __name__ == '__main__': + res = main(sys.argv[1:]) + exit(res)