Skip to content

Commit

Permalink
[Feature] Add swin-transformer model. (open-mmlab#271)
Browse files Browse the repository at this point in the history
* Add swin transformer archs S, B and L.

* Add SwinTransformer configs

* Add train config files of swin.

* Align init method with original code

* Use nn.Unfold to merge patch

* Change all ConfigDict to dict

* Add init_cfg for all subclasses of BaseModule.

* Use mmcv version init function

* Add Swin README

* Use safer cfg copy method

* Improve docstring and variable name.

* Fix some difference in randaug

Fix BGR bug, align scheduler config.

Fix label smoothing parameter difference.

* Fix missing droppath in attn

* Fix bug of relative posititon table if window width is not equal to
height.

* Make `PatchMerging` more general, support kernel, stride, padding and
dilation.

* Rename `residual` to `identity` in attention and FFN.

* Add `auto_pad` option to auto pad feature map

* Improve docstring.

* Fix bug in ShiftWMSA padding.

* Remove unused `key` and `value` in ShiftWMSA

* Move `PatchMerging` into utils and use common `PatchEmbed`.

* Use latest `LinearClsHead`, train augments and label smooth settings.
And remove original `SwinLinearClsHead`.

* Mark some configs as "Evalution Only".

* Remove useless comment in config

* 1. Move ShiftWindowMSA and WindowMSA to `utils/attention.py`
2. Add docstrings of each module.
3. Fix some variables' names.
4. Other small improvement.

* Add unit tests of swin-transformer and patchmerging.

* Fix some bugs in unit tests.

* Fix bug of rel_position_index if window is not square.

* Make WindowMSA implicit, and add unit tests.

* Add metafile.yml, update readme and model_zoo.
  • Loading branch information
mzr1996 authored Jul 1, 2021
1 parent a2d604b commit 5f521f6
Show file tree
Hide file tree
Showing 28 changed files with 1,569 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Supported backbones:
- [x] ShuffleNetV2
- [x] MobileNetV2
- [x] MobileNetV3
- [x] Swin-Transformer

## Installation

Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O
- [x] ShuffleNetV2
- [x] MobileNetV2
- [x] MobileNetV3
- [x] Swin-Transformer

## 安装

Expand Down
122 changes: 122 additions & 0 deletions configs/_base_/datasets/imagenet_bs128_swin_224.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

policies = [
dict(type='AutoContrast'),
dict(type='Equalize'),
dict(type='Invert'),
dict(
type='Rotate',
interpolation='bicubic',
magnitude_key='angle',
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
magnitude_range=(0, 30)),
dict(type='Posterize', magnitude_key='bits', magnitude_range=(4, 0)),
dict(type='Solarize', magnitude_key='thr', magnitude_range=(256, 0)),
dict(
type='SolarizeAdd',
magnitude_key='magnitude',
magnitude_range=(0, 110)),
dict(
type='ColorTransform',
magnitude_key='magnitude',
magnitude_range=(0, 0.9)),
dict(type='Contrast', magnitude_key='magnitude', magnitude_range=(0, 0.9)),
dict(
type='Brightness', magnitude_key='magnitude',
magnitude_range=(0, 0.9)),
dict(
type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0.9)),
dict(
type='Shear',
interpolation='bicubic',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
direction='horizontal'),
dict(
type='Shear',
interpolation='bicubic',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
direction='vertical'),
dict(
type='Translate',
interpolation='bicubic',
magnitude_key='magnitude',
magnitude_range=(0, 0.45),
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
direction='horizontal'),
dict(
type='Translate',
interpolation='bicubic',
magnitude_key='magnitude',
magnitude_range=(0, 0.45),
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
direction='vertical')
]

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies=policies,
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=img_norm_cfg['mean'][::-1],
fill_std=img_norm_cfg['std'][::-1]),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
size=(256, -1),
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=128,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_prefix='data/imagenet/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline))

