Skip to content

Commit

Permalink
supports both online and mini-batch mode
Browse files Browse the repository at this point in the history
  • Loading branch information
cheng6076 committed Dec 1, 2015
1 parent 146f912 commit e1e1335
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions examples/encoder-decoder-coupling.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require 'rnn'

torch.manualSeed(123)

version = 1.1 --minibatch training
version = 1.1 --supports both online and mini-batch training

--[[ Forward coupling: Copy encoder cell and output to decoder LSTM ]]--
function forwardConnect(encLSTM, decLSTM)
Expand All @@ -32,20 +32,20 @@ function main()

-- Some example data
local encInSeq, decInSeq, decOutSeq = torch.Tensor({{1,2,3},{3,2,1}}), torch.Tensor({{1,2,3,4},{4,3,2,1}}), torch.Tensor({{2,3,4,1},{1,2,4,3}})
decOutSeq = nn.SplitTable(2):forward(decOutSeq)
decOutSeq = nn.SplitTable(1, 1):forward(decOutSeq)

-- Encoder
local enc = nn.Sequential()
enc:add(nn.LookupTable(opt.vocabSz, opt.hiddenSz))
enc:add(nn.SplitTable(2)) --Split along the second axis, where the first axis is the mini-batch
enc:add(nn.SplitTable(1, 2)) --works for both online and mini-batch mode
local encLSTM = nn.LSTM(opt.hiddenSz, opt.hiddenSz)
enc:add(nn.Sequencer(encLSTM))
enc:add(nn.SelectTable(-1))

-- Decoder
local dec = nn.Sequential()
dec:add(nn.LookupTable(opt.vocabSz, opt.hiddenSz))
dec:add(nn.SplitTable(2)) --Split along the second axis, where the first axis is the mini-batch
dec:add(nn.SplitTable(1, 2)) --works for both online and mini-batch mode
local decLSTM = nn.LSTM(opt.hiddenSz, opt.hiddenSz)
dec:add(nn.Sequencer(decLSTM))
dec:add(nn.Sequencer(nn.Linear(opt.hiddenSz, opt.vocabSz)))
Expand Down

0 comments on commit e1e1335

Please sign in to comment.