Using both of the implementations below passes all the checks in 9th problem - Causal Self Attention -
def causal_attention(Q, K, V):
B, seq, d_k = Q.size()
mask = 1 - torch.triu(torch.ones(seq, seq), diagonal=1)
scores = (torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_k)) * mask
scores[scores == 0] = float('-inf')
attention = torch.bmm(torch.softmax(scores, dim=-1), V)
return attention
def causal_attention(Q, K, V):
B, seq, d_k = Q.size()
mask = 1 - torch.triu(torch.ones(seq, seq), diagonal=1)
scores = (torch.bmm(Q, K.transpose(1, 2)) / d_k) * mask
scores[scores == 0] = float('-inf')
attention = torch.bmm(torch.softmax(scores, dim=-1), V)
return attention
The difference above being a math.sqrt(d_k) instead of d_k.
Using both of the implementations below passes all the checks in 9th problem - Causal Self Attention -
The difference above being a
math.sqrt(d_k)instead ofd_k.