-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
55 lines (42 loc) · 2.02 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
class Metrics(object):
def __init__(self,
logits,
target,
nb_classes):
super(Metrics, self).__init__()
self.acc, self.f1_score, self.iou= self._get_stat(logits, target, nb_classes)
self.mean_f1_score = self.f1_score[0:nb_classes-1].sum() / (nb_classes-1)
self.mean_iou = self.iou[0:nb_classes-1].sum() / (nb_classes-1)
def _get_stat(self, logits, target, nb_classes):
with torch.no_grad():
pred = torch.argmax(logits, dim=1).view(-1)
target = target.view(-1)
# cpu version
pixel_counter = torch.zeros(nb_classes)
acc = torch.zeros(nb_classes)
f1 = torch.zeros(nb_classes)
iou = torch.zeros(nb_classes)
for k in range(0, nb_classes):
# tp + fp
pred_inds = pred == k
# tp + fn
target_inds = target == k
# fn + tn
non_pred_inds = pred != k
# fp + tn
non_target_inds = target != k
# tp
interection = pred_inds[target_inds].long().sum().float()
# tp + fn + fp
union = pred_inds.long().sum().float() + target_inds.long().sum().float() - interection
# tn
non_interection = non_pred_inds[non_target_inds].long().sum().float()
# fn + fp / tp + fp + tp + fn
#denominator = non_pred_inds.long().sum().float() + non_target_inds.long().sum().float() - 2*non_interection
denominator = pred_inds.long().sum().float() + target_inds.long().sum().float()
pixel_counter[k] = target_inds.long().sum().float()
acc[k] = interection
f1[k] = ((2 * interection) / (denominator + 1e-10))
iou[k] = (interection / (union + 1e-10))
return (acc[:nb_classes-1].sum() / (pixel_counter[:nb_classes-1].sum() + 1e-10)), f1, iou