diff --git a/mmrazor/models/losses/kl_divergence.py b/mmrazor/models/losses/kl_divergence.py index 7946172f6..defa367b7 100644 --- a/mmrazor/models/losses/kl_divergence.py +++ b/mmrazor/models/losses/kl_divergence.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import torch import torch.nn as nn import torch.nn.functional as F @@ -13,18 +12,34 @@ class KLDivergence(nn.Module): Args: tau (float): Temperature coefficient. Defaults to 1.0. + reduction (str): Specifies the reduction to apply to the loss: + ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. + ``'none'``: no reduction will be applied, + ``'batchmean'``: the sum of the output will be divided by + the batchsize, + ``'sum'``: the output will be summed, + ``'mean'``: the output will be divided by the number of + elements in the output. + Default: ``'batchmean'`` loss_weight (float): Weight of loss. Defaults to 1.0. """ def __init__( self, tau=1.0, + reduction='batchmean', loss_weight=1.0, ): super(KLDivergence, self).__init__() self.tau = tau self.loss_weight = loss_weight + accept_reduction = {'none', 'batchmean', 'sum', 'mean'} + assert reduction in accept_reduction, \ + f'KLDivergence supports reduction {accept_reduction}, ' \ + f'but gets {reduction}.' + self.reduction = reduction + def forward(self, preds_S, preds_T): """Forward computation. @@ -37,11 +52,9 @@ def forward(self, preds_S, preds_T): Return: torch.Tensor: The calculated loss value. """ - N = preds_S.shape[0] preds_T = preds_T.detach() softmax_pred_T = F.softmax(preds_T / self.tau, dim=1) - - logsoftmax = torch.nn.LogSoftmax(dim=1) - loss = torch.sum(-softmax_pred_T * logsoftmax(preds_S / self.tau)) * ( - self.tau**2) - return self.loss_weight * loss / N + logsoftmax_preds_S = F.log_softmax(preds_S / self.tau, dim=1) + loss = (self.tau**2) * F.kl_div( + logsoftmax_preds_S, softmax_pred_T, reduction=self.reduction) + return self.loss_weight * loss