Skip to content

Commit

Permalink
Improve variable name resolution by making more variables local.
Browse files Browse the repository at this point in the history
  • Loading branch information
SilverNexus committed Feb 10, 2019
1 parent 0275073 commit 06e8adb
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ for name,proto in pairs(protos) do
end

-- preprocessing helper function
function prepro(x,y)
local function prepro(x,y)
x = x:transpose(1,2):contiguous() -- swap the axes for faster indexing
y = y:transpose(1,2):contiguous()
if opt.gpuid >= 0 and opt.opencl == 0 then -- ship the input arrays to GPU
Expand Down Expand Up @@ -271,22 +271,23 @@ function feval(x)
local rnn_state = {[0] = init_state_global}
local predictions = {} -- softmax outputs
local loss = 0
local local_clones = clones
for t=1,opt.seq_length do
clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag)
local lst = clones.rnn[t]:forward{x[t], unpack(rnn_state[t-1])}
local_clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag)
local lst = local_clones.rnn[t]:forward{x[t], unpack(rnn_state[t-1])}
rnn_state[t] = {unpack(lst, i, #init_state)} -- extract the state, without output
predictions[t] = lst[#lst] -- last element is the prediction
loss = loss + clones.criterion[t]:forward(predictions[t], y[t])
loss = loss + local_clones.criterion[t]:forward(predictions[t], y[t])
end
loss = loss / opt.seq_length
------------------ backward pass -------------------
-- initialize gradient at time t to be zeros (there's no influence from future)
local drnn_state = {[opt.seq_length] = clone_list(init_state, true)} -- true also zeros the clones
for t=opt.seq_length,1,-1 do
-- backprop through loss, and softmax/linear
local doutput_t = clones.criterion[t]:backward(predictions[t], y[t])
local doutput_t = local_clones.criterion[t]:backward(predictions[t], y[t])
drnn_state[t][#drnn_state[t]+1] = doutput_t
local dlst = clones.rnn[t]:backward({x[t], unpack(rnn_state[t-1])}, drnn_state[t])
local dlst = local_clones.rnn[t]:backward({x[t], unpack(rnn_state[t-1])}, drnn_state[t])
-- dlst[1] is the gradient on x, which we don't need
-- using unpack should slide the values into the correct indexes, allowing us to forego a loop.
drnn_state[t-1] = {unpack(dlst, 2)}
Expand Down

0 comments on commit 06e8adb

Please sign in to comment.