Skip to content

Commit

Permalink
Initialize bias into YoloHead (#67)
Browse files Browse the repository at this point in the history
* Initialize weights and bias in YoloHead

* Fix format in Docs

* Fix unittest

* Minor fixes
  • Loading branch information
zhiqwang committed Feb 22, 2021
1 parent e5e9012 commit 5825161
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
3 changes: 2 additions & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def _init_test_yolo_head(self):
in_channels = self._get_in_channels()
num_anchors = self._get_num_anchors()
num_classes = self._get_num_classes()
box_head = YoloHead(in_channels, num_anchors, num_classes)
strides = self._get_strides()
box_head = YoloHead(in_channels, num_anchors, strides, num_classes)
return box_head

def test_yolo_head(self):
Expand Down
23 changes: 21 additions & 2 deletions yolort/models/box_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Modified from ultralytics/yolov5 by Zhiqiang Wang
import math
import torch
from torch import nn, Tensor

Expand All @@ -10,14 +11,31 @@


class YoloHead(nn.Module):
def __init__(self, in_channels: List[int], num_anchors: int, num_classes: int):
def __init__(self, in_channels: List[int], num_anchors: int, strides: List[int], num_classes: int):
super().__init__()
self.num_anchors = num_anchors # anchors
self.num_classes = num_classes
self.num_outputs = num_classes + 5 # number of outputs per anchor
self.strides = strides

self.head = nn.ModuleList(
nn.Conv2d(ch, self.num_outputs * self.num_anchors, 1) for ch in in_channels) # output conv

self._initialize_biases() # Init weights, biases

def _initialize_biases(self, cf=None):
"""
Initialize biases into YoloHead, cf is class frequency
Check section 3.3 in <https://arxiv.org/abs/1708.02002>
"""
for mi, s in zip(self.head, self.strides):
b = mi.bias.view(self.num_anchors, -1) # conv.bias(255) to (3,85)
# obj (8 objects per 640 image)
b.data[:, 4] += math.log(8 / (640 / s) ** 2)
# classes
b.data[:, 5:] += torch.log(cf / cf.sum()) if cf else math.log(0.6 / (self.num_classes - 0.99))
mi.bias = nn.Parameter(b.view(-1), requires_grad=True)

def get_result_from_head(self, features: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.head[idx](features),
Expand Down Expand Up @@ -199,7 +217,8 @@ def assign_targets_to_anchors(
# Append
a = targets_with_gain[:, 6].long() # anchor indices
# image, anchor, grid indices
indices.append((bc[0], a, grid_ij[:, 1].clamp_(0, gain[3] - 1), grid_ij[:, 0].clamp_(0, gain[2] - 1)))
indices.append((bc[0], a, grid_ij[:, 1].clamp_(0, gain[3] - 1),
grid_ij[:, 0].clamp_(0, gain[2] - 1)))
targets_box.append(torch.cat((grid_xy - grid_ij, grid_wh), 1)) # box
anchors_encode.append(anchors_per_layer[a]) # anchors
targets_cls.append(bc[1]) # class
Expand Down
1 change: 1 addition & 0 deletions yolort/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
head = YoloHead(
backbone.out_channels,
anchor_generator.num_anchors,
anchor_generator.strides,
num_classes,
)
self.head = head
Expand Down

0 comments on commit 5825161

Please sign in to comment.