-
Notifications
You must be signed in to change notification settings - Fork 6
/
pretrain.py
82 lines (65 loc) · 3.75 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from network import *
import argparse
import os
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
# argparse
parser = argparse.ArgumentParser(description='Pretrain the models')
parser.add_argument('-network', required=True, type=str, help='select the backbone network')
parser.add_argument('-train_dir', required=True, type=str, help='train image directory')
parser.add_argument('-val_dir', required=True, type=str, help='validation image directory')
parser.add_argument('-img_height', type=str, default=64, help='image height')
parser.add_argument('-img_width', type=int, default=64, help='image width')
parser.add_argument('-batch_size', type=int, default=128, help='batch_size')
parser.add_argument('-es_patience', type=int, default=20, help='early stopping patience')
parser.add_argument('-reduce_factor', type=int, default=0.1, help='reduce factor')
parser.add_argument('-reduce_patience', default=20, type=int, help='reduce patience')
parser.add_argument('-step', type=int, default=200, help='steps per epoch')
parser.add_argument('-epochs', type=int, default=300, help='epochs')
parser.add_argument('-dropout_rate', type=int, default=0.2, help='dropout rate')
parser.add_argument('-gpu_ids', type=str, default='0', help='select the GPU to use')
args = parser.parse_args()
root_dir = os.getcwd()
weight_save_dir = os.path.join(root_dir, 'weights')
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
# model selection
if args.network.lower() == 'xception':
model = xception(args.img_height, args.img_weight, args.dropout_rate)
elif args.network.lower() == 'resnetv2':
model = resNetV2(args.img_height, args.img_weight)
elif args.network.lower() == 'squeezenet':
model = squeezeNet(args.img_height, args.img_weight, args.dropout_rate)
model.summary()
# model compile
model.compile(optimizer=Adam(),
loss='categorical_crossentropy',
metrics=['accuracy'])
print(len(model.trainable_weights))
datagenerator = ImageDataGenerator(rotation_range=0.0,
shear_range=0,
zoom_range=0,
width_shift_range=0,
height_shift_range=0,
horizontal_flip=False,
rescale=1./255,)
train_generator = datagenerator.flow_from_directory(args.train_dir,
target_size=(args.img_height, args.img_width),
batch_size=args.batch_size,
shuffle=True,
class_mode='categorical')
validation_generator = datagenerator.flow_from_directory(args.val_dir,
target_size=(args.img_height, args.img_width),
batch_size=args.batch_size,
shuffle=False,
class_mode='categorical')
callback_list = [EarlyStopping(monitor='val_accuracy', patience=args.es_patience),
ReduceLROnPlateau(monitor='val_loss', factor=args.reduce_factor, patience=args.reduce_patience)]
history = model.fit_generator(train_generator,
steps_per_epoch=args.step,
epochs=args.epochs,
validation_data=validation_generator,
validation_steps=len(validation_generator),
callbacks=callback_list)
# save the model weight
model.save(weight_save_dir)