Skip to content

Commit

Permalink
Detect() and Segment() fixes for CoreML and Paddle (#9458)
Browse files Browse the repository at this point in the history
* Detect() and Segment() fixes for CoreML and Paddle

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
glenn-jocher and pre-commit-ci[bot] committed Sep 17, 2022
1 parent 6a9fffd commit 0608374
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,17 @@ def forward(self, x):
if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

y = x[i].clone()
y[..., :5 + self.nc].sigmoid_()
if self.inplace:
y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy, wh, etc = y.split((2, 2, self.no - 4), 4) # tensor_split((2, 4, 5), 4) if torch 1.8.0
if isinstance(self, Segment): # (boxes + masks)
xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i] # xy
wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)
else: # Detect (boxes only)
xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, etc), 4)
z.append(y.view(bs, -1, self.no))
y = torch.cat((xy, wh, conf), 4)
z.append(y.view(bs, self.na * nx * ny, self.no))

return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)

Expand Down

0 comments on commit 0608374

Please sign in to comment.