evaluation = dict(interval=10, metric='accuracy')
43 changes: 43 additions & 0 deletions configs/_base_/datasets/imagenet_bs128_swin_384.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=384,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=384, backend='pillow', interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=128,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_prefix='data/imagenet/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline))
evaluation = dict(interval=10, metric='accuracy')
22 changes: 22 additions & 0 deletions configs/_base_/models/swin_transformer/base_224.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer', arch='base', img_size=224, drop_path_rate=0.5),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
16 changes: 16 additions & 0 deletions configs/_base_/models/swin_transformer/base_384.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# model settings
# Only for evaluation
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer',
arch='base',
img_size=384,
stage_cfg=dict(block_cfg=dict(window_size=12))),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5)))
12 changes: 12 additions & 0 deletions configs/_base_/models/swin_transformer/large_224.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# model settings
# Only for evaluation
model = dict(
type='ImageClassifier',
backbone=dict(type='SwinTransformer', arch='large', img_size=224),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1536,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5)))
16 changes: 16 additions & 0 deletions configs/_base_/models/swin_transformer/large_384.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# model settings
# Only for evaluation
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer',
arch='large',
img_size=384,
stage_cfg=dict(block_cfg=dict(window_size=12))),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1536,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5)))
23 changes: 23 additions & 0 deletions configs/_base_/models/swin_transformer/small_224.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer', arch='small', img_size=224,
drop_path_rate=0.3),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
22 changes: 22 additions & 0 deletions configs/_base_/models/swin_transformer/tiny_224.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer', arch='tiny', img_size=224, drop_path_rate=0.2),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
30 changes: 30 additions & 0 deletions configs/_base_/schedules/imagenet_bs1024_adamw_swin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
paramwise_cfg = dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
custom_keys={
'.absolute_pos_embed': dict(decay_mult=0.0),
'.relative_position_bias_table': dict(decay_mult=0.0)
})

# for batch in each gpu is 128, 8 gpu
# lr = 5e-4 * 128 * 8 / 512 = 0.001
optimizer = dict(
type='AdamW',
lr=5e-4 * 128 * 8 / 512,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
paramwise_cfg=paramwise_cfg)
optimizer_config = dict(grad_clip=dict(max_norm=5.0))

# learning policy
lr_config = dict(
policy='CosineAnnealing',
by_epoch=False,
min_lr_ratio=1e-2,
warmup='linear',
warmup_ratio=1e-3,
warmup_iters=20 * 1252,
warmup_by_epoch=False)

runner = dict(type='EpochBasedRunner', max_epochs=300)
41 changes: 41 additions & 0 deletions configs/swin_transformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

## Introduction

[ALGORITHM]

```latex
@article{liu2021Swin,
title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
journal={arXiv preprint arXiv:2103.14030},
year={2021}
}
```

## Pretrain model

The pre-trained modles are converted from [model zoo of Swin Transformer](https://github.com/microsoft/Swin-Transformer#main-results-on-imagenet-with-pretrained-models).

### ImageNet 1k

| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download |
|:---------:|:------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:--------:|
| Swin-T | ImageNet-1k | 224x224 | 28.29 | 4.36 | 81.18 | 95.52 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_tiny_patch4_window7_224-160bb0a5.pth)|
| Swin-S | ImageNet-1k | 224x224 | 49.61 | 8.52 | 83.21 | 96.25 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_small_patch4_window7_224-cc7a01c9.pth)|
| Swin-B | ImageNet-1k | 224x224 | 87.77 | 15.14 | 83.42 | 96.44 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224-4670dd19.pth)|
| Swin-B | ImageNet-1k | 384x384 | 87.90 | 44.49 | 84.49 | 96.95 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window12_384-02c598a4.pth)|
| Swin-B | ImageNet-22k | 224x224 | 87.77 | 15.14 | 85.16 | 97.50 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth)|
| Swin-B | ImageNet-22k | 384x384 | 87.90 | 44.49 | 86.44 | 98.05 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window12_384_22kto1k-d59b0d1d.pth)|
| Swin-L | ImageNet-22k | 224x224 | 196.53 | 34.04 | 86.24 | 97.88 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth)|
| Swin-L | ImageNet-22k | 384x384 | 196.74 | 100.04 | 87.25 | 98.25 | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window12_384_22kto1k-0a40944b.pth)|


## Results and models

### ImageNet
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------:|:------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:----------:|:--------:|
| Swin-T | ImageNet-1k | 224x224 | 28.29 | 4.36 | 81.18 | 95.61 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_tiny_224_imagenet.py) |[model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_imagenet-66df6be6.log.json)|
| Swin-S | ImageNet-1k | 224x224 | 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_imagenet-7f9d988b.log.json)|
| Swin-B | ImageNet-1k | 224x224 | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_imagenet-93230b0d.log.json)|
Loading

0 comments on commit 5f521f6

Please sign in to comment.