diff --git a/models/export.py b/models/export.py index 06d51f29e2e4..b262df83bf21 100644 --- a/models/export.py +++ b/models/export.py @@ -28,6 +28,9 @@ attempt_download(opt.weights) model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float() model.eval() + model.fuse() + + # Update model model.model[-1].export = True # set Detect() layer export=True y = model(img) # dry run @@ -47,7 +50,6 @@ print('\nStarting ONNX export with onnx %s...' % onnx.__version__) f = opt.weights.replace('.pt', '.onnx') # filename - model.fuse() # only for ONNX torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'], output_names=['classes', 'boxes'] if y is None else ['output']) diff --git a/models/yolo.py b/models/yolo.py index bba015a80a7e..f1c3c3f9084a 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -163,7 +163,7 @@ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers if type(m) is Conv: m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv - m.bn = None # remove batchnorm + delattr(m, 'bn') # remove batchnorm m.forward = m.fuseforward # update forward self.info() return self