Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix export bugs #189

Merged
merged 2 commits into from
Aug 5, 2022
Merged

fix export bugs #189

merged 2 commits into from
Aug 5, 2022

Conversation

ziqi-jin
Copy link
Contributor

@ziqi-jin ziqi-jin commented Jul 27, 2022

你好,我在使用YOLOv5-Lite导出ONNX模型时遇到了问题,问题和修改如下

  1. export.py 文件默认会使用 opt.concat == True,这样就会在推理阶段使用m.cat_forward,但是opt.concat 的类型是str,这就导致我设置为False 也是会使用m.cat_forward
  • 修改:我修改了 opt.concat 的数据类型,同样默认也是True
  1. 我希望使用opt.concat == False 的情况,让导出过程可以走 m.forward,但是由于models\yolo.py的Detect类中有两行代码是inplace操作,虽然可以导出ONNX文件,但是最后的结果是错误的。代码为
                y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
                y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh

因此我对代码进行了修改,修改为

                if not torch.onnx.is_in_onnx_export():
                    y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                else:
                    xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
                    wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data  # wh
                    y = torch.cat((xy, wh, y[..., 4:]), -1)

此外,还对_make_grid的判断进行了修改,新的完整处理逻辑如下如下

        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

            if not self.training:  # inference
                if torch.onnx.is_in_onnx_export():
                    self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
                elif self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
                logits = x[i][..., 5:]

                y = x[i].sigmoid()
                if not torch.onnx.is_in_onnx_export():
                    y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                else:
                    xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
                    wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data  # wh
                    y = torch.cat((xy, wh, y[..., 4:]), -1)
                z.append(y.view(bs, -1, self.no))

这样导出的ONNX文件可以正确预测。

@ppogg ppogg merged commit 690bb11 into ppogg:master Aug 5, 2022
@ppogg
Copy link
Owner

ppogg commented Aug 5, 2022

good job!

@ppogg ppogg mentioned this pull request Aug 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants