diff --git a/mmcls/models/backbones/convnext.py b/mmcls/models/backbones/convnext.py index 1e0a3e9c5ea..cb9d3a66406 100644 --- a/mmcls/models/backbones/convnext.py +++ b/mmcls/models/backbones/convnext.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torch.utils.checkpoint as cp from mmcv.cnn.bricks import (NORM_LAYERS, DropPath, build_activation_layer, build_norm_layer) from mmcv.runner import BaseModule @@ -77,8 +78,11 @@ def __init__(self, mlp_ratio=4., linear_pw_conv=True, drop_path_rate=0., - layer_scale_init_value=1e-6): + layer_scale_init_value=1e-6, + with_cp=False): super().__init__() + self.with_cp = with_cp + self.depthwise_conv = nn.Conv2d( in_channels, in_channels, @@ -108,24 +112,33 @@ def __init__(self, drop_path_rate) if drop_path_rate > 0. else nn.Identity() def forward(self, x): - shortcut = x - x = self.depthwise_conv(x) - x = self.norm(x) - if self.linear_pw_conv: - x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + def _inner_forward(x): + shortcut = x + x = self.depthwise_conv(x) + x = self.norm(x) - x = self.pointwise_conv1(x) - x = self.act(x) - x = self.pointwise_conv2(x) + if self.linear_pw_conv: + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) - if self.linear_pw_conv: - x = x.permute(0, 3, 1, 2) # permute back + x = self.pointwise_conv1(x) + x = self.act(x) + x = self.pointwise_conv2(x) - if self.gamma is not None: - x = x.mul(self.gamma.view(1, -1, 1, 1)) + if self.linear_pw_conv: + x = x.permute(0, 3, 1, 2) # permute back + + if self.gamma is not None: + x = x.mul(self.gamma.view(1, -1, 1, 1)) + + x = shortcut + self.drop_path(x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) - x = shortcut + self.drop_path(x) return x @@ -169,6 +182,8 @@ class ConvNeXt(BaseBackbone): gap_before_final_norm (bool): Whether to globally average the feature map before the final norm layer. In the official repo, it's only used in classification task. Defaults to True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. init_cfg (dict, optional): Initialization config dict """ # noqa: E501 arch_settings = { @@ -206,6 +221,7 @@ def __init__(self, out_indices=-1, frozen_stages=0, gap_before_final_norm=True, + with_cp=False, init_cfg=None): super().__init__(init_cfg=init_cfg) @@ -288,8 +304,8 @@ def __init__(self, norm_cfg=norm_cfg, act_cfg=act_cfg, linear_pw_conv=linear_pw_conv, - layer_scale_init_value=layer_scale_init_value) - for j in range(depth) + layer_scale_init_value=layer_scale_init_value, + with_cp=with_cp) for j in range(depth) ]) block_idx += depth diff --git a/tests/test_models/test_backbones/test_convnext.py b/tests/test_models/test_backbones/test_convnext.py index 35448b458b1..ccd002d1e2b 100644 --- a/tests/test_models/test_backbones/test_convnext.py +++ b/tests/test_models/test_backbones/test_convnext.py @@ -84,3 +84,13 @@ def test_convnext(): for i in range(2, 4): assert model.downsample_layers[i].training assert model.stages[i].training + + # Test Activation Checkpointing + model = ConvNeXt(arch='tiny', out_indices=-1, with_cp=True) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 1 + assert feat[0].shape == torch.Size([1, 768])