forked from Rick-McCoy/Reformer-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
65 lines (54 loc) · 2.39 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import platform
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.logging import TestTubeLogger
from utils.hparams import HParam
from model.model import Reformer
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, required=True,
help="yaml file for configuration")
parser.add_argument('-n', '--name', type=str, required=True,
help="name of the model for logging, saving checkpoint")
parser.add_argument('-b', '--batch_size', type=int, required=True,
help="batch size to be used")
parser.add_argument('-f', '--fast_dev_run', type=bool, required=False, default=False,
help="enable fast dev run for debugging purposes")
parser.add_argument('-v', '--version', type=int, required=False, default=None,
help="version to resume checkpoint from, default new version")
parser.add_argument('-t', '--trace', type=bool, required=False, default=False,
help="enable tracing for debugging purposes")
parser.add_argument('-s', '--sample', type=bool, required=False, default=False,
help="enable sampling, overrides -f")
args = parser.parse_args()
if args.sample:
args.fast_dev_run = True
hp = HParam(args.config)
if platform.system() == 'Windows':
hp.train.num_workers = 0
reformer = Reformer(hp, args)
logger = TestTubeLogger(
save_dir=hp.log.path,
name=args.name,
version=args.version,
)
trainer = Trainer(
logger=logger,
default_save_path=hp.log.path,
distributed_backend=None if platform.system() == 'Windows' else 'ddp',
fast_dev_run=args.fast_dev_run,
gpus=1,
accumulate_grad_batches=hp.train.accumulate,
min_nb_epochs=hp.train.epochs,
max_nb_epochs=hp.train.epochs,
weights_summary='full',
early_stop_callback=None,
val_check_interval=0.5
)
with torch.autograd.profiler.profile(enabled=args.trace, use_cuda=True) as prof:
if not args.sample:
trainer.fit(reformer)
trainer.test(reformer)
if args.trace:
prof.export_chrome_trace('traces/trace_' + str(logger.version) + '.json')