1+ ---- --------------------------------------------------------------------
2+ --[[ GRU ]] --
3+ -- Gated Recurrent Units architecture.
4+ -- http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-gruGRU-rnn-with-python-and-theano/
5+ -- Expects 1D or 2D input.
6+ -- The first input in sequence uses zero value for cell and hidden state
7+ ---- --------------------------------------------------------------------
8+ assert (not nn .GRU , " update nnx package : luarocks install nnx" )
9+ local GRU , parent = torch .class (' nn.GRU' , ' nn.AbstractRecurrent' )
10+
11+ function GRU :__init (inputSize , outputSize , rho )
12+ parent .__init (self , rho or 9999 )
13+ self .inputSize = inputSize
14+ self .outputSize = outputSize
15+ -- build the model
16+ self .recurrentModule = self :buildModel ()
17+ -- make it work with nn.Container
18+ self .modules [1 ] = self .recurrentModule
19+ self .sharedClones [1 ] = self .recurrentModule
20+
21+ -- for output(0), cell(0) and gradCell(T)
22+ self .zeroTensor = torch .Tensor ()
23+
24+ self .cells = {}
25+ self .gradCells = {}
26+ end
27+
28+ ---- ---------------------- factory methods -----------------------------
29+ function GRU :buildModel ()
30+ -- input : {input, prevOutput}
31+ -- output : {output}
32+
33+ -- Calculate all four gates in one go : input, hidden, forget, output
34+ self .i2g = nn .Linear (self .inputSize , 2 * self .outputSize )
35+ self .o2g = nn .LinearNoBias (self .outputSize , 2 * self .outputSize )
36+
37+ local para = nn .ParallelTable ():add (self .i2g ):add (self .o2g )
38+ local gates = nn .Sequential ()
39+ gates :add (para )
40+ gates :add (nn .CAddTable ())
41+
42+ -- Reshape to (batch_size, n_gates, hid_size)
43+ -- Then slize the n_gates dimension, i.e dimension 2
44+ gates :add (nn .Reshape (2 ,self .outputSize ))
45+ gates :add (nn .SplitTable (1 ,2 ))
46+ local transfer = nn .ParallelTable ()
47+ transfer :add (nn .Sigmoid ()):add (nn .Sigmoid ())
48+ gates :add (transfer )
49+
50+ local concat = nn .ConcatTable ()
51+ concat :add (nn .Identity ()):add (gates )
52+ local seq = nn .Sequential ()
53+ seq :add (concat )
54+ seq :add (nn .FlattenTable ()) -- x(t), s(t-1), r, z
55+
56+ -- Rearrange to x(t), s(t-1), r, z, s(t-1)
57+ local concat = nn .ConcatTable () --
58+ concat :add (nn .NarrowTable (1 ,4 )):add (nn .SelectTable (2 ))
59+ seq :add (concat ):add (nn .FlattenTable ())
60+
61+ -- h
62+ local hidden = nn .Sequential ()
63+ local concat = nn .ConcatTable ()
64+ local t1 = nn .Sequential ()
65+ t1 :add (nn .SelectTable (1 )):add (nn .Linear (self .inputSize , self .outputSize ))
66+ local t2 = nn .Sequential ()
67+ t2 :add (nn .NarrowTable (2 ,2 )):add (nn .CMulTable ()):add (nn .LinearNoBias (self .outputSize , self .outputSize ))
68+ concat :add (t1 ):add (t2 )
69+ hidden :add (concat ):add (nn .CAddTable ()):add (nn .Tanh ())
70+
71+ local z1 = nn .Sequential ()
72+ z1 :add (nn .SelectTable (4 ))
73+ z1 :add (nn .SAdd (- 1 , true )) -- Scalar add & negation
74+
75+ local z2 = nn .Sequential ()
76+ z2 :add (nn .NarrowTable (4 ,2 ))
77+ z2 :add (nn .CMulTable ())
78+
79+ local o1 = nn .Sequential ()
80+ local concat = nn .ConcatTable ()
81+ concat :add (hidden ):add (z1 )
82+ o1 :add (concat ):add (nn .CMulTable ())
83+
84+ local o2 = nn .Sequential ()
85+ local concat = nn .ConcatTable ()
86+ concat :add (o1 ):add (z2 )
87+ o2 :add (concat ):add (nn .CAddTable ())
88+
89+ seq :add (o2 )
90+
91+ return seq
92+ end
93+
94+ ---- --------------------- forward backward -----------------------------
95+ function GRU :updateOutput (input )
96+ local prevOutput
97+ if self .step == 1 then
98+ prevOutput = self .userPrevOutput or self .zeroTensor
99+ if input :dim () == 2 then
100+ self .zeroTensor :resize (input :size (1 ), self .outputSize ):zero ()
101+ else
102+ self .zeroTensor :resize (self .outputSize ):zero ()
103+ end
104+ else
105+ -- previous output and cell of this module
106+ prevOutput = self .output
107+ end
108+
109+ -- output(t) = gru{input(t), output(t-1)}
110+ local output
111+ if self .train ~= false then
112+ self :recycle ()
113+ local recurrentModule = self :getStepModule (self .step )
114+ -- the actual forward propagation
115+ output = recurrentModule :updateOutput {input , prevOutput }
116+ else
117+ output = self .recurrentModule :updateOutput {input , prevOutput }
118+ end
119+
120+ if self .train ~= false then
121+ local input_ = self .inputs [self .step ]
122+ self .inputs [self .step ] = self .copyInputs
123+ and nn .rnn .recursiveCopy (input_ , input )
124+ or nn .rnn .recursiveSet (input_ , input )
125+ end
126+
127+ self .outputs [self .step ] = output
128+
129+ self .output = output
130+
131+ self .step = self .step + 1
132+ self .gradPrevOutput = nil
133+ self .updateGradInputStep = nil
134+ self .accGradParametersStep = nil
135+ self .gradParametersAccumulated = false
136+ -- note that we don't return the cell, just the output
137+ return self .output
138+ end
139+
140+ function GRU :backwardThroughTime (timeStep , rho )
141+ assert (self .step > 1 , " expecting at least one updateOutput" )
142+ self .gradInputs = {} -- used by Sequencer, Repeater
143+ timeStep = timeStep or self .step
144+ local rho = math.min (rho or self .rho , timeStep - 1 )
145+ local stop = timeStep - rho
146+
147+ if self .fastBackward then
148+ for step = timeStep - 1 ,math.max (stop ,1 ),- 1 do
149+ -- set the output/gradOutput states of current Module
150+ local recurrentModule = self :getStepModule (step )
151+
152+ -- backward propagate through this step
153+ local gradOutput = self .gradOutputs [step ]
154+ if self .gradPrevOutput then
155+ self ._gradOutputs [step ] = nn .rnn .recursiveCopy (self ._gradOutputs [step ], self .gradPrevOutput )
156+ nn .rnn .recursiveAdd (self ._gradOutputs [step ], gradOutput )
157+ gradOutput = self ._gradOutputs [step ]
158+ end
159+
160+ local scale = self .scales [step ]
161+ local output = (step == 1 ) and (self .userPrevOutput or self .zeroTensor ) or self .outputs [step - 1 ]
162+ local inputTable = {self .inputs [step ], output , cell }
163+ local gradInputTable = recurrentModule :backward (inputTable , gradOutput , scale )
164+ gradInput , self .gradPrevOutput = unpack (gradInputTable )
165+ table.insert (self .gradInputs , 1 , gradInput )
166+ if self .userPrevOutput then self .userGradPrevOutput = self .gradPrevOutput end
167+ end
168+ self .gradParametersAccumulated = true
169+ return gradInput
170+ else
171+ local gradInput = self :updateGradInputThroughTime ()
172+ self :accGradParametersThroughTime ()
173+ return gradInput
174+ end
175+ end
176+
177+ function GRU :updateGradInputThroughTime (timeStep , rho )
178+ assert (self .step > 1 , " expecting at least one updateOutput" )
179+ self .gradInputs = {}
180+ local gradInput
181+ timeStep = timeStep or self .step
182+ local rho = math.min (rho or self .rho , timeStep - 1 )
183+ local stop = timeStep - rho
184+
185+ for step = timeStep - 1 ,math.max (stop ,1 ),- 1 do
186+ -- set the output/gradOutput states of current Module
187+ local recurrentModule = self :getStepModule (step )
188+
189+ -- backward propagate through this step
190+ local gradOutput = self .gradOutputs [step ]
191+ if self .gradPrevOutput then
192+ self ._gradOutputs [step ] = nn .rnn .recursiveCopy (self ._gradOutputs [step ], self .gradPrevOutput )
193+ nn .rnn .recursiveAdd (self ._gradOutputs [step ], gradOutput )
194+ gradOutput = self ._gradOutputs [step ]
195+ end
196+
197+ local output = (step == 1 ) and (self .userPrevOutput or self .zeroTensor ) or self .outputs [step - 1 ]
198+ local inputTable = {self .inputs [step ], output }
199+ local gradInputTable = recurrentModule :updateGradInput (inputTable , gradOutput )
200+ gradInput , self .gradPrevOutput = unpack (gradInputTable )
201+ table.insert (self .gradInputs , 1 , gradInput )
202+ if self .userPrevOutput then self .userGradPrevOutput = self .gradPrevOutput end
203+ end
204+
205+ return gradInput
206+ end
207+
208+ function GRU :accGradParametersThroughTime (timeStep , rho )
209+ timeStep = timeStep or self .step
210+ local rho = math.min (rho or self .rho , timeStep - 1 )
211+ local stop = timeStep - rho
212+
213+ for step = timeStep - 1 ,math.max (stop ,1 ),- 1 do
214+ -- set the output/gradOutput states of current Module
215+ local recurrentModule = self :getStepModule (step )
216+
217+ -- backward propagate through this step
218+ local scale = self .scales [step ]
219+ local output = (step == 1 ) and (self .userPrevOutput or self .zeroTensor ) or self .outputs [step - 1 ]
220+ local inputTable = {self .inputs [step ], output }
221+ local gradOutput = (step == self .step - 1 ) and self .gradOutputs [step ] or self ._gradOutputs [step ]
222+ recurrentModule :accGradParameters (inputTable , gradOutput , scale )
223+ end
224+
225+ self .gradParametersAccumulated = true
226+ return gradInput
227+ end
228+
229+ function GRU :accUpdateGradParametersThroughTime (lr , timeStep , rho )
230+ timeStep = timeStep or self .step
231+ local rho = math.min (rho or self .rho , timeStep - 1 )
232+ local stop = timeStep - rho
233+
234+ for step = timeStep - 1 ,math.max (stop ,1 ),- 1 do
235+ -- set the output/gradOutput states of current Module
236+ local recurrentModule = self :getStepModule (step )
237+
238+ -- backward propagate through this step
239+ local scale = self .scales [step ]
240+ local output = (step == 1 ) and (self .userPrevOutput or self .zeroTensor ) or self .outputs [step - 1 ]
241+ local inputTable = {self .inputs [step ], output }
242+ local gradOutput = (step == self .step - 1 ) and self .gradOutputs [step ] or self ._gradOutputs [step ]
243+ recurrentModule :accUpdateGradParameters (inputTable , self .gradOutputs [step ], lr * scale )
244+ end
245+
246+ return gradInput
247+ end
0 commit comments