1
+ ---- --------------------------------------------------------------------
2
+ --[[ GRU ]] --
3
+ -- Gated Recurrent Units architecture.
4
+ -- http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-gruGRU-rnn-with-python-and-theano/
5
+ -- Expects 1D or 2D input.
6
+ -- The first input in sequence uses zero value for cell and hidden state
7
+ ---- --------------------------------------------------------------------
8
+ assert (not nn .GRU , " update nnx package : luarocks install nnx" )
9
+ local GRU , parent = torch .class (' nn.GRU' , ' nn.AbstractRecurrent' )
10
+
11
+ function GRU :__init (inputSize , outputSize , rho )
12
+ parent .__init (self , rho or 9999 )
13
+ self .inputSize = inputSize
14
+ self .outputSize = outputSize
15
+ -- build the model
16
+ self .recurrentModule = self :buildModel ()
17
+ -- make it work with nn.Container
18
+ self .modules [1 ] = self .recurrentModule
19
+ self .sharedClones [1 ] = self .recurrentModule
20
+
21
+ -- for output(0), cell(0) and gradCell(T)
22
+ self .zeroTensor = torch .Tensor ()
23
+
24
+ self .cells = {}
25
+ self .gradCells = {}
26
+ end
27
+
28
+ ---- ---------------------- factory methods -----------------------------
29
+ function GRU :buildModel ()
30
+ -- input : {input, prevOutput}
31
+ -- output : {output}
32
+
33
+ -- Calculate all four gates in one go : input, hidden, forget, output
34
+ self .i2g = nn .Linear (self .inputSize , 2 * self .outputSize )
35
+ self .o2g = nn .LinearNoBias (self .outputSize , 2 * self .outputSize )
36
+
37
+ local para = nn .ParallelTable ():add (self .i2g ):add (self .o2g )
38
+ local gates = nn .Sequential ()
39
+ gates :add (para )
40
+ gates :add (nn .CAddTable ())
41
+
42
+ -- Reshape to (batch_size, n_gates, hid_size)
43
+ -- Then slize the n_gates dimension, i.e dimension 2
44
+ gates :add (nn .Reshape (2 ,self .outputSize ))
45
+ gates :add (nn .SplitTable (1 ,2 ))
46
+ local transfer = nn .ParallelTable ()
47
+ transfer :add (nn .Sigmoid ()):add (nn .Sigmoid ())
48
+ gates :add (transfer )
49
+
50
+ local concat = nn .ConcatTable ()
51
+ concat :add (nn .Identity ()):add (gates )
52
+ local seq = nn .Sequential ()
53
+ seq :add (concat )
54
+ seq :add (nn .FlattenTable ()) -- x(t), s(t-1), r, z
55
+
56
+ -- Rearrange to x(t), s(t-1), r, z, s(t-1)
57
+ local concat = nn .ConcatTable () --
58
+ concat :add (nn .NarrowTable (1 ,4 )):add (nn .SelectTable (2 ))
59
+ seq :add (concat ):add (nn .FlattenTable ())
60
+
61
+ -- h
62
+ local hidden = nn .Sequential ()
63
+ local concat = nn .ConcatTable ()
64
+ local t1 = nn .Sequential ()
65
+ t1 :add (nn .SelectTable (1 )):add (nn .Linear (self .inputSize , self .outputSize ))
66
+ local t2 = nn .Sequential ()
67
+ t2 :add (nn .NarrowTable (2 ,2 )):add (nn .CMulTable ()):add (nn .LinearNoBias (self .outputSize , self .outputSize ))
68
+ concat :add (t1 ):add (t2 )
69
+ hidden :add (concat ):add (nn .CAddTable ()):add (nn .Tanh ())
70
+
71
+ local z1 = nn .Sequential ()
72
+ z1 :add (nn .SelectTable (4 ))
73
+ z1 :add (nn .SAdd (- 1 , true )) -- Scalar add & negation
74
+
75
+ local z2 = nn .Sequential ()
76
+ z2 :add (nn .NarrowTable (4 ,2 ))
77
+ z2 :add (nn .CMulTable ())
78
+
79
+ local o1 = nn .Sequential ()
80
+ local concat = nn .ConcatTable ()
81
+ concat :add (hidden ):add (z1 )
82
+ o1 :add (concat ):add (nn .CMulTable ())
83
+
84
+ local o2 = nn .Sequential ()
85
+ local concat = nn .ConcatTable ()
86
+ concat :add (o1 ):add (z2 )
87
+ o2 :add (concat ):add (nn .CAddTable ())
88
+
89
+ seq :add (o2 )
90
+
91
+ return seq
92
+ end
93
+
94
+ ---- --------------------- forward backward -----------------------------
95
+ function GRU :updateOutput (input )
96
+ local prevOutput
97
+ if self .step == 1 then
98
+ prevOutput = self .userPrevOutput or self .zeroTensor
99
+ if input :dim () == 2 then
100
+ self .zeroTensor :resize (input :size (1 ), self .outputSize ):zero ()
101
+ else
102
+ self .zeroTensor :resize (self .outputSize ):zero ()
103
+ end
104
+ else
105
+ -- previous output and cell of this module
106
+ prevOutput = self .output
107
+ end
108
+
109
+ -- output(t) = gru{input(t), output(t-1)}
110
+ local output
111
+ if self .train ~= false then
112
+ self :recycle ()
113
+ local recurrentModule = self :getStepModule (self .step )
114
+ -- the actual forward propagation
115
+ output = recurrentModule :updateOutput {input , prevOutput }
116
+ else
117
+ output = self .recurrentModule :updateOutput {input , prevOutput }
118
+ end
119
+
120
+ if self .train ~= false then
121
+ local input_ = self .inputs [self .step ]
122
+ self .inputs [self .step ] = self .copyInputs
123
+ and nn .rnn .recursiveCopy (input_ , input )
124
+ or nn .rnn .recursiveSet (input_ , input )
125
+ end
126
+
127
+ self .outputs [self .step ] = output
128
+
129
+ self .output = output
130
+
131
+ self .step = self .step + 1
132
+ self .gradPrevOutput = nil
133
+ self .updateGradInputStep = nil
134
+ self .accGradParametersStep = nil
135
+ self .gradParametersAccumulated = false
136
+ -- note that we don't return the cell, just the output
137
+ return self .output
138
+ end
139
+
140
+ function GRU :backwardThroughTime (timeStep , rho )
141
+ assert (self .step > 1 , " expecting at least one updateOutput" )
142
+ self .gradInputs = {} -- used by Sequencer, Repeater
143
+ timeStep = timeStep or self .step
144
+ local rho = math.min (rho or self .rho , timeStep - 1 )
145
+ local stop = timeStep - rho
146
+
147
+ if self .fastBackward then
148
+ for step = timeStep - 1 ,math.max (stop ,1 ),- 1 do
149
+ -- set the output/gradOutput states of current Module
150
+ local recurrentModule = self :getStepModule (step )
151
+
152
+ -- backward propagate through this step
153
+ local gradOutput = self .gradOutputs [step ]
154
+ if self .gradPrevOutput then
155
+ self ._gradOutputs [step ] = nn .rnn .recursiveCopy (self ._gradOutputs [step ], self .gradPrevOutput )
156
+ nn .rnn .recursiveAdd (self ._gradOutputs [step ], gradOutput )
157
+ gradOutput = self ._gradOutputs [step ]
158
+ end
159
+
160
+ local scale = self .scales [step ]
161
+ local output = (step == 1 ) and (self .userPrevOutput or self .zeroTensor ) or self .outputs [step - 1 ]
162
+ local inputTable = {self .inputs [step ], output , cell }
163
+ local gradInputTable = recurrentModule :backward (inputTable , gradOutput , scale )
164
+ gradInput , self .gradPrevOutput = unpack (gradInputTable )
165
+ table.insert (self .gradInputs , 1 , gradInput )
166
+ if self .userPrevOutput then self .userGradPrevOutput = self .gradPrevOutput end
167
+ end
168
+ self .gradParametersAccumulated = true
169
+ return gradInput
170
+ else
171
+ local gradInput = self :updateGradInputThroughTime ()
172
+ self :accGradParametersThroughTime ()
173
+ return gradInput
174
+ end
175
+ end
176
+
177
+ function GRU :updateGradInputThroughTime (timeStep , rho )
178
+ assert (self .step > 1 , " expecting at least one updateOutput" )
179
+ self .gradInputs = {}
180
+ local gradInput
181
+ timeStep = timeStep or self .step
182
+ local rho = math.min (rho or self .rho , timeStep - 1 )
183
+ local stop = timeStep - rho
184
+
185
+ for step = timeStep - 1 ,math.max (stop ,1 ),- 1 do
186
+ -- set the output/gradOutput states of current Module
187
+ local recurrentModule = self :getStepModule (step )
188
+
189
+ -- backward propagate through this step
190
+ local gradOutput = self .gradOutputs [step ]
191
+ if self .gradPrevOutput then
192
+ self ._gradOutputs [step ] = nn .rnn .recursiveCopy (self ._gradOutputs [step ], self .gradPrevOutput )
193
+ nn .rnn .recursiveAdd (self ._gradOutputs [step ], gradOutput )
194
+ gradOutput = self ._gradOutputs [step ]
195
+ end
196
+
197
+ local output = (step == 1 ) and (self .userPrevOutput or self .zeroTensor ) or self .outputs [step - 1 ]
198
+ local inputTable = {self .inputs [step ], output }
199
+ local gradInputTable = recurrentModule :updateGradInput (inputTable , gradOutput )
200
+ gradInput , self .gradPrevOutput = unpack (gradInputTable )
201
+ table.insert (self .gradInputs , 1 , gradInput )
202
+ if self .userPrevOutput then self .userGradPrevOutput = self .gradPrevOutput end
203
+ end
204
+
205
+ return gradInput
206
+ end
207
+
208
+ function GRU :accGradParametersThroughTime (timeStep , rho )
209
+ timeStep = timeStep or self .step
210
+ local rho = math.min (rho or self .rho , timeStep - 1 )
211
+ local stop = timeStep - rho
212
+
213
+ for step = timeStep - 1 ,math.max (stop ,1 ),- 1 do
214
+ -- set the output/gradOutput states of current Module
215
+ local recurrentModule = self :getStepModule (step )
216
+
217
+ -- backward propagate through this step
218
+ local scale = self .scales [step ]
219
+ local output = (step == 1 ) and (self .userPrevOutput or self .zeroTensor ) or self .outputs [step - 1 ]
220
+ local inputTable = {self .inputs [step ], output }
221
+ local gradOutput = (step == self .step - 1 ) and self .gradOutputs [step ] or self ._gradOutputs [step ]
222
+ recurrentModule :accGradParameters (inputTable , gradOutput , scale )
223
+ end
224
+
225
+ self .gradParametersAccumulated = true
226
+ return gradInput
227
+ end
228
+
229
+ function GRU :accUpdateGradParametersThroughTime (lr , timeStep , rho )
230
+ timeStep = timeStep or self .step
231
+ local rho = math.min (rho or self .rho , timeStep - 1 )
232
+ local stop = timeStep - rho
233
+
234
+ for step = timeStep - 1 ,math.max (stop ,1 ),- 1 do
235
+ -- set the output/gradOutput states of current Module
236
+ local recurrentModule = self :getStepModule (step )
237
+
238
+ -- backward propagate through this step
239
+ local scale = self .scales [step ]
240
+ local output = (step == 1 ) and (self .userPrevOutput or self .zeroTensor ) or self .outputs [step - 1 ]
241
+ local inputTable = {self .inputs [step ], output }
242
+ local gradOutput = (step == self .step - 1 ) and self .gradOutputs [step ] or self ._gradOutputs [step ]
243
+ recurrentModule :accUpdateGradParameters (inputTable , self .gradOutputs [step ], lr * scale )
244
+ end
245
+
246
+ return gradInput
247
+ end
0 commit comments