-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtransformer.py
175 lines (139 loc) · 7.51 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import torch
import torch.nn as nn
import torchvision.models as models
import warnings
import numpy as np
class ScaledDotAttention(nn.Module):
def __init__(self, hidden_size):
super(ScaledDotAttention, self).__init__()
self.hidden_size = hidden_size
self.Q = nn.Linear(hidden_size, hidden_size)
self.K = nn.Linear(hidden_size, hidden_size)
self.V = nn.Linear(hidden_size, hidden_size)
self.softmax = nn.Softmax(dim=1)
self.scaling_factor = torch.rsqrt(torch.tensor(self.hidden_size, dtype=torch.float))
def forward(self, queries, keys, values):
"""The forward pass of the scaled dot attention mechanism.
Arguments:
queries: The current decoder hidden state, 2D or 3D tensor. (batch_size x (k) x hidden_size)
keys: The encoder hidden states for each step of the input sequence. (batch_size x seq_len x hidden_size)
values: The encoder hidden states for each step of the input sequence. (batch_size x seq_len x hidden_size)
Returns:
context: weighted average of the values (batch_size x k x hidden_size)
attention_weights: Normalized attention weights for each encoder hidden state. (batch_size x seq_len x 1)
The output must be a softmax weighting over the seq_len annotations.
"""
# ------------
# FILL THIS IN
# ------------
batch_size = keys.size(0)
q = self.Q(queries)
k = self.K(keys)
v = self.V(values)
unnormalized_attention = torch.bmm(q, k.transpose(1,2)) * self.scaling_factor
attention_weights = self.softmax(unnormalized_attention)
context = torch.bmm(attention_weights, v)
return context, attention_weights
class CausalScaledDotAttention(nn.Module):
def __init__(self, hidden_size):
super(CausalScaledDotAttention, self).__init__()
self.hidden_size = hidden_size
self.neg_inf = torch.tensor(-1e7)
self.Q = nn.Linear(hidden_size, hidden_size)
self.K = nn.Linear(hidden_size, hidden_size)
self.V = nn.Linear(hidden_size, hidden_size)
self.softmax = nn.Softmax(dim=1)
self.scaling_factor = torch.rsqrt(torch.tensor(self.hidden_size, dtype=torch.float))
def forward(self, queries, keys, values):
"""The forward pass of the scaled dot attention mechanism.
Arguments:
queries: The current decoder hidden state, 2D or 3D tensor. (batch_size x (k) x hidden_size)
keys: The encoder hidden states for each step of the input sequence. (batch_size x seq_len x hidden_size)
values: The encoder hidden states for each step of the input sequence. (batch_size x seq_len x hidden_size)
Returns:
context: weighted average of the values (batch_size x k x hidden_size)
attention_weights: Normalized attention weights for each encoder hidden state. (batch_size x seq_len x 1)
The output must be a softmax weighting over the seq_len annotations.
"""
# ------------
# FILL THIS IN
# ------------
batch_size = keys.size(0)
q = self.Q(queries)
k = self.K(keys)
v = self.V(values)
unnormalized_attention = torch.bmm(q, k.transpose(1,2)) * self.scaling_factor
mask = self.neg_inf * torch.triu(torch.ones_like(unnormalized_attention))
attention_weights = self.softmax(unnormalized_attention + mask)
context = torch.bmm(attention_weights, v)
return context, attention_weights
class TransformerDecoder(nn.Module):
def __init__(self, vocab_size, hidden_size, num_layers):
super(TransformerDecoder, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.num_layers = num_layers
self.self_attentions = nn.ModuleList([CausalScaledDotAttention(
hidden_size=hidden_size,
) for i in range(self.num_layers)])
self.encoder_attentions = nn.ModuleList([ScaledDotAttention(
hidden_size=hidden_size,
) for i in range(self.num_layers)])
self.attention_mlps = nn.ModuleList([nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
) for i in range(self.num_layers)])
self.out = nn.Linear(hidden_size, vocab_size)
self.positional_encodings = self.create_positional_encodings()
def forward(self, inputs, annotations, hidden_init, cell_init):
"""Forward pass of the attention-based decoder RNN.
Arguments:
inputs: Input token indexes across a batch for all the time step. (batch_size x decoder_seq_len)
annotations: The encoder hidden states for each step of the input.
sequence. (batch_size x seq_len x hidden_size)
hidden_init: Not used in the transformer decoder
cell_init: Not used in transformer decoder
Returns:
output: Un-normalized scores for each token in the vocabulary, across a batch for all the decoding time steps. (batch_size x decoder_seq_len x vocab_size)
attentions: The stacked attention weights applied to the encoder annotations (batch_size x encoder_seq_len x decoder_seq_len)
"""
batch_size, seq_len = inputs.size()
embed = self.embedding(inputs) # batch_size x seq_len x hidden_size
embed = embed + self.positional_encodings[:seq_len]
encoder_attention_weights_list = []
self_attention_weights_list = []
contexts = embed
for i in range(self.num_layers):
new_contexts, self_attention_weights = self.self_attentions[i](
contexts, contexts, contexts) # batch_size x seq_len x hidden_size
residual_contexts = contexts + new_contexts
new_contexts, encoder_attention_weights = self.encoder_attentions[i](
residual_contexts, annotations, annotations) # batch_size x seq_len x hidden_size
residual_contexts = residual_contexts + new_contexts
new_contexts = self.attention_mlps[i](residual_contexts)
contexts = residual_contexts + new_contexts
encoder_attention_weights_list.append(encoder_attention_weights)
self_attention_weights_list.append(self_attention_weights)
output = self.out(contexts)
encoder_attention_weights = torch.stack(encoder_attention_weights_list)
self_attention_weights = torch.stack(self_attention_weights_list)
return output, (encoder_attention_weights, self_attention_weights)
def create_positional_encodings(self, max_seq_len=1000):
"""Creates positional encodings for the inputs.
Arguments:
max_seq_len: a number larger than the maximum string length we expect to encounter during training
Returns:
pos_encodings: (max_seq_len, hidden_dim) Positional encodings for a sequence with length max_seq_len.
"""
pos_indices = torch.arange(max_seq_len)[..., None]
dim_indices = torch.arange(self.hidden_size//2)[None, ...]
exponents = (2*dim_indices).float()/(self.hidden_size)
trig_args = pos_indices / (10000**exponents)
sin_terms = torch.sin(trig_args)
cos_terms = torch.cos(trig_args)
pos_encodings = torch.zeros((max_seq_len, self.hidden_size))
pos_encodings[:, 0::2] = sin_terms
pos_encodings[:, 1::2] = cos_terms
pos_encodings = pos_encodings.cuda()
return pos_encodings