Skip to content

Commit

Permalink
GQA Attention (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
RishikeshMagar authored Jan 11, 2024
1 parent 31dcd8f commit 0d4cba5
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 4 deletions.
100 changes: 96 additions & 4 deletions protein_lm/modeling/models/apt/model_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformers.pytorch_utils import Conv1D
from transformers.activations import ACT2FN
from transformers.utils import logging

from protein_lm.modeling.utils.rotary_embedding import RotaryEmbedding
from protein_lm.modeling.utils.rerope_embedding import RectifiedRotaryEmbedding
from protein_lm.modeling.utils.alibi_embedding import create_alibi_tensor
Expand All @@ -34,6 +35,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.max_sequence_length = config.max_sequence_length
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.attn_type = config.attn_type
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
Expand All @@ -48,7 +50,15 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
# Layer-wise attention scaling, reordering, and upcasting
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
self.layer_idx = layer_idx
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn

if self.attn_type == "gqa":
self.gqa_attn = True
elif self.attn_type == "reorder_and_upcast_attn":
self.reorder_and_upcast_attn = True
elif self.attn_type == "standard":
self.standard_attn = True

#self.reorder_and_upcast_attn = config.reorder_and_upcast_attn #comment out because config now states attn type

if self.is_cross_attention:
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
Expand Down Expand Up @@ -116,6 +126,87 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bia

return attn_output, attn_weights

def _gqa_attn(self, query, key, value, attention_mask=None,
alibi_bias =None, dropout=0.0):
"""Group Query Attention implementation."""

# Check for potential issues before moving on
if not query.ndim == key.ndim == value.ndim == 4:
raise ValueError(f"Expected query, key, and value to be 4-dimensional, but got shapes "
f"{query.shape}, {key.shape}, and {value.shape}.")

"""
Expected shapes: (batch_size, num_heads, query_len, query_dim) similar to _upcast_and_reordered_attn
"""
batch_size, num_heads, query_len, query_dim = query.shape


scale_factor = 1.0
if self.scale_attn_weights:
scale_factor /= float(value.size(-1)) ** 0.5
query = query / scale_factor

'''
Determine the number of groups
For example lets say we have 4 queries heads and 2 keys heads, then we have 2 groups
Lets say the number of group are 2 and head are 2,
then reshape the query tensor to (batch_size, (2, 2), query_len, query_dim)
query shape (batch_size, num_groups, num_heads, query_len, query_dim)
attention_weights_grouped shape (batch_size, num_groups, num_heads, query_len, key_len).
attention weights shape: (batch_size, num_heads, query_len, key_len)
'''

n_groups = query.size(1) // key.size(1)

if n_groups > 1:
query_shape = query.shape
grouped_shape = (query_shape[0], n_groups, query_shape[1]//n_groups, query_shape[2], query_shape[3])
query_grouped = query.reshape(grouped_shape)
attn_weights_grouped = torch.matmul(query_grouped, key.transpose(-2, -1))
attn_weights = attn_weights_grouped.sum(dim=1)
#print("attn_weights:", attn_weights.shape)

else:
'''
If the number of groups is 1, then we can use the normal attention function
'''
attn_weights = torch.matmul(query, key.transpose(-2, -1))

if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)

if attention_mask is not None:
# Apply the attention mask
'''
Input attention_mask shape: (batch_size, query_len, key_len)
'''
attn_weights += attention_mask.unsqueeze(1) # Unsqueeze to Add head dimension

# Causal masking ensures that the attention mechanism doesn't attend to "future" tokens in sequences.
## Adapted to work with groups and ensure similarity with vanilla attention
if not self.is_cross_attention:
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

# print("attn_weights:", attn_weights)
# Softmax normalization to get the attention scores
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

if alibi_bias is not None:
attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)]

# Apply dropout if specified
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)

# Compute the output by multiplying the attention scores with the value tensor.
attn_output = torch.matmul(attn_weights, value)

return attn_output, attn_weights

