@@ -132,17 +132,19 @@ bool WorkerImpl::allocate_host_kv_cache(
132132
133133 CHECK (model_ != nullptr ) << " Model is not initialized." ;
134134 CHECK (host_kv_caches_.empty ()) << " KV caches are already initialized." ;
135+ CHECK (device_kv_cache_shape[0 ][0 ] == device_kv_cache_shape[1 ][0 ]);
135136
136137 std::vector<std::vector<int64_t >> host_kv_cache_shape = device_kv_cache_shape;
137- host_kv_cache_shape[0 ][0 ] =
138+ const int64_t num_layers = context_.get_model_args ().n_layers ();
139+ int64_t host_bolck_size =
138140 device_kv_cache_shape[0 ][0 ] * options_.host_blocks_factor ();
139- host_kv_cache_shape[1 ][0 ] =
140- device_kv_cache_shape [1 ][0 ] * options_. host_blocks_factor () ;
141+ host_kv_cache_shape[0 ][0 ] = num_layers;
142+ host_kv_cache_shape [1 ][0 ] = num_layers ;
141143
142- // create a KVCache for each layer
143- const int64_t num_layers = context_. get_model_args (). n_layers ( );
144- host_kv_caches_. reserve (num_layers);
145- for (int64_t i = 0 ; i < num_layers ; ++i) {
144+ // create a KVCache shape: block_size * [layers, token, head, dim]
145+ host_kv_caches_. reserve (host_bolck_size );
146+
147+ for (int64_t i = 0 ; i < host_bolck_size ; ++i) {
146148 torch::Tensor key_cache, value_cache;
147149 key_cache = torch::empty (host_kv_cache_shape[0 ],
148150 torch::dtype (dtype_).device (torch::kCPU ))
@@ -152,8 +154,7 @@ bool WorkerImpl::allocate_host_kv_cache(
152154 .pin_memory ();
153155 host_kv_caches_.emplace_back (key_cache, value_cache);
154156 }
155- LOG (INFO) << " Initializing host k cache size: " << host_kv_cache_shape[0 ][0 ];
156- LOG (INFO) << " Initializing host v cache size: " << host_kv_cache_shape[1 ][0 ];
157+ LOG (INFO) << " Initializing host kv block size: " << host_bolck_size;
157158
158159 int32_t device_id = device_.index ();
159160 h2d_attrs_.dstLoc .id = device_id;
@@ -688,22 +689,8 @@ uint32_t WorkerImpl::transfer_kv_blocks(
688689
689690 switch (block_transfer_info[0 ].transfer_type ) {
690691 case TransferType::G2H: {
691- folly::Promise<uint32_t > promise;
692- auto future = promise.getSemiFuture ();
693-
694- batchget_threadpool_.schedule (
695- [this , &block_transfer_info, promise = std::move (promise)]() mutable {
696- promise.setValue (
697- KVCacheStore::get_instance ().batch_get (block_transfer_info));
698- });
699-
700- try {
701- auto timeout = std::chrono::seconds (KVSTORE_TIMEOUT);
702- return std::move (future).wait (timeout);
703- } catch (const folly::FutureTimeout& e) {
704- LOG (WARNING) << " BatchGet operation timed out" ;
705- return 0 ;
706- }
692+ Slice<BlockTransferInfo> info_slice{block_transfer_info};
693+ return load_from_store (info_slice);
707694 }
708695 case TransferType::D2G:
709696 return offload_kv_blocks (block_transfer_info);
@@ -793,23 +780,7 @@ uint32_t WorkerImpl::offload_kv_blocks(
793780 promise = std::move (promise),
794781 slice = std::move (slice)]() mutable {
795782 bool ret = d2h_batch_copy (slice);
796- uint32_t success_cnt = 0 ;
797-
798- folly::Promise<uint32_t > store_promise;
799- auto future = store_promise.getSemiFuture ();
800-
801- batchput_threadpool_.schedule (
802- [this , &slice, promise = std::move (store_promise)]() mutable {
803- promise.setValue (KVCacheStore::get_instance ().batch_put (slice));
804- });
805-
806- try {
807- auto timeout = std::chrono::seconds (KVSTORE_TIMEOUT);
808- success_cnt = std::move (future).wait (timeout);
809- } catch (const folly::FutureTimeout& e) {
810- LOG (WARNING) << " BatchPut operation timed out" ;
811- }
812-
783+ auto success_cnt = offload_to_store (slice);
813784 if (success_cnt != slice.size ()) {
814785 LOG (WARNING) << " KVCacheStore not all put success: " << success_cnt
815786 << " /" << slice.size ();
@@ -895,6 +866,7 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
895866#if defined(USE_NPU)
896867 CHECK (copy_stream_.count (std::this_thread::get_id ()) != 0 )
897868 << " WorkerImpl::d2h_batch_copy can only be called in copy_threadpool_." ;
869+
898870 const int64_t num_layers = context_.get_model_args ().n_layers ();
899871 uint32_t num_batches = block_transfer_info.size () * num_layers * 2 ;
900872 void ** srcs = new void *[num_batches];
@@ -904,26 +876,25 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
904876 size_t attrs_indexes[1 ] = {0 };
905877 size_t fail_index;
906878 uint32_t curr_index = 0 ;
907- for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
908- auto src_k_cache = kv_caches_.at (layer_id).get_k_cache ();
909- auto dst_k_cache = host_kv_caches_.at (layer_id).get_k_cache ();
910- auto src_v_cache = kv_caches_.at (layer_id).get_v_cache ();
911- auto dst_v_cache = host_kv_caches_.at (layer_id).get_v_cache ();
912-
913- for (int idx = 0 ; idx < block_transfer_info.size (); idx++) {
914- srcs[curr_index] =
915- src_k_cache[block_transfer_info[idx].src_block_id ].data_ptr ();
916- dsts[curr_index] =
917- dst_k_cache[block_transfer_info[idx].dst_block_id ].data_ptr ();
918879
880+ for (const auto & info : block_transfer_info) {
881+ auto dst_k_cache = host_kv_caches_.at (info.dst_block_id ).get_k_cache ();
882+ auto dst_v_cache = host_kv_caches_.at (info.dst_block_id ).get_v_cache ();
883+
884+ for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
885+ auto src_k_cache = kv_caches_.at (layer_id).get_k_cache ();
886+ auto src_v_cache = kv_caches_.at (layer_id).get_v_cache ();
887+
888+ srcs[curr_index] = src_k_cache[info.src_block_id ].data_ptr ();
889+ dsts[curr_index] = dst_k_cache[layer_id].data_ptr ();
919890 copy_size[curr_index] = key_cache_size_per_layer_;
891+
920892 curr_index++;
921893
922- srcs[curr_index] =
923- src_v_cache[block_transfer_info[idx].src_block_id ].data_ptr ();
924- dsts[curr_index] =
925- dst_v_cache[block_transfer_info[idx].dst_block_id ].data_ptr ();
894+ srcs[curr_index] = src_v_cache[info.src_block_id ].data_ptr ();
895+ dsts[curr_index] = dst_v_cache[layer_id].data_ptr ();
926896 copy_size[curr_index] = value_cache_size_per_layer_;
897+
927898 curr_index++;
928899 }
929900 }
@@ -961,6 +932,7 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
961932#if defined(USE_NPU)
962933 CHECK (copy_stream_.count (std::this_thread::get_id ()) != 0 )
963934 << " WorkerImpl::h2d_batch_copy can only be called in copy_threadpool_." ;
935+
964936 const int64_t num_layers = context_.get_model_args ().n_layers ();
965937 uint32_t num_batches = block_transfer_info.size () * num_layers * 2 ;
966938 void ** srcs = new void *[num_batches];
@@ -971,24 +943,21 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
971943 size_t fail_index;
972944 uint32_t curr_index = 0 ;
973945
974- for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
975- auto src_k_cache = host_kv_caches_.at (layer_id).get_k_cache ();
976- auto dst_k_cache = kv_caches_.at (layer_id).get_k_cache ();
977- auto src_v_cache = host_kv_caches_.at (layer_id).get_v_cache ();
978- auto dst_v_cache = kv_caches_.at (layer_id).get_v_cache ();
979-
980- for (int idx = 0 ; idx < block_transfer_info.size (); idx++) {
981- srcs[curr_index] =
982- src_k_cache[block_transfer_info[idx].src_block_id ].data_ptr ();
983- dsts[curr_index] =
984- dst_k_cache[block_transfer_info[idx].dst_block_id ].data_ptr ();
946+ for (const auto & info : block_transfer_info) {
947+ auto src_k_cache = host_kv_caches_.at (info.src_block_id ).get_k_cache ();
948+ auto src_v_cache = host_kv_caches_.at (info.src_block_id ).get_v_cache ();
949+
950+ for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
951+ auto dst_k_cache = kv_caches_.at (layer_id).get_k_cache ();
952+ auto dst_v_cache = kv_caches_.at (layer_id).get_v_cache ();
953+
954+ srcs[curr_index] = src_k_cache[layer_id].data_ptr ();
955+ dsts[curr_index] = dst_k_cache[info.dst_block_id ].data_ptr ();
985956 copy_size[curr_index] = key_cache_size_per_layer_;
986957 curr_index++;
987958
988- srcs[curr_index] =
989- src_v_cache[block_transfer_info[idx].src_block_id ].data_ptr ();
990- dsts[curr_index] =
991- dst_v_cache[block_transfer_info[idx].dst_block_id ].data_ptr ();
959+ srcs[curr_index] = src_v_cache[layer_id].data_ptr ();
960+ dsts[curr_index] = dst_v_cache[info.dst_block_id ].data_ptr ();
992961 copy_size[curr_index] = value_cache_size_per_layer_;
993962 curr_index++;
994963 }
@@ -1022,4 +991,64 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
1022991 return false ;
1023992}
1024993
994+ uint32_t WorkerImpl::offload_to_store (
995+ Slice<BlockTransferInfo>& block_transfer_info) {
996+ if (!options_.enable_kvcache_store ()) {
997+ return block_transfer_info.size ();
998+ }
999+
1000+ folly::Promise<uint32_t > promise;
1001+ auto future = promise.getSemiFuture ();
1002+
1003+ batchput_threadpool_.schedule (
1004+ [this , &block_transfer_info, promise = std::move (promise)]() mutable {
1005+ promise.setValue (
1006+ KVCacheStore::get_instance ().batch_put (block_transfer_info));
1007+ });
1008+
1009+ auto timeout = std::chrono::seconds (KVSTORE_TIMEOUT);
1010+ return std::move (future)
1011+ .via (folly::getGlobalCPUExecutor ())
1012+ .within (timeout)
1013+ .thenTry ([](folly::Try<uint32_t >&& t) -> uint32_t {
1014+ if (t.hasValue ()) {
1015+ return t.value ();
1016+ } else {
1017+ LOG (WARNING) << " BatchPut operation timed out" ;
1018+ return 0u ;
1019+ }
1020+ })
1021+ .get ();
1022+ }
1023+
1024+ uint32_t WorkerImpl::load_from_store (
1025+ Slice<BlockTransferInfo>& block_transfer_info) {
1026+ if (!options_.enable_kvcache_store ()) {
1027+ return 0 ;
1028+ }
1029+
1030+ folly::Promise<uint32_t > promise;
1031+ auto future = promise.getSemiFuture ();
1032+
1033+ batchget_threadpool_.schedule (
1034+ [this , &block_transfer_info, promise = std::move (promise)]() mutable {
1035+ promise.setValue (
1036+ KVCacheStore::get_instance ().batch_get (block_transfer_info));
1037+ });
1038+
1039+ auto timeout = std::chrono::seconds (KVSTORE_TIMEOUT);
1040+ return std::move (future)
1041+ .via (folly::getGlobalCPUExecutor ())
1042+ .within (timeout)
1043+ .thenTry ([](folly::Try<uint32_t >&& t) -> uint32_t {
1044+ if (t.hasValue ()) {
1045+ return t.value ();
1046+ } else {
1047+ LOG (WARNING) << " BatchGet operation timed out" ;
1048+ return 0u ;
1049+ }
1050+ })
1051+ .get ();
1052+ }
1053+
10251054} // namespace xllm
0 commit comments