diff --git a/detect.py b/detect.py index 53b63ebfb1..841926c72a 100644 --- a/detect.py +++ b/detect.py @@ -73,7 +73,7 @@ def detect(save_img=False): # Inference t1 = time_synchronized() - pred = model(img, augment=opt.augment)[0] + pred = model(img, augment=opt.augment) # Apply NMS pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) diff --git a/models/export.py b/export.py similarity index 86% rename from models/export.py rename to export.py index dc12559416..06dfc942c3 100644 --- a/models/export.py +++ b/export.py @@ -21,6 +21,7 @@ parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') parser.add_argument('--grid', action='store_true', help='export Detect() layer grid') parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--simplify', action='store_true', help='simplify onnx model') opt = parser.parse_args() opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand print(opt) @@ -68,6 +69,7 @@ print('\nStarting ONNX export with onnx %s...' % onnx.__version__) f = opt.weights.replace('.pt', '.onnx') # filename + model.eval() torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'], output_names=['classes', 'boxes'] if y is None else ['output'], dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640) @@ -76,6 +78,23 @@ # Checks onnx_model = onnx.load(f) # load onnx model onnx.checker.check_model(onnx_model) # check onnx model + + # # Metadata + # d = {'stride': int(max(model.stride))} + # for k, v in d.items(): + # meta = onnx_model.metadata_props.add() + # meta.key, meta.value = k, str(v) + # onnx.save(onnx_model, f) + + if opt.simplify: + try: + import onnxsim + + print('\nStarting to simplify ONNX...') + onnx_model, check = onnxsim.simplify(onnx_model) + assert check, 'assert check failed' + except Exception as e: + print(f'Simplifier failure: {e}') # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model print('ONNX export success, saved as %s' % f) except Exception as e: diff --git a/models/yolo.py b/models/yolo.py index 7e1b3da172..a3835e782e 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -50,11 +50,16 @@ def forward(self, x): self.grid[i] = self._make_grid(nx, ny).to(x[i].device) y = x[i].sigmoid() - y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy - y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + if not torch.onnx.is_in_onnx_export(): + y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + else: + xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + y = torch.cat((xy, wh, y[..., 4:]), -1) z.append(y.view(bs, -1, self.no)) - return x if self.training else (torch.cat(z, 1), x) + return x if self.training else torch.cat(z, 1) @staticmethod def _make_grid(nx=20, ny=20):