Skip to content

Commit

Permalink
Merge pull request #64 from cheng6076/master
Browse files Browse the repository at this point in the history
Update encoder-decoder-coupling.lua
  • Loading branch information
nicholas-leonard committed Dec 1, 2015
2 parents 98051e2 + e1e1335 commit fdc3b21
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions examples/encoder-decoder-coupling.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require 'rnn'

torch.manualSeed(123)

version = 1
version = 1.1 --supports both online and mini-batch training

--[[ Forward coupling: Copy encoder cell and output to decoder LSTM ]]--
function forwardConnect(encLSTM, decLSTM)
Expand All @@ -31,20 +31,21 @@ function main()
opt.inputSeqLen = 3 -- length of the encoded sequence

-- Some example data
local encInSeq, decInSeq, decOutSeq = torch.Tensor({1,2,3}), torch.Tensor({1,2,3,4}), torch.Tensor({2,3,4,1})

local encInSeq, decInSeq, decOutSeq = torch.Tensor({{1,2,3},{3,2,1}}), torch.Tensor({{1,2,3,4},{4,3,2,1}}), torch.Tensor({{2,3,4,1},{1,2,4,3}})
decOutSeq = nn.SplitTable(1, 1):forward(decOutSeq)

-- Encoder
local enc = nn.Sequential()
enc:add(nn.LookupTable(opt.vocabSz, opt.hiddenSz))
enc:add(nn.SplitTable(1))
enc:add(nn.SplitTable(1, 2)) --works for both online and mini-batch mode
local encLSTM = nn.LSTM(opt.hiddenSz, opt.hiddenSz)
enc:add(nn.Sequencer(encLSTM))
enc:add(nn.SelectTable(-1))

-- Decoder
local dec = nn.Sequential()
dec:add(nn.LookupTable(opt.vocabSz, opt.hiddenSz))
dec:add(nn.SplitTable(1))
dec:add(nn.SplitTable(1, 2)) --works for both online and mini-batch mode
local decLSTM = nn.LSTM(opt.hiddenSz, opt.hiddenSz)
dec:add(nn.Sequencer(decLSTM))
dec:add(nn.Sequencer(nn.Linear(opt.hiddenSz, opt.vocabSz)))
Expand Down

0 comments on commit fdc3b21

Please sign in to comment.