diff --git a/CopyGrad.lua b/CopyGrad.lua new file mode 100644 index 0000000..4c0fc16 --- /dev/null +++ b/CopyGrad.lua @@ -0,0 +1,6 @@ +local CopyGrad, _ = torch.class('nn.CopyGrad', 'nn.Identity') + +function CopyGrad:updateGradInput(input, gradOutput) + self.gradInput:resizeAs(gradOutput):copy(gradOutput) + return self.gradInput +end diff --git a/NormStabilizer.lua b/NormStabilizer.lua new file mode 100644 index 0000000..5bc25b5 --- /dev/null +++ b/NormStabilizer.lua @@ -0,0 +1,76 @@ +------------------------------------------------------------------------ +--[[ Norm Stabilization]] +-- Regularizing RNNs by Stabilizing Activations +-- Ref. A: http://arxiv.org/abs/1511.08400 +------------------------------------------------------------------------ + +local NS, parent = torch.class("nn.NormStabilizer", "nn.AbstractRecurrent") + +function NS:__init(beta, rho) + parent.__init(self, rho or 9999) + self.recurrentModule = nn.CopyGrad() + self.beta = beta +end + +function NS:_accGradParameters(input, gradOutput, scale) + -- No parameters to update +end + +function NS:updateOutput(input) + local output + if self.train ~= false then + self:recycle() + local recurrentModule = self:getStepModule(self.step) + output = recurrentModule:updateOutput(input) + else + output = self.recurrentModule:updateOutput(input) + end + + self.outputs[self.step] = output + + self.output = output + self.step = self.step + 1 + self.gradPrevOutput = nil + self.updateGradInputStep = nil + self.accGradParametersStep = nil + + return self.output +end + +function NS:_updateGradInput(input, gradOutput) + -- First grab h[t] and h[t+1] : + -- backward propagate through this step + local gradInput = self.recurrentModule:updateGradInput(input, gradOutput) + local curStep = self.updateGradInputStep-1 + local hiddenModule = self:getStepModule(curStep) + local hiddenState = hiddenModule.output + hiddenModule.gradInput = gradInput + + if curStep < self.step then + local batchSize = hiddenState:size(1) + if curStep > 1 then + local prevHiddenModule = self:getStepModule(curStep - 1) + local prevHiddenState = prevHiddenModule.output + -- Add norm stabilizer cost function directly to respective CopyGrad.gradInput tensors + for i=1,batchSize do + local dRegdNorm = self.beta * 2 * (hiddenState[i]:norm()-prevHiddenState[i]:norm()) / batchSize + local dNormdHid = torch.div(hiddenState[i], hiddenState[i]:norm()) + hiddenModule.gradInput[i]:add(torch.mul(dNormdHid, dRegdNorm)) + end + end + if curStep < self.step-1 then + local nextHiddenModule = self:getStepModule(curStep + 1) + local nextHiddenState = nextHiddenModule.output + for i=1,batchSize do + local dRegdNorm = self.beta * -2 * (nextHiddenState[i]:norm() - hiddenState[i]:norm()) / batchSize + local dNormdHid = torch.div(hiddenState[i], hiddenState[i]:norm()) + hiddenModule.gradInput[i]:add(torch.mul(dNormdHid, dRegdNorm)) + end + end + end + return hiddenModule.gradInput +end + +function NS:__tostring__() + return "nn.NormStabilizer" +end diff --git a/init.lua b/init.lua index 7f46eaf..acccf4b 100644 --- a/init.lua +++ b/init.lua @@ -17,12 +17,13 @@ torch.include('rnn', 'Module.lua') torch.include('rnn', 'Dropout.lua') -- for testing: -torch.include('rnn', 'test.lua') +torch.include('rnn', 'test/test.lua') -- support modules torch.include('rnn', 'ZeroGrad.lua') torch.include('rnn', 'LinearNoBias.lua') torch.include('rnn', 'SAdd.lua') +torch.include('rnn', 'CopyGrad.lua') -- recurrent modules torch.include('rnn', 'LookupTableMaskZero.lua') @@ -35,6 +36,7 @@ torch.include('rnn', 'FastLSTM.lua') torch.include('rnn', 'GRU.lua') torch.include('rnn', 'Recursor.lua') torch.include('rnn', 'Recurrence.lua') +torch.include('rnn', 'NormStabilizer.lua') -- sequencer modules torch.include('rnn', 'AbstractSequencer.lua') diff --git a/test/test.lua b/test/test.lua index 59badf2..e6243db 100644 --- a/test/test.lua +++ b/test/test.lua @@ -5011,6 +5011,120 @@ function rnntest.clearState() end end +function checkgrad(opfunc, x, eps) + -- Function taken from 'optim' package to avoid introducing dependency + -- https://github.com/torch/optim/blob/master/checkgrad.lua + -- first, compute true gradient: + local _,dC = opfunc(x) + dC:resize(x:size()) + + -- compute numeric approximations to gradient: + local eps = eps or 1e-7 + local dC_est = torch.Tensor():typeAs(dC):resizeAs(dC) + for i = 1,dC:size(1) do + x[i] = x[i] + eps + local C1 = opfunc(x) + x[i] = x[i] - 2 * eps + local C2 = opfunc(x) + x[i] = x[i] + eps + dC_est[i] = (C1 - C2) / (2 * eps) + end + + -- estimate error of gradient: + local diff = torch.norm(dC - dC_est) / torch.norm(dC + dC_est) + return diff,dC,dC_est +end + +function rnntest.NormStabilizer() + local SequencerCriterion, parent = torch.class('nn.SequencerCriterionNormStab', 'nn.SequencerCriterion') + + function SequencerCriterion:__init(criterion, beta) + parent.__init(self) + self.criterion = criterion + if torch.isTypeOf(criterion, 'nn.ModuleCriterion') then + error("SequencerCriterion shouldn't decorate a ModuleCriterion. ".. + "Instead, try the other way around : ".. + "ModuleCriterion decorates a SequencerCriterion. ".. + "Its modules can also be similarly decorated with a Sequencer.") + end + self.clones = {} + self.gradInput = {} + self.beta = beta + end + + function SequencerCriterion:updateOutput(inputTable, targetTable) + self.output = 0 + for i,input in ipairs(inputTable) do + local criterion = self:getStepCriterion(i) + self.output = self.output + criterion:forward(input, targetTable[i]) + if i > 1 then + local reg = 0 + for j=1,input:size(1) do + reg = reg + ((input[j]:norm() - inputTable[i-1][j]:norm())^2) + end + self.output = self.output + self.beta * reg / input:size(1) + end + end + return self.output + end + + -- Make a simple RNN and training set to test gradients + -- hyper-parameters + batchSize = 3 + rho = 2 + hiddenSize = 3 + inputSize = 4 + lr = 0.1 + beta = 50.0 + + -- build simple recurrent neural network + local r = nn.Recurrent( + hiddenSize, nn.Linear(inputSize, hiddenSize), + nn.Linear(hiddenSize, hiddenSize), nn.Sigmoid(), + rho + ) + local rnn = nn.Sequential() + :add(r) + :add(nn.NormStabilizer(beta)) + + rnn = nn.Sequencer(rnn) + criterion = nn.SequencerCriterionNormStab(nn.MSECriterion(), beta) + + local iteration = 1 + params, gradParams = rnn:getParameters() + + while iteration < 100 do + -- generate a random data point + local inputs, targets = {}, {} + for step=1,rho do + inputs[step] = torch.randn(batchSize, inputSize) + targets[step] = torch.randn(batchSize, hiddenSize) + end + + -- set up closure + function feval(params_new) + if params ~= params_new then + params:copy(params_new) + end + + rnn:zeroGradParameters() + local outputs = rnn:forward(inputs) + local err = criterion:forward(outputs, targets) + local gradOutputs = criterion:backward(outputs, targets) + local gradInputs = rnn:backward(inputs, gradOutputs) + return err, gradParams + end + + -- compare numerical to analytic gradient + local diff, dC, dC_est = checkgrad(feval, params, 1e-10) + mytester:assert(diff < 1e-3, "Numerical gradient and analytic gradient do not match.") + + rnn:updateParameters(lr) + + iteration = iteration + 1 + end +end + function rnn.test(tests, benchmark_) mytester = torch.Tester() benchmark = benchmark_