Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] RepVGG for YOLOX-PAI for dev-1.x #1126

Merged
merged 1 commit into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 91 additions & 12 deletions mmcls/models/backbones/repvgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
build_norm_layer)
from mmengine.model import BaseModule, Sequential
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from torch import nn

from mmcls.registry import MODELS
from ..utils.se_layer import SELayer
Expand Down Expand Up @@ -254,6 +256,51 @@ def _norm_to_conv3x3(self, branch_nrom):
return tmp_conv3x3


class MTSPPF(BaseModule):
"""MTSPPF block for YOLOX-PAI RepVGG backbone.

Args:
in_channels (int): The input channels of the block.
out_channels (int): The output channels of the block.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
kernel_size (int): Kernel size of pooling. Default: 5.
"""

def __init__(self,
in_channels,
out_channels,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
kernel_size=5):
super().__init__()
hidden_features = in_channels // 2 # hidden channels
self.conv1 = ConvModule(
in_channels,
hidden_features,
1,
stride=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv2 = ConvModule(
hidden_features * 4,
out_channels,
1,
stride=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.maxpool = nn.MaxPool2d(
kernel_size=kernel_size, stride=1, padding=kernel_size // 2)

def forward(self, x):
x = self.conv1(x)
y1 = self.maxpool(x)
y2 = self.maxpool(y1)
return self.conv2(torch.cat([x, y1, y2, self.maxpool(y2)], 1))


@MODELS.register_module()
class RepVGG(BaseBackbone):
"""RepVGG backbone.
Expand All @@ -262,17 +309,22 @@ class RepVGG(BaseBackbone):
<https://arxiv.org/abs/2101.03697>`_

Args:
arch (str | dict): The parameter of RepVGG.
If it's a dict, it should contain the following keys:

arch (str | dict): RepVGG architecture. If use string,
choose from 'A0', 'A1`', 'A2', 'B0', 'B1', 'B1g2', 'B1g4', 'B2'
, 'B2g2', 'B2g4', 'B3', 'B3g2', 'B3g4' or 'D2se'. If use dict,
it should have below keys:
- num_blocks (Sequence[int]): Number of blocks in each stage.
- width_factor (Sequence[float]): Width deflator in each stage.
- group_layer_map (dict | None): RepVGG Block that declares
the need to apply group convolution.
- se_cfg (dict | None): Se Layer config
- se_cfg (dict | None): Se Layer config.
- stem_channels (int, optional): The stem channels, the final
stem channels will be
``min(stem_channels, base_channels*width_factor[0])``.
If not set here, 64 is used by default in the code.
in_channels (int): Number of input image channels. Default: 3.
base_channels (int): Base channels of RepVGG backbone, work
with width_factor together. Default: 64.
base_channels (int): Base channels of RepVGG backbone, work with
width_factor together. Defaults to 64.
out_indices (Sequence[int]): Output from which stages. Default: (3, ).
strides (Sequence[int]): Strides of the first block of each stage.
Default: (2, 2, 2, 2).
Expand All @@ -292,6 +344,7 @@ class RepVGG(BaseBackbone):
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
add_ppf (bool): Whether to use the MTSPPF block. Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""

Expand Down Expand Up @@ -323,7 +376,8 @@ class RepVGG(BaseBackbone):
num_blocks=[4, 6, 16, 1],
width_factor=[1, 1, 1, 2.5],
group_layer_map=None,
se_cfg=None),
se_cfg=None,
stem_channels=64),
'B1':
dict(
num_blocks=[4, 6, 16, 1],
Expand Down Expand Up @@ -383,7 +437,14 @@ class RepVGG(BaseBackbone):
num_blocks=[8, 14, 24, 1],
width_factor=[2.5, 2.5, 2.5, 5],
group_layer_map=None,
se_cfg=dict(ratio=16, divisor=1))
se_cfg=dict(ratio=16, divisor=1)),
'yolox-pai-small':
dict(
num_blocks=[3, 5, 7, 3],
width_factor=[1, 1, 1, 1],
group_layer_map=None,
se_cfg=None,
stem_channels=32),
}

def __init__(self,
Expand All @@ -400,6 +461,7 @@ def __init__(self,
with_cp=False,
deploy=False,
norm_eval=False,
add_ppf=False,
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(
Expand Down Expand Up @@ -427,9 +489,9 @@ def __init__(self,
if arch['se_cfg'] is not None:
assert isinstance(arch['se_cfg'], dict)

self.base_channels = base_channels
self.arch = arch
self.in_channels = in_channels
self.base_channels = base_channels
self.out_indices = out_indices
self.strides = strides
self.dilations = dilations
Expand All @@ -441,7 +503,12 @@ def __init__(self,
self.with_cp = with_cp
self.norm_eval = norm_eval

channels = min(64, int(base_channels * self.arch['width_factor'][0]))
# defaults to 64 to prevert BC-breaking if stem_channels
# not in arch dict;
# the stem channels should not be larger than that of stage1.
channels = min(
arch.get('stem_channels', 64),
int(self.base_channels * self.arch['width_factor'][0]))
self.stem = RepVGGBlock(
self.in_channels,
channels,
Expand All @@ -459,7 +526,7 @@ def __init__(self,
num_blocks = self.arch['num_blocks'][i]
stride = self.strides[i]
dilation = self.dilations[i]
out_channels = int(base_channels * 2**i *
out_channels = int(self.base_channels * 2**i *
self.arch['width_factor'][i])

stage, next_create_block_idx = self._make_stage(
Expand All @@ -471,6 +538,16 @@ def __init__(self,

channels = out_channels

if add_ppf:
self.ppf = MTSPPF(
out_channels,
out_channels,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
kernel_size=5)
else:
self.ppf = nn.Identity()

def _make_stage(self, in_channels, out_channels, num_blocks, stride,
dilation, next_create_block_idx, init_cfg):
strides = [stride] + [1] * (num_blocks - 1)
Expand Down Expand Up @@ -507,6 +584,8 @@ def forward(self, x):
for i, stage_name in enumerate(self.stages):
stage = getattr(self, stage_name)
x = stage(x)
if i + 1 == len(self.stages):
x = self.ppf(x)
if i in self.out_indices:
outs.append(x)

Expand Down
78 changes: 67 additions & 11 deletions tests/test_models/test_backbones/test_repvgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,18 +202,36 @@ def test_repvgg_backbone():
# Test RepVGG forward with layer 3 forward
model = RepVGG('A0', out_indices=(3, ))
model.init_weights()
model.train()
model.eval()

for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)

imgs = torch.randn(1, 3, 224, 224)
imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 1280, 1, 1))

# Test with custom arch
cfg = dict(
num_blocks=[3, 5, 7, 3],
width_factor=[1, 1, 1, 1],
group_layer_map=None,
se_cfg=None,
stem_channels=16)
model = RepVGG(arch=cfg, out_indices=(3, ))
model.eval()
assert model.stem.out_channels == min(16, 64 * 1)

imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 1280, 7, 7))
assert feat[0].shape == torch.Size((1, 512, 1, 1))

# Test RepVGG forward
model_test_settings = [
Expand All @@ -233,31 +251,31 @@ def test_repvgg_backbone():
dict(model_name='D2se', out_sizes=(160, 320, 640, 2560))
]

choose_models = ['A0', 'B1', 'B1g2', 'D2se']
choose_models = ['A0', 'B1', 'B1g2']
# Test RepVGG model forward
for model_test_setting in model_test_settings:
if model_test_setting['model_name'] not in choose_models:
continue
model = RepVGG(
model_test_setting['model_name'], out_indices=(0, 1, 2, 3))
model.init_weights()
model.eval()

# Test Norm
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)

model.train()
imgs = torch.randn(1, 3, 224, 224)
imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert feat[0].shape == torch.Size(
(1, model_test_setting['out_sizes'][0], 56, 56))
(1, model_test_setting['out_sizes'][0], 8, 8))
assert feat[1].shape == torch.Size(
(1, model_test_setting['out_sizes'][1], 28, 28))
(1, model_test_setting['out_sizes'][1], 4, 4))
assert feat[2].shape == torch.Size(
(1, model_test_setting['out_sizes'][2], 14, 14))
(1, model_test_setting['out_sizes'][2], 2, 2))
assert feat[3].shape == torch.Size(
(1, model_test_setting['out_sizes'][3], 7, 7))
(1, model_test_setting['out_sizes'][3], 1, 1))

# Test eval of "train" mode and "deploy" mode
gap = nn.AdaptiveAvgPool2d(output_size=(1))
Expand All @@ -275,11 +293,49 @@ def test_repvgg_backbone():
torch.allclose(feat[i], feat_deploy[i])
torch.allclose(pred, pred_deploy)

# Test RepVGG forward with add_ppf
model = RepVGG('A0', out_indices=(3, ), add_ppf=True)
model.init_weights()
model.train()

for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)

imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 1280, 2, 2))

# Test RepVGG forward with 'stem_channels' not in arch
arch = dict(
num_blocks=[2, 4, 14, 1],
width_factor=[0.75, 0.75, 0.75, 2.5],
group_layer_map=None,
se_cfg=None)
model = RepVGG(arch, add_ppf=True)
model.stem.in_channels = min(64, 64 * 0.75)
model.init_weights()
model.train()

for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)

imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 1280, 2, 2))


def test_repvgg_load():
# Test output before and load from deploy checkpoint
model = RepVGG('A1', out_indices=(0, 1, 2, 3))
inputs = torch.randn((1, 3, 224, 224))
inputs = torch.randn((1, 3, 32, 32))
ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth')
model.switch_to_deploy()
model.eval()
Expand Down