Skip to content

Commit

Permalink
Implements NormStabilizer layer and adds tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonathan Uesato committed May 2, 2016
1 parent 4731643 commit 2cf512b
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 1 deletion.
6 changes: 6 additions & 0 deletions CopyGrad.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
local CopyGrad, _ = torch.class('nn.CopyGrad', 'nn.Identity')

function CopyGrad:updateGradInput(input, gradOutput)
self.gradInput:resizeAs(gradOutput):copy(gradOutput)
return self.gradInput
end
76 changes: 76 additions & 0 deletions NormStabilizer.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
------------------------------------------------------------------------
--[[ Norm Stabilization]]
-- Regularizing RNNs by Stabilizing Activations
-- Ref. A: http://arxiv.org/abs/1511.08400
------------------------------------------------------------------------

local NS, parent = torch.class("nn.NormStabilizer", "nn.AbstractRecurrent")

function NS:__init(beta, rho)
parent.__init(self, rho or 9999)
self.recurrentModule = nn.CopyGrad()
self.beta = beta
end

function NS:_accGradParameters(input, gradOutput, scale)
-- No parameters to update
end

function NS:updateOutput(input)
local output
if self.train ~= false then
self:recycle()
local recurrentModule = self:getStepModule(self.step)
output = recurrentModule:updateOutput(input)
else
output = self.recurrentModule:updateOutput(input)
end

self.outputs[self.step] = output

self.output = output
self.step = self.step + 1
self.gradPrevOutput = nil
self.updateGradInputStep = nil
self.accGradParametersStep = nil

return self.output
end

function NS:_updateGradInput(input, gradOutput)
-- First grab h[t] and h[t+1] :
-- backward propagate through this step
local gradInput = self.recurrentModule:updateGradInput(input, gradOutput)
local curStep = self.updateGradInputStep-1
local hiddenModule = self:getStepModule(curStep)
local hiddenState = hiddenModule.output
hiddenModule.gradInput = gradInput

if curStep < self.step then
local batchSize = hiddenState:size(1)
if curStep > 1 then
local prevHiddenModule = self:getStepModule(curStep - 1)
local prevHiddenState = prevHiddenModule.output
-- Add norm stabilizer cost function directly to respective CopyGrad.gradInput tensors
for i=1,batchSize do
local dRegdNorm = self.beta * 2 * (hiddenState[i]:norm()-prevHiddenState[i]:norm()) / batchSize
local dNormdHid = torch.div(hiddenState[i], hiddenState[i]:norm())
hiddenModule.gradInput[i]:add(torch.mul(dNormdHid, dRegdNorm))
end
end
if curStep < self.step-1 then
local nextHiddenModule = self:getStepModule(curStep + 1)
local nextHiddenState = nextHiddenModule.output
for i=1,batchSize do
local dRegdNorm = self.beta * -2 * (nextHiddenState[i]:norm() - hiddenState[i]:norm()) / batchSize
local dNormdHid = torch.div(hiddenState[i], hiddenState[i]:norm())
hiddenModule.gradInput[i]:add(torch.mul(dNormdHid, dRegdNorm))
end
end
end
return hiddenModule.gradInput
end

function NS:__tostring__()
return "nn.NormStabilizer"
end
4 changes: 3 additions & 1 deletion init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ torch.include('rnn', 'Module.lua')
torch.include('rnn', 'Dropout.lua')

-- for testing:
torch.include('rnn', 'test.lua')
torch.include('rnn', 'test/test.lua')

-- support modules
torch.include('rnn', 'ZeroGrad.lua')
torch.include('rnn', 'LinearNoBias.lua')
torch.include('rnn', 'SAdd.lua')
torch.include('rnn', 'CopyGrad.lua')

-- recurrent modules
torch.include('rnn', 'LookupTableMaskZero.lua')
Expand All @@ -35,6 +36,7 @@ torch.include('rnn', 'FastLSTM.lua')
torch.include('rnn', 'GRU.lua')
torch.include('rnn', 'Recursor.lua')
torch.include('rnn', 'Recurrence.lua')
torch.include('rnn', 'NormStabilizer.lua')

