Skip to content

Commit 01c89c6

Browse files
committed
GRU implementation.
1 parent fdc3b21 commit 01c89c6

File tree

4 files changed

+287
-5
lines changed

4 files changed

+287
-5
lines changed

GRU.lua

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
------------------------------------------------------------------------
2+
--[[ GRU ]]--
3+
-- Gated Recurrent Units architecture.
4+
-- http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-gruGRU-rnn-with-python-and-theano/
5+
-- Expects 1D or 2D input.
6+
-- The first input in sequence uses zero value for cell and hidden state
7+
------------------------------------------------------------------------
8+
assert(not nn.GRU, "update nnx package : luarocks install nnx")
9+
local GRU, parent = torch.class('nn.GRU', 'nn.AbstractRecurrent')
10+
11+
function GRU:__init(inputSize, outputSize, rho)
12+
parent.__init(self, rho or 9999)
13+
self.inputSize = inputSize
14+
self.outputSize = outputSize
15+
-- build the model
16+
self.recurrentModule = self:buildModel()
17+
-- make it work with nn.Container
18+
self.modules[1] = self.recurrentModule
19+
self.sharedClones[1] = self.recurrentModule
20+
21+
-- for output(0), cell(0) and gradCell(T)
22+
self.zeroTensor = torch.Tensor()
23+
24+
self.cells = {}
25+
self.gradCells = {}
26+
end
27+
28+
-------------------------- factory methods -----------------------------
29+
function GRU:buildModel()
30+
-- input : {input, prevOutput}
31+
-- output : {output}
32+
33+
-- Calculate all four gates in one go : input, hidden, forget, output
34+
self.i2g = nn.Linear(self.inputSize, 2*self.outputSize)
35+
self.o2g = nn.LinearNoBias(self.outputSize, 2*self.outputSize)
36+
37+
local para = nn.ParallelTable():add(self.i2g):add(self.o2g)
38+
local gates = nn.Sequential()
39+
gates:add(para)
40+
gates:add(nn.CAddTable())
41+
42+
-- Reshape to (batch_size, n_gates, hid_size)
43+
-- Then slize the n_gates dimension, i.e dimension 2
44+
gates:add(nn.Reshape(2,self.outputSize))
45+
gates:add(nn.SplitTable(1,2))
46+
local transfer = nn.ParallelTable()
47+
transfer:add(nn.Sigmoid()):add(nn.Sigmoid())
48+
gates:add(transfer)
49+
50+
local concat = nn.ConcatTable()
51+
concat:add(nn.Identity()):add(gates)
52+
local seq = nn.Sequential()
53+
seq:add(concat)
54+
seq:add(nn.FlattenTable()) -- x(t), s(t-1), r, z
55+
56+
-- Rearrange to x(t), s(t-1), r, z, s(t-1)
57+
local concat = nn.ConcatTable() --
58+
concat:add(nn.NarrowTable(1,4)):add(nn.SelectTable(2))
59+
seq:add(concat):add(nn.FlattenTable())
60+
61+
-- h
62+
local hidden = nn.Sequential()
63+
local concat = nn.ConcatTable()
64+
local t1 = nn.Sequential()
65+
t1:add(nn.SelectTable(1)):add(nn.Linear(self.inputSize, self.outputSize))
66+
local t2 = nn.Sequential()
67+
t2:add(nn.NarrowTable(2,2)):add(nn.CMulTable()):add(nn.LinearNoBias(self.outputSize, self.outputSize))
68+
concat:add(t1):add(t2)
69+
hidden:add(concat):add(nn.CAddTable()):add(nn.Tanh())
70+
71+
local z1 = nn.Sequential()
72+
z1:add(nn.SelectTable(4))
73+
z1:add(nn.SAdd(-1, true)) -- Scalar add & negation
74+
75+
local z2 = nn.Sequential()
76+
z2:add(nn.NarrowTable(4,2))
77+
z2:add(nn.CMulTable())
78+
79+
local o1 = nn.Sequential()
80+
local concat = nn.ConcatTable()
81+
concat:add(hidden):add(z1)
82+
o1:add(concat):add(nn.CMulTable())
83+
84+
local o2 = nn.Sequential()
85+
local concat = nn.ConcatTable()
86+
concat:add(o1):add(z2)
87+
o2:add(concat):add(nn.CAddTable())
88+
89+
seq:add(o2)
90+
91+
return seq
92+
end
93+
94+
------------------------- forward backward -----------------------------
95+
function GRU:updateOutput(input)
96+
local prevOutput
97+
if self.step == 1 then
98+
prevOutput = self.userPrevOutput or self.zeroTensor
99+
if input:dim() == 2 then
100+
self.zeroTensor:resize(input:size(1), self.outputSize):zero()
101+
else
102+
self.zeroTensor:resize(self.outputSize):zero()
103+
end
104+
else
105+
-- previous output and cell of this module
106+
prevOutput = self.output
107+
end
108+
109+
-- output(t) = gru{input(t), output(t-1)}
110+
local output
111+
if self.train ~= false then
112+
self:recycle()
113+
local recurrentModule = self:getStepModule(self.step)
114+
-- the actual forward propagation
115+
output = recurrentModule:updateOutput{input, prevOutput}
116+
else
117+
output = self.recurrentModule:updateOutput{input, prevOutput}
118+
end
119+
120+
if self.train ~= false then
121+
local input_ = self.inputs[self.step]
122+
self.inputs[self.step] = self.copyInputs
123+
and nn.rnn.recursiveCopy(input_, input)
124+
or nn.rnn.recursiveSet(input_, input)
125+
end
126+
127+
self.outputs[self.step] = output
128+
129+
self.output = output
130+
131+
self.step = self.step + 1
132+
self.gradPrevOutput = nil
133+
self.updateGradInputStep = nil
134+
self.accGradParametersStep = nil
135+
self.gradParametersAccumulated = false
136+
-- note that we don't return the cell, just the output
137+
return self.output
138+
end
139+
140+
function GRU:backwardThroughTime(timeStep, rho)
141+
assert(self.step > 1, "expecting at least one updateOutput")
142+
self.gradInputs = {} -- used by Sequencer, Repeater
143+
timeStep = timeStep or self.step
144+
local rho = math.min(rho or self.rho, timeStep-1)
145+
local stop = timeStep - rho
146+
147+
if self.fastBackward then
148+
for step=timeStep-1,math.max(stop,1),-1 do
149+
-- set the output/gradOutput states of current Module
150+
local recurrentModule = self:getStepModule(step)
151+
152+
-- backward propagate through this step
153+
local gradOutput = self.gradOutputs[step]
154+
if self.gradPrevOutput then
155+
self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], self.gradPrevOutput)
156+
nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
157+
gradOutput = self._gradOutputs[step]
158+
end
159+
160+
local scale = self.scales[step]
161+
local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
162+
local inputTable = {self.inputs[step], output, cell}
163+
local gradInputTable = recurrentModule:backward(inputTable, gradOutput, scale)
164+
gradInput, self.gradPrevOutput = unpack(gradInputTable)
165+
table.insert(self.gradInputs, 1, gradInput)
166+
if self.userPrevOutput then self.userGradPrevOutput = self.gradPrevOutput end
167+
end
168+
self.gradParametersAccumulated = true
169+
return gradInput
170+
else
171+
local gradInput = self:updateGradInputThroughTime()
172+
self:accGradParametersThroughTime()
173+
return gradInput
174+
end
175+
end
176+
177+
function GRU:updateGradInputThroughTime(timeStep, rho)
178+
assert(self.step > 1, "expecting at least one updateOutput")
179+
self.gradInputs = {}
180+
local gradInput
181+
timeStep = timeStep or self.step
182+
local rho = math.min(rho or self.rho, timeStep-1)
183+
local stop = timeStep - rho
184+
185+
for step=timeStep-1,math.max(stop,1),-1 do
186+
-- set the output/gradOutput states of current Module
187+
local recurrentModule = self:getStepModule(step)
188+
189+
-- backward propagate through this step
190+
local gradOutput = self.gradOutputs[step]
191+
if self.gradPrevOutput then
192+
self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], self.gradPrevOutput)
193+
nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
194+
gradOutput = self._gradOutputs[step]
195+
end
196+
197+
local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
198+
local inputTable = {self.inputs[step], output}
199+
local gradInputTable = recurrentModule:updateGradInput(inputTable, gradOutput)
200+
gradInput, self.gradPrevOutput = unpack(gradInputTable)
201+
table.insert(self.gradInputs, 1, gradInput)
202+
if self.userPrevOutput then self.userGradPrevOutput = self.gradPrevOutput end
203+
end
204+
205+
return gradInput
206+
end
207+
208+
function GRU:accGradParametersThroughTime(timeStep, rho)
209+
timeStep = timeStep or self.step
210+
local rho = math.min(rho or self.rho, timeStep-1)
211+
local stop = timeStep - rho
212+
213+
for step=timeStep-1,math.max(stop,1),-1 do
214+
-- set the output/gradOutput states of current Module
215+
local recurrentModule = self:getStepModule(step)
216+
217+
-- backward propagate through this step
218+
local scale = self.scales[step]
219+
local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
220+
local inputTable = {self.inputs[step], output}
221+
local gradOutput = (step == self.step-1) and self.gradOutputs[step] or self._gradOutputs[step]
222+
recurrentModule:accGradParameters(inputTable, gradOutput, scale)
223+
end
224+
225+
self.gradParametersAccumulated = true
226+
return gradInput
227+
end
228+
229+
function GRU:accUpdateGradParametersThroughTime(lr, timeStep, rho)
230+
timeStep = timeStep or self.step
231+
local rho = math.min(rho or self.rho, timeStep-1)
232+
local stop = timeStep - rho
233+
234+
for step=timeStep-1,math.max(stop,1),-1 do
235+
-- set the output/gradOutput states of current Module
236+
local recurrentModule = self:getStepModule(step)
237+
238+
-- backward propagate through this step
239+
local scale = self.scales[step]
240+
local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
241+
local inputTable = {self.inputs[step], output}
242+
local gradOutput = (step == self.step-1) and self.gradOutputs[step] or self._gradOutputs[step]
243+
recurrentModule:accUpdateGradParameters(inputTable, self.gradOutputs[step], lr*scale)
244+
end
245+
246+
return gradInput
247+
end

