-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathsru.py
396 lines (346 loc) · 17.2 KB
/
sru.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
from __future__ import absolute_import
import numpy as np
from keras import backend as K
from keras import activations
from keras import initializers
from keras import regularizers
from keras import constraints
from keras.engine import InputSpec
from keras.legacy import interfaces
from keras.layers import Recurrent
def _time_distributed_dense(x, w, b=None, dropout=None,
input_dim=None, output_dim=None,
timesteps=None, training=None):
"""Apply `y . w + b` for every temporal slice y of x.
# Arguments
x: input tensor.
w: weight matrix.
b: optional bias vector.
dropout: wether to apply dropout (same dropout mask
for every temporal slice of the input).
input_dim: integer; optional dimensionality of the input.
output_dim: integer; optional dimensionality of the output.
timesteps: integer; optional number of timesteps.
training: training phase tensor or boolean.
# Returns
Output tensor.
"""
if not input_dim:
input_dim = K.shape(x)[2]
if not timesteps:
timesteps = K.shape(x)[1]
if not output_dim:
output_dim = K.int_shape(w)[1]
if dropout is not None and 0. < dropout < 1.:
# apply the same dropout pattern at every timestep
ones = K.ones_like(K.reshape(x[:, 0, :], (-1, input_dim)))
dropout_matrix = K.dropout(ones, dropout)
expanded_dropout_matrix = K.repeat(dropout_matrix, timesteps)
x = K.in_train_phase(x * expanded_dropout_matrix, x, training=training)
# collapse time dimension and batch dimension together
x = K.reshape(x, (-1, input_dim))
x = K.dot(x, w)
if b is not None:
x = K.bias_add(x, b)
# reshape to 3D tensor
if K.backend() == 'tensorflow':
x = K.reshape(x, K.stack([-1, timesteps, output_dim]))
x.set_shape([None, None, output_dim])
else:
x = K.reshape(x, (-1, timesteps, output_dim))
return x
class SRU(Recurrent):
"""Simple Recurrent Unit - https://arxiv.org/pdf/1709.02755.pdf.
# Arguments
units: Positive integer, dimensionality of the output space.
activation: Activation function to use
(see [activations](../activations.md)).
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
for the recurrent step
(see [activations](../activations.md)).
use_bias: Boolean, whether the layer uses a bias vector.
project_input: Add a projection vector to the input
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
(see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
(see [initializers](../initializers.md)).
bias_initializer: Initializer for the bias vector
(see [initializers](../initializers.md)).
unit_forget_bias: Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Setting it to true will also force `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix
(see [regularizer](../regularizers.md)).
recurrent_regularizer: Regularizer function applied to
the `recurrent_kernel` weights matrix
(see [regularizer](../regularizers.md)).
bias_regularizer: Regularizer function applied to the bias vector
(see [regularizer](../regularizers.md)).
activity_regularizer: Regularizer function applied to
the output of the layer (its "activation").
(see [regularizer](../regularizers.md)).
kernel_constraint: Constraint function applied to
the `kernel` weights matrix
(see [constraints](../constraints.md)).
recurrent_constraint: Constraint function applied to
the `recurrent_kernel` weights matrix
(see [constraints](../constraints.md)).
bias_constraint: Constraint function applied to the bias vector
(see [constraints](../constraints.md)).
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
implementation: one of {0, 1, or 2}.
If set to 0, the SRU will use
an implementation that uses fewer, larger matrix products,
thus running faster on CPU but consuming more memory.
If set to 1, the SRU will use more matrix products,
but smaller ones, thus running slower
(may actually be faster on GPU) while consuming less memory.
If set to 2, the SRU will combine the input gate,
the forget gate and the output gate into a single matrix,
enabling more time-efficient parallelization on the GPU.
Note: SRU dropout must be shared for all gates,
resulting in a slightly reduced regularization.
# References
- [Long short-term memory](http://www.bioinf.jku.at/publications/older/2604.pdf) (original 1997 paper)
- [Learning to forget: Continual prediction with LSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015)
- [Supervised sequence labeling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf)
- [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287)
- [Training RNNs as Fast as CNNs](https://arxiv.org/abs/1709.02755)
"""
@interfaces.legacy_recurrent_support
def __init__(self, units,
activation='tanh',
recurrent_activation='sigmoid',
use_bias=True,
project_input=False,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
unit_forget_bias=True,
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
implementation=2,
**kwargs):
super(SRU, self).__init__(**kwargs)
self.units = units
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
self.use_bias = use_bias
self.project_input = project_input
self.kernel_initializer = initializers.get(kernel_initializer)
self.recurrent_initializer = initializers.get(recurrent_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.unit_forget_bias = unit_forget_bias
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.state_spec = [InputSpec(shape=(None, self.units)),
InputSpec(shape=(None, self.units))]
self.implementation = implementation
def build(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
batch_size = input_shape[0] if self.stateful else None
self.input_dim = input_shape[2]
self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim)) # (timesteps, batchsize, inputdim)
self.states = [None, None]
if self.stateful:
self.reset_states()
if self.project_input:
self.kernel_dim = 4
elif self.input_dim != self.units:
self.kernel_dim = 4
else:
self.kernel_dim = 3
self.kernel = self.add_weight(shape=(self.input_dim, self.units * self.kernel_dim),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
if self.unit_forget_bias:
def bias_initializer(shape, *args, **kwargs):
return K.concatenate([
self.bias_initializer((self.units,), *args, **kwargs),
initializers.Ones()((self.units,), *args, **kwargs),
])
else:
bias_initializer = self.bias_initializer
self.bias = self.add_weight(shape=(self.units * 2,),
name='bias',
initializer=bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
self.kernel_w = self.kernel[:, :self.units]
self.kernel_f = self.kernel[:, self.units: self.units * 2]
self.kernel_r = self.kernel[:, self.units * 2: self.units * 3]
if self.kernel_dim == 4:
self.kernel_p = self.kernel[:, self.units * 3: self.units * 4]
else:
self.kernel_p = None
if self.use_bias:
self.bias_f = self.bias[:self.units]
self.bias_r = self.bias[self.units: self.units * 2]
else:
self.bias_f = None
self.bias_r = None
self.built = True
def preprocess_input(self, inputs, training=None):
if self.implementation == 0:
input_shape = K.int_shape(inputs)
input_dim = input_shape[2]
timesteps = input_shape[1]
x_w = _time_distributed_dense(inputs, self.kernel_w, None,
self.dropout, input_dim, self.units,
timesteps, training=training)
x_f = _time_distributed_dense(inputs, self.kernel_f, self.bias_f,
self.dropout, input_dim, self.units,
timesteps, training=training)
x_r = _time_distributed_dense(inputs, self.kernel_r, self.bias_r,
self.dropout, input_dim, self.units,
timesteps, training=training)
x_f = self.recurrent_activation(x_f)
x_r = self.recurrent_activation(x_r)
if self.kernel_dim == 4:
x_p = _time_distributed_dense(inputs, self.kernel_p, None,
self.dropout, input_dim, self.units,
timesteps, training=training)
return K.concatenate([x_w, x_f, x_r, x_p], axis=2)
else:
return K.concatenate([x_w, x_f, x_r], axis=2)
else:
return inputs
def get_constants(self, inputs, training=None):
constants = []
if self.implementation != 0 and 0 < self.dropout < 1:
input_shape = K.int_shape(inputs) # (timesteps, batchsize, inputdim)
input_dim = input_shape[-1]
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
ones = K.tile(ones, (1, int(input_dim)))
def dropped_inputs():
return K.dropout(ones, self.dropout)
dp_mask = [K.in_train_phase(dropped_inputs,
ones,
training=training) for _ in range(3)]
constants.append(dp_mask)
else:
constants.append([K.cast_to_floatx(1.) for _ in range(3)])
if 0 < self.recurrent_dropout < 1:
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
ones = K.tile(ones, (1, self.units * self.kernel_dim))
def dropped_inputs():
return K.dropout(ones, self.recurrent_dropout)
rec_dp_mask = [K.in_train_phase(dropped_inputs,
ones,
training=training) for _ in range(self.kernel_dim)]
constants.append(rec_dp_mask)
else:
constants.append([K.cast_to_floatx(1.) for _ in range(self.kernel_dim)])
return constants
def step(self, inputs, states):
h_tm1 = states[0] # not used
c_tm1 = states[1]
dp_mask = states[2]
rec_dp_mask = states[3]
if self.implementation == 2:
z = K.dot(inputs * dp_mask[0], self.kernel)
z = z * rec_dp_mask[0]
z0 = z[:, :self.units]
if self.use_bias:
z_bias = K.bias_add(z[:, self.units: self.units * 3], self.bias)
z_bias = self.recurrent_activation(z_bias)
z1 = z_bias[:, :self.units]
z2 = z_bias[:, self.units: 2 * self.units]
else:
z1 = z[:, self.units: 2 * self.units]
z2 = z[:, 2 * self.units: 3 * self.units]
if self.kernel_dim == 4:
z3 = z[:, 3 * self.units: 4 * self.units]
else:
z3 = None
f = z1
r = z2
c = f * c_tm1 + (1 - f) * z0
if self.kernel_dim == 4:
h = r * self.activation(c) + (1 - r) * z3
else:
h = r * self.activation(c) + (1 - r) * inputs
else:
if self.implementation == 0:
x_w = inputs[:, :self.units]
x_f = inputs[:, self.units: 2 * self.units]
x_r = inputs[:, 2 * self.units: 3 * self.units]
if self.kernel_dim == 4:
x_w_x = inputs[:, 3 * self.units: 4 * self.units]
else:
x_w_x = None
elif self.implementation == 1:
x_w = K.dot(inputs * dp_mask[0], self.kernel_w)
x_f = K.dot(inputs * dp_mask[1], self.kernel_f) + self.bias_f
x_r = K.dot(inputs * dp_mask[2], self.kernel_r) + self.bias_r
x_f = self.recurrent_activation(x_f)
x_r = self.recurrent_activation(x_r)
if self.kernel_dim == 4:
x_w_x = K.dot(inputs * dp_mask[0], self.kernel_p)
else:
x_w_x = None
else:
raise ValueError('Unknown `implementation` mode.')
w = x_w * rec_dp_mask[0]
f = x_f
r = x_r
c = f * c_tm1 + (1 - f) * w
if self.kernel_dim == 4:
h = r * self.activation(c) + (1 - r) * x_w_x
else:
h = r * self.activation(c) + (1 - r) * inputs
if 0 < self.dropout + self.recurrent_dropout:
h._uses_learning_phase = True
return h, [h, c]
def get_config(self):
config = {'units': self.units,
'activation': activations.serialize(self.activation),
'recurrent_activation': activations.serialize(self.recurrent_activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'recurrent_initializer': initializers.serialize(self.recurrent_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'unit_forget_bias': self.unit_forget_bias,
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer': regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'recurrent_constraint': constraints.serialize(self.recurrent_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint),
'dropout': self.dropout,
'recurrent_dropout': self.recurrent_dropout}
base_config = super(SRU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))