@@ -853,7 +853,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
853853 /* !
854854 * \brief The KV data managed by the KV cache.
855855 * The array has `num_layers` NDArrays, each of them
856- * has layout (num_pages, 2, num_heads, page_size, head_dim ).
856+ * has layout (num_pages, 2, num_heads, page_size, num_storage ).
857857 * Along on the "2" dimension, index 0 stands for K and 1 stands for V.
858858 */
859859 Array<NDArray> pages_;
@@ -985,10 +985,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
985985 int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim,
986986 int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size,
987987 bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, double rotary_theta,
988- DLDataType dtype, Device device, PackedFunc f_transpose_append, PackedFunc f_compact_copy,
989- PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
990- PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window,
991- PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_with_tree_mask,
988+ int64_t num_storage, DLDataType dtype, DLDataType kv_storage_dtype, Device device,
989+ PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc f_attention_prefill,
990+ PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window,
991+ PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged,
992+ PackedFunc f_attention_prefill_with_tree_mask,
992993 Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
993994 Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
994995 Optional<PackedFunc> f_attention_prefill_begin_forward,
@@ -1030,8 +1031,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
10301031 device_(device) {
10311032 pages_.reserve (num_layers);
10321033 for (int i = 0 ; i < num_layers; ++i) {
1033- pages_.push_back (
1034- NDArray::Empty ({num_total_pages, 2 , num_kv_heads, page_size, head_dim}, dtype , device));
1034+ pages_.push_back (NDArray::Empty ({num_total_pages, 2 , num_kv_heads, page_size, num_storage},
1035+ kv_storage_dtype , device));
10351036 }
10361037 // Allocate the host memory.
10371038 Device preferred_host_device = GetPreferredHostDevice (device);
@@ -1673,8 +1674,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
16731674 NDArray o_data, double attn_score_scaling_factor) final {
16741675 // Part 1. Shape and dtype check.
16751676 NDArray pages = pages_[layer_id];
1676- CHECK (qkv_data.DataType () == pages.DataType ());
1677- CHECK (o_data.DataType () == pages.DataType ());
16781677
16791678 // qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, head_dim)
16801679 // o_data: (num_total_length, num_qo_heads, head_dim)
@@ -2433,7 +2432,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);
24332432
24342433TVM_REGISTER_GLOBAL (" vm.builtin.paged_attention_kv_cache_create" )
24352434 .set_body([](TVMArgs args, TVMRetValue* rv) {
2436- CHECK (args.size () == 25 || args.size () == 26 || args.size () == 27 )
2435+ CHECK (args.size () == 27 || args.size () == 28 || args.size () == 29 )
24372436 << " Invalid number of KV cache constructor args." ;
24382437 ShapeTuple cache_config = args[0 ];
24392438 int64_t num_layers = args[1 ];
@@ -2443,31 +2442,33 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
24432442 int rope_mode = args[5 ];
24442443 double rotary_scale = args[6 ];
24452444 double rotary_theta = args[7 ];
2446- NDArray init = args[8 ];
2447- PackedFunc f_transpose_append = args[9 ];
2448- PackedFunc f_attention_prefill = args[10 ];
2449- PackedFunc f_attention_decode = args[11 ];
2450- PackedFunc f_attention_prefill_sliding_window = args[12 ];
2451- PackedFunc f_attention_decode_sliding_window = args[13 ];
2452- PackedFunc f_attention_prefill_ragged = args[14 ];
2453- PackedFunc f_attention_prefill_ragged_begin_forward = args[15 ];
2454- PackedFunc f_attention_prefill_ragged_end_forward = args[16 ];
2455- PackedFunc f_attention_prefill_begin_forward = args[17 ];
2456- PackedFunc f_attention_prefill_end_forward = args[18 ];
2457- PackedFunc f_attention_decode_begin_forward = args[19 ];
2458- PackedFunc f_attention_decode_end_forward = args[20 ];
2459- PackedFunc f_merge_inplace = args[21 ];
2460- PackedFunc f_split_rotary = args[22 ];
2461- PackedFunc f_copy_single_page = args[23 ];
2462- Optional<PackedFunc> f_debug_get_kv = args[24 ];
2445+ int64_t num_storage = args[8 ];
2446+ NDArray init = args[9 ];
2447+ NDArray kv_storage_init = args[10 ];
2448+ PackedFunc f_transpose_append = args[11 ];
2449+ PackedFunc f_attention_prefill = args[12 ];
2450+ PackedFunc f_attention_decode = args[13 ];
2451+ PackedFunc f_attention_prefill_sliding_window = args[14 ];
2452+ PackedFunc f_attention_decode_sliding_window = args[15 ];
2453+ PackedFunc f_attention_prefill_ragged = args[16 ];
2454+ PackedFunc f_attention_prefill_ragged_begin_forward = args[17 ];
2455+ PackedFunc f_attention_prefill_ragged_end_forward = args[18 ];
2456+ PackedFunc f_attention_prefill_begin_forward = args[19 ];
2457+ PackedFunc f_attention_prefill_end_forward = args[20 ];
2458+ PackedFunc f_attention_decode_begin_forward = args[21 ];
2459+ PackedFunc f_attention_decode_end_forward = args[22 ];
2460+ PackedFunc f_merge_inplace = args[23 ];
2461+ PackedFunc f_split_rotary = args[24 ];
2462+ PackedFunc f_copy_single_page = args[25 ];
2463+ Optional<PackedFunc> f_debug_get_kv = args[26 ];
24632464 PackedFunc f_compact_copy{nullptr };
24642465 PackedFunc f_attention_prefill_with_tree_mask{nullptr };
24652466
2466- if (args.size () >= 26 ) {
2467- f_compact_copy = args[25 ].AsObjectRef <PackedFunc>();
2467+ if (args.size () >= 28 ) {
2468+ f_compact_copy = args[27 ].AsObjectRef <PackedFunc>();
24682469 }
2469- if (args.size () >= 27 ) {
2470- f_attention_prefill_with_tree_mask = args[26 ].AsObjectRef <PackedFunc>();
2470+ if (args.size () >= 29 ) {
2471+ f_attention_prefill_with_tree_mask = args[28 ].AsObjectRef <PackedFunc>();
24712472 }
24722473
24732474 CHECK_EQ (cache_config.size (), 5 );
@@ -2484,8 +2485,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
24842485 ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
24852486 page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs,
24862487 num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode (rope_mode),
2487- rotary_scale, rotary_theta, init->dtype , init->device , std::move (f_transpose_append),
2488- std::move (f_compact_copy), std::move (f_attention_prefill), std::move (f_attention_decode),
2488+ rotary_scale, rotary_theta, num_storage, init->dtype , kv_storage_init->dtype ,
2489+ init->device , std::move (f_transpose_append), std::move (f_compact_copy),
2490+ std::move (f_attention_prefill), std::move (f_attention_decode),
24892491 std::move (f_attention_prefill_sliding_window),
24902492 std::move (f_attention_decode_sliding_window), std::move (f_attention_prefill_ragged),
24912493 std::move (f_attention_prefill_with_tree_mask),
@@ -2500,7 +2502,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
25002502
25012503TVM_REGISTER_GLOBAL (" vm.builtin.paged_attention_kv_cache_create_reduced" )
25022504 .set_body([](TVMArgs args, TVMRetValue* rv) {
2503- CHECK (args.size () == 19 || args.size () == 20 || args.size () == 21 )
2505+ CHECK (args.size () == 21 || args.size () == 22 || args.size () == 23 )
25042506 << " Invalid number of KV cache constructor args." ;
25052507 ShapeTuple cache_config = args[0 ];
25062508 int64_t num_layers = args[1 ];
@@ -2510,25 +2512,27 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
25102512 int rope_mode = args[5 ];
25112513 double rotary_scale = args[6 ];
25122514 double rotary_theta = args[7 ];
2513- NDArray init = args[8 ];
2514- PackedFunc f_transpose_append = args[9 ];
2515- PackedFunc f_attention_prefill = args[10 ];
2516- PackedFunc f_attention_decode = args[11 ];
2517- PackedFunc f_attention_prefill_sliding_window = args[12 ];
2518- PackedFunc f_attention_decode_sliding_window = args[13 ];
2519- PackedFunc f_attention_prefill_ragged = args[14 ];
2520- PackedFunc f_merge_inplace = args[15 ];
2521- PackedFunc f_split_rotary = args[16 ];
2522- PackedFunc f_copy_single_page = args[17 ];
2523- Optional<PackedFunc> f_debug_get_kv = args[18 ];
2515+ int64_t num_storage = args[8 ];
2516+ NDArray init = args[9 ];
2517+ NDArray kv_storage_init = args[10 ];
2518+ PackedFunc f_transpose_append = args[11 ];
2519+ PackedFunc f_attention_prefill = args[12 ];
2520+ PackedFunc f_attention_decode = args[13 ];
2521+ PackedFunc f_attention_prefill_sliding_window = args[14 ];
2522+ PackedFunc f_attention_decode_sliding_window = args[15 ];
2523+ PackedFunc f_attention_prefill_ragged = args[16 ];
2524+ PackedFunc f_merge_inplace = args[17 ];
2525+ PackedFunc f_split_rotary = args[18 ];
2526+ PackedFunc f_copy_single_page = args[19 ];
2527+ Optional<PackedFunc> f_debug_get_kv = args[20 ];
25242528 PackedFunc f_compact_copy{nullptr };
25252529 PackedFunc f_attention_prefill_with_tree_mask{nullptr };
25262530
2527- if (args.size () >= 20 ) {
2528- f_compact_copy = args[19 ].AsObjectRef <PackedFunc>();
2531+ if (args.size () >= 22 ) {
2532+ f_compact_copy = args[21 ].AsObjectRef <PackedFunc>();
25292533 }
2530- if (args.size () >= 21 ) {
2531- f_attention_prefill_with_tree_mask = args[20 ].AsObjectRef <PackedFunc>();
2534+ if (args.size () >= 23 ) {
2535+ f_attention_prefill_with_tree_mask = args[22 ].AsObjectRef <PackedFunc>();
25322536 }
25332537
25342538 CHECK_EQ (cache_config.size (), 5 );
@@ -2545,8 +2549,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
25452549 ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
25462550 page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs,
25472551 num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode (rope_mode),
2548- rotary_scale, rotary_theta, init->dtype , init->device , std::move (f_transpose_append),
2549- std::move (f_compact_copy), std::move (f_attention_prefill), std::move (f_attention_decode),
2552+ rotary_scale, rotary_theta, num_storage, init->dtype , kv_storage_init->dtype ,
2553+ init->device , std::move (f_transpose_append), std::move (f_compact_copy),
2554+ std::move (f_attention_prefill), std::move (f_attention_decode),
25502555 std::move (f_attention_prefill_sliding_window),
25512556 std::move (f_attention_decode_sliding_window), std::move (f_attention_prefill_ragged),
25522557 std::move (f_attention_prefill_with_tree_mask), //
0 commit comments