def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None):
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
bsz, num_heads, q_seq_len, dk = query.size()
Expand Down Expand Up @@ -233,9 +324,10 @@ def forward(

if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask,alibi_bias=alibi_bias)
else:
elif self.standard_attn:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask,alibi_bias=alibi_bias)

elif self.gqa_attn:
attn_output, attn_weights = self._gqa_attn(query, key, value, attention_mask,alibi_bias=alibi_bias)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
Expand All @@ -244,7 +336,7 @@ def forward(
if output_attentions:
outputs += (attn_weights,)

return outputs # a, present, (attentions)
return outputs # a, present, (attentions)


class APTMLP(nn.Module):
Expand Down
89 changes: 89 additions & 0 deletions protein_lm/tests/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
import torch
from torch.nn import functional as F

from model_pytorch import APTAttention

class ParameterConfig:
def __init__(self):
self.max_position_embeddings = 512
self.position_embedding = 'rope'
self.max_sequence_length = 512
self.hidden_size = 768
self.num_attention_heads = 12
self.scale_attn_weights = True
self.scale_attn_by_inverse_layer_idx = True
self.reorder_and_upcast_attn = True
self.attn_pdrop = 0.1
self.resid_pdrop = 0.1
self.rope_scaling_factor = 1
self.rope_theta = 1
self.attn_type = 'gqa'


def test_vanilla_attn():
# Initialize with mock config
config = ParameterConfig()
attention = APTAttention(config, is_cross_attention=False, layer_idx=0)

# generate random input tensors
batch_size = 4
seq_length = 100
num_heads = config.num_attention_heads
query_dim = config.hidden_size // num_heads
query = torch.randn(batch_size, num_heads, seq_length, query_dim)
key = torch.randn(batch_size, num_heads, seq_length, query_dim)
value = torch.randn(batch_size, num_heads, seq_length, query_dim)

# Create a random attention mask for testing
attention_mask = torch.ones(batch_size,seq_length, seq_length)
padding_positions = 10
attention_mask[:, -padding_positions:, :] = float('-inf')
attention_mask[:, :, -padding_positions:] = float('-inf')
attention_mask = attention_mask.unsqueeze(1)
# Pass them through the _attn method
attn_output, attn_weights = attention._attn(query, key, value, attention_mask=attention_mask)

# Check the shapes and types of the output
assert isinstance(attn_output, torch.Tensor)
assert attn_output.shape == (batch_size, num_heads, seq_length, query_dim)
assert isinstance(attn_weights, torch.Tensor)
assert attn_weights.shape == (batch_size, num_heads, seq_length, seq_length)
print("Test passed!")

def test_gqa_attn():
# Initialize with mock config
config = ParameterConfig()
attention = APTAttention(config, is_cross_attention=False, layer_idx=0)

# generate random input tensors
batch_size = 4
seq_length = 100
num_heads = config.num_attention_heads
query_dim = config.hidden_size // num_heads
query = torch.randn(batch_size, num_heads, seq_length, query_dim)
key = torch.randn(batch_size, num_heads, seq_length, query_dim)
value = torch.randn(batch_size, num_heads, seq_length, query_dim)

# Create a random attention mask for testing
attention_mask = torch.ones(batch_size,seq_length, seq_length)
padding_positions = 10
attention_mask[:, -padding_positions:, :] = float('-inf')
attention_mask[:, :, -padding_positions:] = float('-inf')

# Pass them through the _gqa_attn method
attn_output, attn_weights = attention._gqa_attn(query, key, value, attention_mask=attention_mask)

# Check the shapes and types of the output
assert isinstance(attn_output, torch.Tensor)
assert attn_output.shape == (batch_size, num_heads, seq_length, query_dim)
assert isinstance(attn_weights, torch.Tensor)
assert attn_weights.shape == (batch_size, num_heads, seq_length, seq_length)
print("Test passed!")


test_gqa_attn()
test_vanilla_attn()



0 comments on commit 0d4cba5

Please sign in to comment.