Skip to content

Commit

Permalink
rnn.nn.recursive* -> nn.utils.recursive*
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Leonard committed May 2, 2017
1 parent 43ec8c2 commit 587be96
Show file tree
Hide file tree
Showing 10 changed files with 279 additions and 230 deletions.
66 changes: 33 additions & 33 deletions BiSequencerLM.lua
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
------------------------------------------------------------------------
--[[ BiSequencerLM ]]--
-- Encapsulates forward, backward and merge modules.
-- Encapsulates forward, backward and merge modules.
-- Input is a sequence (a table) of tensors.
-- Output is a sequence (a table) of tensors of the same length.
-- Applies a `fwd` rnn instance to the first `N-1` elements in the
-- Applies a `fwd` rnn instance to the first `N-1` elements in the
-- sequence in forward order.
-- Applies the `bwd` rnn in reverse order to the last `N-1` elements
-- Applies the `bwd` rnn in reverse order to the last `N-1` elements
-- (from second-to-last element to first element).
-- Note : you shouldn't stack these for language modeling.
-- Note : you shouldn't stack these for language modeling.
-- Instead, stack each fwd/bwd seqs and encapsulate these.
------------------------------------------------------------------------
local _ = require 'moses'
local BiSequencerLM, parent = torch.class('nn.BiSequencerLM', 'nn.AbstractSequencer')

function BiSequencerLM:__init(forward, backward, merge)

if not torch.isTypeOf(forward, 'nn.Module') then
error"BiSequencerLM: expecting nn.Module instance at arg 1"
end
self.forwardModule = forward

self.backwardModule = backward
if not self.backwardModule then
self.backwardModule = forward:clone()
Expand All @@ -28,7 +28,7 @@ function BiSequencerLM:__init(forward, backward, merge)
if not torch.isTypeOf(self.backwardModule, 'nn.Module') then
error"BiSequencerLM: expecting nn.Module instance at arg 2"
end

if torch.type(merge) == 'number' then
self.mergeModule = nn.JoinTable(1, merge)
elseif merge == nil then
Expand All @@ -38,36 +38,36 @@ function BiSequencerLM:__init(forward, backward, merge)
else
error"BiSequencerLM: expecting nn.Module or number instance at arg 3"
end

if torch.isTypeOf(self.forwardModule, 'nn.AbstractRecurrent') then
self.fwdSeq = nn.Sequencer(self.forwardModule)
else -- assumes a nn.Sequencer or stack thereof
self.fwdSeq = self.forwardModule
end

if torch.isTypeOf(self.backwardModule, 'nn.AbstractRecurrent') then
self.bwdSeq = nn.Sequencer(self.backwardModule)
else
self.bwdSeq = self.backwardModule
end
self.mergeSeq = nn.Sequencer(self.mergeModule)

self._fwd = self.fwdSeq

self._bwd = nn.Sequential()
self._bwd:add(nn.ReverseTable())
self._bwd:add(self.bwdSeq)
self._bwd:add(nn.ReverseTable())

self._merge = nn.Sequential()
self._merge:add(nn.ZipTable())
self._merge:add(self.mergeSeq)


parent.__init(self)

self.modules = {self._fwd, self._bwd, self._merge}

self.output = {}
self.gradInput = {}
end
Expand All @@ -76,64 +76,64 @@ function BiSequencerLM:updateOutput(input)
assert(torch.type(input) == 'table', 'Expecting table at arg 1')
local nStep = #input
assert(nStep > 1, "Expecting at least 2 elements in table")

-- forward through fwd and bwd rnn in fwd and reverse order
self._fwdOutput = self._fwd:updateOutput(_.first(input, nStep - 1))
self._bwdOutput = self._bwd:updateOutput(_.last(input, nStep - 1))

-- empty outputs
for k,v in ipairs(self.output) do self.output[k] = nil end

-- padding for first and last elements of fwd and bwd outputs, respectively
self._firstStep = nn.rnn.recursiveResizeAs(self._firstStep, self._fwdOutput[1])
nn.rnn.recursiveFill(self._firstStep, 0)
self._lastStep = nn.rnn.recursiveResizeAs(self._lastStep, self._bwdOutput[1])
nn.rnn.recursiveFill(self._lastStep, 0)
self._firstStep = nn.utils.recursiveResizeAs(self._firstStep, self._fwdOutput[1])
nn.utils.recursiveFill(self._firstStep, 0)
self._lastStep = nn.utils.recursiveResizeAs(self._lastStep, self._bwdOutput[1])
nn.utils.recursiveFill(self._lastStep, 0)

