Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore flash attention support #431

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions lit_llama/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def forward(
self,
x: torch.Tensor,
rope: RoPECache,
mask: torch.Tensor,
max_seq_length: int,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
adapter_kv_cache: Optional[KVCache] = None,
Expand Down Expand Up @@ -142,7 +142,7 @@ def forward(

# efficient attention using Flash Attention CUDA kernels
# ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) # (B, nh, T, hs)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=mask is None) # (B, nh, T, hs)

# "Adapters are applied to the topmost layers to better tune the language
# representations with higher-level semantics".
Expand Down Expand Up @@ -203,14 +203,14 @@ def forward(
self,
x: torch.Tensor,
rope: RoPECache,
mask: torch.Tensor,
max_seq_length: int,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
adapter_kv_cache: Optional[KVCache] = None,
) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:
h, new_kv_cache, new_adapter_kv_cache = self.attn(
self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache, adapter_kv_cache
self.rms_1(x), rope, max_seq_length, mask, input_pos, kv_cache, adapter_kv_cache
)
x = x + h
x = x + self.mlp(self.rms_2(x))
Expand Down Expand Up @@ -253,6 +253,7 @@ def forward(
self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]:
B, T = idx.size()
use_kv_cache = input_pos is not None

block_size = self.config.block_size
if max_seq_length is None:
Expand All @@ -263,23 +264,26 @@ def forward(

if self.rope_cache is None:
self.rope_cache = self.build_rope_cache(idx) # (block_size, head_size / 2, 2)
if self.mask_cache is None:
# passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
# for the kv-cache support (only during inference), we only create it in that situation
# this will be resolved by https://github.com/pytorch/pytorch/issues/96099
if use_kv_cache and self.mask_cache is None:
self.mask_cache = self.build_mask_cache(idx) # (1, 1, block_size, block_size)

if input_pos is not None:
if use_kv_cache:
rope = self.rope_cache.index_select(0, input_pos)
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, :max_seq_length]
else:
rope = self.rope_cache[:T]
mask = self.mask_cache[:, :, :T, :T]
mask = None

# forward the model itself
x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)

if input_pos is None: # proxy for use_cache=False
if not use_kv_cache:
for block in self.transformer.h:
x, *_ = block(x, rope, mask, max_seq_length)
x, *_ = block(x, rope, max_seq_length)
else:
if not self.kv_caches:
head_size = self.config.n_embd // self.config.n_head
Expand All @@ -292,7 +296,7 @@ def forward(
self.adapter_kv_caches = [None for _ in range(self.config.n_layer)]
for i, block in enumerate(self.transformer.h):
x, self.kv_caches[i], self.adapter_kv_caches[i] = block(
x, rope, mask, max_seq_length, input_pos, self.kv_caches[i], self.adapter_kv_caches[i]
x, rope, max_seq_length, mask, input_pos, self.kv_caches[i], self.adapter_kv_caches[i]
)

x = self.transformer.ln_f(x) # (B, T, n_embd)
Expand Down
24 changes: 14 additions & 10 deletions lit_llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def forward(
self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]:
B, T = idx.size()
use_kv_cache = input_pos is not None

block_size = self.config.block_size
if max_seq_length is None:
Expand All @@ -85,23 +86,26 @@ def forward(

if self.rope_cache is None:
self.rope_cache = self.build_rope_cache(idx)
if self.mask_cache is None:
# passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
# for the kv-cache support (only during inference), we only create it in that situation
# this will be resolved by https://github.com/pytorch/pytorch/issues/96099
if use_kv_cache and self.mask_cache is None:
self.mask_cache = self.build_mask_cache(idx)

if input_pos is not None:
if use_kv_cache:
rope = self.rope_cache.index_select(0, input_pos)
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, :max_seq_length]
else:
rope = self.rope_cache[:T]
mask = self.mask_cache[:, :, :T, :T]
mask = None

# forward the model itself
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)

if input_pos is None: # proxy for use_cache=False
if not use_kv_cache:
for block in self.transformer.h:
x, _ = block(x, rope, mask, max_seq_length)
x, _ = block(x, rope, max_seq_length)
else:
if not self.kv_caches:
head_size = self.config.n_embd // self.config.n_head
Expand All @@ -111,7 +115,7 @@ def forward(
for _ in range(self.config.n_layer)
]
for i, block in enumerate(self.transformer.h):
x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
x, self.kv_caches[i] = block(x, rope, max_seq_length, mask, input_pos, self.kv_caches[i])

x = self.transformer.ln_f(x)

Expand Down Expand Up @@ -155,12 +159,12 @@ def forward(
self,
x: torch.Tensor,
rope: RoPECache,
mask: MaskCache,
max_seq_length: int,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple[torch.Tensor, Optional[KVCache]]:
h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
h, new_kv_cache = self.attn(self.rms_1(x), rope, max_seq_length, mask, input_pos, kv_cache)
x = x + h
x = x + self.mlp(self.rms_2(x))
return x, new_kv_cache
Expand All @@ -184,8 +188,8 @@ def forward(
self,
x: torch.Tensor,
rope: RoPECache,
mask: MaskCache,
max_seq_length: int,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple[torch.Tensor, Optional[KVCache]]:
Expand Down Expand Up @@ -225,7 +229,7 @@ def forward(
# y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)

# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=mask is None)

y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

Expand Down
4 changes: 2 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_to_orig_llama(lit_llama, orig_llama, kv_cache) -> None:
kv_cache_shape = (batch_size, n_head, block_size, head_size)
ours_kv_cache = torch.zeros(kv_cache_shape), torch.zeros(kv_cache_shape)
(llama_block_out, ours_kv_cache) = llama_model.transformer.h[0](
llama_embed, llama_rope, llama_mask, seq_len, torch.arange(block_size), ours_kv_cache
llama_embed, llama_rope, seq_len, llama_mask, torch.arange(block_size), ours_kv_cache
)
ours_k_cache = ours_kv_cache[0].permute(0, 2, 1, 3)
ours_v_cache = ours_kv_cache[1].permute(0, 2, 1, 3)
Expand All @@ -92,7 +92,7 @@ def test_to_orig_llama(lit_llama, orig_llama, kv_cache) -> None:
orig_llama_block_out = orig_llama_model.layers[0](
orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], orig_llama_mask
)
(llama_block_out, _) = llama_model.transformer.h[0](llama_embed, llama_rope, llama_mask, seq_len)
(llama_block_out, _) = llama_model.transformer.h[0](llama_embed, llama_rope, seq_len, llama_mask)
assert torch.allclose(orig_llama_block_out, llama_block_out)

expected = orig_llama_model(token_sample, 0)
Expand Down