-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
100 lines (68 loc) · 2.5 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from utils.args import get_args
from utils.config import process_config, process_config_test
from data_loader.rgb_data_loader import RGBTrainGenerator, RGBCVGenerator
from models.resnetRGB_model import ResNetRGBModel
from trainer.resnetRGB_trainer import ResNetRGBTrainer
from data_loader.persistence_data_loader import PersistenceTrainGenerator, PersistenceCVGenerator
from models.Persistence_model import PersistenceModel
from trainer.Persistence_trainer import PersistenceTrainer
from data_loader.combined_data_loader import CombinedTrainGenerator, CombinedCVGenerator
from models.resnetCombined_model import ResNetCombinedModel
from trainer.resnetCombined_trainer import ResNetCombinedTrainer
'''
Data distribution:
Train Val Test
Malignant 38322 2036 509
Benign 18635 2035 517
'''
def main():
args = get_args()
config = process_config(args)
print '\n', config, '\n'
if config.rgb:
print '-'*60
print 'Training on RGB images'
print '-'*60
train_rgb(config)
elif config.combined:
print '-'*60
print 'Training on Combined images'
print '-'*60
train_combined(config)
else:
print '-'*60
print 'Training on Persistence images'
print '-'*60
train_persistence(config)
def train_combined(config):
print 'Creating data generators'
train_generator = CombinedTrainGenerator(config)
cv_generator = CombinedCVGenerator(config)
print 'Creating combined model'
model = ResNetCombinedModel(config)
print 'Creating trainer'
trainer = ResNetCombinedTrainer(model.model, config)
print 'Training'
trainer.train(train_generator, cv_generator)
def train_persistence(config):
print 'Creating data generators'
train_generator = PersistenceTrainGenerator(config)
cv_generator = PersistenceCVGenerator(config)
print 'creating model'
model = PersistenceModel(config)
print 'Creating trainer'
trainer = PersistenceTrainer(model.model, config)
print 'Training....'
trainer.train(train_generator, cv_generator)
def train_rgb(config):
print 'Creating data generators'
train_generator = RGBTrainGenerator(config)
cv_generator = RGBCVGenerator(config)
print 'creating model'
model = ResNetRGBModel(config)
print 'Creating trainer'
trainer = ResNetRGBTrainer(model.model, config)
print 'Training....'
trainer.train(train_generator, cv_generator)
if __name__ == "__main__":
main()