Skip to content

Commit

Permalink
[Feature] Support gem pooling (#677)
Browse files Browse the repository at this point in the history
* add gem pooling

* add example config

* fix params

* add assert

* add param clamp

* add test assert

* add clamp

* fix conflict
  • Loading branch information
okotaku authored Feb 16, 2022
1 parent fcd5791 commit 43024cd
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 2 deletions.
17 changes: 17 additions & 0 deletions configs/_base_/models/resnet34_gem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=34,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GeneralizedMeanPooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
3 changes: 2 additions & 1 deletion mmcls/models/necks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .gap import GlobalAveragePooling
from .gem import GeneralizedMeanPooling
from .hr_fuse import HRFuseScales

__all__ = ['GlobalAveragePooling', 'HRFuseScales']
__all__ = ['GlobalAveragePooling', 'GeneralizedMeanPooling', 'HRFuseScales']
53 changes: 53 additions & 0 deletions mmcls/models/necks/gem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter

from ..builder import NECKS


def gem(x: Tensor, p: Parameter, eps: float = 1e-6, clamp=True) -> Tensor:
if clamp:
x = x.clamp(min=eps)
return F.avg_pool2d(x.pow(p), (x.size(-2), x.size(-1))).pow(1. / p)


@NECKS.register_module()
class GeneralizedMeanPooling(nn.Module):
"""Generalized Mean Pooling neck.
Note that we use `view` to remove extra channel after pooling. We do not
use `squeeze` as it will also remove the batch dimension when the tensor
has a batch dimension of size 1, which can lead to unexpected errors.
Args:
p (float): Parameter value.
Default: 3.
eps (float): epsilon.
Default: 1e-6
clamp (bool): Use clamp before pooling.
Default: True
"""

def __init__(self, p=3., eps=1e-6, clamp=True):
assert p >= 1, "'p' must be a value greater then 1"
super(GeneralizedMeanPooling, self).__init__()
self.p = Parameter(torch.ones(1) * p)
self.eps = eps
self.clamp = clamp

def forward(self, inputs):
if isinstance(inputs, tuple):
outs = tuple([
gem(x, p=self.p, eps=self.eps, clamp=self.clamp)
for x in inputs
])
outs = tuple(
[out.view(x.size(0), -1) for out, x in zip(outs, inputs)])
elif isinstance(inputs, torch.Tensor):
outs = gem(inputs, p=self.p, eps=self.eps, clamp=self.clamp)
outs = outs.view(inputs.size(0), -1)
else:
raise TypeError('neck inputs should be tuple or torch.tensor')
return outs
29 changes: 28 additions & 1 deletion tests/test_models/test_neck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import pytest
import torch

from mmcls.models.necks import GlobalAveragePooling, HRFuseScales
from mmcls.models.necks import (GeneralizedMeanPooling, GlobalAveragePooling,
HRFuseScales)


def test_gap_neck():
Expand Down Expand Up @@ -39,6 +40,32 @@ def test_gap_neck():
GlobalAveragePooling(dim='other')


def test_gem_neck():

# test gem_neck
neck = GeneralizedMeanPooling()
# batch_size, num_features, feature_size(2)
fake_input = torch.rand(1, 16, 24, 24)

output = neck(fake_input)
# batch_size, num_features
assert output.shape == (1, 16)

# test tuple input gem_neck
neck = GeneralizedMeanPooling()
# batch_size, num_features, feature_size(2)
fake_input = (torch.rand(1, 8, 24, 24), torch.rand(1, 16, 24, 24))

output = neck(fake_input)
# batch_size, num_features
assert output[0].shape == (1, 8)
assert output[1].shape == (1, 16)

with pytest.raises(AssertionError):
# p must be a value greater then 1
GeneralizedMeanPooling(p=0.5)


def test_hr_fuse_scales():

in_channels = (18, 32, 64, 128)
Expand Down

0 comments on commit 43024cd

Please sign in to comment.