Skip to content

Commit

Permalink
Fix the transformer (use final encoder outputs)
Browse files Browse the repository at this point in the history
  • Loading branch information
ageron committed May 10, 2019
1 parent 94dcc52 commit 152fe5c
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions 16_nlp_with_rnns_and_attention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1639,17 +1639,18 @@
"metadata": {},
"outputs": [],
"source": [
"Z = encoder_in\n",
"for N in range(6):\n",
" encoder_attn = keras.layers.Attention(use_scale=True)\n",
" encoder_in = encoder_attn([encoder_in, encoder_in])\n",
" masked_decoder_attn = keras.layers.Attention(use_scale=True, causal=True)\n",
" decoder_in = masked_decoder_attn([decoder_in, decoder_in])\n",
" decoder_attn = keras.layers.Attention(use_scale=True)\n",
" final_enc = decoder_attn([decoder_in, encoder_in])\n",
" Z = keras.layers.Attention(use_scale=True)([Z, Z])\n",
"\n",
"output_layer = keras.layers.TimeDistributed(\n",
" keras.layers.Dense(vocab_size, activation=\"softmax\"))\n",
"outputs = output_layer(final_enc)"
"encoder_outputs = Z\n",
"Z = decoder_in\n",
"for N in range(6):\n",
" Z = keras.layers.Attention(use_scale=True, causal=True)([Z, Z])\n",
" Z = keras.layers.Attention(use_scale=True)([Z, encoder_outputs])\n",
"\n",
"outputs = keras.layers.TimeDistributed(\n",
" keras.layers.Dense(vocab_size, activation=\"softmax\"))(Z)"
]
},
{
Expand Down

0 comments on commit 152fe5c

Please sign in to comment.