Skip to content

Commit e1e1335

Browse files
committed
supports both online and mini-batch mode
1 parent 146f912 commit e1e1335

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

examples/encoder-decoder-coupling.lua

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ require 'rnn'
99

1010
torch.manualSeed(123)
1111

12-
version = 1.1 --minibatch training
12+
version = 1.1 --supports both online and mini-batch training
1313

1414
--[[ Forward coupling: Copy encoder cell and output to decoder LSTM ]]--
1515
function forwardConnect(encLSTM, decLSTM)
@@ -32,20 +32,20 @@ function main()
3232

3333
-- Some example data
3434
local encInSeq, decInSeq, decOutSeq = torch.Tensor({{1,2,3},{3,2,1}}), torch.Tensor({{1,2,3,4},{4,3,2,1}}), torch.Tensor({{2,3,4,1},{1,2,4,3}})
35-
decOutSeq = nn.SplitTable(2):forward(decOutSeq)
35+
decOutSeq = nn.SplitTable(1, 1):forward(decOutSeq)
3636

3737
-- Encoder
3838
local enc = nn.Sequential()
3939
enc:add(nn.LookupTable(opt.vocabSz, opt.hiddenSz))
40-
enc:add(nn.SplitTable(2)) --Split along the second axis, where the first axis is the mini-batch
40+
enc:add(nn.SplitTable(1, 2)) --works for both online and mini-batch mode
4141
local encLSTM = nn.LSTM(opt.hiddenSz, opt.hiddenSz)
4242
enc:add(nn.Sequencer(encLSTM))
4343
enc:add(nn.SelectTable(-1))
4444

4545
-- Decoder
4646
local dec = nn.Sequential()
4747
dec:add(nn.LookupTable(opt.vocabSz, opt.hiddenSz))
48-
dec:add(nn.SplitTable(2)) --Split along the second axis, where the first axis is the mini-batch
48+
dec:add(nn.SplitTable(1, 2)) --works for both online and mini-batch mode
4949
local decLSTM = nn.LSTM(opt.hiddenSz, opt.hiddenSz)
5050
dec:add(nn.Sequencer(decLSTM))
5151
dec:add(nn.Sequencer(nn.Linear(opt.hiddenSz, opt.vocabSz)))

0 commit comments

Comments
 (0)