-- { { zeros, fwd1, fwd2, ..., fwdN}, {bwd1, bwd2, ..., bwdN, zeros} }
self._mergeInput = {_.clone(self._fwdOutput), _.clone(self._bwdOutput)}
table.insert(self._mergeInput[1], 1, self._firstStep)
table.insert(self._mergeInput[2], self._lastStep)
assert(#self._mergeInput[1] == #self._mergeInput[2])

self.output = self._merge:updateOutput(self._mergeInput)

return self.output
end

function BiSequencerLM:updateGradInput(input, gradOutput)
local nStep = #input

self._mergeGradInput = self._merge:updateGradInput(self._mergeInput, gradOutput)
self._fwdGradInput = self._fwd:updateGradInput(_.first(input, nStep - 1), _.last(self._mergeGradInput[1], nStep - 1))
self._bwdGradInput = self._bwd:updateGradInput(_.last(input, nStep - 1), _.first(self._mergeGradInput[2], nStep - 1))

-- add fwd rnn gradInputs to bwd rnn gradInputs
for i=1,nStep do
if i == 1 then
self.gradInput[1] = self._fwdGradInput[1]
elseif i == nStep then
self.gradInput[nStep] = self._bwdGradInput[nStep-1]
else
self.gradInput[i] = nn.rnn.recursiveCopy(self.gradInput[i], self._fwdGradInput[i])
nn.rnn.recursiveAdd(self.gradInput[i], self._bwdGradInput[i-1])
self.gradInput[i] = nn.utils.recursiveCopy(self.gradInput[i], self._fwdGradInput[i])
nn.utils.recursiveAdd(self.gradInput[i], self._bwdGradInput[i-1])
end
end

return self.gradInput
end

function BiSequencerLM:accGradParameters(input, gradOutput, scale)
local nStep = #input

self._merge:accGradParameters(self._mergeInput, gradOutput, scale)
self._fwd:accGradParameters(_.first(input, nStep - 1), _.last(self._mergeGradInput[1], nStep - 1), scale)
self._bwd:accGradParameters(_.last(input, nStep - 1), _.first(self._mergeGradInput[2], nStep - 1), scale)
end

function BiSequencerLM:accUpdateGradParameters(input, gradOutput, lr)
local nStep = #input

self._merge:accUpdateGradParameters(self._mergeInput, gradOutput, lr)
self._fwd:accUpdateGradParameters(_.first(input, nStep - 1), _.last(self._mergeGradInput[1], nStep - 1), lr)
self._bwd:accUpdateGradParameters(_.last(input, nStep - 1), _.first(self._mergeGradInput[2], nStep - 1), lr)
Expand Down
4 changes: 2 additions & 2 deletions FireModule.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function FireModule:__init(nInputPlane, s1x1, e1x1, e3x3, activation)
print('Warning: <FireModule> s1x1 is recommended to be smaller'..
' then e1x1+e3x3')
end

self.module = nn.Sequential()
self.squeeze = nn.SpatialConvolution(nInputPlane, s1x1, 1, 1)
self.expand = nn.Concat(2)
Expand All @@ -28,7 +28,7 @@ function FireModule:__init(nInputPlane, s1x1, e1x1, e3x3, activation)
self.module:add(nn[self.activation]())
self.module:add(self.expand)
self.module:add(nn[self.activation]())

Parent.__init(self, self.module)
end

Expand Down
4 changes: 2 additions & 2 deletions GRU.lua
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ function GRU:_updateGradInput(input, gradOutput)
-- backward propagate through this step
local _gradOutput = self:getGradHiddenState(step, input)
assert(_gradOutput)
self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], _gradOutput)
nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
self._gradOutputs[step] = nn.utils.recursiveCopy(self._gradOutputs[step], _gradOutput)
nn.utils.recursiveAdd(self._gradOutputs[step], gradOutput)
gradOutput = self._gradOutputs[step]

