Skip to content

Commit

Permalink
[Feature] Add Transferal Perceptual Loss (#372)
Browse files Browse the repository at this point in the history
* [Feature] Add Texture Perceptual Loss

* Rename

* Rename

Co-authored-by: liyinshuo <liyinshuo@sensetime.com>
  • Loading branch information
Yshuo-Li and liyinshuo authored Jun 21, 2021
1 parent 82a964d commit 7f71f74
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 10 deletions.
8 changes: 5 additions & 3 deletions mmedit/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
MSECompositionLoss)
from .gan_loss import DiscShiftLoss, GANLoss, GradientPenaltyLoss
from .gradient_loss import GradientLoss
from .perceptual_loss import PerceptualLoss, PerceptualVGG
from .perceptual_loss import (PerceptualLoss, PerceptualVGG,
TransferalPerceptualLoss)
from .pixelwise_loss import CharbonnierLoss, L1Loss, MaskedTVLoss, MSELoss
from .utils import mask_reduce_loss, reduce_loss

__all__ = [
'L1Loss', 'MSELoss', 'CharbonnierLoss', 'L1CompositionLoss',
'MSECompositionLoss', 'CharbonnierCompLoss', 'GANLoss',
'GradientPenaltyLoss', 'PerceptualLoss', 'PerceptualVGG', 'reduce_loss',
'mask_reduce_loss', 'DiscShiftLoss', 'MaskedTVLoss', 'GradientLoss'
'TransferalPerceptualLoss', 'GradientPenaltyLoss', 'PerceptualLoss',
'PerceptualVGG', 'reduce_loss', 'mask_reduce_loss', 'DiscShiftLoss',
'MaskedTVLoss', 'GradientLoss'
]
62 changes: 61 additions & 1 deletion mmedit/models/losses/perceptual_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
import torchvision.models.vgg as vgg
from mmcv.runner import load_checkpoint
from torch.nn import functional as F

from mmedit.utils import get_root_logger
from ..registry import LOSSES
Expand Down Expand Up @@ -115,7 +116,9 @@ class PerceptualLoss(nn.Module):
in forward function of vgg according to the statistics of dataset.
Importantly, the input image must be in range [-1, 1].
pretrained (str): Path for pretrained weights. Default:
'torchvision://vgg19'
'torchvision://vgg19'.
criterion (str): Criterion type. Options are 'l1' and 'mse'.
Default: 'l1'.
"""

def __init__(self,
Expand All @@ -138,6 +141,7 @@ def __init__(self,
use_input_norm=use_input_norm,
pretrained=pretrained)

criterion = criterion.lower()
if criterion == 'l1':
self.criterion = torch.nn.L1Loss()
elif criterion == 'mse':
Expand Down Expand Up @@ -202,3 +206,59 @@ def _gram_mat(self, x):
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram


@LOSSES.register_module()
class TransferalPerceptualLoss(nn.Module):
"""Transferal perceptual loss.
Args:
loss_weight (float): Loss weight. Default: 1.0.
use_attention (bool): If True, use soft-attention tensor. Default: True
criterion (str): Criterion type. Options are 'l1' and 'mse'.
Default: 'l1'.
"""

def __init__(self, loss_weight=1.0, use_attention=True, criterion='mse'):
super().__init__()
self.use_attention = use_attention
self.loss_weight = loss_weight
criterion = criterion.lower()
if criterion == 'l1':
self.loss_function = torch.nn.L1Loss()
elif criterion == 'mse':
self.loss_function = torch.nn.MSELoss()
else:
raise ValueError(
f"criterion should be 'l1' or 'mse', but got {criterion}")

def forward(self, maps, soft_attention, textures):
"""Forward function.
Args:
maps (Tuple[Tensor]): Input tensors.
soft_attention (Tensor): Soft-attention tensor.
textures (Tuple[Tensor]): Ground-truth tensors.
Returns:
Tensor: Forward results.
"""

if self.use_attention:
h, w = soft_attention.shape[-2:]
softs = [torch.sigmoid(soft_attention)]
for i in range(1, len(maps)):
softs.append(
F.interpolate(
soft_attention,
size=(h * pow(2, i), w * pow(2, i)),
mode='bicubic',
align_corners=False))
else:
softs = [1., 1., 1.]

loss_texture = 0
for map, soft, texture in zip(maps, softs, textures):
loss_texture += self.loss_function(map * soft, texture * soft)

return loss_texture * self.loss_weight
65 changes: 59 additions & 6 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import pytest
import torch

from mmedit.models.losses import (CharbonnierCompLoss, CharbonnierLoss,
DiscShiftLoss, GANLoss, GradientLoss,
GradientPenaltyLoss, L1CompositionLoss,
L1Loss, MaskedTVLoss, MSECompositionLoss,
MSELoss, PerceptualLoss, PerceptualVGG,
mask_reduce_loss, reduce_loss)
from mmedit.models import (CharbonnierCompLoss, CharbonnierLoss, DiscShiftLoss,
GANLoss, GradientLoss, GradientPenaltyLoss,
L1CompositionLoss, L1Loss, MaskedTVLoss,
MSECompositionLoss, MSELoss, PerceptualLoss,
PerceptualVGG, TransferalPerceptualLoss,
mask_reduce_loss, reduce_loss)


def test_utils():
Expand Down Expand Up @@ -300,6 +300,59 @@ def test_perceptual_loss(init_weights):
init_weights.reset_mock()


def test_t_perceptual_loss():

maps = [
torch.rand((2, 8, 8, 8), requires_grad=True),
torch.rand((2, 4, 16, 16), requires_grad=True)
]
textures = [torch.rand((2, 8, 8, 8)), torch.rand((2, 4, 16, 16))]
soft = torch.rand((2, 1, 8, 8))

loss_t_percep = TransferalPerceptualLoss()
t_percep = loss_t_percep(maps, soft, textures)
assert t_percep.item() > 0

loss_t_percep = TransferalPerceptualLoss(
use_attention=False, criterion='l1')
t_percep = loss_t_percep(maps, soft, textures)
assert t_percep.item() > 0

if torch.cuda.is_available():
maps = [
torch.rand((2, 8, 8, 8)).cuda(),
torch.rand((2, 4, 16, 16)).cuda()
]
textures = [
torch.rand((2, 8, 8, 8)).cuda(),
torch.rand((2, 4, 16, 16)).cuda()
]
soft = torch.rand((2, 1, 8, 8)).cuda()
loss_t_percep = TransferalPerceptualLoss().cuda()
maps[0].requires_grad = True
maps[1].requires_grad = True

t_percep = loss_t_percep(maps, soft, textures)
assert t_percep.item() > 0

optim = torch.optim.SGD(params=maps, lr=10)
optim.zero_grad()
t_percep.backward()
optim.step()

t_percep_new = loss_t_percep(maps, soft, textures)
assert t_percep_new < t_percep

loss_t_percep = TransferalPerceptualLoss(
use_attention=False, criterion='l1').cuda()
t_percep = loss_t_percep(maps, soft, textures)
assert t_percep.item() > 0

# test whether vgg type is valid
with pytest.raises(ValueError):
TransferalPerceptualLoss(criterion='l2')


def test_gan_losses():
"""Test gan losses."""
with pytest.raises(NotImplementedError):
Expand Down

0 comments on commit 7f71f74

Please sign in to comment.