Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support RepLKnet backbone. #1129

Merged
merged 8 commits into from
Nov 21, 2022

Conversation

techmonsterwang
Copy link
Collaborator

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Add RepLKNet inference and test files.

Modification

Add RepLKNet inference and test files.

Checklist

After PR:

  • Add RepLKNet inference and test files..

@CLAassistant
Copy link

CLAassistant commented Oct 25, 2022

CLA assistant check
All committers have signed the CLA.

@Ezra-Yu Ezra-Yu changed the base branch from master to dev-1.x October 25, 2022 06:54
@mzr1996 mzr1996 changed the title Replknet [Feature] Support RepLKnet backbone. Nov 14, 2022
tools/test.py Outdated
@@ -138,6 +138,7 @@ def main():
# load config
cfg = Config.fromfile(args.config)
cfg = merge_args(cfg, args)
print(cfg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the unrelated modification.

Comment on lines 419 to 447
# def create_RepLKNet31B(drop_path_rate=0.3, num_classes=1000, with_cp=True, small_kernel_merged=False):
# return RepLKNet(arch='31B')

# def create_RepLKNet31L(drop_path_rate=0.3, num_classes=1000, with_cp=True, small_kernel_merged=False):
# return RepLKNet(large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], channels=[192,384,768,1536],
# drop_path_rate=drop_path_rate, small_kernel=5, num_classes=num_classes, with_cp=with_cp,
# small_kernel_merged=small_kernel_merged)

# def create_RepLKNetXL(drop_path_rate=0.3, num_classes=1000, with_cp=True, small_kernel_merged=False):
# return RepLKNet(large_kernel_sizes=[27,27,27,13], layers=[2,2,18,2], channels=[256,512,1024,2048],
# drop_path_rate=drop_path_rate, small_kernel=None, dw_ratio=1.5,
# num_classes=num_classes, with_cp=with_cp,
# small_kernel_merged=small_kernel_merged)

# if __name__ == '__main__':
# model = create_RepLKNet31B(small_kernel_merged=False)
# model.eval()
# print('------------------- training-time model -------------')
# for i in model.state_dict().keys():
# print(i)
# print(model)
# x = torch.randn(2, 3, 224, 224)
# origin_y = model(x)
# model.switch_to_deploy()
# print('------------------- after re-param -------------')
# print(model)
# reparam_y = model(x)
# print('------------------- the difference is ------------------------')
# print((origin_y - reparam_y).abs().sum())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove useless comments.

Comment on lines 16 to 33
def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias):
if type(kernel_size) is int:
use_large_impl = kernel_size > 5
else:
assert len(kernel_size) == 2 and kernel_size[0] == kernel_size[1]
use_large_impl = kernel_size[0] > 5
has_large_impl = 'LARGE_KERNEL_CONV_IMPL' in os.environ
if has_large_impl and in_channels == out_channels and out_channels == groups and use_large_impl and stride == 1 and padding == kernel_size // 2 and dilation == 1:
sys.path.append(os.environ['LARGE_KERNEL_CONV_IMPL'])
# Please follow the instructions https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/README.md
# export LARGE_KERNEL_CONV_IMPL=absolute_path_to_where_you_cloned_the_example (i.e., depthwise_conv2d_implicit_gemm.py)
# TODO more efficient PyTorch implementations of large-kernel convolutions. Pull requests are welcomed.
# Or you may try MegEngine. We have integrated an efficient implementation into MegEngine and it will automatically use it.
from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
else:
return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't have the DepthWiseConv2dImplicitGEMM in MMCV, please remove this function.
You can also add this operator to mmcv package.

Comment on lines 35 to 45
use_sync_bn = False

def enable_sync_bn():
global use_sync_bn
use_sync_bn = True

def get_bn(channels):
if use_sync_bn:
return nn.SyncBatchNorm(channels)
else:
return nn.BatchNorm2d(channels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have our sync_bn configuration, please use norm_cfg instead of this kind of code.

self.small_kernel_merged = True


class ConvFFN(BaseModule):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is FFN?


@MODELS.register_module()
class RepLKNet(BaseBackbone):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docstring

@codecov
Copy link

codecov bot commented Nov 18, 2022

Codecov Report

Base: 0.02% // Head: 89.16% // Increases project coverage by +89.14% 🎉

Coverage data is based on head (d17b6f6) compared to base (b8b31e9).
Patch has no changes to coverable lines.

❗ Current head d17b6f6 differs from pull request most recent head 951b774. Consider uploading reports for the commit 951b774 to get more accurate results

Additional details and impacted files
@@             Coverage Diff              @@
##           dev-1.x    #1129       +/-   ##
============================================
+ Coverage     0.02%   89.16%   +89.14%     
============================================
  Files          121      144       +23     
  Lines         8217    11086     +2869     
  Branches      1368     1764      +396     
============================================
+ Hits             2     9885     +9883     
+ Misses        8215      952     -7263     
- Partials         0      249      +249     
Flag Coverage Δ
unittests 89.16% <ø> (+89.14%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcls/apis/inference.py 0.00% <0.00%> (ø)
mmcls/datasets/transforms/compose.py
mmcls/models/classifiers/timm.py 25.67% <0.00%> (ø)
mmcls/structures/utils.py 77.77% <0.00%> (ø)
mmcls/models/backbones/mobilevit.py 91.15% <0.00%> (ø)
mmcls/models/utils/layer_scale.py 86.66% <0.00%> (ø)
mmcls/models/retrievers/base.py 100.00% <0.00%> (ø)
mmcls/utils/progress.py 66.66% <0.00%> (ø)
mmcls/engine/hooks/retriever_hooks.py 72.72% <0.00%> (ø)
mmcls/engine/hooks/switch_recipe_hook.py 88.46% <0.00%> (ø)
... and 134 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@mzr1996 mzr1996 merged commit 72c6bc4 into open-mmlab:dev-1.x Nov 21, 2022
mzr1996 added a commit to mzr1996/mmpretrain that referenced this pull request Nov 24, 2022
* update replknet configs

* update replknet test

* update replknet model

* update replknet model

* update replknet model

* update replknet model

* Fix docs and config names

Co-authored-by: mzr1996 <mzr1996@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants