Skip to content

Commit 046fe45

Browse files
fix recurrent docstrings (#2597)
1 parent 25d7f69 commit 046fe45

File tree

1 file changed

+47
-35
lines changed

1 file changed

+47
-35
lines changed

src/layers/recurrent.jl

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ the initial hidden state. The output of the `cell` is considered to be:
2626
2727
The input `x` should be an array of size `in x len` or `in x len x batch_size`,
2828
where `in` is the input dimension of the cell, `len` is the sequence length, and `batch_size` is the batch size.
29-
The `state` should be a valid state for the recurrent cell. If not provided, it obtained by calling
29+
The `state` should be a valid state for the recurrent cell. If not provided, it is obtained by calling
3030
`Flux.initialstates(cell)`.
3131
3232
The output is an array of size `out x len x batch_size`, where `out` is the output dimension of the cell.
@@ -107,7 +107,7 @@ See [`RNN`](@ref) for a layer that processes entire sequences.
107107
108108
rnncell(x, [h])
109109
110-
The arguments of the forward pass are:
110+
The arguments for the forward pass are:
111111
112112
- `x`: The input to the RNN. It should be a vector of size `in` or a matrix of size `in x batch_size`.
113113
- `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`.
@@ -210,12 +210,12 @@ end
210210
The most basic recurrent layer. Essentially acts as a `Dense` layer, but with the
211211
output fed back into the input each time step.
212212
213-
In the forward pass computes
213+
The forward pass computes
214214
215215
```math
216216
h_t = \sigma(W_i x_t + W_h h_{t-1} + b)
217217
```
218-
for all `len` steps `t` in the in input sequence.
218+
for all `len` steps `t` in the input sequence.
219219
220220
See [`RNNCell`](@ref) for a layer that processes a single time step.
221221
@@ -225,7 +225,7 @@ See [`RNNCell`](@ref) for a layer that processes a single time step.
225225
- `σ`: The non-linearity to apply to the output. Default is `tanh`.
226226
- `return_state`: Option to return the last state together with the output. Default is `false`.
227227
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
228-
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
228+
- `init_recurrent_kernel`: The initialization function to use for the hidden-to-hidden connection weights. Default is `glorot_uniform`.
229229
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
230230
231231
# Forward
@@ -239,7 +239,7 @@ The arguments of the forward pass are:
239239
If given, it is a vector of size `out` or a matrix of size `out x batch_size`.
240240
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
241241
242-
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns
242+
Returns all the new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns
243243
a tuple of the hidden stats `h_t` and the last state of the iteration.
244244
245245
# Examples
@@ -330,11 +330,13 @@ Behaves like an RNN but generally exhibits a longer memory span over sequences.
330330
In the forward pass, computes
331331
332332
```math
333-
i_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i)
334-
f_t = \sigma(W_{xf} x_t + W_{hf} h_{t-1} + b_f)
335-
c_t = f_t \odot c_{t-1} + i_t \odot \tanh(W_{xc} x_t + W_{hc} h_{t-1} + b_c)
336-
o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o)
337-
h_t = o_t \odot \tanh(c_t)
333+
\begin{aligned}
334+
i_t &= \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i)\\
335+
f_t &= \sigma(W_{xf} x_t + W_{hf} h_{t-1} + b_f)\\
336+
c_t &= f_t \odot c_{t-1} + i_t \odot \tanh(W_{xc} x_t + W_{hc} h_{t-1} + b_c)\\
337+
o_t &= \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o)\\
338+
h_t &= o_t \odot \tanh(c_t)
339+
\end{aligned}
338340
```
339341
340342
See also [`LSTM`](@ref) for a layer that processes entire sequences.
@@ -430,14 +432,16 @@ recurrent layer. Behaves like an RNN but generally exhibits a longer memory span
430432
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
431433
for a good overview of the internals.
432434
433-
In the forward pass, computes
435+
In the forward pass, it computes
434436
435437
```math
436-
i_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i)
437-
f_t = \sigma(W_{xf} x_t + W_{hf} h_{t-1} + b_f)
438-
c_t = f_t \odot c_{t-1} + i_t \odot \tanh(W_{xc} x_t + W_{hc} h_{t-1} + b_c)
439-
o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o)
440-
h_t = o_t \odot \tanh(c_t)
438+
\begin{aligned}
439+
i_t &= \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i)\\
440+
f_t &= \sigma(W_{xf} x_t + W_{hf} h_{t-1} + b_f)\\
441+
c_t &= f_t \odot c_{t-1} + i_t \odot \tanh(W_{xc} x_t + W_{hc} h_{t-1} + b_c)\\
442+
o_t &= \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o)\\
443+
h_t &= o_t \odot \tanh(c_t)
444+
\end{aligned}
441445
```
442446
for all `len` steps `t` in the input sequence.
443447
See [`LSTMCell`](@ref) for a layer that processes a single time step.
@@ -447,7 +451,7 @@ See [`LSTMCell`](@ref) for a layer that processes a single time step.
447451
- `in => out`: The input and output dimensions of the layer.
448452
- `return_state`: Option to return the last state together with the output. Default is `false`.
449453
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
450-
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
454+
- `init_recurrent_kernel`: The initialization function to use for the hidden-to-hidden connection weights. Default is `glorot_uniform`.
451455
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
452456
453457
# Forward
@@ -536,10 +540,12 @@ This implements the variant proposed in v1 of the referenced paper.
536540
In the forward pass, computes
537541
538542
```math
539-
r = \sigma(W_{xi} x + W_{hi} h + b_i)
540-
z = \sigma(W_{xz} x + W_{hz} h + b_z)
541-
h̃ = \tanh(W_{xh} x + r \odot W_{hh} h + b_h)
542-
h' = (1 - z) \odot h̃ + z \odot h
543+
\begin{aligned}
544+
r &= \sigma(W_{xi} x + W_{hi} h + b_i)\\
545+
z &= \sigma(W_{xz} x + W_{hz} h + b_z)\\
546+
h̃ &= \tanh(W_{xh} x + r \odot W_{hh} h + b_h)\\
547+
h' &= (1 - z) \odot h̃ + z \odot h
548+
\end{aligned}
543549
```
544550
545551
See also [`GRU`](@ref) for a layer that processes entire sequences.
@@ -635,10 +641,12 @@ the variant proposed in v1 of the referenced paper.
635641
The forward pass computes
636642
637643
```math
638-
r_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i)
639-
z_t = \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z)
640-
h̃_t = \tanh(W_{xh} x_t + r_t \odot W_{hh} h_{t-1} + b_h)
641-
h_t = (1 - z_t) \odot h̃_t + z_t \odot h_{t-1}
644+
\begin{aligned}
645+
r_t &= \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i)\\
646+
z_t &= \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z)\\
647+
h̃_t &= \tanh(W_{xh} x_t + r_t \odot W_{hh} h_{t-1} + b_h)\\
648+
h_t &= (1 - z_t) \odot h̃_t + z_t \odot h_{t-1}
649+
\end{aligned}
642650
```
643651
for all `len` steps `t` in the input sequence.
644652
See [`GRUCell`](@ref) for a layer that processes a single time step.
@@ -724,10 +732,12 @@ This implements the variant proposed in v3 of the referenced paper.
724732
725733
The forward pass computes
726734
```math
727-
r = \sigma(W_{xi} x + W_{hi} h + b_i)
728-
z = \sigma(W_{xz} x + W_{hz} h + b_z)
729-
h̃ = \tanh(W_{xh} x + W_{hh̃} (r \odot W_{hh} h) + b_h)
730-
h' = (1 - z) \odot h̃ + z \odot h
735+
\begin{aligned}
736+
r &= \sigma(W_{xi} x + W_{hi} h + b_i)\\
737+
z &= \sigma(W_{xz} x + W_{hz} h + b_z)\\
738+
h̃ &= \tanh(W_{xh} x + W_{hh̃} (r \odot W_{hh} h) + b_h)\\
739+
h' &= (1 - z) \odot h̃ + z \odot h
740+
\end{aligned}
731741
```
732742
and returns `h'`. This is a single time step of the GRU.
733743
@@ -813,10 +823,12 @@ the variant proposed in v3 of the referenced paper.
813823
The forward pass computes
814824
815825
```math
816-
r_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i)
817-
z_t = \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z)
818-
h̃_t = \tanh(W_{xh} x_t + W_{hh̃} (r_t \odot W_{hh} h_{t-1}) + b_h)
819-
h_t = (1 - z_t) \odot h̃_t + z_t \odot h_{t-1}
826+
\begin{aligned}
827+
r_t &= \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i)\\
828+
z_t &= \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z)\\
829+
h̃_t &= \tanh(W_{xh} x_t + W_{hh̃} (r_t \odot W_{hh} h_{t-1}) + b_h)\\
830+
h_t &= (1 - z_t) \odot h̃_t + z_t \odot h_{t-1}
831+
\end{aligned}
820832
```
821833
for all `len` steps `t` in the input sequence.
822834
See [`GRUv3Cell`](@ref) for a layer that processes a single time step.
@@ -893,4 +905,4 @@ end
893905

894906
function Base.show(io::IO, m::GRUv3)
895907
print(io, "GRUv3(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")
896-
end
908+
end

0 commit comments

Comments
 (0)