SAdd.lua

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
local SAdd, parent = torch.class('nn.SAdd', 'nn.Module')
2+
3+
function SAdd:__init(addend, negate)
4+
parent.__init(self)
5+
6+
self.addend = addend
7+
self.negate = (negate == nil) and false or negate
8+
end
9+
10+
function SAdd:updateOutput(input)
11+
self.output:resizeAs(input):copy(input)
12+
self.output = self.output + self.addend
13+
if self.negate then
14+
self.output = -self.output
15+
end
16+
return self.output
17+
end
18+
19+
function SAdd:updateGradInput(input, gradOutput)
20+
if self.gradInput then
21+
self.gradInput:resizeAs(gradOutput):copy(gradOutput)
22+
else
23+
self.gradInput = torch.Tensor():resizeAs(gradOutput):copy(gradOutput)
24+
end
25+
if self.negate then
26+
self.gradInput = -self.gradInput
27+
end
28+
return self.gradInput
29+
end

examples/recurrent-language-model.lua

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ version = 1
66
--[[command line arguments]]--
77
cmd = torch.CmdLine()
88
cmd:text()
9-
cmd:text('Train a Language Model on PennTreeBank dataset using RNN or LSTM')
9+
cmd:text('Train a Language Model on PennTreeBank dataset using RNN or LSTM or GRU')
1010
cmd:text('Example:')
1111
cmd:text("recurrent-language-model.lua --cuda --useDevice 2 --progress --zeroFirst --cutoffNorm 4 --rho 10")
1212
cmd:text('Options:')
@@ -27,8 +27,9 @@ cmd:option('--uniform', 0.1, 'initialize parameters using uniform distribution b
2727

2828
-- recurrent layer
2929
cmd:option('--lstm', false, 'use Long Short Term Memory (nn.LSTM instead of nn.Recurrent)')
30+
cmd:option('--gru', false, 'use Gated Recurrent Units (nn.GRU instead of nn.Recurrent)')
3031
cmd:option('--rho', 5, 'back-propagate through time (BPTT) for rho time-steps')
31-
cmd:option('--hiddenSize', '{200}', 'number of hidden units used at output of each recurrent layer. When more than one is specified, RNN/LSTMs are stacked')
32+
cmd:option('--hiddenSize', '{200}', 'number of hidden units used at output of each recurrent layer. When more than one is specified, RNN/LSTMs/GRUs are stacked')
3233
cmd:option('--zeroFirst', false, 'first step will forward zero through recurrence (i.e. add bias of recurrence). As opposed to learning bias specifically for first step.')
3334
cmd:option('--dropout', false, 'apply dropout after each recurrent layer')
3435
cmd:option('--dropoutProb', 0.5, 'probability of zeroing a neuron (dropout probability)')
@@ -61,13 +62,16 @@ lm = nn.Sequential()
6162
local inputSize = opt.hiddenSize[1]
6263
for i,hiddenSize in ipairs(opt.hiddenSize) do
6364

64-
if i~= 1 and not opt.lstm then
65+
if i~= 1 and (not opt.lstm) and (not opt.gru) then
6566
lm:add(nn.Sequencer(nn.Linear(inputSize, hiddenSize)))
6667
end
6768

6869
-- recurrent layer
6970
local rnn
70-
if opt.lstm then
71+
if opt.gru then
72+
-- Gated Recurrent Units
73+
rnn = nn.Sequencer(nn.GRU(inputSize, hiddenSize))
74+
elseif opt.lstm then
7175
-- Long Short Term Memory
7276
rnn = nn.Sequencer(nn.FastLSTM(inputSize, hiddenSize))
7377
else
@@ -117,7 +121,7 @@ if opt.uniform > 0 then
117121
end
118122

119123
-- will recurse a single continuous sequence
120-
lm:remember(opt.lstm and 'both' or 'eval')
124+
lm:remember((opt.lstm or opt.gru) and 'both' or 'eval')
121125

122126
--[[Propagators]]--
123127

init.lua

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ torch.include('rnn', 'AbstractRecurrent.lua')
2424
torch.include('rnn', 'Recurrent.lua')
2525
torch.include('rnn', 'LSTM.lua')
2626
torch.include('rnn', 'FastLSTM.lua')
27+
torch.include('rnn', 'GRU.lua')
28+
torch.include('rnn', 'SAdd.lua')
2729
torch.include('rnn', 'Recursor.lua')
2830
torch.include('rnn', 'Recurrence.lua')
2931

0 commit comments

Comments
 (0)