-
Notifications
You must be signed in to change notification settings - Fork 6
/
loss_functions.py
119 lines (79 loc) · 3.42 KB
/
loss_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
def soft_dice_loss(input:torch.Tensor, target:torch.Tensor):
input_sigmoid = torch.sigmoid(input)
eps = 1e-6
iflat = input_sigmoid.flatten()
tflat = target.flatten()
intersection = (iflat * tflat).sum()
return 1 - ((2. * intersection) /
(iflat.sum() + tflat.sum() + eps))
def soft_dice_loss_multi_class(input:torch.Tensor, y:torch.Tensor):
p = torch.softmax(input, dim=1)
eps = 1e-6
sum_dims= (0, 2, 3) # Batch, height, width
intersection = (y * p).sum(dim=sum_dims)
denom = (y.sum(dim=sum_dims) + p.sum(dim=sum_dims)).clamp(eps)
loss = 1 - (2. * intersection / denom).mean()
return loss
def soft_dice_loss_multi_class_debug(input:torch.Tensor, y:torch.Tensor):
p = torch.softmax(input, dim=1)
eps = 1e-6
sum_dims= (0, 2, 3) # Batch, height, width
intersection = (y * p).sum(dim=sum_dims)
denom = (y.sum(dim=sum_dims) + p.sum(dim=sum_dims)).clamp(eps)
loss = 1 - (2. * intersection / denom).mean()
loss_components = 1 - 2 * intersection/denom
return loss, loss_components
def generalized_soft_dice_loss_multi_class(input:torch.Tensor, y:torch.Tensor):
p = torch.softmax(input, dim=1)
eps = 1e-12
# TODO [B, C, H, W] -> [C, B, H, W] because softdice includes all pixels
sum_dims= (0, 2, 3) # Batch, height, width
ysum = y.sum(dim=sum_dims)
wc = 1 / (ysum ** 2 + eps)
intersection = ((y * p).sum(dim=sum_dims) * wc).sum()
denom = ((ysum + p.sum(dim=sum_dims)) * wc).sum()
loss = 1 - (2. * intersection / denom)
return loss
def jaccard_like_loss_multi_class(input:torch.Tensor, y:torch.Tensor):
p = torch.softmax(input, dim=1)
eps = 1e-6
# TODO [B, C, H, W] -> [C, B, H, W] because softdice includes all pixels
sum_dims= (0, 2, 3) # Batch, height, width
intersection = (y * p).sum(dim=sum_dims)
denom = (y ** 2 + p ** 2).sum(dim=sum_dims) + (y*p).sum(dim=sum_dims) + eps
loss = 1 - (2. * intersection / denom).mean()
return loss
def jaccard_like_loss(input:torch.Tensor, target:torch.Tensor):
input_sigmoid = torch.sigmoid(input)
eps = 1e-6
iflat = input_sigmoid.flatten()
tflat = target.flatten()
intersection = (iflat * tflat).sum()
denom = (iflat**2 + tflat**2).sum() - (iflat * tflat).sum() + eps
return 1 - ((2. * intersection) / denom)
def jaccard_like_balanced_loss(input:torch.Tensor, target:torch.Tensor):
input_sigmoid = torch.sigmoid(input)
eps = 1e-6
iflat = input_sigmoid.flatten()
tflat = target.flatten()
intersection = (iflat * tflat).sum()
denom = (iflat**2 + tflat**2).sum() - (iflat * tflat).sum() + eps
piccard = (2. * intersection)/denom
n_iflat = 1-iflat
n_tflat = 1-tflat
neg_intersection = (n_iflat * n_tflat).sum()
neg_denom = (n_iflat**2 + n_tflat**2).sum() - (n_iflat * n_tflat).sum()
n_piccard = (2. * neg_intersection)/neg_denom
return 1 - piccard - n_piccard
def soft_dice_loss_balanced(input:torch.Tensor, target:torch.Tensor):
input_sigmoid = torch.sigmoid(input)
eps = 1e-6
iflat = input_sigmoid.flatten()
tflat = target.flatten()
intersection = (iflat * tflat).sum()
dice_pos = ((2. * intersection) /
(iflat.sum() + tflat.sum() + eps))
negatiev_intersection = ((1-iflat) * (1 - tflat)).sum()
dice_neg = (2 * negatiev_intersection) / ((1-iflat).sum() + (1-tflat).sum() + eps)
return 1 - dice_pos - dice_neg