Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Training of densehead #6315

Merged
merged 31 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
304150d
Refactor one-stage get_bboxes logic (#5317)
hhaAndroid Aug 11, 2021
38c0ecf
pull master
jshilong Sep 1, 2021
b6f5700
support onnx export for fcos
jshilong Sep 2, 2021
032ae46
support onnx export for fcos fsaf retina and ssd
jshilong Sep 2, 2021
21e5b34
resolve comments
jshilong Sep 4, 2021
9903e59
resolve comments
jshilong Sep 9, 2021
94ee765
add with nms
jshilong Sep 9, 2021
8570f4c
support cornernet
jshilong Sep 9, 2021
b42f395
resolve comments
jshilong Sep 11, 2021
a377b13
add default with nms
jshilong Sep 15, 2021
59a4fb6
pull refactor dense
jshilong Oct 11, 2021
7e56e59
fix trt arrange should be int
jshilong Oct 11, 2021
f4483f6
Merge branch 'refactor_dense' of https://github.com/open-mmlab/mmdete…
jshilong Oct 18, 2021
c7a0639
refactor anchor head anchor free head
jshilong Oct 19, 2021
5e1edf9
add dtype to single_level_grid_priors
jshilong Oct 19, 2021
cbd14f0
atss fcos autoassign
jshilong Oct 19, 2021
a362a38
fovea
jshilong Oct 19, 2021
ce8b15a
fsaf free anchor
jshilong Oct 19, 2021
7e9c91f
suport more
jshilong Oct 19, 2021
863c97e
suport more
jshilong Oct 19, 2021
17fee99
support all
jshilong Oct 19, 2021
a4eb3a7
resolve conversation
jshilong Oct 20, 2021
35ae641
fix point generator
jshilong Oct 20, 2021
f988124
fix device
jshilong Oct 20, 2021
3110071
pull refactor dense
jshilong Oct 25, 2021
352f0b1
change to distancecoder
jshilong Oct 25, 2021
7fbeef5
resolve conversation
jshilong Oct 25, 2021
4c3d31b
fix grid prior
jshilong Oct 25, 2021
3250e33
fix typos in autoassgin
jshilong Oct 25, 2021
b4fe4e6
fix typos
jshilong Oct 25, 2021
67dee09
fix doc
jshilong Oct 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmdet/core/anchor/anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ class AnchorGenerator:
Examples:
>>> from mmdet.core import AnchorGenerator
>>> self = AnchorGenerator([16], [1.], [1.], [9])
>>> all_anchors = self.grid_anchors([(2, 2)], device='cpu')
>>> all_anchors = self.grid_priors([(2, 2)], device='cpu')
>>> print(all_anchors)
[tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
[11.5000, -4.5000, 20.5000, 4.5000],
[-4.5000, 11.5000, 4.5000, 20.5000],
[11.5000, 11.5000, 20.5000, 20.5000]])]
>>> self = AnchorGenerator([16, 32], [1.], [1.], [9, 18])
>>> all_anchors = self.grid_anchors([(2, 2), (1, 1)], device='cpu')
>>> all_anchors = self.grid_priors([(2, 2), (1, 1)], device='cpu')
>>> print(all_anchors)
[tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
[11.5000, -4.5000, 20.5000, 4.5000],
Expand Down
1 change: 1 addition & 0 deletions mmdet/core/anchor/point_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def single_level_grid_priors(self,
if not with_stride:
shifts = torch.stack([shift_xx, shift_yy], dim=-1)
else:
# use `shape[0]` instead of `len(shift_xx)` for ONNX export
stride_w = shift_xx.new_full((shift_xx.shape[0], ),
jshilong marked this conversation as resolved.
Show resolved Hide resolved
stride_w).to(dtype)
stride_h = shift_xx.new_full((shift_yy.shape[0], ),
Expand Down
25 changes: 24 additions & 1 deletion mmdet/models/dense_heads/anchor_free_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import abstractmethod

import torch
Expand Down Expand Up @@ -88,7 +89,13 @@ def __init__(self,
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder)

self.prior_generator = MlvlPointGenerator(strides)

# In order to keep a more general interface and be consistent with
# anchor_head. We can think of point like one anchor
self.num_base_priors = self.prior_generator.num_base_priors[0]

self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.conv_cfg = conv_cfg
Expand Down Expand Up @@ -278,7 +285,17 @@ def _get_points_single(self,
dtype,
device,
flatten=False):
"""Get points of a single scale level."""
"""Get points of a single scale level.

This function will be deprecated soon.
"""

warnings.warn(
'`_get_points_single` in `AnchorFreeHead` will be '
'deprecated soon, we support a multi level point generator now'
'you can get points of a single level feature map '
'with `self.prior_generator.single_level_grid_priors` ')

h, w = featmap_size
# First create Range with the default dtype, than convert to
# target `dtype` for onnx exporting.
Expand All @@ -301,6 +318,12 @@ def get_points(self, featmap_sizes, dtype, device, flatten=False):
Returns:
tuple: points of each image.
"""
warnings.warn(
'`get_points` in `AnchorFreeHead` will be '
'deprecated soon, we support a multi level point generator now'
'you can get points of all levels '
'with `self.prior_generator.grid_priors` ')

mlvl_points = []
for i in range(len(featmap_sizes)):
mlvl_points.append(
Expand Down
31 changes: 21 additions & 10 deletions mmdet/models/dense_heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,20 @@ def __init__(self,
self.fp16_enabled = False

self.prior_generator = build_prior_generator(anchor_generator)

# usually the numbers of anchors for each level are the same
# except SSD detectors
self.num_anchors = self.prior_generator.num_base_priors[0]
# except SSD detectors, so it is a int in most densehead and
# it will be a list of int in SSDHead
jshilong marked this conversation as resolved.
Show resolved Hide resolved
self.num_base_priors = self.prior_generator.num_base_priors[0]
self._init_layers()

@property
def num_anchors(self):
warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
'for consistency or also use '
'`num_base_priors` instead')
return self.prior_generator.num_base_priors[0]

@property
def anchor_generator(self):
warnings.warn('DeprecationWarning: anchor_generator is deprecated, '
Expand All @@ -108,8 +117,10 @@ def anchor_generator(self):
def _init_layers(self):
"""Initialize layers of the head."""
self.conv_cls = nn.Conv2d(self.in_channels,
self.num_anchors * self.cls_out_channels, 1)
self.conv_reg = nn.Conv2d(self.in_channels, self.num_anchors * 4, 1)
self.num_base_priors * self.cls_out_channels,
1)
self.conv_reg = nn.Conv2d(self.in_channels, self.num_base_priors * 4,
1)

def forward_single(self, x):
"""Forward feature of a single scale level.
Expand All @@ -120,9 +131,9 @@ def forward_single(self, x):
Returns:
tuple:
cls_score (Tensor): Cls scores for a single scale level \
the channels number is num_anchors * num_classes.
the channels number is num_base_priors * num_classes.
bbox_pred (Tensor): Box energies / deltas for a single scale \
level, the channels number is num_anchors * 4.
level, the channels number is num_base_priors * 4.
"""
cls_score = self.conv_cls(x)
bbox_pred = self.conv_reg(x)
Expand All @@ -140,10 +151,10 @@ def forward(self, feats):

- cls_scores (list[Tensor]): Classification scores for all \
scale levels, each is a 4D-tensor, the channels number \
is num_anchors * num_classes.
is num_base_priors * num_classes.
- bbox_preds (list[Tensor]): Box energies / deltas for all \
scale levels, each is a 4D-tensor, the channels number \
is num_anchors * 4.
is num_base_priors * 4.
"""
return multi_apply(self.forward_single, feats)

Expand All @@ -164,8 +175,8 @@ def get_anchors(self, featmap_sizes, img_metas, device='cuda'):

# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = self.prior_generator.grid_anchors(
featmap_sizes, device)
multi_level_anchors = self.prior_generator.grid_priors(
featmap_sizes, device=device)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]

# for each image, we compute valid flags of multi level anchors
Expand Down
4 changes: 2 additions & 2 deletions mmdet/models/dense_heads/atss_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def _init_layers(self):
3,
padding=1)
self.atss_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 4, 3, padding=1)
self.feat_channels, self.num_base_priors * 4, 3, padding=1)
self.atss_centerness = nn.Conv2d(
self.feat_channels, self.num_anchors * 1, 3, padding=1)
self.feat_channels, self.num_base_priors * 1, 3, padding=1)
self.scales = nn.ModuleList(
[Scale(1.0) for _ in self.prior_generator.strides])

