Skip to content

Commit 9490426

Browse files
committed
kv cache quantization
1 parent 070546e commit 9490426

File tree

3 files changed

+66
-51
lines changed

3 files changed

+66
-51
lines changed

src/runtime/relax_vm/paged_kv_cache.cc

Lines changed: 56 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
853853
/*!
854854
* \brief The KV data managed by the KV cache.
855855
* The array has `num_layers` NDArrays, each of them
856-
* has layout (num_pages, 2, num_heads, page_size, head_dim).
856+
* has layout (num_pages, 2, num_heads, page_size, num_storage).
857857
* Along on the "2" dimension, index 0 stands for K and 1 stands for V.
858858
*/
859859
Array<NDArray> pages_;
@@ -985,10 +985,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
985985
int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim,
986986
int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size,
987987
bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, double rotary_theta,
988-
DLDataType dtype, Device device, PackedFunc f_transpose_append, PackedFunc f_compact_copy,
989-
PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
990-
PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window,
991-
PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_with_tree_mask,
988+
int64_t num_storage, DLDataType dtype, DLDataType kv_storage_dtype, Device device,
989+
PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc f_attention_prefill,
990+
PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window,
991+
PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged,
992+
PackedFunc f_attention_prefill_with_tree_mask,
992993
Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
993994
Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
994995
Optional<PackedFunc> f_attention_prefill_begin_forward,
@@ -1030,8 +1031,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
10301031
device_(device) {
10311032
pages_.reserve(num_layers);
10321033
for (int i = 0; i < num_layers; ++i) {
1033-
pages_.push_back(
1034-
NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size, head_dim}, dtype, device));
1034+
pages_.push_back(NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size, num_storage},
1035+
kv_storage_dtype, device));
10351036
}
10361037
// Allocate the host memory.
10371038
Device preferred_host_device = GetPreferredHostDevice(device);
@@ -1673,8 +1674,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
16731674
NDArray o_data, double attn_score_scaling_factor) final {
16741675
// Part 1. Shape and dtype check.
16751676
NDArray pages = pages_[layer_id];
1676-
CHECK(qkv_data.DataType() == pages.DataType());
1677-
CHECK(o_data.DataType() == pages.DataType());
16781677

16791678
// qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, head_dim)
16801679
// o_data: (num_total_length, num_qo_heads, head_dim)
@@ -2433,7 +2432,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);
24332432

24342433
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
24352434
.set_body([](TVMArgs args, TVMRetValue* rv) {
2436-
CHECK(args.size() == 25 || args.size() == 26 || args.size() == 27)
2435+
CHECK(args.size() == 27 || args.size() == 28 || args.size() == 29)
24372436
<< "Invalid number of KV cache constructor args.";
24382437
ShapeTuple cache_config = args[0];
24392438
int64_t num_layers = args[1];
@@ -2443,31 +2442,33 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
24432442
int rope_mode = args[5];
24442443
double rotary_scale = args[6];
24452444
double rotary_theta = args[7];
2446-
NDArray init = args[8];
2447-
PackedFunc f_transpose_append = args[9];
2448-
PackedFunc f_attention_prefill = args[10];
2449-
PackedFunc f_attention_decode = args[11];
2450-
PackedFunc f_attention_prefill_sliding_window = args[12];
2451-
PackedFunc f_attention_decode_sliding_window = args[13];
2452-
PackedFunc f_attention_prefill_ragged = args[14];
2453-
PackedFunc f_attention_prefill_ragged_begin_forward = args[15];
2454-
PackedFunc f_attention_prefill_ragged_end_forward = args[16];
2455-
PackedFunc f_attention_prefill_begin_forward = args[17];
2456-
PackedFunc f_attention_prefill_end_forward = args[18];
2457-
PackedFunc f_attention_decode_begin_forward = args[19];
2458-
PackedFunc f_attention_decode_end_forward = args[20];
2459-
PackedFunc f_merge_inplace = args[21];
2460-
PackedFunc f_split_rotary = args[22];
2461-
PackedFunc f_copy_single_page = args[23];
2462-
Optional<PackedFunc> f_debug_get_kv = args[24];
2445+
int64_t num_storage = args[8];
2446+
NDArray init = args[9];
2447+
NDArray kv_storage_init = args[10];
2448+
PackedFunc f_transpose_append = args[11];
2449+
PackedFunc f_attention_prefill = args[12];
2450+
PackedFunc f_attention_decode = args[13];
2451+
PackedFunc f_attention_prefill_sliding_window = args[14];
2452+
PackedFunc f_attention_decode_sliding_window = args[15];
2453+
PackedFunc f_attention_prefill_ragged = args[16];
2454+
PackedFunc f_attention_prefill_ragged_begin_forward = args[17];
2455+
PackedFunc f_attention_prefill_ragged_end_forward = args[18];
2456+
PackedFunc f_attention_prefill_begin_forward = args[19];
2457+
PackedFunc f_attention_prefill_end_forward = args[20];
2458+
PackedFunc f_attention_decode_begin_forward = args[21];
2459+
PackedFunc f_attention_decode_end_forward = args[22];
2460+
PackedFunc f_merge_inplace = args[23];
2461+
PackedFunc f_split_rotary = args[24];
2462+
PackedFunc f_copy_single_page = args[25];
2463+
Optional<PackedFunc> f_debug_get_kv = args[26];
24632464
PackedFunc f_compact_copy{nullptr};
24642465
PackedFunc f_attention_prefill_with_tree_mask{nullptr};
24652466

2466-
if (args.size() >= 26) {
2467-
f_compact_copy = args[25].AsObjectRef<PackedFunc>();
2467+
if (args.size() >= 28) {
2468+
f_compact_copy = args[27].AsObjectRef<PackedFunc>();
24682469
}
2469-
if (args.size() >= 27) {
2470-
f_attention_prefill_with_tree_mask = args[26].AsObjectRef<PackedFunc>();
2470+
if (args.size() >= 29) {
2471+
f_attention_prefill_with_tree_mask = args[28].AsObjectRef<PackedFunc>();
24712472
}
24722473

24732474
CHECK_EQ(cache_config.size(), 5);
@@ -2484,8 +2485,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
24842485
ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
24852486
page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs,
24862487
num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode),
2487-
rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append),
2488-
std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode),
2488+
rotary_scale, rotary_theta, num_storage, init->dtype, kv_storage_init->dtype,
2489+
init->device, std::move(f_transpose_append), std::move(f_compact_copy),
2490+
std::move(f_attention_prefill), std::move(f_attention_decode),
24892491
std::move(f_attention_prefill_sliding_window),
24902492
std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged),
24912493
std::move(f_attention_prefill_with_tree_mask),
@@ -2500,7 +2502,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
25002502

