forked from jcjohnson/torch-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LanguageModelSkip_dIn.lua
397 lines (333 loc) · 14.1 KB
/
LanguageModelSkip_dIn.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
require 'torch'
require 'nn'
require 'VanillaRNN'
require 'LSTM'
require 'dimPrint' -- Self-written debugging module whose purpose is to print the dimensions of tensors passing through.
local utils = require 'util.utils'
local LM, parent = torch.class('nn.LanguageModelSkip_dIn', 'nn.Module')
function LM:__init(kwargs)
self.idx_to_token = utils.get_kwarg(kwargs, 'idx_to_token')
self.token_to_idx = {}
self.vocab_size = 0
for idx, token in pairs(self.idx_to_token) do
self.token_to_idx[token] = idx
self.vocab_size = self.vocab_size + 1
end
self.model_type = utils.get_kwarg(kwargs, 'model_type')
self.wordvec_dim = utils.get_kwarg(kwargs, 'wordvec_size')
self.rnn_size = utils.get_kwarg(kwargs, 'rnn_size')
self.num_layers = utils.get_kwarg(kwargs, 'num_layers')
self.dropout = utils.get_kwarg(kwargs, 'dropout')
self.batchnorm = utils.get_kwarg(kwargs, 'batchnorm')
-- This model is identical to the skip connected network (LanguageModelSkipCon.lua), except
-- It expects to receive a 47 long vector as the input, for every corresponding output character
-- In the normal network, input is (N, T) tensor of character indices - they get put through a LUT
-- Here the data has already been decoded
-- An input batch is therefore of size (N, T, 47) instead of (N, T) as it is when taking character input
-- As a result, removing the nn.LookupTable should sort it out.
-- Also dimensions of the hidden cells need to be changed accordingly.
local V, H = self.vocab_size, self.rnn_size
local D = 47 + self.wordvec_dim -- Hard setting the input vector, data is a 47 wide vector, char comes out as self.wordvec_dim.
self.net = nn.Sequential()
self.rnns = {}
self.bn_view_in = {}
self.bn_view_out = {}
--[[
-- Building a sub-module to handle the double inputs
local decoderContainer = nn.Sequential()
--decoderContainer:add(nn.ConcatTable()) -- Network now takes 2 inputs, {data, char} and outputs {nextChar}
local decoderConcat = nn.ConcatTable()
local dc1 = nn.Sequential()
local dc2 = nn.Sequential()
dc1:add(nn.SelectTable(1)) -- This gets the data vector
-- We don't change this vector.
dc2:add(nn.SelectTable(2)) -- this gets the char to be decoded
dc2:add(nn.LookupTable(V, (D-47))) -- D - 47 because it should be only the size of one character.
decoderConcat:add(dc1)
decoderConcat:add(dc2) -- decoderConcat should output a table.
decoderContainer:add(decoderConcat)
decoderContainer:add(nn.JoinTable(3, 3))
--]]
-- Building a sub-module to handle splitting of data/char tuple, encoding of char, re-concatenation.
-- Input is size (N, T, 47 + V)
-- Split to (N, T, 47), (N, T, V)
-- Decode (N, T, V) -> (N, T, D)
-- merge along dimension 3: (N, T, 47) + (N, T, D) -> (N, T, D+47)
local decoderContainer = nn.Sequential()
local dConcat1 = nn.ConcatTable()
local dCSeq1 = nn.Sequential()
local dCSeq2 = nn.Sequential()
dCSeq1:add(nn.Narrow(3, 1, 47)) -- Slice out the data
dCSeq2:add(nn.Narrow(3, 48, 1)) -- Slice out the characters
dCSeq2:add(nn.Squeeze(3)) -- Get rid of the singleton dimension
dCSeq2:add(nn.LookupTable(V, (D-47))) -- D - 47 because it should be only the size of one character.
dConcat1:add(dCSeq1)
dConcat1:add(dCSeq2)
decoderContainer:add(dConcat1)
decoderContainer:add(nn.JoinTable(3, 3))
self.net:add(decoderContainer)
for i = 1, self.num_layers do
-- Selecting input dimensions for LSTM cells
local prev_dim = D+H -- All LSTMs in layers 2 onwards have input dimension D+H (H because of skip connections)
if i == 1 then prev_dim = D end -- First layer LSTM has input dimension D only
-- Selecting cell type to use
local rnn
local rnnContainer = nn.Sequential()
if self.model_type == 'rnn' then
rnn = nn.VanillaRNN(prev_dim, H)
elseif self.model_type == 'lstm' then
rnn = nn.LSTM(prev_dim, H)
end
rnn.remember_states = true
table.insert(self.rnns, rnn) -- This table is used to find all the rnn cells and reset them later.
rnnContainer:add(rnn)
-- Batch normalisation
if self.batchnorm == 1 then
local view_in = nn.View(1, 1, -1):setNumInputDims(3)
table.insert(self.bn_view_in, view_in)
rnnContainer:add(view_in)
rnnContainer:add(nn.BatchNormalization(H))
local view_out = nn.View(1, -1):setNumInputDims(2)
table.insert(self.bn_view_out, view_out)
rnnContainer:add(view_out)
end
-- Dropout
if self.dropout > 0 then
rnnContainer:add(nn.Dropout(self.dropout))
end
-- Construct and link the layers
if i == 1 then -- Set up first layer
local t1 = nn.Sequential() -- Contains both sub-layers
local t11 = nn.ConcatTable() -- First sub-layer
local t12 = nn.ConcatTable() -- Second sub-layer
t11:add(rnnContainer)
t11:add(nn.Identity()) -- output 1 of this layer is LSTM output, output 2 is input.
-- Output from t11: table of 2 elements:
-- {LSTM output, network input}
t12:add(nn.SelectTable(1)) -- Grab LSTM output only from first sublayer.
t12:add(nn.JoinTable(3, 3)) -- For incoming skip connection to next LSTM layer.
t12:add(nn.SelectTable(2)) -- Forward input for use in future incoming skip connections.
-- Output from t12: table of 3 elements:
-- {LSTM output, LSTM output + network input, network input}
t1:add(t11)
t1:add(t12) -- Construct the complete first layer.
self.net:add(t1) -- Add the completed layer to the overall network container.
elseif i ~= 1 then -- Set up any layers after the first layer
local t1 = nn.Sequential() -- Container for both sub-layers
local t11 = nn.ConcatTable() -- First sub-layer
local t12 = nn.ConcatTable() -- Second sub-layer
-- Define sequentials to contain a SelectTable and a module.
local seq1 = nn.Sequential()
local seq2 = nn.Sequential()
local seq3 = nn.Sequential()
seq1:add(nn.SelectTable(1))
seq2:add(nn.SelectTable(2))
seq3:add(nn.SelectTable(3)) -- This bit replicates the nn.ParallelTable functionality
seq1:add(nn.Identity())
seq2:add(rnnContainer) -- LSTM/RNN and Batchnorm and dropout if enabled
seq3:add(nn.Identity())
t11:add(seq1)
t11:add(seq2)
t11:add(seq3)
-- Output from t11: table of 3 elements:
-- {outgoing skipcon forwarder, LSTM output, input forwarder}
local t12s = nn.Sequential() -- Container for handling outgoing skip connection
t12s:add(nn.NarrowTable(1, 2)) -- Select only the first two outputs from t11
t12s:add(nn.JoinTable(3)) -- Add to the outgoing skip connection 'accumulator'
local t12sj = nn.Sequential() -- Container for handling the incoming skip connection
t12sj:add(nn.NarrowTable(2, 2)) -- Select only the LSTM output and the forwarded network input.
t12sj:add(nn.JoinTable(3)) -- Join network input and LSTM input
t12:add(t12s) -- Handles outgoing skip connection to accumulator
t12:add(t12sj) -- For incoming skip connection to next LSTM layer.
t12:add(nn.SelectTable(3)) -- Forward input for use in future incoming skip connections.
-- Output from t12: table of 3 elements:
-- {Output from outgoing skipcon accumulator, LSTM output + network input, network input}
t1:add(t11)
t1:add(t12) -- Construct the complete layer.
self.net:add(t1) -- Add the completed layer to the overall network container.
end
end
self.net:add(nn.SelectTable(1)) -- This contains the outgoing skip connection accumulator (a large table)
-- After all the RNNs run, we will have a tensor of shape (N, T, H);
-- we want to apply a 1D temporal convolution to predict scores for each
-- vocab element, giving a tensor of shape (N, T, V). Unfortunately
-- nn.TemporalConvolution is SUPER slow, so instead we will use a pair of
-- views (N, T, H) -> (NT, H) and (NT, V) -> (N, T, V) with a nn.Linear in
-- between. Unfortunately N and T can change on every minibatch, so we need
-- to set them in the forward pass.
--
-- MODIFICATION:
-- Introducing skip connections means that after running, the tensor is of shape (N, T, self.num_layers * H)
-- The same approach is used as in the original language model, but will be modified so that the layers are:
-- view (N, T, self.num_layers * H) -> (NT, self.num_layers * H)
-- nn.Linear(self.num_layers * H, V)
-- view (NT, V) -> (N, T, V)
--
self.view1 = nn.View(1, 1, -1):setNumInputDims(3)
self.view2 = nn.View(1, -1):setNumInputDims(2)
self.net:add(self.view1)
self.net:add(nn.Linear(self.num_layers * H, V))
self.net:add(self.view2)
end
function LM:updateOutput(input)
local N, T = input:size(1), input:size(2)
self.view1:resetSize(N * T, -1)
self.view2:resetSize(N, T, -1)
for _, view_in in ipairs(self.bn_view_in) do
view_in:resetSize(N * T, -1)
end
for _, view_out in ipairs(self.bn_view_out) do
view_out:resetSize(N, T, -1)
end
return self.net:forward(input)
end
function LM:backward(input, gradOutput, scale)
return self.net:backward(input, gradOutput, scale)
end
function LM:parameters()
return self.net:parameters()
end
function LM:resetStates()
for i, rnn in ipairs(self.rnns) do
rnn:resetStates()
end
end
function LM:encode_string(s)
local encoded = torch.DoubleTensor(#s)
for i = 1, #s do
local token = s:sub(i, i)
local idx = self.token_to_idx[token]
assert(idx ~= nil, 'Got invalid idx')
encoded[i] = idx
end
return encoded
end
function LM:decode_string(encoded)
assert(torch.isTensor(encoded) and encoded:dim() == 1)
local s = ''
for i = 1, encoded:size(1) do
local idx = encoded[i]
local token = self.idx_to_token[idx]
s = s .. token
end
return s
end
-- The below function modified from http://stackoverflow.com/questions/1426954/split-string-in-lua
function LM:parseInArr(inputstr)
sep = ", "
local t={}; i=1
for str in string.gmatch(inputstr, "([^"..sep.."]+)") do
t[i] = tonumber(str)
i = i + 1
end
return torch.LongTensor(t)
end
--[[
Sample from the language model. Note that this will reset the states of the
underlying RNNs.
Inputs:
- init: String of length T0
- max_length: Number of characters to sample
Returns:
- sampled: (1, max_length) array of integers, where the first part is init.
--]]
function LM:sample(kwargs)
local T = utils.get_kwarg(kwargs, 'length', 100)
local start_text = utils.get_kwarg(kwargs, 'start_text', '')
local verbose = utils.get_kwarg(kwargs, 'verbose', 0)
local sample = utils.get_kwarg(kwargs, 'sample', 1)
local temperature = utils.get_kwarg(kwargs, 'temperature', 1)
local nullStop = utils.get_kwarg(kwargs, 'nullstop', 0) -- Argument to stop sampling/truncate output after a null character is generated.
if nullStop > 0 then -- Change the sample limit if the nullStop argument is set to 1.
T = 20000 -- Hardcoding this to be truncated later, should be adequate... HEY I'M SURE IT'S PROBABLY FINE RIGHT
end
local sampled = torch.LongTensor(1, T)
self:resetStates()
local scores, first_t, x_in
if #start_text > 0 then
if verbose > 0 then
print('Seeding with: "' .. start_text .. '"')
end
x_in = self:parseInArr(start_text):view(1, 1, -1)
local T0 = 1
--sampled[{{}, {1, T0}}]:copy(x)
--x_in is (1, 1, 47) at the moment
local start_char = torch.LongTensor(1, 1, 1)
local temp1 = self:encode_string("."):view(1, -1)
start_char[{{}, {}, 1}] = temp1 -- Seed with start character "a".
-- encoder produces vector of (1, stringlength)
local netInput = torch.cat(x_in, start_char, 3) -- Concatenate data+char to make (1, 1, 48) tensor
--netInput = netInput:type('torch.CudaTensor')
netInput = netInput:type('torch.FloatTensor')
scores = self:forward(netInput)[{{}, {T0, T0}}]
first_t = 1
else
if verbose > 0 then
print('Seeding with uniform probabilities')
end
local w = self.net:get(1).weight
scores = w.new(1, 1, self.vocab_size):fill(1)
first_t = 1
end
local n1Flag = 0 -- Having to use flags to detect the string "\n\n." which should mean the end of a forecast.
local n2Flag = 0
-- Trying to remove a little overhead by repeating the loop twice - once for nullStop, once for no nullStop.
if nullStop > 0 then
for t = first_t, T do
if sample == 0 then
local _, next_char = scores:max(3)
next_char = next_char[{{}, {}, 1}]
else
local probs = torch.div(scores, temperature):double():exp():squeeze()
probs:div(torch.sum(probs))
next_char = torch.multinomial(probs, 1):view(1, 1)
end
sampled[{{}, {t, t}}]:copy(next_char)
--scores = self:forward(next_char)
-- netInput will now contain a concatenation of the previous character and the data
-- i.e. [x_in][next_char]
--print("next_char DIMENSIONS: ", #next_char)
--os.exit()
netInput = torch.cat(x_in, next_char, 3)
netInput = netInput:type('torch.FloatTensor')
scores = self:forward(netInput)
if (n1Flag == 0 and n2Flag == 0) then -- No newlines detected yet.
if (self:decode_string(next_char[1]) == "\n") then
n1Flag = 1
end
elseif (n1Flag == 1 and n2Flag == 0) then -- First newline detected already
if (self:decode_string(next_char[1]) == "\n") then
n2Flag = 1 -- Say that we've detected the second newline
else
n1Flag = 0 -- Reset since the latest character is just a normal newline.
end
elseif (n1Flag == 1 and n2Flag == 1) then -- 2 newlines detected in a row.
if (self:decode_string(next_char[1]) == ".") then -- This is the last of the "\n\n."
sampled:resize(1, t-1) -- Resize output vector to the final size
break -- If a null character is received then stop sampling. Don't write the character to the output here.
else
n1Flag = 0
n2Flag = 0
end
end
end
else -- Same thing, without the comparisons and truncation.
for t = first_t, T do
if sample == 0 then
local _, next_char = scores:max(3)
next_char = next_char[{{}, {}, 1}]
else
local probs = torch.div(scores, temperature):double():exp():squeeze()
probs:div(torch.sum(probs))
next_char = torch.multinomial(probs, 1):view(1, 1)
end
sampled[{{}, {t, t}}]:copy(next_char)
scores = self:forward(x_in)
end
end
self:resetStates()
return self:decode_string(sampled[1])
end
function LM:clearState()
self.net:clearState()
end