-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
90 lines (72 loc) · 3.57 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
import os
import random as rn
import numpy as np
import tensorflow as tf
import warnings
from keras.backend import set_session
from helpers.arguments import Mode, Dataset, Method, Model
from helpers.training import train_test_dense_net_split, train_test_attention_guided_cnn_cv, train_test_dense_net_cv, \
train_test_attention_guided_cnn_split, train_test_dense_net_split_regression, \
train_test_attention_guided_cnn_split_regression
RANDOM_SEED = 20
BATCH_SIZE = 16
N_FOLDS = 10
EPOCHS = 50
def initial_configs():
np.random.seed(RANDOM_SEED)
tf.set_random_seed(RANDOM_SEED)
rn.seed(RANDOM_SEED)
warnings.filterwarnings("ignore", category=UserWarning, module='keras')
np.set_printoptions(threshold=np.inf)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
config = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
set_session(sess)
# noinspection PyTypeChecker
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--mode", dest='mode', choices=list(Mode), type=Mode.from_string, required=True)
parser.add_argument("--dataset", dest='dataset', choices=list(Dataset), type=Dataset.from_string, required=True)
parser.add_argument("--method", dest='method', choices=list(Method), type=Method.from_string, required=True)
parser.add_argument("--model", dest='model', choices=list(Model), type=Model.from_string, required=True)
parser.add_argument("--data-augmentation", dest='data_augmentation', action='store_true')
parser.add_argument("--class-activation-map", dest='class_activation_map', action='store_true')
parser.add_argument("--print-classifications", dest='print_classifications', action='store_true')
parser.add_argument("--mix-up", dest='mix_up', action='store_true')
return vars(parser.parse_args())
def train_test_model(args):
is_binary = args['dataset'] not in [Dataset.flood_severity_3_classes,
Dataset.flood_severity_4_classes,
Dataset.flood_severity_european_floods,
Dataset.flood_heights]
args['image_size'] = 224 if args['model'] != Model.efficient_net else 300
args['random_seed'] = RANDOM_SEED
args['batch_size'] = BATCH_SIZE
args['is_binary'] = is_binary
args['epochs'] = EPOCHS
if args['dataset'] != Dataset.flood_heights:
if args['method'] == Method.cross_validation:
args['nr_folds'] = N_FOLDS
if args['model'] == Model.attention_guided:
train_test_attention_guided_cnn_cv(args)
elif args['model'] == Model.dense_net or args['model'] == Model.efficient_net:
train_test_dense_net_cv(args)
if args['method'] == Method.train_test_split:
if args['model'] == Model.attention_guided:
train_test_attention_guided_cnn_split(args)
elif args['model'] == Model.dense_net or args['model'] == Model.efficient_net:
train_test_dense_net_split(args)
else:
if args['method'] == Method.train_test_split:
if args['model'] == Model.attention_guided:
train_test_attention_guided_cnn_split_regression(args)
elif args['model'] == Model.dense_net or args['model'] == Model.efficient_net:
train_test_dense_net_split_regression(args)
def main():
initial_configs()
parsed_args = parse_args()
train_test_model(parsed_args)
if __name__ == "__main__":
main()