forked from kuprel/pycaffe-recurrent
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_recurrent.py
78 lines (58 loc) · 2 KB
/
train_recurrent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import caffe, numpy, string, os, shutil, cPickle as pickle
save_pkl = lambda f, obj: pickle.dump(obj, open(f, 'wb'), protocol=-1)
load_pkl = lambda f: pickle.load(open(f, 'rb'))
sf = lambda *x: string.join([str(i) for i in x], '_')
solver = caffe.get_solver('solver.prototxt')
X = numpy.load('X.npy')
Y = numpy.load('Y.npy')
m, T, b = X.shape
L = max([int(i.split('_')[-1])
for i in solver.net.blobs.keys()
if 'h_{}_'.format(T) in i]) + 1
param_corresp = [(sf('fc',l), sf('fc',0,l)) for l in range(L+1)]
if os.path.isdir('params'): shutil.rmtree('params')
os.makedirs('params')
step_num = 10
epoch = 1
while True:
print 'epoch {}'.format(epoch)
for i in range(m//2):
# Copy previous final state to current initial state
for l in range(L):
state_i = solver.net.blobs[sf('h',0,l)].data
state_f = solver.net.blobs[sf('h',1,l)].data
state_i[...] = state_f
# Insert data
for t in range(T):
xt = solver.net.blobs[sf('x',t)].data
yt = solver.net.blobs[sf('y',t)].data
if i>0: xt[range(b), X[i-1,t]] = 0
xt[range(b), X[i,t]] = 1
yt[...] = Y[i,t]
# Test net
if solver.iter%1 == 0:
# Save params
params = {ki: [pr.data for pr in solver.net.params[kj]]
for ki, kj in param_corresp}
save_pkl('params/iter%08d.pkl'%solver.iter, params)
solver.step(step_num)
# Insert data
for t in range(T):
xt = solver.test_nets[0].blobs[sf('x',t)].data
yt = solver.test_nets[0].blobs[sf('y',t)].data
if i>0: xt[range(b), X[i+m//2-1,t]] = 0
xt[range(b), X[i+m//2,t]] = 1
yt[...] = Y[i+m//2,t]
solver.test_nets[0].forward()
# Compute loss
loss = lambda t: solver.test_nets[0].blobs[sf('loss',t)].data
loss = numpy.mean([loss(t) for t in range(T)])
print 'test loss: {}, iter {}'.format(loss, solver.iter)
# Reset input and initial state
for net in [solver.net, solver.test_nets[0]]:
for t in range(T):
net.blobs[sf('x',t)].data[...] = 0
for l in range(L):
net.blobs[sf('h',0,l)].data[...] = 0
epoch += 1
step_num = max(1, step_num/2)