-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add paged attention support #1355
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch L0 |
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
/te-ci pytorch L0 |
Signed-off-by: Charlene Yang <[email protected]>
v_cache: torch.Tensor | ||
The value cache tensor containing previous and the current tokens | ||
""" | ||
k_cache, v_cache = self.cache[layer_number] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
k_cache, v_cache = self.cache[layer_number] | |
assert layer_number in self.cache | |
k_cache, v_cache = self.cache[layer_number] |
def __init__( | ||
self, | ||
max_batch_size: int, | ||
max_seqlen_kv: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
corresponding docstring says max_sequence_length
, we should change one of those
seq_s = self.sequences[seq] - step_dict[seq] | ||
seq_e = self.sequences[seq] | ||
if qkv_format == "bshd": | ||
new_k_cache[i, seq_s:seq_e, :, :] = k[i, : step_dict[seq], :, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
k[i, : step_dict[seq], :, :]
k
isn't supposed to have any tokens beyond step_dict[seq]
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for v
seq_s = self.sequences[seq] - step_dict[seq] | ||
seq_e = self.sequences[seq] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These could potentially be moved into a method since this could be reused from outside like when getting the start positions of RoPE embeddings application
Description
This PR adds paged attention support for FusedAttention, FlashAttention, and UnfusedDotProductAttention.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: