From fe38440cd888febc83a86996cac48a65fccd79b9 Mon Sep 17 00:00:00 2001 From: Linaom1214 Date: Fri, 22 Jul 2022 21:24:13 +0800 Subject: [PATCH] End2end (#61) * export end2end onnx model * fixbug * add web demo (#58) * Update README.md * main code update yolov7-tiny deploy cfg * main code update yolov7-tiny training cfg * main code @liguagua752109150 https://github.com/WongKinYiu/yolov7/issues/33#issuecomment-1178669212 * main code @albertfaromatics https://github.com/WongKinYiu/yolov7/issues/35#issuecomment-1178800685 * main code update link * main code add custom hyp * main code update default activation function * main code update path * main figure add more tasks * main code update readme * main code update reparameterization * Update README.md * main code update readme * main code update aux training * main code update aux training * main code update aux training * main figure update yolov7 prediction * main code update readme * main code rename * main code rename * main code rename * main code rename * main code update readme * main code update visualization * main code fix gain for train_aux * main code update loss * main code update instance segmentation demo * main code update keypoint detection demo * main code update pose demo * main code update pose * main code update pose * main code update pose * main code update pose * main code update trace * Update README.md * main code fix ciou * main code fix nan of aux training https://github.com/WongKinYiu/yolov7/issues/250#issue-1312356380 @hudingding * support onnx to tensorrt convert (#114) * fuse IDetect (#148) * Fixes #199 (#203) * minor fix * resolve conflict * resolve conflict * resolve conflict * resolve conflict * resolve conflict * resolve * resolve * resolve * resolve Co-authored-by: AK391 <81195143+AK391@users.noreply.github.com> Co-authored-by: Alexey Co-authored-by: Kin-Yiu, Wong <102582011@cc.ncu.edu.tw> Co-authored-by: linghu8812 <36389436+linghu8812@users.noreply.github.com> Co-authored-by: Alexander <84590713+SashaAlderson@users.noreply.github.com> Co-authored-by: Ben Raymond Co-authored-by: AlexeyAB84 --- .idea/.gitignore | 3 + .idea/inspectionProfiles/Project_Default.xml | 46 ++++++ .../inspectionProfiles/profiles_settings.xml | 6 + .idea/misc.xml | 4 + .idea/modules.xml | 8 + .idea/vcs.xml | 6 + .idea/yolov7.iml | 12 ++ README.md | 7 + export.py | 35 ++-- models/common.py | 2 +- models/yolo.py | 22 ++- utils/add_nms.py | 151 ++++++++++++++++++ 12 files changed, 284 insertions(+), 18 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml create mode 100644 .idea/yolov7.iml create mode 100644 utils/add_nms.py diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000000..359bb5307e --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000000..28c863b6ff --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,46 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000000..105ce2da2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000000..d1e22ecb89 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000000..ecffae6dc6 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000000..94a25f7f4c --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/yolov7.iml b/.idea/yolov7.iml new file mode 100644 index 0000000000..8b8c395472 --- /dev/null +++ b/.idea/yolov7.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/README.md b/README.md index 4d1f48c9b6..1892b1b067 100644 --- a/README.md +++ b/README.md @@ -151,6 +151,13 @@ python detect.py --weights yolov7.pt --conf 0.25 --img-size 640 --source inferen + +## Export +Use the args `--include-nms` can to export end to end onnx model which include the `EfficientNMS`. +```shell +python models/export.py --weights yolov7.pt --grid --include-nms +``` + ## Citation ``` diff --git a/export.py b/export.py index 06dfc942c3..fd4975f6f9 100644 --- a/export.py +++ b/export.py @@ -12,6 +12,7 @@ from utils.activations import Hardswish, SiLU from utils.general import set_logging, check_img_size from utils.torch_utils import select_device +from utils.add_nms import RegisterNMS if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -22,6 +23,7 @@ parser.add_argument('--grid', action='store_true', help='export Detect() layer grid') parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--simplify', action='store_true', help='simplify onnx model') + parser.add_argument('--include-nms', action='store_true', help='export end2end onnx') opt = parser.parse_args() opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand print(opt) @@ -52,7 +54,9 @@ # m.forward = m.forward_export # assign forward (optional) model.model[-1].export = not opt.grid # set Detect() layer grid export y = model(img) # dry run - + if opt.include_nms: + model.model[-1].include_nms = True + y = None # TorchScript export try: print('\nStarting TorchScript export with torch %s...' % torch.__version__) @@ -75,16 +79,23 @@ dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640) 'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None) - # Checks - onnx_model = onnx.load(f) # load onnx model - onnx.checker.check_model(onnx_model) # check onnx model - - # # Metadata - # d = {'stride': int(max(model.stride))} - # for k, v in d.items(): - # meta = onnx_model.metadata_props.add() - # meta.key, meta.value = k, str(v) - # onnx.save(onnx_model, f) + if opt.include_nms: + print('Registering NMS plugin...') + mo = RegisterNMS(f) + mo.register_nms() + mo.save(f) + else: + # Checks + onnx_model = onnx.load(f) # load onnx model + onnx.checker.check_model(onnx_model) # check onnx model + # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model + + # # Metadata + # d = {'stride': int(max(model.stride))} + # for k, v in d.items(): + # meta = onnx_model.metadata_props.add() + # meta.key, meta.value = k, str(v) + # onnx.save(onnx_model, f) if opt.simplify: try: @@ -95,11 +106,9 @@ assert check, 'assert check failed' except Exception as e: print(f'Simplifier failure: {e}') - # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model print('ONNX export success, saved as %s' % f) except Exception as e: print('ONNX export failure: %s' % e) - # CoreML export try: import coremltools as ct diff --git a/models/common.py b/models/common.py index 53e3f87193..111af708de 100644 --- a/models/common.py +++ b/models/common.py @@ -236,7 +236,7 @@ def forward(self, x): class ResX(Res): # ResNet bottleneck def __init__(self, c1, c2, shortcut=True, g=32, e=0.5): # ch_in, ch_out, shortcut, groups, expansion - super().__init__(c1, c2, shortcu, g, e) + super().__init__(c1, c2, shortcut, g, e) c_ = int(c2 * e) # hidden channels diff --git a/models/yolo.py b/models/yolo.py index 5f45aad886..5d2845f1f0 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -5,7 +5,7 @@ sys.path.append('./') # to run '$ python *.py' files in subdirectories logger = logging.getLogger(__name__) - +import torch from models.common import * from models.experimental import * from utils.autoanchor import check_anchor_order @@ -23,7 +23,7 @@ class Detect(nn.Module): stride = None # strides computed during build export = False # onnx export - + include_nms = False def __init__(self, nc=80, anchors=(), ch=()): # detection layer super(Detect, self).__init__() self.nc = nc # number of classes @@ -48,7 +48,6 @@ def forward(self, x): if not self.training: # inference if self.grid[i].shape[2:4] != x[i].shape[2:4]: self.grid[i] = self._make_grid(nx, ny).to(x[i].device) - y = x[i].sigmoid() if not torch.onnx.is_in_onnx_export(): y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy @@ -59,13 +58,28 @@ def forward(self, x): y = torch.cat((xy, wh, y[..., 4:]), -1) z.append(y.view(bs, -1, self.no)) - return x if self.training else (torch.cat(z, 1), x) + if self.include_nms: + z = self.convert(z) + + return x if self.training else (z, ) if self.include_nms else (torch.cat(z, 1), x) @staticmethod def _make_grid(nx=20, ny=20): yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() + def convert(self, z): + z = torch.cat(z, 1) + box = z[:, :, :4] + conf = z[:, :, 4:5] + score = z[:, :, 5:] + score *= conf + convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]], + dtype=torch.float32, + device=z.device) + box @= convert_matrix + return (box, score) + class IDetect(nn.Module): stride = None # strides computed during build diff --git a/utils/add_nms.py b/utils/add_nms.py new file mode 100644 index 0000000000..8cfa23919e --- /dev/null +++ b/utils/add_nms.py @@ -0,0 +1,151 @@ +import numpy as np +import onnx +from onnx import shape_inference +import onnx_graphsurgeon as gs +import logging + +LOGGER = logging.getLogger(__name__) + +class RegisterNMS(object): + def __init__( + self, + onnx_model_path: str, + precision: str = "fp32", + ): + + self.graph = gs.import_onnx(onnx.load(onnx_model_path)) + assert self.graph + LOGGER.info("ONNX graph created successfully") + # Fold constants via ONNX-GS that PyTorch2ONNX may have missed + self.graph.fold_constants() + self.precision = precision + self.batch_size = 1 + def infer(self): + """ + Sanitize the graph by cleaning any unconnected nodes, do a topological resort, + and fold constant inputs values. When possible, run shape inference on the + ONNX graph to determine tensor shapes. + """ + for _ in range(3): + count_before = len(self.graph.nodes) + + self.graph.cleanup().toposort() + try: + for node in self.graph.nodes: + for o in node.outputs: + o.shape = None + model = gs.export_onnx(self.graph) + model = shape_inference.infer_shapes(model) + self.graph = gs.import_onnx(model) + except Exception as e: + LOGGER.info(f"Shape inference could not be performed at this time:\n{e}") + try: + self.graph.fold_constants(fold_shapes=True) + except TypeError as e: + LOGGER.error( + "This version of ONNX GraphSurgeon does not support folding shapes, " + f"please upgrade your onnx_graphsurgeon module. Error:\n{e}" + ) + raise + + count_after = len(self.graph.nodes) + if count_before == count_after: + # No new folding occurred in this iteration, so we can stop for now. + break + + def save(self, output_path): + """ + Save the ONNX model to the given location. + Args: + output_path: Path pointing to the location where to write + out the updated ONNX model. + """ + self.graph.cleanup().toposort() + model = gs.export_onnx(self.graph) + onnx.save(model, output_path) + LOGGER.info(f"Saved ONNX model to {output_path}") + + def register_nms( + self, + *, + score_thresh: float = 0.25, + nms_thresh: float = 0.45, + detections_per_img: int = 100, + ): + """ + Register the ``EfficientNMS_TRT`` plugin node. + NMS expects these shapes for its input tensors: + - box_net: [batch_size, number_boxes, 4] + - class_net: [batch_size, number_boxes, number_labels] + Args: + score_thresh (float): The scalar threshold for score (low scoring boxes are removed). + nms_thresh (float): The scalar threshold for IOU (new boxes that have high IOU + overlap with previously selected boxes are removed). + detections_per_img (int): Number of best detections to keep after NMS. + """ + + self.infer() + # Find the concat node at the end of the network + op_inputs = self.graph.outputs + op = "EfficientNMS_TRT" + attrs = { + "plugin_version": "1", + "background_class": -1, # no background class + "max_output_boxes": detections_per_img, + "score_threshold": score_thresh, + "iou_threshold": nms_thresh, + "score_activation": False, + "box_coding": 0, + } + + if self.precision == "fp32": + dtype_output = np.float32 + elif self.precision == "fp16": + dtype_output = np.float16 + else: + raise NotImplementedError(f"Currently not supports precision: {self.precision}") + + # NMS Outputs + output_num_detections = gs.Variable( + name="num_detections", + dtype=np.int32, + shape=[self.batch_size, 1], + ) # A scalar indicating the number of valid detections per batch image. + output_boxes = gs.Variable( + name="detection_boxes", + dtype=dtype_output, + shape=[self.batch_size, detections_per_img, 4], + ) + output_scores = gs.Variable( + name="detection_scores", + dtype=dtype_output, + shape=[self.batch_size, detections_per_img], + ) + output_labels = gs.Variable( + name="detection_classes", + dtype=np.int32, + shape=[self.batch_size, detections_per_img], + ) + + op_outputs = [output_num_detections, output_boxes, output_scores, output_labels] + + # Create the NMS Plugin node with the selected inputs. The outputs of the node will also + # become the final outputs of the graph. + self.graph.layer(op=op, name="batched_nms", inputs=op_inputs, outputs=op_outputs, attrs=attrs) + LOGGER.info(f"Created NMS plugin '{op}' with attributes: {attrs}") + + self.graph.outputs = op_outputs + + self.infer() + + def save(self, output_path): + """ + Save the ONNX model to the given location. + Args: + output_path: Path pointing to the location where to write + out the updated ONNX model. + """ + self.graph.cleanup().toposort() + model = gs.export_onnx(self.graph) + onnx.save(model, output_path) + LOGGER.info(f"Saved ONNX model to {output_path}") \ No newline at end of file