Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Activation Checkpointing for ConvNeXt #1152

Merged
merged 4 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions mmcls/models/backbones/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import (NORM_LAYERS, DropPath, build_activation_layer,
build_norm_layer)
from mmcv.runner import BaseModule
Expand Down Expand Up @@ -77,8 +78,11 @@ def __init__(self,
mlp_ratio=4.,
linear_pw_conv=True,
drop_path_rate=0.,
layer_scale_init_value=1e-6):
layer_scale_init_value=1e-6,
with_cp=False):
super().__init__()
self.with_cp = with_cp

self.depthwise_conv = nn.Conv2d(
in_channels,
in_channels,
Expand Down Expand Up @@ -108,24 +112,33 @@ def __init__(self,
drop_path_rate) if drop_path_rate > 0. else nn.Identity()

def forward(self, x):
shortcut = x
x = self.depthwise_conv(x)
x = self.norm(x)

if self.linear_pw_conv:
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
def _inner_forward(x):
shortcut = x
x = self.depthwise_conv(x)
x = self.norm(x)

x = self.pointwise_conv1(x)
x = self.act(x)
x = self.pointwise_conv2(x)
if self.linear_pw_conv:
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)

if self.linear_pw_conv:
x = x.permute(0, 3, 1, 2) # permute back
x = self.pointwise_conv1(x)
x = self.act(x)
x = self.pointwise_conv2(x)

if self.gamma is not None:
x = x.mul(self.gamma.view(1, -1, 1, 1))
if self.linear_pw_conv:
x = x.permute(0, 3, 1, 2) # permute back

if self.gamma is not None:
x = x.mul(self.gamma.view(1, -1, 1, 1))

x = shortcut + self.drop_path(x)
return x

if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)

x = shortcut + self.drop_path(x)
return x


Expand Down Expand Up @@ -169,6 +182,8 @@ class ConvNeXt(BaseBackbone):
gap_before_final_norm (bool): Whether to globally average the feature
map before the final norm layer. In the official repo, it's only
used in classification task. Defaults to True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): Initialization config dict
""" # noqa: E501
arch_settings = {
Expand Down Expand Up @@ -206,6 +221,7 @@ def __init__(self,
out_indices=-1,
frozen_stages=0,
gap_before_final_norm=True,
with_cp=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)

Expand Down Expand Up @@ -288,8 +304,8 @@ def __init__(self,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
linear_pw_conv=linear_pw_conv,
layer_scale_init_value=layer_scale_init_value)
for j in range(depth)
layer_scale_init_value=layer_scale_init_value,
with_cp=with_cp) for j in range(depth)
])
block_idx += depth

Expand Down
10 changes: 10 additions & 0 deletions tests/test_models/test_backbones/test_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,13 @@ def test_convnext():
for i in range(2, 4):
assert model.downsample_layers[i].training
assert model.stages[i].training

# Test Activation Checkpointing
model = ConvNeXt(arch='tiny', out_indices=-1, with_cp=True)
model.init_weights()
model.train()

imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 768])