Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Decoupled KD Loss #222

Merged
merged 31 commits into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions configs/distill/mmcls/dkd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Decoupled Knowledge Distillation

> [Decoupled Knowledge Distillation](https://arxiv.org/pdf/2203.08679.pdf)

<!-- [ALGORITHM] -->

## Abstract

State-of-the-art distillation methods are mainly based on distilling deep features from intermediate layers, while the significance of logit distillation is greatly overlooked. To provide a novel viewpoint to study logit distillation, we reformulate the classical KD loss into two parts, i.e., target class knowledge distillation (TCKD) and non-target class knowledge distillation (NCKD). We empirically investigate and prove the effects of the two parts: TCKD transfers knowledge concerning the "difficulty" of training samples, while NCKD is the prominent reason why logit distillation works. More importantly, we reveal that the classical KD loss is a coupled formulation, which (1) suppresses the effectiveness of NCKD and (2) limits the flexibility to balance these two parts. To address these issues, we present Decoupled Knowledge Distillation (DKD), enabling TCKD and NCKD to play their roles more efficiently and flexibly. Compared with complex feature-based methods, our DKD achieves comparable or even better results and has better training efficiency on CIFAR-100, ImageNet, and MS-COCO datasets for image classification and object detection tasks. This paper proves the great potential of logit distillation, and we hope it will be helpful for future research. The code is available at https://github.com/megvii-research/mdistiller.

![avatar](../../../../docs/en/imgs/model_zoo/dkd/dkd.png)

## Results and models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update readme

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


### Classification

| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download |
| -------- | --------- | --------- | --------- | --------- | ------------------------------------------ | ---------------------------------------------------------------------------------------------------- |
| ImageNet | ResNet-18 | ResNet-34 | 71.368 | 90.256 | [config](dkd_logits_r34_r18_8xb32_in1k.py) | [model & log](https://autolink.sensetime.com/pages/model/share/afc68955-e25d-4488-b044-5e801b3ff62f) |

## Citation

```latex
@article{zhao2022decoupled,
title={Decoupled Knowledge Distillation},
author={Zhao, Borui and Cui, Quan and Song, Renjie and Qiu, Yiyu and Liang, Jiajun},
journal={arXiv preprint arXiv:2203.08679},
year={2022}
}
```

## Getting Started

### Download teacher ckpt from

https://mmclassification.readthedocs.io/en/latest/papers/resnet.html

### Distillation training.

```bash
sh tools/slurm_train.sh $PARTITION $JOB_NAME \
configs/distill/mmcls/dkd/dkd_logits_r34_r18_8xb32_in1k.py \
$DISTILLATION_WORK_DIR
```
Copy link
Contributor

@fpshuang fpshuang Aug 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a special acknowledgement:

### Acknowledgement

Shout out to @Davidgzx for his special contribution.


## Acknowledgement

Shout out to Davidgzx for his special contribution.
45 changes: 45 additions & 0 deletions configs/distill/mmcls/dkd/dkd_logits_r34_r18_8xb32_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
_base_ = [
'mmcls::_base_/datasets/imagenet_bs32.py',
'mmcls::_base_/schedules/imagenet_bs256.py',
'mmcls::_base_/default_runtime.py'
]

model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True),
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth',
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc'),
gt_labels=dict(type='ModuleInputs', source='head.loss_module')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_dkd=dict(
type='DKDLoss',
tau=1,
beta=0.5,
loss_weight=1,
reduction='mean')),
loss_forward_mappings=dict(
loss_dkd=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc'),
gt_labels=dict(
recorder='gt_labels', from_student=True, data_idx=1)))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True),
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth',
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc'),
data_samples=dict(type='ModuleInputs', source='')),
gt_labels=dict(type='ModuleInputs', source='head.loss_module')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
Expand All @@ -31,8 +32,8 @@
loss_wsld=dict(
student=dict(recorder='fc', from_student=True),
teacher=dict(recorder='fc', from_student=False),
data_samples=dict(
recorder='data_samples', from_student=True, data_idx=1)))))
gt_labels=dict(
recorder='gt_labels', from_student=True, data_idx=1)))))

find_unused_parameters = True

Expand Down
Binary file added docs/en/imgs/model_zoo/dkd/dkd.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion mmrazor/models/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ab_loss import ABLoss
from .cwd import ChannelWiseDivergence
from .decoupled_kd import DKDLoss
from .kl_divergence import KLDivergence
from .l2_loss import L2Loss
from .relational_kd import AngleWiseRKD, DistanceWiseRKD
from .weighted_soft_label_distillation import WSLD