local gradInputTable = stepmodule:updateGradInput({input, self:getHiddenState(step-1)}, gradOutput)
Expand Down
4 changes: 2 additions & 2 deletions Recurrence.lua
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ function Recurrence:_updateGradInput(input, gradOutput)

-- backward propagate through this step
local _gradOutput = self:getGradHiddenState(step, input)[1]
self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], _gradOutput)
nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
self._gradOutputs[step] = nn.utils.recursiveCopy(self._gradOutputs[step], _gradOutput)
nn.utils.recursiveAdd(self._gradOutputs[step], gradOutput)
gradOutput = self._gradOutputs[step]

local gradInputTable = stepmodule:updateGradInput({input, self:getHiddenState(step-1)[1]}, gradOutput)
Expand Down
56 changes: 28 additions & 28 deletions RecurrentAttention.lua
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
------------------------------------------------------------------------
--[[ RecurrentAttention ]]--
--[[ RecurrentAttention ]]--
-- Ref. A. http://papers.nips.cc/paper/5542-recurrent-models-of-visual-attention.pdf
-- B. http://incompleteideas.net/sutton/williams-92.pdf
-- module which takes an RNN as argument with other
-- hyper-parameters such as the maximum number of steps,
-- action (actions sampling module like ReinforceNormal) and
-- module which takes an RNN as argument with other
-- hyper-parameters such as the maximum number of steps,
-- action (actions sampling module like ReinforceNormal) and
------------------------------------------------------------------------
local RecurrentAttention, parent = torch.class("nn.RecurrentAttention", "nn.AbstractSequencer")

Expand All @@ -14,33 +14,33 @@ function RecurrentAttention:__init(rnn, action, nStep, hiddenSize)
assert(torch.type(nStep) == 'number')
assert(torch.type(hiddenSize) == 'table')
assert(torch.type(hiddenSize[1]) == 'number', "Does not support table hidden layers" )

self.rnn = rnn
-- we can decorate the module with a Recursor to make it AbstractRecurrent
self.rnn = (not torch.isTypeOf(rnn, 'nn.AbstractRecurrent')) and nn.Recursor(rnn) or rnn

-- samples an x,y actions for each example
self.action = (not torch.isTypeOf(action, 'nn.AbstractRecurrent')) and nn.Recursor(action) or action
self.action = (not torch.isTypeOf(action, 'nn.AbstractRecurrent')) and nn.Recursor(action) or action
self.hiddenSize = hiddenSize
self.nStep = nStep

self.modules = {self.rnn, self.action}

self.output = {} -- rnn output
self.actions = {} -- action output

self.forwardActions = false

self.gradHidden = {}
end

function RecurrentAttention:updateOutput(input)
self.rnn:forget()
self.action:forget()
local nDim = input:dim()

for step=1,self.nStep do

if step == 1 then
-- sample an initial starting actions by forwarding zeros through the action
self._initInput = self._initInput or input.new()
Expand All @@ -50,20 +50,20 @@ function RecurrentAttention:updateOutput(input)
-- sample actions from previous hidden activation (rnn output)
self.actions[step] = self.action:updateOutput(self.output[step-1])
end

-- rnn handles the recurrence internally
local output = self.rnn:updateOutput{input, self.actions[step]}
self.output[step] = self.forwardActions and {output, self.actions[step]} or output
end

return self.output
end

