Skip to content

Commit

Permalink
Merge pull request #29 from torch/SeqBRNN
Browse files Browse the repository at this point in the history
Bidirectional RNNs
  • Loading branch information
nicholas-leonard authored May 24, 2017
2 parents 6ca4c57 + d362b2c commit 31bd914
Show file tree
Hide file tree
Showing 18 changed files with 509 additions and 385 deletions.
118 changes: 66 additions & 52 deletions BiSequencer.lua
Original file line number Diff line number Diff line change
@@ -1,71 +1,85 @@
------------------------------------------------------------------------
--[[ BiSequencer ]]--
-- Encapsulates forward, backward and merge modules.
-- Input is a sequence (a table) of tensors.
-- Output is a sequence (a table) of tensors of the same length.
-- Applies a forward rnn to each element in the sequence in
-- forward order and applies a backward rnn in reverse order.
-- For each step, the outputs of both rnn are merged together using
-- the merge module (defaults to nn.JoinTable(1,1)).
-- The sequences in a batch must have the same size.
-- But the sequence length of each batch can vary.
-- It is implemented by decorating a structure of modules that makes
-- use of 3 Sequencers for the forward, backward and merge modules.
-- Encapsulates forward, backward and merge modules.
-- Input is a seqlen x inputsize [x ...] sequence tensor
-- Output is a seqlen x outputsize [x ...] sequence tensor
-- Applies a forward RNN to each element in the sequence in
-- forward order and applies a backward RNN in reverse order.
-- For each step, the outputs of both RNNs are merged together using
-- the merge module (defaults to nn.CAddTable()).
------------------------------------------------------------------------
local BiSequencer, parent = torch.class('nn.BiSequencer', 'nn.AbstractSequencer')

function BiSequencer:__init(forward, backward, merge)

parent.__init(self)

if not torch.isTypeOf(forward, 'nn.Module') then
error"BiSequencer: expecting nn.Module instance at arg 1"
end
self.forwardModule = forward

self.backwardModule = backward
if not self.backwardModule then
self.backwardModule = forward:clone()
self.backwardModule:reset()

if not backward then
backward = forward:clone()
backward:reset()
end
if not torch.isTypeOf(self.backwardModule, 'nn.Module') then
error"BiSequencer: expecting nn.Module instance at arg 2"
if not torch.isTypeOf(backward, 'nn.Module') then
error"BiSequencer: expecting nn.Module instance or nil at arg 2"
end

if torch.type(merge) == 'number' then
self.mergeModule = nn.JoinTable(1, merge)
elseif merge == nil then
self.mergeModule = nn.JoinTable(1, 1)
elseif torch.isTypeOf(merge, 'nn.Module') then
self.mergeModule = merge
else
error"BiSequencer: expecting nn.Module or number instance at arg 3"

-- for table sequences use nn.Sequential():add(nn.ZipTable()):add(nn.Sequencer(nn.JoinTable(1,1)))
merge = merge or nn.CAddTable()
if not torch.isTypeOf(merge, 'nn.Module') then
error"BiSequencer: expecting nn.Module instance or nil at arg 3"
end

self.fwdSeq = nn.Sequencer(self.forwardModule)
self.bwdSeq = nn.Sequencer(self.backwardModule)
self.mergeSeq = nn.Sequencer(self.mergeModule)

local backward = nn.Sequential()
backward:add(nn.ReverseTable()) -- reverse
backward:add(self.bwdSeq)
backward:add(nn.ReverseTable()) -- unreverse

local concat = nn.ConcatTable()
concat:add(self.fwdSeq):add(backward)


-- make into sequencer (if not already the case)
forward = self.isSeq(forward) and forward or nn.Sequencer(forward)
backward = self.isSeq(backward) and backward or nn.Sequencer(backward)

-- the backward sequence reads the input in reverse and outputs the output in correct order
backward = nn.ReverseUnreverse(backward)

local brnn = nn.Sequential()
brnn:add(concat)
brnn:add(nn.ZipTable())
brnn:add(self.mergeSeq)

parent.__init(self)

self.output = {}
self.gradInput = {}

self.module = brnn
:add(nn.ConcatTable():add(forward):add(backward))
:add(merge)

-- so that it can be handled like a Container
self.modules[1] = brnn
end

-- forward RNN can remember. backward RNN can't.
function BiSequencer:remember(remember)
local fwd, bwd = self:getForward(), self:getBackward()
fwd:remember(remember)
bwd:remember('neither')
return self
end

function BiSequencer.isSeq(module)
return torch.isTypeOf(module, 'nn.AbstractSequencer') or torch.typename(module):find('nn.Seq.+')
end

-- multiple-inheritance
nn.Decorator.decorate(BiSequencer)

function BiSequencer:getForward()
return self:get(1):get(1):get(1)
end

function BiSequencer:getBackward()
return self:get(1):get(1):get(2):getModule()
end

function BiSequencer:__tostring__()
local tab = ' '
local line = '\n'
local ext = ' | '
local extlast = ' '
local last = ' ... -> '
local str = torch.type(self)
str = str .. ' {'
str = str .. line .. tab .. 'forward: ' .. tostring(self:getForward()):gsub(line, line .. tab .. ext)
str = str .. line .. tab .. 'backward: ' .. tostring(self:getBackward()):gsub(line, line .. tab .. ext)
str = str .. line .. tab .. 'merge: ' .. tostring(self:get(1):get(2)):gsub(line, line .. tab .. ext)
str = str .. line .. '}'
return str
end
10 changes: 6 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ SET(luasrc
AbstractSequencer.lua
AbstractSequencerCriterion.lua
BiSequencer.lua
BiSequencerLM.lua
deprecated/BiSequencerLM.lua
CopyGrad.lua
Dropout.lua
ExpandAs.lua
Expand All @@ -38,11 +38,12 @@ SET(luasrc
Repeater.lua
RepeaterCriterion.lua
SAdd.lua
SeqBRNN.lua
SeqBGRU.lua
SeqBLSTM.lua
SeqGRU.lua
SeqLSTM.lua
deprecated/SeqLSTMP.lua
SeqReverseSequence.lua
deprecated/SeqReverseSequence.lua
Sequencer.lua
SequencerCriterion.lua
ZeroGrad.lua
Expand Down Expand Up @@ -85,7 +86,7 @@ SET(luasrc
ReinforceCategorical.lua
ReinforceGamma.lua
ReinforceNormal.lua
ReverseTable.lua
ReverseSequence.lua
Sequential.lua
Serial.lua
SimpleColorTransform.lua
Expand All @@ -104,6 +105,7 @@ SET(luasrc
WhiteNoise.lua
ZipTable.lua
ZipTableOneToMany.lua
ReverseUnreverse.lua
)

ADD_TORCH_PACKAGE(rnn "${src}" "${luasrc}" "An RNN library for Torch")
Expand Down
Loading

0 comments on commit 31bd914

Please sign in to comment.