forked from Deeachain/Segmentation-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
129 lines (111 loc) · 6.47 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import torch
import torch.backends.cudnn as cudnn
from argparse import ArgumentParser
from prettytable import PrettyTable
from builders.model_builder import build_model
from builders.dataset_builder import build_dataset_test
from builders.loss_builder import build_loss
from builders.validation_builder import predict_multiscale_sliding
def main(args):
"""
main function for testing
param args: global arguments
return: None
"""
t = PrettyTable(['args_name', 'args_value'])
for k in list(vars(args).keys()):
t.add_row([k, vars(args)[k]])
print(t.get_string(title="Predict Arguments"))
# build the model
# model = build_model(args.model, args.classes, args.backbone)
model = build_model(args.model, args.classes, args.backbone, args.pretrained, args.out_stride, args.mult_grid)
# load the test set
if args.predict_type == 'validation':
testLoader, class_dict_df = build_dataset_test(args.root, args.dataset, args.crop_size, args.batch_size,
args.num_workers, mode=args.predict_mode, gt=True)
else:
testLoader, class_dict_df = build_dataset_test(args.root, args.dataset, args.crop_size, args.batch_size,
args.num_workers, mode=args.predict_mode, gt=False)
if args.cuda:
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
model = model.cuda() # using GPU for inference
cudnn.benchmark = True
if not torch.cuda.is_available():
raise Exception("no GPU found or wrong gpu id, please run without --cuda")
if not os.path.exists(args.save_seg_dir):
os.makedirs(args.save_seg_dir)
if args.checkpoint:
if os.path.isfile(args.checkpoint):
checkpoint = torch.load(args.checkpoint)['model']
check_list = [i for i in checkpoint.items()]
if 'module.' in check_list[0][0]: # 读取使用多卡训练权重,并且此次使用单卡预测
new_stat_dict = {}
for k, v in checkpoint.items():
new_stat_dict[k[7:]] = v
model.load_state_dict(new_stat_dict, strict=True)
else: # 读取使用单卡训练权重,并且此次使用单卡预测
model.load_state_dict(checkpoint)
else:
print("no checkpoint found at '{}'".format(args.checkpoint))
raise FileNotFoundError("no checkpoint found at '{}'".format(args.checkpoint))
# define loss function, respectively
criterion = build_loss(args, None, 255)
print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n"
">>>>>>>>>>> beginning testing >>>>>>>>>>>>\n"
">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
predict_multiscale_sliding(args=args, model=model, testLoader=testLoader, class_dict_df=class_dict_df,
scales=args.scales, overlap=args.overlap, criterion=criterion,
mode=args.predict_type, save_result=True)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--model', default="UNet", help="model name")
parser.add_argument('--backbone', type=str, default="resnet18", help="backbone name")
parser.add_argument('--pretrained', action='store_true',
help="whether choice backbone pretrained on imagenet")
parser.add_argument('--out_stride', type=int, default=32, help="output stride of backbone")
parser.add_argument('--mult_grid', action='store_true',
help="whether choice mult_grid in backbone last layer")
parser.add_argument('--root', type=str, default="", help="path of datasets")
parser.add_argument('--predict_mode', default="sliding", choices=["sliding", "whole"],
help="Defalut use whole predict mode")
parser.add_argument('--predict_type', default="validation", choices=["validation", "predict"],
help="Defalut use validation type")
parser.add_argument('--flip_merge', action='store_true', help="Defalut use predict without flip_merge")
parser.add_argument('--scales', type=float, nargs='+', default=[1.0], help="predict with multi_scales")
parser.add_argument('--overlap', type=float, default=0.0, help="sliding predict overlap rate")
parser.add_argument('--dataset', default="paris", help="dataset: cityscapes")
parser.add_argument('--num_workers', type=int, default=4, help="the number of parallel threads")
parser.add_argument('--batch_size', type=int, default=1,
help=" the batch_size is set to 1 when evaluating or testing NOTES:image size should fixed!")
parser.add_argument('--tile_hw_size', type=str, default='512, 512',
help=" the tile_size is when evaluating or testing")
parser.add_argument('--crop_size', type=int, default=769, help="crop size of image")
parser.add_argument('--input_size', type=str, default=(769, 769),
help=" the input_size is for build ProbOhemCrossEntropy2d loss")
parser.add_argument('--checkpoint', type=str, default='',
help="use the file to load the checkpoint for evaluating or testing ")
parser.add_argument('--save_seg_dir', type=str, default="./outputs/",
help="saving path of prediction result")
parser.add_argument('--loss', type=str, default="CrossEntropyLoss2d",
choices=['CrossEntropyLoss2d', 'ProbOhemCrossEntropy2d', 'CrossEntropyLoss2dLabelSmooth',
'LovaszSoftmax', 'FocalLoss2d'], help="choice loss for train or val in list")
parser.add_argument('--cuda', default=True, help="run on CPU or GPU")
parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
args = parser.parse_args()
save_dirname = args.checkpoint.split('/')[-2] + '_' + args.checkpoint.split('/')[-1].split('.')[0]
args.save_seg_dir = os.path.join(args.save_seg_dir, args.dataset, args.predict_mode, save_dirname)
if args.dataset == 'cityscapes':
args.classes = 19
elif args.dataset == 'paris':
args.classes = 3
elif args.dataset == 'austin':
args.classes = 2
elif args.dataset == 'road':
args.classes = 2
elif args.dataset == 'postdam' or args.dataset == 'vaihingen':
args.classes = 6
else:
raise NotImplementedError(
"This repository now supports datasets %s is not included" % args.dataset)
main(args)