__all__ = [
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss'
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss'
]
157 changes: 157 additions & 0 deletions mmrazor/models/losses/decoupled_kd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmrazor.registry import MODELS


@MODELS.register_module()
class DKDLoss(nn.Module):
"""Decoupled Knowledge Distillation, CVPR2022.

link: https://arxiv.org/abs/2203.08679
reformulate the classical KD loss into two parts:
1. target class knowledge distillation (TCKD)
2. non-target class knowledge distillation (NCKD).
Args:
tau (float): Temperature coefficient. Defaults to 1.0.
alpha (float): Weight of TCKD loss. Defaults to 1.0.
beta (float): Weight of NCKD loss. Defaults to 1.0.
reduction (str): Specifies the reduction to apply to the loss:
``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.
``'none'``: no reduction will be applied,
``'batchmean'``: the sum of the output will be divided by
the batchsize,
``'sum'``: the output will be summed,
``'mean'``: the output will be divided by the number of
elements in the output.
Default: ``'batchmean'``
loss_weight (float): Weight of loss. Defaults to 1.0.
"""

def __init__(
self,
tau: float = 1.0,
alpha: float = 1.0,
beta: float = 1.0,
reduction: str = 'batchmean',
loss_weight: float = 1.0,
) -> None:
super(DKDLoss, self).__init__()
self.tau = tau
accept_reduction = {'none', 'batchmean', 'sum', 'mean'}
assert reduction in accept_reduction, \
f'KLDivergence supports reduction {accept_reduction}, ' \
f'but gets {reduction}.'
self.reduction = reduction
self.alpha = alpha
self.beta = beta
self.loss_weight = loss_weight

def forward(
self,
preds_S: torch.Tensor,
preds_T: torch.Tensor,
gt_labels: torch.Tensor,
) -> torch.Tensor:
"""DKDLoss forward function.

Args:
preds_S (torch.Tensor): The student model prediction, shape (N, C).
preds_T (torch.Tensor): The teacher model prediction, shape (N, C).
gt_labels (torch.Tensor): The gt label tensor, shape (N, C).

Return:
torch.Tensor: The calculated loss value.
"""
gt_mask = self._get_gt_mask(preds_S, gt_labels)
tckd_loss = self._get_tckd_loss(preds_S, preds_T, gt_labels, gt_mask)
nckd_loss = self._get_nckd_loss(preds_S, preds_T, gt_mask)
loss = self.alpha * tckd_loss + self.beta * nckd_loss
return self.loss_weight * loss

def _get_nckd_loss(
self,
preds_S: torch.Tensor,
preds_T: torch.Tensor,
gt_mask: torch.Tensor,
) -> torch.Tensor:
"""Calculate non-target class knowledge distillation."""
# implementation to mask out gt_mask, faster than index
s_nckd = F.log_softmax(preds_S / self.tau - 1000.0 * gt_mask, dim=1)
t_nckd = F.softmax(preds_T / self.tau - 1000.0 * gt_mask, dim=1)
return self._kl_loss(s_nckd, t_nckd)

def _get_tckd_loss(
self,
preds_S: torch.Tensor,
preds_T: torch.Tensor,
gt_labels: torch.Tensor,
gt_mask: torch.Tensor,
) -> torch.Tensor:
"""Calculate target class knowledge distillation."""
non_gt_mask = self._get_non_gt_mask(preds_S, gt_labels)
s_tckd = F.softmax(preds_S / self.tau, dim=1)
t_tckd = F.softmax(preds_T / self.tau, dim=1)
mask_student = torch.log(self._cat_mask(s_tckd, gt_mask, non_gt_mask))
mask_teacher = self._cat_mask(t_tckd, gt_mask, non_gt_mask)
return self._kl_loss(mask_student, mask_teacher)

def _kl_loss(
self,
preds_S: torch.Tensor,
preds_T: torch.Tensor,
) -> torch.Tensor:
"""Calculate the KL Divergence."""
kl_loss = F.kl_div(
preds_S, preds_T, size_average=False,
reduction=self.reduction) * self.tau**2
return kl_loss

