Skip to content

Outputs are not correctly checked - [9] Causal Self Attention #17

@arnavm2k3

Description

@arnavm2k3

Using both of the implementations below passes all the checks in 9th problem - Causal Self Attention -

def causal_attention(Q, K, V):
    B, seq, d_k = Q.size()
    mask = 1 - torch.triu(torch.ones(seq, seq), diagonal=1)
    scores = (torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_k)) * mask
    scores[scores == 0] = float('-inf')

    attention = torch.bmm(torch.softmax(scores, dim=-1), V)
    return attention
     
def causal_attention(Q, K, V):
    B, seq, d_k = Q.size()
    mask = 1 - torch.triu(torch.ones(seq, seq), diagonal=1)
    scores = (torch.bmm(Q, K.transpose(1, 2)) / d_k) * mask
    scores[scores == 0] = float('-inf')

    attention = torch.bmm(torch.softmax(scores, dim=-1), V)
    return attention
     

The difference above being a math.sqrt(d_k) instead of d_k.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions