Skip to content

Commit

Permalink
GRU implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
jnhwkim committed Dec 3, 2015
1 parent fdc3b21 commit 01c89c6
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 5 deletions.
247 changes: 247 additions & 0 deletions GRU.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
------------------------------------------------------------------------
--[[ GRU ]]--
-- Gated Recurrent Units architecture.
-- http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-gruGRU-rnn-with-python-and-theano/
-- Expects 1D or 2D input.
-- The first input in sequence uses zero value for cell and hidden state
------------------------------------------------------------------------
assert(not nn.GRU, "update nnx package : luarocks install nnx")
local GRU, parent = torch.class('nn.GRU', 'nn.AbstractRecurrent')

function GRU:__init(inputSize, outputSize, rho)
parent.__init(self, rho or 9999)
self.inputSize = inputSize
self.outputSize = outputSize
-- build the model
self.recurrentModule = self:buildModel()
-- make it work with nn.Container
self.modules[1] = self.recurrentModule
self.sharedClones[1] = self.recurrentModule

-- for output(0), cell(0) and gradCell(T)
self.zeroTensor = torch.Tensor()

self.cells = {}
self.gradCells = {}
end

-------------------------- factory methods -----------------------------
function GRU:buildModel()
-- input : {input, prevOutput}
-- output : {output}

-- Calculate all four gates in one go : input, hidden, forget, output
self.i2g = nn.Linear(self.inputSize, 2*self.outputSize)
self.o2g = nn.LinearNoBias(self.outputSize, 2*self.outputSize)

local para = nn.ParallelTable():add(self.i2g):add(self.o2g)
local gates = nn.Sequential()
gates:add(para)
gates:add(nn.CAddTable())

-- Reshape to (batch_size, n_gates, hid_size)
-- Then slize the n_gates dimension, i.e dimension 2
gates:add(nn.Reshape(2,self.outputSize))
gates:add(nn.SplitTable(1,2))
local transfer = nn.ParallelTable()
transfer:add(nn.Sigmoid()):add(nn.Sigmoid())
gates:add(transfer)

local concat = nn.ConcatTable()
concat:add(nn.Identity()):add(gates)
local seq = nn.Sequential()
seq:add(concat)
seq:add(nn.FlattenTable()) -- x(t), s(t-1), r, z

-- Rearrange to x(t), s(t-1), r, z, s(t-1)
local concat = nn.ConcatTable() --
concat:add(nn.NarrowTable(1,4)):add(nn.SelectTable(2))
seq:add(concat):add(nn.FlattenTable())

-- h
local hidden = nn.Sequential()
local concat = nn.ConcatTable()
local t1 = nn.Sequential()
t1:add(nn.SelectTable(1)):add(nn.Linear(self.inputSize, self.outputSize))
local t2 = nn.Sequential()
t2:add(nn.NarrowTable(2,2)):add(nn.CMulTable()):add(nn.LinearNoBias(self.outputSize, self.outputSize))
concat:add(t1):add(t2)
hidden:add(concat):add(nn.CAddTable()):add(nn.Tanh())

local z1 = nn.Sequential()
z1:add(nn.SelectTable(4))
z1:add(nn.SAdd(-1, true)) -- Scalar add & negation

local z2 = nn.Sequential()
z2:add(nn.NarrowTable(4,2))
z2:add(nn.CMulTable())

local o1 = nn.Sequential()
local concat = nn.ConcatTable()
concat:add(hidden):add(z1)
o1:add(concat):add(nn.CMulTable())

local o2 = nn.Sequential()
local concat = nn.ConcatTable()
concat:add(o1):add(z2)
o2:add(concat):add(nn.CAddTable())

seq:add(o2)

return seq
end

------------------------- forward backward -----------------------------
function GRU:updateOutput(input)
local prevOutput
if self.step == 1 then
prevOutput = self.userPrevOutput or self.zeroTensor
if input:dim() == 2 then
self.zeroTensor:resize(input:size(1), self.outputSize):zero()
else
self.zeroTensor:resize(self.outputSize):zero()
end
else
-- previous output and cell of this module
prevOutput = self.output
end

-- output(t) = gru{input(t), output(t-1)}
local output
if self.train ~= false then
self:recycle()
local recurrentModule = self:getStepModule(self.step)
-- the actual forward propagation
output = recurrentModule:updateOutput{input, prevOutput}
else
output = self.recurrentModule:updateOutput{input, prevOutput}
end

if self.train ~= false then
local input_ = self.inputs[self.step]
self.inputs[self.step] = self.copyInputs
and nn.rnn.recursiveCopy(input_, input)
or nn.rnn.recursiveSet(input_, input)
end

self.outputs[self.step] = output

self.output = output

self.step = self.step + 1
self.gradPrevOutput = nil
self.updateGradInputStep = nil
self.accGradParametersStep = nil
self.gradParametersAccumulated = false
-- note that we don't return the cell, just the output
return self.output
end

function GRU:backwardThroughTime(timeStep, rho)
assert(self.step > 1, "expecting at least one updateOutput")
self.gradInputs = {} -- used by Sequencer, Repeater
timeStep = timeStep or self.step
local rho = math.min(rho or self.rho, timeStep-1)
local stop = timeStep - rho

if self.fastBackward then
for step=timeStep-1,math.max(stop,1),-1 do
-- set the output/gradOutput states of current Module
local recurrentModule = self:getStepModule(step)

-- backward propagate through this step
local gradOutput = self.gradOutputs[step]
if self.gradPrevOutput then
self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], self.gradPrevOutput)
nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
gradOutput = self._gradOutputs[step]
end

local scale = self.scales[step]
local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
local inputTable = {self.inputs[step], output, cell}
local gradInputTable = recurrentModule:backward(inputTable, gradOutput, scale)
gradInput, self.gradPrevOutput = unpack(gradInputTable)
table.insert(self.gradInputs, 1, gradInput)
if self.userPrevOutput then self.userGradPrevOutput = self.gradPrevOutput end
end
self.gradParametersAccumulated = true
return gradInput
else
local gradInput = self:updateGradInputThroughTime()
self:accGradParametersThroughTime()
return gradInput
end
end

function GRU:updateGradInputThroughTime(timeStep, rho)
assert(self.step > 1, "expecting at least one updateOutput")
self.gradInputs = {}
local gradInput
timeStep = timeStep or self.step
local rho = math.min(rho or self.rho, timeStep-1)
local stop = timeStep - rho

for step=timeStep-1,math.max(stop,1),-1 do
-- set the output/gradOutput states of current Module
local recurrentModule = self:getStepModule(step)

-- backward propagate through this step
local gradOutput = self.gradOutputs[step]
if self.gradPrevOutput then
self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], self.gradPrevOutput)
nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
gradOutput = self._gradOutputs[step]
end

local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
local inputTable = {self.inputs[step], output}
local gradInputTable = recurrentModule:updateGradInput(inputTable, gradOutput)
gradInput, self.gradPrevOutput = unpack(gradInputTable)
table.insert(self.gradInputs, 1, gradInput)
if self.userPrevOutput then self.userGradPrevOutput = self.gradPrevOutput end
end

return gradInput
end

function GRU:accGradParametersThroughTime(timeStep, rho)
timeStep = timeStep or self.step
local rho = math.min(rho or self.rho, timeStep-1)
local stop = timeStep - rho

for step=timeStep-1,math.max(stop,1),-1 do
-- set the output/gradOutput states of current Module
local recurrentModule = self:getStepModule(step)

-- backward propagate through this step
local scale = self.scales[step]
local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
local inputTable = {self.inputs[step], output}
local gradOutput = (step == self.step-1) and self.gradOutputs[step] or self._gradOutputs[step]
recurrentModule:accGradParameters(inputTable, gradOutput, scale)
end

self.gradParametersAccumulated = true
return gradInput
end

function GRU:accUpdateGradParametersThroughTime(lr, timeStep, rho)
timeStep = timeStep or self.step
local rho = math.min(rho or self.rho, timeStep-1)
local stop = timeStep - rho

for step=timeStep-1,math.max(stop,1),-1 do
-- set the output/gradOutput states of current Module
local recurrentModule = self:getStepModule(step)

-- backward propagate through this step
local scale = self.scales[step]
local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
local inputTable = {self.inputs[step], output}
local gradOutput = (step == self.step-1) and self.gradOutputs[step] or self._gradOutputs[step]
recurrentModule:accUpdateGradParameters(inputTable, self.gradOutputs[step], lr*scale)
end

