-
Notifications
You must be signed in to change notification settings - Fork 9
/
test.py
52 lines (42 loc) · 1.49 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import os, sys
os.environ['TORCH_MODEL_ZOO'] = '/mnt/projects/counting/pretrained/resnet'
import models, datasets, metrics
import utils as ut
import tqdm, time
def main():
exp_dict = {"model": "MRCNN",
"max_epochs": 10,
"batch_size": 1}
model = models.mrcnn.MRCNN(exp_dict).cuda()
path_base = "checkpoints"
model_state_dict = torch.load(path_base + "/model_state_dict.pth")
model.load_state_dict(model_state_dict)
val_set = datasets.pascal2012.Pascal2012(split="val", exp_dict=exp_dict,
root="/mnt/datasets/public/issam/VOCdevkit/VOC2012/")
val_loader = torch.utils.data.DataLoader(
val_set,
collate_fn=val_set.collate_fn,
batch_size=1,
shuffle=False,
num_workers=1,
pin_memory=True)
model.eval()
# Init variables
n_batches = len(val_loader)
metric_class = metrics.ap.AP50_segm
metric_object = metric_class()
metric_name = type(metric_object).__name__
# Validate
pbar = tqdm.tqdm(total=n_batches)
for i, batch in enumerate(val_loader):
# Validate a batch
metric_object.add_batch(model, batch)
score_dict = metric_object.get_score_dict()
pbar.set_description(" > Validation %s: %.4f" % (metric_name, score_dict["score"]))
pbar.update(1)
pbar.close()
# score should be around 40.4
print("score:", metric_object.get_score_dict())
if __name__ == "__main__":
main()