def _cat_mask(
self,
tckd: torch.Tensor,
gt_mask: torch.Tensor,
non_gt_mask: torch.Tensor,
) -> torch.Tensor:
"""Calculate preds of target (pt) & preds of non-target (pnt)."""
t1 = (tckd * gt_mask).sum(dim=1, keepdims=True)
t2 = (tckd * non_gt_mask).sum(dim=1, keepdims=True)
return torch.cat([t1, t2], dim=1)

def _get_gt_mask(
self,
logits: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
"""Calculate groundtruth mask on logits with target class tensor.

Args:
logits (torch.Tensor): The prediction logits with shape (N, C).
target (torch.Tensor): The gt_label target with shape (N, C).

Return:
torch.Tensor: The masked logits.
"""
target = target.reshape(-1)
HIT-cwh marked this conversation as resolved.
Show resolved Hide resolved
return torch.zeros_like(logits).scatter_(1, target.unsqueeze(1),
1).bool()

def _get_non_gt_mask(
self,
logits: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
"""Calculate non-groundtruth mask on logits with target class tensor.

Args:
logits (torch.Tensor): The prediction logits with shape (N, C).
target (torch.Tensor): The gt_label target with shape (N, C).

Return:
torch.Tensor: The masked logits.
"""
target = target.reshape(-1)
HIT-cwh marked this conversation as resolved.
Show resolved Hide resolved
return torch.ones_like(logits).scatter_(1, target.unsqueeze(1),
0).bool()
12 changes: 1 addition & 11 deletions mmrazor/models/losses/weighted_soft_label_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,7 @@ def __init__(self, tau=1.0, loss_weight=1.0, num_classes=1000):
self.softmax = nn.Softmax(dim=1)
self.logsoftmax = nn.LogSoftmax(dim=1)

def forward(self, student, teacher, data_samples):

# Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label:
# Batch augmentation may convert labels to one-hot format scores.
gt_labels = torch.stack([i.gt_label.score for i in data_samples])
one_hot_labels = gt_labels.float()
else:
gt_labels = torch.hstack([i.gt_label.label for i in data_samples])
one_hot_labels = F.one_hot(
gt_labels, num_classes=self.num_classes).float()
def forward(self, student, teacher, gt_labels):

student_logits = student / self.tau
teacher_logits = teacher / self.tau
Expand Down
32 changes: 25 additions & 7 deletions tests/test_models/test_losses/test_distillation_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from mmrazor.models import ABLoss
from mmrazor.models import ABLoss, DKDLoss


class TestLosses(TestCase):
Expand All @@ -14,16 +14,28 @@ def setUpClass(cls):
cls.feats_2d = torch.randn(5, 2, 3)
cls.feats_3d = torch.randn(5, 2, 3, 3)

def normal_test_1d(self, loss_instance):
loss_1d = loss_instance.forward(self.feats_1d, self.feats_1d)
num_classes = 6
cls.labels = torch.randint(0, num_classes, [5])

def normal_test_1d(self, loss_instance, labels=False):
args = tuple([self.feats_1d, self.feats_1d])
if labels:
args += (self.labels, )
loss_1d = loss_instance.forward(*args)
self.assertTrue(loss_1d.numel() == 1)

def normal_test_2d(self, loss_instance):
loss_2d = loss_instance.forward(self.feats_2d, self.feats_2d)
def normal_test_2d(self, loss_instance, labels=False):
args = tuple([self.feats_2d, self.feats_2d])
if labels:
args += (self.labels, )
loss_2d = loss_instance.forward(*args)
self.assertTrue(loss_2d.numel() == 1)

def normal_test_3d(self, loss_instance):
loss_3d = loss_instance.forward(self.feats_3d, self.feats_3d)
def normal_test_3d(self, loss_instance, labels=False):
args = tuple([self.feats_3d, self.feats_3d])
if labels:
args += (self.labels, )
loss_3d = loss_instance.forward(*args)
self.assertTrue(loss_3d.numel() == 1)

def test_ab_loss(self):
Expand All @@ -32,3 +44,9 @@ def test_ab_loss(self):
self.normal_test_1d(ab_loss)
self.normal_test_2d(ab_loss)
self.normal_test_3d(ab_loss)

def test_dkd_loss(self):
dkd_loss_cfg = dict(loss_weight=1.0)
dkd_loss = DKDLoss(**dkd_loss_cfg)
# dkd requires label logits
self.normal_test_1d(dkd_loss, labels=True)