return gradInput
end
29 changes: 29 additions & 0 deletions SAdd.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
local SAdd, parent = torch.class('nn.SAdd', 'nn.Module')

function SAdd:__init(addend, negate)
parent.__init(self)

self.addend = addend
self.negate = (negate == nil) and false or negate
end

function SAdd:updateOutput(input)
self.output:resizeAs(input):copy(input)
self.output = self.output + self.addend
if self.negate then
self.output = -self.output
end
return self.output
end

function SAdd:updateGradInput(input, gradOutput)
if self.gradInput then
self.gradInput:resizeAs(gradOutput):copy(gradOutput)
else
self.gradInput = torch.Tensor():resizeAs(gradOutput):copy(gradOutput)
end
if self.negate then
self.gradInput = -self.gradInput
end
return self.gradInput
end
14 changes: 9 additions & 5 deletions examples/recurrent-language-model.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ version = 1
--[[command line arguments]]--
cmd = torch.CmdLine()
cmd:text()
cmd:text('Train a Language Model on PennTreeBank dataset using RNN or LSTM')
cmd:text('Train a Language Model on PennTreeBank dataset using RNN or LSTM or GRU')
cmd:text('Example:')
cmd:text("recurrent-language-model.lua --cuda --useDevice 2 --progress --zeroFirst --cutoffNorm 4 --rho 10")
cmd:text('Options:')
Expand All @@ -27,8 +27,9 @@ cmd:option('--uniform', 0.1, 'initialize parameters using uniform distribution b

-- recurrent layer
cmd:option('--lstm', false, 'use Long Short Term Memory (nn.LSTM instead of nn.Recurrent)')
cmd:option('--gru', false, 'use Gated Recurrent Units (nn.GRU instead of nn.Recurrent)')
cmd:option('--rho', 5, 'back-propagate through time (BPTT) for rho time-steps')
cmd:option('--hiddenSize', '{200}', 'number of hidden units used at output of each recurrent layer. When more than one is specified, RNN/LSTMs are stacked')
cmd:option('--hiddenSize', '{200}', 'number of hidden units used at output of each recurrent layer. When more than one is specified, RNN/LSTMs/GRUs are stacked')
cmd:option('--zeroFirst', false, 'first step will forward zero through recurrence (i.e. add bias of recurrence). As opposed to learning bias specifically for first step.')
cmd:option('--dropout', false, 'apply dropout after each recurrent layer')
cmd:option('--dropoutProb', 0.5, 'probability of zeroing a neuron (dropout probability)')
Expand Down Expand Up @@ -61,13 +62,16 @@ lm = nn.Sequential()
local inputSize = opt.hiddenSize[1]
for i,hiddenSize in ipairs(opt.hiddenSize) do

if i~= 1 and not opt.lstm then
if i~= 1 and (not opt.lstm) and (not opt.gru) then
lm:add(nn.Sequencer(nn.Linear(inputSize, hiddenSize)))
end

-- recurrent layer
local rnn
if opt.lstm then
if opt.gru then
-- Gated Recurrent Units
rnn = nn.Sequencer(nn.GRU(inputSize, hiddenSize))
elseif opt.lstm then
-- Long Short Term Memory
rnn = nn.Sequencer(nn.FastLSTM(inputSize, hiddenSize))
else
Expand Down Expand Up @@ -117,7 +121,7 @@ if opt.uniform > 0 then
end

-- will recurse a single continuous sequence
lm:remember(opt.lstm and 'both' or 'eval')
lm:remember((opt.lstm or opt.gru) and 'both' or 'eval')

--[[Propagators]]--

Expand Down
2 changes: 2 additions & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ torch.include('rnn', 'AbstractRecurrent.lua')
torch.include('rnn', 'Recurrent.lua')
torch.include('rnn', 'LSTM.lua')
torch.include('rnn', 'FastLSTM.lua')
torch.include('rnn', 'GRU.lua')
torch.include('rnn', 'SAdd.lua')
torch.include('rnn', 'Recursor.lua')
torch.include('rnn', 'Recurrence.lua')

Expand Down

0 comments on commit 01c89c6

Please sign in to comment.