-
Notifications
You must be signed in to change notification settings - Fork 4
/
run_eval.py
167 lines (130 loc) · 6.3 KB
/
run_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import argparse
import os
import sys
import time
import torch
from prodict import Prodict
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm
from lib import config_utils
from lib.arguments import eval_parser
from lib.data_utils import get_dataset
from lib.eval_tools import Imputation
from lib.logger import AverageMeter
from lib.metrics import EvalMetrics
def print_stats(stats, evaluator, print_only_masked=False):
prefix = evaluator.compute_metrics.prefix
if print_only_masked is False:
print('Metrics computed over all pixels:')
for k, v in stats.items():
if 'occluded_input_pixels' in k or 'observed_input_pixels' in k:
pass
else:
metric = k.replace(prefix, '')
print(f'{metric.upper()}: {v}')
if evaluator.compute_metrics.eval_occluded_observed:
print('\nMetrics computed over all masked input pixels:')
for k, v in stats.items():
if 'occluded_input_pixels' in k:
metric = k.replace(prefix, '').replace('_occluded_input_pixels', '').replace('_images', '')
print(f'{metric.upper()}: {v}')
if print_only_masked is False:
print('\nMetrics computed over all observed input pixels:')
for k, v in stats.items():
if 'observed_input_pixels' in k:
metric = k.replace(prefix, '').replace('_observed_input_pixels', '').replace('_images', '')
print(f'{metric.upper()}: {v}')
class Evaluator:
def __init__(self, args: argparse.Namespace, args_test_data: DictConfig):
self.args = args
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.args_metrics = {
'masked_metrics': True,
'sam_units': 'deg',
'eval_occluded_observed': True,
'mae': True, 'rmse': True, 'mse': False, 'ssim': True, 'psnr': True, 'sam': True
}
self.compute_metrics = EvalMetrics(self.args_metrics)
_ = torch.set_grad_enabled(False)
if not os.path.isfile(args.config_file):
raise FileNotFoundError(f'Cannot find the configuration file used during training: {args.config_file}\n')
# Read config file used during training
self.config = config_utils.read_config(args.config_file)
# Merge generic data settings (used during training) with test-specific data settings
self.config.data.update(args_test_data)
self.config.data.preprocessed = True
# Evaluate the entire image sequence
self.config.data.max_seq_length = None
# Get the data loader
dset = get_dataset(self.config, phase='test')
self.dataloader = torch.utils.data.DataLoader(
dataset=dset, batch_size=1, shuffle=False, num_workers=self.config.misc.num_workers, drop_last=False
)
# Get the imputation model
self.imputation = Imputation(
config_file_train=self.args.config_file,
method=self.args.method,
mode=args.mode,
checkpoint=self.args.checkpoint
)
def evaluate(self):
self._initialize_stats()
for i, batch in enumerate(tqdm(self.dataloader, leave=False)):
_, y_pred = self.imputation.impute_sample(batch)
# Evaluation
metrics = self.compute_metrics(batch, y_pred)
for key, value in metrics.items():
self.stats[key].update(value)
# Average metrics over all samples
for metric in self.stats.keys():
self.stats[metric] = self.stats[metric].avg
return self.stats
def _initialize_stats(self):
stats = Prodict()
eval_occluded_observed = self.args_metrics.get('eval_occluded_observed', True)
for metric, val in self.args_metrics.items():
if metric in ['masked_metrics', 'sam_units', 'eval_occluded_observed']:
pass
elif val:
metric_name = f'masked_{metric}' if (
self.args_metrics['masked_metrics'] and 'ssim' not in metric
) else metric
stats[metric_name] = AverageMeter()
if eval_occluded_observed and 'ssim' not in metric:
stats[f'{metric_name}_occluded_input_pixels'] = AverageMeter()
stats[f'{metric_name}_observed_input_pixels'] = AverageMeter()
if eval_occluded_observed and 'ssim' in metric:
stats[f'{metric_name}_images_occluded_input_pixels'] = AverageMeter()
stats[f'{metric_name}_images_observed_input_pixels'] = AverageMeter()
self.stats = stats
if __name__ == '__main__':
if len(sys.argv) < 2:
eval_parser.print_help()
sys.exit(1)
args = eval_parser.parse_args()
# Extract settings w.r.t. test data
if args.test_data.test_config is not None:
if not os.path.isfile(args.test_data.test_config):
raise FileNotFoundError(f'Cannot find the test configuration file: {args.test_data.test_config}\n')
args_test_data = config_utils.read_config(args.test_data.test_config).data
else:
args_test_data = OmegaConf.create()
if args.test_data.data_dir is not None:
if not os.path.exists(args.test_data.data_dir):
raise ValueError(f'Cannot find the data directory: {args.test_data.data_dir}\n')
args_test_data.root = args.test_data.data_dir
if args.test_data.hdf5_file is not None:
if not os.path.isfile(os.path.join(args_test_data.root, args.test_data.hdf5_file)):
raise FileNotFoundError(f'Cannot find the data file: {os.path.join(args_test_data.root, args.test_data.hdf5_file)}\n')
args_test_data.hdf5_file = args.test_data.hdf5_file
if args.test_data.split is not None:
args_test_data.split = args.test_data.split
if args.test_data.mode is not None:
args_test_data.mode = args.test_data.mode
evaluator = Evaluator(args, args_test_data)
since = time.time()
stats = evaluator.evaluate()
time_elapsed = time.time() - since
print('Evaluation completed in {:.0f}m {:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60))
print('Statistics:\n===========')
print_stats(stats, evaluator)