Skip to content

Commit

Permalink
[Feature] Add reduction for neck (open-mmlab#978)
Browse files Browse the repository at this point in the history
* feat: add reduction for neck

* feat: add reduction for neck

* feat: add reduction for neck

* feat:add linear reduction neck

* feat: add reduction neck

* mod out of linearReduction as tuple

* fix typo

* fix unit tests

* fix unit tests

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
  • Loading branch information
zzc98 and Ezra-Yu authored Nov 4, 2022
1 parent 2f05fd9 commit 28986fb
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 2 deletions.
6 changes: 5 additions & 1 deletion mmcls/models/necks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,9 @@
from .gap import GlobalAveragePooling
from .gem import GeneralizedMeanPooling
from .hr_fuse import HRFuseScales
from .reduction import LinearReduction

__all__ = ['GlobalAveragePooling', 'GeneralizedMeanPooling', 'HRFuseScales']
__all__ = [
'GlobalAveragePooling', 'GeneralizedMeanPooling', 'HRFuseScales',
'LinearReduction'
]
71 changes: 71 additions & 0 deletions mmcls/models/necks/reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmengine.model import BaseModule

from mmcls.registry import MODELS


@MODELS.register_module()
class LinearReduction(BaseModule):
"""Neck with Dimension reduction.
Args:
in_channels (int): Number of channels in the input.
out_channels (int): Number of channels in the output.
norm_cfg (dict, optional): dictionary to construct and
config norm layer. Defaults to dict(type='BN1d').
act_cfg (dict, optional): dictionary to construct and
config activate layer. Defaults to None.
init_cfg (dict, optional): dictionary to initialize weights.
Defaults to None.
"""

def __init__(self,
in_channels: int,
out_channels: int,
norm_cfg: Optional[dict] = dict(type='BN1d'),
act_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None):
super(LinearReduction, self).__init__(init_cfg=init_cfg)

self.in_channels = in_channels
self.out_channels = out_channels
self.norm_cfg = copy.deepcopy(norm_cfg)
self.act_cfg = copy.deepcopy(act_cfg)

self.reduction = nn.Linear(
in_features=in_channels, out_features=out_channels)
if norm_cfg:
self.norm = build_norm_layer(norm_cfg, out_channels)[1]
else:
self.norm = nn.Identity()
if act_cfg:
self.act = build_activation_layer(act_cfg)
else:
self.act = nn.Identity()

def forward(self, inputs: Union[Tuple,
torch.Tensor]) -> Tuple[torch.Tensor]:
"""forward function.
Args:
inputs (Union[Tuple, torch.Tensor]): The features extracted from
the backbone. Multiple stage inputs are acceptable but only
the last stage will be used.
Returns:
Tuple(torch.Tensor)): A tuple of reducted features.
"""
assert isinstance(inputs, (tuple, torch.Tensor)), (
'The inputs of `LinearReduction` neck must be tuple or '
f'`torch.Tensor`, but get {type(inputs)}.')
if isinstance(inputs, tuple):
inputs = inputs[-1]

out = self.act(self.norm(self.reduction(inputs)))
return (out, )
51 changes: 50 additions & 1 deletion tests/test_models/test_necks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from mmcls.models.necks import (GeneralizedMeanPooling, GlobalAveragePooling,
HRFuseScales)
HRFuseScales, LinearReduction)


def test_gap_neck():
Expand Down Expand Up @@ -85,3 +85,52 @@ def test_hr_fuse_scales():
assert isinstance(outs, tuple)
assert len(outs) == 1
assert outs[0].shape == (3, 1024, 7, 7)


def test_linear_reduction():
# test linear_reduction without `act_cfg` and `norm_cfg`
neck = LinearReduction(10, 5, None, None)
neck.eval()
assert isinstance(neck.act, torch.nn.Identity)
assert isinstance(neck.norm, torch.nn.Identity)

# batch_size, in_channels, out_channels
fake_input = torch.rand(1, 10)
output = neck(fake_input)
# batch_size, out_features
assert output[-1].shape == (1, 5)

# batch_size, in_features, feature_size(2)
fake_input = (torch.rand(1, 20), torch.rand(1, 10))

output = neck(fake_input)
# batch_size, out_features
assert output[-1].shape == (1, 5)

# test linear_reduction with `init_cfg`
neck = LinearReduction(
10, 5, init_cfg=dict(type='Xavier', layer=['Linear']))

# test linear_reduction with `act_cfg` and `norm_cfg`
neck = LinearReduction(
10, 5, act_cfg=dict(type='ReLU'), norm_cfg=dict(type='BN1d'))
neck.eval()

assert isinstance(neck.act, torch.nn.ReLU)
assert isinstance(neck.norm, torch.nn.BatchNorm1d)

# batch_size, in_channels, out_channels
fake_input = torch.rand(1, 10)
output = neck(fake_input)
# batch_size, out_features
assert output[-1].shape == (1, 5)
#
# # batch_size, in_features, feature_size(2)
fake_input = (torch.rand(1, 20), torch.rand(1, 10))

output = neck(fake_input)
# batch_size, out_features
assert output[-1].shape == (1, 5)

with pytest.raises(AssertionError):
neck([])

0 comments on commit 28986fb

Please sign in to comment.