Skip to content

Commit

Permalink
[Feature] Add kd examples (#305)
Browse files Browse the repository at this point in the history
* support kd for mbv2 and shufflenetv2

* WIP: fix ckpt path

* WIP: fix kd r34-r18

* add metafile

* fix metafile

* delete
  • Loading branch information
HIT-cwh committed Oct 26, 2022
1 parent 1e8f886 commit db32b32
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 11 deletions.
8 changes: 5 additions & 3 deletions configs/distill/mmcls/kd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ A very simple way to improve the performance of almost any machine learning algo

### Classification

| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| logits | ImageNet | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 71.54 | 73.62 | 69.90 | [config](./wsld_cls_head_resnet34_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k/wsld_cls_head_resnet34_resnet18_8xb32_in1k_acc-71.54_20211222-91f28cf6.pth?versionId=CAEQHxiBgMC6memK7xciIGMzMDFlYTA4YzhlYTRiMTNiZWU0YTVhY2I5NjVkMjY2) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k/wsld_cls_head_resnet34_resnet18_8xb32_in1k_20211221_181516.log.json?versionId=CAEQHxiBgIDLmemK7xciIGNkM2FiN2Y4N2E5YjRhNDE4NDVlNmExNDczZDIxN2E5) |
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :------: | :------: | :-----------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| logits | ImageNet | [resnet34](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet34_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet18_8xb32_in1k.py) | 71.81 | 73.62 | 69.90 | [config](./kd_logits_resnet34_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth) \|[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/kd/kl_r18_w3/kd_logits_resnet34_resnet18_8xb32_in1k_w3_20221011_181115-5c6a834d.pth?versionId=CAEQThiBgID1_Me0oBgiIDE3NTk3MDgxZmU2YjRlMjVhMzg1ZTQwMmRhNmYyNGU2) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/kd/kl_r18_w3/kd_logits_resnet34_resnet18_8xb32_in1k_w3_20221011_181115-5c6a834d.json?versionId=CAEQThiBgMDx_se0oBgiIDQxNTM2MWZjZGRhNjRhZDZiZTIzY2Y0NDU3NDA4ODBl) |
| logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet50_8xb32_in1k.py) | [mobilenet-v2](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py) | | 76.55 | 71.86 | [config](./kd_logits_resnet50_mobilenet-v2_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| \[log\](\<>>) |
| logits | ImageNet | [resnet34](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet50_8xb32_in1k.py) | [shufflenet-v2](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/shufflenet_v2/shufflenet-v2-1x_16xb64_in1k.py) | | 76.55 | 69.55 | [config](./kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|\[model\](\<>>) \| \[log\](\<>>) |

## Citation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
'mmcls::_base_/default_runtime.py'
]

teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa: E501

model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
Expand All @@ -17,16 +19,16 @@
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True),
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth',
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=False),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_kl=dict(type='KLDivergence', tau=1, loss_weight=5)),
loss_kl=dict(type='KLDivergence', tau=1, loss_weight=3)),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
_base_ = ['mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py']

student = _base_.model

teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501

model = dict(
_scope_='mmrazor',
_delete_=True,
type='SingleTeacherDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
architecture=student,
teacher=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_kl=dict(type='KLDivergence', tau=1, loss_weight=3)),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
_base_ = ['mmcls::shufflenet_v2/shufflenet-v2-1x_16xb64_in1k.py']

student = _base_.model

teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501

model = dict(
_scope_='mmrazor',
_delete_=True,
type='SingleTeacherDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
architecture=student,
teacher=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_kl=dict(type='KLDivergence', tau=1, loss_weight=3)),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
56 changes: 51 additions & 5 deletions configs/distill/mmcls/kd/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ Collections:
URL: https://arxiv.org/abs/1503.02531
Title: Distilling the Knowledge in a Neural Network
README: configs/distill/mmcls/kd/README.md
Code:
URL: https://github.com/open-mmlab/mmrazor/blob/v0.1.0/mmrazor/models/losses/weighted_soft_label_distillation.py
Version: v0.1.0

Models:
- Name: kd_logits_resnet34_resnet18_8xb32_in1k
In Collection: KD
Expand All @@ -31,6 +29,54 @@ Models:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 71.54
Top 1 Accuracy: 71.81
Config: configs/distill/mmcls/kd/kd_logits_resnet34_resnet18_8xb32_in1k.py
Weights: https://download.openmmlab.com/mmrazor/v0.1/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k/wsld_cls_head_resnet34_resnet18_8xb32_in1k_acc-71.54_20211222-91f28cf6.pth
Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/kd/kl_r18_w3/kd_logits_resnet34_resnet18_8xb32_in1k_w3_20221011_181115-5c6a834d.pth?versionId=CAEQThiBgID1_Me0oBgiIDE3NTk3MDgxZmU2YjRlMjVhMzg1ZTQwMmRhNmYyNGU2

- Name: kd_logits_resnet50_mobilenet-v2_8xb32_in1k
In Collection: KD
Metadata:
Location: logits
Student:
Config: mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py
Weights: https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth
Metrics:
Top 1 Accuracy: 71.86
Top 5 Accuracy: 90.42
Teacher:
Config: mmcls::resnet/resnet50_8xb32_in1k.py
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth
Metrics:
Top 1 Accuracy: 76.55
Top 5 Accuracy: 93.06
Results:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 73.68
Config: configs/distill/mmcls/kd/kd_logits_resnet50_mobilenet-v2_8xb32_in1k.py
Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/kd/kl_mbv2/kd_logits_resnet50_mobilenet-v2_8xb32_in1k_20221003_124149-d9548d27.pth?versionId=CAEQThiBgICdqMS0oBgiIGEzMmI1NmIwODI2ODQyODFiM2ZmOTE3NDQ3NmU1Yjhh

- Name: kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k
In Collection: KD
Metadata:
Location: logits
Student:
Config: mmcls::shufflenet_v2/shufflenet-v2-1x_16xb64_in1k.py
Weights: https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth
Metrics:
Top 1 Accuracy: 69.55
Top 5 Accuracy: 88.92
Teacher:
Config: mmcls::resnet/resnet50_8xb32_in1k.py
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth
Metrics:
Top 1 Accuracy: 76.55
Top 5 Accuracy: 93.06
Results:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 71.27
Config: configs/distill/mmcls/kd/kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k.py
Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/kd/kl_shuffle/kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k_20221003_124324-666a2414.pth?versionId=CAEQThiBgMCe78e0oBgiIDdhYjA0NTE4ZDZjMDRmOTU4YzNkN2E2ODNmODUwOGY5

0 comments on commit db32b32

Please sign in to comment.