-
Notifications
You must be signed in to change notification settings - Fork 9
/
loss.py
318 lines (260 loc) · 12.2 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import torch
import torch.nn as nn
from utils import to_onehot_tensor
class BinaryFocalWithLogitsLoss(nn.Module):
"""Computes the focal loss with logits for binary data.
The Focal Loss is designed to address the one-stage object detection scenario in
which there is an extreme imbalance between foreground and background classes during
training (e.g., 1:1000). Focal loss is defined as:
FL = alpha(1 - p)^gamma * CE(p, y)
where p are the probabilities, after applying the sigmoid to the logits, alpha is a
balancing parameter, gamma is the focusing parameter, and CE(p, y) is the
cross entropy loss. When gamma=0 and alpha=1 the focal loss equals cross entropy.
See: https://arxiv.org/abs/1708.02002
Arguments:
gamma (float, optional): focusing parameter. Default: 2.
alpha (float, optional): balancing parameter. Default: 0.25.
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the sum of the output will be divided by the number of
elements in the output, 'sum': the output will be summed. Default: 'mean'
eps (float, optional): small value to avoid division by zero. Default: 1e-6.
"""
def __init__(self, gamma=2, alpha=0.25, reduction="mean"):
super().__init__()
self.gamma = gamma
self.alpha = alpha
if reduction.lower() == "none":
self.reduction_op = None
elif reduction.lower() == "mean":
self.reduction_op = torch.mean
elif reduction.lower() == "sum":
self.reduction_op = torch.sum
else:
raise ValueError(
"expected one of ('none', 'mean', 'sum'), got {}".format(reduction)
)
def forward(self, input, target):
if input.size() != target.size():
raise ValueError(
"size mismatch, {} != {}".format(input.size(), target.size())
)
elif target.unique(sorted=True).tolist() not in [[0, 1], [0], [1]]:
raise ValueError("target values are not binary")
input = input.view(-1)
target = target.view(-1)
# Following the paper: probabilities = probabilities if y=1; otherwise,
# probabilities = 1-probabilities
probabilities = torch.sigmoid(input)
probabilities = torch.where(target == 1, probabilities, 1 - probabilities)
# Compute the loss
focal = self.alpha * (1 - probabilities).pow(self.gamma)
bce = nn.functional.binary_cross_entropy_with_logits(
input, target, reduction="none"
)
loss = focal * bce
if self.reduction_op is not None:
return self.reduction_op(loss)
else:
return loss
class FocalWithLogitsLoss(nn.Module):
"""Computes the focal loss with logits.
The Focal Loss is designed to address the one-stage object detection scenario in
which there is an extreme imbalance between foreground and background classes during
training (e.g., 1:1000). Focal loss is defined as:
FL = alpha(1 - p)^gamma * CE(p, y)
where p are the probabilities, after applying the softmax layer to the logits,
alpha is a balancing parameter, gamma is the focusing parameter, and CE(p, y) is the
cross entropy loss. When gamma=0 and alpha=1 the focal loss equals cross entropy.
See: https://arxiv.org/abs/1708.02002
Arguments:
gamma (float, optional): focusing parameter. Default: 2.
alpha (float, optional): balancing parameter. Default: 0.25.
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the sum of the output will be divided by the number of
elements in the output, 'sum': the output will be summed. Default: 'mean'
eps (float, optional): small value to avoid division by zero. Default: 1e-6.
"""
def __init__(self, gamma=2, alpha=0.25, reduction="mean"):
super().__init__()
self.gamma = gamma
self.alpha = alpha
if reduction.lower() == "none":
self.reduction_op = None
elif reduction.lower() == "mean":
self.reduction_op = torch.mean
elif reduction.lower() == "sum":
self.reduction_op = torch.sum
else:
raise ValueError(
"expected one of ('none', 'mean', 'sum'), got {}".format(reduction)
)
def forward(self, input, target):
if input.dim() == 4:
input = input.permute(0, 2, 3, 1)
input = input.contiguous().view(-1, input.size(-1))
elif input.dim() != 2:
raise ValueError(
"expected input of size 4 or 2, got {}".format(input.dim())
)
if target.dim() == 3:
target = target.contiguous().view(-1)
elif target.dim() != 1:
raise ValueError(
"expected target of size 3 or 1, got {}".format(target.dim())
)
if target.dim() != input.dim() - 1:
raise ValueError(
"expected target dimension {} for input dimension {}, got {}".format(
input.dim() - 1, input.dim(), target.dim()
)
)
m = input.size(0)
probabilities = nn.functional.softmax(input[range(m), target], dim=0)
focal = self.alpha * (1 - probabilities).pow(self.gamma)
ce = nn.functional.cross_entropy(input, target, reduction="none")
loss = focal * ce
if self.reduction_op is not None:
return self.reduction_op(loss)
else:
return loss
class BinaryDiceWithLogitsLoss(nn.Module):
"""Computes the Sørensen–Dice loss with logits for binary data.
DC = 2 * intersection(X, Y) / (|X| + |Y|)
where, X and Y are sets of binary data, in this case, probabilities and targets.
|X| and |Y| are the cardinalities of the corresponding sets. Probabilities are
computed using the sigmoid.
The optimizer minimizes the loss function therefore:
DL = -DC (min(-x) = max(x))
To make the loss positive (convenience) and because the coefficient is within
[0, -1], subtract 1.
DL = 1 - DC
See: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
Arguments:
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the sum of the output will be divided by the number of
elements in the output, 'sum': the output will be summed. Default: 'mean'
eps (float, optional): small value to avoid division by zero. Default: 1e-6.
"""
def __init__(self, reduction="mean", eps=1e-6):
super().__init__()
self.eps = eps
if reduction.lower() == "none":
self.reduction_op = None
elif reduction.lower() == "mean":
self.reduction_op = torch.mean
elif reduction.lower() == "sum":
self.reduction_op = torch.sum
else:
raise ValueError(
"expected one of ('none', 'mean', 'sum'), got {}".format(reduction)
)
def forward(self, input, target):
if input.size() != target.size():
raise ValueError(
"size mismatch, {} != {}".format(input.size(), target.size())
)
elif target.unique(sorted=True).tolist() not in [[0, 1], [0], [1]]:
raise ValueError("target values are not binary")
input = input.view(-1)
target = target.view(-1)
# Dice = 2 * intersection(X, Y) / (|X| + |Y|)
# X and Y are sets of binary data, in this case, probabilities and targets
# |X| and |Y| are the cardinalities of the corresponding sets
probabilities = torch.sigmoid(input)
num = torch.sum(target * probabilities)
den_t = torch.sum(target)
den_p = torch.sum(probabilities)
loss = 1 - (2 * (num / (den_t + den_p + self.eps)))
if self.reduction_op is not None:
return self.reduction_op(loss)
else:
return loss
class DiceWithLogitsLoss(nn.Module):
"""Computes the Sørensen–Dice loss with logits.
DC = 2 * intersection(X, Y) / (|X| + |Y|)
where, X and Y are sets of binary data, in this case, predictions and targets.
|X| and |Y| are the cardinalities of the corresponding sets. Probabilities are
computed using softmax.
The optimizer minimizes the loss function therefore:
DL = -DC (min(-x) = max(x))
To make the loss positive (convenience) and because the coefficient is within
[0, -1], subtract 1.
DL = 1 - DC
See: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
Arguments:
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the sum of the output will be divided by the number of
elements in the output, 'sum': the output will be summed. Default: 'mean'
eps (float, optional): small value to avoid division by zero. Default: 1e-6.
"""
def __init__(self, reduction="mean", eps=1e-6):
super().__init__()
self.eps = eps
if reduction.lower() == "none":
self.reduction_op = None
elif reduction.lower() == "mean":
self.reduction_op = torch.mean
elif reduction.lower() == "sum":
self.reduction_op = torch.sum
else:
raise ValueError(
"expected one of ('none', 'mean', 'sum'), got {}".format(reduction)
)
def forward(self, input, target):
if input.dim() != 2 and input.dim() != 4:
raise ValueError(
"expected input of size 4 or 2, got {}".format(input.dim())
)
if target.dim() != 1 and target.dim() != 3:
raise ValueError(
"expected target of size 3 or 1, got {}".format(target.dim())
)
if input.dim() == 4 and target.dim() == 3:
reduce_dims = (0, 3, 2)
elif input.dim() == 2 and target.dim() == 1:
reduce_dims = 0
else:
raise ValueError(
"expected target dimension {} for input dimension {}, got {}".format(
input.dim() - 1, input.dim(), target.dim()
)
)
target_onehot = to_onehot_tensor(target, num_classes=input.size(1), axis=1)
probabilities = nn.functional.softmax(input, 1)
# Dice = 2 * intersection(X, Y) / (|X| + |Y|)
# X and Y are sets of binary data, in this case, probabilities and targets
# |X| and |Y| are the cardinalities of the corresponding sets
num = torch.sum(target_onehot * probabilities, dim=reduce_dims)
den_t = torch.sum(target_onehot, dim=reduce_dims)
den_p = torch.sum(probabilities, dim=reduce_dims)
loss = 1 - (2 * (num / (den_t + den_p + self.eps)))
if self.reduction_op is not None:
return self.reduction_op(loss)
else:
return loss
class BCE_BDWithLogitsLoss(nn.Module):
def __init__(self, reduction="mean"):
super().__init__()
self.bdl_logits = BinaryDiceWithLogitsLoss(reduction=reduction)
bce_reduction = reduction
if reduction == "mean":
bce_reduction = "elementwise_mean"
self.bce_logits = nn.BCEWithLogitsLoss(reduction=bce_reduction)
def forward(self, input, target):
return self.bce_logits(input, target) + self.bdl_logits(input, target)
class BCE_LogBDWithLogitsLoss(nn.Module):
def __init__(self, reduction="mean"):
super().__init__()
self.bdl_logits = BinaryDiceWithLogitsLoss(reduction=reduction)
bce_reduction = reduction
if reduction == "mean":
bce_reduction = "elementwise_mean"
self.bce_logits = nn.BCEWithLogitsLoss(reduction=bce_reduction)
def forward(self, input, target):
return self.bce_logits(input, target) - torch.log(
1 - self.bdl_logits(input, target)
)