Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ONNX export using --grid --simplify --dynamic simultaneously #2982

Merged
merged 13 commits into from
May 3, 2021
10 changes: 6 additions & 4 deletions models/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--inplace', action='store_true', help='inplace ops') # inplace = True
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
opt = parser.parse_args()
Expand Down Expand Up @@ -56,9 +56,11 @@
m.act = Hardswish()
elif isinstance(m.act, nn.SiLU):
m.act = SiLU()
# elif isinstance(m, models.yolo.Detect):
# m.forward = m.forward_export # assign forward (optional)
model.model[-1].export = not opt.grid # set Detect() layer grid export
elif isinstance(m, models.yolo.Detect):
m.inplace = opt.inplace
m.onnx_dynamic = opt.dynamic
# m.forward = m.forward_export # assign forward (optional)

for _ in range(2):
y = model(img) # dry runs
print(f"\n{colorstr('PyTorch:')} starting from {opt.weights} ({file_size(opt.weights):.1f} MB)")
Expand Down
7 changes: 3 additions & 4 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

class Detect(nn.Module):
stride = None # strides computed during build
export = False # onnx export
onnx_dynamic = False # ONNX export parameter

def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
super(Detect, self).__init__()
Expand All @@ -42,14 +42,13 @@ def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
def forward(self, x):
# x = x.copy() # for profiling
z = [] # inference output
self.training |= self.export
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)

y = x[i].sigmoid()
Expand All @@ -58,7 +57,7 @@ def forward(self, x):
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(bs, self.na, 1, 1, 2) # wh
glenn-jocher marked this conversation as resolved.
Show resolved Hide resolved
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))

Expand Down