diff --git a/copy.lua b/copy.lua new file mode 100644 index 00000000..954450f5 --- /dev/null +++ b/copy.lua @@ -0,0 +1,158 @@ +--[[ + + Training a NTM to memorize input. + + The current version seems to work, giving good output after 5000 iterations + or so. Proper initialization of the read/write weights seems to be crucial + here. + +--]] + +require('../') +require('./util') +require('optim') +require('sys') + +torch.manualSeed(0) + +-- NTM config +local config = { + input_dim = 10, + output_dim = 10, + mem_rows = 128, + mem_cols = 20, + cont_dim = 100 +} + +local input_dim = config.input_dim +local start_symbol = torch.zeros(input_dim) +start_symbol[1] = 1 +local end_symbol = torch.zeros(input_dim) +end_symbol[2] = 1 + +function generate_sequence(len, bits) + local seq = torch.zeros(len, bits + 2) + for i = 1, len do + seq[{i, {3, bits + 2}}] = torch.rand(bits):round() + end + return seq +end + +function forward(model, seq, print_flag) + local len = seq:size(1) + local loss = 0 + + -- present start symbol + model:forward(start_symbol) + + -- present inputs + if print_flag then print('write head max') end + for j = 1, len do + model:forward(seq[j]) + if print_flag then print_write_max(model) end + end + + -- present end symbol + model:forward(end_symbol) + + -- present targets + local zeros = torch.zeros(input_dim) + local outputs = torch.Tensor(len, input_dim) + local criteria = {} + if print_flag then print('read head max') end + for j = 1, len do + criteria[j] = nn.BCECriterion() + outputs[j] = model:forward(zeros) + loss = loss + criteria[j]:forward(outputs[j], seq[j]) * input_dim + if print_flag then print_read_max(model) end + end + return outputs, criteria, loss +end + +function backward(model, seq, outputs, criteria) + local len = seq:size(1) + local zeros = torch.zeros(input_dim) + for j = len, 1, -1 do + model:backward( + zeros, + criteria[j] + :backward(outputs[j], seq[j]) + :mul(input_dim) + ) + end + + model:backward(end_symbol, zeros) + for j = len, 1, -1 do + model:backward(seq[j], zeros) + end + model:backward(start_symbol, zeros) +end + +local model = ntm.NTM(config) +local params, grads = model:getParameters() + +local num_iters = 10000 +local start = sys.clock() +local print_interval = 25 +local min_len = 1 +local max_len = 20 + +print(string.rep('=', 80)) +print("NTM copy task") +print('training up to ' .. num_iters .. ' iteration(s)') +print('min sequence length = ' .. min_len) +print('max sequence length = ' .. max_len) +print(string.rep('=', 80)) +print('num params: ' .. params:size(1)) + +local rmsprop_state = { + learningRate = 1e-4, + momentum = 0.9, + decay = 0.95 +} + +-- local adagrad_state = { +-- learningRate = 1e-3 +-- } + +-- train +for iter = 1, num_iters do + local print_flag = (iter % print_interval == 0) + local feval = function(x) + if print_flag then + print(string.rep('-', 80)) + print('iter = ' .. iter) + print('learn rate = ' .. rmsprop_state.learningRate) + print('momentum = ' .. rmsprop_state.momentum) + print('decay = ' .. rmsprop_state.decay) + printf('t = %.1fs\n', sys.clock() - start) + end + + local loss = 0 + grads:zero() + + local len = math.floor(torch.random(min_len, max_len)) + local seq = generate_sequence(len, input_dim - 2) + local outputs, criteria, sample_loss = forward(model, seq, print_flag) + loss = loss + sample_loss + backward(model, seq, outputs, criteria) + if print_flag then + print("target:") + print(seq) + print("output:") + print(outputs) + end + + -- clip gradients + grads:clamp(-10, 10) + if print_flag then + print('max grad = ' .. grads:max()) + print('min grad = ' .. grads:min()) + print('loss = ' .. loss) + end + return loss, grads + end + + --optim.adagrad(feval, params, adagrad_state) + ntm.rmsprop(feval, params, rmsprop_state) +end diff --git a/lr.lua b/lr.lua new file mode 100644 index 00000000..725b979b --- /dev/null +++ b/lr.lua @@ -0,0 +1,4 @@ +local function lr() + return 0.008 +end +return lr diff --git a/model/CircularConvolution.lua b/model/CircularConvolution.lua new file mode 100644 index 00000000..ba1fbce0 --- /dev/null +++ b/model/CircularConvolution.lua @@ -0,0 +1,96 @@ +--[[ + + Input: A table {x, k} of a vector x and a convolution kernel k. + + Output: Circular convolution of x with k. + + TODO: This module can probably be implemented more efficiently. + +--]] + +local CircularConvolution, parent = torch.class('nn.CircularConvolution', 'nn.Module') + +function CircularConvolution:__init() + parent.__init(self) + self.gradInput = {} +end + +function rotate_left(input, step) + local output = input.new():resizeAs(input) + local size = input:size(1) + output[{{1, size - step}}] = input[{{step + 1, size}}] + output[{{size - step + 1, size}}] = input[{{1, step}}] + return output +end + +function rotate_right(input, step) + local output = input.new():resizeAs(input) + local size = input:size(1) + output[{{step + 1, size}}] = input[{{1, size - step}}] + output[{{1, step}}] = input[{{size - step + 1, size}}] + return output +end + +-- function CircularConvolution:updateOutput_orig(input) +-- local a, b = unpack(input) +-- local size = a:size(1) +-- self.b = b:repeatTensor(1,2) +-- local circ = a.new():resize(size, size) +-- for i = 0, size - 1 do +-- circ[i + 1] = self.b:narrow(2, size - i + 1, size) +-- end +-- self.output:set(torch.mv(circ:t(), a)) +-- return self.output +-- end + +-- function CircularConvolution:updateGradInput_orig(input, gradOutput) +-- local a, b = unpack(input) +-- local size = a:size(1) +-- for i = 1, 2 do +-- self.gradInput[i] = self.gradInput[i] or input[1].new() +-- self.gradInput[i]:resize(size) +-- end + +-- a = a:repeatTensor(1, 2) +-- for i = 0, size - 1 do +-- self.gradInput[1][i + 1] = gradOutput:dot(self.b:narrow(2, size - i + 1, size)) +-- self.gradInput[2][i + 1] = gradOutput:dot(a:narrow(2, size - i + 1, size)) +-- end +-- return self.gradInput +-- end + +function CircularConvolution:updateOutput(input) + local v, k = unpack(input) + self.size = v:size(1) + self.kernel_size = k:size(1) + self.kernel_shift = math.floor(self.kernel_size / 2) + self.output = v.new():resize(self.size):zero() + for i = 1, self.size do + for j = 1, self.kernel_size do + local idx = i + self.kernel_shift - j + 1 + if idx < 1 then idx = idx + self.size end + if idx > self.size then idx = idx - self.size end + self.output[{{i}}]:add(k[j] * v[idx]) + end + end + return self.output +end + +function CircularConvolution:updateGradInput(input, gradOutput) + local v, k = unpack(input) + self.gradInput[1] = self.gradInput[1] or v.new() + self.gradInput[2] = self.gradInput[2] or k.new() + self.gradInput[1]:resize(self.size) + self.gradInput[2]:resize(self.kernel_size) + + local gradOutput2 = rotate_right(gradOutput:repeatTensor(1, 2):view(2 * self.size), self.kernel_shift) + for i = 1, self.size do + self.gradInput[1][i] = k:dot(gradOutput2:narrow(1, i, self.kernel_size)) + end + + local v2 = rotate_left(v:repeatTensor(1, 2):view(2 * self.size), self.kernel_shift + 1) + for i = 1, self.kernel_size do + self.gradInput[2][i] = gradOutput:dot(v2:narrow(1, self.size - i + 1, self.size)) + end + return self.gradInput +end diff --git a/model/GRU.lua b/model/GRU.lua index 11ae34e3..d3154962 100644 --- a/model/GRU.lua +++ b/model/GRU.lua @@ -6,7 +6,7 @@ Creates one timestep of one GRU Paper reference: http://arxiv.org/pdf/1412.3555v1.pdf ]]-- function GRU.gru(input_size, rnn_size, n, dropout) - dropout = dropout or 0 + dropout = dropout or 0 -- there are n+1 inputs (hiddens on each layer and x) local inputs = {} table.insert(inputs, nn.Identity()()) -- x @@ -26,11 +26,12 @@ function GRU.gru(input_size, rnn_size, n, dropout) local prev_h = inputs[L+1] -- the input to this layer - if L == 1 then - x = OneHot(input_size)(inputs[1]) + if L == 1 then + print(input_size) + x = nn.LookupTable(input_size,rnn_size)(inputs[1]) input_size_L = input_size - else - x = outputs[(L-1)] + else + x = outputs[(L-1)] if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any input_size_L = rnn_size end diff --git a/model/LSTMEX.lua b/model/LSTMEX.lua new file mode 100644 index 00000000..abe71a19 --- /dev/null +++ b/model/LSTMEX.lua @@ -0,0 +1,267 @@ +require "./Print.lua" +local LSTMEX = {} + +-- retunrs memory_dim vector - convex combination of memory slots +function LSTMEX.ReadHead(name, control, control_dim, memory, memory_slots, memory_dim) + print(string.format( + '%d guided attention reader on %dx%d memory matrix initialized...', + control_dim, + memory_slots, + memory_dim + )) + local transform = nn.Linear(control_dim, memory_slots)(control) + local address = nn.SoftMax()(transform):annotate{name=name} + address = nn.Reshape(memory_slots,1,true)(address) + local hologram = nn.MM(true){memory, address} + return nn.Reshape(memory_dim,true)(hologram) + --return nn.Tanh()(hologram) +end + +-- writes to whole memory weighted decomposition of x ruled by y signal +function LSTMEX.WriteHead(name, control, control_dim, x, memory, memory_slots, memory_dim) + print(string.format( + '%d guided writer on %dx%d memory matrix initialized...', + control_dim, + memory_slots, + memory_dim + )) + local transform = nn.Linear(control_dim, memory_slots)(control) + local address = nn.SoftMax()(transform):annotate{name=name} + address = nn.Reshape(memory_slots,1,true)(address) + --address = nn.Print('address',true)(address) + + local tx = nn.Reshape(1,memory_dim,true)(x) + --tx = nn.Print('x')(tx) + local delta = nn.MM(){address, tx} + --delta = nn.Print('delta')(delta) + local updated_memory = nn.CAddTable()({delta, memory}) + return updated_memory +end + +-- Gated eraser by control signal +function LSTMEX.EraseHead(name, control, control_dim, memory, memory_slots, memory_dim) + print(string.format( + '%d guided eraser on %dx%d memory matrix initialized...', + control_dim, + memory_slots, + memory_dim + )) + --control = nn.Print('erase_signal', true)(control) + local transform = nn.Linear(control_dim, memory_slots)(control) + --transform = nn.Print('transformed_erase')(transform) + local address = nn.SoftMax()(transform):annotate{name=name} + --address = nn.Print('erase_address')(address) + address = nn.AddConstant(1,false)(nn.MulConstant(-1,false)(address)) + --address = nn.Print('mul_mask', true)(address) + address = nn.Replicate(memory_dim,3,3)(address) + --address = nn.Print('replicated_mask')(address) + local updated_memory = nn.CMulTable()({address, memory}) + return updated_memory +end + + +-- writes to whole memory weighted decomposition of x ruled by y signal +function LSTMEX.EraseHeadModule(name, control_dim, memory_slots, memory_dim) + local control = nn.Identity()() + local memory = nn.Identity()() + local updated_memory = LSTMEX.EraseHead(name,control,control_dim,memory,memory_slots,memory_dim) + return nn.gModule({control,memory}, {updated_memory}) +end +-- writes to whole memory weighted decomposition of x ruled by y signal +function LSTMEX.WriteHeadModule(name, control_dim, memory_slots, memory_dim) + local control = nn.Identity()() + local x = nn.Identity()() + local memory = nn.Identity()() + local updated_memory = LSTMEX.WriteHead(name,control,control_dim,x,memory,memory_slots,memory_dim) + return nn.gModule({control,x,memory}, {updated_memory}) +end + +-- writes to whole memory weighted decomposition of x ruled by y signal +function LSTMEX.WriteEraseHead(name, i, f, x, i_dim, f_dim, x_dim, memory, memory_slots, memory_dim) + print(string.format( + 'WriteErase head %d guided writer on %dx%d memory matrix initialized...', + i_dim, + memory_slots, + memory_dim + )) + + local write_weights = nn.SoftMax()(nn.Linear(i_dim, memory_slots)(i)) + local delta = nn.MM(){ -- memory_slots X memory_dim matrix + nn.Reshape(memory_slots,1,true)(write_weights), + nn.Reshape(1,memory_dim,true)(x) + } + + + --delta = nn.Print('delta')(delta) + --local updated_memory = nn.CAddTable()({delta, memory}) + + --control = nn.Print('erase_signal', true)(control) + local T = nn.SoftMax()(nn.Linear(f_dim, memory_slots)(f)) + T = nn.Replicate(memory_dim,3,3)(T) + local M = nn.CAddTable()({ + nn.CMulTable()({T,memory}), + nn.CMulTable()({ + delta, + nn.AddConstant(1,false)(nn.MulConstant(-1,false)(T)) + }) + }) + + return M +end + + +function LSTMEX.lstm(input_size, rnn_size, n, dropout, memory_slots) + dropout = dropout or 0 + + -- there will be 2*n+1 inputs + local inputs = {} + table.insert(inputs, nn.Identity()()) -- x + for L = 1,n do + table.insert(inputs, nn.Identity()()) -- prev_c[L] + table.insert(inputs, nn.Identity()()) -- prev_h[L] + end + + local x, input_size_L + local outputs = {} + for L = 1,n do + -- c,h from previos timesteps + local prev_h = inputs[L*2+1] + local prev_c = inputs[L*2] + -- the input to this layer + if L == 1 then + x = OneHot(input_size)(inputs[1]) + --x = nn.Print(x) + input_size_L = input_size + else + x = outputs[(L-1)*2] + if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any + input_size_L = rnn_size + end + -- evaluate the input sums at once for efficiency + local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L} + local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L} + local all_input_sums = nn.CAddTable()({i2h, h2h}) + local reshaped = nn.Reshape(4, rnn_size)(all_input_sums) + local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) + -- decode the gates + local in_gate = nn.Sigmoid()(n1) + local forget_gate = nn.Sigmoid()(n2) + local out_gate = nn.Sigmoid()(n3) + -- decode the write inputs + local in_transform = nn.Tanh()(n4) + -- perform the LSTM update + + -- erase controlled by forget gate + local erased_c = LSTMEX.EraseHead('Erase',forget_gate,rnn_size,prev_c,memory_slots,rnn_size) + + -- write controlled by input gate + local next_c = LSTMEX.WriteHead('Write',in_gate,rnn_size,in_transform,erased_c,memory_slots,rnn_size) + next_c = nn.PrintTensor(10,"Memory")(next_c) + --local next_c = nn.CAddTable()({ + -- nn.CMulTable()({forget_gate, prev_c}), + -- nn.CMulTable()({in_gate, in_transform}) + -- }) + --next_c = nn.Print()(next_c) + -- read controlled by output gate + local next_h = LSTMEX.ReadHead('Read', out_gate, rnn_size, next_c, memory_slots, rnn_size) + next_h = nn.Tanh()(next_h) + --local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) + + table.insert(outputs, next_c) + table.insert(outputs, next_h) + end + + -- set up the decoder + local top_h = outputs[#outputs] + if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end + local proj = nn.Linear(rnn_size, input_size)(top_h):annotate{name='decoder'} + local logsoft = nn.LogSoftMax()(proj) + table.insert(outputs, logsoft) + + return nn.gModule(inputs, outputs) +end + +function LSTMEX.lstm2(input_size, rnn_size, n, dropout, memory_slots) + dropout = dropout or 0 + + -- there will be 2*n+1 inputs + local inputs = {} + table.insert(inputs, nn.Identity()()) -- x + for L = 1,n do + table.insert(inputs, nn.Identity()()) -- prev_c[L] + table.insert(inputs, nn.Identity()()) -- prev_h[L] + end + + local x, input_size_L + local outputs = {} + for L = 1,n do + -- c,h from previos timesteps + local prev_h = inputs[L*2+1] + local prev_c = inputs[L*2] + -- the input to this layer + if L == 1 then + x = OneHot(input_size)(inputs[1]) + --x = nn.Print(x) + input_size_L = input_size + else + x = outputs[(L-1)*2] + if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any + input_size_L = rnn_size + end + -- evaluate the input sums at once for efficiency + local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L} + local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L} + local all_input_sums = nn.CAddTable()({i2h, h2h}) + local reshaped = nn.Reshape(4, rnn_size)(all_input_sums) + local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) + -- decode the gates + local in_gate = nn.Sigmoid()(n1) + local forget_gate = nn.Sigmoid()(n2) + local out_gate = nn.Sigmoid()(n3) + -- decode the write inputs + local in_transform = nn.Tanh()(n4) + -- perform the LSTM update + local next_c = LSTMEX.WriteEraseHead( + 'WriteErase', + in_gate, + forget_gate, + in_transform, + rnn_size, --| + rnn_size, --| HUooooooooooH! + rnn_size, --| + prev_c, + memory_slots, + rnn_size + ) + next_c = nn.PrintTensor(10,"Memory")(next_c) + --next_c = nn.Print("next_c")(next_c) + -- erase controlled by forget gate + --local erased_c = LSTMEX.EraseHead('Erase',forget_gate,rnn_size,prev_c,memory_slots,rnn_size) + + -- write controlled by input gate + --local next_c = LSTMEX.WriteHead('Write',in_gate,rnn_size,in_transform,erased_c,memory_slots,rnn_size) + --local next_c = nn.CAddTable()({ + -- nn.CMulTable()({forget_gate, prev_c}), + -- nn.CMulTable()({in_gate, in_transform}) + -- }) + --next_c = nn.Print()(next_c) + -- read controlled by output gate + local next_h = LSTMEX.ReadHead('Read', out_gate, rnn_size, next_c, memory_slots, rnn_size) + next_h = nn.Tanh()(next_h) + --local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) + + table.insert(outputs, next_c) + table.insert(outputs, next_h) + end + + -- set up the decoder + local top_h = outputs[#outputs] + if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end + local proj = nn.Linear(rnn_size, input_size)(top_h):annotate{name='decoder'} + local logsoft = nn.LogSoftMax()(proj) + table.insert(outputs, logsoft) + + return nn.gModule(inputs, outputs) +end + +return LSTMEX diff --git a/model/LSTMNTM.lua b/model/LSTMNTM.lua new file mode 100644 index 00000000..e82bdcac --- /dev/null +++ b/model/LSTMNTM.lua @@ -0,0 +1,267 @@ + +local LSTMEX = {} + +-- retunrs memory_dim vector - convex combination of memory slots +function LSTMEX.ReadHead(name, control, control_dim, memory, memory_slots, memory_dim) + print(string.format( + '%d guided attention reader on %dx%d memory matrix initialized...', + control_dim, + memory_slots, + memory_dim + )) + local transform = nn.Linear(control_dim, memory_slots)(control) + local address = nn.SoftMax()(transform):annotate{name=name} + address = nn.Reshape(memory_slots,1,true)(address) + local hologram = nn.MM(true){memory, address} + return nn.Reshape(memory_dim,true)(hologram) + --return nn.Tanh()(hologram) +end + +-- writes to whole memory weighted decomposition of x ruled by y signal +function LSTMEX.WriteHead(name, control, control_dim, x, memory, memory_slots, memory_dim) + print(string.format( + '%d guided writer on %dx%d memory matrix initialized...', + control_dim, + memory_slots, + memory_dim + )) + local transform = nn.Linear(control_dim, memory_slots)(control) + local address = nn.SoftMax()(transform):annotate{name=name} + address = nn.Reshape(memory_slots,1,true)(address) + --address = nn.Print('address',true)(address) + + local tx = nn.Reshape(1,memory_dim,true)(x) + --tx = nn.Print('x')(tx) + local delta = nn.MM(){address, tx} + --delta = nn.Print('delta')(delta) + local updated_memory = nn.CAddTable()({delta, memory}) + return updated_memory +end + +-- Gated eraser by control signal +function LSTMEX.EraseHead(name, control, control_dim, memory, memory_slots, memory_dim) + print(string.format( + '%d guided eraser on %dx%d memory matrix initialized...', + control_dim, + memory_slots, + memory_dim + )) + --control = nn.Print('erase_signal', true)(control) + local transform = nn.Linear(control_dim, memory_slots)(control) + --transform = nn.Print('transformed_erase')(transform) + local address = nn.SoftMax()(transform):annotate{name=name} + --address = nn.Print('erase_address')(address) + address = nn.AddConstant(1,false)(nn.MulConstant(-1,false)(address)) + --address = nn.Print('mul_mask', true)(address) + address = nn.Replicate(memory_dim,3,3)(address) + --address = nn.Print('replicated_mask')(address) + local updated_memory = nn.CMulTable()({address, memory}) + return updated_memory +end + + +-- writes to whole memory weighted decomposition of x ruled by y signal +function LSTMEX.EraseHeadModule(name, control_dim, memory_slots, memory_dim) + local control = nn.Identity()() + local memory = nn.Identity()() + local updated_memory = LSTMEX.EraseHead(name,control,control_dim,memory,memory_slots,memory_dim) + return nn.gModule({control,memory}, {updated_memory}) +end +-- writes to whole memory weighted decomposition of x ruled by y signal +function LSTMEX.WriteHeadModule(name, control_dim, memory_slots, memory_dim) + local control = nn.Identity()() + local x = nn.Identity()() + local memory = nn.Identity()() + local updated_memory = LSTMEX.WriteHead(name,control,control_dim,x,memory,memory_slots,memory_dim) + return nn.gModule({control,x,memory}, {updated_memory}) +end + +-- writes to whole memory weighted decomposition of x ruled by y signal +function LSTMEX.WriteEraseHead(name, i, f, x, i_dim, f_dim, x_dim, memory, memory_slots, memory_dim) + print(string.format( + 'WriteErase head %d guided writer on %dx%d memory matrix initialized...', + i_dim, + memory_slots, + memory_dim + )) + + local write_weights = nn.SoftMax()(nn.Linear(i_dim, memory_slots)(i)) + local delta = nn.MM(){ -- memory_slots X memory_dim matrix + nn.Reshape(memory_slots,1,true)(write_weights), + nn.Reshape(1,memory_dim,true)(x) + } + + + --delta = nn.Print('delta')(delta) + --local updated_memory = nn.CAddTable()({delta, memory}) + + --control = nn.Print('erase_signal', true)(control) + local T = nn.SoftMax()(nn.Linear(f_dim, memory_slots)(f)) + T = nn.Replicate(memory_dim,3,3)(T) + local M = nn.CAddTable()({ + nn.CMulTable()({T,memory}), + nn.CMulTable()({ + delta, + nn.AddConstant(1,false)(nn.MulConstant(-1,false)(T)) + }) + }) + + return M +end + + +function LSTMEX.lstm(input_size, rnn_size, n, dropout, memory_slots) + dropout = dropout or 0 + + -- there will be 2*n+1 inputs + local inputs = {} + table.insert(inputs, nn.Identity()()) -- x + for L = 1,n do + table.insert(inputs, nn.Identity()()) -- prev_c[L] + table.insert(inputs, nn.Identity()()) -- prev_h[L] + end + + local x, input_size_L + local outputs = {} + for L = 1,n do + -- c,h from previos timesteps + local prev_h = inputs[L*2+1] + local prev_c = inputs[L*2] + -- the input to this layer + if L == 1 then + x = OneHot(input_size)(inputs[1]) + --x = nn.Print(x) + input_size_L = input_size + else + x = outputs[(L-1)*2] + if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any + input_size_L = rnn_size + end + -- evaluate the input sums at once for efficiency + local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L} + local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L} + local all_input_sums = nn.CAddTable()({i2h, h2h}) + local reshaped = nn.Reshape(4, rnn_size)(all_input_sums) + local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) + -- decode the gates + local in_gate = nn.Sigmoid()(n1) + local forget_gate = nn.Sigmoid()(n2) + local out_gate = nn.Sigmoid()(n3) + -- decode the write inputs + local in_transform = nn.Tanh()(n4) + -- perform the LSTM update + + -- erase controlled by forget gate + local erased_c = LSTMEX.EraseHead('Erase',forget_gate,rnn_size,prev_c,memory_slots,rnn_size) + + -- write controlled by input gate + local next_c = LSTMEX.WriteHead('Write',in_gate,rnn_size,in_transform,erased_c,memory_slots,rnn_size) + next_c = nn.PrintTensor(10,"Memory")(next_c) + --local next_c = nn.CAddTable()({ + -- nn.CMulTable()({forget_gate, prev_c}), + -- nn.CMulTable()({in_gate, in_transform}) + -- }) + --next_c = nn.Print()(next_c) + -- read controlled by output gate + local next_h = LSTMEX.ReadHead('Read', out_gate, rnn_size, next_c, memory_slots, rnn_size) + next_h = nn.Tanh()(next_h) + --local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) + + table.insert(outputs, next_c) + table.insert(outputs, next_h) + end + + -- set up the decoder + local top_h = outputs[#outputs] + if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end + local proj = nn.Linear(rnn_size, input_size)(top_h):annotate{name='decoder'} + local logsoft = nn.LogSoftMax()(proj) + table.insert(outputs, logsoft) + + return nn.gModule(inputs, outputs) +end + +function LSTMEX.lstm2(input_size, rnn_size, n, dropout, memory_slots) + dropout = dropout or 0 + + -- there will be 2*n+1 inputs + local inputs = {} + table.insert(inputs, nn.Identity()()) -- x + for L = 1,n do + table.insert(inputs, nn.Identity()()) -- prev_c[L] + table.insert(inputs, nn.Identity()()) -- prev_h[L] + end + + local x, input_size_L + local outputs = {} + for L = 1,n do + -- c,h from previos timesteps + local prev_h = inputs[L*2+1] + local prev_c = inputs[L*2] + -- the input to this layer + if L == 1 then + x = OneHot(input_size)(inputs[1]) + --x = nn.Print(x) + input_size_L = input_size + else + x = outputs[(L-1)*2] + if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any + input_size_L = rnn_size + end + -- evaluate the input sums at once for efficiency + local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L} + local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L} + local all_input_sums = nn.CAddTable()({i2h, h2h}) + local reshaped = nn.Reshape(4, rnn_size)(all_input_sums) + local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) + -- decode the gates + local in_gate = nn.Sigmoid()(n1) + local forget_gate = nn.Sigmoid()(n2) + local out_gate = nn.Sigmoid()(n3) + -- decode the write inputs + local in_transform = nn.Tanh()(n4) + -- perform the LSTM update + local next_c = LSTMEX.WriteEraseHead( + 'WriteErase', + in_gate, + forget_gate, + in_transform, + rnn_size, --| + rnn_size, --| HUooooooooooH! + rnn_size, --| + prev_c, + memory_slots, + rnn_size + ) + next_c = nn.PrintTensor(10,"Memory")(next_c) + --next_c = nn.Print("next_c")(next_c) + -- erase controlled by forget gate + --local erased_c = LSTMEX.EraseHead('Erase',forget_gate,rnn_size,prev_c,memory_slots,rnn_size) + + -- write controlled by input gate + --local next_c = LSTMEX.WriteHead('Write',in_gate,rnn_size,in_transform,erased_c,memory_slots,rnn_size) + --local next_c = nn.CAddTable()({ + -- nn.CMulTable()({forget_gate, prev_c}), + -- nn.CMulTable()({in_gate, in_transform}) + -- }) + --next_c = nn.Print()(next_c) + -- read controlled by output gate + local next_h = LSTMEX.ReadHead('Read', out_gate, rnn_size, next_c, memory_slots, rnn_size) + next_h = nn.Tanh()(next_h) + --local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) + + table.insert(outputs, next_c) + table.insert(outputs, next_h) + end + + -- set up the decoder + local top_h = outputs[#outputs] + if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end + local proj = nn.Linear(rnn_size, input_size)(top_h):annotate{name='decoder'} + local logsoft = nn.LogSoftMax()(proj) + table.insert(outputs, logsoft) + + return nn.gModule(inputs, outputs) +end + +return LSTMEX diff --git a/model/NTM.lua b/model/NTM.lua new file mode 100644 index 00000000..f259ed88 --- /dev/null +++ b/model/NTM.lua @@ -0,0 +1,527 @@ +--[[ + + Implementation of the Neural Turing Machine described here: + + http://arxiv.org/pdf/1410.5401v2.pdf + + Variable names take after the notation in the paper. Identifiers with "r" + appended indicate read-head variables, and likewise for those with "w" appended. + + The NTM take a configuration table at initialization time with the following + options: + + * input_dim dimension of input vectors (required) + * output_dim dimension of output vectors (required) + * mem_rows number of rows of memory + * mem_cols number of columns of memory + * cont_dim dimension of controller state + * cont_layers number of controller layers + * shift_range allowed range for shifting read/write weights + * write_heads number of write heads + * read_heads number of read heads + +--]] +require "./SmoothCosineSimilarity" +require "./ScalarMulTable" +require "./CircularConvolution" +require "./PowTable" +require "./NormalizeBySum" +require "./OuterProd" +require "./Squeeze" + +local function share_params(cell, src, ...) + for i = 1, #cell.forwardnodes do + local node = cell.forwardnodes[i] + if node.data.module then + node.data.module:share(src.forwardnodes[i].data.module, ...) + end + end +end + +local NTM, parent = torch.class('nn.NTM', 'nn.Module') + +function NTM:__init(config) + self.input_dim = config.input_dim or error('config.input_dim must be specified') + self.output_dim = config.output_dim or error('config.output_dim must be specified') + self.mem_rows = config.mem_rows or 20 + self.mem_cols = config.mem_cols or 128 + self.cont_dim = config.cont_dim or 128 + self.cont_layers = config.cont_layers or 1 + self.shift_range = config.shift_range or 1 + self.write_heads = config.write_heads or 1 + self.read_heads = config.read_heads or 1 + + self.depth = 0 + self.cells = {} + self.master_cell = self:new_cell() + self.init_module = self:new_init_module() + + self:init_grad_inputs() +end + +function NTM:init_grad_inputs() + local ww_gradInput + if self.write_heads == 1 then + ww_gradInput = torch.zeros(self.mem_rows) + else + ww_gradInput = {} + for i = 1, self.write_heads do + ww_gradInput[i] = torch.zeros(self.mem_rows) + end + end + + local wr_gradInput, r_gradInput + if self.read_heads == 1 then + wr_gradInput = torch.zeros(self.mem_rows) + r_gradInput = torch.zeros(self.mem_cols) + else + wr_gradInput, r_gradInput = {}, {} + for i = 1, self.read_heads do + wr_gradInput[i] = torch.zeros(self.mem_rows) + r_gradInput[i] = torch.zeros(self.mem_cols) + end + end + + local m_gradInput, c_gradInput + if self.cont_layers == 1 then + m_gradInput = torch.zeros(self.cont_dim) + c_gradInput = torch.zeros(self.cont_dim) + else + m_gradInput, c_gradInput = {}, {} + for i = 1, self.cont_layers do + m_gradInput[i] = torch.zeros(self.cont_dim) + c_gradInput[i] = torch.zeros(self.cont_dim) + end + end + + self.gradInput = { + torch.zeros(self.input_dim), -- input + torch.zeros(self.mem_rows, self.mem_cols), -- M + wr_gradInput, + ww_gradInput, + r_gradInput, + m_gradInput, + c_gradInput + } +end + +-- The initialization module initializes the state of NTM memory, +-- read/write weights, and the state of the LSTM controller. +function NTM:new_init_module() + local dummy = nn.Identity()() -- always zero + local output_init = nn.Tanh()(nn.Linear(1, self.input_dim)(dummy)) + + -- memory + local M_init_lin = nn.Linear(1, self.mem_rows * self.mem_cols) + local M_init = nn.View(self.mem_rows, self.mem_cols)( + nn.Tanh()(M_init_lin(dummy))) + + -- read weights + local wr_init, r_init = {}, {} + for i = 1, self.read_heads do + local wr_init_lin = nn.Linear(1, self.mem_rows) + wr_init[i] = nn.SoftMax()(wr_init_lin(dummy)) + r_init[i] = nn.Tanh()(nn.Linear(1, self.mem_cols)(dummy)) + + -- We initialize the read and write distributions such that the + -- weights decay exponentially over the rows of NTM memory. + -- This sort of initialization seems to be important in my experiments (kst). + wr_init_lin.bias:copy(torch.range(self.mem_rows, 1, -1)) + end + + -- write weights + local ww_init = {} + for i = 1, self.write_heads do + local ww_init_lin = nn.Linear(1, self.mem_rows) + ww_init[i] = nn.SoftMax()(ww_init_lin(dummy)) + + -- See initialization comment above + ww_init_lin.bias:copy(torch.range(self.mem_rows, 1, -1)) + end + + -- controller state + local m_init, c_init = {}, {} + for i = 1, self.cont_layers do + m_init[i] = nn.Tanh()(nn.Linear(1, self.cont_dim)(dummy)) + c_init[i] = nn.Tanh()(nn.Linear(1, self.cont_dim)(dummy)) + end + + -- wrap tables as nngraph nodes + ww_init = nn.Identity()(ww_init) + wr_init = nn.Identity()(wr_init) + r_init = nn.Identity()(r_init) + m_init = nn.Identity()(m_init) + c_init = nn.Identity()(c_init) + + local inits = { + output_init, M_init, wr_init, ww_init, r_init, m_init, c_init + } + return nn.gModule({dummy}, inits) +end + +-- Create a new NTM cell. Each cell shares the parameters of the "master" cell +-- and stores the outputs of each iteration of forward propagation. +function NTM:new_cell() + -- input to the network + local input = nn.LookupTable(self.input_dim,self.input_dim)() + local inn = nn.Squeeze()(input) + + -- previous memory state and read/write weights + local M_p = nn.Identity()() + local wr_p = nn.Identity()() + local ww_p = nn.Identity()() + + -- vector read from memory + local r_p = nn.Identity()() + + -- LSTM controller output + local mtable_p = nn.Identity()() + local ctable_p = nn.Identity()() + + -- output and hidden states of the controller module + local mtable, ctable = self:new_controller_module(inn, r_p, mtable_p, ctable_p) + local m = (self.cont_layers == 1) and mtable + or nn.SelectTable(self.cont_layers)(mtable) + local M, wr, ww, r = self:new_mem_module(M_p, wr_p, ww_p, m) + local output = self:new_output_module(m) + + local inputs = {input, M_p, wr_p, ww_p, r_p, mtable_p, ctable_p} + local outputs = {output, M, wr, ww, r, mtable, ctable} + + local cell = nn.gModule(inputs, outputs) + if self.master_cell ~= nil then + share_params(cell, self.master_cell, 'weight', 'bias', 'gradWeight', 'gradBias') + end + return cell +end + +-- Create a new LSTM controller +function NTM:new_controller_module(input, r_p, mtable_p, ctable_p) + -- multilayer LSTM + local mtable, ctable = {}, {} + for layer = 1, self.cont_layers do + local new_gate, m_p, c_p + if self.cont_layers == 1 then + m_p = mtable_p + c_p = ctable_p + else + m_p = nn.SelectTable(layer)(mtable_p) + c_p = nn.SelectTable(layer)(ctable_p) + end + + if layer == 1 then + new_gate = function() + local in_modules = { + nn.Linear(self.input_dim, self.cont_dim)(input), + nn.Linear(self.cont_dim, self.cont_dim)(m_p) + } + if self.read_heads == 1 then + table.insert(in_modules, nn.Linear(self.mem_cols, self.cont_dim)(r_p)) + else + for i = 1, self.read_heads do + local vec = nn.SelectTable(i)(r_p) + table.insert(in_modules, nn.Linear(self.mem_cols, self.cont_dim)(vec)) + end + end + return nn.CAddTable()(in_modules) + end + else + new_gate = function() + return nn.CAddTable(){ + nn.Linear(self.cont_dim, self.cont_dim)(mtable[layer - 1]), + nn.Linear(self.cont_dim, self.cont_dim)(m_p) + } + end + end + + -- input, forget, and output gates + local i = nn.Sigmoid()(new_gate()) + local f = nn.Sigmoid()(new_gate()) + local o = nn.Sigmoid()(new_gate()) + local update = nn.Tanh()(new_gate()) + + -- update the state of the LSTM cell + ctable[layer] = nn.CAddTable(){ + nn.CMulTable(){f, c_p}, + nn.CMulTable(){i, update} + } + + mtable[layer] = nn.CMulTable(){o, nn.Tanh()(ctable[layer])} + end + + mtable = nn.Identity()(mtable) + ctable = nn.Identity()(ctable) + return mtable, ctable +end + +-- Create a new module to read/write to memory +function NTM:new_mem_module(M_p, wr_p, ww_p, m) + -- read heads + local wr, r + if self.read_heads == 1 then + wr, r = self:new_read_head(M_p, wr_p, m) + else + wr, r = {}, {} + for i = 1, self.read_heads do + local prev_weights = nn.SelectTable(i)(wr_p) + wr[i], r[i] = self:new_read_head(M_p, prev_weights, m) + end + wr = nn.Identity()(wr) + r = nn.Identity()(r) + end + + -- write heads + local ww, a, e, M_erase, M_write + if self.write_heads == 1 then + ww, a, e = self:new_write_head(M_p, ww_p, m) + M_erase = nn.AddConstant(1)(nn.MulConstant(-1)(nn.OuterProd(){ww, e})) + M_write = nn.OuterProd(){ww, a} + else + ww, a, e, M_erase, M_write = {}, {}, {}, {}, {} + for i = 1, self.write_heads do + local prev_weights = nn.SelectTable(i)(ww_p) + ww[i], a[i], e[i] = self:new_write_head(M_p, prev_weights, m) + M_erase[i] = nn.AddConstant(1)(nn.MulConstant(-1)(nn.OuterProd(){ww[i], e[i]})) + M_write[i] = nn.OuterProd(){ww[i], a[i]} + end + M_erase = nn.CMulTable()(M_erase) + M_write = nn.CAddTable()(M_write) + ww = nn.Identity()(ww) + end + + -- erase some history from memory + --M_erase = nn.PrintTensor(1,"EraseMemory")(M_erase) + local Mtilde = nn.CMulTable(){M_p, M_erase} + --M_write = nn.PrintTensor(1,"WriteMemory")(M_write) + -- write to memory + local M = nn.CAddTable(){Mtilde, M_write} + + M = nn.PrintTensor(50,"Memory")(M) + return M, wr, ww, r +end + +function NTM:new_read_head(M_p, w_p, m) + return self:new_head(M_p, w_p, m, true) +end + +function NTM:new_write_head(M_p, w_p, m) + return self:new_head(M_p, w_p, m, false) +end + +-- Create a new head +function NTM:new_head(M_p, w_p, m, is_read) + -- key vector + local k = nn.Tanh()(nn.Linear(self.cont_dim, self.mem_cols)(m)) + -- circular convolution kernel + local s = nn.SoftMax()(nn.Linear(self.cont_dim, 2 * self.shift_range + 1)(m)) + -- weight sharpening parameter + local beta = nn.SoftPlus()(nn.Linear(self.cont_dim, 1)(m)) + -- gating parameter + local g = nn.Sigmoid()(nn.Linear(self.cont_dim, 1)(m)) + -- exponential focusing parameter + local gamma = nn.AddConstant(1)( + nn.SoftPlus()(nn.Linear(self.cont_dim, 1)(m))) + + local sim = nn.SmoothCosineSimilarity(){M_p, k} + local wc = nn.SoftMax()(nn.ScalarMulTable(){sim, beta}) + local wg = nn.CAddTable(){ + nn.ScalarMulTable(){wc, g}, + nn.ScalarMulTable(){w_p, nn.AddConstant(1)(nn.MulConstant(-1)(g))} + } + + local wtilde = nn.CircularConvolution(){wg, s} + local wpow = nn.PowTable(){wtilde, gamma} + local w = nn.Normalize(2)(wpow) + + if is_read then + local r = nn.MixtureTable(){w, M_p} + return w, r + else + local a = nn.Tanh()(nn.Linear(self.cont_dim, self.mem_cols)(m)) + local e = nn.Sigmoid()(nn.Linear(self.cont_dim, self.mem_cols)(m)) + return w, a, e + end +end + +-- Create an output module, e.g. to output binary strings. +function NTM:new_output_module(m) + local output = nn.LogSoftMax()(nn.Linear(self.cont_dim, self.output_dim)(m)) + return output +end + +function NTM:r() + self.ddepth = 0 + self.dcell = nil + self.dprev_outputs = nil +end +-- Forward propagate one time step. The outputs of previous time steps are +-- cached for backpropagation. +function NTM:f(input) + self.ddepth = self.ddepth or 0 + self.ddepth = self.ddepth + 1 + self.dcell = self.dcell or self:new_cell() + + local prev_outputs + if self.ddepth == 1 then + self.dprev_outputs = self.init_module:forward(torch.Tensor{0}) + else + self.dprev_outputs = self.dcell.output + end + + -- get inputs + local inputs = {input} + for i = 2, #self.dprev_outputs do + inputs[i] = self.dprev_outputs[i] + end + --print('F',inputs) + local outputs = self.dcell:forward(inputs) + return outputs[1] +end + +function NTM:set_last_state() + self.prev_output = self.cells[self.depth].output +end +-- Forward propagate one time step. The outputs of previous time steps are +-- cached for backpropagation. +function NTM:forward(input) + self.depth = self.depth + 1 + local cell = self.cells[self.depth] + --print('forward depth', self.depth) + if cell == nil then + cell = self:new_cell() + self.cells[self.depth] = cell + end + + local prev_outputs + if self.depth == 1 then + --if self.prev_output == nil then + prev_outputs = self.init_module:forward(torch.Tensor{0}) + -- self.prev_output = prev_outputs + --else + --print('use previous state') + --print(self.prev_output[5][{{1,10}}]) + --prev_outputs = self.prev_output + --end + else + prev_outputs = self.cells[self.depth - 1].output + end + + -- get inputs + local inputs = {input} + for i = 2, #prev_outputs do + inputs[i] = prev_outputs[i] + end + --print('FORWARD',inputs) + local outputs = cell:forward(inputs) + self.output = outputs[1] + return self.output +end + +-- Backward propagate one time step. Throws an error if called more times than +-- forward has been called. +function NTM:backward(input, grad_output) + if self.depth == 0 then + error("No cells to backpropagate through") + end + local cell = self.cells[self.depth] + local grad_outputs = {grad_output} + for i = 2, #self.gradInput do + grad_outputs[i] = self.gradInput[i] + end + + -- get inputs + local prev_outputs + if self.depth == 1 then + --print("remember previous state") + --print(self.prev_output[5][{{1,10}}]) + --prev_outputs = self.prev_output + prev_outputs = self.init_module:forward(torch.Tensor{0}) + else + prev_outputs = self.cells[self.depth - 1].output + end + + local inputs = {input} + for i = 2, #prev_outputs do + inputs[i] = prev_outputs[i] + end + + self.gradInput = cell:backward(inputs, grad_outputs) + self.depth = self.depth - 1 + + if self.depth == 0 then + + self.gradInput[1] = torch.zeros(self.input_dim) -- fix TODO: hm hm hm + + self.init_module:backward(torch.Tensor{0}, self.gradInput) + for i = 1, #self.gradInput do + local gradInput = self.gradInput[i] + if type(gradInput) == 'table' then + for _, t in pairs(gradInput) do + --print(t) + t:zero() + end + else + self.gradInput[i]:zero() + end + end + end + return self.gradInput +end + +-- Get the state of memory +function NTM:get_memory(depth) + if self.depth == 0 then + return self.initial_values[2] + end + local depth = depth or self.depth + return self.cells[self.depth].output[2] +end + +-- Get read head weights over the rows of memory +function NTM:get_read_weights(depth) + if self.depth == 0 then + return self.initial_values[3] + end + local depth = depth or self.depth + return self.cells[depth].output[3] +end + +-- Get write head weights over the rows of memory +function NTM:get_write_weights(depth) + if self.depth == 0 then + return self.initial_values[4] + end + local depth = depth or self.depth + return self.cells[depth].output[4] +end + +-- Get the vector read from memory +function NTM:get_read_vector(depth) + if self.depth == 0 then + return self.initial_values[5] + end + local depth = depth or self.depth + return self.cells[depth].output[5] +end + +function NTM:parameters() + local p, g = self.master_cell:parameters() + local pi, gi = self.init_module:parameters() + tablex.insertvalues(p, pi) + tablex.insertvalues(g, gi) + return p, g +end + +function NTM:forget() + self.depth = 0 + self:zeroGradParameters() + for i = 1, #self.gradInput do + --print(self.gradInput[i]) + self.gradInput[i]:zero() + end +end + +function NTM:zeroGradParameters() + self.master_cell:zeroGradParameters() + self.init_module:zeroGradParameters() +end diff --git a/model/NormalizeBySum.lua b/model/NormalizeBySum.lua new file mode 100644 index 00000000..69d5e0e7 --- /dev/null +++ b/model/NormalizeBySum.lua @@ -0,0 +1,30 @@ +--[[ + + Divides each element of a Tensor by their sum. + +--]] + +local NormalizeBySum, parent = torch.class('nn.NormalizeBySum', 'nn.Module') + +function NormalizeBySum:__init() + parent.__init(self) +end + +function NormalizeBySum:updateOutput(input) + self.output:resizeAs(input):copy(input) + self.sum = input:sum() + self.output:div(self.sum) + return self.output +end + +function NormalizeBySum:updateGradInput(input, gradOutput) + local size = input:size(1) + self.gradInput:resizeAs(input) + for i = 1, size do + local output = torch.Tensor(size):copy(self.output) + output:div(-self.sum) + output[i] = output[i] + (1 / self.sum) + self.gradInput[i] = gradOutput:dot(output) + end + return self.gradInput +end diff --git a/model/OuterProd.lua b/model/OuterProd.lua new file mode 100644 index 00000000..01511c4f --- /dev/null +++ b/model/OuterProd.lua @@ -0,0 +1,69 @@ +--[[ + + Input: a table of 2 or 3 vectors. + + Output: the outer product of the vectors. + +--]] + +local OuterProd, parent = torch.class('nn.OuterProd', 'nn.Module') + +function OuterProd:__init() + parent.__init(self) + self.gradInput = {} +end + +function OuterProd:updateOutput(input) + local order = #input + self.order = order + if order == 2 then + self.output:set(torch.ger(input[1], input[2])) + self.size = self.output:size() + elseif order == 3 then + -- allocate + self.size = torch.LongStorage(order) + local idx = 1 + for i = 1, order do + self.size[i] = input[i]:size(1) + end + self.output:resize(self.size):zero() + + local u, v, w = unpack(input) + local uv = torch.ger(u, v) + for i = 1, self.size[3] do + self.output[{{}, {}, i}]:add(w[i], uv) + end + else + error('outer products of order higher than 3 unsupported') + end + return self.output +end + +function OuterProd:updateGradInput(input, gradOutput) + local order = #input + for i = 1, order do + self.gradInput[i] = self.gradInput[i] or input[1].new() + self.gradInput[i]:resizeAs(input[i]) + end + + if order == 2 then + self.gradInput[1]:copy(gradOutput * input[2]) + self.gradInput[2]:copy(gradOutput:t() * input[1]) + else + local u, v, w = unpack(input) + local du, dv, dw = u:size(1), v:size(1), w:size(1) + local uv = input[1].new():resize(du, dv):zero() + for i = 1, dw do + uv:add(w[i], gradOutput[{{}, {}, i}]) + end + self.gradInput[1]:copy(uv * input[2]) + self.gradInput[2]:copy(uv:t() * input[1]) + + local vw = input[1].new():resize(dv, dw):zero() + for i = 1, du do + vw:add(u[i], gradOutput[{i, {}, {}}]) + end + self.gradInput[3]:copy(vw:t() * input[2]) + end + return self.gradInput +end diff --git a/model/PowTable.lua b/model/PowTable.lua new file mode 100644 index 00000000..7555865c --- /dev/null +++ b/model/PowTable.lua @@ -0,0 +1,37 @@ +--[[ + + Input: A table {x, y} of a Tensor x and a scalar y. + + Output: x^y (elementwise) + +--]] + +local PowTable, parent = torch.class('nn.PowTable', 'nn.Module') + +function PowTable:__init() + parent.__init(self) + self.gradInput = {} +end + +function PowTable:updateOutput(input) + local v, p = unpack(input) + return self.output:set(torch.pow(v, p[1])) +end + +function PowTable:updateGradInput(input, gradOutput) + local v, p = unpack(input) + p = p[1] + self.gradInput[1] = self.gradInput[1] or input[1].new() + self.gradInput[2] = self.gradInput[2] or input[2].new() + self.gradInput[2]:resizeAs(input[2]) + + self.gradInput[1]:set(torch.cmul(gradOutput, torch.pow(v, p - 1)) * p) + local pgrad = 0 + for i = 1, v:size(1) do + if v[i] > 0 then + pgrad = pgrad + math.log(v[i]) * self.output[i] * gradOutput[i] + end + end + self.gradInput[2][1] = pgrad + return self.gradInput +end diff --git a/model/Print.lua b/model/Print.lua new file mode 100644 index 00000000..d9bb185f --- /dev/null +++ b/model/Print.lua @@ -0,0 +1,99 @@ +require "gnuplot" +--[[ + An Identity layer that prints its input. +--]] +local Print, parent = torch.class('nn.Print', 'nn.Module') +function Print:__init(label) + parent:__init(self) + self.label = label +end +function Print:updateOutput(input) + self.output = input + if self.label ~= nil then + print(self.label) + end + print(input) + return self.output +end +function Print:updateGradInput(input, gradOutput) + self.gradInput = gradOutput + return self.gradInput +end + +--print tensor size +local PrintSize, parent = torch.class('nn.PrintSize', 'nn.Module') +function PrintSize:__init(label) + parent:__init(self) + self.label = label +end +function PrintSize:updateOutput(input) + self.output = input + local sizes = {} + local size = input:size() + for i=1,input:nDimension() do + table.insert(sizes, size[i]) + end + print(string.format("%s -> %s", self.label,table.concat(sizes,"x"))) + return self.output +end +function PrintSize:updateGradInput(input, gradOutput) + self.gradInput = gradOutput + return self.gradInput +end + + + +local PrintAddress, parent = torch.class('nn.PrintAddress', 'nn.Module') + +function PrintAddress:__init(label) + parent:__init(self) + self.label = label + self.look = 0 +end + +function PrintAddress:updateOutput(input) + self.output = input + local v, index = torch.max(input[1],1) + v, index = v[1], index[1] + if self.look ~= index then + print(string.format('%s moved from %d to %d %.2f',self.label,self.look,index,v)) + self.look = index + end + return self.output +end + + +function PrintAddress:updateGradInput(input, gradOutput) + self.gradInput = gradOutput + return self.gradInput +end + +local PrintTensor, parent = torch.class('nn.PrintTensor', 'nn.Module') + +function PrintTensor:__init(interval, label) + parent:__init(self) + self.label = label + self.interval = interval + self.count = 0 +end + +function PrintTensor:updateOutput(input) + self.output = input + --gnuplot.figure(1) + self.count = self.count + 1 + if self.count % self.interval == 0 then + gnuplot.title(self.label) + if input:nDimension() == 3 then + gnuplot.imagesc(input[1],'memory') + else + gnuplot.imagesc(input ,'memory') + end + end + return self.output +end + + +function PrintTensor:updateGradInput(input, gradOutput) + self.gradInput = gradOutput + return self.gradInput +end diff --git a/model/ScalarDivTable.lua b/model/ScalarDivTable.lua new file mode 100644 index 00000000..93a3ce3f --- /dev/null +++ b/model/ScalarDivTable.lua @@ -0,0 +1,32 @@ +--[[ + +Input: A table {x, y} of a Tensor and a scalar. + +Output: x / y + +--]] + + +local ScalarDivTable, parent = torch.class('nn.ScalarDivTable', 'nn.Module') + +function ScalarDivTable:__init() + parent.__init(self) + self.gradInput = {} +end + +function ScalarDivTable:updateOutput(input) + local v, scale = unpack(input) + return self.output:set(v / scale[1]) +end + +function ScalarDivTable:updateGradInput(input, gradOutput) + local v, scale = unpack(input) + self.gradInput[1] = self.gradInput[1] or input[1].new() + self.gradInput[2] = self.gradInput[2] or input[2].new() + self.gradInput[2]:resizeAs(input[2]) + + local c = scale[1] + self.gradInput[1]:set(gradOutput / c) + self.gradInput[2][1] = -gradOutput:dot(v) / (c * c) + return self.gradInput +end diff --git a/model/ScalarMulTable.lua b/model/ScalarMulTable.lua new file mode 100644 index 00000000..a85508c8 --- /dev/null +++ b/model/ScalarMulTable.lua @@ -0,0 +1,30 @@ +--[[ + + Input: A table {x, y} of a Tensor x and a scalar y. + + Output: x * y + +--]] + +local ScalarMulTable, parent = torch.class('nn.ScalarMulTable', 'nn.Module') + +function ScalarMulTable:__init() + parent.__init(self) + self.gradInput = {} +end + +function ScalarMulTable:updateOutput(input) + local v, scale = unpack(input) + return self.output:set(v * scale[1]) +end + +function ScalarMulTable:updateGradInput(input, gradOutput) + local v, scale = unpack(input) + self.gradInput[1] = self.gradInput[1] or input[1].new() + self.gradInput[2] = self.gradInput[2] or input[2].new() + self.gradInput[2]:resizeAs(input[2]) + + self.gradInput[1]:set(gradOutput * scale[1]) + self.gradInput[2][1] = gradOutput:dot(v) + return self.gradInput +end diff --git a/model/SmoothCosineSimilarity.lua b/model/SmoothCosineSimilarity.lua new file mode 100644 index 00000000..39b1b49e --- /dev/null +++ b/model/SmoothCosineSimilarity.lua @@ -0,0 +1,58 @@ +--[[ + +Input: a table of two inputs {M, k}, where + M = an n-by-m matrix + k = an m-dimensional vector + +Output: an n-dimensional vector + +Each element is an approximation of the cosine similarity between k and the +corresponding row of M. It's an approximation since we add a constant to the +denominator of the cosine similarity function to remove the singularity when +one of the inputs is zero. + +--]] + +local SmoothCosineSimilarity, parent = torch.class('nn.SmoothCosineSimilarity', 'nn.Module') + +function SmoothCosineSimilarity:__init(smoothen) + parent.__init(self) + self.gradInput = {} + self.smooth = smoothen or 1e-3 +end + +function SmoothCosineSimilarity:updateOutput(input) + local M, k = unpack(input) + self.rownorms = torch.cmul(M, M):sum(2):sqrt():view(M:size(1)) + self.knorm = math.sqrt(k:dot(k)) + self.dot = M * k + self.output:set(torch.cdiv(self.dot, self.rownorms * self.knorm + self.smooth)) + return self.output +end + +function SmoothCosineSimilarity:updateGradInput(input, gradOutput) + local M, k = unpack(input) + self.gradInput[1] = self.gradInput[1] or input[1].new() + self.gradInput[2] = self.gradInput[2] or input[2].new() + + -- M gradient + local rows = M:size(1) + local Mgrad = self.gradInput[1] + Mgrad:set(k:repeatTensor(rows, 1)) + for i = 1, rows do + if self.rownorms[i] > 0 then + Mgrad[i]:add(-self.output[i] * self.knorm / self.rownorms[i], M[i]) + end + Mgrad[i]:mul(gradOutput[i] / (self.rownorms[i] * self.knorm + self.smooth)) + end + + -- k gradient + self.gradInput[2]:set(M:t() * torch.cdiv(gradOutput, self.rownorms * self.knorm + self.smooth)) + if self.knorm > 0 then + local scale = torch.cmul(self.output, self.rownorms) + :cdiv(self.rownorms * self.knorm + self.smooth) + :dot(gradOutput) / self.knorm + self.gradInput[2]:add(-scale, k) + end + return self.gradInput +end diff --git a/model/Squeeze.lua b/model/Squeeze.lua new file mode 100644 index 00000000..e9b15160 --- /dev/null +++ b/model/Squeeze.lua @@ -0,0 +1,12 @@ +local Squeeze, parent = torch.class('nn.Squeeze', 'nn.Module') + +function Squeeze:updateOutput(input) + self.size = input:size() + self.output = input:squeeze() + return self.output +end + +function Squeeze:updateGradInput(input, gradOutput) + self.gradInput = gradOutput:view(self.size) + return self.gradInput +end diff --git a/model/TT.lua b/model/TT.lua new file mode 100644 index 00000000..d774b86d --- /dev/null +++ b/model/TT.lua @@ -0,0 +1,171 @@ +require 'nn' + +local TensorTrain, parent = torch.class('nn.TensorTrain', 'nn.Module') + +function TensorTrain:__init(outChannels, outHeight, outWidth) + parent.__init(self) + + self.weight = torch.Tensor() + self.bias = torch.Tensor() + self.gradWeight = torch.Tensor() + self.gradBias = torch.Tensor() + + self.outHeight = outHeight + self.outWidth = outWidth + self.outChannels = outChannels + + self.W = { + n = nil, + m = nil, + tt = {core = nil, ps = nil, r = nil}, + mul = function(a, b) + local n=a.n; local m=a.m; local tt=a.tt; local cra=tt.core; local d=tt.d; local ps=tt.ps; local r=tt.r; + local rb=b:size(2); + local c=torch.view(b,torch.cat(m:t(),rb, 1):long()); + + for k=1,d do + local cr=cra:sub(ps(k),ps[k+1]-1); + cr=torch.view(cr,r[k],n[k],m[k],r[k+1]); + cr=torch.permute(cr,2,4,1,3); cr=torch.view(cr,n[k]*r[k+1],r[k]*m[k]); + local M=c:nElement(); + c=torch.view(c,r[k]*m[k],M/(r[k]*m[k])); + c=cr*c; c=torch.view(c,n[k],c:nElement()/n[k]); + c=torch.permute(c,2,1); + end + c=c:view(-1); c=torch.view(c,rb,c:nElement()/rb); + c=c:t(); + return c + end, + + t = function(tt) + local t = tt.tt; + local m = tt.m; + local n = tt.n; + local d = t.d; + local r = t.r; + for i=1,d do + local cr = t[i] + cr = torch.view(cr, r[i], n[i], m[i], r[i+1]); + cr = torch.permute(cr, 1, 3, 2, 4); + t[i] = torch.view(cr, r[i], m[i]*n[i], r[i+1]); + + end + local tt1 = {mul = tt.mul, t = tt.t, rank = tt.rank, tocell = tt.tocell}; + tt1.tt = t; + tt1.m=tt.n; + tt1.n=tt.m; + return tt1 + end, + + rank = function(a) + return r=a.tt.r; + end, + + tocell = function(tt) + local d = tt.tt.d; + local cc = {} + local n = tt.n; + local m = tt.m; + local r = tt.tt.r; + local ps = tt.tt.ps; + local cr = tt.tt.core; + for i=1:d do + cc[i] = torch.view(cr:sub(ps(i),(ps(i+1)-1)), r[i], n[i], m[i], r[i+1]); + end + return cc + end + } + + --TODO: should the constructor arguments resemble Linear or SpatialConvolution? + --TODO: self:reset() +end + +function TensorTrain:updateOutput(input) + assert(input:dim() == 4) + + local inHeight, inWidth, inChannels, batchSize = input:size(1), input:size(2), input:size(3), input:size(4) + + self.output = W:mul(torch.view(input, -1, batchSize)) + if self.bias:nElement() > 0 then + self.output:add(torch.view(self.bias, self.outHeight, self.outWidth, self.outChannels, 1):expandAs(self.output)) + end + self.output = torch.view(self.output, self.outHeight, self.outWidth, self.outChannels, batchSize) + return self.output +end + +function TensorTrain:updateGradInput(input, gradOutput) + local inHeight, inWidth, inChannels, batchSize = input:size(1), input:size(2), input:size(3), input:size(4) + + self.gradInput = W:t():mul(torch.view(self.gradInput, -1, batchSize)) + self.gradInput = torch.view(self.gradInput, inHeight, inWidth, inChannels, batchSize) + return self.gradInput +end + +function TensorTrain:accGradParameters(input, gradOutput, scale) + local inHeight, inWidth, inChannels, batchSize = input:size(1), input:size(2), input:size(3), input:size(4) + if self.bias:nElement() > 0 then + self.gradBias = self.gradInput:sum(4) + else + self.gradBias = [] + end + + local DZDWCore = input.new(W_core:size()):zero() + local rankArr = self.W:rank() + local corePos = W.ps + + local numDims = W.n:size(1) + local coreArr = W:tocell() + + local rightSum = torch.view(input, -1, batchSize) + rightSum = rightSum:t() + + local leftSum + for derDim = numDims, 1, -1 do + if derDim < numDims then + local rightDim = derDim + 1 + local sumSize = W.m[rightDim] * rankArr[rightDim+1] + local core = torch.view(coreArr[rightDim], -1, sumSize) + rightSum = torch.view(rightSum, -1, W.m[rightDim]) + rightSum = core * (torch.view(rightSum:t(), sumSize, -1)) + end + + if derDim >= 2 then + local core = torch.permute(coreArr[derDim-1], 1, 2, 4, 3) + core = torch.view(core, -1, W.m[derDim-1]) + + leftSum = torch.view(rightSum, rankArr[derDim+1]*torch.prod(W.n:sub(derDim+1, -1))*batchSize*torch.prod(W.m:sub(1, derDim-2)), torch.prod(W.m:sub(derDim-1, derDim))) + leftSum = core * torch.view(leftSum:t(), W.m[derDim-1], -1) + + local leftSumDims = torch.LongStorage{rankArr[derDim-1]*W.n[derDim-1], rankArr[derDim]*W.m[derDim]*rankArr[derDim+1], torch.prod(W.n:sub(derDim+1, -1))*batchSize, torch.prod(W.m:sub(1, derDim-2))} + leftSum = torch.view(leftSum, leftSumDims) + leftSum = torch.permute(leftSum, 1, 3, 2, 4) + + for leftDim = derDim-2:1,-1 do + local sumSize = W.m[leftDim] * rankArr[leftDim+1] + core = torch.view(coreArr[leftDim], -1, sumSize) + leftSum = torch.view(leftSum, -1, W.m[leftDim]) + leftSum = core * torch.view(leftSum:t(), sumSize, -1) + end + elseif derDim == 1 then + leftSum = torch.view(rightSum, rankArr[derDim+1], -1, batchSize, W.m[derDim]) + leftSum = torch.permute(leftSum, 2, 3, 4, 1) + else + error('Something bad happened(') + end + + local coreSize = rankArr[derDim] * W.n[derDim] * W.m[derDim] * rankArr[derDim+1] + local leftISize = torch.prod(W.n:sub(1, derDim-1)) + local rightISize = torch.prod(W.n:sub(derDim+1, -1)) + + local currout_dzdx = torch.view(self.gradInput, leftISize, W.n[derDim], rightISize*batchSize) + + currout_dzdx = torch.permute(currout_dzdx, 2, 1, 3) + local sumSize = leftISize * rightISize * batchSize + local der = torch.view(currout_dzdx, -1, sumSize) * torch.view(leftSum, sumSize, -1) + + der = torch.view(der, W.n[derDim], rankArr[derDim], W.m[derDim]*rankArr[derDim+1]) + der = torch.permute(der, 2, 1, 3) + DZDWCore:sub(corePos[derDim], corePos[derDim+1]-1) = der + end + self.gradWeight = DZDWCore +end diff --git a/sample.lua b/sample.lua index e22ece74..94b256e1 100644 --- a/sample.lua +++ b/sample.lua @@ -3,7 +3,7 @@ This file samples characters from a trained model -Code is based on implementation in +Code is based on implementation in https://github.com/oxford-cs-ml-2015/practical6 ]]-- @@ -94,18 +94,26 @@ for c,i in pairs(vocab) do ivocab[i] = c end -- initialize the rnn state to all zeros gprint('creating an ' .. checkpoint.opt.model .. '...') -local current_state -current_state = {} -for L = 1,checkpoint.opt.num_layers do - -- c and h for all layers - local h_init = torch.zeros(1, checkpoint.opt.rnn_size):double() - if opt.gpuid >= 0 and opt.opencl == 0 then h_init = h_init:cuda() end - if opt.gpuid >= 0 and opt.opencl == 1 then h_init = h_init:cl() end - table.insert(current_state, h_init:clone()) - if checkpoint.opt.model == 'lstm' then + +local current_state = {} +for L=1,opt.num_layers do + local h_init = torch.zeros(1, opt.rnn_size) + if opt.gpuid >=0 and opt.opencl == 0 then h_init = h_init:cuda() end + if opt.gpuid >=0 and opt.opencl == 1 then h_init = h_init:cl() end + if opt.model == 'lstmex' then + local m_init = torch.zeros(1, opt.lstmex_memory_slots ,opt.rnn_size) + if opt.gpuid >=0 and opt.opencl == 0 then m_init = m_init:cuda() end + if opt.gpuid >=0 and opt.opencl == 1 then m_init = m_init:cl() end + table.insert(current_state, m_init:clone()) + table.insert(current_state, h_init:clone()) + else + table.insert(current_state, h_init:clone()) + end + if opt.model == 'lstm' then table.insert(current_state, h_init:clone()) end end + state_size = #current_state -- do a few seeded timesteps @@ -113,7 +121,7 @@ local seed_text = opt.primetext if string.len(seed_text) > 0 then gprint('seeding with ' .. seed_text) gprint('--------------------------') - for c in seed_text:gmatch'.' do + for char_code, c in pairs(UTF8ToCharArray(seed_text)) do prev_char = torch.Tensor{vocab[c]} io.write(ivocab[prev_char[1]]) if opt.gpuid >= 0 and opt.opencl == 0 then prev_char = prev_char:cuda() end @@ -158,4 +166,3 @@ for i=1, opt.length do io.write(ivocab[prev_char[1]]) end io.write('\n') io.flush() - diff --git a/sample_ntm.lua b/sample_ntm.lua new file mode 100644 index 00000000..39954acd --- /dev/null +++ b/sample_ntm.lua @@ -0,0 +1,168 @@ + +--[[ + +This file samples characters from a trained model + +Code is based on implementation in +https://github.com/oxford-cs-ml-2015/practical6 + +]]-- + +require 'torch' +require 'nn' +require 'nngraph' +require 'optim' +require 'lfs' + +require 'util.OneHot' +require 'util.misc' +require 'model.Print' +require 'model.NTM' + +cmd = torch.CmdLine() +cmd:text() +cmd:text('Sample from a character-level language model') +cmd:text() +cmd:text('Options') +-- required: +cmd:argument('-model','model checkpoint to use for sampling') +-- optional parameters +cmd:option('-seed',123,'random number generator\'s seed') +cmd:option('-sample',1,' 0 to use max at each timestep, 1 to sample at each timestep') +cmd:option('-primetext',"",'used as a prompt to "seed" the state of the LSTM using a given sequence, before we sample.') +cmd:option('-length',2000,'number of characters to sample') +cmd:option('-temperature',1,'temperature of sampling') +cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') +cmd:option('-opencl',0,'use OpenCL (instead of CUDA)') +cmd:option('-verbose',1,'set to 0 to ONLY print the sampled text, no diagnostics') +cmd:text() + +-- parse input params +opt = cmd:parse(arg) + +-- gated print: simple utility function wrapping a print +function gprint(str) + if opt.verbose == 1 then print(str) end +end + +-- check that cunn/cutorch are installed if user wants to use the GPU +if opt.gpuid >= 0 and opt.opencl == 0 then + local ok, cunn = pcall(require, 'cunn') + local ok2, cutorch = pcall(require, 'cutorch') + if not ok then gprint('package cunn not found!') end + if not ok2 then gprint('package cutorch not found!') end + if ok and ok2 then + gprint('using CUDA on GPU ' .. opt.gpuid .. '...') + gprint('Make sure that your saved checkpoint was also trained with GPU. If it was trained with CPU use -gpuid -1 for sampling as well') + cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua + cutorch.manualSeed(opt.seed) + else + gprint('Falling back on CPU mode') + opt.gpuid = -1 -- overwrite user setting + end +end + +-- check that clnn/cltorch are installed if user wants to use OpenCL +if opt.gpuid >= 0 and opt.opencl == 1 then + local ok, cunn = pcall(require, 'clnn') + local ok2, cutorch = pcall(require, 'cltorch') + if not ok then print('package clnn not found!') end + if not ok2 then print('package cltorch not found!') end + if ok and ok2 then + gprint('using OpenCL on GPU ' .. opt.gpuid .. '...') + gprint('Make sure that your saved checkpoint was also trained with GPU. If it was trained with CPU use -gpuid -1 for sampling as well') + cltorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua + torch.manualSeed(opt.seed) + else + gprint('Falling back on CPU mode') + opt.gpuid = -1 -- overwrite user setting + end +end + +torch.manualSeed(opt.seed) + +-- load the model checkpoint +if not lfs.attributes(opt.model, 'mode') then + gprint('Error: File ' .. opt.model .. ' does not exist. Are you sure you didn\'t forget to prepend cv/ ?') +end +checkpoint = torch.load(opt.model) +protos = checkpoint.protos +protos.rnn:evaluate() -- put in eval mode so that dropout works properly + +-- initialize the vocabulary (and its inverted version) +local vocab = checkpoint.vocab +local ivocab = {} +for c,i in pairs(vocab) do ivocab[i] = c end + +-- initialize the rnn state to all zeros +gprint('creating an ' .. checkpoint.opt.model .. '...') + +local current_state = {} +print(opt) +for L=1,checkpoint.opt.num_layers do + local h_init = torch.zeros(1, checkpoint.opt.rnn_size) + if opt.gpuid >=0 and opt.opencl == 0 then h_init = h_init:cuda() end + if opt.gpuid >=0 and opt.opencl == 1 then h_init = h_init:cl() end + if opt.model == 'lstmex' then + local m_init = torch.zeros(1, checkpoint.opt.lstmex_memory_slots ,opt.rnn_size) + if opt.gpuid >=0 and opt.opencl == 0 then m_init = m_init:cuda() end + if opt.gpuid >=0 and opt.opencl == 1 then m_init = m_init:cl() end + table.insert(current_state, m_init:clone()) + table.insert(current_state, h_init:clone()) + else + table.insert(current_state, h_init:clone()) + end + if checkpoint.opt.model == 'lstm' then + table.insert(current_state, h_init:clone()) + end +end + +state_size = #current_state + +-- do a few seeded timesteps +local seed_text = opt.primetext +if string.len(seed_text) > 0 then + gprint('seeding with ' .. seed_text) + gprint('--------------------------') + for char_code, c in pairs(UTF8ToCharArray(seed_text)) do + prev_char = torch.Tensor{vocab[c]} + io.write(ivocab[prev_char[1]]) + if opt.gpuid >= 0 and opt.opencl == 0 then prev_char = prev_char:cuda() end + if opt.gpuid >= 0 and opt.opencl == 1 then prev_char = prev_char:cl() end + local lst = protos.rnn:forward{prev_char, unpack(current_state)} + -- lst is a list of [state1,state2,..stateN,output]. We want everything but last piece + current_state = {} + for i=1,state_size do table.insert(current_state, lst[i]) end + prediction = lst[#lst] -- last element holds the log probabilities + end +else + -- fill with uniform probabilities over characters (? hmm) + gprint('missing seed text, using uniform probability over first character') + gprint('--------------------------') + prediction = torch.Tensor(1, #ivocab):fill(1)/(#ivocab) + if opt.gpuid >= 0 and opt.opencl == 0 then prediction = prediction:cuda() end + if opt.gpuid >= 0 and opt.opencl == 1 then prediction = prediction:cl() end +end + +-- start sampling/argmaxing +for i=1, opt.length do + + -- log probabilities from the previous timestep + if opt.sample == 0 then + -- use argmax + local _, prev_char_ = prediction:squeeze():max(1) + prev_char = prev_char_:resize(1) + else + -- use sampling + prediction:div(opt.temperature) -- scale by temperature + local probs = torch.exp(prediction):squeeze() + probs:div(torch.sum(probs)) -- renormalize so probs sum to one + prev_char = torch.multinomial(probs:float(), 1):resize(1):float() + end + + -- forward the rnn for next character + prediction = protos.rnn:f(prev_char) + + io.write(ivocab[prev_char[1]]) +end +io.write('\n') io.flush() diff --git a/train.lua b/train.lua index 75dd09c7..f6f4e8c8 100644 --- a/train.lua +++ b/train.lua @@ -3,11 +3,11 @@ This file trains a character-level multi-layer RNN on text data -Code is based on implementation in +Code is based on implementation in https://github.com/oxford-cs-ml-2015/practical6 but modified to have multi-layer support, GPU support, as well as many other common model/optimization bells and whistles. -The practical6 code is in turn based on +The practical6 code is in turn based on https://github.com/wojciechz/learning_to_execute which is turn based on other stuff in Torch, etc... (long lineage) @@ -24,8 +24,11 @@ require 'util.misc' local CharSplitLMMinibatchLoader = require 'util.CharSplitLMMinibatchLoader' local model_utils = require 'util.model_utils' local LSTM = require 'model.LSTM' +local LSTMEX = require 'model.LSTMEX' +local LSTMNTM = require 'model.LSTMNTM' local GRU = require 'model.GRU' local RNN = require 'model.RNN' +require 'model.NTM' cmd = torch.CmdLine() cmd:text() @@ -33,19 +36,20 @@ cmd:text('Train a character-level language model') cmd:text() cmd:text('Options') -- data -cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain the file input.txt with input data') +cmd:option('-data_dir','data/ru','data directory. Should contain the file input.txt with input data') -- model params -cmd:option('-rnn_size', 128, 'size of LSTM internal state') -cmd:option('-num_layers', 2, 'number of layers in the LSTM') -cmd:option('-model', 'lstm', 'lstm,gru or rnn') +cmd:option('-rnn_size', 127, 'size of LSTM internal state') +cmd:option('-num_layers', 1, 'number of layers in the LSTM') +cmd:option('-lstmex_memory_slots', 3, 'number of LSTM internal memory slots') +cmd:option('-model', 'gru', 'lstm, gru or rnn, lstmex or ntm or lstmntm') -- optimization -cmd:option('-learning_rate',2e-3,'learning rate') +cmd:option('-learning_rate',5e-5,'learning rate') cmd:option('-learning_rate_decay',0.97,'learning rate decay') cmd:option('-learning_rate_decay_after',10,'in number of epochs, when to start decaying the learning rate') cmd:option('-decay_rate',0.95,'decay rate for rmsprop') cmd:option('-dropout',0,'dropout for regularization, used after each RNN hidden layer. 0 = no dropout') cmd:option('-seq_length',50,'number of timesteps to unroll for') -cmd:option('-batch_size',50,'number of sequences to train on in parallel') +cmd:option('-batch_size',128,'number of sequences to train on in parallel') cmd:option('-max_epochs',50,'number of full passes through the training data') cmd:option('-grad_clip',5,'clip gradients at this value') cmd:option('-train_frac',0.95,'fraction of data that goes into train set') @@ -57,7 +61,7 @@ cmd:option('-seed',123,'torch manual random number generator seed') cmd:option('-print_every',1,'how many steps/minibatches between printing out the loss') cmd:option('-eval_val_every',1000,'every how many iterations should we evaluate on validation data?') cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written') -cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/') +cmd:option('-savefile','model','filename to autosave the checkpont to. Will be inside checkpoint_dir/') -- GPU/CPU cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') cmd:option('-opencl',0,'use OpenCL (instead of CUDA)') @@ -68,10 +72,10 @@ opt = cmd:parse(arg) torch.manualSeed(opt.seed) -- train / val / test split for data, in fractions local test_frac = math.max(0, 1 - (opt.train_frac + opt.val_frac)) -local split_sizes = {opt.train_frac, opt.val_frac, test_frac} +local split_sizes = {opt.train_frac, opt.val_frac, test_frac} -- initialize cunn/cutorch for training on the GPU and fall back to CPU gracefully -if opt.gpuid >= 0 and opt.opencl == 0 then +if opt.gpuid >= 0 then local ok, cunn = pcall(require, 'cunn') local ok2, cutorch = pcall(require, 'cutorch') if not ok then print('package cunn not found!') end @@ -88,24 +92,6 @@ if opt.gpuid >= 0 and opt.opencl == 0 then end end --- initialize clnn/cltorch for training on the GPU and fall back to CPU gracefully -if opt.gpuid >= 0 and opt.opencl == 1 then - local ok, cunn = pcall(require, 'clnn') - local ok2, cutorch = pcall(require, 'cltorch') - if not ok then print('package clnn not found!') end - if not ok2 then print('package cltorch not found!') end - if ok and ok2 then - print('using OpenCL on GPU ' .. opt.gpuid .. '...') - cltorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua - torch.manualSeed(opt.seed) - else - print('If cltorch and clnn are installed, your OpenCL driver may be improperly configured.') - print('Check your OpenCL driver installation, check output of clinfo command, and try again.') - print('Falling back on CPU mode') - opt.gpuid = -1 -- overwrite user setting - end -end - -- create the data loader class local loader = CharSplitLMMinibatchLoader.create(opt.data_dir, opt.batch_size, opt.seq_length, split_sizes) local vocab_size = loader.vocab_size -- the number of distinct characters @@ -122,8 +108,8 @@ if string.len(opt.init_from) > 0 then protos = checkpoint.protos -- make sure the vocabs are the same local vocab_compatible = true - for c,i in pairs(checkpoint.vocab) do - if not vocab[c] == i then + for c,i in pairs(checkpoint.vocab) do + if not vocab[c] == i then vocab_compatible = false end end @@ -138,6 +124,14 @@ else protos = {} if opt.model == 'lstm' then protos.rnn = LSTM.lstm(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout) + elseif opt.model == 'lstmex' then + protos.rnn = LSTMEX.lstm2(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout, opt.lstmex_memory_slots) + elseif opt.model == 'ntm' then + local ntm_conf = { + input_dim = vocab_size, + output_dim = vocab_size + } + protos.rnn = nn.NTM(ntm_conf) elseif opt.model == 'gru' then protos.rnn = GRU.gru(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout) elseif opt.model == 'rnn' then @@ -150,21 +144,25 @@ end init_state = {} for L=1,opt.num_layers do local h_init = torch.zeros(opt.batch_size, opt.rnn_size) - if opt.gpuid >=0 and opt.opencl == 0 then h_init = h_init:cuda() end - if opt.gpuid >=0 and opt.opencl == 1 then h_init = h_init:cl() end - table.insert(init_state, h_init:clone()) + if opt.gpuid >=0 then h_init = h_init:cuda() end + if opt.model == 'lstmex' then + local m_init = torch.zeros(opt.batch_size, opt.lstmex_memory_slots ,opt.rnn_size) + if opt.gpuid >=0 then m_init = m_init:cuda() end + table.insert(init_state, m_init:clone()) + table.insert(init_state, h_init:clone()) + else + table.insert(init_state, h_init:clone()) + end if opt.model == 'lstm' then table.insert(init_state, h_init:clone()) end end -- ship the model to the GPU if desired -if opt.gpuid >= 0 and opt.opencl == 0 then +if opt.gpuid >= 0 then for k,v in pairs(protos) do v:cuda() end end -if opt.gpuid >= 0 and opt.opencl == 1 then - for k,v in pairs(protos) do v:cl() end -end + -- put the above things into one flattened parameters tensor params, grad_params = model_utils.combine_all_parameters(protos.rnn) @@ -203,26 +201,23 @@ function eval_split(split_index, max_batches) loader:reset_batch_pointer(split_index) -- move batch iteration pointer for this split to front local loss = 0 local rnn_state = {[0] = init_state} - + for i = 1,n do -- iterate over batches in the split -- fetch a batch local x, y = loader:next_batch(split_index) - if opt.gpuid >= 0 and opt.opencl == 0 then -- ship the input arrays to GPU + if opt.gpuid >= 0 then -- ship the input arrays to GPU -- have to convert to float because integers can't be cuda()'d x = x:float():cuda() y = y:float():cuda() end - if opt.gpuid >= 0 and opt.opencl == 1 then -- ship the input arrays to GPU - x = x:cl() - y = y:cl() - end + -- forward pass for t=1,opt.seq_length do clones.rnn[t]:evaluate() -- for dropout proper functioning local lst = clones.rnn[t]:forward{x[{{}, t}], unpack(rnn_state[t-1])} rnn_state[t] = {} for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end - prediction = lst[#lst] + prediction = lst[#lst] loss = loss + clones.criterion[t]:forward(prediction, y[{{}, t}]) end -- carry over lstm state @@ -244,15 +239,12 @@ function feval(x) ------------------ get minibatch ------------------- local x, y = loader:next_batch(1) - if opt.gpuid >= 0 and opt.opencl == 0 then -- ship the input arrays to GPU + if opt.gpuid >= 0 then -- ship the input arrays to GPU -- have to convert to float because integers can't be cuda()'d x = x:float():cuda() y = y:float():cuda() end - if opt.gpuid >= 0 and opt.opencl == 1 then -- ship the input arrays to GPU - x = x:cl() - y = y:cl() - end + ------------------- forward pass ------------------- local rnn_state = {[0] = init_state_global} local predictions = {} -- softmax outputs @@ -277,7 +269,7 @@ function feval(x) drnn_state[t-1] = {} for k,v in pairs(dlst) do if k > 1 then -- k == 1 is gradient on x, which we dont need - -- note we do k-1 because first item is dembeddings, and then follow the + -- note we do k-1 because first item is dembeddings, and then follow the -- derivatives of the state, starting at index 2. I know... drnn_state[t-1][k-1] = v end @@ -303,6 +295,7 @@ for i = 1, iterations do local epoch = i / loader.ntrain local timer = torch.Timer() + optim_state.learningRate = dofile("lr.lua")() local _, loss = optim.rmsprop(feval, params, optim_state) local time = timer:time().real @@ -341,7 +334,7 @@ for i = 1, iterations do if i % opt.print_every == 0 then print(string.format("%d/%d (epoch %.3f), train_loss = %6.8f, grad/param norm = %6.4e, time/batch = %.2fs", i, iterations, epoch, train_loss, grad_params:norm() / params:norm(), time)) end - + if i % 10 == 0 then collectgarbage() end -- handle early stopping if things are going really bad @@ -355,5 +348,3 @@ for i = 1, iterations do break -- halt end end - - diff --git a/train_ntm.lua b/train_ntm.lua new file mode 100644 index 00000000..503caa41 --- /dev/null +++ b/train_ntm.lua @@ -0,0 +1,376 @@ + +--[[ + +This file trains a character-level multi-layer RNN on text data + +Code is based on implementation in +https://github.com/oxford-cs-ml-2015/practical6 +but modified to have multi-layer support, GPU support, as well as +many other common model/optimization bells and whistles. +The practical6 code is in turn based on +https://github.com/wojciechz/learning_to_execute +which is turn based on other stuff in Torch, etc... (long lineage) + +]]-- + +require 'torch' +torch.setdefaulttensortype('torch.DoubleTensor') +require 'nn' +require 'nngraph' +require 'optim' +require 'lfs' + +require 'util.OneHot' +require 'util.misc' +local CharSplitLMMinibatchLoader = require 'util.CharSplitLMMinibatchLoader' +local model_utils = require 'util.model_utils' +local LSTM = require 'model.LSTM' +local LSTMEX = require 'model.LSTMEX' +local LSTMNTM = require 'model.LSTMNTM' +local GRU = require 'model.GRU' +local RNN = require 'model.RNN' +require 'model.NTM' + +cmd = torch.CmdLine() +cmd:text() +cmd:text('Train a character-level language model') +cmd:text() +cmd:text('Options') +-- data +cmd:option('-data_dir','data/ru','data directory. Should contain the file input.txt with input data') +-- model params +cmd:option('-rnn_size', 128, 'size of LSTM internal state') +cmd:option('-num_layers', 1, 'number of layers in the LSTM') +cmd:option('-lstmex_memory_slots', 3, 'number of LSTM internal memory slots') +cmd:option('-model', 'ntm', 'lstm, gru or rnn, lstmex or ntm or lstmntm') +-- optimization +cmd:option('-learning_rate',5e-5,'learning rate') +cmd:option('-learning_rate_decay',0.97,'learning rate decay') +cmd:option('-learning_rate_decay_after',10,'in number of epochs, when to start decaying the learning rate') +cmd:option('-decay_rate',0.95,'decay rate for rmsprop') +cmd:option('-dropout',0,'dropout for regularization, used after each RNN hidden layer. 0 = no dropout') +cmd:option('-seq_length',50,'number of timesteps to unroll for') +cmd:option('-batch_size',1,'number of sequences to train on in parallel') +cmd:option('-max_epochs',50,'number of full passes through the training data') +cmd:option('-grad_clip',5,'clip gradients at this value') +cmd:option('-train_frac',0.99,'fraction of data that goes into train set') +cmd:option('-val_frac',0.01,'fraction of data that goes into validation set') + -- test_frac will be computed as (1 - train_frac - val_frac) +cmd:option('-init_from', '', 'initialize network parameters from checkpoint at this path') +-- bookkeeping +cmd:option('-seed',123,'torch manual random number generator seed') +cmd:option('-print_every',1,'how many steps/minibatches between printing out the loss') +cmd:option('-eval_val_every',1000,'every how many iterations should we evaluate on validation data?') +cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written') +cmd:option('-savefile','model','filename to autosave the checkpont to. Will be inside checkpoint_dir/') +-- GPU/CPU +cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') +cmd:option('-opencl',0,'use OpenCL (instead of CUDA)') +cmd:text() + +-- parse input params +opt = cmd:parse(arg) +torch.manualSeed(opt.seed) +-- train / val / test split for data, in fractions +local test_frac = math.max(0, 1 - (opt.train_frac + opt.val_frac)) +local split_sizes = {opt.train_frac, opt.val_frac, test_frac} + +-- initialize cunn/cutorch for training on the GPU and fall back to CPU gracefully +if opt.gpuid >= 0 and opt.opencl == 0 then + local ok, cunn = pcall(require, 'cunn') + local ok2, cutorch = pcall(require, 'cutorch') + if not ok then print('package cunn not found!') end + if not ok2 then print('package cutorch not found!') end + if ok and ok2 then + print('using CUDA on GPU ' .. opt.gpuid .. '...') + cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua + cutorch.manualSeed(opt.seed) + else + print('If cutorch and cunn are installed, your CUDA toolkit may be improperly configured.') + print('Check your CUDA toolkit installation, rebuild cutorch and cunn, and try again.') + print('Falling back on CPU mode') + opt.gpuid = -1 -- overwrite user setting + end +end + +-- initialize clnn/cltorch for training on the GPU and fall back to CPU gracefully +if opt.gpuid >= 0 and opt.opencl == 1 then + local ok, cunn = pcall(require, 'clnn') + local ok2, cutorch = pcall(require, 'cltorch') + if not ok then print('package clnn not found!') end + if not ok2 then print('package cltorch not found!') end + if ok and ok2 then + print('using OpenCL on GPU ' .. opt.gpuid .. '...') + cltorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua + torch.manualSeed(opt.seed) + else + print('If cltorch and clnn are installed, your OpenCL driver may be improperly configured.') + print('Check your OpenCL driver installation, check output of clinfo command, and try again.') + print('Falling back on CPU mode') + opt.gpuid = -1 -- overwrite user setting + end +end + +-- create the data loader class +local loader = CharSplitLMMinibatchLoader.create(opt.data_dir, opt.batch_size, opt.seq_length, split_sizes) +local vocab_size = loader.vocab_size -- the number of distinct characters +local vocab = loader.vocab_mapping +print('vocab size: ' .. vocab_size) +-- make sure output directory exists +if not path.exists(opt.checkpoint_dir) then lfs.mkdir(opt.checkpoint_dir) end + +-- define the model: prototypes for one timestep, then clone them in time +local do_random_init = true +local criterions = {} +if string.len(opt.init_from) > 0 then + print('loading an LSTM from checkpoint ' .. opt.init_from) + local checkpoint = torch.load(opt.init_from) + protos = checkpoint.protos + -- make sure the vocabs are the same + local vocab_compatible = true + for c,i in pairs(checkpoint.vocab) do + if not vocab[c] == i then + vocab_compatible = false + end + end + assert(vocab_compatible, 'error, the character vocabulary for this dataset and the one in the saved checkpoint are not the same. This is trouble.') + -- overwrite model settings based on checkpoint to ensure compatibility + print('overwriting rnn_size=' .. checkpoint.opt.rnn_size .. ', num_layers=' .. checkpoint.opt.num_layers .. ' based on the checkpoint.') + opt.rnn_size = checkpoint.opt.rnn_size + opt.num_layers = checkpoint.opt.num_layers + do_random_init = false +else + print('creating an ' .. opt.model .. ' with ' .. opt.num_layers .. ' layers') + protos = {} + if opt.model == 'lstm' then + protos.rnn = LSTM.lstm(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout) + elseif opt.model == 'lstmex' then + protos.rnn = LSTMEX.lstm2(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout, opt.lstmex_memory_slots) + elseif opt.model == 'ntm' then + local ntm_conf = { + input_dim = vocab_size, + output_dim = vocab_size + } + protos.rnn = nn.NTM(ntm_conf) + elseif opt.model == 'gru' then + protos.rnn = GRU.gru(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout) + elseif opt.model == 'rnn' then + protos.rnn = RNN.rnn(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout) + end + + for i=1,opt.seq_length do + criterions[i] = nn.ClassNLLCriterion() + end +end + +-- the initial state of the cell/hidden states +init_state = {} +for L=1,opt.num_layers do + local h_init = torch.zeros(opt.batch_size, opt.rnn_size) + if opt.gpuid >=0 and opt.opencl == 0 then h_init = h_init:cuda() end + if opt.gpuid >=0 and opt.opencl == 1 then h_init = h_init:cl() end + if opt.model == 'lstmex' then + local m_init = torch.zeros(opt.batch_size, opt.lstmex_memory_slots ,opt.rnn_size) + if opt.gpuid >=0 and opt.opencl == 0 then m_init = m_init:cuda() end + if opt.gpuid >=0 and opt.opencl == 1 then m_init = m_init:cl() end + table.insert(init_state, m_init:clone()) + table.insert(init_state, h_init:clone()) + else + table.insert(init_state, h_init:clone()) + end + if opt.model == 'lstm' then + table.insert(init_state, h_init:clone()) + end +end + +-- ship the model to the GPU if desired +if opt.gpuid >= 0 and opt.opencl == 0 then + for k,v in pairs(protos) do v:cuda() end +end +if opt.gpuid >= 0 and opt.opencl == 1 then + for k,v in pairs(protos) do v:cl() end +end + +-- put the above things into one flattened parameters tensor +params, grad_params = model_utils.combine_all_parameters(protos.rnn) + +-- initialization +if do_random_init then + params:uniform(-0.08, 0.08) -- small uniform numbers +end +-- initialize the LSTM forget gates with slightly higher biases to encourage remembering in the beginning +if opt.model == 'lstm' then + for layer_idx = 1, opt.num_layers do + for _,node in ipairs(protos.rnn.forwardnodes) do + if node.data.annotations.name == "i2h_" .. layer_idx then + print('setting forget gate biases to 1 in LSTM layer ' .. layer_idx) + -- the gates are, in order, i,f,o,g, so f is the 2nd block of weights + node.data.module.bias[{{opt.rnn_size+1, 2*opt.rnn_size}}]:fill(1.0) + end + end + end +end + +print('number of parameters in the model: ' .. params:nElement()) +-- make a bunch of clones after flattening, as that reallocates memory +--clones = {} +--for name,proto in pairs(protos) do +-- print('cloning ' .. name) +-- clones[name] = model_utils.clone_many_times(proto, opt.seq_length, not proto.parameters) +--end + +-- evaluate the loss over an entire split +function eval_split(split_index, max_batches) + print('evaluating loss over split index ' .. split_index) + local n = loader.split_sizes[split_index] + n = 50 + if max_batches ~= nil then n = math.min(max_batches, n) end + + loader:reset_batch_pointer(split_index) -- move batch iteration pointer for this split to front + local loss = 0 + local rnn_state = {[0] = init_state} + + for i = 1,n do -- iterate over batches in the split + -- fetch a batch + local x, y = loader:next_batch(split_index) + if opt.gpuid >= 0 and opt.opencl == 0 then -- ship the input arrays to GPU + -- have to convert to float because integers can't be cuda()'d + x = x:float():cuda() + y = y:float():cuda() + end + if opt.gpuid >= 0 and opt.opencl == 1 then -- ship the input arrays to GPU + x = x:cl() + y = y:cl() + end + -- forward pass + for t=1,opt.seq_length do + protos.rnn:evaluate() -- for dropout proper functioning + local prediction = protos.rnn:f(x[{{1}, t}]) + loss = loss + + criterions[t]:forward(prediction, y[{{1}, t}]) + end + -- carry over lstm state + print(i .. '/' .. n .. '...') + end + protos.rnn:r() + loss = loss / opt.seq_length / n + return loss +end + +-- do fwd/bwd and return loss, grad_params +local init_state_global = clone_list(init_state) +function feval(x) + if x ~= params then + params:copy(x) + end + grad_params:zero() + + ------------------ get minibatch ------------------- + local x, y = loader:next_batch(1) + if opt.gpuid >= 0 and opt.opencl == 0 then -- ship the input arrays to GPU + -- have to convert to float because integers can't be cuda()'d + x = x:float():cuda() + y = y:float():cuda() + end + if opt.gpuid >= 0 and opt.opencl == 1 then -- ship the input arrays to GPU + x = x:cl() + y = y:cl() + end + ------------------- forward pass ------------------- + local rnn_state = {[0] = init_state_global} + local predictions = {} -- softmax outputs + local loss = 0 + for t=1,opt.seq_length do + protos.rnn:training() -- make sure we are in correct mode (this is cheap, sets flag) + --predictions[t] = protos.rnn:forward(x[{{}, t}]) + predictions[t] = protos.rnn:forward(x[{{1}, t}]):clone() + --loss = loss + protos.criterions[t]:forward(predictions[t], y[{{}, t}]) + loss = loss + criterions[t]:forward(predictions[t], y[{{1}, t}]) + end + + loss = loss / opt.seq_length + -- maintain last state + --protos.rnn:set_last_state() + ------------------ backward pass ------------------- + -- initialize gradient at time t to be zeros (there's no influence from future) + local drnn_state = {[opt.seq_length] = clone_list(init_state, true)} -- true also zeros the clones + for t=opt.seq_length,1,-1 do + -- backprop through loss, and softmax/linear + local grad = criterions[t]:backward(predictions[t], y[{{1}, t}]) + --print(t,grad) + protos.rnn:backward(x[{{1}, t}], grad) + end + ------------------------ misc ---------------------- + -- transfer final state to initial state (BPTT) + init_state_global = rnn_state[#rnn_state] -- NOTE: I don't think this needs to be a clone, right? + -- grad_params:div(opt.seq_length) -- this line should be here but since we use rmsprop it would have no effect. Removing for efficiency + -- clip gradient element-wise + grad_params:clamp(-opt.grad_clip, opt.grad_clip) + return loss, grad_params +end + +-- start optimization here +train_losses = {} +val_losses = {} +local optim_state = {learningRate = opt.learning_rate, alpha = opt.decay_rate} +local iterations = opt.max_epochs * loader.ntrain +local iterations_per_epoch = loader.ntrain +local loss0 = nil +for i = 1, iterations do + local epoch = i / loader.ntrain + + local timer = torch.Timer() + optim_state.learningRate = dofile("lr.lua")() + local _, loss = optim.rmsprop(feval, params, optim_state) + local time = timer:time().real + + local train_loss = loss[1] -- the loss is inside a list, pop it + train_losses[i] = train_loss + + -- exponential learning rate decay + if i % loader.ntrain == 0 and opt.learning_rate_decay < 1 then + if epoch >= opt.learning_rate_decay_after then + local decay_factor = opt.learning_rate_decay + optim_state.learningRate = optim_state.learningRate * decay_factor -- decay it + print('decayed learning rate by a factor ' .. decay_factor .. ' to ' .. optim_state.learningRate) + end + end + + -- every now and then or on last iteration + if i % opt.eval_val_every == 0 or i == iterations then + -- evaluate loss on validation data + local val_loss = eval_split(2) -- 2 = validation + val_losses[i] = val_loss + + local savefile = string.format('%s/lm_%s_epoch%.3f_%.4f.t7', opt.checkpoint_dir, opt.savefile, epoch, val_loss) + print('saving checkpoint to ' .. savefile) + local checkpoint = {} + + checkpoint.protos = protos + checkpoint.opt = opt + checkpoint.train_losses = train_losses + checkpoint.val_loss = val_loss + checkpoint.val_losses = val_losses + checkpoint.i = i + checkpoint.epoch = epoch + checkpoint.vocab = loader.vocab_mapping + torch.save(savefile, checkpoint) + end + + if i % opt.print_every == 0 then + print(string.format("%d/%d (epoch %.3f), train_loss = %6.8f, grad/param norm = %6.4e, time/batch = %.2fs , lr = %f", i, iterations, epoch, train_loss, grad_params:norm() / params:norm(), time, optim_state.learningRate)) + end + + if i % 10 == 0 then collectgarbage() end + + -- handle early stopping if things are going really bad + if loss[1] ~= loss[1] then + print('loss is NaN. This usually indicates a bug. Please check the issues page for existing issues, or create a new issue, if none exist. Ideally, please state: your operating system, 32-bit/64-bit, your blas version, cpu/cuda/cl?') + break -- halt + end + if loss0 == nil then loss0 = loss[1] end + if loss[1] > loss0 * 3 then + print('loss is exploding, aborting.') + break -- halt + end +end diff --git a/util/CharSplitLMMinibatchLoader.lua b/util/CharSplitLMMinibatchLoader.lua index 08e95ecd..d9ee844a 100644 --- a/util/CharSplitLMMinibatchLoader.lua +++ b/util/CharSplitLMMinibatchLoader.lua @@ -2,6 +2,8 @@ -- Modified from https://github.com/oxford-cs-ml-2015/practical6 -- the modification included support for train/val/test splits +require './misc.lua' + local CharSplitLMMinibatchLoader = {} CharSplitLMMinibatchLoader.__index = CharSplitLMMinibatchLoader @@ -22,7 +24,7 @@ function CharSplitLMMinibatchLoader.create(data_dir, batch_size, seq_length, spl print('vocab.t7 and data.t7 do not exist. Running preprocessing...') run_prepro = true else - -- check if the input file was modified since last time we + -- check if the input file was modified since last time we -- ran the prepro. if so, we have to rerun the preprocessing local input_attr = lfs.attributes(input_file) local vocab_attr = lfs.attributes(vocab_file) @@ -46,14 +48,14 @@ function CharSplitLMMinibatchLoader.create(data_dir, batch_size, seq_length, spl local len = data:size(1) if len % (batch_size * seq_length) ~= 0 then print('cutting off end of data so that the batches/sequences divide evenly') - data = data:sub(1, batch_size * seq_length + data = data:sub(1, batch_size * seq_length * math.floor(len / (batch_size * seq_length))) end -- count vocab self.vocab_size = 0 - for _ in pairs(self.vocab_mapping) do - self.vocab_size = self.vocab_size + 1 + for _ in pairs(self.vocab_mapping) do + self.vocab_size = self.vocab_size + 1 end -- self.batches is a table of tensors @@ -78,7 +80,7 @@ function CharSplitLMMinibatchLoader.create(data_dir, batch_size, seq_length, spl assert(split_fractions[1] >= 0 and split_fractions[1] <= 1, 'bad split fraction ' .. split_fractions[1] .. ' for train, not between 0 and 1') assert(split_fractions[2] >= 0 and split_fractions[2] <= 1, 'bad split fraction ' .. split_fractions[2] .. ' for val, not between 0 and 1') assert(split_fractions[3] >= 0 and split_fractions[3] <= 1, 'bad split fraction ' .. split_fractions[3] .. ' for test, not between 0 and 1') - if split_fractions[3] == 0 then + if split_fractions[3] == 0 then -- catch a common special case where the user might not want a test set self.ntrain = math.floor(self.nbatches * split_fractions[1]) self.nval = self.nbatches - self.ntrain @@ -125,10 +127,9 @@ end -- *** STATIC method *** function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, out_tensorfile) local timer = torch.Timer() - print('loading text file...') local cache_len = 10000 - local rawdata + local line local tot_len = 0 f = io.open(in_textfile, "r") @@ -136,15 +137,20 @@ function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, o print('creating vocabulary mapping...') -- record all characters to a set local unordered = {} - rawdata = f:read(cache_len) - repeat - for char in rawdata:gmatch'.' do - if not unordered[char] then unordered[char] = true end - end - tot_len = tot_len + #rawdata - rawdata = f:read(cache_len) - until not rawdata + -- read file line by line + while true do + line = f:read() + if line == nil then break end -- no more lines to read + for char_code, char in pairs(UTF8ToCharArray(line)) do + if not unordered[char] then unordered[char] = true end + tot_len = tot_len + 1 + end + -- don't forget end of line character it is excluded by f:read() + if not unordered['\n'] then unordered['\n'] = true end + tot_len = tot_len + 1 + end f:close() + -- sort into a table (i.e. keys become 1..N) local ordered = {} for char in pairs(unordered) do ordered[#ordered + 1] = char end @@ -154,19 +160,24 @@ function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, o for i, char in ipairs(ordered) do vocab_mapping[char] = i end + -- construct a tensor with all the data print('putting data into tensor...') local data = torch.ByteTensor(tot_len) -- store it into 1D first, then rearrange f = io.open(in_textfile, "r") - local currlen = 0 - rawdata = f:read(cache_len) - repeat - for i=1, #rawdata do - data[currlen+i] = vocab_mapping[rawdata:sub(i, i)] -- lua has no string indexing using [] - end - currlen = currlen + #rawdata - rawdata = f:read(cache_len) - until not rawdata + local currlen = 1 + + while true do + line = f:read() + if line == nil then break end -- no more lines to read + for char_code, char in pairs(UTF8ToCharArray(line)) do + data[currlen] = vocab_mapping[char] + currlen = currlen + 1 + end + -- don't forget end of line character it is excluded by f:read() + data[currlen] = vocab_mapping['\n'] + currlen = currlen + 1 + end f:close() -- save output preprocessed files @@ -177,4 +188,3 @@ function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, o end return CharSplitLMMinibatchLoader - diff --git a/util/OneHot.lua b/util/OneHot.lua index 538b434d..ab94748e 100644 --- a/util/OneHot.lua +++ b/util/OneHot.lua @@ -11,7 +11,8 @@ function OneHot:__init(outputSize) end function OneHot:updateOutput(input) - self.output:resize(input:size(1), self.outputSize):zero() + --self.output:resize(input:size(1), self.outputSize):zero() + self.output:resize(self.outputSize):zero() if self._eye == nil then self._eye = torch.eye(self.outputSize) end self._eye = self._eye:float() local longInput = input:long() diff --git a/util/misc.lua b/util/misc.lua index 043f65cb..e458b44b 100644 --- a/util/misc.lua +++ b/util/misc.lua @@ -10,4 +10,47 @@ function clone_list(tensor_list, zero_too) if zero_too then out[k]:zero() end end return out -end \ No newline at end of file +end + +-- Multi byte characters start with a byte with bits 7 and 8 set, trailing bytes have bit 7 not set and bit 8 set. +-- https://forums.coronalabs.com/topic/42019-split-utf-8-string-word-with-foreign-characters-to-letters/ by ingemar +function UTF8ToCharArray(str) + local charArray = {}; + local iStart = 0; + local strLen = str:len(); + + local function bit(b) + return 2 ^ (b - 1); + end + + local function hasbit(w, b) + return w % (b + b) >= b; + end + + local checkMultiByte = function(i) + if (iStart ~= 0) then + charArray[#charArray + 1] = str:sub(iStart, i - 1); + iStart = 0; + end + end + + for i = 1, strLen do + local b = str:byte(i); + local multiStart = hasbit(b, bit(7)) and hasbit(b, bit(8)); + local multiTrail = not hasbit(b, bit(7)) and hasbit(b, bit(8)); + + if (multiStart) then + checkMultiByte(i); + iStart = i; + + elseif (not multiTrail) then + checkMultiByte(i); + charArray[#charArray + 1] = str:sub(i, i); + end + end + + -- process if last character is multi-byte + checkMultiByte(strLen + 1); + + return charArray; +end