Expand Down
63 changes: 37 additions & 26 deletions mmdet/models/dense_heads/autoassign_head.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import bias_init_with_prob, normal_init
from mmcv.runner import force_fp32

from mmdet.core import distance2bbox, multi_apply
from mmdet.core import multi_apply
from mmdet.core.anchor.point_generator import MlvlPointGenerator
from mmdet.core.bbox import bbox_overlaps
from mmdet.models import HEADS
Expand Down Expand Up @@ -174,22 +176,6 @@ def init_weights(self):
normal_init(self.conv_cls, std=0.01, bias=bias_cls)
normal_init(self.conv_reg, std=0.01, bias=4.0)

def _get_points_single(self,
featmap_size,
stride,
dtype,
device,
flatten=False):
"""Almost the same as the implementation in fcos, we remove half stride
offset to align with the original implementation."""

y, x = super(FCOSHead,
self)._get_points_single(featmap_size, stride, dtype,
device)
points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
dim=-1)
return points

def forward_single(self, x, scale, stride):
"""Forward features of a single scale level.

Expand Down Expand Up @@ -349,8 +335,10 @@ def loss(self,
assert len(cls_scores) == len(bbox_preds) == len(objectnesses)
all_num_gt = sum([len(item) for item in gt_bboxes])
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
all_level_points = self.prior_generator.grid_priors(
featmap_sizes,
dtype=bbox_preds[0].dtype,
device=bbox_preds[0].device)
inside_gt_bbox_mask_list, bbox_targets_list = self.get_targets(
all_level_points, gt_bboxes)

Expand All @@ -364,7 +352,6 @@ def loss(self,
center_prior_weight_list.append(center_prior_weight)
temp_inside_gt_bbox_mask_list.append(inside_gt_bbox_mask)
inside_gt_bbox_mask_list = temp_inside_gt_bbox_mask_list

mlvl_points = torch.cat(all_level_points, dim=0)
bbox_preds = levels_to_images(bbox_preds)
cls_scores = levels_to_images(cls_scores)
Expand All @@ -374,17 +361,18 @@ def loss(self,
ious_list = []
num_points = len(mlvl_points)

for bbox_pred, gt_bboxe, inside_gt_bbox_mask in zip(
for bbox_pred, encoded_targets, inside_gt_bbox_mask in zip(
bbox_preds, bbox_targets_list, inside_gt_bbox_mask_list):
temp_num_gt = gt_bboxe.size(1)
temp_num_gt = encoded_targets.size(1)
expand_mlvl_points = mlvl_points[:, None, :].expand(
num_points, temp_num_gt, 2).reshape(-1, 2)
gt_bboxe = gt_bboxe.reshape(-1, 4)
encoded_targets = encoded_targets.reshape(-1, 4)
expand_bbox_pred = bbox_pred[:, None, :].expand(
num_points, temp_num_gt, 4).reshape(-1, 4)
decoded_bbox_preds = distance2bbox(expand_mlvl_points,
expand_bbox_pred)
decoded_target_preds = distance2bbox(expand_mlvl_points, gt_bboxe)
decoded_bbox_preds = self.bbox_coder.decode(
expand_mlvl_points, expand_bbox_pred)
decoded_target_preds = self.bbox_coder.decode(
expand_mlvl_points, encoded_targets)
with torch.no_grad():
ious = bbox_overlaps(
decoded_bbox_preds, decoded_target_preds, is_aligned=True)
Expand Down Expand Up @@ -511,3 +499,26 @@ def _get_target_single(self, gt_bboxes, points):
dtype=torch.bool)

return inside_gt_bbox_mask, bbox_targets

def _get_points_single(self,
featmap_size,
stride,
dtype,
device,
flatten=False):
"""Almost the same as the implementation in fcos, we remove half stride
offset to align with the original implementation.

