-
Notifications
You must be signed in to change notification settings - Fork 17
/
train.py
executable file
·301 lines (238 loc) · 14.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import os
import sys
import logging
import signal
import random
from argparse import ArgumentParser
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchsummary import summary
from ignite.engine.engine import Engine, State, Events
from ignite.handlers import ModelCheckpoint, EarlyStopping
from ignite._utils import convert_tensor
from utils import Experiment
from utils.factory import *
logging.basicConfig(level=logging.INFO, format='')
logger = logging.getLogger()
torch.backends.cudnn.benchmark = True
def main(config):
assert validate_config(config), "ERROR: Config file is invalid. Please see log for details."
logger.info("INFO: {}".format(config.toDict()))
# Set the random number generator seed for torch, as we use their dataloaders this will ensure shuffle is constant
# Remeber to seed custom datasets etc with the same seed
if config.seed > 0:
torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed_all(config.seed)
torch.manual_seed(config.seed)
random.seed(config.seed)
np.random.seed(config.seed)
if config.device == "cpu" and torch.cuda.is_available():
logger.warning("WARNING: Not using the GPU")
elif config.device == "cuda":
config.device = f"cuda:{config.device_ids[0]}"
config.nsave = config.nsave if "nsave" in config else 5
logger.info("INFO: Creating datasets and dataloaders...")
# Create the training dataset
dset_train = create_dataset(config.datasets.train)
# If the validation config has a parameter called split then we ask the training dset for the validation dataset
# it should be noted that you shouldn't shuffle the dataset in the init of the train dataset if this is the case
# as only on get_validation_split will we know how to split the data. Unless shuffling is deterministic.
train_ids = None
if 'validation' in config.datasets:
# Ensure we have a full config for validation, this means we don't need t specify everything in the config file
# only the differences
config_val = config.datasets.train.copy()
config_val.update(config.datasets.validation)
dset_val = create_dataset(config_val)
loader_val = get_data_loader(dset_val, config_val)
print("Using validation dataset of {} samples or {} batches".format(len(dset_val), len(loader_val)))
elif 'includes_validation' in config.datasets.train:
train_ids, val_ids = dset_train.get_validation_split(config_val)
loader_val = get_data_loader(dset_train, config.datasets.train, val_ids)
print("Using validation dataset of {} samples or {} batches".format(len(val_ids), len(loader_val)))
else:
logger.warning("WARNING: No validation dataset was specified")
dset_val = None
loader_val = None
loader_train = get_data_loader(dset_train, config.datasets.train, train_ids)
dset_len = len(train_ids) if train_ids is not None else len(dset_train)
print("Using training dataset of {} samples or {} batches".format(dset_len, len(loader_train)))
cp_paths = None
last_epoch = 0
if 'checkpoint' in config:
checkpoint_dir = config.checkpoint_dir if 'checkpoint_dir' in config else config.result_path
cp_paths, last_epoch = config.get_checkpoints(path=checkpoint_dir, tag=config.checkpoint)
print("Found checkpoint {} for Epoch {}".format(config.checkpoint, last_epoch))
last_epoch = last_epoch if config.resume_from == -1 else config.resume_from
# config.epochs = config.epochs - last_epoch if last_epoch else config.epochs
models = {}
for name, model in config.model.items():
logger.info("INFO: Building the {} model".format(name))
models[name] = build_model(model)
# Load the checkpoint
if name in cp_paths:
models[name].load_state_dict( torch.load( cp_paths[name] ) )
logger.info("INFO: Loaded model {} checkpoint {}".format(name, cp_paths[name]))
if len(config.device_ids) > 1:
models[name] = nn.DataParallel(models[name], device_ids=config.device_ids)
models[name].to(config.device)
print(models[name])
if 'debug' in config and config.debug is True:
print("*********** {} ************".format(name))
for name, param in models[name].named_parameters():
if param.requires_grad:
print(name, param.data)
optimizers = {}
for name, conf in config.optimizer.items():
optim_conf = conf.copy()
del optim_conf["models"]
model_params = []
for model_id in conf.models:
model_params.extend( list(filter(lambda p: p.requires_grad, models[model_id].parameters())) )
logger.info("INFO: Using {} Optimization for {}".format(list(optim_conf.keys())[0], name))
optimizers[name] = get_optimizer(model_params, optim_conf)
# Restoring the optimizer breaks because we do not include all parameters in the optimizer state. So if we aren't continuing training then just make a new optimizer
if name in cp_paths and 'checkpoint_dir' not in config:
optimizers[name].load_state_dict( torch.load( cp_paths[name] ) )
logger.info("INFO: Loaded {} optimizer checkpoint {}".format(name, cp_paths[name]))
for state in optimizers[name].state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(config.device)
losses = {}
for name, fcns in config.loss.items():
losses[name] = []
for l in fcns:
losses[name].append( get_loss(l) )
assert losses[name][-1], "Loss function {} for {} could not be found, please check your config".format(l, name)
if 'logger' in config:
logger.info("INFO: Initialising the experiment logger")
exp_logger = get_experiment_logger(config.result_path, config.logger)
if last_epoch > 0:
exp_logger.fast_forward(last_epoch, len(loader_train))
logger.info("INFO: Creating training manager and configuring callbacks")
trainer = get_trainer(models, optimizers, losses, exp_logger, config)
trainer_engine = Engine(trainer.train)
evaluator_engine = Engine(trainer.evaluate)
trainer.attach("train_loader", loader_train)
trainer.attach("validation_loader", loader_val)
trainer.attach("evaluation_engine", evaluator_engine)
trainer.attach("train_engine", trainer_engine)
for phase in config.metrics.keys():
if phase == "train": engine = trainer_engine
if phase == "validation": engine = evaluator_engine
for name, metric in config.metrics[phase].items():
metric = get_metric(metric)
if metric is not None:
metric.attach(engine, name)
else:
logger.warning("WARNING: Metric {} could not be created for {} phase".format(name, phase))
# Register default callbacks to run the validation stage
if loader_val is not None:
if len(loader_train) > 2000:
# Validate 4 times an epoch
num_batch = len(loader_train)//4
def validate_run(engine):
if engine.state.iteration % num_batch == 0:
evaluator_engine.run(loader_val)
trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, validate_run)
else:
trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: evaluator_engine.run(loader_val))
# Initialise the Epoch from the checkpoint - this is a hack because Ignite is dumb
if last_epoch > 0:
def set_epoch(engine, last_epoch):
engine.state.epoch = last_epoch
trainer_engine.add_event_handler(Events.STARTED, set_epoch, last_epoch)
schedulers = {"batch": {}, "epoch": {}}
if 'scheduler' in config:
for sched_name, sched in config.scheduler.items():
if sched_name in optimizers:
logger.info("INFO: Setting up LR scheduler for {}".format(sched_name))
sched_fn, sched_scheme = get_lr_scheduler(optimizers[sched_name], sched)
assert sched_fn, "Learning Rate scheduler for {} could not be found, please check your config".format(sched_name)
assert sched_scheme in ["batch", "epoch"], "ERROR: Invalid scheduler scheme, must be either epoch or batch"
schedulers[sched_scheme][sched_name] = sched_fn
def epoch_scheduler(engine):
for name, sched in schedulers["epoch"].items():
sched.step()
def batch_scheduler(engine):
for name, sched in schedulers["batch"].items():
sched.step()
trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: batch_scheduler(engine))
trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: epoch_scheduler(engine))
if exp_logger is not None:
trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, exp_logger.log_iteration, phase="train", models=models, optims=optimizers)
trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, exp_logger.log_epoch, phase="train", models=models, optims=optimizers)
evaluator_engine.add_event_handler(Events.ITERATION_COMPLETED, exp_logger.log_iteration, phase="evaluate", models=models, optims=optimizers)
evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED, exp_logger.log_epoch, phase="evaluate", models=models, optims=optimizers)
if "monitor" in config and config.monitor.early_stopping:
logger.info("INFO: Enabling early stopping, monitoring {}".format(config.monitor.score))
score_fn = lambda e: config.monitor.scale * e.state.metrics[config.monitor.score]
es_handler = EarlyStopping(patience=config.monitor.patience, score_function=score_fn, trainer=trainer_engine)
evaluator_engine.add_event_handler(Events.COMPLETED, es_handler)
if "monitor" in config and config.monitor.save_score:
logger.info("INFO: Saving best model based on {}".format(config.monitor.save_score))
score_fn = lambda e: config.monitor.save_scale * e.state.metrics[config.monitor.save_score]
ch_handler = ModelCheckpoint(config.result_path, 'best_checkpoint', score_function=score_fn, score_name=config.monitor.save_score, n_saved=1, require_empty=False, save_as_state_dict=True)
to_save = dict(models, **optimizers)
evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED, ch_handler, to_save)
if config.save_freq > 0:
ch_handler = ModelCheckpoint(config.result_path, 'checkpoint', save_interval=config.save_freq, n_saved=config.nsave, require_empty=False, save_as_state_dict=True)
to_save = dict(models, **optimizers)
trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, ch_handler, to_save)
# Register custom callbacks with the engines
if check_if_implemented(trainer, "on_iteration_start"):
trainer_engine.add_event_handler(Events.ITERATION_STARTED, trainer.on_iteration_start, phase="train")
evaluator_engine.add_event_handler(Events.ITERATION_STARTED, trainer.on_iteration_start, phase="evaluate")
if check_if_implemented(trainer, "on_iteration_end"):
trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, trainer.on_iteration_end, phase="train")
evaluator_engine.add_event_handler(Events.ITERATION_COMPLETED, trainer.on_iteration_end, phase="evaluate")
if check_if_implemented(trainer, "on_epoch_start"):
trainer_engine.add_event_handler(Events.EPOCH_STARTED, trainer.on_epoch_start, phase="train")
evaluator_engine.add_event_handler(Events.EPOCH_STARTED, trainer.on_epoch_start, phase="evaluate")
if check_if_implemented(trainer, "on_epoch_end"):
trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, trainer.on_epoch_end, phase="train")
evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED, trainer.on_epoch_end, phase="evaluate")
# Save the config for this experiment to the results directory, once we know the params are good
config.save()
def signal_handler(sig, frame):
print('You pressed Ctrl+C!')
if exp_logger is not None:
exp_logger.teardown()
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
logger.info("INFO: Starting training...")
trainer_engine.run(loader_train, max_epochs=config.epochs)
if exp_logger is not None:
exp_logger.teardown()
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('-c', '--config', default=None, type=str, required=True, help='config file path (default: None)')
parser.add_argument('--checkpoint', default=None, type=str, help='Checkpoint tag to reload')
parser.add_argument('--checkpoint_dir', default=None, type=str, help='Checkpoint directory to reload')
parser.add_argument('--suffix', default=None, type=str, help='Add to the name')
parser.add_argument('--epochs', default=None, type=int, help='Number of epochs')
parser.add_argument('--resume_from', default=None, type=int, help='Epoch to resume from, allows using checkpoints as initialisation')
args = parser.parse_args()
OVERLOADABLE = ['checkpoint', 'epochs', 'checkpoint_dir', 'resume_from']
overloaded = {}
for k, v in vars(args).items():
if (k in OVERLOADABLE) and (v is not None):
overloaded[k] = v
config = Experiment.load_from_path(args.config, overloaded, args.suffix)
assert config, "Config could not be loaded."
# Else load the saved config from the results dir or throw an error if one doesn't exist
if len(config.checkpoint) > 0:
logger.warning("WARNING: --config specifies resuming, overriding config with exising experiment config.")
# resume_config = Experiment(config.name, desc=config.desc, result_dir=config.result_dir).load()
# assert resume_config is not None, "No experiment {} exists, cannot resume training".format(config.name)
# config = resume_config
assert config, "Config could not be loaded for resume"
# If we have resume_from in the config but have it < 0 to start a fresh training run then throw and error if the directory already exists
elif config.overwrite is False:
assert not config.exists(), "Results directory {} already exists! Please specify a new experiment name or the remove old files.".format(config.result_path)
else:
empty_folder(config.result_path)
main(config)