Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,7 @@ def __init__(
self.disable_split_kv = False

self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(
self.model_config.max_model_len, self.kv_cache_spec.block_size
)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
speculative_config = vllm_config.speculative_config
num_spec_tokens = (
speculative_config.num_speculative_tokens
Expand All @@ -333,7 +329,21 @@ def __init__(
self.num_kv_heads = self.kv_cache_spec.num_kv_heads
self.head_dim = self.kv_cache_spec.head_size
FlashInferBackend.validate_head_size(self.head_dim)
self.page_size = self.kv_cache_spec.block_size

# IMPORTANT: page_size must match the actual kernel block size,
# not the KV manager block size!
# When using hybrid blocks(self.kv_cache_spec.block_size != kernel_block_size),
# the KV cache is allocated with kernel_block_size, so we must use that
# for page_size when calling FlashInfer.
kernel_block_size = self.kv_cache_spec.find_compatible_kernel_block_sizes(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you refactor _select_common_block_size a bit so that you can call something like GPUModelRunner._select_common_block_size? I don't think it's a good idea to move to self.kv_cache_spec.find_compatible_kernel_block_sizes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid misunderstand I'd like to double check do you mean find_compatible_kernel_block_sizes or _select_common_block_size?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends on you. Just call GPUModelRunner.xxxx() here and minimize the changes to gpu model runner

FlashInferBackend, return_all=False
)[0]
self.page_size = kernel_block_size

# Calculate buffer sizes using the actual kernel block size (page_size)
# to ensure buffers are correctly sized in hybrid mode
max_num_pages_per_req = cdiv(self.model_config.max_model_len, self.page_size)
max_num_pages = max_num_reqs * max_num_pages_per_req

self.cache_dtype = self.cache_config.cache_dtype
if self.cache_dtype.startswith("fp8"):
Expand Down
44 changes: 44 additions & 0 deletions vllm/v1/kv_cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
from dataclasses import dataclass, fields
from math import prod
from typing import TYPE_CHECKING

import torch
from typing_extensions import Self
Expand All @@ -13,6 +14,9 @@
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import get_dtype_size

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend

logger = init_logger(__name__)


Expand Down Expand Up @@ -71,6 +75,46 @@ def page_size_bytes(self) -> int:
* get_dtype_size(self.dtype)
)

def find_compatible_kernel_block_sizes(
self, backend_cls: "type[AttentionBackend]", return_all: bool = False
) -> list[int]:
"""Find compatible kernel block sizes for this spec and backend.

Args:
backend_cls: The attention backend class
return_all: If True, return all compatible sizes;
If False, return only the max

Returns:
List of compatible kernel block sizes

Raises:
ValueError: If no compatible block size found
"""
from vllm.attention.backends.abstract import MultipleOf

kv_manager_block_size = self.block_size
supported_sizes = backend_cls.get_supported_kernel_block_size()
compatible_sizes = []

for block_size in supported_sizes:
if isinstance(block_size, int) and kv_manager_block_size % block_size == 0:
compatible_sizes.append(block_size)
elif (
isinstance(block_size, MultipleOf)
and kv_manager_block_size % block_size.base == 0
):
compatible_sizes.append(kv_manager_block_size)

if not compatible_sizes:
raise ValueError(
f"No compatible kernel block size found for "
f"kv_manager_block_size={kv_manager_block_size}, "
f"supported_sizes={supported_sizes}"
)

return compatible_sizes if return_all else [max(compatible_sizes)]


@dataclass(frozen=True)
class FullAttentionSpec(AttentionSpec):
Expand Down
53 changes: 10 additions & 43 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import vllm.envs as envs
from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionBackend, MultipleOf
from vllm.attention.backends.abstract import AttentionBackend
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
Expand Down Expand Up @@ -4146,44 +4146,6 @@ def calculate_reorder_batch_threshold(self) -> None:
else:
self.reorder_batch_threshold = reorder_batch_threshold_i

def _find_compatible_block_sizes(
self,
kv_manager_block_size: int,
backend_cls: type[AttentionBackend],
return_all: bool = False,
) -> list[int]:
"""
Find compatible block sizes for a backend.

Args:
kv_manager_block_size: Physical block size of KV cache
backend_cls: Attention backend class
return_all: Return all compatible sizes if True, max size if False

Returns:
Compatible block size(s) based on return_all parameter

Raises:
ValueError: If no compatible block size found
"""
supported_block_size = backend_cls.get_supported_kernel_block_size()
compatible_sizes = []

for block_size in supported_block_size:
if isinstance(block_size, int):
if kv_manager_block_size % block_size == 0:
compatible_sizes.append(block_size)
elif (
isinstance(block_size, MultipleOf)
and kv_manager_block_size % block_size.base == 0
):
compatible_sizes.append(kv_manager_block_size)

if not compatible_sizes:
raise ValueError(f"No compatible block size for {kv_manager_block_size}")

return compatible_sizes if return_all else [max(compatible_sizes)]

def _select_common_block_size(
self, kv_manager_block_size: int, attn_groups: list[AttentionGroup]
) -> int:
Expand All @@ -4204,8 +4166,13 @@ def _select_common_block_size(
all_backend_supports = []

for attn_group in attn_groups:
compatible_sizes = self._find_compatible_block_sizes(
kv_manager_block_size, attn_group.backend, return_all=True
kv_cache_spec = attn_group.kv_cache_spec
assert isinstance(kv_cache_spec, AttentionSpec), (
f"Expected AttentionSpec for attention group {attn_group.layer_names}, "
f"but got {type(kv_cache_spec)}"
)
compatible_sizes = kv_cache_spec.find_compatible_kernel_block_sizes(
attn_group.backend, return_all=True
)
supported_sizes = sorted(list(set(compatible_sizes)), reverse=True)
all_backend_supports.append(set(supported_sizes))
Expand Down Expand Up @@ -4388,8 +4355,8 @@ def _reshape_kv_cache_tensors(
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
kv_manager_block_size = kv_cache_spec.block_size
kernel_size_list = self._find_compatible_block_sizes(
kv_manager_block_size, attn_backend, return_all=False
kernel_size_list = kv_cache_spec.find_compatible_kernel_block_sizes(
attn_backend, return_all=False
)
kernel_size = kernel_size_list[0]
num_blocks_per_kv_block = kv_manager_block_size // kernel_size
Expand Down