Skip to content

Commit d6fba79

Browse files
committed
feat: implement batch prefetch from store.
1 parent 1926248 commit d6fba79

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+492
-155
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,14 +338,18 @@ DEFINE_string(store_protocol,
338338
"tcp",
339339
"KV cache store protocol(e.g. tcp, rdma).");
340340

341-
DEFINE_string(store_master_server_entry,
341+
DEFINE_string(store_master_server_address,
342342
"",
343343
"The address information of the store master service.");
344344

345-
DEFINE_string(store_metadata_connstring,
345+
DEFINE_string(store_metadata_server,
346346
"",
347347
"The address of the kv cache store metadata service.");
348348

349+
DEFINE_string(store_local_hostname,
350+
"",
351+
"The local host name of the kv cache store client.");
352+
349353
// --- computation communication parallel config ---
350354

351355
DEFINE_bool(

xllm/core/common/global_flags.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,11 @@ DECLARE_bool(enable_kvcache_store);
163163

164164
DECLARE_string(store_protocol);
165165

166-
DECLARE_string(store_master_server_entry);
166+
DECLARE_string(store_master_server_address);
167167

168-
DECLARE_string(store_metadata_connstring);
168+
DECLARE_string(store_metadata_server);
169+
170+
DECLARE_string(store_local_hostname);
169171

170172
DECLARE_bool(enable_multi_stream_parallel);
171173

xllm/core/common/options.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ std::string Options::to_string() const {
5353
<< ", enable_cache_upload: " << enable_cache_upload()
5454
<< ", enable_kvcache_store: " << enable_kvcache_store()
5555
<< ", store_protocol: " << store_protocol()
56-
<< ", store_master_server_entry: " << store_master_server_entry()
57-
<< ", store_metadata_connstring: " << store_metadata_connstring()
56+
<< ", store_master_server_address: " << store_master_server_address()
57+
<< ", store_metadata_server: " << store_metadata_server()
58+
<< ", store_local_hostname: " << store_local_hostname()
5859
<< ", enable_multi_stream_parallel: " << enable_multi_stream_parallel()
5960
<< ", enable_continuous_kvcache: " << enable_continuous_kvcache()
6061
<< ", disable_ttft_profiling: " << disable_ttft_profiling()

xllm/core/common/options.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,11 @@ class Options {
141141

142142
PROPERTY(std::string, store_protocol) = "tcp";
143143

144-
PROPERTY(std::string, store_master_server_entry) = "";
144+
PROPERTY(std::string, store_master_server_address) = "";
145145

146-
PROPERTY(std::string, store_metadata_connstring) = "";
146+
PROPERTY(std::string, store_metadata_server) = "";
147+
148+
PROPERTY(std::string, store_local_hostname) = "";
147149

148150
PROPERTY(bool, enable_multi_stream_parallel) = false;
149151

xllm/core/distributed_runtime/comm_channel.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ limitations under the License.
1818
#include <brpc/controller.h>
1919
#include <glog/logging.h>
2020

21+
#include <future>
22+
2123
namespace xllm {
2224

2325
bool CommChannel::init_brpc(const std::string& server_address) {
@@ -335,6 +337,94 @@ void CommChannel::transfer_kv_blocks(
335337
stub_->TransferBlocks(&cntl, &pb_block_transfer_info, &response, nullptr);
336338
}
337339

340+
class ClientStreamReceiver : public brpc::StreamInputHandler {
341+
private:
342+
const std::atomic<bool>& termination_flag_;
343+
std::shared_ptr<std::atomic<uint32_t>> success_cnt_;
344+
std::promise<void> close_promise_;
345+
std::atomic<bool> promise_set_{false};
346+
347+
public:
348+
ClientStreamReceiver(const std::atomic<bool>& termination_flag,
349+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt)
350+
: termination_flag_(termination_flag), success_cnt_(success_cnt) {}
351+
352+
~ClientStreamReceiver() {
353+
if (!promise_set_.exchange(true)) {
354+
try {
355+
close_promise_.set_value();
356+
} catch (const std::exception& e) {
357+
LOG(WARNING) << "Exception in destructor: " << e.what();
358+
}
359+
}
360+
}
361+
362+
std::future<void> get_close_future() { return close_promise_.get_future(); }
363+
364+
int on_received_messages(brpc::StreamId id,
365+
butil::IOBuf* const messages[],
366+
size_t size) override {
367+
for (size_t i = 0; i < size; ++i) {
368+
std::string msg_str = messages[i]->to_string();
369+
int32_t success_cnt = std::stoi(msg_str);
370+
371+
if (success_cnt > 0 &&
372+
!termination_flag_.load(std::memory_order_acquire)) {
373+
success_cnt_->fetch_add(success_cnt, std::memory_order_relaxed);
374+
} else {
375+
brpc::StreamClose(id);
376+
if (!promise_set_.exchange(true)) {
377+
close_promise_.set_value();
378+
}
379+
break;
380+
}
381+
}
382+
return 0;
383+
}
384+
385+
virtual void on_idle_timeout(brpc::StreamId id) override {
386+
if (!promise_set_.exchange(true)) {
387+
close_promise_.set_value();
388+
}
389+
}
390+
391+
virtual void on_closed(brpc::StreamId id) override {
392+
if (!promise_set_.exchange(true)) {
393+
close_promise_.set_value();
394+
}
395+
}
396+
};
397+
398+
void CommChannel::prefetch_from_storage(
399+
const std::atomic<bool>& flag,
400+
const std::vector<BlockTransferInfo>& block_transfer_info,
401+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) {
402+
proto::BlockTransferInfos pb_block_transfer_info;
403+
if (!block_transfer_info_to_proto(
404+
0x0, block_transfer_info, &pb_block_transfer_info)) {
405+
return;
406+
}
407+
ClientStreamReceiver receiver(flag, success_cnt);
408+
brpc::Controller cntl;
409+
brpc::StreamOptions stream_options;
410+
brpc::StreamId stream_id;
411+
proto::Status response;
412+
stream_options.handler = &receiver;
413+
if (brpc::StreamCreate(&stream_id, cntl, &stream_options) != 0) {
414+
LOG(ERROR) << "Failed to create stream";
415+
return;
416+
}
417+
418+
stub_->PrefetchFromStorage(
419+
&cntl, &pb_block_transfer_info, &response, nullptr);
420+
421+
if (cntl.Failed()) {
422+
LOG(ERROR) << "Fail to connect stream, " << cntl.ErrorText();
423+
}
424+
425+
receiver.get_close_future().wait();
426+
}
427+
338428
bool CommChannel::get_last_step_result_async(
339429
folly::Promise<std::optional<RawForwardOutput>>& promise) {
340430
proto::Empty req;

xllm/core/distributed_runtime/comm_channel.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,6 @@ class CommChannel {
8787
const uint64_t kv_cache_size,
8888
const std::vector<std::vector<int64_t>>& kv_cache_shape);
8989

90-
virtual bool load_kv_blocks_from_store_async(
91-
const std::vector<CacheBlockInfo>& cache_block_info,
92-
folly::Promise<uint32_t>& promise);
93-
9490
virtual void transfer_kv_blocks(
9591
const std::vector<BlockTransferInfo>& block_transfer_info,
9692
folly::Promise<uint32_t>& promise);
@@ -99,6 +95,11 @@ class CommChannel {
9995
const uint64_t batch_id,
10096
const std::vector<BlockTransferInfo>& block_transfer_info);
10197

98+
virtual void prefetch_from_storage(
99+
const std::atomic<bool>& flag,
100+
const std::vector<BlockTransferInfo>& block_transfer_info,
101+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt);
102+
102103
virtual bool get_last_step_result_async(
103104
folly::Promise<std::optional<RawForwardOutput>>& promise);
104105

xllm/core/distributed_runtime/remote_worker.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ limitations under the License.
3535
#include "util/hash_util.h"
3636

3737
namespace xllm {
38+
3839
RemoteWorker::RemoteWorker(int32_t global_rank,
3940
const std::string& server_address,
4041
const torch::Device& d,
@@ -286,7 +287,7 @@ folly::SemiFuture<uint32_t> RemoteWorker::transfer_kv_blocks(
286287
const std::vector<BlockTransferInfo>& block_transfer_info) {
287288
folly::Promise<uint32_t> promise;
288289
auto future = promise.getSemiFuture();
289-
general_threadpool_.schedule(
290+
copy_threadpool_.schedule(
290291
[this,
291292
block_transfer_info = std::move(block_transfer_info),
292293
promise = std::move(promise)]() mutable {
@@ -298,14 +299,27 @@ folly::SemiFuture<uint32_t> RemoteWorker::transfer_kv_blocks(
298299
void RemoteWorker::transfer_kv_blocks(
299300
const uint64_t batch_id,
300301
const std::vector<BlockTransferInfo>& block_transfer_info) {
301-
general_threadpool_.schedule(
302+
copy_threadpool_.schedule(
302303
[this,
303304
batch_id = batch_id,
304305
block_transfer_info = std::move(block_transfer_info)]() mutable {
305306
channel_->transfer_kv_blocks(batch_id, block_transfer_info);
306307
});
307308
}
308309

310+
void RemoteWorker::prefetch_from_storage(
311+
const std::atomic<bool>& flag,
312+
const std::vector<BlockTransferInfo>& block_transfer_info,
313+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) {
314+
copy_threadpool_.schedule(
315+
[this,
316+
flag = &flag,
317+
block_transfer_info = std::move(block_transfer_info),
318+
success_cnt = success_cnt]() mutable {
319+
channel_->prefetch_from_storage(flag, block_transfer_info, success_cnt);
320+
});
321+
}
322+
309323
const torch::Device& RemoteWorker::device() const {
310324
LOG(ERROR) << "RemoteWorker Method device is UnImplemented.";
311325
}

xllm/core/distributed_runtime/remote_worker.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ class RemoteWorker : public WorkerClient {
117117
const uint64_t batch_id,
118118
const std::vector<BlockTransferInfo>& block_transfer_info) override;
119119

120+
virtual void prefetch_from_storage(
121+
const std::atomic<bool>& flag,
122+
const std::vector<BlockTransferInfo>& block_transfer_info,
123+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) override;
124+
120125
// Run the model and return the output.
121126
virtual folly::SemiFuture<std::optional<ForwardOutput>> step_async(
122127
const ForwardInput& inputs) override;
@@ -144,9 +149,8 @@ class RemoteWorker : public WorkerClient {
144149
// connection resource
145150
std::unique_ptr<CommChannel> channel_;
146151
ThreadPool threadpool_;
147-
// general working thread
148-
// do some overlap work with model execute
149-
ThreadPool general_threadpool_{4};
152+
// copy working thread
153+
ThreadPool copy_threadpool_{4};
150154
const torch::Device device_;
151155
};
152156
} // namespace xllm

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -416,13 +416,12 @@ void WorkerService::PullKVCache(::google::protobuf::RpcController* controller,
416416

417417
void WorkerService::TransferBlocks(
418418
::google::protobuf::RpcController* controller,
419-
const ::xllm::proto::BlockTransferInfos* req,
420-
::xllm::proto::TransferStatus* resp,
419+
const proto::BlockTransferInfos* req,
420+
proto::TransferStatus* resp,
421421
::google::protobuf::Closure* done) {
422422
brpc::ClosureGuard done_guard(done);
423423
std::vector<BlockTransferInfo> block_transfer_info;
424-
uint64_t batch_id;
425-
proto_to_block_transfer_info(*req, batch_id, block_transfer_info);
424+
uint64_t batch_id = proto_to_block_transfer_info(*req, block_transfer_info);
426425

427426
if (batch_id == 0x0) {
428427
resp->set_success_cnt(worker_->transfer_kv_blocks(block_transfer_info));
@@ -432,6 +431,114 @@ void WorkerService::TransferBlocks(
432431
return;
433432
}
434433

434+
class ServerStreamHandler : public brpc::StreamInputHandler {
435+
private:
436+
std::promise<void> close_promise_;
437+
std::atomic<bool> promise_set_{false};
438+
439+
public:
440+
~ServerStreamHandler() {
441+
if (!promise_set_.exchange(true)) {
442+
try {
443+
close_promise_.set_value();
444+
} catch (const std::exception& e) {
445+
LOG(WARNING) << "Exception in destructor: " << e.what();
446+
}
447+
}
448+
}
449+
450+
std::future<void> get_close_future() { return close_promise_.get_future(); }
451+
452+
int on_received_messages(brpc::StreamId id,
453+
butil::IOBuf* const messages[],
454+
size_t size) override {
455+
LOG(WARNING) << "ServerStreamHandler::on_received_messages not implement.";
456+
return 0;
457+
}
458+
459+
void on_closed(brpc::StreamId id) override {
460+
if (!promise_set_.exchange(true)) {
461+
close_promise_.set_value();
462+
}
463+
}
464+
465+
void on_idle_timeout(brpc::StreamId id) override {
466+
if (!promise_set_.exchange(true)) {
467+
LOG(WARNING) << "Stream idle timeout: " << id;
468+
close_promise_.set_value();
469+
}
470+
}
471+
};
472+
473+
void WorkerService::PrefetchFromStorage(
474+
google::protobuf::RpcController* controller,
475+
const proto::BlockTransferInfos* req,
476+
proto::Status* resp,
477+
google::protobuf::Closure* done) {
478+
brpc::ClosureGuard done_guard(done);
479+
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
480+
481+
auto stream_handler = std::make_unique<ServerStreamHandler>();
482+
auto stream_id = std::make_unique<brpc::StreamId>();
483+
brpc::StreamOptions stream_options;
484+
stream_options.handler = stream_handler.get();
485+
if (brpc::StreamAccept(stream_id.get(), *cntl, &stream_options) != 0) {
486+
resp->set_ok(false);
487+
LOG(ERROR) << "Failed to accept stream!";
488+
return;
489+
}
490+
491+
std::vector<BlockTransferInfo> block_transfer_info;
492+
proto_to_block_transfer_info(*req, block_transfer_info);
493+
494+
copy_threadpool_.schedule(
495+
[this,
496+
block_transfer_info = std::move(block_transfer_info),
497+
stream_id = std::move(stream_id),
498+
stream_handler = std::move(stream_handler)]() mutable {
499+
Slice<BlockTransferInfo> transfer_slice{block_transfer_info};
500+
auto close_future = stream_handler->get_close_future();
501+
bool is_completed = false;
502+
503+
for (size_t i = 0; i < transfer_slice.size();
504+
i += stream_copy_batch_size_) {
505+
auto current_slice = transfer_slice.slice(
506+
i, std::min(i + stream_copy_batch_size_, transfer_slice.size()));
507+
508+
auto success_cnt = worker_->prefetch_from_storage(current_slice);
509+
510+
if (success_cnt != current_slice.size() ||
511+
i + stream_copy_batch_size_ >= transfer_slice.size()) {
512+
is_completed = true;
513+
}
514+
515+
butil::IOBuf buf;
516+
buf.append(std::to_string(success_cnt));
517+
if (brpc::StreamWrite(*stream_id.get(), buf) != 0) {
518+
brpc::StreamClose(*stream_id.get());
519+
is_completed = false;
520+
break;
521+
}
522+
523+
if (is_completed) {
524+
if (success_cnt != 0) {
525+
butil::IOBuf buf_end;
526+
buf_end.append("0");
527+
brpc::StreamWrite(*stream_id.get(), buf_end);
528+
}
529+
break;
530+
}
531+
}
532+
if (is_completed) {
533+
close_future.wait();
534+
}
535+
brpc::StreamClose(*stream_id.get());
536+
});
537+
538+
resp->set_ok(true);
539+
return;
540+
}
541+
435542
void WorkerService::GetDeviceInfo(::google::protobuf::RpcController* controller,
436543
const proto::Empty* req,
437544
proto::DeviceInfo* resp,

0 commit comments

Comments
 (0)