This function will be deprecated soon.
"""
warnings.warn(
'`_get_points_single` in `AutoAssignHead` will be '
'deprecated soon, we support a multi level point generator now'
'you can get points of a single level feature map '
'with `self.prior_generator.single_level_grid_priors` ')
y, x = super(FCOSHead,
jshilong marked this conversation as resolved.
Show resolved Hide resolved
self)._get_points_single(featmap_size, stride, dtype,
device)
points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
dim=-1)
return points
28 changes: 16 additions & 12 deletions mmdet/models/dense_heads/base_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def get_bboxes(self,

featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, device=cls_scores[0].device)
featmap_sizes,
dtype=cls_scores[0].device,
device=cls_scores[0].device)

result_list = []

Expand Down Expand Up @@ -118,8 +120,9 @@ def _get_bboxes_single(self,
levels of a single image, each item has shape
(num_priors * 1, H, W).
mlvl_priors (list[Tensor]): Each element in the list is
the priors of a single level in feature pyramid, has shape
(num_priors, 4).
the priors of a single level in feature pyramid. In all
anchor-based methods, it has shape (num_priors, 4). In
all anchor-free methods, it has shape (num_priors, 2).
jshilong marked this conversation as resolved.
Show resolved Hide resolved
img_meta (dict): Image meta info.
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
Expand Down Expand Up @@ -181,17 +184,17 @@ def _get_bboxes_single(self,
scores = cls_score.softmax(-1)[:, :-1]

# After https://github.com/open-mmlab/mmdetection/pull/6268/,
# this operation keeps fewer bboxes under the same `nms_pre`,
# there is no difference in performance for most models, if you
# find a slight drop in performance, You can set a larger
# this operation keeps fewer bboxes under the same `nms_pre`.
# There is no difference in performance for most models. If you
# find a slight drop in performance, you can set a larger
# `nms_pre` than before.
results = filter_scores_and_topk(
scores, cfg.score_thr, nms_pre,
dict(bbox_pred=bbox_pred, priors=priors))
scores, labels, keep_idxs, filter_results = results
scores, labels, keep_idxs, filtered_results = results

bbox_pred = filter_results['bbox_pred']
priors = filter_results['priors']
bbox_pred = filtered_results['bbox_pred']
priors = filtered_results['priors']

if with_score_factors:
score_factor = score_factor[keep_idxs]
Expand Down Expand Up @@ -380,9 +383,10 @@ def onnx_export(self,
num_levels = len(cls_scores)

featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(featmap_sizes,
bbox_preds[0].dtype,
bbox_preds[0].device)
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=bbox_preds[0].dtype,
device=bbox_preds[0].device)

mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
Expand Down
41 changes: 27 additions & 14 deletions mmdet/models/dense_heads/fcos_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -193,8 +195,10 @@ def loss(self,
"""
assert len(cls_scores) == len(bbox_preds) == len(centernesses)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
all_level_points = self.prior_generator.grid_priors(
featmap_sizes,
dtype=bbox_preds[0].dtype,
device=bbox_preds[0].device)
labels, bbox_targets = self.get_targets(all_level_points, gt_bboxes,
gt_labels)

