-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
71 lines (54 loc) · 2.95 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
'''
Main module for training models with a given kernel, evaluating the models, and saving
the results as MLflow runs. Experiment ID should be "0" until new experiments are
created using MLflow.
'''
from src import experiments
from src import models
from src.helpers import make_config
from sklearn.gaussian_process import kernels as kern
import argparse
import warnings
warnings.filterwarnings("ignore")
params = make_config('config.yaml')
KERNELS = {
'Linear_Kernel': kern.DotProduct(sigma_0=0) + kern.WhiteKernel(noise_level=0.001),
'Quadratic_Kernel': kern.DotProduct(sigma_0=0)**2 + kern.WhiteKernel(noise_level=0.001),
'Cubic_Kernel': kern.DotProduct(sigma_0=0)**3 + kern.WhiteKernel(noise_level=0.001),
'RBF_Kernel': kern.RBF(),
'RQ_Kernel': kern.RationalQuadratic()
}
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Provide kernel to test.')
parser.add_argument('-k', '--kernel', default='RBF_Kernel', required=False)
parser.add_argument('-n', '--n_restarts', default=3, required=False)
args = parser.parse_args()
print(f"\nStarting {args.kernel} experiment..\n")
print("Running Vanilla GP..\n")
experiments.AircraftExperiment(config_path='config.yaml',
experiment_id=params['experiment_id'],
run_name=f'Vanilla GP with {args.kernel}',
projection_title="Original Images",
visualize_multiclass=True,
model=models.VanillaGP(kernel=KERNELS[args.kernel],
n_restarts_optimizer=args.n_restarts)
).run()
print("Running ResNet-18 GP..\n")
experiments.AircraftExperiment(config_path='config.yaml',
experiment_id=params['experiment_id'],
run_name=f'ResNet-18 GP with {args.kernel}',
projection_title="ResNet-18 Features",
visualize_multiclass=True,
model=models.ResNetGP(kernel=KERNELS[args.kernel],
n_restarts_optimizer=args.n_restarts)
).run()
print("Running SIFT GP..\n")
experiments.AircraftExperiment(config_path='config.yaml',
experiment_id=params['experiment_id'],
run_name=f'SIFT GP with {args.kernel}',
projection_title="SIFT Features",
visualize_multiclass=True,
model=models.SIFTGP(kernel=KERNELS[args.kernel],
n_restarts_optimizer=args.n_restarts)
).run()
print("Experiments complete.\n")