-- sequencer modules
torch.include('rnn', 'AbstractSequencer.lua')
Expand Down
114 changes: 114 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5011,6 +5011,120 @@ function rnntest.clearState()
end
end

function checkgrad(opfunc, x, eps)
-- Function taken from 'optim' package to avoid introducing dependency
-- https://github.com/torch/optim/blob/master/checkgrad.lua
-- first, compute true gradient:
local _,dC = opfunc(x)
dC:resize(x:size())

-- compute numeric approximations to gradient:
local eps = eps or 1e-7
local dC_est = torch.Tensor():typeAs(dC):resizeAs(dC)
for i = 1,dC:size(1) do
x[i] = x[i] + eps
local C1 = opfunc(x)
x[i] = x[i] - 2 * eps
local C2 = opfunc(x)
x[i] = x[i] + eps
dC_est[i] = (C1 - C2) / (2 * eps)
end

-- estimate error of gradient:
local diff = torch.norm(dC - dC_est) / torch.norm(dC + dC_est)
return diff,dC,dC_est
end

function rnntest.NormStabilizer()
local SequencerCriterion, parent = torch.class('nn.SequencerCriterionNormStab', 'nn.SequencerCriterion')

function SequencerCriterion:__init(criterion, beta)
parent.__init(self)
self.criterion = criterion
if torch.isTypeOf(criterion, 'nn.ModuleCriterion') then
error("SequencerCriterion shouldn't decorate a ModuleCriterion. "..
"Instead, try the other way around : "..
"ModuleCriterion decorates a SequencerCriterion. "..
"Its modules can also be similarly decorated with a Sequencer.")
end
self.clones = {}
self.gradInput = {}
self.beta = beta
end

function SequencerCriterion:updateOutput(inputTable, targetTable)
self.output = 0
for i,input in ipairs(inputTable) do
local criterion = self:getStepCriterion(i)
self.output = self.output + criterion:forward(input, targetTable[i])
if i > 1 then
local reg = 0
for j=1,input:size(1) do
reg = reg + ((input[j]:norm() - inputTable[i-1][j]:norm())^2)
end
self.output = self.output + self.beta * reg / input:size(1)
end
end
return self.output
end

-- Make a simple RNN and training set to test gradients
-- hyper-parameters
batchSize = 3
rho = 2
hiddenSize = 3
inputSize = 4
lr = 0.1
beta = 50.0

-- build simple recurrent neural network
local r = nn.Recurrent(
hiddenSize, nn.Linear(inputSize, hiddenSize),
nn.Linear(hiddenSize, hiddenSize), nn.Sigmoid(),
rho
)
local rnn = nn.Sequential()
:add(r)
:add(nn.NormStabilizer(beta))

rnn = nn.Sequencer(rnn)
criterion = nn.SequencerCriterionNormStab(nn.MSECriterion(), beta)

local iteration = 1
params, gradParams = rnn:getParameters()

while iteration < 100 do
-- generate a random data point
local inputs, targets = {}, {}
for step=1,rho do
inputs[step] = torch.randn(batchSize, inputSize)
targets[step] = torch.randn(batchSize, hiddenSize)
end

-- set up closure
function feval(params_new)
if params ~= params_new then
params:copy(params_new)
end

rnn:zeroGradParameters()
local outputs = rnn:forward(inputs)
local err = criterion:forward(outputs, targets)
local gradOutputs = criterion:backward(outputs, targets)
local gradInputs = rnn:backward(inputs, gradOutputs)
return err, gradParams
end

-- compare numerical to analytic gradient
local diff, dC, dC_est = checkgrad(feval, params, 1e-10)
mytester:assert(diff < 1e-3, "Numerical gradient and analytic gradient do not match.")

rnn:updateParameters(lr)

iteration = iteration + 1
end
end

function rnn.test(tests, benchmark_)
mytester = torch.Tester()
benchmark = benchmark_
Expand Down

0 comments on commit 2cf512b

Please sign in to comment.