-
Notifications
You must be signed in to change notification settings - Fork 8
/
test_image_patch.py
129 lines (112 loc) · 5.58 KB
/
test_image_patch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import argparse
import torch
import os
import numpy as np
import datasets.crowd as crowd
from Networks import ALTGVT
import torch.nn.functional as F
def tensor_divideByfactor(img_tensor, factor=32):
_, _, h, w = img_tensor.size()
h, w = int(h//factor*factor), int(w//factor*factor)
img_tensor = F.interpolate(img_tensor, (h, w), mode='bilinear', align_corners=True)
return img_tensor
def cal_new_tensor(img_tensor, min_size=256):
_, _, h, w = img_tensor.size()
if min(h, w) < min_size:
ratio_h, ratio_w = min_size / h, min_size / w
if ratio_h >= ratio_w:
img_tensor = F.interpolate(img_tensor, (min_size, int(min_size / h * w)), mode='bilinear', align_corners=True)
else:
img_tensor = F.interpolate(img_tensor, (int(min_size / w * h), min_size), mode='bilinear', align_corners=True)
return img_tensor
parser = argparse.ArgumentParser(description='Test ')
parser.add_argument('--device', default='0', help='assign device')
parser.add_argument('--batch-size', type=int, default=8,
help='train batch size')
parser.add_argument('--crop-size', type=int, default=256,
help='the crop size of the train image')
parser.add_argument('--model-path', type=str, required=True,
help='saved model path')
parser.add_argument('--data-path', type=str,
help='dataset path')
parser.add_argument('--dataset', type=str, default='sha',
help='dataset name: qnrf, nwpu, sha, shb, custom')
parser.add_argument('--pred-density-map-path', type=str, default='inference_results',
help='save predicted density maps when pred-density-map-path is not empty.')
def test(args, isSave = True):
os.environ['CUDA_VISIBLE_DEVICES'] = args.device # set vis gpu
device = torch.device('cuda')
model_path = args.model_path
crop_size = args.crop_size
data_path = args.data_path
if args.dataset.lower() == 'qnrf':
dataset = crowd.Crowd_qnrf(os.path.join(data_path, 'test'), crop_size, 8, method='val')
elif args.dataset.lower() == 'nwpu':
dataset = crowd.Crowd_nwpu(os.path.join(data_path, 'val'), crop_size, 8, method='val')
elif args.dataset.lower() == 'sha' or args.dataset.lower() == 'shb':
dataset = crowd.Crowd_sh(os.path.join(data_path, 'test_data'), crop_size, 8, method='val')
elif args.dataset.lower() == 'custom':
dataset = crowd.CustomDataset(data_path, crop_size, downsample_ratio=8, method='test')
else:
raise NotImplementedError
dataloader = torch.utils.data.DataLoader(dataset, 1, shuffle=False,
num_workers=1, pin_memory=True)
model = ALTGVT.alt_gvt_large(pretrained=True)
model.to(device)
model.load_state_dict(torch.load(model_path, device))
model.eval()
image_errs = []
result = []
for inputs, count, name in dataloader:
with torch.no_grad():
# nputs = cal_new_tensor(inputs, min_size=args.crop_size)
inputs = inputs.to(device)
crop_imgs, crop_masks = [], []
b, c, h, w = inputs.size()
rh, rw = args.crop_size, args.crop_size
for i in range(0, h, rh):
gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
for j in range(0, w, rw):
gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
crop_imgs.append(inputs[:, :, gis:gie, gjs:gje])
mask = torch.zeros([b, 1, h, w]).to(device)
mask[:, :, gis:gie, gjs:gje].fill_(1.0)
crop_masks.append(mask)
crop_imgs, crop_masks = map(lambda x: torch.cat(x, dim=0), (crop_imgs, crop_masks))
crop_preds = []
nz, bz = crop_imgs.size(0), args.batch_size
for i in range(0, nz, bz):
gs, gt = i, min(nz, i + bz)
crop_pred, _ = model(crop_imgs[gs:gt])
_, _, h1, w1 = crop_pred.size()
crop_pred = F.interpolate(crop_pred, size=(h1 * 8, w1 * 8), mode='bilinear', align_corners=True) / 64
crop_preds.append(crop_pred)
crop_preds = torch.cat(crop_preds, dim=0)
# splice them to the original size
idx = 0
pred_map = torch.zeros([b, 1, h, w]).to(device)
for i in range(0, h, rh):
gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
for j in range(0, w, rw):
gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx]
idx += 1
# for the overlapping area, compute average value
mask = crop_masks.sum(dim=0).unsqueeze(0)
outputs = pred_map / mask
img_err = count[0].item() - torch.sum(outputs).item()
print("Img name: ", name, "Error: ", img_err, "GT count: ", count[0].item(), "Model out: ", torch.sum(outputs).item())
image_errs.append(img_err)
result.append([name, count[0].item(), torch.sum(outputs).item(), img_err])
image_errs = np.array(image_errs)
mse = np.sqrt(np.mean(np.square(image_errs)))
mae = np.mean(np.abs(image_errs))
print('{}: mae {}, mse {}\n'.format(model_path, mae, mse))
if isSave:
with open("ALGVT_sha_test.txt","w") as f:
for i in range(len(result)):
f.write(str(result[i]).replace('[','').replace(']','').replace(',', ' ')+"\n")
f.close()
if __name__ == '__main__':
args = parser.parse_args()
test(args, isSave= True)