function RecurrentAttention:updateGradInput(input, gradOutput)
assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps")
assert(torch.type(gradOutput) == 'table', "expecting gradOutput table")
assert(#gradOutput == self.nStep, "gradOutput should have nStep elements")

-- back-propagate through time (BPTT)
for step=self.nStep,1,-1 do
-- 1. backward through the action layer
Expand All @@ -78,22 +78,22 @@ function RecurrentAttention:updateGradInput(input, gradOutput)
end
gradAction_ = self._gradAction
end

if step == self.nStep then
self.gradHidden[step] = nn.rnn.recursiveCopy(self.gradHidden[step], gradOutput_)
self.gradHidden[step] = nn.utils.recursiveCopy(self.gradHidden[step], gradOutput_)
else
-- gradHidden = gradOutput + gradAction
nn.rnn.recursiveAdd(self.gradHidden[step], gradOutput_)
nn.utils.recursiveAdd(self.gradHidden[step], gradOutput_)
end

if step == 1 then
-- backward through initial starting actions
self.action:updateGradInput(self._initInput, gradAction_)
else
local gradAction = self.action:updateGradInput(self.output[step-1], gradAction_)
self.gradHidden[step-1] = nn.rnn.recursiveCopy(self.gradHidden[step-1], gradAction)
self.gradHidden[step-1] = nn.utils.recursiveCopy(self.gradHidden[step-1], gradAction)
end

-- 2. backward through the rnn layer
local gradInput = self.rnn:updateGradInput({input, self.actions[step]}, self.gradHidden[step])[1]
if step == self.nStep then
Expand All @@ -110,19 +110,19 @@ function RecurrentAttention:accGradParameters(input, gradOutput, scale)
assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps")
assert(torch.type(gradOutput) == 'table', "expecting gradOutput table")
assert(#gradOutput == self.nStep, "gradOutput should have nStep elements")

-- back-propagate through time (BPTT)
for step=self.nStep,1,-1 do
-- 1. backward through the action layer
local gradAction_ = self.forwardActions and gradOutput[step][2] or self._gradAction

if step == 1 then
-- backward through initial starting actions
self.action:accGradParameters(self._initInput, gradAction_, scale)
else
self.action:accGradParameters(self.output[step-1], gradAction_, scale)
end

-- 2. backward through the rnn layer
self.rnn:accGradParameters({input, self.actions[step]}, self.gradHidden[step], scale)
end
Expand All @@ -132,20 +132,20 @@ function RecurrentAttention:accUpdateGradParameters(input, gradOutput, lr)
assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps")
assert(torch.type(gradOutput) == 'table', "expecting gradOutput table")
assert(#gradOutput == self.nStep, "gradOutput should have nStep elements")

-- backward through the action layers
for step=self.nStep,1,-1 do
-- 1. backward through the action layer
local gradAction_ = self.forwardActions and gradOutput[step][2] or self._gradAction

if step == 1 then
-- backward through initial starting actions
self.action:accUpdateGradParameters(self._initInput, gradAction_, lr)
else
-- Note : gradOutput is ignored by REINFORCE modules so we give action.output as a dummy variable
self.action:accUpdateGradParameters(self.output[step-1], gradAction_, lr)
end

-- 2. backward through the rnn layer
self.rnn:accUpdateGradParameters({input, self.actions[step]}, self.gradHidden[step], lr)
end
Expand Down
6 changes: 3 additions & 3 deletions Repeater.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function Repeater:updateOutput(input)
self.module:forget()
-- TODO make copy outputs optional
for step=1,self.seqlen do
self.output[step] = nn.rnn.recursiveCopy(self.output[step], self.module:updateOutput(input))
self.output[step] = nn.utils.recursiveCopy(self.output[step], self.module:updateOutput(input))
end
return self.output
end
Expand All @@ -39,9 +39,9 @@ function Repeater:updateGradInput(input, gradOutput)
for step=self.seqlen,1,-1 do
local gradInput = self.module:updateGradInput(input, gradOutput[step])
if step == self.seqlen then
self.gradInput = nn.rnn.recursiveCopy(self.gradInput, gradInput)
self.gradInput = nn.utils.recursiveCopy(self.gradInput, gradInput)
else
nn.rnn.recursiveAdd(self.gradInput, gradInput)
nn.utils.recursiveAdd(self.gradInput, gradInput)
end
end

Expand Down
4 changes: 2 additions & 2 deletions deprecated/LSTM.lua
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ function LSTM:_updateGradInput(input, gradOutput)
local _gradOutput, gradCell = gradHiddenState[1], gradHiddenState[2]
assert(_gradOutput and gradCell)

self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], _gradOutput)
nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
self._gradOutputs[step] = nn.utils.recursiveCopy(self._gradOutputs[step], _gradOutput)
nn.utils.recursiveAdd(self._gradOutputs[step], gradOutput)
gradOutput = self._gradOutputs[step]

local inputTable = self:getHiddenState(step-1)
Expand Down
Loading

0 comments on commit 587be96

Please sign in to comment.