-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add Decoupled KD Loss #222
Merged
Merged
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
b9ecf28
add DKDLoss, config
b19e488
linting
083c90f
linting
98da1b7
reset default reduction
4ecd958
Merge branch 'dev-1.x' of github.com:spynccat/mmrazor into spynccat/d…
63ede5e
dkd ut
6ff06b4
Merge branch 'dev-1.x' into spynccat/dkd_loss
2e6fdab
Update decoupled_kd.py
spynccat 24d2cd6
Update decoupled_kd.py
spynccat 52b77db
Update decoupled_kd.py
spynccat 726d478
fix commit
316b30e
fix readme
45bff97
fix comments
4afa536
linting comment
38dd6b2
rename loss params
0c39f91
fix docstring
cdf4d82
Update decoupled_kd.py
spynccat b633be6
fix gt from config
ca25f02
merge fix
af035fa
Merge branch 'spynccat/dkd_loss' of github.com:spynccat/mmrazor into …
87ec0ac
fix ut & wsld
4e0925c
Update README.md
spynccat ae64330
Update README.md
spynccat 541423c
add Acknowledgement
spynccat a9c0080
Update README.md
spynccat d47cf72
Update README.md
spynccat 42405f1
Update README.md
spynccat 66192a8
Update README.md
spynccat 25f6797
fix readme style
spynccat 02f4425
Merge branch 'dev-1.x' into spynccat/dkd_loss
33b3883
fix md
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Decoupled Knowledge Distillation | ||
|
||
> [Decoupled Knowledge Distillation](https://arxiv.org/pdf/2203.08679.pdf) | ||
|
||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
State-of-the-art distillation methods are mainly based on distilling deep features from intermediate layers, while the significance of logit distillation is greatly overlooked. To provide a novel viewpoint to study logit distillation, we reformulate the classical KD loss into two parts, i.e., target class knowledge distillation (TCKD) and non-target class knowledge distillation (NCKD). We empirically investigate and prove the effects of the two parts: TCKD transfers knowledge concerning the "difficulty" of training samples, while NCKD is the prominent reason why logit distillation works. More importantly, we reveal that the classical KD loss is a coupled formulation, which (1) suppresses the effectiveness of NCKD and (2) limits the flexibility to balance these two parts. To address these issues, we present Decoupled Knowledge Distillation (DKD), enabling TCKD and NCKD to play their roles more efficiently and flexibly. Compared with complex feature-based methods, our DKD achieves comparable or even better results and has better training efficiency on CIFAR-100, ImageNet, and MS-COCO datasets for image classification and object detection tasks. This paper proves the great potential of logit distillation, and we hope it will be helpful for future research. The code is available at https://github.com/megvii-research/mdistiller. | ||
|
||
![avatar](../../../../docs/en/imgs/model_zoo/dkd/dkd.png) | ||
|
||
## Results and models | ||
|
||
### Classification | ||
|
||
| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download | | ||
| -------- | --------- | --------- | --------- | --------- | ------------------------------------------ | ---------------------------------------------------------------------------------------------------- | | ||
| ImageNet | ResNet-18 | ResNet-34 | 71.368 | 90.256 | [config](dkd_logits_r34_r18_8xb32_in1k.py) | [model & log](https://autolink.sensetime.com/pages/model/share/afc68955-e25d-4488-b044-5e801b3ff62f) | | ||
|
||
## Citation | ||
|
||
```latex | ||
@article{zhao2022decoupled, | ||
title={Decoupled Knowledge Distillation}, | ||
author={Zhao, Borui and Cui, Quan and Song, Renjie and Qiu, Yiyu and Liang, Jiajun}, | ||
journal={arXiv preprint arXiv:2203.08679}, | ||
year={2022} | ||
} | ||
``` | ||
|
||
## Getting Started | ||
|
||
### Download teacher ckpt from | ||
|
||
https://mmclassification.readthedocs.io/en/latest/papers/resnet.html | ||
|
||
### Distillation training. | ||
|
||
```bash | ||
sh tools/slurm_train.sh $PARTITION $JOB_NAME \ | ||
configs/distill/mmcls/dkd/dkd_logits_r34_r18_8xb32_in1k.py \ | ||
$DISTILLATION_WORK_DIR | ||
``` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a special acknowledgement:
|
||
|
||
## Acknowledgement | ||
|
||
Shout out to Davidgzx for his special contribution. |
45 changes: 45 additions & 0 deletions
45
configs/distill/mmcls/dkd/dkd_logits_r34_r18_8xb32_in1k.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
_base_ = [ | ||
'mmcls::_base_/datasets/imagenet_bs32.py', | ||
'mmcls::_base_/schedules/imagenet_bs256.py', | ||
'mmcls::_base_/default_runtime.py' | ||
] | ||
|
||
model = dict( | ||
_scope_='mmrazor', | ||
type='SingleTeacherDistill', | ||
data_preprocessor=dict( | ||
type='ImgDataPreprocessor', | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
bgr_to_rgb=True), | ||
architecture=dict( | ||
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False), | ||
teacher=dict( | ||
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True), | ||
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth', | ||
distiller=dict( | ||
type='ConfigurableDistiller', | ||
student_recorders=dict( | ||
fc=dict(type='ModuleOutputs', source='head.fc'), | ||
gt_labels=dict(type='ModuleInputs', source='head.loss_module')), | ||
teacher_recorders=dict( | ||
fc=dict(type='ModuleOutputs', source='head.fc')), | ||
distill_losses=dict( | ||
loss_dkd=dict( | ||
type='DKDLoss', | ||
tau=1, | ||
beta=0.5, | ||
loss_weight=1, | ||
reduction='mean')), | ||
loss_forward_mappings=dict( | ||
loss_dkd=dict( | ||
preds_S=dict(from_student=True, recorder='fc'), | ||
preds_T=dict(from_student=False, recorder='fc'), | ||
gt_labels=dict( | ||
recorder='gt_labels', from_student=True, data_idx=1))))) | ||
|
||
find_unused_parameters = True | ||
|
||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .ab_loss import ABLoss | ||
from .cwd import ChannelWiseDivergence | ||
from .decoupled_kd import DKDLoss | ||
from .kl_divergence import KLDivergence | ||
from .l2_loss import L2Loss | ||
from .relational_kd import AngleWiseRKD, DistanceWiseRKD | ||
from .weighted_soft_label_distillation import WSLD | ||
|
||
__all__ = [ | ||
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD', | ||
'WSLD', 'L2Loss', 'ABLoss' | ||
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from mmrazor.registry import MODELS | ||
|
||
|
||
@MODELS.register_module() | ||
class DKDLoss(nn.Module): | ||
"""Decoupled Knowledge Distillation, CVPR2022. | ||
|
||
link: https://arxiv.org/abs/2203.08679 | ||
reformulate the classical KD loss into two parts: | ||
1. target class knowledge distillation (TCKD) | ||
2. non-target class knowledge distillation (NCKD). | ||
Args: | ||
tau (float): Temperature coefficient. Defaults to 1.0. | ||
alpha (float): Weight of TCKD loss. Defaults to 1.0. | ||
beta (float): Weight of NCKD loss. Defaults to 1.0. | ||
reduction (str): Specifies the reduction to apply to the loss: | ||
``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. | ||
``'none'``: no reduction will be applied, | ||
``'batchmean'``: the sum of the output will be divided by | ||
the batchsize, | ||
``'sum'``: the output will be summed, | ||
``'mean'``: the output will be divided by the number of | ||
elements in the output. | ||
Default: ``'batchmean'`` | ||
loss_weight (float): Weight of loss. Defaults to 1.0. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
tau: float = 1.0, | ||
alpha: float = 1.0, | ||
beta: float = 1.0, | ||
reduction: str = 'batchmean', | ||
loss_weight: float = 1.0, | ||
) -> None: | ||
super(DKDLoss, self).__init__() | ||
self.tau = tau | ||
accept_reduction = {'none', 'batchmean', 'sum', 'mean'} | ||
assert reduction in accept_reduction, \ | ||
f'KLDivergence supports reduction {accept_reduction}, ' \ | ||
f'but gets {reduction}.' | ||
self.reduction = reduction | ||
self.alpha = alpha | ||
self.beta = beta | ||
self.loss_weight = loss_weight | ||
|
||
def forward( | ||
self, | ||
preds_S: torch.Tensor, | ||
preds_T: torch.Tensor, | ||
gt_labels: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""DKDLoss forward function. | ||
|
||
Args: | ||
preds_S (torch.Tensor): The student model prediction, shape (N, C). | ||
preds_T (torch.Tensor): The teacher model prediction, shape (N, C). | ||
gt_labels (torch.Tensor): The gt label tensor, shape (N, C). | ||
|
||
Return: | ||
torch.Tensor: The calculated loss value. | ||
""" | ||
gt_mask = self._get_gt_mask(preds_S, gt_labels) | ||
tckd_loss = self._get_tckd_loss(preds_S, preds_T, gt_labels, gt_mask) | ||
nckd_loss = self._get_nckd_loss(preds_S, preds_T, gt_mask) | ||
loss = self.alpha * tckd_loss + self.beta * nckd_loss | ||
return self.loss_weight * loss | ||
|
||
def _get_nckd_loss( | ||
self, | ||
preds_S: torch.Tensor, | ||
preds_T: torch.Tensor, | ||
gt_mask: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""Calculate non-target class knowledge distillation.""" | ||
# implementation to mask out gt_mask, faster than index | ||
s_nckd = F.log_softmax(preds_S / self.tau - 1000.0 * gt_mask, dim=1) | ||
t_nckd = F.softmax(preds_T / self.tau - 1000.0 * gt_mask, dim=1) | ||
return self._kl_loss(s_nckd, t_nckd) | ||
|
||
def _get_tckd_loss( | ||
self, | ||
preds_S: torch.Tensor, | ||
preds_T: torch.Tensor, | ||
gt_labels: torch.Tensor, | ||
gt_mask: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""Calculate target class knowledge distillation.""" | ||
non_gt_mask = self._get_non_gt_mask(preds_S, gt_labels) | ||
s_tckd = F.softmax(preds_S / self.tau, dim=1) | ||
t_tckd = F.softmax(preds_T / self.tau, dim=1) | ||
mask_student = torch.log(self._cat_mask(s_tckd, gt_mask, non_gt_mask)) | ||
mask_teacher = self._cat_mask(t_tckd, gt_mask, non_gt_mask) | ||
return self._kl_loss(mask_student, mask_teacher) | ||
|
||
def _kl_loss( | ||
self, | ||
preds_S: torch.Tensor, | ||
preds_T: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""Calculate the KL Divergence.""" | ||
kl_loss = F.kl_div( | ||
preds_S, preds_T, size_average=False, | ||
reduction=self.reduction) * self.tau**2 | ||
return kl_loss | ||
|
||
def _cat_mask( | ||
self, | ||
tckd: torch.Tensor, | ||
gt_mask: torch.Tensor, | ||
non_gt_mask: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""Calculate preds of target (pt) & preds of non-target (pnt).""" | ||
t1 = (tckd * gt_mask).sum(dim=1, keepdims=True) | ||
t2 = (tckd * non_gt_mask).sum(dim=1, keepdims=True) | ||
return torch.cat([t1, t2], dim=1) | ||
|
||
def _get_gt_mask( | ||
self, | ||
logits: torch.Tensor, | ||
target: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""Calculate groundtruth mask on logits with target class tensor. | ||
|
||
Args: | ||
logits (torch.Tensor): The prediction logits with shape (N, C). | ||
target (torch.Tensor): The gt_label target with shape (N, C). | ||
|
||
Return: | ||
torch.Tensor: The masked logits. | ||
""" | ||
target = target.reshape(-1) | ||
HIT-cwh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), | ||
1).bool() | ||
|
||
def _get_non_gt_mask( | ||
self, | ||
logits: torch.Tensor, | ||
target: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""Calculate non-groundtruth mask on logits with target class tensor. | ||
|
||
Args: | ||
logits (torch.Tensor): The prediction logits with shape (N, C). | ||
target (torch.Tensor): The gt_label target with shape (N, C). | ||
|
||
Return: | ||
torch.Tensor: The masked logits. | ||
""" | ||
target = target.reshape(-1) | ||
HIT-cwh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return torch.ones_like(logits).scatter_(1, target.unsqueeze(1), | ||
0).bool() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update readme
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed