Skip to content

Commit eae52e8

Browse files
committed
refactor: change host KV cache memory layout from layer-wise to block-wise.
1 parent 746ea08 commit eae52e8

File tree

5 files changed

+140
-129
lines changed

5 files changed

+140
-129
lines changed

xllm/core/framework/kv_cache/kv_cache_store.cpp

Lines changed: 31 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,24 @@ bool KVCacheStore::init(const StoreConfig& config,
4343
}
4444
client_ptr_ = client_opt.value();
4545

46-
auto key_tensor_one_layer = host_kv_caches_->at(0).get_k_cache();
47-
auto value_tensor_one_layer = host_kv_caches_->at(0).get_v_cache();
46+
auto k_tensor_one_block = host_kv_caches_->at(0).get_k_cache();
47+
auto v_tensor_one_block = host_kv_caches_->at(0).get_v_cache();
4848

49-
key_cache_size_per_layer_ =
50-
key_tensor_one_layer[0].numel() * key_tensor_one_layer[0].element_size();
51-
value_cache_size_per_layer_ = value_tensor_one_layer[0].numel() *
52-
value_tensor_one_layer[0].element_size();
49+
k_cache_size_per_block_ =
50+
k_tensor_one_block.numel() * k_tensor_one_block.element_size();
51+
v_cache_size_per_block_ =
52+
v_tensor_one_block.numel() * v_tensor_one_block.element_size();
5353

54-
auto key_cache_host_size =
55-
key_tensor_one_layer.numel() * key_tensor_one_layer.element_size();
56-
auto value_cache_host_size =
57-
value_tensor_one_layer.numel() * value_tensor_one_layer.element_size();
58-
59-
LOG(INFO) << "key_cache_size_per_layer: " << key_cache_size_per_layer_;
60-
LOG(INFO) << "value_cache_size_per_layer: " << value_cache_size_per_layer_;
54+
LOG(INFO) << "k_cache_size_per_block: " << k_cache_size_per_block_;
55+
LOG(INFO) << "v_cache_size_per_block: " << v_cache_size_per_block_;
6156

6257
if (config_.protocol == "rdma") {
63-
for (int layer = 0; layer < host_kv_caches_->size(); layer++) {
58+
for (int block = 0; block < host_kv_caches_->size(); block++) {
6459
void* key_cache = static_cast<char*>(
65-
host_kv_caches_->at(layer).get_k_cache().data_ptr());
60+
host_kv_caches_->at(block).get_k_cache().data_ptr());
6661

6762
auto register_k_result = client_ptr_->RegisterLocalMemory(
68-
key_cache, key_cache_host_size, "cpu:0", false, false);
63+
key_cache, k_cache_size_per_block_, "cpu:0", false, false);
6964

7065
if (!register_k_result.has_value()) {
7166
LOG(ERROR) << "Failed to register local memory for key cache: "
@@ -74,10 +69,10 @@ bool KVCacheStore::init(const StoreConfig& config,
7469
}
7570

7671
void* value_cache = static_cast<char*>(
77-
host_kv_caches_->at(layer).get_v_cache().data_ptr());
72+
host_kv_caches_->at(block).get_v_cache().data_ptr());
7873

7974
auto register_v_result = client_ptr_->RegisterLocalMemory(
80-
value_cache, value_cache_host_size, "cpu:0", false, false);
75+
value_cache, v_cache_size_per_block_, "cpu:0", false, false);
8176