25012503
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
25022504
.set_body([](TVMArgs args, TVMRetValue* rv) {
2503-
CHECK(args.size() == 19 || args.size() == 20 || args.size() == 21)
2505+
CHECK(args.size() == 21 || args.size() == 22 || args.size() == 23)
25042506
<< "Invalid number of KV cache constructor args.";
25052507
ShapeTuple cache_config = args[0];
25062508
int64_t num_layers = args[1];
@@ -2510,25 +2512,27 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
25102512
int rope_mode = args[5];
25112513
double rotary_scale = args[6];
25122514
double rotary_theta = args[7];
2513-
NDArray init = args[8];
2514-
PackedFunc f_transpose_append = args[9];
2515-
PackedFunc f_attention_prefill = args[10];
2516-
PackedFunc f_attention_decode = args[11];
2517-
PackedFunc f_attention_prefill_sliding_window = args[12];
2518-
PackedFunc f_attention_decode_sliding_window = args[13];
2519-
PackedFunc f_attention_prefill_ragged = args[14];
2520-
PackedFunc f_merge_inplace = args[15];
2521-
PackedFunc f_split_rotary = args[16];
2522-
PackedFunc f_copy_single_page = args[17];
2523-
Optional<PackedFunc> f_debug_get_kv = args[18];
2515+
int64_t num_storage = args[8];
2516+
NDArray init = args[9];
2517+
NDArray kv_storage_init = args[10];
2518+
PackedFunc f_transpose_append = args[11];
2519+
PackedFunc f_attention_prefill = args[12];
2520+
PackedFunc f_attention_decode = args[13];
2521+
PackedFunc f_attention_prefill_sliding_window = args[14];
2522+
PackedFunc f_attention_decode_sliding_window = args[15];
2523+
PackedFunc f_attention_prefill_ragged = args[16];
2524+
PackedFunc f_merge_inplace = args[17];
2525+
PackedFunc f_split_rotary = args[18];
2526+
PackedFunc f_copy_single_page = args[19];
2527+
Optional<PackedFunc> f_debug_get_kv = args[20];
25242528
PackedFunc f_compact_copy{nullptr};
25252529
PackedFunc f_attention_prefill_with_tree_mask{nullptr};
25262530

2527-
if (args.size() >= 20) {
2528-
f_compact_copy = args[19].AsObjectRef<PackedFunc>();
2531+
if (args.size() >= 22) {
2532+
f_compact_copy = args[21].AsObjectRef<PackedFunc>();
25292533
}
2530-
if (args.size() >= 21) {
2531-
f_attention_prefill_with_tree_mask = args[20].AsObjectRef<PackedFunc>();
2534+
if (args.size() >= 23) {
2535+
f_attention_prefill_with_tree_mask = args[22].AsObjectRef<PackedFunc>();
25322536
}
25332537

25342538
CHECK_EQ(cache_config.size(), 5);
@@ -2545,8 +2549,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
25452549
ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
25462550
page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs,
25472551
num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode),
2548-
rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append),
2549-
std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode),
2552+
rotary_scale, rotary_theta, num_storage, init->dtype, kv_storage_init->dtype,
2553+
init->device, std::move(f_transpose_append), std::move(f_compact_copy),
2554+
std::move(f_attention_prefill), std::move(f_attention_decode),
25502555
std::move(f_attention_prefill_sliding_window),
25512556
std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged),
25522557
std::move(f_attention_prefill_with_tree_mask), //

tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,9 @@ def set_global_func():
344344

345345
def create_kv_cache(rope_mode):
346346
support_sliding_window = 0
347+
num_storage = head_dim
348+
kv_storage_dtype = dtype
349+
347350
cache = fcreate(
348351
tvm.runtime.ShapeTuple(
349352
[
@@ -361,7 +364,9 @@ def create_kv_cache(rope_mode):
361364
rope_mode,
362365
rope_scale,
363366
rope_theta,
367+
num_storage,
364368
tvm.nd.empty((), dtype, device=device),
369+
tvm.nd.empty((), kv_storage_dtype, device=device),
365370
ftranspose_append,
366371
fattention_prefill,
367372
fattention_decode,

tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def set_global_func(head_dim, dtype):
142142

143143

144144
def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
145+
num_storage = head_dim
146+
kv_storage_dtype = dtype
147+
145148
fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create_reduced")
146149
cache = fcreate(
147150
tvm.runtime.ShapeTuple(
@@ -160,7 +163,9 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
160163
rope_mode,
161164
rope_scale,
162165
rope_theta,
166+
num_storage,
163167
tvm.nd.empty((), dtype, device=device),
168+
tvm.nd.empty((), kv_storage_dtype, device=device),
164169
ftranspose_append,
165170
fattn_prefill,
166171
fattn_decode,

0 commit comments

Comments
 (0)