-
Notifications
You must be signed in to change notification settings - Fork 11
/
validation.py
122 lines (94 loc) · 3.6 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# -*- coding: utf-8 -*-
"""
@Time : 2019/1/30 23:34
@Author : Wang Xin
@Email : wangxin_buaa@163.com
"""
import numpy as np
import torch
import time
import joblib
import torch.nn.functional as F
from libs.DenseCRF import DenseCRF
from libs.metrics import AverageMeter, Result
from libs import utils
def validate(args, val_loader, model, epoch, logger):
average_meter = AverageMeter()
model.eval() # switch to train mode
output_directory = utils.get_output_directory(args, check=True)
skip = len(val_loader) // 9 # save images every skip iters
if args.crf:
ITER_MAX = 10
POS_W = 3
POS_XY_STD = 1
BI_W = 4
BI_XY_STD = 67
BI_RGB_STD = 3
postprocessor = DenseCRF(
iter_max=ITER_MAX,
pos_xy_std=POS_XY_STD,
pos_w=POS_W,
bi_xy_std=BI_XY_STD,
bi_rgb_std=BI_RGB_STD,
bi_w=BI_W,
)
end = time.time()
for i, samples in enumerate(val_loader):
input = samples['image']
target = samples['label']
# itr_count += 1
input, target = input.cuda(), target.cuda()
# print('input size = ', input.size())
# print('target size = ', target.size())
torch.cuda.synchronize()
data_time = time.time() - end
# compute pred
end = time.time()
with torch.no_grad():
pred = model(input) # @wx 注意输出
torch.cuda.synchronize()
gpu_time = time.time() - end
# measure accuracy and record loss
result = Result()
pred = F.softmax(pred, 1)
if pred.size() != target.size():
pred = F.interpolate(pred, size=(target.size()[-2], target.size()[-1]), mode='bilinear', align_corners=True)
pred = pred.data.cpu().numpy()
target = target.data.cpu().numpy()
# Post Processing
if args.crf:
images = input.data.cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
pred = joblib.Parallel(n_jobs=-1)(
[joblib.delayed(postprocessor)(*pair) for pair in zip(images, pred)]
)
result.evaluate(pred, target, n_class=21)
average_meter.update(result, gpu_time, data_time, input.size(0))
end = time.time()
# save 8 images for visualization
rgb = input.data.cpu().numpy()[0]
target = target[0]
pred = np.argmax(pred, axis=1)
pred = pred[0]
if i == 0:
img_merge = utils.merge_into_row(rgb, target, pred)
elif (i < 8 * skip) and (i % skip == 0):
row = utils.merge_into_row(rgb, target, pred)
img_merge = utils.add_row(img_merge, row)
elif i == 8 * skip:
filename = output_directory + '/comparison_' + str(epoch) + '.png'
utils.save_image(img_merge, filename)
if (i + 1) % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
'mean_acc={result.mean_acc:.3f}({average.mean_acc:.3f}) '
'mean_iou={result.mean_iou:.3f}({average.mean_iou:.3f})'.format(
i + 1, len(val_loader), gpu_time=gpu_time, result=result, average=average_meter.average()))
avg = average_meter.average()
logger.add_scalar('Test/mean_acc', avg.mean_acc, epoch)
logger.add_scalar('Test/mean_iou', avg.mean_iou, epoch)
print('\n*\n'
'mean_acc={average.mean_acc:.3f}\n'
'mean_iou={average.mean_iou:.3f}\n'
't_GPU={time:.3f}\n'.format(
average=avg, time=avg.gpu_time))
return avg, img_merge