Expand Down Expand Up @@ -261,18 +265,6 @@ def loss(self,
loss_bbox=loss_bbox,
loss_centerness=loss_centerness)

def _get_points_single(self,
featmap_size,
stride,
dtype,
device,
flatten=False):
"""Get points according to feature map sizes."""
y, x = super()._get_points_single(featmap_size, stride, dtype, device)
points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
dim=-1) + stride // 2
return points

def get_targets(self, points, gt_bboxes_list, gt_labels_list):
"""Compute regression, classification and centerness targets for points
in multiple images.
Expand Down Expand Up @@ -438,3 +430,24 @@ def centerness_target(self, pos_bbox_targets):
left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
return torch.sqrt(centerness_targets)

def _get_points_single(self,
featmap_size,
stride,
dtype,
device,
flatten=False):
"""Get points according to feature map size.

This function will be deprecated soon.
"""
warnings.warn(
'`_get_points_single` in `FCOSHead` will be '
'deprecated soon, we support a multi level point generator now'
'you can get points of a single level feature map '
'with `self.prior_generator.single_level_grid_priors` ')

y, x = super()._get_points_single(featmap_size, stride, dtype, device)
jshilong marked this conversation as resolved.
Show resolved Hide resolved
points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
jshilong marked this conversation as resolved.
Show resolved Hide resolved
dim=-1) + stride // 2
return points
Loading