Skip to content
12 changes: 11 additions & 1 deletion csrc/cache.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <torch/all.h>
#include <c10/util/Optional.h>

#include <map>
#include <vector>
Expand Down Expand Up @@ -57,6 +58,15 @@ void cp_gather_cache(
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);

// Gather and upconvert FP8 KV cache to BF16 workspace
void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size);

// Indexer K quantization and cache function
void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim]
Expand All @@ -71,4 +81,4 @@ void cp_gather_indexer_k_quant_cache(
torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
131 changes: 130 additions & 1 deletion csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAException.h>
#include <c10/util/Optional.h>

#include "cuda_utils.h"
#include "cuda_compat.h"
Expand Down Expand Up @@ -514,7 +515,8 @@ __global__ void indexer_k_quant_and_cache_kernel(
const int quant_block_size, // quantization block size
const int cache_block_size, // cache block size
const int cache_stride, // stride for each token in kv_cache
const bool use_ue8m0 // use ue8m0 scale format

const bool use_ue8m0 // use ue8m0 scale format
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
Expand Down Expand Up @@ -1058,6 +1060,82 @@ void gather_and_maybe_dequant_cache(
}

namespace vllm {

// Gather and upconvert FP8 KV cache tokens to BF16 workspace
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
__nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ seq_lens, // [BATCH]
const int32_t* __restrict__ workspace_starts, // [BATCH]
const int32_t block_size, const int32_t head_dim,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride) {
const int64_t bid = blockIdx.x; // Batch ID
const int32_t num_splits = gridDim.y;
const int32_t split = blockIdx.y;
const int32_t seq_start = workspace_starts[bid];
const int32_t seq_len = seq_lens[bid];
const int32_t tot_slots = seq_len;
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);

const int32_t split_start = split * split_slots;
const int32_t split_end = min((split + 1) * split_slots, tot_slots);

const bool is_active_split = (split_start < tot_slots);

if (!is_active_split) return;

// Adjust the pointer for the block_table for this batch
const int32_t batch_offset = bid * block_table_stride;
int32_t offset = split_start;
int32_t offset_div = offset / block_size;
offset = offset % block_size;
const int32_t* batch_block_table = block_table + batch_offset;

// Adjust dst pointer based on the cumulative sequence lengths
dst += seq_start * dst_entry_stride;

const int tid = threadIdx.x;

// Process each token in this split
for (int pid = split_start; pid < split_end; ++pid) {
auto block_id = batch_block_table[offset_div];
const uint8_t* token_ptr =
src_cache + block_id * cache_block_stride + offset * cache_entry_stride;
__nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride;

// FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16)
const uint8_t* no_pe_ptr = token_ptr;
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
const __nv_bfloat16* rope_ptr =
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);

// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
if (tid < 512) {
// FP8 dequantization
const int tile = tid >> 7; // each tile is 128 elements
const float scale = scales_ptr[tile];
const uint8_t val = no_pe_ptr[tid];
dst_ptr[tid] =
fp8::scaled_convert<__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
} else if (tid < 576) {
// Rope copy (64 bf16 elements)
const int rope_idx = tid - 512;
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
}

// Move to next token
offset += 1;
if (offset == block_size) {
offset_div += 1;
offset = 0;
}
}
}

template <typename scalar_t>
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
// block_size.
Expand Down Expand Up @@ -1199,6 +1277,57 @@ void cp_gather_cache(
}
}

void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

int32_t block_size = src_cache.size(1);
int32_t head_dim = dst.size(1);

TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32");
TORCH_CHECK(workspace_starts.dtype() == torch::kInt32,
"workspace_starts must be int32");

TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == seq_lens.device(),
"src_cache and seq_lens must be on the same device");
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
"src_cache and workspace_starts must be on the same device");

TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8");
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");

int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);

// Decide on the number of splits based on the batch size
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
dim3 grid(batch_size, num_splits);
dim3 block(576);

vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
src_cache.data_ptr<uint8_t>(),
reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
block_table_stride, cache_block_stride, cache_entry_stride,
dst_entry_stride);
}

// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
Expand Down
6 changes: 6 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);

torch::Tensor convert_logical_index_to_physical_index(
torch::Tensor req_id, torch::Tensor block_table,
torch::Tensor token_indices, int64_t block_size,
const std::optional<torch::Tensor>& prefill_request_id,
const std::optional<torch::Tensor>& workspace_starts);

void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
Expand Down
7 changes: 7 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);

cache_ops.def(
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
"batch_size) -> ()");
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
&cp_gather_and_upconvert_fp8_kv_cache);

cache_ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
Expand Down
23 changes: 23 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2149,6 +2149,29 @@ def cp_gather_cache(
)


def cp_gather_and_upconvert_fp8_kv_cache(
src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
workspace_starts: torch.Tensor,
batch_size: int,
) -> None:
"""Gather and upconvert FP8 KV cache to BF16 workspace.

Args:
src_cache: FP8 KV cache [num_blocks, block_size, 656]
dst: BF16 output workspace [total_tokens, 576]
block_table: Block indices [num_reqs, max_blocks]
seq_lens: Sequence lengths [num_reqs]
workspace_starts: Workspace start offsets [num_reqs]
batch_size: Number of requests
"""
torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache(
src_cache, dst, block_table, seq_lens, workspace_starts, batch_size
)


def indexer_k_quant_and_cache(
k: torch.Tensor,
kv_cache: torch.Tensor,
Expand Down
4 changes: 4 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@
VLLM_NCCL_INCLUDE_PATH: str | None = None
VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = ""
VLLM_DEBUG_WORKSPACE: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"

Expand Down Expand Up @@ -1439,6 +1440,9 @@ def get_vllm_port() -> int | None:
# - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with
# top 5 collected objects
"VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""),
# Debug workspace allocations.
# logging of workspace resize operations.
"VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))),
# Disables parallel execution of shared_experts via separate cuda stream
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv(
"VLLM_DISABLE_SHARED_EXPERTS_STREAM", False
Expand Down
Loading