diff --git a/export.py b/export.py index e73715ea13e9..574bf8d9ed61 100644 --- a/export.py +++ b/export.py @@ -477,6 +477,7 @@ def run( if isinstance(m, Detect): m.inplace = inplace m.onnx_dynamic = dynamic + m.export = True if hasattr(m, 'forward_export'): m.forward = m.forward_export # assign custom forward (optional) diff --git a/models/yolo.py b/models/yolo.py index 3dd5fe9dcd25..fee5e932fd4d 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -37,6 +37,7 @@ class Detect(nn.Module): stride = None # strides computed during build onnx_dynamic = False # ONNX export parameter + export = False # export mode def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer super().__init__() @@ -72,7 +73,7 @@ def forward(self, x): y = torch.cat((xy, wh, conf), 4) z.append(y.view(bs, -1, self.no)) - return x if self.training else (torch.cat(z, 1), x) + return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x) def _make_grid(self, nx=20, ny=20, i=0): d = self.anchors[i].device