Skip to content

Commit

Permalink
check_anchor_order() update
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jun 21, 2020
1 parent 1f1917e commit fc171e2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
3 changes: 2 additions & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def __init__(self, model_cfg='yolov5s.yaml', ch=3, nc=None): # model, input cha

# Build strides, anchors
m = self.model[-1] # Detect()
m.stride = torch.tensor([64 / x.shape[-2] for x in self.forward(torch.zeros(1, ch, 64, 64))]) # forward
m.stride = torch.tensor([128 / x.shape[-2] for x in self.forward(torch.zeros(1, ch, 128, 128))]) # forward
m.anchors /= m.stride.view(-1, 1, 1)
check_anchor_order(m)
self.stride = m.stride

# Init weights, biases
Expand Down
14 changes: 13 additions & 1 deletion utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
print('\nAnalyzing anchors... ', end='')
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh

def metric(k): # compute metric
r = wh[:, None] / k[None]
Expand All @@ -77,12 +78,23 @@ def metric(k): # compute metric
new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
check_anchor_order(m)
print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
else:
print('Original anchors better than new anchors. Proceeding with original anchors.')
print('') # newline


def check_anchor_order(m):
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
a = m.anchor_grid.prod(-1).view(-1) # anchor area
da = a[-1] - a[0] # delta a
ds = m.stride[-1] - m.stride[0] # delta s
if da.sign() != ds.sign(): # same order
m.anchors[:] = m.anchors.flip(0)
m.anchor_grid[:] = m.anchor_grid.flip(0)


def check_file(file):
# Searches for file if not found locally
if os.path.isfile(file):
Expand Down

0 comments on commit fc171e2

Please sign in to comment.