-
Notifications
You must be signed in to change notification settings - Fork 0
/
dice_loss.py
156 lines (124 loc) · 4.96 KB
/
dice_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Code taken from the following repository on GitHub:
# https://github.com/bonlime/pytorch-tools
from enum import Enum
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
def soft_dice_score(y_pred, y_true, dims=None, eps=1e-4):
"""
`Soft` means than when `y_pred` and `y_true` are zero this function will
return 1, while in many other implementations it will return 0.
Args:
y_pred (torch.Tensor): Of shape `NxCx*` where * means any
number of additional dimensions
y_true (torch.Tensor): `NxCx*`, same shape as `y_pred`
dims (Tuple[int], optional): Dims to use for calculating
eps (float): Laplace smoothing
"""
if y_pred.size() != y_true.size():
raise ValueError("Input and target shapes should match")
if dims is not None:
intersection = torch.sum(y_pred * y_true, dim=dims)
cardinality = torch.sum(y_pred + y_true, dim=dims)
else:
intersection = torch.sum(y_pred * y_true)
cardinality = torch.sum(y_pred + y_true)
dice_score = (2.0 * intersection + eps) / (cardinality + eps)
return dice_score
class Mode(Enum):
BINARY = "binary"
MULTICLASS = "multiclass"
MULTILABEL = "multilabel"
class Loss(_Loss):
"""Loss which supports addition and multiplication"""
def __add__(self, other):
if isinstance(other, Loss):
return SumOfLosses(self, other)
else:
raise ValueError("Loss should be inherited from `Loss` class")
def __radd__(self, other):
return self.__add__(other)
def __mul__(self, value):
if isinstance(value, (int, float)):
return WeightedLoss(self, value)
else:
raise ValueError("Loss should be multiplied by int or float")
def __rmul__(self, other):
return self.__mul__(other)
class WeightedLoss(Loss):
"""
Wrapper class around loss function that applies weighted with fixed factor.
This class helps to balance multiple losses if they have different scales
"""
def __init__(self, loss, weight=1.0):
super().__init__()
self.loss = loss
self.weight = torch.Tensor([weight])
def forward(self, *inputs):
l = self.loss(*inputs)
self.weight = self.weight.to(l.device)
return l * self.weight[0]
class SumOfLosses(Loss):
def __init__(self, l1, l2):
super().__init__()
self.l1 = l1
self.l2 = l2
def __call__(self, *inputs):
return self.l1(*inputs) + self.l2(*inputs)
class DiceLoss(Loss):
"""
Implementation of Dice loss for image segmentation task.
It supports binary, multiclass and multilabel cases
Args:
mode (str): Target mode {'binary', 'multiclass', 'multilabel'}
'multilabel' - expects y_true of shape [N, C, H, W]
'multiclass', 'binary' - expects y_true of shape [N, H, W]
log_loss (bool): If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
from_logits (bool): If True assumes input is raw logits
eps (float): small epsilon for numerical stability
Shape:
y_pred: [N, C, H, W]
y_true: [N, C, H, W] or [N, H, W] depending on mode
"""
IOU_FUNCTION = soft_dice_score
def __init__(self, mode="binary", log_loss=False, from_logits=False, eps=1.0):
super(DiceLoss, self).__init__()
self.mode = Mode(mode) # raises an error if not valid
self.log_loss = log_loss
self.from_logits = from_logits
self.eps = eps
def forward(self, y_pred, y_true):
if self.from_logits:
# Apply activations to get [0..1] class probabilities
if self.mode == Mode.BINARY or self.mode == Mode.MULTILABEL:
y_pred = y_pred.sigmoid()
elif self.mode == Mode.MULTICLASS:
y_pred = y_pred.softmax(dim=1)
bs = y_true.size(0)
num_classes = y_pred.size(1)
dims = (0, 2)
if self.mode == Mode.BINARY:
y_true = y_true.view(bs, 1, -1)
y_pred = y_pred.view(bs, 1, -1)
elif self.mode == Mode.MULTICLASS:
y_true = y_true.view(bs, -1)
y_pred = y_pred.view(bs, num_classes, -1)
y_true = torch.nn.functional.one_hot(
y_true, num_classes
) # N,H*W -> N,H*W, C
y_true = y_true.permute(0, 2, 1) # H, C, H*W
elif self.mode == Mode.MULTILABEL:
y_true = y_true.view(bs, num_classes, -1)
y_pred = y_pred.view(bs, num_classes, -1)
scores = self.__class__.IOU_FUNCTION(
y_pred, y_true.type(y_pred.dtype), dims=dims, eps=self.eps
)
if self.log_loss:
loss = -torch.log(scores)
else:
loss = 1 - scores
# IoU loss is defined for non-empty classes
# So we zero contribution of channel that does not have true pixels
mask = y_true.sum(dims) > 0
loss *= mask.float()
return loss.mean()