-
Notifications
You must be signed in to change notification settings - Fork 228
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* support kd for mbv2 and shufflenetv2 * WIP: fix ckpt path * WIP: fix kd r34-r18 * add metafile * fix metafile * delete
- Loading branch information
Showing
5 changed files
with
135 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
37 changes: 37 additions & 0 deletions
37
configs/distill/mmcls/kd/kd_logits_resnet50_mobilenet-v2_8xb32_in1k.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
_base_ = ['mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py'] | ||
|
||
student = _base_.model | ||
|
||
teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501 | ||
|
||
model = dict( | ||
_scope_='mmrazor', | ||
_delete_=True, | ||
type='SingleTeacherDistill', | ||
data_preprocessor=dict( | ||
type='ImgDataPreprocessor', | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
bgr_to_rgb=True), | ||
architecture=student, | ||
teacher=dict( | ||
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False), | ||
teacher_ckpt=teacher_ckpt, | ||
distiller=dict( | ||
type='ConfigurableDistiller', | ||
student_recorders=dict( | ||
fc=dict(type='ModuleOutputs', source='head.fc')), | ||
teacher_recorders=dict( | ||
fc=dict(type='ModuleOutputs', source='head.fc')), | ||
distill_losses=dict( | ||
loss_kl=dict(type='KLDivergence', tau=1, loss_weight=3)), | ||
loss_forward_mappings=dict( | ||
loss_kl=dict( | ||
preds_S=dict(from_student=True, recorder='fc'), | ||
preds_T=dict(from_student=False, recorder='fc'))))) | ||
|
||
find_unused_parameters = True | ||
|
||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') |
37 changes: 37 additions & 0 deletions
37
configs/distill/mmcls/kd/kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
_base_ = ['mmcls::shufflenet_v2/shufflenet-v2-1x_16xb64_in1k.py'] | ||
|
||
student = _base_.model | ||
|
||
teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501 | ||
|
||
model = dict( | ||
_scope_='mmrazor', | ||
_delete_=True, | ||
type='SingleTeacherDistill', | ||
data_preprocessor=dict( | ||
type='ImgDataPreprocessor', | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
bgr_to_rgb=True), | ||
architecture=student, | ||
teacher=dict( | ||
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False), | ||
teacher_ckpt=teacher_ckpt, | ||
distiller=dict( | ||
type='ConfigurableDistiller', | ||
student_recorders=dict( | ||
fc=dict(type='ModuleOutputs', source='head.fc')), | ||
teacher_recorders=dict( | ||
fc=dict(type='ModuleOutputs', source='head.fc')), | ||
distill_losses=dict( | ||
loss_kl=dict(type='KLDivergence', tau=1, loss_weight=3)), | ||
loss_forward_mappings=dict( | ||
loss_kl=dict( | ||
preds_S=dict(from_student=True, recorder='fc'), | ||
preds_T=dict(from_student=False, recorder='fc'))))) | ||
|
||
find_unused_parameters = True | ||
|
||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters