Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add paged attention support #1355

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open

Conversation

cyanguwa
Copy link
Collaborator

@cyanguwa cyanguwa commented Dec 4, 2024

Description

This PR adds paged attention support for FusedAttention, FlashAttention, and UnfusedDotProductAttention.

  • KV cache is maintained in 'bshd' format, and it supports FP32/FP16/BF16, not FP8 yet
  • Context parallelism is not supported with KV cache
  • FusedAttention and UnfusedDotProductAttention support page_size >= 16, and FlashAttention supports page_size >= 256
  • All backends support both pure generation and mixed generation/context in the batch
  • FlashAttention supports paged attention through flash_attn_varlen_func, not flash_attn_with_kvcache, due to some numerical issues
  • UnfusedDotProductAttention supports paged attention by converting paged cache tensors to non-paged, before attention

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Add paged attention support for FusedAttention, FlashAttention, and UnfusedDotProductAttention

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Dec 4, 2024

/te-ci pytorch L0

@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Jan 6, 2025

/te-ci pytorch L0

v_cache: torch.Tensor
The value cache tensor containing previous and the current tokens
"""
k_cache, v_cache = self.cache[layer_number]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
k_cache, v_cache = self.cache[layer_number]
assert layer_number in self.cache
k_cache, v_cache = self.cache[layer_number]

def __init__(
self,
max_batch_size: int,
max_seqlen_kv: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

corresponding docstring says max_sequence_length, we should change one of those

seq_s = self.sequences[seq] - step_dict[seq]
seq_e = self.sequences[seq]
if qkv_format == "bshd":
new_k_cache[i, seq_s:seq_e, :, :] = k[i, : step_dict[seq], :, :]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k[i, : step_dict[seq], :, :]

k isn't supposed to have any tokens beyond step_dict[seq], right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for v

Comment on lines +128 to +129
seq_s = self.sequences[seq] - step_dict[seq]
seq_e = self.sequences[seq]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These could potentially be moved into a method since this could be reused from outside like when getting the start positions of RoPE embeddings application

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants