-
Notifications
You must be signed in to change notification settings - Fork 18
/
main.py
55 lines (45 loc) · 2.68 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
'''
This is a part of the supplementary material uploaded along with
the manuscript:
"Lung Pattern Classification for Interstitial Lung Diseases Using a Deep Convolutional Neural Network"
M. Anthimopoulos, S. Christodoulidis, L. Ebner, A. Christe and S. Mougiakakou
IEEE Transactions on Medical Imaging (2016)
http://dx.doi.org/10.1109/TMI.2016.2535865
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
For more information please read the README file. The files can also
be found at: https://github.com/intact-project/ild-cnn
'''
import helpers as H
import cnn_model as CNN
# debug
from ipdb import set_trace as bp
# initialization
args = H.parse_args() # Function for parcing command-line arguments
train_params = {
'do' : float(args.do) if args.do else 0.5, # Dropout Parameter
'a' : float(args.a) if args.a else 0.3, # Conv Layers LeakyReLU alpha param [if alpha set to 0 LeakyReLU is equivalent with ReLU]
'k' : int(args.k) if args.k else 4, # Feature maps k multiplier
's' : float(args.s) if args.s else 1, # Input Image rescale factor
'pf' : float(args.pf) if args.pf else 1, # Percentage of the pooling layer: [0,1]
'pt' : args.pt if args.pt else 'Avg', # Pooling type: Avg, Max
'fp' : args.fp if args.fp else 'proportional', # Feature maps policy: proportional, static
'cl' : int(args.cl) if args.cl else 5, # Number of Convolutional Layers
'opt': args.opt if args.opt else 'Adam', # Optimizer: SGD, Adagrad, Adam
'obj': args.obj if args.obj else 'ce', # Minimization Objective: mse, ce
'patience' : args.pat if args.pat else 200, # Patience parameter for early stoping
'tolerance': args.tol if args.tol else 1.005, # Tolerance parameter for early stoping [default: 1.005, checks if > 0.5%]
'res_alias': args.csv if args.csv else 'res' # csv results filename alias
}
# loading mnist data as example
(X_train, y_train), (X_val, y_val) = H.load_data()
# train a CNN model
model = CNN.train(X_train, y_train, X_val, y_val, train_params)