Skip to content

Commit

Permalink
added anchors and excluded postprocess from detect
Browse files Browse the repository at this point in the history
  • Loading branch information
franklin-degirum committed May 6, 2023
1 parent c3e4e94 commit e37b631
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
7 changes: 6 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX
onnx.checker.check_model(model_onnx) # check onnx model

# Metadata
d = {'stride': int(max(model.stride)), 'names': model.names}
# d = {'stride': int(max(model.stride)), 'names': model.names}
d = {'stride': int(max(model.stride)), 'names': dict(model.names), 'anchors':model.model[-1].anchors.numpy().tolist()}
for k, v in d.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)
Expand Down Expand Up @@ -674,6 +675,7 @@ def run(
topk_all=100, # TF.js NMS: topk for all classes to keep
iou_thres=0.45, # TF.js NMS: IoU threshold
conf_thres=0.25, # TF.js NMS: confidence threshold
exclude_postprocess_detect=False, # onnx export excludes postprocessing for detection models
):
t = time.time()
include = [x.lower() for x in include] # to lowercase
Expand Down Expand Up @@ -707,6 +709,8 @@ def run(
m.inplace = inplace
m.dynamic = dynamic
m.export = True
m.exclude_postprocess_detect = exclude_postprocess_detect


for _ in range(2):
y = model(im) # dry runs
Expand Down Expand Up @@ -798,6 +802,7 @@ def parse_opt(known=False):
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
parser.add_argument('--exclude-postprocess-detect', action='store_true', help='onnx export excludes postprocessing for detection models')
parser.add_argument(
'--include',
nargs='+',
Expand Down
20 changes: 13 additions & 7 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
except ImportError:
thop = None


class Detect(nn.Module):
# YOLOv5 Detect head for detection models
stride = None # strides computed during build
dynamic = False # force grid reconstruction
export = False # export mode
exclude_postprocess_detect = False # onnx export excludes postprocess

def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
super().__init__()
Expand All @@ -58,23 +58,29 @@ def forward(self, x):
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

if not self.training: # inference
if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

if isinstance(self, Segment): # (boxes + masks)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i] # xy
wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)
z.append(y.view(bs, self.na * nx * ny, self.no))
else: # Detect (boxes only)
xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, conf), 4)
z.append(y.view(bs, self.na * nx * ny, self.no))
if self.exclude_postprocess_detect and self.export:
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 3, 4, 1, 2).contiguous() # Mehrdad: Onnx
z.append(x[i].view(bs, self.na * nx * ny, self.no))
else:
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, conf), 4)
z.append(y.view(bs, self.na * nx * ny, self.no))

return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)

Expand Down

0 comments on commit e37b631

Please sign in to comment.