From 0aee7ad3269ec5a3af3027d36e7a01d31df5752e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sat, 25 Oct 2025 06:22:02 -0700 Subject: [PATCH 01/16] wip Signed-off-by: Lucas Wilkinson --- csrc/cache.h | 18 +- csrc/cache_kernels.cu | 288 +++++++++++- csrc/ops.h | 6 + csrc/torch_bindings.cpp | 18 + vllm/_custom_ops.py | 54 +++ vllm/envs.py | 4 + .../layers/fused_moe/modular_kernel.py | 97 ++-- vllm/model_executor/models/deepseek_v2.py | 44 +- .../attention/backends/mla/flashmla_sparse.py | 427 ++++++++++++------ vllm/v1/attention/backends/mla/indexer.py | 4 +- vllm/v1/core/kv_cache_utils.py | 2 + vllm/v1/worker/gpu_model_runner.py | 6 + vllm/v1/worker/gpu_worker.py | 5 + vllm/v1/worker/workspace.py | 318 +++++++++++++ 14 files changed, 1096 insertions(+), 195 deletions(-) create mode 100644 vllm/v1/worker/workspace.py diff --git a/csrc/cache.h b/csrc/cache.h index b162a4a2bc31..49245cbca139 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -57,6 +58,15 @@ void cp_gather_cache( torch::Tensor const& cu_seq_lens, // [BATCH+1] int64_t batch_size, std::optional seq_starts = std::nullopt); +// Gather and upconvert FP8 KV cache to BF16 workspace +void cp_gather_and_upconvert_fp8_kv_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + torch::Tensor const& dst, // [TOT_TOKENS, 576] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& seq_lens, // [BATCH] + torch::Tensor const& workspace_starts, // [BATCH] + int64_t batch_size); + // Indexer K quantization and cache function void indexer_k_quant_and_cache( torch::Tensor& k, // [num_tokens, head_dim] @@ -71,4 +81,10 @@ void cp_gather_indexer_k_quant_cache( torch::Tensor& dst_k, // [num_tokens, head_dim] torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] const torch::Tensor& block_table, // [batch_size, num_blocks] - const torch::Tensor& cu_seq_lens); // [batch_size + 1] \ No newline at end of file + const torch::Tensor& cu_seq_lens); // [batch_size + 1] + +torch::Tensor convert_logical_index_to_physical_index( + torch::Tensor req_id, torch::Tensor block_table, + torch::Tensor token_indices, int64_t block_size, + const std::optional& prefill_request_id, + const std::optional& workspace_starts); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 0aa0dc14c748..ed97f63ab851 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -2,6 +2,7 @@ #include #include #include +#include #include "cuda_utils.h" #include "cuda_compat.h" @@ -514,7 +515,8 @@ __global__ void indexer_k_quant_and_cache_kernel( const int quant_block_size, // quantization block size const int cache_block_size, // cache block size const int cache_stride, // stride for each token in kv_cache - const bool use_ue8m0 // use ue8m0 scale format + + const bool use_ue8m0 // use ue8m0 scale format ) { constexpr int VEC_SIZE = 4; const int64_t token_idx = blockIdx.x; @@ -1058,6 +1060,84 @@ void gather_and_maybe_dequant_cache( } namespace vllm { + +// Gather and upconvert FP8 KV cache tokens to BF16 workspace +// Similar to cp_gather_cache but specifically for FP8->BF16 conversion +__global__ void cp_gather_and_upconvert_fp8_kv_cache( + const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + __nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ seq_lens, // [BATCH] + const int32_t* __restrict__ workspace_starts, // [BATCH] + const int32_t block_size, const int32_t head_dim, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride) { + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = workspace_starts[bid]; + const int32_t seq_len = seq_lens[bid]; + const int32_t tot_slots = seq_len; + const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits); + + const int32_t split_start = split * split_slots; + const int32_t split_end = min((split + 1) * split_slots, tot_slots); + + const bool is_active_split = (split_start < tot_slots); + + if (!is_active_split) return; + + // Adjust the pointer for the block_table for this batch + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = split_start; + int32_t offset_div = offset / block_size; + offset = offset % block_size; + const int32_t* batch_block_table = block_table + batch_offset; + + // Adjust dst pointer based on the cumulative sequence lengths + dst += seq_start * dst_entry_stride; + + const int tid = threadIdx.x; + + // Process each token in this split + for (int pid = split_start; pid < split_end; ++pid) { + auto block_id = batch_block_table[offset_div]; + const uint8_t* token_ptr = + src_cache + block_id * cache_block_stride + offset * cache_entry_stride; + __nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride; + + // FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16) + const uint8_t* no_pe_ptr = token_ptr; + const float* scales_ptr = reinterpret_cast(token_ptr + 512); + const __nv_bfloat16* rope_ptr = + reinterpret_cast(token_ptr + 512 + 16); + + // Parallelize fp8 dequant (512 elements) and rope copy (64 elements) + if (tid < 512) { + // FP8 dequantization + const int tile = tid >> 7; // each tile is 128 elements + const float scale = scales_ptr[tile]; + const uint8_t val = no_pe_ptr[tid]; + dst_ptr[tid] = + fp8::scaled_convert<__nv_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale); + } else if (tid < 576) { + // Rope copy (64 bf16 elements) + const int rope_idx = tid - 512; + dst_ptr[512 + rope_idx] = rope_ptr[rope_idx]; + } + // Threads 576+ are idle + // No sync needed - each iteration processes independent tokens + + // Move to next token + offset += 1; + if (offset == block_size) { + offset_div += 1; + offset = 0; + } + } +} + template // Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by // block_size. @@ -1199,6 +1279,58 @@ void cp_gather_cache( } } +// Host function to launch the gather-and-upconvert kernel +void cp_gather_and_upconvert_fp8_kv_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + torch::Tensor const& dst, // [TOT_TOKENS, 576] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& seq_lens, // [BATCH] + torch::Tensor const& workspace_starts, // [BATCH] + int64_t batch_size) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t head_dim = dst.size(1); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32"); + TORCH_CHECK(workspace_starts.dtype() == torch::kInt32, + "workspace_starts must be int32"); + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == seq_lens.device(), + "src_cache and seq_lens must be on the same device"); + TORCH_CHECK(src_cache.device() == workspace_starts.device(), + "src_cache and workspace_starts must be on the same device"); + + TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8"); + TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16"); + TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA"); + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + // Decide on the number of splits based on the batch size + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + dim3 grid(batch_size, num_splits); + dim3 block(1024); + + vllm::cp_gather_and_upconvert_fp8_kv_cache<<>>( + src_cache.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()), + block_table.data_ptr(), seq_lens.data_ptr(), + workspace_starts.data_ptr(), block_size, head_dim, + block_table_stride, cache_block_stride, cache_entry_stride, + dst_entry_stride); +} + // Macro to dispatch the kernel based on the data type. #define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ vllm::indexer_k_quant_and_cache_kernel \ @@ -1238,7 +1370,159 @@ void indexer_k_quant_and_cache( CALL_INDEXER_K_QUANT_AND_CACHE); } -// Macro to dispatch the kernel based on the data amount. +namespace vllm { + +// Simplified kernel: convert per-request indices to global slots or workspace +// offsets +__global__ void convert_logical_index_to_physical_index_kernel( + const int32_t* __restrict__ req_id, // [num_tokens] + const int32_t* __restrict__ block_table, // [num_requests, + // max_num_blocks_per_req] + const int32_t* __restrict__ token_indices, // [num_tokens, NUM_TOPK_TOKENS] + int32_t* __restrict__ out, // [num_tokens, NUM_TOPK_TOKENS] + const int32_t* __restrict__ prefill_request_id, // [num_tokens], -1 for + // decode, >=0 for prefill + const int32_t* __restrict__ workspace_starts, // [num_prefill_reqs+1] or + // nullptr + int num_topk_tokens, int block_size, int max_num_blocks_per_req, + int bt_stride0, int bt_stride1, int ti_stride0, int ti_stride1, + int out_stride0, int out_stride1) { + const int token_id = blockIdx.x; + const int tid = threadIdx.x; + + // Load request id and prefill request id for this token + const int req = req_id[token_id]; + const int prefill_req_id = + prefill_request_id != nullptr ? prefill_request_id[token_id] : -1; + const bool is_prefill = prefill_req_id >= 0; + + // Loop over topk_indices + for (int indice_id = tid; indice_id < num_topk_tokens; + indice_id += blockDim.x) { + // Load token index (logical index within request) + const int ti_offset = token_id * ti_stride0 + indice_id * ti_stride1; + const int tok = token_indices[ti_offset]; + + // Check if token is invalid + bool is_invalid = tok < 0; + + int out_val = -1; + + if (is_prefill && workspace_starts != nullptr && !is_invalid) { + // Map to workspace offset: workspace_starts[prefill_req_id] + + // logical_token_id + out_val = workspace_starts[prefill_req_id] + tok; + } else if (!is_invalid) { + // Map to global cache slot (decode path) + // Compute block id and in-block offset + const int block_id = tok / block_size; + const int inblock_off = tok % block_size; + + // Guard block_table access + const bool valid_block = block_id < max_num_blocks_per_req; + int base = 0; + if (valid_block) { + const int bt_offset = req * bt_stride0 + block_id * bt_stride1; + base = block_table[bt_offset]; + } + + if (valid_block) { + out_val = base * block_size + inblock_off; + } + } + + // Store result + const int out_offset = token_id * out_stride0 + indice_id * out_stride1; + out[out_offset] = out_val; + } +} + +} // namespace vllm + +// Host function to launch the simplified index conversion kernel +torch::Tensor convert_logical_index_to_physical_index( + torch::Tensor req_id, // int32 [num_tokens] + torch::Tensor block_table, // int32 [num_requests, max_num_blocks_per_req] + torch::Tensor token_indices, // int32 [num_tokens, NUM_TOPK_TOKENS] + int64_t block_size, // KV cache block size + const std::optional& + prefill_request_id, // int32 [num_tokens], -1 for decode + const std::optional& + workspace_starts // int32 [num_prefill_reqs+1] +) { + constexpr int THREADS_PER_BLOCK = 256; + + // Validate input tensors + TORCH_CHECK(req_id.is_cuda(), "req_id must be a CUDA tensor"); + TORCH_CHECK(block_table.is_cuda(), "block_table must be a CUDA tensor"); + TORCH_CHECK(token_indices.is_cuda(), "token_indices must be a CUDA tensor"); + TORCH_CHECK(req_id.dtype() == torch::kInt32, "req_id must be int32"); + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(token_indices.dtype() == torch::kInt32, + "token_indices must be int32"); + + // Ensure contiguous + req_id = req_id.contiguous(); + block_table = block_table.contiguous(); + token_indices = token_indices.contiguous(); + + // Extract dimensions + const int num_tokens = req_id.size(0); + const int num_topk_tokens = token_indices.size(1); + const int max_num_blocks_per_req = block_table.size(1); + + // Create output tensor + auto out = torch::empty_like(token_indices); + + // Extract strides + const int bt_stride0 = block_table.stride(0); + const int bt_stride1 = block_table.stride(1); + const int ti_stride0 = token_indices.stride(0); + const int ti_stride1 = token_indices.stride(1); + const int out_stride0 = out.stride(0); + const int out_stride1 = out.stride(1); + + // Handle optional prefill tensors + const int32_t* prefill_request_id_ptr = nullptr; + const int32_t* workspace_starts_ptr = nullptr; + + if (prefill_request_id.has_value()) { + auto& prid = prefill_request_id.value(); + TORCH_CHECK(prid.is_cuda(), "prefill_request_id must be a CUDA tensor"); + TORCH_CHECK(prid.is_contiguous(), "prefill_request_id must be contiguous"); + TORCH_CHECK(prid.dtype() == torch::kInt32, + "prefill_request_id must be int32"); + prefill_request_id_ptr = prid.data_ptr(); + } + + if (workspace_starts.has_value()) { + auto& ws = workspace_starts.value(); + TORCH_CHECK(ws.is_cuda(), "workspace_starts must be a CUDA tensor"); + TORCH_CHECK(ws.is_contiguous(), "workspace_starts must be contiguous"); + TORCH_CHECK(ws.dtype() == torch::kInt32, "workspace_starts must be int32"); + workspace_starts_ptr = ws.data_ptr(); + } + + // Get CUDA stream + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Launch kernel + dim3 grid(num_tokens); + dim3 block(THREADS_PER_BLOCK); + + vllm::convert_logical_index_to_physical_index_kernel<<>>( + req_id.data_ptr(), block_table.data_ptr(), + token_indices.data_ptr(), out.data_ptr(), + prefill_request_id_ptr, workspace_starts_ptr, num_topk_tokens, block_size, + max_num_blocks_per_req, bt_stride0, bt_stride1, ti_stride0, ti_stride1, + out_stride0, out_stride1); + + return out; +} + +// Macro to dispatch the kernel based on the data type. #define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \ vllm::cp_gather_indexer_k_quant_cache_kernel \ <<& prefill_request_id, + const std::optional& workspace_starts); + void convert_vertical_slash_indexes( torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 9c0f524dcab1..452b6ff8605d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -89,6 +89,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor suffix_lse) -> ()"); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); + ops.def( + "convert_logical_index_to_physical_index(" + " Tensor req_id," + " Tensor block_table," + " Tensor token_indices," + " int block_size," + " Tensor? prefill_request_id," + " Tensor? workspace_starts) -> Tensor"); + ops.impl("convert_logical_index_to_physical_index", torch::kCUDA, + &convert_logical_index_to_physical_index); + ops.def( "convert_vertical_slash_indexes(" " Tensor! block_count, Tensor! block_offset, " @@ -726,6 +737,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); + cache_ops.def( + "cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, " + "Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int " + "batch_size) -> ()"); + cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA, + &cp_gather_and_upconvert_fp8_kv_cache); + cache_ops.def( "indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor " "slot_mapping, " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 657b11046809..feb8c8b476ed 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2149,6 +2149,29 @@ def cp_gather_cache( ) +def cp_gather_and_upconvert_fp8_kv_cache( + src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor, + workspace_starts: torch.Tensor, + batch_size: int, +) -> None: + """Gather and upconvert FP8 KV cache to BF16 workspace. + + Args: + src_cache: FP8 KV cache [num_blocks, block_size, 656] + dst: BF16 output workspace [total_tokens, 576] + block_table: Block indices [num_reqs, max_blocks] + seq_lens: Sequence lengths [num_reqs] + workspace_starts: Workspace start offsets [num_reqs] + batch_size: Number of requests + """ + torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache( + src_cache, dst, block_table, seq_lens, workspace_starts, batch_size + ) + + def indexer_k_quant_and_cache( k: torch.Tensor, kv_cache: torch.Tensor, @@ -2173,6 +2196,37 @@ def cp_gather_indexer_k_quant_cache( ) +def convert_logical_index_to_physical_index( + req_id: torch.Tensor, + block_table: torch.Tensor, + token_indices: torch.Tensor, + block_size: int, + prefill_request_id: torch.Tensor | None = None, + workspace_starts: torch.Tensor | None = None, +) -> torch.Tensor: + """Convert per-request logical indices to physical cache slots or workspace offsets. + + For decode tokens, maps to physical cache slots. + For prefill tokens, maps to workspace offsets. + + Args: + req_id: Request ID for each token + block_table: Block table mapping requests to cache blocks + token_indices: Per-request logical token indices to convert + block_size: Size of each cache block + prefill_request_id: Request ID for prefill tokens (-1 for decode) + workspace_starts: Cumulative sum of prefill sequence lengths + """ + return torch.ops._C.convert_logical_index_to_physical_index( + req_id, + block_table, + token_indices, + block_size, + prefill_request_id, + workspace_starts, + ) + + def get_device_attribute(attribute: int, device: int) -> int: return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) diff --git a/vllm/envs.py b/vllm/envs.py index 81f189ada9a6..746fae1cce30 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -217,6 +217,7 @@ VLLM_NCCL_INCLUDE_PATH: str | None = None VLLM_USE_FBGEMM: bool = False VLLM_GC_DEBUG: str = "" + VLLM_DEBUG_WORKSPACE: bool = False VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" @@ -1439,6 +1440,9 @@ def get_vllm_port() -> int | None: # - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with # top 5 collected objects "VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""), + # Debug workspace allocations. + # logging of workspace resize operations. + "VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))), # Disables parallel execution of shared_experts via separate cuda stream "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv( "VLLM_DISABLE_SHARED_EXPERTS_STREAM", False diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 3b5916f8ccaf..dec7eeb2ae9f 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -10,6 +10,7 @@ import torch import vllm.envs as envs +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, @@ -18,12 +19,12 @@ ) from vllm.utils.math_utils import cdiv from vllm.v1.worker.ubatching import ( - dbo_current_ubatch_id, dbo_enabled, dbo_maybe_run_recv_hook, dbo_register_recv_hook, dbo_yield, ) +from vllm.v1.worker.workspace import WorkspaceSpec, current_workspace_manager # # This file defines a set of base classes used to make MoE kernels more modular. @@ -639,25 +640,6 @@ def _slice_scales( return None -class SharedResizableBuffer: - def __init__(self): - self.buffer = None - - def get( - self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype - ) -> torch.Tensor: - assert shape != () - shape_numel = prod(shape) - if ( - self.buffer is None - or self.buffer.numel() < shape_numel - or self.buffer.device != device - or self.buffer.dtype != dtype - ): - self.buffer = torch.empty(shape_numel, device=device, dtype=dtype) - return self.buffer[:shape_numel].view(*shape) - - @final class FusedMoEModularKernel(torch.nn.Module): """ @@ -672,22 +654,6 @@ class FusedMoEModularKernel(torch.nn.Module): objects. """ - class SharedBuffers: - def __init__(self) -> None: - self.fused_out = SharedResizableBuffer() - self.workspace13 = SharedResizableBuffer() - self.workspace2 = SharedResizableBuffer() - - # Persistent buffers that are shared across `FusedMoEModularKernel` - # instances (layers), to save memory and allocattions. - # - # We have two sets of buffers to support dual batch overlap (DBO) where each - # microbatch (ubatch) should use its own set of buffers to avoid - # cross-ubatch contimination. - # NOTE that memory is lazily allocated for these buffers, meaning that if - # DBO isn't being used, the second SharedBuffers will be empty. - shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()] - def __init__( self, prepare_finalize: FusedMoEPrepareAndFinalize, @@ -758,10 +724,6 @@ def _allocate_buffers( assert M_full > 0 and M_chunk > 0 num_chunks, _ = self._chunk_info(M_full) - - # select per-ubatch buffers to avoid cross-ubatch reuse under DBO - ubatch_idx = dbo_current_ubatch_id() - buffers = self.shared_buffers[ubatch_idx] workspace_dtype = self.fused_experts.workspace_dtype(out_dtype) # Get intermediate workspace shapes based off the chunked M size. @@ -786,13 +748,48 @@ def _allocate_buffers( expert_tokens_meta, ) + # For modular kernels that use "mk.FusedMoEModularKernel.Standard" format + # we may not see the worst case during profiling in the DP+EP case due to + # random token routing. Force allocating the worst case. + is_profile_run = get_forward_context().attn_metadata is None + if is_profile_run and self.fused_experts.supports_chunking(): + ( + max_workspace13_shape, + max_workspace2_shape, + max_fused_out_shape, + ) = self.fused_experts.workspace_shapes( + envs.VLLM_FUSED_MOE_CHUNK_SIZE, + N, + K, + top_k, + global_num_experts, + local_num_experts, + None, # Pass None to avoid using sampled token counts + ) + max_workspace13_spec = WorkspaceSpec( + shape=max_workspace13_shape, + dtype=workspace_dtype, + name="moe.workspace13", + ) + max_workspace2_spec = WorkspaceSpec( + shape=max_workspace2_shape, + dtype=workspace_dtype, + name="moe.workspace2", + ) + max_fused_out_spec = WorkspaceSpec( + shape=max_fused_out_shape, dtype=out_dtype, name="moe.fused_out" + ) + current_workspace_manager().reserve_simultaneous( + max_workspace13_spec, max_workspace2_spec, max_fused_out_spec + ) + # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. - workspace13 = buffers.workspace13.get( - workspace13_shape, device=device, dtype=workspace_dtype + workspace13_spec = WorkspaceSpec( + shape=workspace13_shape, dtype=workspace_dtype, name="moe.workspace13" ) - workspace2 = buffers.workspace2.get( - workspace2_shape, device=device, dtype=workspace_dtype + workspace2_spec = WorkspaceSpec( + shape=workspace2_shape, dtype=workspace_dtype, name="moe.workspace2" ) # Construct the entire output that can then be processed in chunks. @@ -800,10 +797,18 @@ def _allocate_buffers( # as it is large enough. This will not always be the case for standard # format experts and with experts that have empty workspaces. if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape): + workspace13, workspace2 = current_workspace_manager().get_simultaneous( + workspace13_spec, workspace2_spec + ) fused_out = _resize_cache(workspace13, fused_out_shape) else: - fused_out = buffers.fused_out.get( - fused_out_shape, device=device, dtype=out_dtype + fused_out_spec = WorkspaceSpec( + shape=fused_out_shape, dtype=out_dtype, name="moe.fused_out" + ) + workspace13, workspace2, fused_out = ( + current_workspace_manager().get_simultaneous( + workspace13_spec, workspace2_spec, fused_out_spec + ) ) return workspace13, workspace2, fused_out diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index db7b86ffaf96..ce510f4e12e7 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -86,6 +86,7 @@ DeepseekV32IndexerMetadata, ) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec +from vllm.v1.worker.workspace import WorkspaceSpec, current_workspace_manager from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import ( @@ -520,6 +521,20 @@ def sparse_attn_indexer( attn_metadata = get_forward_context().attn_metadata # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): + # Reserve workspace memory for the actual run. + k_fp8_spec = WorkspaceSpec( + shape=(total_seq_lens, head_dim), + dtype=torch.float8_e4m3fn, + name="sparse_attn_indexer.k_fp8", + ) + k_scale_spec = WorkspaceSpec( + shape=(total_seq_lens, 4), + dtype=torch.uint8, + name="sparse_attn_indexer.k_scale", + ) + + current_workspace_manager().reserve_simultaneous(k_fp8_spec, k_scale_spec) + return sparse_attn_indexer_fake( hidden_states, k_cache_prefix, @@ -553,17 +568,26 @@ def sparse_attn_indexer( topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: prefill_metadata = attn_metadata.prefill + k_fp8_spec = WorkspaceSpec( + shape=(total_seq_lens, head_dim), + dtype=torch.float8_e4m3fn, + name="sparse_attn_indexer.k_fp8", + ) + k_scale_spec = WorkspaceSpec( + shape=(total_seq_lens, 4), + dtype=torch.uint8, + name="sparse_attn_indexer.k_scale", + ) + + # Get the full shared workspace buffers once (will allocate on first use) + workspace_manager = current_workspace_manager() + k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( + k_fp8_spec, k_scale_spec + ) + for chunk in prefill_metadata.chunks: - k_fp8 = torch.empty( - [chunk.total_seq_lens, head_dim], - device=k.device, - dtype=torch.float8_e4m3fn, - ) - k_scale = torch.empty( - [chunk.total_seq_lens, 4], - device=k.device, - dtype=torch.uint8, - ) + k_fp8 = k_fp8_full[: chunk.total_seq_lens] + k_scale = k_scale_full[: chunk.total_seq_lens] ops.cp_gather_indexer_k_quant_cache( kv_cache, k_fp8, diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index bf8e4d5a6289..b306df09fff2 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -24,12 +24,18 @@ from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl +from vllm.v1.attention.backends.mla.indexer import ( + get_max_prefill_buffer_size, + split_prefill_chunks, +) from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.workspace import WorkspaceSpec, current_workspace_manager if TYPE_CHECKING: from vllm.model_executor.models.deepseek_v2 import Indexer @@ -108,6 +114,49 @@ class FlashMLASparseMetadata: block_size: int = 64 topk_tokens: int = 2048 + num_prefill_reqs: int = 0 + num_decode_reqs: int = 0 + num_prefill_tokens: int = 0 + num_decode_tokens: int = 0 + + # Sequence lengths (context + query) for prefill requests + prefill_seq_lens: torch.Tensor | None = None + + # Request ID for each token: -1 for decode tokens, request index + # (0, 1, 2, ...) for prefill tokens. Shape: [num_actual_tokens] + prefill_request_id: torch.Tensor | None = None + + # Workspace start offsets for all prefill requests + # Shape: [num_prefill_reqs], adjusted in-place per chunk to be + # 0-indexed within each chunk. Used to map prefill tokens to workspace + # offsets in convert_logical_index_to_physical_index + prefill_workspace_starts: torch.Tensor | None = None + + @dataclass + class ChunkMetadata: + """Metadata for a chunk of prefill requests. + + Prefill requests may be chunked to fit within the fixed workspace size. + """ + + seq_lens: ( + torch.Tensor + ) # [num_reqs_in_chunk] sequence lengths (for gather kernel) + tokens_slice: slice # Slice to extract query tokens for this chunk + block_table: ( + torch.Tensor + ) # [num_reqs_in_chunk, max_blocks] block table for chunk + req_start_idx: int # Starting request index in the original request list + workspace_starts: ( + torch.Tensor + ) # [num_reqs_in_chunk] workspace starts, adjusted to start at 0 for this chunk + request_slice: ( + slice # Slice to extract requests for this chunk from the full request list + ) + chunk_size: int # Total number of tokens in this chunk (sum of seq_lens) + + prefill_chunks: list[ChunkMetadata] | None = None + @dataclass class FP8KernelMetadata: scheduler_metadata: torch.Tensor | None @@ -118,126 +167,6 @@ class FP8KernelMetadata: fp8_extra_metadata: FP8KernelMetadata | None = None -@triton.jit -def _convert_req_index_to_global_index_kernel( - req_id_ptr, # int32 [num_tokens] - block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] - token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] - out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] - # shapes (compile-time where possible) - max_num_blocks_per_req: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - BLOCK_N: tl.constexpr, # tile width along columns - # strides (in elements) - bt_stride0, - bt_stride1, - ti_stride0, - ti_stride1, - out_stride0, - out_stride1, -): - # program_id(0) -> token_id (row) - # program_id(1) -> tile index along columns - token_id = tl.program_id(0) - tile_id = tl.program_id(1) - - # Each program covers BLOCK_N consecutive columns - indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) - - # Load request id for this token (no mask: grid is exact) - req = tl.load(req_id_ptr + token_id) - - # Load token indices for this tile - ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 - tok = tl.load(ti_ptr) # int32 - - # Only token == -1 should propagate as -1 - is_invalid_tok = tok < 0 - - # Compute block id and in-block offset - block_id = tok // BLOCK_SIZE - inblock_off = tok % BLOCK_SIZE - - # Guard block_table access - valid_block = block_id < max_num_blocks_per_req - bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 - base = tl.load(bt_ptr, mask=valid_block, other=0) - - # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset - out_val = tl.where( - is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off - ) - - # Store results - out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 - tl.store(out_ptr_ij, out_val) - - -def triton_convert_req_index_to_global_index( - req_id: torch.Tensor, # int32 [num_tokens] - block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] - token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] - BLOCK_SIZE: int = 64, - NUM_TOPK_TOKENS: int = 2048, - BLOCK_N: int = 128, # tile width along columns -): - """ - out[token_id, indice_id] = - block_table[req_id[token_id], - token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE - + token_indices[token_id, indice_id] % BLOCK_SIZE - - Only when token_indices[token_id, indice_id] == -1 do we output -1. - For safety, we also output -1 if the derived block_id would be - out-of-bounds. - """ - assert req_id.dtype == torch.int32 - assert block_table.dtype == torch.int32 - assert token_indices.dtype == torch.int32 - assert token_indices.shape[1] == NUM_TOPK_TOKENS - assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( - f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})" - ) - - num_tokens = req_id.shape[0] - num_requests, max_num_blocks_per_req = block_table.shape - tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N - - # Ensure contiguous tensors on the same device - req_id_c = req_id.contiguous() - block_table_c = block_table.contiguous() - token_indices_c = token_indices.contiguous() - out = torch.empty_like(token_indices_c) - - # Strides in elements - bt_stride0, bt_stride1 = block_table_c.stride() - ti_stride0, ti_stride1 = token_indices_c.stride() - out_stride0, out_stride1 = out.stride() - - # Exact 2D grid: tokens × column tiles - grid = (num_tokens, tiles_per_row) - - _convert_req_index_to_global_index_kernel[grid]( - req_id_c, - block_table_c, - token_indices_c, - out, - # shapes / constexprs - max_num_blocks_per_req, - BLOCK_SIZE, - BLOCK_N, - # strides - bt_stride0, - bt_stride1, - ti_stride0, - ti_stride1, - out_stride0, - out_stride1, - ) - return out - - -@dataclass class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH @@ -247,18 +176,25 @@ def __init__( layer_names: list[str], vllm_config: VllmConfig, device: torch.device, - ): + ) -> None: + self.vllm_config = vllm_config + self.layer_names = layer_names cache_config = vllm_config.cache_config self.kv_cache_spec = kv_cache_spec self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.device = device + # Treat requests with query length <= 1 as decodes to match the + # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2) + self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) + props = torch.cuda.get_device_properties(device) sm_count = props.multi_processor_count self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) + self.topk_tokens = vllm_config.model_config.hf_config.index_topk self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" self.topk_tokens_tensor = torch.tensor( @@ -319,6 +255,135 @@ def build( ) req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.reorder_batch_threshold or 1 + ) + ) + + num_prefill_reqs = num_prefills + num_decode_reqs = num_decodes + prefill_token_count = num_prefill_tokens + decode_token_count = num_decode_tokens + + assert num_prefill_reqs + num_decode_reqs == common_attn_metadata.num_reqs + assert prefill_token_count + decode_token_count == num_tokens + + # Extract prefill sequence lengths (context + query, not just query) + # Decode requests come first in the batch, prefill requests follow + prefill_seq_lens = None + prefill_request_id = None + prefill_workspace_starts = None + prefill_chunks = None + + # For pure decode batches, prefill_request_id will be None + # For mixed batches, it will have -1 for decode and request_id for prefill + if num_prefill_reqs > 0: + # Get sequence lengths from common_attn_metadata + # seq_lens includes both context and query + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + # Prefill requests are at the end (after decode requests) + prefill_seq_lens_cpu = seq_lens_cpu[num_decode_reqs:] + prefill_seq_lens = torch.tensor( + prefill_seq_lens_cpu, dtype=torch.int32, device=self.device + ) + + # Build prefill_request_id: -1 for decode, request index for + # prefill. This enables a single + # convert_logical_index_to_physical_index call for all tokens + prefill_request_id = torch.full( + (num_tokens,), -1, dtype=torch.int32, device=self.device + ) + # Map prefill tokens to their request IDs (0, 1, 2, ...) + for req_idx in range(num_prefill_reqs): + # Get query token range for this prefill request + global_req_idx = num_decode_reqs + req_idx + req_query_start = common_attn_metadata.query_start_loc[global_req_idx] + req_query_end = common_attn_metadata.query_start_loc[global_req_idx + 1] + prefill_request_id[req_query_start:req_query_end] = req_idx + + # Compute cumulative sequence lengths for workspace mapping + # Shape: [num_prefill_reqs] - cumsum gives END of each sequence + # We'll convert to STARTS below + prefill_workspace_starts = torch.cumsum( + prefill_seq_lens, dim=0, dtype=torch.int32 + ) + # Convert ends to starts by shifting: + # cumsum=[10,25,45,50] -> starts=[0,10,25,45] + prefill_workspace_starts = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=self.device), + prefill_workspace_starts[:-1], + ] + ) + + # Chunk prefill requests to fit within workspace size + max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config) + chunk_bounds = split_prefill_chunks( + prefill_seq_lens_cpu, max_prefill_buffer_size, 0 + ) + + # Adjust workspace_starts in-place per chunk to be 0-indexed + # within each chunk + # Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]] + # Initial: workspace_starts=[0,10,25,45] + # After: workspace_starts=[0,10,0,20] + # (chunk 0 starts at 0, chunk 1 starts at 0) + for chunk_start, chunk_end in chunk_bounds: + offset = prefill_workspace_starts[chunk_start].item() + prefill_workspace_starts[chunk_start:chunk_end] -= offset + + # Create chunk metadata for each chunk + prefill_chunks = [] + for chunk_start, chunk_end in chunk_bounds: + # Get sequence lengths for this chunk + chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end] + + # Compute chunk size from CPU tensor to avoid GPU-to-CPU transfer later + chunk_size = int(prefill_seq_lens_cpu[chunk_start:chunk_end].sum()) + + # Create request slice (used for block_table and workspace_starts) + chunk_req_slice = slice( + num_decode_reqs + chunk_start, num_decode_reqs + chunk_end + ) + + # Determine token slice for this chunk's queries + # query_start_loc indices for prefill requests start after + # decode requests + token_start = common_attn_metadata.query_start_loc[ + chunk_req_slice.start + ].item() + token_end = common_attn_metadata.query_start_loc[ + chunk_req_slice.stop + ].item() + tokens_slice = slice(token_start, token_end) + + # Extract block table for this chunk + chunk_block_table = common_attn_metadata.block_table_tensor[ + chunk_req_slice + ] + + # Extract workspace_starts for this chunk + # (already adjusted to start at 0) + chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end] + + # Store request slice for mapping global req IDs to chunk-local IDs + request_slice = slice( + num_decode_reqs + chunk_start, num_decode_reqs + chunk_end + ) + + prefill_chunks.append( + FlashMLASparseMetadata.ChunkMetadata( + seq_lens=chunk_seq_lens, + tokens_slice=tokens_slice, + block_table=chunk_block_table, + req_start_idx=chunk_start, + workspace_starts=chunk_workspace_starts, + request_slice=request_slice, + chunk_size=chunk_size, + ) + ) + fp8_extra_metadata = None if self.use_fp8_kv_cache: tile_scheduler_metadata, num_splits = get_mla_metadata( @@ -361,6 +426,14 @@ def build( req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, topk_tokens=self.topk_tokens, + num_prefill_reqs=num_prefill_reqs, + num_decode_reqs=num_decode_reqs, + num_prefill_tokens=prefill_token_count, + num_decode_tokens=decode_token_count, + prefill_seq_lens=prefill_seq_lens, + prefill_request_id=prefill_request_id, + prefill_workspace_starts=prefill_workspace_starts, + prefill_chunks=prefill_chunks, fp8_extra_metadata=fp8_extra_metadata, ) return metadata @@ -402,6 +475,26 @@ def __init__( self.topk_indices_buffer = indexer.topk_indices_buffer self.padding = 128 if current_platform.is_device_capability(100) else 64 + # Reserve fixed workspace for prefill upconversion + # Workspace size: 5 * max_model_len tokens * head_dim (576) * + # sizeof(bfloat16). This allows us to gather and upconvert up to + # 5*max_model_len tokens from the KV cache for prefill attention. + # Prefill requests are chunked to fit within this workspace. + # Memory usage: 5 * max_model_len * 576 * 2 bytes + # Example: DeepSeek-V3.2 with max_model_len=163840 -> + # 5 * 163840 * 576 * 2 = ~900 MB + vllm_config = indexer.vllm_config + max_prefill_buffer_size = get_max_prefill_buffer_size(vllm_config) + + self.prefill_workspace_spec = WorkspaceSpec( + shape=(max_prefill_buffer_size, head_size), + dtype=torch.bfloat16, + name="FlashMLASparseImpl.prefill_workspace", + ) + + if kv_cache_dtype == "fp8_ds_mla": + current_workspace_manager().reserve(self.prefill_workspace_spec) + def _forward_bf16_kv( self, q: torch.Tensor, @@ -465,7 +558,7 @@ def forward( k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, - attn_metadata: FlashMLASparseMetadata, + attn_metadata: FlashMLASparseMetadata | None, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, @@ -481,6 +574,7 @@ def forward( ) if attn_metadata is None: + # Dummy run - no need to allocate buffers # The zero fill is required when used with DP + EP # to ensure all ranks within a DP group compute the # same expert outputs. @@ -504,18 +598,11 @@ def forward( topk_indices = self.topk_indices_buffer[:num_actual_toks] - # TODO: handle index / kv_cache correctly - topk_indices_global = triton_convert_req_index_to_global_index( - attn_metadata.req_id_per_token, - attn_metadata.block_table, - topk_indices, - BLOCK_SIZE=attn_metadata.block_size, - NUM_TOPK_TOKENS=attn_metadata.topk_tokens, - ) + use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla" q = torch.cat([ql_nope, q_pe], dim=-1) - # write the latent and rope to kv cache + # CRITICAL: Write to KV cache FIRST before gathering/upconverting if kv_cache.numel() > 0: ops.concat_and_cache_mla( k_c_normed, @@ -526,14 +613,86 @@ def forward( scale=layer._k_scale, ) - if self.kv_cache_dtype != "fp8_ds_mla": + # Convert per-request indices to global slots (decode) or workspace + # offsets (prefill). Single call for all tokens! + # prefill_workspace_starts has been adjusted in-place per chunk so + # prefill indices automatically come out chunk-local + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_tokens = attn_metadata.num_prefill_tokens + + topk_indices_global = ops.convert_logical_index_to_physical_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + block_size=attn_metadata.block_size, + prefill_request_id=attn_metadata.prefill_request_id, + workspace_starts=attn_metadata.prefill_workspace_starts, + ) + + if not use_fp8_cache: attn_out = self._forward_bf16_kv( q, kv_cache, topk_indices_global, attn_metadata ) else: - attn_out = self._forward_fp8_kv( - q, kv_cache, topk_indices_global, attn_metadata + attn_out = q.new_empty( + (num_actual_toks, self.num_heads, self.kv_lora_rank), + dtype=q.dtype, + device=q.device, ) + # Process decode tokens + if num_decode_tokens > 0: + attn_out[:num_decode_tokens] = self._forward_fp8_kv( + q[:num_decode_tokens], + kv_cache, + topk_indices_global[:num_decode_tokens], + attn_metadata, + ) + + # Process prefill tokens in chunks + if num_prefill_tokens > 0: + if kv_cache.numel() == 0: + raise RuntimeError( + "Expected non-empty kv_cache for fp8_ds_mla prefill handling" + ) + + assert attn_metadata.prefill_chunks is not None + + # Get the reserved workspace (reused for all chunks) + prefill_bf16_workspace = current_workspace_manager().get( + self.prefill_workspace_spec + ) + + # Process each chunk + for chunk in attn_metadata.prefill_chunks: + # Gather and upconvert this chunk into workspace + chunk_workspace = prefill_bf16_workspace[: chunk.chunk_size] + + ops.cp_gather_and_upconvert_fp8_kv_cache( + kv_cache, + chunk_workspace, + chunk.block_table, + chunk.seq_lens, + chunk.workspace_starts, + len(chunk.block_table), + ) + + # Get query tokens and precomputed workspace indices for this chunk + chunk_q = q[chunk.tokens_slice] + chunk_topk_indices_workspace = topk_indices_global[ + chunk.tokens_slice + ] + + # Run attention for this chunk + chunk_attn_out = self._forward_bf16_kv( + chunk_q, + chunk_workspace, + chunk_topk_indices_workspace, + attn_metadata, + ) + + # Write results back to output tensor + attn_out[chunk.tokens_slice] = chunk_attn_out + self._v_up_proj(attn_out, out=output[:num_actual_toks]) return output diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 49009a939d0b..da2209c202aa 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -176,9 +176,9 @@ def kv_spans_from_batches( def get_max_prefill_buffer_size(vllm_config: VllmConfig): max_model_len = vllm_config.model_config.max_model_len - # NOTE(Chen): 2 is a magic number for controlling the prefill buffer size. + # NOTE(Chen): 5 is a magic number for controlling the prefill buffer size. # May be tuned later. - return max_model_len * 2 + return max_model_len * 5 def split_prefill_chunks( diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6e026215d402..22310697f122 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -819,7 +819,9 @@ def get_num_blocks( available_memory: Memory available for KV cache in bytes. page_size: The page size of the KV cache. """ + num_blocks = int(available_memory // page_size // num_layers) + num_blocks = max(num_blocks, 0) num_blocks = may_override_num_blocks(vllm_config, num_blocks) return num_blocks diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9212221bb600..6c04d1a28146 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -138,6 +138,7 @@ check_ubatch_thresholds, ) from vllm.v1.worker.utils import is_residual_scattered_for_sp +from vllm.v1.worker.workspace import lock_workspace from .utils import ( AttentionGroup, @@ -261,6 +262,7 @@ def __init__( self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype + self.kv_cache_dtype = kv_cache_dtype_str_to_dtype( cache_config.cache_dtype, self.model_config ) @@ -3908,6 +3910,10 @@ def freeze_gc(): # after here. set_cudagraph_capturing_enabled(False) + # Lock workspace to prevent resizing during execution. + # Max workspace sizes should have been captured during warmup/profiling. + lock_workspace() + end_time = time.perf_counter() elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index c2bf1419bebd..62681cd09c1a 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -51,6 +51,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.worker_base import WorkerBase +from vllm.v1.worker.workspace import init_workspace_manager logger = init_logger(__name__) @@ -245,6 +246,9 @@ def init_device(self): else: raise RuntimeError(f"Not support device type: {self.device_config.device}") + # Initialize workspace manager + init_workspace_manager(self.device, self.vllm_config) + # Construct the model runner self.model_runner: GPUModelRunner = GPUModelRunner( self.vllm_config, self.device @@ -352,6 +356,7 @@ def determine_available_memory(self) -> int: ) gc.collect() + # Workspaces are now allocated dynamically, no need to pre-reserve memory return int(self.available_kv_cache_memory_bytes) def get_kv_connector_handshake_metadata(self) -> dict | None: diff --git a/vllm/v1/worker/workspace.py b/vllm/v1/worker/workspace.py new file mode 100644 index 000000000000..e3cc3225fa60 --- /dev/null +++ b/vllm/v1/worker/workspace.py @@ -0,0 +1,318 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from itertools import accumulate +from math import prod +from typing import Optional + +import torch + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import round_up +from vllm.v1.worker.ubatching import dbo_current_ubatch_id + +logger = init_logger(__name__) + + +@dataclass(frozen=True) +class WorkspaceSpec: + """Specification of a workspace to be allocated. + + Attributes: + shape: The shape of the workspace. + dtype: The data type of the workspace. + name: Optional name for debugging. + """ + + shape: tuple[int, ...] + dtype: torch.dtype + name: str = "unnamed" + + def num_bytes(self) -> int: + return prod(self.shape) * self.dtype.itemsize + + +# Constants +_MB = 1024**2 +_GiB = 1024**3 + +# Global workspace manager instance +_manager: Optional["WorkspaceManager"] = None + + +def is_workspace_manager_initialized() -> bool: + """Check if workspace manager has been initialized. + + Returns: + True if workspace manager is initialized, False otherwise. + """ + return _manager is not None + + +def current_workspace_manager() -> "WorkspaceManager": + """Get the current workspace manager instance. + + Raises: + AssertionError: If workspace manager has not been initialized. + """ + assert _manager is not None, ( + "WorkspaceManager not initialized. Call init_workspace_manager() " + "with a device before using workspace functions." + ) + return _manager + + +class WorkspaceManager: + """Manager for workspace allocation. + + Manages workspace buffers for DBO (Dual Batch Overlap) execution. + Can be locked to prevent further growth during execution. + """ + + def __init__(self, device: torch.device, vllm_config): + self._device = device + self._vllm_config = vllm_config + # Cache num ubatches at init based on configuration + self._num_ubatches = 2 if vllm_config.parallel_config.enable_dbo else 1 + self._current_workspaces: list[torch.Tensor | None] = [None, None] + self._locked: bool = False + + @staticmethod + def _workspace_size_bytes(workspace: torch.Tensor | None) -> int: + """Get size of workspace in bytes.""" + if workspace is None: + return 0 + return workspace.numel() * workspace.element_size() + + def lock(self) -> None: + """Lock the workspace to prevent further growth. + + After locking, any attempt to allocate a larger workspace will raise + an assertion error. This ensures workspace size is fixed during execution. + """ + self._locked = True + if envs.VLLM_DEBUG_WORKSPACE: + logger.info( + "[WORKSPACE DEBUG] Workspace locked. Current sizes: %s", + [ + self._workspace_size_bytes(ws) / _MB + for ws in self._current_workspaces + if ws is not None + ], + ) + + def is_locked(self) -> bool: + """Check if workspace is locked.""" + return self._locked + + def current_allocated_size_bytes(self) -> int: + """Get the size of the current workspace in bytes.""" + return self._workspace_size_bytes( + self._current_workspaces[dbo_current_ubatch_id()] + ) + + def reserve(self, spec: "WorkspaceSpec") -> None: + """Reserve workspace memory for a given spec. + + Allocates the workspace immediately if needed. + + Args: + spec: The workspace specification. + """ + # Allocate if workspace needs resize + # Note: both ubatches always have the same size, so we only check the first + num_bytes = spec.num_bytes() + if self._workspace_size_bytes(self._current_workspaces[0]) < num_bytes: + self._increase_size(num_bytes, spec.name) + + def reserve_simultaneous(self, *specs: "WorkspaceSpec") -> None: + """Reserve workspace memory for multiple specs simultaneously. + + Allocates a single workspace large enough for all specs immediately if needed. + + Args: + *specs: One or more workspace specifications. + """ + # Calculate total bytes needed for specs + spec_bytes = [spec.num_bytes() for spec in specs] + aligned_bytes = [round_up(byte_count, 256) for byte_count in spec_bytes] + total_bytes = sum(aligned_bytes) + + # Allocate if workspace needs resize + if self._workspace_size_bytes(self._current_workspaces[0]) < total_bytes: + workspace_names = ", ".join(spec.name for spec in specs) + self._increase_size(total_bytes, f"[{workspace_names}]") + + def get(self, spec: "WorkspaceSpec") -> torch.Tensor: + """Get a workspace tensor for the given spec. + + Args: + spec: The workspace specification. + + Returns: + A tensor view into the workspace buffer with the requested shape and dtype. + """ + shape, num_bytes = self._shape_and_bytes_for_spec(spec) + current_workspace = self._ensure_workspace_size(num_bytes, spec.name) + return current_workspace[:num_bytes].view(spec.dtype).reshape(shape) + + def _shape_and_bytes_for_spec( + self, spec: "WorkspaceSpec" + ) -> tuple[tuple[int, ...], int]: + """Return adjusted shape and actual size for a workspace spec.""" + num_bytes = spec.num_bytes() + shape = spec.shape + return shape, num_bytes + + def get_simultaneous(self, *specs: "WorkspaceSpec") -> list[torch.Tensor]: + """Get multiple workspace tensors simultaneously from a single allocation. + + Args: + *specs: One or more workspace specifications. + + Returns: + List of tensor views into the workspace buffer, one per spec. + """ + adjusted = [self._shape_and_bytes_for_spec(spec) for spec in specs] + adjusted_shapes = [shape for shape, _ in adjusted] + actual_bytes = [actual for _, actual in adjusted] + aligned_bytes = [round_up(actual, 256) for actual in actual_bytes] + total_bytes = sum(aligned_bytes) + + # Calculate cumulative offsets using itertools.accumulate + offsets = list(accumulate([0] + aligned_bytes[:-1])) + + workspace_names = ", ".join(spec.name for spec in specs) + current_workspace = self._ensure_workspace_size( + total_bytes, f"[{workspace_names}]" + ) + + return [ + current_workspace[offsets[i] : offsets[i] + actual_bytes[i]] + .view(specs[i].dtype) + .reshape(adjusted_shapes[i]) + for i in range(len(specs)) + ] + + def _ensure_workspace_size(self, num_bytes: int, name: str) -> torch.Tensor: + """Ensure workspace is allocated and large enough, return current workspace.""" + ubatch_id = dbo_current_ubatch_id() + current_workspace = self._current_workspaces[ubatch_id] + + # Manager owns a single device; no cross-device assertions needed + + if self._workspace_size_bytes(current_workspace) < num_bytes: + self._increase_size(num_bytes, name) + current_workspace = self._current_workspaces[ubatch_id] + + return current_workspace + + def _increase_size( + self, + required_bytes: int, + name: str = "unnamed", + ) -> None: + """Allocate or resize workspace for all ubatches. + + If DBO is enabled, allocates for both ubatches. Otherwise, allocates for + ubatch 0. Uses PyTorch's resize_() for efficient in-place resizing when + possible. + + Invariant: Both ubatches always have the same size after this function + completes. + + Args: + required_bytes: The number of bytes required. + name: Name for debugging/logging. + """ + # Manager owns a single device; no cross-device assertions needed + + # Check if we need to grow the workspace + current_size = self._workspace_size_bytes(self._current_workspaces[0]) + if self._locked and current_size < required_bytes: + raise AssertionError( + f"Workspace is locked but allocation for '{name}' requires " + f"{required_bytes / _MB:.2f} MB, current size is " + f"{current_size / _MB:.2f} MB. " + "Workspace growth is not allowed after locking." + ) + + was_unallocated = self._current_workspaces[0] is None + + for ubatch_id in range(self._num_ubatches): + current_workspace = self._current_workspaces[ubatch_id] + + if current_workspace is None: + self._current_workspaces[ubatch_id] = torch.empty( + (required_bytes,), dtype=torch.uint8, device=self._device + ) + elif self._workspace_size_bytes(current_workspace) < required_bytes: + # Use resize_() for efficient in-place resizing + current_workspace.resize_(required_bytes) + + if envs.VLLM_DEBUG_WORKSPACE: + total_mb = required_bytes * self._num_ubatches / _MB + if was_unallocated: + logger.info( + "[WORKSPACE DEBUG] Allocated workspace '%s': %.2f MB " + "(%d ubatches, total memory %.2f MB)", + name, + required_bytes / _MB, + self._num_ubatches, + total_mb, + ) + else: + logger.info( + "[WORKSPACE DEBUG] Resized workspace '%s': %.2f MB -> %.2f " + "MB (%d ubatches, total memory %.2f MB)", + name, + current_size / _MB, + required_bytes / _MB, + self._num_ubatches, + total_mb, + ) + + +def init_workspace_manager(device: torch.device, vllm_config: VllmConfig) -> None: + """Initialize the workspace manager with a device. + + Must be called before using any workspace functions. Typically called + from GPUModelRunner.__init__. + + Args: + device: The device to allocate workspace on. + """ + global _manager + if _manager is not None: + logger.warning( + "WorkspaceManager already initialized on device %s, " + "reinitializing on device %s", + _manager._device, + device, + ) + _manager = WorkspaceManager(device, vllm_config) + + +def lock_workspace() -> None: + """Lock the workspace to prevent further growth. + + After calling this function, any attempt to allocate a workspace larger + than the current size will raise an AssertionError. This ensures that + workspace size is fixed during execution and prevents unexpected memory + allocations in the hot path. + + Example: + # During initialization + init_workspace_manager(device) + reserve_workspace(spec1) + reserve_workspace(spec2) + + # Lock after warmup/profiling + lock_workspace() + + # Now all get_workspace calls must fit in pre-allocated size + """ + current_workspace_manager().lock() From ea4bb18a22b7dbde756c2677628ff86468bf29cd Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sat, 25 Oct 2025 07:40:38 -0700 Subject: [PATCH 02/16] cleanup Signed-off-by: Lucas Wilkinson --- .../attention/backends/mla/flashmla_sparse.py | 186 ++++++++---------- 1 file changed, 81 insertions(+), 105 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index b306df09fff2..7db96afe5b30 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -24,10 +24,6 @@ from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl -from vllm.v1.attention.backends.mla.indexer import ( - get_max_prefill_buffer_size, - split_prefill_chunks, -) from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -56,6 +52,45 @@ """ +def split_prefill_chunks( + seq_lens_cpu: torch.Tensor, workspace_size: int +) -> list[tuple[int, int]]: + """ + Split the prefill chunks into a list of tuples of (reqs_start, reqs_end) + such that the total sequence length of each chunk is less than the + maximum prefill buffer size. + + Args: + seq_lens_cpu: The sequence lengths of the prefill requests. + max_prefill_buffer_size: The maximum prefill buffer size. + + Returns: + A list of tuples of (reqs_start, reqs_end). + """ + chunk_bounds = [] + i, n = 0, len(seq_lens_cpu) + assert np.all(seq_lens_cpu <= workspace_size) + + while i < n: + start, total = i, 0 + while i < n and (total + (cur := seq_lens_cpu[i].item())) <= workspace_size: + total += cur + i += 1 + chunk_bounds.append((start, i)) + return chunk_bounds + + +def get_prefill_workspace_size(vllm_config: VllmConfig): + max_model_len = vllm_config.model_config.max_model_len + # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size. + # May be tuned later. + # Memory usage: 5 * max_model_len * 576 * 2 bytes + # Example: DeepSeek-V3.2 with max_model_len=163840 -> + # 5 * 163840 * 576 * 2 = ~900 MB + # This fits nicely below the typical MoE workspace size of >2GB so this is "free" + return max_model_len * 5 + + class FlashMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True @@ -139,21 +174,12 @@ class ChunkMetadata: Prefill requests may be chunked to fit within the fixed workspace size. """ - seq_lens: ( - torch.Tensor - ) # [num_reqs_in_chunk] sequence lengths (for gather kernel) - tokens_slice: slice # Slice to extract query tokens for this chunk - block_table: ( - torch.Tensor - ) # [num_reqs_in_chunk, max_blocks] block table for chunk - req_start_idx: int # Starting request index in the original request list - workspace_starts: ( - torch.Tensor - ) # [num_reqs_in_chunk] workspace starts, adjusted to start at 0 for this chunk - request_slice: ( - slice # Slice to extract requests for this chunk from the full request list - ) - chunk_size: int # Total number of tokens in this chunk (sum of seq_lens) + seq_lens: torch.Tensor + tokens_slice: slice + block_table: torch.Tensor + req_start_idx: int + workspace_starts: torch.Tensor + chunk_size: int prefill_chunks: list[ChunkMetadata] | None = None @@ -279,14 +305,12 @@ def build( # For pure decode batches, prefill_request_id will be None # For mixed batches, it will have -1 for decode and request_id for prefill if num_prefill_reqs > 0: - # Get sequence lengths from common_attn_metadata - # seq_lens includes both context and query seq_lens_cpu = common_attn_metadata.seq_lens_cpu - # Prefill requests are at the end (after decode requests) + seq_lens = common_attn_metadata.seq_lens + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + prefill_seq_lens_cpu = seq_lens_cpu[num_decode_reqs:] - prefill_seq_lens = torch.tensor( - prefill_seq_lens_cpu, dtype=torch.int32, device=self.device - ) + prefill_seq_lens = seq_lens[num_decode_reqs:] # Build prefill_request_id: -1 for decode, request index for # prefill. This enables a single @@ -302,76 +326,46 @@ def build( req_query_end = common_attn_metadata.query_start_loc[global_req_idx + 1] prefill_request_id[req_query_start:req_query_end] = req_idx - # Compute cumulative sequence lengths for workspace mapping - # Shape: [num_prefill_reqs] - cumsum gives END of each sequence - # We'll convert to STARTS below - prefill_workspace_starts = torch.cumsum( - prefill_seq_lens, dim=0, dtype=torch.int32 + # will be adjusted by chunk loop + prefill_workspace_starts_cpu = torch.zeros( + num_prefill_reqs, dtype=torch.int32, pin_memory=True ) - # Convert ends to starts by shifting: - # cumsum=[10,25,45,50] -> starts=[0,10,25,45] - prefill_workspace_starts = torch.cat( - [ - torch.zeros(1, dtype=torch.int32, device=self.device), - prefill_workspace_starts[:-1], - ] + torch.cumsum(prefill_seq_lens[1:], out=prefill_workspace_starts_cpu) + # populated by non-blocking copy after prefill_workspace_starts_cpu is + # updated by each chunk + prefill_workspace_starts = torch.empty( + num_prefill_reqs, dtype=torch.int32, device=self.device ) # Chunk prefill requests to fit within workspace size - max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config) + max_prefill_buffer_size = get_prefill_workspace_size(self.vllm_config) chunk_bounds = split_prefill_chunks( - prefill_seq_lens_cpu, max_prefill_buffer_size, 0 + prefill_seq_lens_cpu, max_prefill_buffer_size ) + prefill_chunks = [] - # Adjust workspace_starts in-place per chunk to be 0-indexed - # within each chunk - # Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]] - # Initial: workspace_starts=[0,10,25,45] - # After: workspace_starts=[0,10,0,20] - # (chunk 0 starts at 0, chunk 1 starts at 0) for chunk_start, chunk_end in chunk_bounds: - offset = prefill_workspace_starts[chunk_start].item() - prefill_workspace_starts[chunk_start:chunk_end] -= offset + # Adjust workspace_starts in-place per chunk to be + # 0-indexed within each chunk + # Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]] + # Initial: workspace_starts=[0,10,25,45] + # After: workspace_starts=[0,10,0,20] + # (chunk 0 starts at 0, chunk 1 starts at 0) + offset = prefill_workspace_starts_cpu[chunk_start].item() + prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset - # Create chunk metadata for each chunk - prefill_chunks = [] - for chunk_start, chunk_end in chunk_bounds: - # Get sequence lengths for this chunk chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end] - - # Compute chunk size from CPU tensor to avoid GPU-to-CPU transfer later chunk_size = int(prefill_seq_lens_cpu[chunk_start:chunk_end].sum()) - - # Create request slice (used for block_table and workspace_starts) - chunk_req_slice = slice( - num_decode_reqs + chunk_start, num_decode_reqs + chunk_end - ) - - # Determine token slice for this chunk's queries - # query_start_loc indices for prefill requests start after - # decode requests - token_start = common_attn_metadata.query_start_loc[ - chunk_req_slice.start - ].item() - token_end = common_attn_metadata.query_start_loc[ - chunk_req_slice.stop - ].item() + token_start = query_start_loc_cpu[chunk_start].item() + token_end = query_start_loc_cpu[chunk_end].item() tokens_slice = slice(token_start, token_end) - # Extract block table for this chunk + # Create chunk view of gpu tensor + chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end] chunk_block_table = common_attn_metadata.block_table_tensor[ - chunk_req_slice + num_decode_reqs + chunk_start : num_decode_reqs + chunk_end ] - # Extract workspace_starts for this chunk - # (already adjusted to start at 0) - chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end] - - # Store request slice for mapping global req IDs to chunk-local IDs - request_slice = slice( - num_decode_reqs + chunk_start, num_decode_reqs + chunk_end - ) - prefill_chunks.append( FlashMLASparseMetadata.ChunkMetadata( seq_lens=chunk_seq_lens, @@ -379,11 +373,14 @@ def build( block_table=chunk_block_table, req_start_idx=chunk_start, workspace_starts=chunk_workspace_starts, - request_slice=request_slice, chunk_size=chunk_size, ) ) + prefill_workspace_starts.copy_( + prefill_workspace_starts_cpu, non_blocking=True + ) + fp8_extra_metadata = None if self.use_fp8_kv_cache: tile_scheduler_metadata, num_splits = get_mla_metadata( @@ -476,18 +473,11 @@ def __init__( self.padding = 128 if current_platform.is_device_capability(100) else 64 # Reserve fixed workspace for prefill upconversion - # Workspace size: 5 * max_model_len tokens * head_dim (576) * - # sizeof(bfloat16). This allows us to gather and upconvert up to - # 5*max_model_len tokens from the KV cache for prefill attention. - # Prefill requests are chunked to fit within this workspace. - # Memory usage: 5 * max_model_len * 576 * 2 bytes - # Example: DeepSeek-V3.2 with max_model_len=163840 -> - # 5 * 163840 * 576 * 2 = ~900 MB vllm_config = indexer.vllm_config - max_prefill_buffer_size = get_max_prefill_buffer_size(vllm_config) + prefill_workspace_size = get_prefill_workspace_size(vllm_config) self.prefill_workspace_spec = WorkspaceSpec( - shape=(max_prefill_buffer_size, head_size), + shape=(prefill_workspace_size, head_size), dtype=torch.bfloat16, name="FlashMLASparseImpl.prefill_workspace", ) @@ -507,7 +497,7 @@ def _forward_bf16_kv( -1, 1, kv_c_and_k_pe_cache.shape[-1] ) - # NOTE(Chen): kernel requires num_local_head to be a multiple of + # NOTE(Lucas): kernel requires num_local_head to be a multiple of # 64 on hopper and 128 on blackwell if self.num_heads % self.padding != 0: assert self.padding % self.num_heads == 0 @@ -614,7 +604,7 @@ def forward( ) # Convert per-request indices to global slots (decode) or workspace - # offsets (prefill). Single call for all tokens! + # offsets (prefill). # prefill_workspace_starts has been adjusted in-place per chunk so # prefill indices automatically come out chunk-local num_decode_tokens = attn_metadata.num_decode_tokens @@ -649,25 +639,14 @@ def forward( attn_metadata, ) - # Process prefill tokens in chunks if num_prefill_tokens > 0: - if kv_cache.numel() == 0: - raise RuntimeError( - "Expected non-empty kv_cache for fp8_ds_mla prefill handling" - ) - assert attn_metadata.prefill_chunks is not None - - # Get the reserved workspace (reused for all chunks) prefill_bf16_workspace = current_workspace_manager().get( self.prefill_workspace_spec ) - # Process each chunk for chunk in attn_metadata.prefill_chunks: - # Gather and upconvert this chunk into workspace chunk_workspace = prefill_bf16_workspace[: chunk.chunk_size] - ops.cp_gather_and_upconvert_fp8_kv_cache( kv_cache, chunk_workspace, @@ -677,13 +656,11 @@ def forward( len(chunk.block_table), ) - # Get query tokens and precomputed workspace indices for this chunk chunk_q = q[chunk.tokens_slice] chunk_topk_indices_workspace = topk_indices_global[ chunk.tokens_slice ] - # Run attention for this chunk chunk_attn_out = self._forward_bf16_kv( chunk_q, chunk_workspace, @@ -691,7 +668,6 @@ def forward( attn_metadata, ) - # Write results back to output tensor attn_out[chunk.tokens_slice] = chunk_attn_out self._v_up_proj(attn_out, out=output[:num_actual_toks]) From b6b4451f1e257e5c3f34745329961070b2201a7c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sat, 25 Oct 2025 07:43:48 -0700 Subject: [PATCH 03/16] fix Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 7db96afe5b30..09613fd828cc 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -330,7 +330,9 @@ def build( prefill_workspace_starts_cpu = torch.zeros( num_prefill_reqs, dtype=torch.int32, pin_memory=True ) - torch.cumsum(prefill_seq_lens[1:], out=prefill_workspace_starts_cpu) + prefill_workspace_starts_cpu[1:] = torch.cumsum( + prefill_seq_lens_cpu[:-1], dim=0 + ) # populated by non-blocking copy after prefill_workspace_starts_cpu is # updated by each chunk prefill_workspace_starts = torch.empty( From a7bf346d2f56a748884fd2798c7bfe29314ccf76 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sat, 25 Oct 2025 07:48:00 -0700 Subject: [PATCH 04/16] fix Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 09613fd828cc..42f21323c55c 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -69,7 +69,7 @@ def split_prefill_chunks( """ chunk_bounds = [] i, n = 0, len(seq_lens_cpu) - assert np.all(seq_lens_cpu <= workspace_size) + assert torch.all(seq_lens_cpu <= workspace_size).item() while i < n: start, total = i, 0 From fa7947bd4b22ed87ef84c6b26c08f880bc9968eb Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sat, 25 Oct 2025 13:02:26 -0700 Subject: [PATCH 05/16] fix Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 42f21323c55c..a9b982a5c7fc 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -358,8 +358,8 @@ def build( chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end] chunk_size = int(prefill_seq_lens_cpu[chunk_start:chunk_end].sum()) - token_start = query_start_loc_cpu[chunk_start].item() - token_end = query_start_loc_cpu[chunk_end].item() + token_start = query_start_loc_cpu[num_decode_reqs + chunk_start].item() + token_end = query_start_loc_cpu[num_decode_reqs + chunk_end].item() tokens_slice = slice(token_start, token_end) # Create chunk view of gpu tensor From 90419b35478edd56292d7caac3aeb0bf1e736ed5 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 26 Oct 2025 06:12:47 -0700 Subject: [PATCH 06/16] clean-up revert to triton Signed-off-by: Lucas Wilkinson --- csrc/cache.h | 6 - csrc/cache_kernels.cu | 152 --------------- csrc/torch_bindings.cpp | 11 -- vllm/_custom_ops.py | 31 ---- .../attention/backends/mla/flashmla_sparse.py | 173 +++++++++++++++++- vllm/v1/worker/workspace.py | 64 ++++--- 6 files changed, 206 insertions(+), 231 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 49245cbca139..6f989f10aedd 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -82,9 +82,3 @@ void cp_gather_indexer_k_quant_cache( torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] const torch::Tensor& block_table, // [batch_size, num_blocks] const torch::Tensor& cu_seq_lens); // [batch_size + 1] - -torch::Tensor convert_logical_index_to_physical_index( - torch::Tensor req_id, torch::Tensor block_table, - torch::Tensor token_indices, int64_t block_size, - const std::optional& prefill_request_id, - const std::optional& workspace_starts); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index ed97f63ab851..29986447bc56 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1370,158 +1370,6 @@ void indexer_k_quant_and_cache( CALL_INDEXER_K_QUANT_AND_CACHE); } -namespace vllm { - -// Simplified kernel: convert per-request indices to global slots or workspace -// offsets -__global__ void convert_logical_index_to_physical_index_kernel( - const int32_t* __restrict__ req_id, // [num_tokens] - const int32_t* __restrict__ block_table, // [num_requests, - // max_num_blocks_per_req] - const int32_t* __restrict__ token_indices, // [num_tokens, NUM_TOPK_TOKENS] - int32_t* __restrict__ out, // [num_tokens, NUM_TOPK_TOKENS] - const int32_t* __restrict__ prefill_request_id, // [num_tokens], -1 for - // decode, >=0 for prefill - const int32_t* __restrict__ workspace_starts, // [num_prefill_reqs+1] or - // nullptr - int num_topk_tokens, int block_size, int max_num_blocks_per_req, - int bt_stride0, int bt_stride1, int ti_stride0, int ti_stride1, - int out_stride0, int out_stride1) { - const int token_id = blockIdx.x; - const int tid = threadIdx.x; - - // Load request id and prefill request id for this token - const int req = req_id[token_id]; - const int prefill_req_id = - prefill_request_id != nullptr ? prefill_request_id[token_id] : -1; - const bool is_prefill = prefill_req_id >= 0; - - // Loop over topk_indices - for (int indice_id = tid; indice_id < num_topk_tokens; - indice_id += blockDim.x) { - // Load token index (logical index within request) - const int ti_offset = token_id * ti_stride0 + indice_id * ti_stride1; - const int tok = token_indices[ti_offset]; - - // Check if token is invalid - bool is_invalid = tok < 0; - - int out_val = -1; - - if (is_prefill && workspace_starts != nullptr && !is_invalid) { - // Map to workspace offset: workspace_starts[prefill_req_id] + - // logical_token_id - out_val = workspace_starts[prefill_req_id] + tok; - } else if (!is_invalid) { - // Map to global cache slot (decode path) - // Compute block id and in-block offset - const int block_id = tok / block_size; - const int inblock_off = tok % block_size; - - // Guard block_table access - const bool valid_block = block_id < max_num_blocks_per_req; - int base = 0; - if (valid_block) { - const int bt_offset = req * bt_stride0 + block_id * bt_stride1; - base = block_table[bt_offset]; - } - - if (valid_block) { - out_val = base * block_size + inblock_off; - } - } - - // Store result - const int out_offset = token_id * out_stride0 + indice_id * out_stride1; - out[out_offset] = out_val; - } -} - -} // namespace vllm - -// Host function to launch the simplified index conversion kernel -torch::Tensor convert_logical_index_to_physical_index( - torch::Tensor req_id, // int32 [num_tokens] - torch::Tensor block_table, // int32 [num_requests, max_num_blocks_per_req] - torch::Tensor token_indices, // int32 [num_tokens, NUM_TOPK_TOKENS] - int64_t block_size, // KV cache block size - const std::optional& - prefill_request_id, // int32 [num_tokens], -1 for decode - const std::optional& - workspace_starts // int32 [num_prefill_reqs+1] -) { - constexpr int THREADS_PER_BLOCK = 256; - - // Validate input tensors - TORCH_CHECK(req_id.is_cuda(), "req_id must be a CUDA tensor"); - TORCH_CHECK(block_table.is_cuda(), "block_table must be a CUDA tensor"); - TORCH_CHECK(token_indices.is_cuda(), "token_indices must be a CUDA tensor"); - TORCH_CHECK(req_id.dtype() == torch::kInt32, "req_id must be int32"); - TORCH_CHECK(block_table.dtype() == torch::kInt32, - "block_table must be int32"); - TORCH_CHECK(token_indices.dtype() == torch::kInt32, - "token_indices must be int32"); - - // Ensure contiguous - req_id = req_id.contiguous(); - block_table = block_table.contiguous(); - token_indices = token_indices.contiguous(); - - // Extract dimensions - const int num_tokens = req_id.size(0); - const int num_topk_tokens = token_indices.size(1); - const int max_num_blocks_per_req = block_table.size(1); - - // Create output tensor - auto out = torch::empty_like(token_indices); - - // Extract strides - const int bt_stride0 = block_table.stride(0); - const int bt_stride1 = block_table.stride(1); - const int ti_stride0 = token_indices.stride(0); - const int ti_stride1 = token_indices.stride(1); - const int out_stride0 = out.stride(0); - const int out_stride1 = out.stride(1); - - // Handle optional prefill tensors - const int32_t* prefill_request_id_ptr = nullptr; - const int32_t* workspace_starts_ptr = nullptr; - - if (prefill_request_id.has_value()) { - auto& prid = prefill_request_id.value(); - TORCH_CHECK(prid.is_cuda(), "prefill_request_id must be a CUDA tensor"); - TORCH_CHECK(prid.is_contiguous(), "prefill_request_id must be contiguous"); - TORCH_CHECK(prid.dtype() == torch::kInt32, - "prefill_request_id must be int32"); - prefill_request_id_ptr = prid.data_ptr(); - } - - if (workspace_starts.has_value()) { - auto& ws = workspace_starts.value(); - TORCH_CHECK(ws.is_cuda(), "workspace_starts must be a CUDA tensor"); - TORCH_CHECK(ws.is_contiguous(), "workspace_starts must be contiguous"); - TORCH_CHECK(ws.dtype() == torch::kInt32, "workspace_starts must be int32"); - workspace_starts_ptr = ws.data_ptr(); - } - - // Get CUDA stream - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - // Launch kernel - dim3 grid(num_tokens); - dim3 block(THREADS_PER_BLOCK); - - vllm::convert_logical_index_to_physical_index_kernel<<>>( - req_id.data_ptr(), block_table.data_ptr(), - token_indices.data_ptr(), out.data_ptr(), - prefill_request_id_ptr, workspace_starts_ptr, num_topk_tokens, block_size, - max_num_blocks_per_req, bt_stride0, bt_stride1, ti_stride0, ti_stride1, - out_stride0, out_stride1); - - return out; -} - // Macro to dispatch the kernel based on the data type. #define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \ vllm::cp_gather_indexer_k_quant_cache_kernel \ diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 452b6ff8605d..639a231119a2 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -89,17 +89,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor suffix_lse) -> ()"); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); - ops.def( - "convert_logical_index_to_physical_index(" - " Tensor req_id," - " Tensor block_table," - " Tensor token_indices," - " int block_size," - " Tensor? prefill_request_id," - " Tensor? workspace_starts) -> Tensor"); - ops.impl("convert_logical_index_to_physical_index", torch::kCUDA, - &convert_logical_index_to_physical_index); - ops.def( "convert_vertical_slash_indexes(" " Tensor! block_count, Tensor! block_offset, " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index feb8c8b476ed..a13c5fd46434 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2196,37 +2196,6 @@ def cp_gather_indexer_k_quant_cache( ) -def convert_logical_index_to_physical_index( - req_id: torch.Tensor, - block_table: torch.Tensor, - token_indices: torch.Tensor, - block_size: int, - prefill_request_id: torch.Tensor | None = None, - workspace_starts: torch.Tensor | None = None, -) -> torch.Tensor: - """Convert per-request logical indices to physical cache slots or workspace offsets. - - For decode tokens, maps to physical cache slots. - For prefill tokens, maps to workspace offsets. - - Args: - req_id: Request ID for each token - block_table: Block table mapping requests to cache blocks - token_indices: Per-request logical token indices to convert - block_size: Size of each cache block - prefill_request_id: Request ID for prefill tokens (-1 for decode) - workspace_starts: Cumulative sum of prefill sequence lengths - """ - return torch.ops._C.convert_logical_index_to_physical_index( - req_id, - block_table, - token_indices, - block_size, - prefill_request_id, - workspace_starts, - ) - - def get_device_attribute(attribute: int, device: int) -> int: return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index a9b982a5c7fc..2c807795d13c 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -52,6 +52,174 @@ """ +@triton.jit +def _convert_req_index_to_global_index_kernel( + req_id_ptr, # int32 [num_tokens] + block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] + token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill + workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr + # shapes (compile-time where possible) + max_num_blocks_per_req: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, # tile width along columns + HAS_PREFILL: tl.constexpr, + # strides (in elements) + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, +): + # program_id(0) -> token_id (row) + # program_id(1) -> tile index along columns + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + # Each program covers BLOCK_N consecutive columns + indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load request id for this token (no mask: grid is exact) + req = tl.load(req_id_ptr + token_id) + + # Load prefill request id if prefill support is enabled + if HAS_PREFILL: + prefill_req_id = tl.load(prefill_request_id_ptr + token_id) + is_prefill = prefill_req_id >= 0 + + # Load token indices for this tile + ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 + tok = tl.load(ti_ptr) # int32 + + # Only token == -1 should propagate as -1 + is_invalid_tok = tok < 0 + + # Prefill path: map to workspace offset + if HAS_PREFILL: + workspace_start = tl.load( + workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 + ) + prefill_out = workspace_start + tok + + # Compute block id and in-block offset + block_id = tok // BLOCK_SIZE + inblock_off = tok % BLOCK_SIZE + + # Guard block_table access + valid_block = block_id < max_num_blocks_per_req + bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 + base = tl.load(bt_ptr, mask=valid_block, other=0) + + # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset + if HAS_PREFILL: + decode_out = tl.where(valid_block, base * BLOCK_SIZE + inblock_off, -1) + out_val = tl.where( + is_invalid_tok, -1, tl.where(is_prefill, prefill_out, decode_out) + ) + else: + out_val = tl.where( + is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off + ) + + # Store results + out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 + tl.store(out_ptr_ij, out_val) + + +def triton_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns + prefill_request_id: torch.Tensor | None = None, + workspace_starts: torch.Tensor | None = None, +): + """ + out[token_id, indice_id] = + block_table[req_id[token_id], + token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + + token_indices[token_id, indice_id] % BLOCK_SIZE + + Only when token_indices[token_id, indice_id] == -1 do we output -1. + For safety, we also output -1 if the derived block_id would be + out-of-bounds. + + When prefill_request_id and workspace_starts are provided, prefill tokens + are mapped to workspace offsets instead of global cache slots. + """ + assert req_id.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert token_indices.dtype == torch.int32 + assert token_indices.shape[1] == NUM_TOPK_TOKENS + assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" + ) + + has_prefill = prefill_request_id is not None and workspace_starts is not None + if has_prefill: + assert prefill_request_id is not None + assert workspace_starts is not None + assert prefill_request_id.dtype == torch.int32 + assert workspace_starts.dtype == torch.int32 + + num_tokens = req_id.shape[0] + num_requests, max_num_blocks_per_req = block_table.shape + tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N + + # Ensure contiguous tensors on the same device + req_id_c = req_id.contiguous() + block_table_c = block_table.contiguous() + token_indices_c = token_indices.contiguous() + out = torch.empty_like(token_indices_c) + + # Strides in elements + bt_stride0, bt_stride1 = block_table_c.stride() + ti_stride0, ti_stride1 = token_indices_c.stride() + out_stride0, out_stride1 = out.stride() + + # Prepare prefill pointers + if has_prefill: + assert prefill_request_id is not None # for mypy + assert workspace_starts is not None # for mypy + prefill_request_id_c = prefill_request_id.contiguous() + workspace_starts_c = workspace_starts.contiguous() + prefill_request_id_ptr = prefill_request_id_c + workspace_starts_ptr = workspace_starts_c + else: + # Dummy pointers (won't be accessed when HAS_PREFILL=False) + prefill_request_id_ptr = req_id_c + workspace_starts_ptr = req_id_c + + # Exact 2D grid: tokens × column tiles + grid = (num_tokens, tiles_per_row) + + _convert_req_index_to_global_index_kernel[grid]( + req_id_c, + block_table_c, + token_indices_c, + out, + prefill_request_id_ptr, + workspace_starts_ptr, + # shapes / constexprs + max_num_blocks_per_req, + BLOCK_SIZE, + BLOCK_N, + has_prefill, + # strides + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, + ) + return out + + def split_prefill_chunks( seq_lens_cpu: torch.Tensor, workspace_size: int ) -> list[tuple[int, int]]: @@ -612,11 +780,12 @@ def forward( num_decode_tokens = attn_metadata.num_decode_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens - topk_indices_global = ops.convert_logical_index_to_physical_index( + topk_indices_global = triton_convert_req_index_to_global_index( attn_metadata.req_id_per_token, attn_metadata.block_table, topk_indices, - block_size=attn_metadata.block_size, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=topk_indices.shape[1], prefill_request_id=attn_metadata.prefill_request_id, workspace_starts=attn_metadata.prefill_workspace_starts, ) diff --git a/vllm/v1/worker/workspace.py b/vllm/v1/worker/workspace.py index e3cc3225fa60..0157e30d32ab 100644 --- a/vllm/v1/worker/workspace.py +++ b/vllm/v1/worker/workspace.py @@ -117,34 +117,50 @@ def current_allocated_size_bytes(self) -> int: def reserve(self, spec: "WorkspaceSpec") -> None: """Reserve workspace memory for a given spec. - Allocates the workspace immediately if needed. + This is a convenience wrapper around get() that makes it easier to grep + for workspace reservations in the codebase for auditing purposes. Args: spec: The workspace specification. """ - # Allocate if workspace needs resize - # Note: both ubatches always have the same size, so we only check the first - num_bytes = spec.num_bytes() - if self._workspace_size_bytes(self._current_workspaces[0]) < num_bytes: - self._increase_size(num_bytes, spec.name) + # TODO(Lucas): Assert that only reserves (ds/reserve_simultaneous) can + # increase the workspace size, so that reserve must be called before `get`. + # This will encourage the use of reserve which is mostly just useful for + # grepping/auditing the codebase. + + # Note: We don't assert !locked here because reserve can be called + # during forward passes. The actual locking logic is in _increase_size. + # Call get() to perform the actual allocation + self.get(spec) + + def ds(self, spec: "WorkspaceSpec") -> None: + """Alias for reserve() for backwards compatibility. + + Reserve workspace memory for a given spec. + + Args: + spec: The workspace specification. + """ + self.reserve(spec) def reserve_simultaneous(self, *specs: "WorkspaceSpec") -> None: """Reserve workspace memory for multiple specs simultaneously. - Allocates a single workspace large enough for all specs immediately if needed. + This is a convenience wrapper around get_simultaneous() that makes it easier + to grep for workspace reservations in the codebase for auditing purposes. Args: *specs: One or more workspace specifications. """ - # Calculate total bytes needed for specs - spec_bytes = [spec.num_bytes() for spec in specs] - aligned_bytes = [round_up(byte_count, 256) for byte_count in spec_bytes] - total_bytes = sum(aligned_bytes) + # TODO(Lucas): Assert that only reserves (ds/reserve_simultaneous) can + # increase the workspace size, so that reserve must be called before `get`. + # This will encourage the use of reserve which is mostly just useful for + # grepping/auditing the codebase. - # Allocate if workspace needs resize - if self._workspace_size_bytes(self._current_workspaces[0]) < total_bytes: - workspace_names = ", ".join(spec.name for spec in specs) - self._increase_size(total_bytes, f"[{workspace_names}]") + # Note: We don't assert !locked here because reserve can be called + # during forward passes. The actual locking logic is in _increase_size. + # Call get_simultaneous() to perform the actual allocation + self.get_simultaneous(*specs) def get(self, spec: "WorkspaceSpec") -> torch.Tensor: """Get a workspace tensor for the given spec. @@ -155,17 +171,9 @@ def get(self, spec: "WorkspaceSpec") -> torch.Tensor: Returns: A tensor view into the workspace buffer with the requested shape and dtype. """ - shape, num_bytes = self._shape_and_bytes_for_spec(spec) - current_workspace = self._ensure_workspace_size(num_bytes, spec.name) - return current_workspace[:num_bytes].view(spec.dtype).reshape(shape) - - def _shape_and_bytes_for_spec( - self, spec: "WorkspaceSpec" - ) -> tuple[tuple[int, ...], int]: - """Return adjusted shape and actual size for a workspace spec.""" num_bytes = spec.num_bytes() - shape = spec.shape - return shape, num_bytes + current_workspace = self._ensure_workspace_size(num_bytes, spec.name) + return current_workspace[:num_bytes].view(spec.dtype).reshape(spec.shape) def get_simultaneous(self, *specs: "WorkspaceSpec") -> list[torch.Tensor]: """Get multiple workspace tensors simultaneously from a single allocation. @@ -176,9 +184,7 @@ def get_simultaneous(self, *specs: "WorkspaceSpec") -> list[torch.Tensor]: Returns: List of tensor views into the workspace buffer, one per spec. """ - adjusted = [self._shape_and_bytes_for_spec(spec) for spec in specs] - adjusted_shapes = [shape for shape, _ in adjusted] - actual_bytes = [actual for _, actual in adjusted] + actual_bytes = [spec.num_bytes() for spec in specs] aligned_bytes = [round_up(actual, 256) for actual in actual_bytes] total_bytes = sum(aligned_bytes) @@ -193,7 +199,7 @@ def get_simultaneous(self, *specs: "WorkspaceSpec") -> list[torch.Tensor]: return [ current_workspace[offsets[i] : offsets[i] + actual_bytes[i]] .view(specs[i].dtype) - .reshape(adjusted_shapes[i]) + .reshape(specs[i].shape) for i in range(len(specs)) ] From 55d9b843122ff6e01430dcb93fe0eee4250e7678 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 26 Oct 2025 14:33:01 -0700 Subject: [PATCH 07/16] cleanup Signed-off-by: Lucas Wilkinson --- csrc/cache_kernels.cu | 4 +- vllm/model_executor/models/deepseek_v2.py | 34 +- .../attention/backends/mla/flashmla_sparse.py | 445 +++++++++--------- vllm/v1/worker/workspace.py | 89 ++-- 4 files changed, 265 insertions(+), 307 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 29986447bc56..f34f5bbc86c3 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1126,8 +1126,6 @@ __global__ void cp_gather_and_upconvert_fp8_kv_cache( const int rope_idx = tid - 512; dst_ptr[512 + rope_idx] = rope_ptr[rope_idx]; } - // Threads 576+ are idle - // No sync needed - each iteration processes independent tokens // Move to next token offset += 1; @@ -1320,7 +1318,7 @@ void cp_gather_and_upconvert_fp8_kv_cache( // Decide on the number of splits based on the batch size int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; dim3 grid(batch_size, num_splits); - dim3 block(1024); + dim3 block(576); vllm::cp_gather_and_upconvert_fp8_kv_cache<<>>( src_cache.data_ptr(), diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index ce510f4e12e7..8ee1a8a831b7 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -519,20 +519,20 @@ def sparse_attn_indexer( ) -> torch.Tensor: # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata + + k_fp8_spec = WorkspaceSpec( + shape=(total_seq_lens, head_dim), + dtype=torch.float8_e4m3fn, + name="sparse_attn_indexer.k_fp8", + ) + k_scale_spec = WorkspaceSpec( + shape=(total_seq_lens, 4), + dtype=torch.uint8, + name="sparse_attn_indexer.k_scale", + ) + # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): - # Reserve workspace memory for the actual run. - k_fp8_spec = WorkspaceSpec( - shape=(total_seq_lens, head_dim), - dtype=torch.float8_e4m3fn, - name="sparse_attn_indexer.k_fp8", - ) - k_scale_spec = WorkspaceSpec( - shape=(total_seq_lens, 4), - dtype=torch.uint8, - name="sparse_attn_indexer.k_scale", - ) - current_workspace_manager().reserve_simultaneous(k_fp8_spec, k_scale_spec) return sparse_attn_indexer_fake( @@ -568,16 +568,6 @@ def sparse_attn_indexer( topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: prefill_metadata = attn_metadata.prefill - k_fp8_spec = WorkspaceSpec( - shape=(total_seq_lens, head_dim), - dtype=torch.float8_e4m3fn, - name="sparse_attn_indexer.k_fp8", - ) - k_scale_spec = WorkspaceSpec( - shape=(total_seq_lens, 4), - dtype=torch.uint8, - name="sparse_attn_indexer.k_scale", - ) # Get the full shared workspace buffers once (will allocate on first use) workspace_manager = current_workspace_manager() diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 2c807795d13c..3c2e12832994 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -52,174 +52,6 @@ """ -@triton.jit -def _convert_req_index_to_global_index_kernel( - req_id_ptr, # int32 [num_tokens] - block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] - token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] - out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] - prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill - workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr - # shapes (compile-time where possible) - max_num_blocks_per_req: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - BLOCK_N: tl.constexpr, # tile width along columns - HAS_PREFILL: tl.constexpr, - # strides (in elements) - bt_stride0, - bt_stride1, - ti_stride0, - ti_stride1, - out_stride0, - out_stride1, -): - # program_id(0) -> token_id (row) - # program_id(1) -> tile index along columns - token_id = tl.program_id(0) - tile_id = tl.program_id(1) - - # Each program covers BLOCK_N consecutive columns - indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) - - # Load request id for this token (no mask: grid is exact) - req = tl.load(req_id_ptr + token_id) - - # Load prefill request id if prefill support is enabled - if HAS_PREFILL: - prefill_req_id = tl.load(prefill_request_id_ptr + token_id) - is_prefill = prefill_req_id >= 0 - - # Load token indices for this tile - ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 - tok = tl.load(ti_ptr) # int32 - - # Only token == -1 should propagate as -1 - is_invalid_tok = tok < 0 - - # Prefill path: map to workspace offset - if HAS_PREFILL: - workspace_start = tl.load( - workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 - ) - prefill_out = workspace_start + tok - - # Compute block id and in-block offset - block_id = tok // BLOCK_SIZE - inblock_off = tok % BLOCK_SIZE - - # Guard block_table access - valid_block = block_id < max_num_blocks_per_req - bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 - base = tl.load(bt_ptr, mask=valid_block, other=0) - - # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset - if HAS_PREFILL: - decode_out = tl.where(valid_block, base * BLOCK_SIZE + inblock_off, -1) - out_val = tl.where( - is_invalid_tok, -1, tl.where(is_prefill, prefill_out, decode_out) - ) - else: - out_val = tl.where( - is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off - ) - - # Store results - out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 - tl.store(out_ptr_ij, out_val) - - -def triton_convert_req_index_to_global_index( - req_id: torch.Tensor, # int32 [num_tokens] - block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] - token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] - BLOCK_SIZE: int = 64, - NUM_TOPK_TOKENS: int = 2048, - BLOCK_N: int = 128, # tile width along columns - prefill_request_id: torch.Tensor | None = None, - workspace_starts: torch.Tensor | None = None, -): - """ - out[token_id, indice_id] = - block_table[req_id[token_id], - token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE - + token_indices[token_id, indice_id] % BLOCK_SIZE - - Only when token_indices[token_id, indice_id] == -1 do we output -1. - For safety, we also output -1 if the derived block_id would be - out-of-bounds. - - When prefill_request_id and workspace_starts are provided, prefill tokens - are mapped to workspace offsets instead of global cache slots. - """ - assert req_id.dtype == torch.int32 - assert block_table.dtype == torch.int32 - assert token_indices.dtype == torch.int32 - assert token_indices.shape[1] == NUM_TOPK_TOKENS - assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( - f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" - ) - - has_prefill = prefill_request_id is not None and workspace_starts is not None - if has_prefill: - assert prefill_request_id is not None - assert workspace_starts is not None - assert prefill_request_id.dtype == torch.int32 - assert workspace_starts.dtype == torch.int32 - - num_tokens = req_id.shape[0] - num_requests, max_num_blocks_per_req = block_table.shape - tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N - - # Ensure contiguous tensors on the same device - req_id_c = req_id.contiguous() - block_table_c = block_table.contiguous() - token_indices_c = token_indices.contiguous() - out = torch.empty_like(token_indices_c) - - # Strides in elements - bt_stride0, bt_stride1 = block_table_c.stride() - ti_stride0, ti_stride1 = token_indices_c.stride() - out_stride0, out_stride1 = out.stride() - - # Prepare prefill pointers - if has_prefill: - assert prefill_request_id is not None # for mypy - assert workspace_starts is not None # for mypy - prefill_request_id_c = prefill_request_id.contiguous() - workspace_starts_c = workspace_starts.contiguous() - prefill_request_id_ptr = prefill_request_id_c - workspace_starts_ptr = workspace_starts_c - else: - # Dummy pointers (won't be accessed when HAS_PREFILL=False) - prefill_request_id_ptr = req_id_c - workspace_starts_ptr = req_id_c - - # Exact 2D grid: tokens × column tiles - grid = (num_tokens, tiles_per_row) - - _convert_req_index_to_global_index_kernel[grid]( - req_id_c, - block_table_c, - token_indices_c, - out, - prefill_request_id_ptr, - workspace_starts_ptr, - # shapes / constexprs - max_num_blocks_per_req, - BLOCK_SIZE, - BLOCK_N, - has_prefill, - # strides - bt_stride0, - bt_stride1, - ti_stride0, - ti_stride1, - out_stride0, - out_stride1, - ) - return out - - def split_prefill_chunks( seq_lens_cpu: torch.Tensor, workspace_size: int ) -> list[tuple[int, int]]: @@ -259,49 +91,6 @@ def get_prefill_workspace_size(vllm_config: VllmConfig): return max_model_len * 5 -class FlashMLASparseBackend(AttentionBackend): - accept_output_buffer: bool = True - - @staticmethod - def get_name() -> str: - return "FLASHMLA_SPARSE" - - @staticmethod - def get_metadata_cls() -> type[AttentionMetadata]: - return FlashMLASparseMetadata - - @staticmethod - def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: - return FlashMLASparseMetadataBuilder - - @staticmethod - def get_impl_cls() -> type["FlashMLASparseImpl"]: - return FlashMLASparseImpl - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if cache_dtype_str == "fp8_ds_mla": - # custom storage fromat is 656 bytes - # see FlashMLA readme.md for details - return (num_blocks, block_size, 656) - else: - return (num_blocks, block_size, head_size) - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [576] - - @dataclass class FlashMLASparseMetadata: num_reqs: int @@ -347,7 +136,7 @@ class ChunkMetadata: block_table: torch.Tensor req_start_idx: int workspace_starts: torch.Tensor - chunk_size: int + chunk_tot_seqlen: int prefill_chunks: list[ChunkMetadata] | None = None @@ -490,8 +279,8 @@ def build( for req_idx in range(num_prefill_reqs): # Get query token range for this prefill request global_req_idx = num_decode_reqs + req_idx - req_query_start = common_attn_metadata.query_start_loc[global_req_idx] - req_query_end = common_attn_metadata.query_start_loc[global_req_idx + 1] + req_query_start = query_start_loc_cpu[global_req_idx] + req_query_end = query_start_loc_cpu[global_req_idx + 1] prefill_request_id[req_query_start:req_query_end] = req_idx # will be adjusted by chunk loop @@ -512,8 +301,8 @@ def build( chunk_bounds = split_prefill_chunks( prefill_seq_lens_cpu, max_prefill_buffer_size ) - prefill_chunks = [] + prefill_chunks = [] for chunk_start, chunk_end in chunk_bounds: # Adjust workspace_starts in-place per chunk to be # 0-indexed within each chunk @@ -525,7 +314,7 @@ def build( prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end] - chunk_size = int(prefill_seq_lens_cpu[chunk_start:chunk_end].sum()) + chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum() token_start = query_start_loc_cpu[num_decode_reqs + chunk_start].item() token_end = query_start_loc_cpu[num_decode_reqs + chunk_end].item() tokens_slice = slice(token_start, token_end) @@ -543,7 +332,7 @@ def build( block_table=chunk_block_table, req_start_idx=chunk_start, workspace_starts=chunk_workspace_starts, - chunk_size=chunk_size, + chunk_tot_seqlen=chunk_tot_seqlen, ) ) @@ -606,6 +395,214 @@ def build( return metadata +@triton.jit +def _convert_req_index_to_global_index_kernel( + req_id_ptr, # int32 [num_tokens] + block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] + token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill + workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr + # shapes (compile-time where possible) + max_num_blocks_per_req: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, # tile width along columns + HAS_PREFILL: tl.constexpr, + # strides (in elements) + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, +): + # program_id(0) -> token_id (row) + # program_id(1) -> tile index along columns + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + # Each program covers BLOCK_N consecutive columns + indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load request id for this token (no mask: grid is exact) + req = tl.load(req_id_ptr + token_id) + + # Load prefill request id if prefill support is enabled + is_prefill = tl.full((BLOCK_N,), 0, dtype=tl.bool) + if HAS_PREFILL: + prefill_req_id = tl.load(prefill_request_id_ptr + token_id) + is_prefill = prefill_req_id >= 0 + + # Load token indices for this tile + ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 + tok = tl.load(ti_ptr) # int32 + + # Only token == -1 should propagate as -1 + is_invalid_tok = tok < 0 + + # Compute block id and in-block offset + block_id = tok // BLOCK_SIZE + inblock_off = tok % BLOCK_SIZE + + # Guard block_table access + valid_block = block_id < max_num_blocks_per_req + bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 + base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0) + + if HAS_PREFILL: + workspace_start = tl.load( + workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 + ) + prefill_out = workspace_start + tok + decode_out = tl.where(valid_block, base * BLOCK_SIZE + inblock_off, -1) + out_val = tl.where(is_prefill, prefill_out, decode_out) + out_val = tl.where(is_invalid_tok, -1, out_val) + else: + # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset + out_val = tl.where( + is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off + ) + + # Store results + out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 + tl.store(out_ptr_ij, out_val) + + +def triton_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns + prefill_request_id: torch.Tensor | None = None, + workspace_starts: torch.Tensor | None = None, +): + """ + out[token_id, indice_id] = + block_table[req_id[token_id], + token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + + token_indices[token_id, indice_id] % BLOCK_SIZE + + Only when token_indices[token_id, indice_id] == -1 do we output -1. + For safety, we also output -1 if the derived block_id would be + out-of-bounds. + + When prefill_request_id and workspace_starts are provided, prefill tokens + are mapped to workspace offsets instead of global cache slots. + """ + assert req_id.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert token_indices.dtype == torch.int32 + assert token_indices.shape[1] == NUM_TOPK_TOKENS + assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" + ) + + has_prefill = prefill_request_id is not None and workspace_starts is not None + if has_prefill: + assert prefill_request_id is not None + assert workspace_starts is not None + assert prefill_request_id.dtype == torch.int32 + assert workspace_starts.dtype == torch.int32 + + num_tokens = req_id.shape[0] + num_requests, max_num_blocks_per_req = block_table.shape + tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N + + # Ensure contiguous tensors on the same device + req_id_c = req_id.contiguous() + block_table_c = block_table.contiguous() + token_indices_c = token_indices.contiguous() + out = torch.empty_like(token_indices_c) + + # Strides in elements + bt_stride0, bt_stride1 = block_table_c.stride() + ti_stride0, ti_stride1 = token_indices_c.stride() + out_stride0, out_stride1 = out.stride() + + # Prepare prefill pointers + if has_prefill: + assert prefill_request_id is not None # for mypy + assert workspace_starts is not None # for mypy + prefill_request_id_c = prefill_request_id.contiguous() + workspace_starts_c = workspace_starts.contiguous() + prefill_request_id_ptr = prefill_request_id_c + workspace_starts_ptr = workspace_starts_c + else: + # Dummy pointers (won't be accessed when HAS_PREFILL=False) + prefill_request_id_ptr = req_id_c + workspace_starts_ptr = req_id_c + + # Exact 2D grid: tokens × column tiles + grid = (num_tokens, tiles_per_row) + + _convert_req_index_to_global_index_kernel[grid]( + req_id_c, + block_table_c, + token_indices_c, + out, + prefill_request_id_ptr, + workspace_starts_ptr, + # shapes / constexprs + max_num_blocks_per_req, + BLOCK_SIZE, + BLOCK_N, + has_prefill, + # strides + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, + ) + return out + + +class FlashMLASparseBackend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "FLASHMLA_SPARSE" + + @staticmethod + def get_metadata_cls() -> type[AttentionMetadata]: + return FlashMLASparseMetadata + + @staticmethod + def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: + return FlashMLASparseMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashMLASparseImpl"]: + return FlashMLASparseImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if cache_dtype_str == "fp8_ds_mla": + # custom storage fromat is 656 bytes + # see FlashMLA readme.md for details + return (num_blocks, block_size, 656) + else: + return (num_blocks, block_size, head_size) + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): def __init__( self, @@ -667,7 +664,7 @@ def _forward_bf16_kv( -1, 1, kv_c_and_k_pe_cache.shape[-1] ) - # NOTE(Lucas): kernel requires num_local_head to be a multiple of + # NOTE(Chen): kernel requires num_local_head to be a multiple of # 64 on hopper and 128 on blackwell if self.num_heads % self.padding != 0: assert self.padding % self.num_heads == 0 @@ -773,13 +770,13 @@ def forward( scale=layer._k_scale, ) + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_tokens = attn_metadata.num_prefill_tokens + # Convert per-request indices to global slots (decode) or workspace # offsets (prefill). # prefill_workspace_starts has been adjusted in-place per chunk so # prefill indices automatically come out chunk-local - num_decode_tokens = attn_metadata.num_decode_tokens - num_prefill_tokens = attn_metadata.num_prefill_tokens - topk_indices_global = triton_convert_req_index_to_global_index( attn_metadata.req_id_per_token, attn_metadata.block_table, @@ -817,7 +814,7 @@ def forward( ) for chunk in attn_metadata.prefill_chunks: - chunk_workspace = prefill_bf16_workspace[: chunk.chunk_size] + chunk_workspace = prefill_bf16_workspace[: chunk.chunk_tot_seqlen] ops.cp_gather_and_upconvert_fp8_kv_cache( kv_cache, chunk_workspace, @@ -832,14 +829,12 @@ def forward( chunk.tokens_slice ] - chunk_attn_out = self._forward_bf16_kv( + attn_out[chunk.tokens_slice] = self._forward_bf16_kv( chunk_q, chunk_workspace, chunk_topk_indices_workspace, attn_metadata, ) - attn_out[chunk.tokens_slice] = chunk_attn_out - self._v_up_proj(attn_out, out=output[:num_actual_toks]) return output diff --git a/vllm/v1/worker/workspace.py b/vllm/v1/worker/workspace.py index 0157e30d32ab..d7bfaa0d5d94 100644 --- a/vllm/v1/worker/workspace.py +++ b/vllm/v1/worker/workspace.py @@ -43,28 +43,6 @@ def num_bytes(self) -> int: _manager: Optional["WorkspaceManager"] = None -def is_workspace_manager_initialized() -> bool: - """Check if workspace manager has been initialized. - - Returns: - True if workspace manager is initialized, False otherwise. - """ - return _manager is not None - - -def current_workspace_manager() -> "WorkspaceManager": - """Get the current workspace manager instance. - - Raises: - AssertionError: If workspace manager has not been initialized. - """ - assert _manager is not None, ( - "WorkspaceManager not initialized. Call init_workspace_manager() " - "with a device before using workspace functions." - ) - return _manager - - class WorkspaceManager: """Manager for workspace allocation. @@ -123,7 +101,7 @@ def reserve(self, spec: "WorkspaceSpec") -> None: Args: spec: The workspace specification. """ - # TODO(Lucas): Assert that only reserves (ds/reserve_simultaneous) can + # TODO(Lucas): Assert that only reserves (reserve/reserve_simultaneous) can # increase the workspace size, so that reserve must be called before `get`. # This will encourage the use of reserve which is mostly just useful for # grepping/auditing the codebase. @@ -133,16 +111,6 @@ def reserve(self, spec: "WorkspaceSpec") -> None: # Call get() to perform the actual allocation self.get(spec) - def ds(self, spec: "WorkspaceSpec") -> None: - """Alias for reserve() for backwards compatibility. - - Reserve workspace memory for a given spec. - - Args: - spec: The workspace specification. - """ - self.reserve(spec) - def reserve_simultaneous(self, *specs: "WorkspaceSpec") -> None: """Reserve workspace memory for multiple specs simultaneously. @@ -234,9 +202,6 @@ def _increase_size( required_bytes: The number of bytes required. name: Name for debugging/logging. """ - # Manager owns a single device; no cross-device assertions needed - - # Check if we need to grow the workspace current_size = self._workspace_size_bytes(self._current_workspaces[0]) if self._locked and current_size < required_bytes: raise AssertionError( @@ -246,8 +211,6 @@ def _increase_size( "Workspace growth is not allowed after locking." ) - was_unallocated = self._current_workspaces[0] is None - for ubatch_id in range(self._num_ubatches): current_workspace = self._current_workspaces[ubatch_id] @@ -261,25 +224,37 @@ def _increase_size( if envs.VLLM_DEBUG_WORKSPACE: total_mb = required_bytes * self._num_ubatches / _MB - if was_unallocated: - logger.info( - "[WORKSPACE DEBUG] Allocated workspace '%s': %.2f MB " - "(%d ubatches, total memory %.2f MB)", - name, - required_bytes / _MB, - self._num_ubatches, - total_mb, - ) - else: - logger.info( - "[WORKSPACE DEBUG] Resized workspace '%s': %.2f MB -> %.2f " - "MB (%d ubatches, total memory %.2f MB)", - name, - current_size / _MB, - required_bytes / _MB, - self._num_ubatches, - total_mb, - ) + logger.info( + "[WORKSPACE DEBUG] Resized workspace '%s': %.2f MB -> %.2f " + "MB (%d ubatches, total memory %.2f MB)", + name, + current_size / _MB, + required_bytes / _MB, + self._num_ubatches, + total_mb, + ) + + +def is_workspace_manager_initialized() -> bool: + """Check if workspace manager has been initialized. + + Returns: + True if workspace manager is initialized, False otherwise. + """ + return _manager is not None + + +def current_workspace_manager() -> "WorkspaceManager": + """Get the current workspace manager instance. + + Raises: + AssertionError: If workspace manager has not been initialized. + """ + assert _manager is not None, ( + "WorkspaceManager not initialized. Call init_workspace_manager() " + "with a device before using workspace functions." + ) + return _manager def init_workspace_manager(device: torch.device, vllm_config: VllmConfig) -> None: From 90145645770770ddc414a9dc31258746bede0456 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 26 Oct 2025 14:39:15 -0700 Subject: [PATCH 08/16] cleanup Signed-off-by: Lucas Wilkinson --- csrc/cache_kernels.cu | 3 +- .../attention/backends/mla/flashmla_sparse.py | 480 +++++++++--------- 2 files changed, 241 insertions(+), 242 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index f34f5bbc86c3..0fab0bd5e6f5 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1277,7 +1277,6 @@ void cp_gather_cache( } } -// Host function to launch the gather-and-upconvert kernel void cp_gather_and_upconvert_fp8_kv_cache( torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] torch::Tensor const& dst, // [TOT_TOKENS, 576] @@ -1368,7 +1367,7 @@ void indexer_k_quant_and_cache( CALL_INDEXER_K_QUANT_AND_CACHE); } -// Macro to dispatch the kernel based on the data type. +// Macro to dispatch the kernel based on the data amount. #define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \ vllm::cp_gather_indexer_k_quant_cache_kernel \ << list[tuple[int, int]]: - """ - Split the prefill chunks into a list of tuples of (reqs_start, reqs_end) - such that the total sequence length of each chunk is less than the - maximum prefill buffer size. +class FlashMLASparseBackend(AttentionBackend): + accept_output_buffer: bool = True - Args: - seq_lens_cpu: The sequence lengths of the prefill requests. - max_prefill_buffer_size: The maximum prefill buffer size. + @staticmethod + def get_name() -> str: + return "FLASHMLA_SPARSE" - Returns: - A list of tuples of (reqs_start, reqs_end). - """ - chunk_bounds = [] - i, n = 0, len(seq_lens_cpu) - assert torch.all(seq_lens_cpu <= workspace_size).item() + @staticmethod + def get_metadata_cls() -> type[AttentionMetadata]: + return FlashMLASparseMetadata - while i < n: - start, total = i, 0 - while i < n and (total + (cur := seq_lens_cpu[i].item())) <= workspace_size: - total += cur - i += 1 - chunk_bounds.append((start, i)) - return chunk_bounds + @staticmethod + def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: + return FlashMLASparseMetadataBuilder + @staticmethod + def get_impl_cls() -> type["FlashMLASparseImpl"]: + return FlashMLASparseImpl -def get_prefill_workspace_size(vllm_config: VllmConfig): - max_model_len = vllm_config.model_config.max_model_len - # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size. - # May be tuned later. - # Memory usage: 5 * max_model_len * 576 * 2 bytes - # Example: DeepSeek-V3.2 with max_model_len=163840 -> - # 5 * 163840 * 576 * 2 = ~900 MB - # This fits nicely below the typical MoE workspace size of >2GB so this is "free" - return max_model_len * 5 + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if cache_dtype_str == "fp8_ds_mla": + # custom storage fromat is 656 bytes + # see FlashMLA readme.md for details + return (num_blocks, block_size, 656) + else: + return (num_blocks, block_size, head_size) + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] @dataclass @@ -150,6 +154,210 @@ class FP8KernelMetadata: fp8_extra_metadata: FP8KernelMetadata | None = None +@triton.jit +def _convert_req_index_to_global_index_kernel( + req_id_ptr, # int32 [num_tokens] + block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] + token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill + workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr + # shapes (compile-time where possible) + max_num_blocks_per_req: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, # tile width along columns + HAS_PREFILL: tl.constexpr, + # strides (in elements) + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, +): + # program_id(0) -> token_id (row) + # program_id(1) -> tile index along columns + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + # Each program covers BLOCK_N consecutive columns + indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load request id for this token (no mask: grid is exact) + req = tl.load(req_id_ptr + token_id) + + # Load prefill request id if prefill support is enabled + is_prefill = tl.full((BLOCK_N,), 0, dtype=tl.bool) + if HAS_PREFILL: + prefill_req_id = tl.load(prefill_request_id_ptr + token_id) + is_prefill = prefill_req_id >= 0 + + # Load token indices for this tile + ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 + tok = tl.load(ti_ptr) # int32 + + # Only token == -1 should propagate as -1 + is_invalid_tok = tok < 0 + + # Compute block id and in-block offset + block_id = tok // BLOCK_SIZE + inblock_off = tok % BLOCK_SIZE + + # Guard block_table access + valid_block = block_id < max_num_blocks_per_req + bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 + base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0) + + if HAS_PREFILL: + workspace_start = tl.load( + workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 + ) + prefill_out = workspace_start + tok + decode_out = tl.where(valid_block, base * BLOCK_SIZE + inblock_off, -1) + out_val = tl.where(is_prefill, prefill_out, decode_out) + out_val = tl.where(is_invalid_tok, -1, out_val) + else: + # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset + out_val = tl.where( + is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off + ) + + # Store results + out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 + tl.store(out_ptr_ij, out_val) + + +def triton_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns + prefill_request_id: torch.Tensor | None = None, + workspace_starts: torch.Tensor | None = None, +): + """ + out[token_id, indice_id] = + block_table[req_id[token_id], + token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + + token_indices[token_id, indice_id] % BLOCK_SIZE + + Only when token_indices[token_id, indice_id] == -1 do we output -1. + For safety, we also output -1 if the derived block_id would be + out-of-bounds. + + When prefill_request_id and workspace_starts are provided, prefill tokens + are mapped to workspace offsets instead of global cache slots. + """ + assert req_id.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert token_indices.dtype == torch.int32 + assert token_indices.shape[1] == NUM_TOPK_TOKENS + assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" + ) + + has_prefill = prefill_request_id is not None and workspace_starts is not None + if has_prefill: + assert prefill_request_id is not None + assert workspace_starts is not None + assert prefill_request_id.dtype == torch.int32 + assert workspace_starts.dtype == torch.int32 + + num_tokens = req_id.shape[0] + num_requests, max_num_blocks_per_req = block_table.shape + tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N + + # Ensure contiguous tensors on the same device + req_id_c = req_id.contiguous() + block_table_c = block_table.contiguous() + token_indices_c = token_indices.contiguous() + out = torch.empty_like(token_indices_c) + + # Strides in elements + bt_stride0, bt_stride1 = block_table_c.stride() + ti_stride0, ti_stride1 = token_indices_c.stride() + out_stride0, out_stride1 = out.stride() + + # Prepare prefill pointers + if has_prefill: + assert prefill_request_id is not None # for mypy + assert workspace_starts is not None # for mypy + prefill_request_id_c = prefill_request_id.contiguous() + workspace_starts_c = workspace_starts.contiguous() + prefill_request_id_ptr = prefill_request_id_c + workspace_starts_ptr = workspace_starts_c + else: + # Dummy pointers (won't be accessed when HAS_PREFILL=False) + prefill_request_id_ptr = req_id_c + workspace_starts_ptr = req_id_c + + # Exact 2D grid: tokens × column tiles + grid = (num_tokens, tiles_per_row) + + _convert_req_index_to_global_index_kernel[grid]( + req_id_c, + block_table_c, + token_indices_c, + out, + prefill_request_id_ptr, + workspace_starts_ptr, + # shapes / constexprs + max_num_blocks_per_req, + BLOCK_SIZE, + BLOCK_N, + has_prefill, + # strides + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, + ) + return out + + +def split_prefill_chunks( + seq_lens_cpu: torch.Tensor, workspace_size: int +) -> list[tuple[int, int]]: + """ + Split the prefill chunks into a list of tuples of (reqs_start, reqs_end) + such that the total sequence length of each chunk is less than the + maximum prefill buffer size. + + Args: + seq_lens_cpu: The sequence lengths of the prefill requests. + max_prefill_buffer_size: The maximum prefill buffer size. + + Returns: + A list of tuples of (reqs_start, reqs_end). + """ + chunk_bounds = [] + i, n = 0, len(seq_lens_cpu) + assert torch.all(seq_lens_cpu <= workspace_size).item() + + while i < n: + start, total = i, 0 + while i < n and (total + (cur := seq_lens_cpu[i].item())) <= workspace_size: + total += cur + i += 1 + chunk_bounds.append((start, i)) + return chunk_bounds + + +def get_prefill_workspace_size(vllm_config: VllmConfig): + max_model_len = vllm_config.model_config.max_model_len + # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size. + # May be tuned later. + # Memory usage: 5 * max_model_len * 576 * 2 bytes + # Example: DeepSeek-V3.2 with max_model_len=163840 -> + # 5 * 163840 * 576 * 2 = ~900 MB + # This fits nicely below the typical MoE workspace size of >2GB so this is "free" + return max_model_len * 5 + + class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH @@ -395,214 +603,6 @@ def build( return metadata -@triton.jit -def _convert_req_index_to_global_index_kernel( - req_id_ptr, # int32 [num_tokens] - block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] - token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] - out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] - prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill - workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr - # shapes (compile-time where possible) - max_num_blocks_per_req: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - BLOCK_N: tl.constexpr, # tile width along columns - HAS_PREFILL: tl.constexpr, - # strides (in elements) - bt_stride0, - bt_stride1, - ti_stride0, - ti_stride1, - out_stride0, - out_stride1, -): - # program_id(0) -> token_id (row) - # program_id(1) -> tile index along columns - token_id = tl.program_id(0) - tile_id = tl.program_id(1) - - # Each program covers BLOCK_N consecutive columns - indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) - - # Load request id for this token (no mask: grid is exact) - req = tl.load(req_id_ptr + token_id) - - # Load prefill request id if prefill support is enabled - is_prefill = tl.full((BLOCK_N,), 0, dtype=tl.bool) - if HAS_PREFILL: - prefill_req_id = tl.load(prefill_request_id_ptr + token_id) - is_prefill = prefill_req_id >= 0 - - # Load token indices for this tile - ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 - tok = tl.load(ti_ptr) # int32 - - # Only token == -1 should propagate as -1 - is_invalid_tok = tok < 0 - - # Compute block id and in-block offset - block_id = tok // BLOCK_SIZE - inblock_off = tok % BLOCK_SIZE - - # Guard block_table access - valid_block = block_id < max_num_blocks_per_req - bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 - base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0) - - if HAS_PREFILL: - workspace_start = tl.load( - workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 - ) - prefill_out = workspace_start + tok - decode_out = tl.where(valid_block, base * BLOCK_SIZE + inblock_off, -1) - out_val = tl.where(is_prefill, prefill_out, decode_out) - out_val = tl.where(is_invalid_tok, -1, out_val) - else: - # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset - out_val = tl.where( - is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off - ) - - # Store results - out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 - tl.store(out_ptr_ij, out_val) - - -def triton_convert_req_index_to_global_index( - req_id: torch.Tensor, # int32 [num_tokens] - block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] - token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] - BLOCK_SIZE: int = 64, - NUM_TOPK_TOKENS: int = 2048, - BLOCK_N: int = 128, # tile width along columns - prefill_request_id: torch.Tensor | None = None, - workspace_starts: torch.Tensor | None = None, -): - """ - out[token_id, indice_id] = - block_table[req_id[token_id], - token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE - + token_indices[token_id, indice_id] % BLOCK_SIZE - - Only when token_indices[token_id, indice_id] == -1 do we output -1. - For safety, we also output -1 if the derived block_id would be - out-of-bounds. - - When prefill_request_id and workspace_starts are provided, prefill tokens - are mapped to workspace offsets instead of global cache slots. - """ - assert req_id.dtype == torch.int32 - assert block_table.dtype == torch.int32 - assert token_indices.dtype == torch.int32 - assert token_indices.shape[1] == NUM_TOPK_TOKENS - assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( - f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" - ) - - has_prefill = prefill_request_id is not None and workspace_starts is not None - if has_prefill: - assert prefill_request_id is not None - assert workspace_starts is not None - assert prefill_request_id.dtype == torch.int32 - assert workspace_starts.dtype == torch.int32 - - num_tokens = req_id.shape[0] - num_requests, max_num_blocks_per_req = block_table.shape - tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N - - # Ensure contiguous tensors on the same device - req_id_c = req_id.contiguous() - block_table_c = block_table.contiguous() - token_indices_c = token_indices.contiguous() - out = torch.empty_like(token_indices_c) - - # Strides in elements - bt_stride0, bt_stride1 = block_table_c.stride() - ti_stride0, ti_stride1 = token_indices_c.stride() - out_stride0, out_stride1 = out.stride() - - # Prepare prefill pointers - if has_prefill: - assert prefill_request_id is not None # for mypy - assert workspace_starts is not None # for mypy - prefill_request_id_c = prefill_request_id.contiguous() - workspace_starts_c = workspace_starts.contiguous() - prefill_request_id_ptr = prefill_request_id_c - workspace_starts_ptr = workspace_starts_c - else: - # Dummy pointers (won't be accessed when HAS_PREFILL=False) - prefill_request_id_ptr = req_id_c - workspace_starts_ptr = req_id_c - - # Exact 2D grid: tokens × column tiles - grid = (num_tokens, tiles_per_row) - - _convert_req_index_to_global_index_kernel[grid]( - req_id_c, - block_table_c, - token_indices_c, - out, - prefill_request_id_ptr, - workspace_starts_ptr, - # shapes / constexprs - max_num_blocks_per_req, - BLOCK_SIZE, - BLOCK_N, - has_prefill, - # strides - bt_stride0, - bt_stride1, - ti_stride0, - ti_stride1, - out_stride0, - out_stride1, - ) - return out - - -class FlashMLASparseBackend(AttentionBackend): - accept_output_buffer: bool = True - - @staticmethod - def get_name() -> str: - return "FLASHMLA_SPARSE" - - @staticmethod - def get_metadata_cls() -> type[AttentionMetadata]: - return FlashMLASparseMetadata - - @staticmethod - def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: - return FlashMLASparseMetadataBuilder - - @staticmethod - def get_impl_cls() -> type["FlashMLASparseImpl"]: - return FlashMLASparseImpl - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if cache_dtype_str == "fp8_ds_mla": - # custom storage fromat is 656 bytes - # see FlashMLA readme.md for details - return (num_blocks, block_size, 656) - else: - return (num_blocks, block_size, head_size) - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [576] - - class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): def __init__( self, From 50a8571a3a27b31d941cb58f8767c1c379f616cd Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 26 Oct 2025 14:50:40 -0700 Subject: [PATCH 09/16] cleanup Signed-off-by: Lucas Wilkinson --- .../attention/backends/mla/flashmla_sparse.py | 51 ++++++++++--------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 58f160d44ce5..c7e3597bf51e 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -18,7 +18,7 @@ flash_mla_with_kvcache, get_mla_metadata, ) -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -639,8 +639,7 @@ def __init__( self.topk_indices_buffer = indexer.topk_indices_buffer self.padding = 128 if current_platform.is_device_capability(100) else 64 - # Reserve fixed workspace for prefill upconversion - vllm_config = indexer.vllm_config + vllm_config = get_current_vllm_config() prefill_workspace_size = get_prefill_workspace_size(vllm_config) self.prefill_workspace_spec = WorkspaceSpec( @@ -755,11 +754,25 @@ def forward( topk_indices = self.topk_indices_buffer[:num_actual_toks] + # Convert per-request indices to global slots (decode) or workspace + # offsets (prefill). + # prefill_workspace_starts has been adjusted in-place per chunk so + # prefill indices automatically come out chunk-local + topk_indices_global = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=topk_indices.shape[1], + prefill_request_id=attn_metadata.prefill_request_id, + workspace_starts=attn_metadata.prefill_workspace_starts, + ) + use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla" q = torch.cat([ql_nope, q_pe], dim=-1) - # CRITICAL: Write to KV cache FIRST before gathering/upconverting + # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( k_c_normed, @@ -773,34 +786,14 @@ def forward( num_decode_tokens = attn_metadata.num_decode_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens - # Convert per-request indices to global slots (decode) or workspace - # offsets (prefill). - # prefill_workspace_starts has been adjusted in-place per chunk so - # prefill indices automatically come out chunk-local - topk_indices_global = triton_convert_req_index_to_global_index( - attn_metadata.req_id_per_token, - attn_metadata.block_table, - topk_indices, - BLOCK_SIZE=attn_metadata.block_size, - NUM_TOPK_TOKENS=topk_indices.shape[1], - prefill_request_id=attn_metadata.prefill_request_id, - workspace_starts=attn_metadata.prefill_workspace_starts, - ) - if not use_fp8_cache: attn_out = self._forward_bf16_kv( q, kv_cache, topk_indices_global, attn_metadata ) else: - attn_out = q.new_empty( - (num_actual_toks, self.num_heads, self.kv_lora_rank), - dtype=q.dtype, - device=q.device, - ) - # Process decode tokens if num_decode_tokens > 0: - attn_out[:num_decode_tokens] = self._forward_fp8_kv( + attn_out = self._forward_fp8_kv( q[:num_decode_tokens], kv_cache, topk_indices_global[:num_decode_tokens], @@ -808,6 +801,14 @@ def forward( ) if num_prefill_tokens > 0: + decode_attn_out = attn_out + attn_out = q.new_empty( + (num_actual_toks, self.num_heads, self.kv_lora_rank), + dtype=q.dtype, + device=q.device, + ) + attn_out[:num_prefill_tokens] = decode_attn_out[:num_prefill_tokens] + assert attn_metadata.prefill_chunks is not None prefill_bf16_workspace = current_workspace_manager().get( self.prefill_workspace_spec From 09ba3e0f0f5f728f97ede5dfa71e9da13927cfac Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 26 Oct 2025 14:54:22 -0700 Subject: [PATCH 10/16] cleanup Signed-off-by: Lucas Wilkinson --- vllm/v1/worker/gpu_worker.py | 1 - vllm/v1/worker/workspace.py | 8 -------- 2 files changed, 9 deletions(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 62681cd09c1a..5f0647f7468d 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -356,7 +356,6 @@ def determine_available_memory(self) -> int: ) gc.collect() - # Workspaces are now allocated dynamically, no need to pre-reserve memory return int(self.available_kv_cache_memory_bytes) def get_kv_connector_handshake_metadata(self) -> dict | None: diff --git a/vllm/v1/worker/workspace.py b/vllm/v1/worker/workspace.py index d7bfaa0d5d94..8953c29fe2cd 100644 --- a/vllm/v1/worker/workspace.py +++ b/vllm/v1/worker/workspace.py @@ -105,10 +105,6 @@ def reserve(self, spec: "WorkspaceSpec") -> None: # increase the workspace size, so that reserve must be called before `get`. # This will encourage the use of reserve which is mostly just useful for # grepping/auditing the codebase. - - # Note: We don't assert !locked here because reserve can be called - # during forward passes. The actual locking logic is in _increase_size. - # Call get() to perform the actual allocation self.get(spec) def reserve_simultaneous(self, *specs: "WorkspaceSpec") -> None: @@ -124,10 +120,6 @@ def reserve_simultaneous(self, *specs: "WorkspaceSpec") -> None: # increase the workspace size, so that reserve must be called before `get`. # This will encourage the use of reserve which is mostly just useful for # grepping/auditing the codebase. - - # Note: We don't assert !locked here because reserve can be called - # during forward passes. The actual locking logic is in _increase_size. - # Call get_simultaneous() to perform the actual allocation self.get_simultaneous(*specs) def get(self, spec: "WorkspaceSpec") -> torch.Tensor: From 579c51ce10c9e04074e8af4f87495a20d7f19f75 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 26 Oct 2025 15:23:13 -0700 Subject: [PATCH 11/16] keep Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index c7e3597bf51e..765b1234cce1 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -187,7 +187,7 @@ def _convert_req_index_to_global_index_kernel( req = tl.load(req_id_ptr + token_id) # Load prefill request id if prefill support is enabled - is_prefill = tl.full((BLOCK_N,), 0, dtype=tl.bool) + is_prefill = tl.full((BLOCK_N,), 0, dtype=tl.int1) if HAS_PREFILL: prefill_req_id = tl.load(prefill_request_id_ptr + token_id) is_prefill = prefill_req_id >= 0 From 9ddee463d6825b4769e1b335feebf86c7f4634e7 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 27 Oct 2025 23:16:24 -0700 Subject: [PATCH 12/16] fix Signed-off-by: Lucas Wilkinson --- .../attention/backends/mla/flashmla_sparse.py | 46 +++++++++++-------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 765b1234cce1..b3e6faf74ceb 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -187,7 +187,6 @@ def _convert_req_index_to_global_index_kernel( req = tl.load(req_id_ptr + token_id) # Load prefill request id if prefill support is enabled - is_prefill = tl.full((BLOCK_N,), 0, dtype=tl.int1) if HAS_PREFILL: prefill_req_id = tl.load(prefill_request_id_ptr + token_id) is_prefill = prefill_req_id >= 0 @@ -199,6 +198,13 @@ def _convert_req_index_to_global_index_kernel( # Only token == -1 should propagate as -1 is_invalid_tok = tok < 0 + # Prefill path: map to workspace offset + if HAS_PREFILL: + workspace_start = tl.load( + workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 + ) + prefill_out = workspace_start + tok + # Compute block id and in-block offset block_id = tok // BLOCK_SIZE inblock_off = tok % BLOCK_SIZE @@ -206,18 +212,15 @@ def _convert_req_index_to_global_index_kernel( # Guard block_table access valid_block = block_id < max_num_blocks_per_req bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 - base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0) + base = tl.load(bt_ptr, mask=valid_block, other=0) + # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset if HAS_PREFILL: - workspace_start = tl.load( - workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 - ) - prefill_out = workspace_start + tok decode_out = tl.where(valid_block, base * BLOCK_SIZE + inblock_off, -1) - out_val = tl.where(is_prefill, prefill_out, decode_out) - out_val = tl.where(is_invalid_tok, -1, out_val) + out_val = tl.where( + is_invalid_tok, -1, tl.where(is_prefill, prefill_out, decode_out) + ) else: - # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset out_val = tl.where( is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off ) @@ -791,24 +794,29 @@ def forward( q, kv_cache, topk_indices_global, attn_metadata ) else: - # Process decode tokens - if num_decode_tokens > 0: + # Pure decode case: direct call without allocation + if num_prefill_tokens == 0: attn_out = self._forward_fp8_kv( - q[:num_decode_tokens], - kv_cache, - topk_indices_global[:num_decode_tokens], - attn_metadata, + q, kv_cache, topk_indices_global, attn_metadata ) - - if num_prefill_tokens > 0: - decode_attn_out = attn_out + else: + # Mixed or pure prefill: allocate output tensor attn_out = q.new_empty( (num_actual_toks, self.num_heads, self.kv_lora_rank), dtype=q.dtype, device=q.device, ) - attn_out[:num_prefill_tokens] = decode_attn_out[:num_prefill_tokens] + # Fill decode portion if present + if num_decode_tokens > 0: + attn_out[:num_decode_tokens] = self._forward_fp8_kv( + q[:num_decode_tokens], + kv_cache, + topk_indices_global[:num_decode_tokens], + attn_metadata, + ) + + # Process prefill chunks assert attn_metadata.prefill_chunks is not None prefill_bf16_workspace = current_workspace_manager().get( self.prefill_workspace_spec From 96157d2fa80b3aff3ef6fc1859fbd088ec72b14f Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 28 Oct 2025 00:59:32 -0700 Subject: [PATCH 13/16] cleanup Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index b3e6faf74ceb..0ca8ba4da7d6 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -186,11 +186,6 @@ def _convert_req_index_to_global_index_kernel( # Load request id for this token (no mask: grid is exact) req = tl.load(req_id_ptr + token_id) - # Load prefill request id if prefill support is enabled - if HAS_PREFILL: - prefill_req_id = tl.load(prefill_request_id_ptr + token_id) - is_prefill = prefill_req_id >= 0 - # Load token indices for this tile ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 tok = tl.load(ti_ptr) # int32 @@ -198,8 +193,10 @@ def _convert_req_index_to_global_index_kernel( # Only token == -1 should propagate as -1 is_invalid_tok = tok < 0 - # Prefill path: map to workspace offset + # Prefill path: load metadata and compute workspace offset if HAS_PREFILL: + prefill_req_id = tl.load(prefill_request_id_ptr + token_id) + is_prefill = prefill_req_id >= 0 workspace_start = tl.load( workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 ) From 6caacbd215fb42c4a9570feaeb8ee735959fed2e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 28 Oct 2025 05:51:35 -0700 Subject: [PATCH 14/16] cleanup Signed-off-by: Lucas Wilkinson --- .../attention/backends/mla/flashmla_sparse.py | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 0ca8ba4da7d6..0d429df3ef17 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -192,16 +192,10 @@ def _convert_req_index_to_global_index_kernel( # Only token == -1 should propagate as -1 is_invalid_tok = tok < 0 - - # Prefill path: load metadata and compute workspace offset + is_prefill = False if HAS_PREFILL: prefill_req_id = tl.load(prefill_request_id_ptr + token_id) is_prefill = prefill_req_id >= 0 - workspace_start = tl.load( - workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 - ) - prefill_out = workspace_start + tok - # Compute block id and in-block offset block_id = tok // BLOCK_SIZE inblock_off = tok % BLOCK_SIZE @@ -209,18 +203,18 @@ def _convert_req_index_to_global_index_kernel( # Guard block_table access valid_block = block_id < max_num_blocks_per_req bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 - base = tl.load(bt_ptr, mask=valid_block, other=0) + is_invalid_tok |= ~valid_block + base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0) + out_val = base * BLOCK_SIZE + inblock_off - # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset + # Override with prefill output if prefill is enabled if HAS_PREFILL: - decode_out = tl.where(valid_block, base * BLOCK_SIZE + inblock_off, -1) - out_val = tl.where( - is_invalid_tok, -1, tl.where(is_prefill, prefill_out, decode_out) - ) - else: - out_val = tl.where( - is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off + workspace_start = tl.load( + workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 ) + prefill_out = workspace_start + tok + out_val = tl.where(is_prefill, prefill_out, out_val) + out_val = tl.where(is_invalid_tok, -1, out_val) # Store results out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 From cd528d12448c2358323133b8dfa7e97cc0986933 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 2 Nov 2025 17:30:32 -0800 Subject: [PATCH 15/16] cleanup Signed-off-by: Lucas Wilkinson --- .../layers/fused_moe/modular_kernel.py | 39 +--- vllm/model_executor/models/deepseek_v2.py | 22 +- .../attention/backends/mla/flashmla_sparse.py | 15 +- vllm/v1/worker/workspace.py | 207 +++++++----------- 4 files changed, 111 insertions(+), 172 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index dec7eeb2ae9f..5485be7ef366 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -24,7 +24,7 @@ dbo_register_recv_hook, dbo_yield, ) -from vllm.v1.worker.workspace import WorkspaceSpec, current_workspace_manager +from vllm.v1.worker.workspace import current_workspace_manager # # This file defines a set of base classes used to make MoE kernels more modular. @@ -766,48 +766,31 @@ def _allocate_buffers( local_num_experts, None, # Pass None to avoid using sampled token counts ) - max_workspace13_spec = WorkspaceSpec( - shape=max_workspace13_shape, - dtype=workspace_dtype, - name="moe.workspace13", - ) - max_workspace2_spec = WorkspaceSpec( - shape=max_workspace2_shape, - dtype=workspace_dtype, - name="moe.workspace2", - ) - max_fused_out_spec = WorkspaceSpec( - shape=max_fused_out_shape, dtype=out_dtype, name="moe.fused_out" - ) - current_workspace_manager().reserve_simultaneous( - max_workspace13_spec, max_workspace2_spec, max_fused_out_spec + + current_workspace_manager().get_simultaneous( + (max_workspace13_shape, workspace_dtype), + (max_workspace2_shape, workspace_dtype), + (max_fused_out_shape, out_dtype), ) # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. - workspace13_spec = WorkspaceSpec( - shape=workspace13_shape, dtype=workspace_dtype, name="moe.workspace13" - ) - workspace2_spec = WorkspaceSpec( - shape=workspace2_shape, dtype=workspace_dtype, name="moe.workspace2" - ) - # Construct the entire output that can then be processed in chunks. # Reuse workspace13 for the output in the non-chunked case as long # as it is large enough. This will not always be the case for standard # format experts and with experts that have empty workspaces. if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape): workspace13, workspace2 = current_workspace_manager().get_simultaneous( - workspace13_spec, workspace2_spec + (workspace13_shape, workspace_dtype), + (workspace2_shape, workspace_dtype), ) fused_out = _resize_cache(workspace13, fused_out_shape) else: - fused_out_spec = WorkspaceSpec( - shape=fused_out_shape, dtype=out_dtype, name="moe.fused_out" - ) workspace13, workspace2, fused_out = ( current_workspace_manager().get_simultaneous( - workspace13_spec, workspace2_spec, fused_out_spec + (workspace13_shape, workspace_dtype), + (workspace2_shape, workspace_dtype), + (fused_out_shape, out_dtype), ) ) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 8ee1a8a831b7..c992d0951c2b 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -86,7 +86,7 @@ DeepseekV32IndexerMetadata, ) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec -from vllm.v1.worker.workspace import WorkspaceSpec, current_workspace_manager +from vllm.v1.worker.workspace import current_workspace_manager from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import ( @@ -520,20 +520,13 @@ def sparse_attn_indexer( # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata - k_fp8_spec = WorkspaceSpec( - shape=(total_seq_lens, head_dim), - dtype=torch.float8_e4m3fn, - name="sparse_attn_indexer.k_fp8", - ) - k_scale_spec = WorkspaceSpec( - shape=(total_seq_lens, 4), - dtype=torch.uint8, - name="sparse_attn_indexer.k_scale", - ) - # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): - current_workspace_manager().reserve_simultaneous(k_fp8_spec, k_scale_spec) + # Reserve workspace for indexer during profiling run + current_workspace_manager().get_simultaneous( + ((total_seq_lens, head_dim), torch.float8_e4m3fn), + ((total_seq_lens, 4), torch.uint8), + ) return sparse_attn_indexer_fake( hidden_states, @@ -572,7 +565,8 @@ def sparse_attn_indexer( # Get the full shared workspace buffers once (will allocate on first use) workspace_manager = current_workspace_manager() k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( - k_fp8_spec, k_scale_spec + ((total_seq_lens, head_dim), torch.float8_e4m3fn), + ((total_seq_lens, 4), torch.uint8), ) for chunk in prefill_metadata.chunks: diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 0d429df3ef17..d31adddfab6d 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -31,7 +31,7 @@ split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.workspace import WorkspaceSpec, current_workspace_manager +from vllm.v1.worker.workspace import current_workspace_manager if TYPE_CHECKING: from vllm.model_executor.models.deepseek_v2 import Indexer @@ -636,14 +636,13 @@ def __init__( vllm_config = get_current_vllm_config() prefill_workspace_size = get_prefill_workspace_size(vllm_config) - self.prefill_workspace_spec = WorkspaceSpec( - shape=(prefill_workspace_size, head_size), - dtype=torch.bfloat16, - name="FlashMLASparseImpl.prefill_workspace", - ) + self.prefill_workspace_shape = (prefill_workspace_size, head_size) if kv_cache_dtype == "fp8_ds_mla": - current_workspace_manager().reserve(self.prefill_workspace_spec) + # Reserve workspace during initialization + current_workspace_manager().get( + self.prefill_workspace_shape, torch.bfloat16 + ) def _forward_bf16_kv( self, @@ -810,7 +809,7 @@ def forward( # Process prefill chunks assert attn_metadata.prefill_chunks is not None prefill_bf16_workspace = current_workspace_manager().get( - self.prefill_workspace_spec + self.prefill_workspace_shape, torch.bfloat16 ) for chunk in attn_metadata.prefill_chunks: diff --git a/vllm/v1/worker/workspace.py b/vllm/v1/worker/workspace.py index 8953c29fe2cd..9c1241859015 100644 --- a/vllm/v1/worker/workspace.py +++ b/vllm/v1/worker/workspace.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass +import inspect +import os from itertools import accumulate from math import prod from typing import Optional @@ -11,28 +12,14 @@ import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import round_up +from vllm.utils.math_utils import round_up from vllm.v1.worker.ubatching import dbo_current_ubatch_id logger = init_logger(__name__) -@dataclass(frozen=True) -class WorkspaceSpec: - """Specification of a workspace to be allocated. - - Attributes: - shape: The shape of the workspace. - dtype: The data type of the workspace. - name: Optional name for debugging. - """ - - shape: tuple[int, ...] - dtype: torch.dtype - name: str = "unnamed" - - def num_bytes(self) -> int: - return prod(self.shape) * self.dtype.itemsize +def _compute_bytes(shape: tuple[int, ...], dtype: torch.dtype) -> int: + return prod(shape) * dtype.itemsize # Constants @@ -92,139 +79,115 @@ def current_allocated_size_bytes(self) -> int: self._current_workspaces[dbo_current_ubatch_id()] ) - def reserve(self, spec: "WorkspaceSpec") -> None: - """Reserve workspace memory for a given spec. - - This is a convenience wrapper around get() that makes it easier to grep - for workspace reservations in the codebase for auditing purposes. - - Args: - spec: The workspace specification. - """ - # TODO(Lucas): Assert that only reserves (reserve/reserve_simultaneous) can - # increase the workspace size, so that reserve must be called before `get`. - # This will encourage the use of reserve which is mostly just useful for - # grepping/auditing the codebase. - self.get(spec) - - def reserve_simultaneous(self, *specs: "WorkspaceSpec") -> None: - """Reserve workspace memory for multiple specs simultaneously. - - This is a convenience wrapper around get_simultaneous() that makes it easier - to grep for workspace reservations in the codebase for auditing purposes. - - Args: - *specs: One or more workspace specifications. - """ - # TODO(Lucas): Assert that only reserves (ds/reserve_simultaneous) can - # increase the workspace size, so that reserve must be called before `get`. - # This will encourage the use of reserve which is mostly just useful for - # grepping/auditing the codebase. - self.get_simultaneous(*specs) - - def get(self, spec: "WorkspaceSpec") -> torch.Tensor: - """Get a workspace tensor for the given spec. + def get(self, shape: tuple[int, ...], dtype: torch.dtype) -> torch.Tensor: + """Get a workspace tensor for the given shape and dtype. Args: - spec: The workspace specification. + shape: The shape of the workspace tensor. + dtype: The data type of the workspace tensor. Returns: A tensor view into the workspace buffer with the requested shape and dtype. """ - num_bytes = spec.num_bytes() - current_workspace = self._ensure_workspace_size(num_bytes, spec.name) - return current_workspace[:num_bytes].view(spec.dtype).reshape(spec.shape) + num_bytes = _compute_bytes(shape, dtype) + current_workspace = self._ensure_workspace_size(num_bytes) + return current_workspace[:num_bytes].view(dtype).reshape(shape) - def get_simultaneous(self, *specs: "WorkspaceSpec") -> list[torch.Tensor]: + def get_simultaneous( + self, *shapes_and_dtypes: tuple[tuple[int, ...], torch.dtype] + ) -> list[torch.Tensor]: """Get multiple workspace tensors simultaneously from a single allocation. Args: - *specs: One or more workspace specifications. + *shapes_and_dtypes: One or more (shape, dtype) tuples. Returns: - List of tensor views into the workspace buffer, one per spec. + List of tensor views into the workspace buffer, one per shape/dtype pair. """ - actual_bytes = [spec.num_bytes() for spec in specs] + actual_bytes = [_compute_bytes(s, d) for s, d in shapes_and_dtypes] aligned_bytes = [round_up(actual, 256) for actual in actual_bytes] total_bytes = sum(aligned_bytes) # Calculate cumulative offsets using itertools.accumulate offsets = list(accumulate([0] + aligned_bytes[:-1])) - workspace_names = ", ".join(spec.name for spec in specs) - current_workspace = self._ensure_workspace_size( - total_bytes, f"[{workspace_names}]" - ) + current_workspace = self._ensure_workspace_size(total_bytes) return [ current_workspace[offsets[i] : offsets[i] + actual_bytes[i]] - .view(specs[i].dtype) - .reshape(specs[i].shape) - for i in range(len(specs)) + .view(shapes_and_dtypes[i][1]) + .reshape(shapes_and_dtypes[i][0]) + for i in range(len(shapes_and_dtypes)) ] - def _ensure_workspace_size(self, num_bytes: int, name: str) -> torch.Tensor: - """Ensure workspace is allocated and large enough, return current workspace.""" - ubatch_id = dbo_current_ubatch_id() - current_workspace = self._current_workspaces[ubatch_id] - - # Manager owns a single device; no cross-device assertions needed - - if self._workspace_size_bytes(current_workspace) < num_bytes: - self._increase_size(num_bytes, name) - current_workspace = self._current_workspaces[ubatch_id] - - return current_workspace - - def _increase_size( - self, - required_bytes: int, - name: str = "unnamed", - ) -> None: - """Allocate or resize workspace for all ubatches. - - If DBO is enabled, allocates for both ubatches. Otherwise, allocates for - ubatch 0. Uses PyTorch's resize_() for efficient in-place resizing when - possible. - - Invariant: Both ubatches always have the same size after this function - completes. + def _ensure_workspace_size(self, required_bytes: int) -> torch.Tensor: + """Ensure workspace is allocated and large enough, return current workspace. Args: required_bytes: The number of bytes required. - name: Name for debugging/logging. - """ - current_size = self._workspace_size_bytes(self._current_workspaces[0]) - if self._locked and current_size < required_bytes: - raise AssertionError( - f"Workspace is locked but allocation for '{name}' requires " - f"{required_bytes / _MB:.2f} MB, current size is " - f"{current_size / _MB:.2f} MB. " - "Workspace growth is not allowed after locking." - ) - for ubatch_id in range(self._num_ubatches): - current_workspace = self._current_workspaces[ubatch_id] + Returns: + The current workspace tensor. + """ + ubatch_id = dbo_current_ubatch_id() + current_workspace = self._current_workspaces[ubatch_id] + current_size = self._workspace_size_bytes(current_workspace) + + if current_size < required_bytes: + + def get_caller_info() -> str: + """Find first frame outside WorkspaceManager.""" + curr_frame = inspect.currentframe() + if curr_frame is None: + return "unknown" + # Walk up the stack skipping WorkspaceManager frames + curr_frame = curr_frame.f_back + while curr_frame is not None: + # TODO: This only catches instance methods (self), missing + # classmethods and staticmethods. Once Python 3.11+ is the + # minimum supported version, use co_qualname instead: + # qualname = curr_frame.f_code.co_qualname + # if qualname.startswith("WorkspaceManager."): + if isinstance(curr_frame.f_locals.get("self"), WorkspaceManager): + curr_frame = curr_frame.f_back + continue + filename = os.path.basename(curr_frame.f_code.co_filename) + return ( + f"{filename}:{curr_frame.f_lineno}:{curr_frame.f_code.co_name}" + ) + return "unknown" + + if self._locked: + raise AssertionError( + f"Workspace is locked but allocation from '{get_caller_info()}' " + f"requires {required_bytes / _MB:.2f} MB, current size is " + f"{current_size / _MB:.2f} MB. " + "Workspace growth is not allowed after locking." + ) - if current_workspace is None: - self._current_workspaces[ubatch_id] = torch.empty( - (required_bytes,), dtype=torch.uint8, device=self._device + for ubatch_id in range(self._num_ubatches): + current_workspace = self._current_workspaces[ubatch_id] + if current_workspace is None: + self._current_workspaces[ubatch_id] = torch.empty( + (required_bytes,), dtype=torch.uint8, device=self._device + ) + elif self._workspace_size_bytes(current_workspace) < required_bytes: + current_workspace.resize_(required_bytes) + + if envs.VLLM_DEBUG_WORKSPACE: + logger.info( + "[WORKSPACE DEBUG] Resized workspace from '%s': %.2f MB -> " + "%.2f MB (%d ubatches, total memory %.2f MB)", + get_caller_info(), + current_size / _MB, + required_bytes / _MB, + self._num_ubatches, + required_bytes * self._num_ubatches / _MB, ) - elif self._workspace_size_bytes(current_workspace) < required_bytes: - # Use resize_() for efficient in-place resizing - current_workspace.resize_(required_bytes) - if envs.VLLM_DEBUG_WORKSPACE: - total_mb = required_bytes * self._num_ubatches / _MB - logger.info( - "[WORKSPACE DEBUG] Resized workspace '%s': %.2f MB -> %.2f " - "MB (%d ubatches, total memory %.2f MB)", - name, - current_size / _MB, - required_bytes / _MB, - self._num_ubatches, - total_mb, - ) + current_workspace = self._current_workspaces[dbo_current_ubatch_id()] + + return current_workspace def is_workspace_manager_initialized() -> bool: @@ -280,8 +243,8 @@ def lock_workspace() -> None: Example: # During initialization init_workspace_manager(device) - reserve_workspace(spec1) - reserve_workspace(spec2) + reserve_workspace(shape1, dtype1) + reserve_workspace(shape2, dtype2) # Lock after warmup/profiling lock_workspace() From 39ba79c9fd3fc9ce9cd2a051fe83c215b3922d88 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 3 Nov 2025 21:35:34 -0800 Subject: [PATCH 16/16] cleanup Signed-off-by: Lucas Wilkinson --- vllm/v1/core/kv_cache_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 22310697f122..6e026215d402 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -819,9 +819,7 @@ def get_num_blocks( available_memory: Memory available for KV cache in bytes. page_size: The page size of the KV cache. """ - num_blocks = int(available_memory // page_size // num_layers) - num_blocks = max(num_blocks, 0) num_blocks = may_override_num_blocks(vllm_config, num_blocks) return num_blocks