Skip to content

Commit cac63da

Browse files
committed
feat: add page-aligned tensor creator for host KV cache.
1 parent 3078669 commit cac63da

File tree

13 files changed

+191
-105
lines changed

13 files changed

+191
-105
lines changed

xllm/core/distributed_runtime/comm_channel.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,7 @@ class ClientStreamReceiver : public brpc::StreamInputHandler {
351351

352352
~ClientStreamReceiver() {
353353
if (!promise_set_.exchange(true)) {
354-
try {
355-
close_promise_.set_value();
356-
} catch (const std::exception& e) {
357-
LOG(WARNING) << "Exception in destructor: " << e.what();
358-
}
354+
close_promise_.set_value();
359355
}
360356
}
361357

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,12 @@ WorkerService::WorkerService(runtime::Options options,
4040
device_.set_device();
4141
device_.init_device_context();
4242
stream_ = device_.get_stream_from_pool();
43-
threadpool_ = std::make_unique<ThreadPool>(
44-
4, [this]() mutable { device_.set_device(); });
43+
std::vector<folly::Function<void()>> set_device;
44+
set_device.reserve(4);
45+
for (int i = 0; i < 4; i++) {
46+
set_device.emplace_back([this]() mutable { device_.set_device(); });
47+
}
48+
threadpool_ = std::make_unique<ThreadPool>(4, set_device);
4549
}
4650

4751
WorkerService::WorkerService(runtime::Options options,
@@ -54,8 +58,12 @@ WorkerService::WorkerService(runtime::Options options,
5458
device_.set_device();
5559
device_.init_device_context();
5660
stream_ = device_.get_stream_from_pool();
57-
threadpool_ = std::make_unique<ThreadPool>(
58-
4, [this]() mutable { device_.set_device(); });
61+
std::vector<folly::Function<void()>> set_device;
62+
set_device.reserve(4);
63+
for (int i = 0; i < 4; i++) {
64+
set_device.emplace_back([this]() mutable { device_.set_device(); });
65+
}
66+
threadpool_ = std::make_unique<ThreadPool>(4, set_device);
5967
}
6068

6169
WorkerService::~WorkerService() = default;
@@ -442,11 +450,7 @@ class ServerStreamHandler : public brpc::StreamInputHandler {
442450
public:
443451
~ServerStreamHandler() {
444452
if (!promise_set_.exchange(true)) {
445-
try {
446-
close_promise_.set_value();
447-
} catch (const std::exception& e) {
448-
LOG(WARNING) << "Exception in destructor: " << e.what();
449-
}
453+
close_promise_.set_value();
450454
}
451455
}
452456

xllm/core/framework/batch/batch.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ namespace xllm {
3434

3535
struct ModelArgs;
3636

37+
static uint64_t batch_counter_ = 1;
3738
class Batch {
3839
public:
3940
Batch() = default;
@@ -57,7 +58,11 @@ class Batch {
5758

5859
void set_batch_id() {
5960
if (batch_id_ == 0x0) {
60-
batch_id_ = absl::ToUnixMicros(absl::Now());
61+
batch_id_ = batch_counter_;
62+
batch_counter_++;
63+
if (batch_counter_ == UINT64_MAX) {
64+
batch_counter_ = 1;
65+
}
6166
}
6267
}
6368

xllm/core/framework/block/block_manager_impl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ void BlockManagerImpl::deallocate(const Slice<Block>& blocks) {
7070
for (const auto& block : blocks) {
7171
// the block is not shared by other sequence
7272
if (block.is_valid() && block.ref_count() <= 2) {
73-
if (num_used_blocks_ > 0) {
74-
num_used_blocks_.fetch_sub(1, std::memory_order_relaxed);
75-
} else {
73+
auto origin_num_used_blocks =
74+
num_used_blocks_.fetch_sub(1, std::memory_order_relaxed);
75+
if (origin_num_used_blocks < 0) {
7676
LOG(ERROR) << "num_used_blocks_==0 cannot fetch_sub for id:"
7777
<< block.id()
7878
<< ", total block size: " << num_total_blocks();
@@ -84,7 +84,7 @@ void BlockManagerImpl::deallocate(const Slice<Block>& blocks) {
8484
error_msg.append(std::to_string(id)).append(" ");
8585
}
8686
}
87-
LOG(ERROR) << error_msg;
87+
LOG(FATAL) << error_msg;
8888
}
8989
}
9090
}

xllm/core/framework/block/block_manager_pool.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,9 @@ void BlockManagerPool::set_offload_callback(
156156
device_block_mgr_ptr = block_managers_[i].get()](
157157
std::vector<folly::Try<uint32_t>>&& results) {
158158
for (auto&& result : results) {
159-
try {
160-
if (result.value() != host_blocks.size()) {
161-
LOG(FATAL) << "Offload copy fail, expected "
162-
<< host_blocks.size() << ", got " << result.value();
163-
}
164-
} catch (const std::exception& e) {
165-
LOG(FATAL) << "Offload copy fail! Exception caught: " << e.what();
159+
if (result.value() != host_blocks.size()) {
160+
LOG(FATAL) << "Offload copy fail, expected " << host_blocks.size()
161+
<< ", got " << result.value();
166162
}
167163
}
168164
host_block_mgr_ptr->cache(host_blocks);
@@ -212,6 +208,11 @@ bool BlockManagerPool::allocate(Sequence* sequence, size_t num_tokens) {
212208
allocate_shared(sequence);
213209
if (sequence->host_kv_state().num_kv_blocks() == 0) {
214210
allocate_host_shared(sequence);
211+
if (sequence->kv_state().shared_kv_blocks_num() <
212+
sequence->host_kv_state().shared_kv_blocks_num())
213+
LOG(INFO) << "device : host = : "
214+
<< sequence->kv_state().shared_kv_blocks_num() << " : "
215+
<< sequence->host_kv_state().shared_kv_blocks_num();
215216
}
216217
}
217218

xllm/core/framework/kv_cache/kv_cache_store.cpp

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,30 +55,18 @@ bool KVCacheStore::init(const StoreConfig& config,
5555
LOG(INFO) << "v_cache_size_per_block: " << v_cache_size_per_block_;
5656

5757
if (config_.protocol == "rdma") {
58-
for (int block = 0; block < host_kv_caches_->size(); block++) {
59-
void* key_cache = static_cast<char*>(
60-
host_kv_caches_->at(block).get_k_cache().data_ptr());
61-
62-
auto register_k_result = client_ptr_->RegisterLocalMemory(
63-
key_cache, k_cache_size_per_block_, "cpu:0", false, false);
64-
65-
if (!register_k_result.has_value()) {
66-
LOG(ERROR) << "Failed to register local memory for key cache: "
67-
<< toString(register_k_result.error());
68-
return false;
69-
}
70-
71-
void* value_cache = static_cast<char*>(
72-
host_kv_caches_->at(block).get_v_cache().data_ptr());
73-
74-
auto register_v_result = client_ptr_->RegisterLocalMemory(
75-
value_cache, v_cache_size_per_block_, "cpu:0", false, false);
76-
77-
if (!register_v_result.has_value()) {
78-
LOG(ERROR) << "Failed to register local memory for value cache: "
79-
<< toString(register_v_result.error());
58+
if (config_.total_size > 0 && config_.tensor_data != nullptr) {
59+
auto result = client_ptr_->RegisterLocalMemory(
60+
config_.tensor_data, config_.total_size, "cpu:0", false, false);
61+
if (!result.has_value()) {
62+
LOG(ERROR) << "Failed to register local memory: "
63+
<< toString(result.error());
8064
return false;
8165
}
66+
} else {
67+
LOG(FATAL) << "rdma must RegisterLocalMemory, but got register size: "
68+
<< config_.total_size
69+
<< ", and data ptr: " << uint64_t(config_.tensor_data);
8270
}
8371
}
8472
is_initialized_ = true;

xllm/core/framework/kv_cache/kv_cache_store.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ struct StoreConfig {
1919
std::string master_server_address = "";
2020
int replica_num = 1;
2121
uint32_t tp_rank = 0;
22+
size_t total_size = 0;
23+
void* tensor_data = nullptr;
2224
};
2325

2426
class KVCacheStore {

xllm/core/framework/model/model_input_params.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ limitations under the License.
2727

2828
namespace xllm {
2929

30-
enum class TransferType : uint8_t { G2H = 0, H2D = 1, D2G = 2 };
30+
enum class TransferType : uint8_t {
31+
G2H = 0, // global memory(KVCache store) to host memory(DRAM)
32+
H2D = 1, // host memory(DRAM) to device memory(HBM)
33+
D2G = 2 // host memory(DRAM) to global memory(KVCache store)
34+
};
3135

3236
struct BlockTransferInfo {
3337
int32_t src_block_id = -1;

xllm/core/runtime/params_utils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,8 @@ bool block_transfer_info_to_proto(
731731
pb_cache.set_dst_block_id(info.dst_block_id);
732732
pb_cache.set_hash_key(info.hash_key, MURMUR_HASH3_VALUE_LEN);
733733

734-
*pb_block_transfer_info->mutable_transfer_infos()->Add() = pb_cache;
734+
*pb_block_transfer_info->mutable_transfer_infos()->Add() =
735+
std::move(pb_cache);
735736
}
736737
pb_block_transfer_info->set_batch_id(batch_id);
737738
pb_block_transfer_info->set_transfer_type(proto::TransferType(transfer_type));

0 commit comments

Comments
 (0)