-
Notifications
You must be signed in to change notification settings - Fork 13
/
train_cnn.py
188 lines (150 loc) · 8.29 KB
/
train_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 26 10:42:26 2019
@author: rulix
"""
import logging
import os
import time
from parser import argparser
import numpy as np
import scipy.io as sio
import tensorflow as tf
import utils
from CorrFusionNet import model
logging.basicConfig(format='%(asctime)-15s %(levelname)s: %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
## for evaluation in each epoch
def test(model=None, session=None, file_list=None, batch_size=None, use_tfboard=True, summary=None, tb_writer=None, step=0):
session.run(model.local_init)
for pfile in file_list:
logging.info('Epoch %2d, evaluating on file: %s......'%(step,pfile))
xbatch1, xbatch2, ybatch1, ybatch2 = utils.LoadNpy(pfile)
for k1 in range(0, np.shape(xbatch1)[0], batch_size):
lb = int(k1)
ub = int(np.min((lb+batch_size,np.shape(xbatch1)[0])))
#sess.run(base_model.local_init)
feed_dict = {model.inputs_t1: xbatch1[lb:ub, :], model.labels_t1: ybatch1[lb:ub],
model.inputs_t2: xbatch2[lb:ub, :], model.labels_t2: ybatch2[lb:ub]}
session.run([model.metrics_t1_op, model.metrics_t2_op], feed_dict=feed_dict)
if args.use_tfboard is True:
tb_writer.add_summary(session.run(summary, feed_dict=feed_dict),global_step=step)
tb_writer.flush()
acc_t1, acc_t2 = session.run([model.metrics_t1, model.metrics_t2])
return acc_t1, acc_t2
def main(trn_file=None, val_file=None, tst_file=None, args=None):
inputs_shape = [None, 200, 200, 3]
base_model = model(inputs_shape=inputs_shape)
base_model.forward(num_classes=args.num_classes)
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)
#global_steps = tf.Variable(0, trainable=False)
#optimizer = tf.train.MomentumOptimizer(learning_rate=1e-3,momentum=0.9,use_nesterov=True).minimize(base_model.losses, global_step=global_steps)
global_steps = tf.Variable(0, trainable=False)
steps_per_epoch = int(11783 / args.batch_size)
boundaries = [30*steps_per_epoch, 60*steps_per_epoch, 90*steps_per_epoch]
values = [5*1e-3, 1e-3, 1e-4, 1e-5]
learning_rate = tf.train.piecewise_constant(global_steps, boundaries, values)
'''
inital_lr = 1e-3
decay_rate = 0.95
learning_rate = tf.train.exponential_decay(learning_rate=inital_lr, global_step=global_steps, decay_rate=decay_rate, decay_steps=decay_steps, staircase=True)
'''
#optimizer = tf.train.MomentumOptimizer(learning_rate=1e-3,momentum=0.9,use_nesterov=True).minimize(base_model.losses, global_step=global_steps)
tf.summary.scalar(name='lr/learning_rate', tensor=learning_rate)
tf.summary.scalar(name='lr/global_steps', tensor=global_steps)
'''
specific optimizers for DCCA loss
'''
optimizer = tf.train.MomentumOptimizer(learning_rate=1e-3,momentum=0.9).minimize(base_model.losses, global_step=global_steps)
conv_vars = tf.trainable_variables(scope='conv_layers')
dense_vars = tf.trainable_variables()[len(conv_vars):]
optimizer_dcca = tf.train.AdamOptimizer(learning_rate=1e-6).minimize(base_model.dcca_loss, var_list=dense_vars, global_step=global_steps)
optimizer = tf.group(optimizer, optimizer_dcca)
initializer = tf.global_variables_initializer()
sess.run(base_model.local_init)
sess.run(initializer)
## tensorboard logger
summary_merge = None
writer_trn = writer_tst = writer_val = None
if args.use_tfboard is True:
writer_trn = tf.summary.FileWriter(logdir=args.tb_path+'/trn', graph=sess.graph)
writer_val = tf.summary.FileWriter(logdir=args.tb_path+'/val', graph=sess.graph)
writer_tst = tf.summary.FileWriter(logdir=args.tb_path+'/tst', graph=sess.graph)
summary_merge = tf.summary.merge_all()
temp_acc_t1 = 0.
temp_acc_t2 = 0.
## storing the outputs to log.txt
f = open(args.log_path+'log.txt', 'w')
cnt = 0
for step in range(args.epoches):
### optimization
logging.info('Epoch %2d, training started......'%(step))
for trn in trn_file:
logging.info('Epoch %2d, training on file: %s......'%(step,trn))
xtrn1, xtrn2, ytrn1, ytrn2 = utils.LoadNpy(trn)
for k1 in range(0, np.shape(xtrn1)[0], args.batch_size):
lb = int(k1)
ub = int(np.min((lb + args.batch_size,np.shape(xtrn1)[0])))
feed_dict = {base_model.inputs_t1: xtrn1[lb:ub, :], base_model.labels_t1: ytrn1[lb:ub],
base_model.inputs_t2: xtrn2[lb:ub, :], base_model.labels_t2: ytrn2[lb:ub]}
sess.run([optimizer], feed_dict=feed_dict)
logging.info('Epoch %2d, training finished.....'%(step))
### training
logging.info('Epoch %2d, evaluating on training set......'%(step))
f.writelines('Epoch %2d, evaluating on training set......\n'%(step))
trn_acc_t1, trn_acc_t2 = test(model=base_model, session=sess, file_list=trn_file, batch_size=args.batch_size, use_tfboard=args.use_tfboard, summary=summary_merge, tb_writer=writer_trn, step=step)
logging.info('Epoch %2d, evaluating on training set finished, acc_t1 is: %.4f, acc_t2 is %.4f.....'%(step, trn_acc_t1, trn_acc_t2))
f.writelines('Epoch %2d, evaluating on training set finished, acc_t1 is: %.4f, acc_t2 is %.4f.....\n'%(step, trn_acc_t1, trn_acc_t2))
### validation
logging.info('Epoch %2d, evaluating on validation set......'%(step))
f.writelines('Epoch %2d, evaluating on validation set......\n'%(step))
val_acc_t1, val_acc_t2 = test(model=base_model, session=sess, file_list=val_file, batch_size=args.batch_size, use_tfboard=args.use_tfboard, summary=summary_merge, tb_writer=writer_val, step=step)
logging.info('Epoch %2d, evaluating on validation set finished, acc_t1 is: %.4f, acc_t2 is %.4f.....'%(step, val_acc_t1, val_acc_t2))
f.writelines('Epoch %2d, evaluating on validation set finished, acc_t1 is: %.4f, acc_t2 is %.4f.....\n'%(step, val_acc_t1, val_acc_t2))
### testing
logging.info('Epoch %2d, evaluating on testing set......'%(step))
f.writelines('Epoch %2d, evaluating on testing set......\n'%(step))
tst_acc_t1, tst_acc_t2 = test(model=base_model, session=sess, file_list=tst_file, batch_size=args.batch_size, use_tfboard=args.use_tfboard, summary=summary_merge, tb_writer=writer_tst, step=step)
logging.info('Epoch %2d, evaluating on testing set finished, acc_t1 is: %.4f, acc_t2 is %.4f.....'%(step, tst_acc_t1, tst_acc_t2))
f.writelines('Epoch %2d, evaluating on testing set finished, acc_t1 is: %.4f, acc_t2 is %.4f.....\n\n'%(step, tst_acc_t1, tst_acc_t2))
#### save better model
if args.save_model & ((val_acc_t1+val_acc_t2) > (temp_acc_t1+temp_acc_t2)):
model_name = 'model.ckpt'
logging.info('Saving model to %s......'%(args.model_path+'/'+model_name))
saver = tf.train.Saver(max_to_keep=3,)
saver.save(sess=sess,save_path=args.model_path+'/'+model_name)
logging.info('Model saved......\n')
temp_acc_t1 = val_acc_t1
temp_acc_t2 = val_acc_t2
else:
logging.info('Performance is worse, don\'t save model....\n')
f.close()
sess.close()
return True
if __name__ == '__main__':
args = argparser()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
## load data from npz files
trn_list = os.listdir(args.trn_dir)
trn_file = [args.trn_dir+npz for npz in trn_list]
print(trn_file)
val_list = os.listdir(args.val_dir)
val_file = [args.val_dir+npz for npz in val_list]
print(val_file)
tst_list = os.listdir(args.tst_dir)
tst_file = [args.tst_dir+npz for npz in tst_list]
print(tst_file)
log_path = args.log_path
model_path = args.model_path
tb_path = args.tb_path
for k in range(3):
args.log_path = log_path + str(k) + '/'
args.model_path = model_path + str(k) + '/'
args.tb_path = tb_path + str(k) + '/'
if os.path.exists(args.log_path) is False:
os.makedirs(args.log_path)
if args.save_model and (os.path.exists(args.model_path) is False):
os.makedirs(args.model_path)
main(trn_file, val_file, tst_file, args)