8277
if (!register_v_result.has_value()) {
8378
LOG(ERROR) << "Failed to register local memory for value cache: "
@@ -119,23 +114,14 @@ uint32_t KVCacheStore::batch_put(
119114

120115
str_keys.emplace_back(str_key);
121116

122-
std::vector<mooncake::Slice> slice;
123-
slice.reserve(host_kv_caches_->size() * 2);
124-
for (int layer = 0; layer < host_kv_caches_->size(); layer++) {
125-
void* key_cache =
126-
static_cast<char*>(
127-
host_kv_caches_->at(layer).get_k_cache().data_ptr()) +
128-
block_info.dst_block_id * key_cache_size_per_layer_;
129-
slice.emplace_back(mooncake::Slice{key_cache, key_cache_size_per_layer_});
130-
131-
void* value_cache =
132-
static_cast<char*>(
133-
host_kv_caches_->at(layer).get_v_cache().data_ptr()) +
134-
block_info.dst_block_id * value_cache_size_per_layer_;
135-
slice.emplace_back(
136-
mooncake::Slice{value_cache, value_cache_size_per_layer_});
137-
}
138-
slices.emplace_back(std::move(slice));
117+
void* k_cache =
118+
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
119+
void* v_cache =
120+
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
121+
122+
slices.emplace_back(std::vector<mooncake::Slice>{
123+
mooncake::Slice{k_cache, k_cache_size_per_block_},
124+
mooncake::Slice{v_cache, v_cache_size_per_block_}});
139125
}
140126

141127
if (str_keys.size() == 0) {
@@ -177,24 +163,16 @@ uint32_t KVCacheStore::batch_get(
177163

178164
str_keys.emplace_back(str_key);
179165

180-
slices.insert(std::make_pair(str_key, std::vector<mooncake::Slice>()));
181-
182-
slices[str_key].reserve(host_kv_caches_->size() * 2);
183-
for (int layer = 0; layer < host_kv_caches_->size(); layer++) {
184-
void* key_cache =
185-
static_cast<char*>(
186-
host_kv_caches_->at(layer).get_k_cache().data_ptr()) +
187-
block_info.dst_block_id * key_cache_size_per_layer_;
188-
slices[str_key].emplace_back(
189-
mooncake::Slice{key_cache, key_cache_size_per_layer_});
190-
191-
void* value_cache =
192-
static_cast<char*>(
193-
host_kv_caches_->at(layer).get_v_cache().data_ptr()) +
194-
block_info.dst_block_id * value_cache_size_per_layer_;
195-
slices[str_key].emplace_back(
196-
mooncake::Slice{value_cache, value_cache_size_per_layer_});
197-
}
166+
void* k_cache =
167+
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
168+
void* v_cache =
169+
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
170+
171+
slices.insert(
172+
std::make_pair(str_key,
173+
std::vector<mooncake::Slice>{
174+
mooncake::Slice{k_cache, k_cache_size_per_block_},
175+
mooncake::Slice{v_cache, v_cache_size_per_block_}}));
198176
}
199177

200178
if (str_keys.size() == 0) {

xllm/core/framework/kv_cache/kv_cache_store.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ class KVCacheStore {
6969

7070
std::vector<xllm::KVCache>* host_kv_caches_;
7171

72-
uint64_t key_cache_size_per_layer_;
73-
uint64_t value_cache_size_per_layer_;
72+
uint64_t k_cache_size_per_block_;
73+
uint64_t v_cache_size_per_block_;
7474

7575
std::shared_ptr<mooncake::Client> client_ptr_;
7676
};

xllm/core/framework/request/sequence.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,10 @@ class Sequence final {
239239

240240
void sync_result() {
241241
if (futures_.has_value()) {
242-
auto success_cnt = host_kv_state_.num_kv_blocks();
242+
uint32_t success_cnt = host_kv_state_.num_kv_blocks();
243243
for (auto& future : futures_.value()) {
244244
if (future.isReady()) {
245-
success_cnt = std::min(success_cnt, size_t(future.value()));
245+
success_cnt = std::min(success_cnt, future.value());
246246
} else {
247247
return;
248248
}

xllm/core/runtime/worker_impl.cpp

Lines changed: 101 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -132,17 +132,19 @@ bool WorkerImpl::allocate_host_kv_cache(
132132

133133
CHECK(model_ != nullptr) << "Model is not initialized.";
134134
CHECK(host_kv_caches_.empty()) << "KV caches are already initialized.";
135+
CHECK(device_kv_cache_shape[0][0] == device_kv_cache_shape[1][0]);
135136

136137
std::vector<std::vector<int64_t>> host_kv_cache_shape = device_kv_cache_shape;
137-
host_kv_cache_shape[0][0] =
138+
const int64_t num_layers = context_.get_model_args().n_layers();
139+
int64_t host_bolck_size =
138140
device_kv_cache_shape[0][0] * options_.host_blocks_factor();
139-
host_kv_cache_shape[1][0] =
140-
device_kv_cache_shape[1][0] * options_.host_blocks_factor();
141+
host_kv_cache_shape[0][0] = num_layers;
142+
host_kv_cache_shape[1][0] = num_layers;
141143

142-
// create a KVCache for each layer
143-
const int64_t num_layers = context_.get_model_args().n_layers();
144-
host_kv_caches_.reserve(num_layers);
145-
for (int64_t i = 0; i < num_layers; ++i) {
144+
// create a KVCache shape: block_size * [layers, token, head, dim]
145+
host_kv_caches_.reserve(host_bolck_size);
146+
147+
for (int64_t i = 0; i < host_bolck_size; ++i) {
146148
torch::Tensor key_cache, value_cache;
147149
key_cache = torch::empty(host_kv_cache_shape[0],
148150
torch::dtype(dtype_).device(torch::kCPU))
@@ -152,8 +154,7 @@ bool WorkerImpl::allocate_host_kv_cache(
152154
.pin_memory();
153155
host_kv_caches_.emplace_back(key_cache, value_cache);
154156
}
155-
LOG(INFO) << "Initializing host k cache size: " << host_kv_cache_shape[0][0];
156-
LOG(INFO) << "Initializing host v cache size: " << host_kv_cache_shape[1][0];
157+
LOG(INFO) << "Initializing host kv block size: " << host_bolck_size;
157158

158159
int32_t device_id = device_.index();
159160
h2d_attrs_.dstLoc.id = device_id;
@@ -688,22 +689,8 @@ uint32_t WorkerImpl::transfer_kv_blocks(
688689

689690
switch (block_transfer_info[0].transfer_type) {
690691
case TransferType::G2H: {
691-
folly::Promise<uint32_t> promise;
692-
auto future = promise.getSemiFuture();
693-
694-
batchget_threadpool_.schedule(
695-
[this, &block_transfer_info, promise = std::move(promise)]() mutable {
696-
promise.setValue(
697-
KVCacheStore::get_instance().batch_get(block_transfer_info));
698-
});
699-
700-
try {
701-
auto timeout = std::chrono::seconds(KVSTORE_TIMEOUT);
702-
return std::move(future).wait(timeout);
703-
} catch (const folly::FutureTimeout& e) {
704-
LOG(WARNING) << "BatchGet operation timed out";
705-
return 0;
706-
}
692+
Slice<BlockTransferInfo> info_slice{block_transfer_info};
693+
return load_from_store(info_slice);
707694
}
708695
case TransferType::D2G:
709696
return offload_kv_blocks(block_transfer_info);
@@ -793,23 +780,7 @@ uint32_t WorkerImpl::offload_kv_blocks(
793780
promise = std::move(promise),
794781
slice = std::move(slice)]() mutable {
795782
bool ret = d2h_batch_copy(slice);
796-
uint32_t success_cnt = 0;
797-
798-
folly::Promise<uint32_t> store_promise;
799-
auto future = store_promise.getSemiFuture();
800-
801-
batchput_threadpool_.schedule(
802-
[this, &slice, promise = std::move(store_promise)]() mutable {
803-
promise.setValue(KVCacheStore::get_instance().batch_put(slice));
804-
});
805-
806-
try {
807-
auto timeout = std::chrono::seconds(KVSTORE_TIMEOUT);
808-
success_cnt = std::move(future).wait(timeout);
809-
} catch (const folly::FutureTimeout& e) {
810-
LOG(WARNING) << "BatchPut operation timed out";
811-
}
812-
783+
auto success_cnt = offload_to_store(slice);
813784
if (success_cnt != slice.size()) {
814785
LOG(WARNING) << "KVCacheStore not all put success: " << success_cnt
815786
<< "/" << slice.size();
@@ -895,6 +866,7 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
895866
#if defined(USE_NPU)
896867
CHECK(copy_stream_.count(std::this_thread::get_id()) != 0)
897868
<< "WorkerImpl::d2h_batch_copy can only be called in copy_threadpool_.";
869+
898870
const int64_t num_layers = context_.get_model_args().n_layers();
899871
uint32_t num_batches = block_transfer_info.size() * num_layers * 2;
900872
void** srcs = new void*[num_batches];
@@ -904,26 +876,25 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
904876
size_t attrs_indexes[1] = {0};
905877
size_t fail_index;
906878
uint32_t curr_index = 0;
907-
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
908-
auto src_k_cache = kv_caches_.at(layer_id).get_k_cache();
909-
auto dst_k_cache = host_kv_caches_.at(layer_id).get_k_cache();
910-
auto src_v_cache = kv_caches_.at(layer_id).get_v_cache();
911-
auto dst_v_cache = host_kv_caches_.at(layer_id).get_v_cache();
912-
913-
for (int idx = 0; idx < block_transfer_info.size(); idx++) {
914-
srcs[curr_index] =
915-
src_k_cache[block_transfer_info[idx].src_block_id].data_ptr();
916-
dsts[curr_index] =
917-
dst_k_cache[block_transfer_info[idx].dst_block_id].data_ptr();
918879

880+
for (const auto& info : block_transfer_info) {
881+
auto dst_k_cache = host_kv_caches_.at(info.dst_block_id).get_k_cache();
882+
auto dst_v_cache = host_kv_caches_.at(info.dst_block_id).get_v_cache();
883+
884+
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
885+
auto src_k_cache = kv_caches_.at(layer_id).get_k_cache();
886+
auto src_v_cache = kv_caches_.at(layer_id).get_v_cache();
887+
888+
srcs[curr_index] = src_k_cache[info.src_block_id].data_ptr();
889+
dsts[curr_index] = dst_k_cache[layer_id].data_ptr();
919890
copy_size[curr_index] = key_cache_size_per_layer_;
891+
920892
curr_index++;
921893

922-
srcs[curr_index] =
923-
src_v_cache[block_transfer_info[idx].src_block_id].data_ptr();
924-
dsts[curr_index] =
925-
dst_v_cache[block_transfer_info[idx].dst_block_id].data_ptr();
894+
srcs[curr_index] = src_v_cache[info.src_block_id].data_ptr();
895+
dsts[curr_index] = dst_v_cache[layer_id].data_ptr();
926896
copy_size[curr_index] = value_cache_size_per_layer_;
897+
927898
curr_index++;
928899
}
929900
}
@@ -961,6 +932,7 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
961932
#if defined(USE_NPU)
962933
CHECK(copy_stream_.count(std::this_thread::get_id()) != 0)
963934
<< "WorkerImpl::h2d_batch_copy can only be called in copy_threadpool_.";
935+
964936
const int64_t num_layers = context_.get_model_args().n_layers();
965937
uint32_t num_batches = block_transfer_info.size() * num_layers * 2;
966938
void** srcs = new void*[num_batches];
@@ -971,24 +943,21 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
971943
size_t fail_index;
972944
uint32_t curr_index = 0;
973945

974-
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
975-
auto src_k_cache = host_kv_caches_.at(layer_id).get_k_cache();
976-
auto dst_k_cache = kv_caches_.at(layer_id).get_k_cache();
977-
auto src_v_cache = host_kv_caches_.at(layer_id).get_v_cache();
978-
auto dst_v_cache = kv_caches_.at(layer_id).get_v_cache();
979-
980-
for (int idx = 0; idx < block_transfer_info.size(); idx++) {
981-
srcs[curr_index] =
982-
src_k_cache[block_transfer_info[idx].src_block_id].data_ptr();
983-
dsts[curr_index] =
984-
dst_k_cache[block_transfer_info[idx].dst_block_id].data_ptr();
946+
for (const auto& info : block_transfer_info) {
947+
auto src_k_cache = host_kv_caches_.at(info.src_block_id).get_k_cache();
948+
auto src_v_cache = host_kv_caches_.at(info.src_block_id).get_v_cache();
949+
950+
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
951+
auto dst_k_cache = kv_caches_.at(layer_id).get_k_cache();
952+
auto dst_v_cache = kv_caches_.at(layer_id).get_v_cache();
953+
954+
srcs[curr_index] = src_k_cache[layer_id].data_ptr();
955+
dsts[curr_index] = dst_k_cache[info.dst_block_id].data_ptr();
985956
copy_size[curr_index] = key_cache_size_per_layer_;
986957
curr_index++;
987958

988-
srcs[curr_index] =
989-
src_v_cache[block_transfer_info[idx].src_block_id].data_ptr();
990-
dsts[curr_index] =
991-
dst_v_cache[block_transfer_info[idx].dst_block_id].data_ptr();
959+
srcs[curr_index] = src_v_cache[layer_id].data_ptr();
960+
dsts[curr_index] = dst_v_cache[info.dst_block_id].data_ptr();
992961
copy_size[curr_index] = value_cache_size_per_layer_;
993962
curr_index++;
994963
}
@@ -1022,4 +991,64 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
1022991
return false;
1023992
}
1024993

994+
uint32_t WorkerImpl::offload_to_store(
995+
Slice<BlockTransferInfo>& block_transfer_info) {
996+
if (!options_.enable_kvcache_store()) {
997+
return block_transfer_info.size();
998+
}
999+
1000+
folly::Promise<uint32_t> promise;
1001+
auto future = promise.getSemiFuture();
1002+
1003+
batchput_threadpool_.schedule(
1004+
[this, &block_transfer_info, promise = std::move(promise)]() mutable {
1005+
promise.setValue(
1006+
KVCacheStore::get_instance().batch_put(block_transfer_info));
1007+
});
1008+
1009+
auto timeout = std::chrono::seconds(KVSTORE_TIMEOUT);
1010+
return std::move(future)
1011+
.via(folly::getGlobalCPUExecutor())
1012+
.within(timeout)
1013+
.thenTry([](folly::Try<uint32_t>&& t) -> uint32_t {
1014+
if (t.hasValue()) {
1015+
return t.value();
1016+
} else {
1017+
LOG(WARNING) << "BatchPut operation timed out";
1018+
return 0u;
1019+
}
1020+
})
1021+
.get();
1022+
}
1023+
1024+
uint32_t WorkerImpl::load_from_store(
1025+
Slice<BlockTransferInfo>& block_transfer_info) {
1026+
if (!options_.enable_kvcache_store()) {
1027+
return 0;
1028+
}
1029+
1030+
folly::Promise<uint32_t> promise;
1031+
auto future = promise.getSemiFuture();
1032+
1033+
batchget_threadpool_.schedule(
1034+
[this, &block_transfer_info, promise = std::move(promise)]() mutable {
1035+
promise.setValue(
1036+
KVCacheStore::get_instance().batch_get(block_transfer_info));
1037+
});
1038+
1039+
auto timeout = std::chrono::seconds(KVSTORE_TIMEOUT);
1040+
return std::move(future)
1041+
.via(folly::getGlobalCPUExecutor())
1042+
.within(timeout)
1043+
.thenTry([](folly::Try<uint32_t>&& t) -> uint32_t {
1044+
if (t.hasValue()) {
1045+
return t.value();
1046+
} else {
1047+
LOG(WARNING) << "BatchGet operation timed out";
1048+
return 0u;
1049+
}
1050+
})
1051+
.get();
1052+
}
1053+
10251054
} // namespace xllm

0 commit comments

Comments
 (0)