-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #29 from torch/SeqBRNN
Bidirectional RNNs
- Loading branch information
Showing
18 changed files
with
509 additions
and
385 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,85 @@ | ||
------------------------------------------------------------------------ | ||
--[[ BiSequencer ]]-- | ||
-- Encapsulates forward, backward and merge modules. | ||
-- Input is a sequence (a table) of tensors. | ||
-- Output is a sequence (a table) of tensors of the same length. | ||
-- Applies a forward rnn to each element in the sequence in | ||
-- forward order and applies a backward rnn in reverse order. | ||
-- For each step, the outputs of both rnn are merged together using | ||
-- the merge module (defaults to nn.JoinTable(1,1)). | ||
-- The sequences in a batch must have the same size. | ||
-- But the sequence length of each batch can vary. | ||
-- It is implemented by decorating a structure of modules that makes | ||
-- use of 3 Sequencers for the forward, backward and merge modules. | ||
-- Encapsulates forward, backward and merge modules. | ||
-- Input is a seqlen x inputsize [x ...] sequence tensor | ||
-- Output is a seqlen x outputsize [x ...] sequence tensor | ||
-- Applies a forward RNN to each element in the sequence in | ||
-- forward order and applies a backward RNN in reverse order. | ||
-- For each step, the outputs of both RNNs are merged together using | ||
-- the merge module (defaults to nn.CAddTable()). | ||
------------------------------------------------------------------------ | ||
local BiSequencer, parent = torch.class('nn.BiSequencer', 'nn.AbstractSequencer') | ||
|
||
function BiSequencer:__init(forward, backward, merge) | ||
|
||
parent.__init(self) | ||
|
||
if not torch.isTypeOf(forward, 'nn.Module') then | ||
error"BiSequencer: expecting nn.Module instance at arg 1" | ||
end | ||
self.forwardModule = forward | ||
|
||
self.backwardModule = backward | ||
if not self.backwardModule then | ||
self.backwardModule = forward:clone() | ||
self.backwardModule:reset() | ||
|
||
if not backward then | ||
backward = forward:clone() | ||
backward:reset() | ||
end | ||
if not torch.isTypeOf(self.backwardModule, 'nn.Module') then | ||
error"BiSequencer: expecting nn.Module instance at arg 2" | ||
if not torch.isTypeOf(backward, 'nn.Module') then | ||
error"BiSequencer: expecting nn.Module instance or nil at arg 2" | ||
end | ||
|
||
if torch.type(merge) == 'number' then | ||
self.mergeModule = nn.JoinTable(1, merge) | ||
elseif merge == nil then | ||
self.mergeModule = nn.JoinTable(1, 1) | ||
elseif torch.isTypeOf(merge, 'nn.Module') then | ||
self.mergeModule = merge | ||
else | ||
error"BiSequencer: expecting nn.Module or number instance at arg 3" | ||
|
||
-- for table sequences use nn.Sequential():add(nn.ZipTable()):add(nn.Sequencer(nn.JoinTable(1,1))) | ||
merge = merge or nn.CAddTable() | ||
if not torch.isTypeOf(merge, 'nn.Module') then | ||
error"BiSequencer: expecting nn.Module instance or nil at arg 3" | ||
end | ||
|
||
self.fwdSeq = nn.Sequencer(self.forwardModule) | ||
self.bwdSeq = nn.Sequencer(self.backwardModule) | ||
self.mergeSeq = nn.Sequencer(self.mergeModule) | ||
|
||
local backward = nn.Sequential() | ||
backward:add(nn.ReverseTable()) -- reverse | ||
backward:add(self.bwdSeq) | ||
backward:add(nn.ReverseTable()) -- unreverse | ||
|
||
local concat = nn.ConcatTable() | ||
concat:add(self.fwdSeq):add(backward) | ||
|
||
|
||
-- make into sequencer (if not already the case) | ||
forward = self.isSeq(forward) and forward or nn.Sequencer(forward) | ||
backward = self.isSeq(backward) and backward or nn.Sequencer(backward) | ||
|
||
-- the backward sequence reads the input in reverse and outputs the output in correct order | ||
backward = nn.ReverseUnreverse(backward) | ||
|
||
local brnn = nn.Sequential() | ||
brnn:add(concat) | ||
brnn:add(nn.ZipTable()) | ||
brnn:add(self.mergeSeq) | ||
|
||
parent.__init(self) | ||
|
||
self.output = {} | ||
self.gradInput = {} | ||
|
||
self.module = brnn | ||
:add(nn.ConcatTable():add(forward):add(backward)) | ||
:add(merge) | ||
|
||
-- so that it can be handled like a Container | ||
self.modules[1] = brnn | ||
end | ||
|
||
-- forward RNN can remember. backward RNN can't. | ||
function BiSequencer:remember(remember) | ||
local fwd, bwd = self:getForward(), self:getBackward() | ||
fwd:remember(remember) | ||
bwd:remember('neither') | ||
return self | ||
end | ||
|
||
function BiSequencer.isSeq(module) | ||
return torch.isTypeOf(module, 'nn.AbstractSequencer') or torch.typename(module):find('nn.Seq.+') | ||
end | ||
|
||
-- multiple-inheritance | ||
nn.Decorator.decorate(BiSequencer) | ||
|
||
function BiSequencer:getForward() | ||
return self:get(1):get(1):get(1) | ||
end | ||
|
||
function BiSequencer:getBackward() | ||
return self:get(1):get(1):get(2):getModule() | ||
end | ||
|
||
function BiSequencer:__tostring__() | ||
local tab = ' ' | ||
local line = '\n' | ||
local ext = ' | ' | ||
local extlast = ' ' | ||
local last = ' ... -> ' | ||
local str = torch.type(self) | ||
str = str .. ' {' | ||
str = str .. line .. tab .. 'forward: ' .. tostring(self:getForward()):gsub(line, line .. tab .. ext) | ||
str = str .. line .. tab .. 'backward: ' .. tostring(self:getBackward()):gsub(line, line .. tab .. ext) | ||
str = str .. line .. tab .. 'merge: ' .. tostring(self:get(1):get(2)):gsub(line, line .. tab .. ext) | ||
str = str .. line .. '}' | ||
return str | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.