-
Notifications
You must be signed in to change notification settings - Fork 6
/
train.py
108 lines (90 loc) · 4.14 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
import time
import os
import shutil
from datetime import datetime
from config import *
from model import model
from utils import Helpers, DataPreprocessor, MiniBatchLoader
def main(save_path, params):
nhidden = params['nhidden']
dropout = params['dropout']
word2vec = params['word2vec']
sub2vec = params['sub2vec']
subdict = params['subdic']
dataset = params['data']
nlayers = params['nlayers']
train_emb = params['train_emb']
sub_dim = params['sub_dim']
use_feat = params['use_feat']
gating_fn = params['gating_fn']
# save settings
shutil.copyfile('config.py', '%s/config.py' % save_path)
use_subs = sub_dim > 0
dp = DataPreprocessor.DataPreprocessor()
data = dp.preprocess(dataset, no_training_set=False, use_subs=use_subs, subdict=subdict)
print "building minibatch loaders ...", datetime.now().strftime('%Y-%m-%d %H:%M:%S')
batch_loader_train = MiniBatchLoader.MiniBatchLoader(data.training, BATCH_SIZE,
sample=1)
batch_loader_val = MiniBatchLoader.MiniBatchLoader(data.validation, BATCH_SIZE)
print "building network ...", datetime.now().strftime('%Y-%m-%d %H:%M:%S')
W_init, embed_dim, = Helpers.load_word2vec_embeddings(data.dictionary[0], word2vec)
S_init, sub_dim = Helpers.load_sub_embeddings(data.dictionary[1], sub2vec)
m = model.Model(nlayers, data.vocab_size, data.num_chars, W_init,S_init,
nhidden, embed_dim, dropout, train_emb,
sub_dim, use_feat, gating_fn)
print "training ...", datetime.now().strftime('%Y-%m-%d %H:%M:%S')
num_iter = 0
max_acc = 0.
deltas = []
logger = open(save_path + '/log', 'a', 0)
if os.path.isfile('%s/best_model.p' % save_path):
print 'loading previously saved model', datetime.now().strftime('%Y-%m-%d %H:%M:%S')
m.load_model('%s/best_model.p' % save_path)
print "model loaded"
else:
print 'saving init model', datetime.now().strftime('%Y-%m-%d %H:%M:%S')
m.save_model('%s/model_init.p' % save_path)
print 'loading init model', datetime.now().strftime('%Y-%m-%d %H:%M:%S')
m.load_model('%s/model_init.p' % save_path)
for epoch in xrange(NUM_EPOCHS):
print "epochs training ...", datetime.now().strftime('%Y-%m-%d %H:%M:%S')
estart = time.time()
new_max = False
for dw, dt, qw, qt, a, m_dw, m_qw, tt, tm, c, m_c, cl, fnames in batch_loader_train:
loss, tr_acc, probs = m.train(dw, dt, qw, qt, c, a, m_dw, m_qw, tt, tm, m_c, cl)
message = "Epoch %d TRAIN loss=%.4e acc=%.4f elapsed=%.1f" % (
epoch, loss, tr_acc, time.time() - estart)
print message
logger.write(message + '\n')
num_iter += 1
if num_iter % VALIDATION_FREQ == 0:
total_loss, total_acc, n, n_cand = 0., 0., 0, 0.
for dw, dt, qw, qt, a, m_dw, m_qw, tt, tm, c, m_c, cl, fnames in batch_loader_val:
outs = m.validate(dw, dt, qw, qt, c, a, m_dw, m_qw, tt, tm, m_c, cl)
loss, acc, probs = outs[:3]
bsize = dw.shape[0]
total_loss += bsize * loss
total_acc += bsize * acc
n += bsize
print ('validate on ',str(n)+'validation data')
val_acc = total_acc / n
if val_acc > max_acc:
max_acc = val_acc
m.save_model('%s/best_model.p' % save_path)
new_max = True
message = "Epoch %d VAL loss=%.4e acc=%.4f max_acc=%.4f" % (
epoch, total_loss / n, val_acc, max_acc)
print message
logger.write(message + '\n')
m.save_model('%s/model_%d.p' % (save_path, epoch))
message = "After Epoch %d: Train acc=%.4f, Val acc=%.4f" % (epoch, tr_acc, val_acc)
print message
logger.write(message + '\n')
# learning schedule
if epoch >= 2:
m.anneal()
# stopping criterion
if not new_max:
break
logger.close()