Skip to content

change kv cache layout #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 387 commits into
base: main
Choose a base branch
from
Open

change kv cache layout #5

wants to merge 387 commits into from

Conversation

yaochengji
Copy link
Owner

FILL IN THE PR DESCRIPTION HERE

FIX #xxxx (link existing issues this PR will resolve)

BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html

@@ -130,7 +130,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
# Calculate the TPU KV cache size based on profiling.
usable_memory_size = int(total_memory_size *
self.cache_config.gpu_memory_utilization)
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) // 4

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do "// 4"?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVM, I will remove it. This is only used to understand the cost of the torch.permute on the pre-allocated kv cache buffer. Anyway after the new pallas attn kernel is finished, the torch.permute will disappear.

@@ -172,14 +173,14 @@ def _dummy_run(
) -> None:
exec_mode = ExecutionMode(exec_mode)
if exec_mode.is_prefill():
seq_len = (seq_len + 15) // 16 * 16
seq_len = (seq_len + 15) // MIN_PREFILL_SEQ_LEN * MIN_PREFILL_SEQ_LEN

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace 15 with "MIN_PREFILL_SEQ_LEN-1"?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

key_cache = key_cache.flatten(0, 1)
value_cache = value_cache.flatten(0, 1)
slot_mapping = slot_mapping.flatten()
key_cache.index_copy_(0, slot_mapping, key)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc, you are copying block by block (of size hn x hs) hence achieved the speedup?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's now block_size x hn x hs per copy.

value = value.flatten(0, 1)
key_cache = key_cache.flatten(0, 1)
value_cache = value_cache.flatten(0, 1)
slot_mapping = slot_mapping.flatten()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why you need to flatten the slot_mapping here

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a 2D tensor before that

@@ -11,6 +11,8 @@
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState

MIN_PREFILL_SEQ_LEN = 16
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we clarify what specific TPU generations work optimally with this seq-len? Does it make sense to add this into a config file?

Wondering if v5p or older generations would require a different threshold. Same comment regarding future TPU generations. :)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically the larger the better. It has two limits:

  1. the new size cannot be smaller than the actual size
  2. the new size cannot be larger than the page size

I'm actually thinking about directly set MIN_PREFILL_SEQ_LEN to page size for simplicity, then we can control the size by the page size

Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale label May 15, 2025
DarkLight1337 and others added 21 commits May 26, 2025 00:45
Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
dsikka and others added 18 commits June 5, 2025 18:21
Signed-off-by: Luis Vega <[email protected]>
Co-authored-by: Luis Vega <[email protected]>
@yaochengji yaochengji force-pushed the chengji/kv-cache branch 2 times, most recently from d0e50eb to 1f7c757 Compare June 6, 2025 17:48
Signed-off-by: Chengji Yao <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.