-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_multiple_checkpoints.py
114 lines (100 loc) · 4.17 KB
/
evaluate_multiple_checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import sys, yaml, argparse
from os.path import join, basename
import os
import subprocess
import re
from utils_ import prep_email, notify_email, get_run_checkpoints
"""
Function to run and evaluate the K last checkpoints of a given run.
"""
parser = argparse.ArgumentParser()
parser.add_argument("configfile")
parser.add_argument("--onlyprint", action="store_true")
# number of checkpoints to evaluate. Note that tensorflow has a similar param, keeping at most k checkpoints.
parser.add_argument("-num_checkpoints", type=int)
parser.add_argument("-omit_epochs", nargs="*", dest="omit")
parser.add_argument("-only_epochs", nargs="*", dest="only")
args = parser.parse_args()
# parse checkpoints
#####################
with open(args.configfile,"r") as f:
config = yaml.load(f)['run']
resume_file = config['resume_file']
run_folder = config['run_folder']
email_notify = config['logging']['email_notify']
if email_notify:
sender, password, recipient = prep_email(email_notify)
if args.omit:
args.omit = ["_ep_{}_".format(x) for x in args.omit]
elif args.only:
args.only = ["_ep_{}_".format(x) for x in args.only]
if args.omit and args.only:
print("Cannot specify [only] and [commit] flags at the same time")
exit(1)
raw_checkpoints = get_run_checkpoints(run_folder)
checkpoints = []
for chkp in raw_checkpoints:
if chkp.startswith('"') or chkp.startswith("'"):
chkp = chkp[1:-1]
if args.omit and any([x in chkp for x in args.omit]):
print("Omitting {} due to epoch omission arguments.".format(chkp))
continue
if args.only and (not any([x in chkp for x in args.only])):
print("Omitting {} due to epoch restriction arguments.".format(chkp))
continue
checkpoints.append(chkp)
if args.num_checkpoints:
if len(checkpoints) < args.num_checkpoints: print("Unable to run for {} checkpoints, as there are only {}".format(args.num_checkpoints, len(checkpoints)))
if len(checkpoints) < args.num_checkpoints: print("Limiting evaluation from {} to the {} last checkpoints".format(len(checkpoints), args.num_checkpoints))
num_checkpoints = min(args.num_checkpoints, len(checkpoints))
checkpoints = checkpoints[-num_checkpoints:]
print("Checkpoints:")
for line in checkpoints:
print(line)
# write configuration files
###########################
config_files, run_ids = [], []
base_run_id = config["rund_id"] if "run_id" in config else ""
for i in range(len(checkpoints)):
checkpoint_path = checkpoints[i]
conffile = os.path.splitext(args.configfile)[0] + "." + os.path.basename(checkpoint_path) + ".yml"
config_files.append(conffile)
config['resume_file'] = checkpoint_path
config['phase'] = "defs.phase.val"
config['run_id'] = base_run_id + "multiple_eval_%d" % (i+1)
# no email notification
config['logging']['email_notify'] = ""
run_ids.append(config['run_id'])
curr_config = { "run" : config }
if not args.onlyprint:
with open(config_files[-1],"w") as f:
yaml.dump(curr_config, f, default_flow_style = False)
# run each validation run
if not args.onlyprint:
for i, conf in enumerate(config_files):
cmd = ("python3 run_task.py " + conf).split(maxsplit=2)
print("Running %d/%d validation, with command:" % (i+1, len(config_files)),cmd)
subprocess.run(cmd)
# delete the config file locally
os.remove(conf)
else:
for i, conf in enumerate(config_files):
print(conf)
exit(1)
# print out accuracies
print("Getting results from",run_folder)
dirfiles = [ff for ff in os.listdir(run_folder)]
for i, (conf, rid) in enumerate(zip(config_files, run_ids)):
run_epoch_id = "_{}_".format(rid)
accfiles = [f for f in dirfiles if run_epoch_id in f and "accuracy" in f]
if len(accfiles) > 1:
print("(!) Multiple accuracy files for", rid, ":", accfiles)
if not accfiles:
print("(!) No accuracy files for", rid, ":", accfiles)
continue
accfile = accfiles[0]
with open(join(run_folder,accfile),"r") as f:
accuracy = f.read()
print(rid,accfile,basename(checkpoints[i]),accuracy)
if email_notify:
notify_email(sender, password, recipient, "Multi-chekcpoint evaluation complete.", msgtype="INFO")