From ac92836226f3389b1458b70f63eaec953affe7b1 Mon Sep 17 00:00:00 2001 From: huangweizhe1 Date: Tue, 26 Aug 2025 20:18:34 +0800 Subject: [PATCH 1/2] feat: support multi-priority and on/offline unified request schedule. --- xllm/core/common/global_flags.cpp | 6 + xllm/core/common/global_flags.h | 4 + xllm/core/common/metrics.cpp | 8 + xllm/core/common/metrics.h | 4 + xllm/core/common/options.h | 4 + .../disagg_pd_service_impl.cpp | 5 +- xllm/core/framework/block/block_manager.h | 9 +- .../framework/block/block_manager_impl.cpp | 28 ++ .../core/framework/block/block_manager_impl.h | 7 + .../framework/block/block_manager_pool.cpp | 10 + .../core/framework/block/block_manager_pool.h | 4 + .../block/concurrent_block_manager_impl.cpp | 9 + .../block/concurrent_block_manager_impl.h | 4 + xllm/core/framework/request/CMakeLists.txt | 2 + .../framework/request/priority_comparator.cpp | 54 +++ .../framework/request/priority_comparator.h | 36 ++ xllm/core/framework/request/request.cpp | 10 +- xllm/core/framework/request/request.h | 15 +- .../core/framework/request/request_params.cpp | 20 + xllm/core/framework/request/request_params.h | 7 + xllm/core/runtime/llm_master.cpp | 5 +- xllm/core/runtime/options.h | 4 + xllm/core/scheduler/CMakeLists.txt | 3 + .../scheduler/chunked_prefill_scheduler.cpp | 32 +- .../chunked_prefill_scheduler_test.cpp | 2 + xllm/core/scheduler/continuous_scheduler.cpp | 314 ++++++++++++---- xllm/core/scheduler/continuous_scheduler.h | 71 +++- .../scheduler/continuous_scheduler_test.cpp | 341 ++++++++++++++++++ xllm/core/scheduler/decode_priority_queue.h | 175 +++++++++ xllm/core/scheduler/disagg_pd_scheduler.cpp | 28 +- xllm/core/scheduler/disagg_pd_scheduler.h | 9 +- .../scheduler/zero_eviction_scheduler.cpp | 13 +- xllm/core/scheduler/zero_eviction_scheduler.h | 4 +- xllm/proto/chat.proto | 9 + xllm/proto/common.proto | 10 + xllm/proto/completion.proto | 7 + xllm/proto/disagg_pd.proto | 4 + xllm/proto/multimodal.proto | 8 + xllm/xllm.cpp | 4 +- 39 files changed, 1155 insertions(+), 134 deletions(-) create mode 100644 xllm/core/framework/request/priority_comparator.cpp create mode 100644 xllm/core/framework/request/priority_comparator.h create mode 100644 xllm/core/scheduler/continuous_scheduler_test.cpp create mode 100644 xllm/core/scheduler/decode_priority_queue.h diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index cd94dbf9..cad4cec1 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -199,3 +199,9 @@ DEFINE_string(etcd_addr, "", "etcd adderss for save instance meta info"); DEFINE_bool(enable_service_routing, false, "whether to use etcd."); DEFINE_int32(heart_beat_interval, 3, "heart beat interval"); + +DEFINE_string(priority_strategy, "FCFS", "priority strategy for requests"); + +DEFINE_bool(enable_on_preempt_off, + true, + "whether enable online preempt offline"); diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 563a42c0..49e059b2 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -126,3 +126,7 @@ DECLARE_int32(heart_beat_interval); DECLARE_int32(chunked_match_frequency); DECLARE_bool(use_zero_evict); + +DECLARE_string(priority_strategy); + +DECLARE_bool(enable_on_preempt_off); diff --git a/xllm/core/common/metrics.cpp b/xllm/core/common/metrics.cpp index be153b18..7a881219 100644 --- a/xllm/core/common/metrics.cpp +++ b/xllm/core/common/metrics.cpp @@ -88,6 +88,14 @@ DEFINE_GAUGE(num_running_requests, "Number of running requests in scheduler"); DEFINE_GAUGE(num_waiting_requests, "Number of waiting requests in scheduler"); DEFINE_GAUGE(num_preempted_requests, "Number of preempted requests in scheduler"); +DEFINE_GAUGE(num_offd_preempt_off_requests, + "Number of offline decode preempt offline requests in scheduler"); +DEFINE_GAUGE(num_ond_preempt_on_requests, + "Number of online decode preempt online requests in scheduler"); +DEFINE_GAUGE(num_onp_preempt_off_requests, + "Number of online prefill preempt offline requests in scheduler"); +DEFINE_GAUGE(num_ond_preempt_off_requests, + "Number of online decode preempt offline requests in scheduler"); DEFINE_GAUGE(num_running_sequences, "Number of running sequences"); diff --git a/xllm/core/common/metrics.h b/xllm/core/common/metrics.h index 7338e8cc..64bf9c55 100644 --- a/xllm/core/common/metrics.h +++ b/xllm/core/common/metrics.h @@ -149,6 +149,10 @@ DECLARE_GAUGE(num_pending_requests); DECLARE_GAUGE(num_running_requests); DECLARE_GAUGE(num_waiting_requests); DECLARE_GAUGE(num_preempted_requests); +DECLARE_GAUGE(num_offd_preempt_off_requests); +DECLARE_GAUGE(num_ond_preempt_on_requests); +DECLARE_GAUGE(num_onp_preempt_off_requests); +DECLARE_GAUGE(num_ond_preempt_off_requests); DECLARE_GAUGE(num_running_sequences); DECLARE_GAUGE(kv_cache_utilization_perc); DECLARE_GAUGE(num_blocks_in_prefix_cache); diff --git a/xllm/core/common/options.h b/xllm/core/common/options.h index 9ea5a596..950d9e12 100644 --- a/xllm/core/common/options.h +++ b/xllm/core/common/options.h @@ -113,6 +113,10 @@ class Options { PROPERTY(bool, enable_service_routing) = false; PROPERTY(std::optional, tool_call_parser); + + PROPERTY(std::string, priority_strategy) = "FCFS"; + + PROPERTY(bool, enable_on_preempt_off) = true; }; } // namespace xllm diff --git a/xllm/core/distributed_runtime/disagg_pd_service_impl.cpp b/xllm/core/distributed_runtime/disagg_pd_service_impl.cpp index 7e0a76a6..d8d4f3a9 100644 --- a/xllm/core/distributed_runtime/disagg_pd_service_impl.cpp +++ b/xllm/core/distributed_runtime/disagg_pd_service_impl.cpp @@ -94,7 +94,10 @@ std::shared_ptr DisaggPDServiceImpl::generate_request( req.x_request_id(), req.x_request_time(), std::move(req_state), - req.service_req_id()); + req.service_req_id(), + req.offline(), + req.slo_ms(), + req.priority()); // add one sequence, rest will be added by scheduler return new_request; diff --git a/xllm/core/framework/block/block_manager.h b/xllm/core/framework/block/block_manager.h index 8c8e52b2..a5e85742 100644 --- a/xllm/core/framework/block/block_manager.h +++ b/xllm/core/framework/block/block_manager.h @@ -32,10 +32,13 @@ limitations under the License. #include "common/metrics.h" #include "common/types.h" #include "framework/prefix_cache/prefix_cache.h" +#include "framework/request/request.h" +#include "framework/request/sequence.h" +#include "scheduler/decode_priority_queue.h" #include "util/timer.h" namespace xllm { - +// class DecodePriorityQueue; class BlockManager { public: struct Options { @@ -59,6 +62,10 @@ class BlockManager { virtual void cache(const Slice& token_ids, const Slice& blocks) = 0; + virtual bool check_if_enough_to_evict( + DecodePriorityQueue* running_queue_to_evict, + Sequence* prefill_sequence, + size_t& num_request_to_evict) = 0; // get merged all dp rank KVCacheEvent virtual void get_merged_kvcache_event(KvCacheEvent* event) const = 0; diff --git a/xllm/core/framework/block/block_manager_impl.cpp b/xllm/core/framework/block/block_manager_impl.cpp index 128402a0..ea686495 100644 --- a/xllm/core/framework/block/block_manager_impl.cpp +++ b/xllm/core/framework/block/block_manager_impl.cpp @@ -30,6 +30,7 @@ BlockManagerImpl::BlockManagerImpl(const Options& options) } size_t total_blocks = options_.num_blocks(); + block_size_ = options_.block_size(); num_free_blocks_ = total_blocks; free_blocks_.reserve(total_blocks); for (int32_t i = 0; i < total_blocks; ++i) { @@ -73,6 +74,33 @@ void BlockManagerImpl::deallocate(const Slice& blocks) { } } +bool BlockManagerImpl::check_if_enough_to_evict( + DecodePriorityQueue* running_queue_to_evict, + Sequence* prefill_sequence, + size_t& num_request_to_evict) { + // check if it's enough when we evict this requests queue + + const size_t num_blocks_needed = + (prefill_sequence->num_tokens() + block_size_ - 1) / block_size_; + size_t num_blocks_can_evict = 0; + // count the number of blocks can be preempted + for (auto it = running_queue_to_evict->rbegin(); + it != running_queue_to_evict->rend(); + ++it) { + std::shared_ptr request_to_preempt = *it; + num_request_to_evict++; + // count the number of blocks belong to the request + for (const auto& seq : request_to_preempt->sequences()) { + num_blocks_can_evict += seq->kv_state().num_kv_blocks(); + } + if ((num_blocks_needed <= num_blocks_can_evict) || + has_enough_blocks(num_blocks_needed - num_blocks_can_evict)) { + return true; + } + } + return false; +} + bool BlockManagerImpl::has_enough_blocks(uint32_t num_blocks) { if (num_blocks <= num_free_blocks_) { return true; diff --git a/xllm/core/framework/block/block_manager_impl.h b/xllm/core/framework/block/block_manager_impl.h index 6e1e5ead..1ce3ba37 100644 --- a/xllm/core/framework/block/block_manager_impl.h +++ b/xllm/core/framework/block/block_manager_impl.h @@ -46,6 +46,10 @@ class BlockManagerImpl : public BlockManager { void get_merged_kvcache_event(KvCacheEvent* event) const override; + bool check_if_enough_to_evict(DecodePriorityQueue* running_queue_to_evict, + Sequence* prefill_sequence, + size_t& num_request_to_evict) override; + size_t num_blocks_in_prefix_cache() const override { if (options_.enable_prefix_cache()) { CHECK(prefix_cache_); @@ -99,6 +103,9 @@ class BlockManagerImpl : public BlockManager { // free block count size_t num_free_blocks_ = 0; + // block size + size_t block_size_ = 0; + // free block list std::vector free_blocks_; }; diff --git a/xllm/core/framework/block/block_manager_pool.cpp b/xllm/core/framework/block/block_manager_pool.cpp index ae835472..dd620d0d 100644 --- a/xllm/core/framework/block/block_manager_pool.cpp +++ b/xllm/core/framework/block/block_manager_pool.cpp @@ -92,6 +92,16 @@ bool BlockManagerPool::allocate(Sequence* sequence) { return allocate(sequence, sequence->num_tokens()); } +bool BlockManagerPool::check_if_enough_to_evict( + DecodePriorityQueue* running_queue_to_evict, + Sequence* prefill_sequence, + size_t& num_request_to_evict) { + DCHECK(prefill_sequence != nullptr); + int32_t dp_rank = prefill_sequence->dp_rank(); + return block_managers_[dp_rank]->check_if_enough_to_evict( + running_queue_to_evict, prefill_sequence, num_request_to_evict); +} + bool BlockManagerPool::allocate(std::vector& sequences) { for (auto* sequence : sequences) { DCHECK(sequence != nullptr); diff --git a/xllm/core/framework/block/block_manager_pool.h b/xllm/core/framework/block/block_manager_pool.h index f646646b..d39e0d4b 100644 --- a/xllm/core/framework/block/block_manager_pool.h +++ b/xllm/core/framework/block/block_manager_pool.h @@ -48,6 +48,10 @@ class BlockManagerPool { void get_merged_kvcache_event(KvCacheEvent* event) const; float get_gpu_cache_usage_perc() const; + bool check_if_enough_to_evict(DecodePriorityQueue* running_queue_to_evict, + Sequence* prefill_sequence, + size_t& num_request_to_evict); + std::vector num_blocks_in_prefix_cache() const; std::vector num_free_blocks() const; std::vector num_used_blocks() const; diff --git a/xllm/core/framework/block/concurrent_block_manager_impl.cpp b/xllm/core/framework/block/concurrent_block_manager_impl.cpp index d7414dc1..63a8a523 100644 --- a/xllm/core/framework/block/concurrent_block_manager_impl.cpp +++ b/xllm/core/framework/block/concurrent_block_manager_impl.cpp @@ -43,6 +43,15 @@ void ConcurrentBlockManagerImpl::cache(const Slice& token_ids, BlockManagerImpl::cache(token_ids, blocks); } +bool ConcurrentBlockManagerImpl::check_if_enough_to_evict( + DecodePriorityQueue* running_queue_to_evict, + Sequence* prefill_sequence, + size_t& num_request_to_evict) { + std::lock_guard lock(mutex_); + return BlockManagerImpl::check_if_enough_to_evict( + running_queue_to_evict, prefill_sequence, num_request_to_evict); +} + size_t ConcurrentBlockManagerImpl::num_blocks_in_prefix_cache() const { std::lock_guard lock(mutex_); return BlockManagerImpl::num_blocks_in_prefix_cache(); diff --git a/xllm/core/framework/block/concurrent_block_manager_impl.h b/xllm/core/framework/block/concurrent_block_manager_impl.h index 30233cd0..a4a77a4e 100644 --- a/xllm/core/framework/block/concurrent_block_manager_impl.h +++ b/xllm/core/framework/block/concurrent_block_manager_impl.h @@ -39,6 +39,10 @@ class ConcurrentBlockManagerImpl : public BlockManagerImpl { void cache(const Slice& token_ids, const Slice& blocks) override; + bool check_if_enough_to_evict(DecodePriorityQueue* running_queue_to_evict, + Sequence* prefill_sequence, + size_t& num_request_to_evict) override; + // get the number of blocks in the prefix cache size_t num_blocks_in_prefix_cache() const override; diff --git a/xllm/core/framework/request/CMakeLists.txt b/xllm/core/framework/request/CMakeLists.txt index 1507e185..f0db5a4e 100644 --- a/xllm/core/framework/request/CMakeLists.txt +++ b/xllm/core/framework/request/CMakeLists.txt @@ -18,6 +18,7 @@ cc_library( sequences_group.h request_state.h stopping_checker.h + priority_comparator.h SRCS finish_reason.cpp incremental_decoder.cpp @@ -32,6 +33,7 @@ cc_library( sequences_group.cpp request_state.cpp stopping_checker.cpp + priority_comparator.cpp DEPS :kv_cache :prefix_cache diff --git a/xllm/core/framework/request/priority_comparator.cpp b/xllm/core/framework/request/priority_comparator.cpp new file mode 100644 index 00000000..e7769c75 --- /dev/null +++ b/xllm/core/framework/request/priority_comparator.cpp @@ -0,0 +1,54 @@ +#include "priority_comparator.h" + +#include "glog/logging.h" + +namespace xllm { + +// implement operator() +bool FCFSComparator::operator()(const std::shared_ptr& a, + const std::shared_ptr& b) const { + return a->created_time() > b->created_time(); +} + +bool StrictPriorityComparator::operator()( + const std::shared_ptr& a, + const std::shared_ptr& b) const { + auto priority_a = a->priority(); + auto priority_b = b->priority(); + if (priority_a != priority_b) { + return priority_a > priority_b; // HIGH(1) < NORMAL(2) < LOW(3) + } + return a->created_time() > b->created_time(); +} + +bool DeadlineComparator::operator()(const std::shared_ptr& a, + const std::shared_ptr& b) const { + return a->slo_ms() - a->elapsed_seconds() * 1000 > + b->slo_ms() - b->elapsed_seconds() * 1000; +} + +std::function&, + const std::shared_ptr&)> +create_comparator(const std::string& priority_strategy) { + if (priority_strategy == "FCFS") { + return [](const std::shared_ptr& a, + const std::shared_ptr& b) { + return FCFSComparator()(a, b); + }; + } else if (priority_strategy == "priority") { + return [](const std::shared_ptr& a, + const std::shared_ptr& b) { + return StrictPriorityComparator()(a, b); + }; + } else if (priority_strategy == "deadline") { + return [](const std::shared_ptr& a, + const std::shared_ptr& b) { + return DeadlineComparator()(a, b); + }; + } else { + LOG(FATAL) << "Unknown strategy: " << priority_strategy; + return nullptr; + } +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/request/priority_comparator.h b/xllm/core/framework/request/priority_comparator.h new file mode 100644 index 00000000..6d9834da --- /dev/null +++ b/xllm/core/framework/request/priority_comparator.h @@ -0,0 +1,36 @@ +#pragma once +#include +#include +#include + +#include "common.pb.h" +#include "framework/request/request.h" + +namespace xllm { +class PriorityComparator { + public: + virtual bool operator()(const std::shared_ptr& a, + const std::shared_ptr& b) const = 0; + virtual ~PriorityComparator() = default; +}; + +struct FCFSComparator : public PriorityComparator { + bool operator()(const std::shared_ptr& a, + const std::shared_ptr& b) const override; +}; + +struct StrictPriorityComparator : public PriorityComparator { + bool operator()(const std::shared_ptr& a, + const std::shared_ptr& b) const override; +}; + +struct DeadlineComparator : public PriorityComparator { + bool operator()(const std::shared_ptr& a, + const std::shared_ptr& b) const override; +}; + +std::function&, + const std::shared_ptr&)> +create_comparator(const std::string& priority_strategy); + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/request/request.cpp b/xllm/core/framework/request/request.cpp index 41442432..1fd4a859 100644 --- a/xllm/core/framework/request/request.cpp +++ b/xllm/core/framework/request/request.cpp @@ -33,13 +33,19 @@ Request::Request(const std::string& request_id, const std::string& x_request_id, const std::string& x_request_time, const RequestState& state, - const std::string& service_request_id) + const std::string& service_request_id, + bool offline, + int32_t slo_ms, + xllm::proto::Priority priority) : request_id_(request_id), service_request_id_(service_request_id), x_request_id_(x_request_id), x_request_time_(x_request_time), state_(std::move(state)), - created_time_(absl::Now()) { + created_time_(absl::Now()), + offline_(offline), + priority_(priority), + slo_ms_(slo_ms) { create_sequences_group(); } diff --git a/xllm/core/framework/request/request.h b/xllm/core/framework/request/request.h index 96a04a87..462dab89 100644 --- a/xllm/core/framework/request/request.h +++ b/xllm/core/framework/request/request.h @@ -37,7 +37,10 @@ class Request { const std::string& x_request_id, const std::string& x_request_time, const RequestState& state, - const std::string& service_request_id = ""); + const std::string& service_request_id = "", + bool offline = false, + int32_t slo_ms = 0, + xllm::proto::Priority priority = xllm::proto::Priority::NORMAL); bool finished() const; @@ -81,6 +84,10 @@ class Request { const std::string& x_request_time() const { return x_request_time_; } + const bool offline() const { return offline_; } + const int32_t slo_ms() const { return slo_ms_; } + const xllm::proto::Priority priority() const { return priority_; } + RequestState& state() { return state_; } void update_connection_status(); @@ -108,6 +115,12 @@ class Request { std::atomic cancelled_{false}; + bool offline_; + + int32_t slo_ms_; + + xllm::proto::Priority priority_; + private: void create_sequences_group(); }; diff --git a/xllm/core/framework/request/request_params.cpp b/xllm/core/framework/request/request_params.cpp index cbef1bcb..2552e453 100644 --- a/xllm/core/framework/request/request_params.cpp +++ b/xllm/core/framework/request/request_params.cpp @@ -47,6 +47,15 @@ RequestParams::RequestParams(const proto::CompletionRequest& request, request_id = generate_completion_request_id(); x_request_id = x_rid; x_request_time = x_rtime; + if (request.has_offline()) { + offline = request.offline(); + } + if (request.has_slo_ms()) { + slo_ms = request.slo_ms(); + } + if (request.has_priority()) { + priority = request.priority(); + } if (request.has_service_request_id()) { service_request_id = request.service_request_id(); @@ -186,6 +195,17 @@ void InitFromChatRequest(RequestParams& params, const ChatRequest& request) { if (request.has_request_id()) { params.request_id = request.request_id(); } + + if (request.has_offline()) { + params.offline = request.offline(); + } + if (request.has_slo_ms()) { + params.slo_ms = request.slo_ms(); + } + if (request.has_priority()) { + params.priority = request.priority(); + } + if (request.has_service_request_id()) { params.service_request_id = request.service_request_id(); } diff --git a/xllm/core/framework/request/request_params.h b/xllm/core/framework/request/request_params.h index 11ebfc5d..2e779e9a 100644 --- a/xllm/core/framework/request/request_params.h +++ b/xllm/core/framework/request/request_params.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "chat.pb.h" +#include "common.pb.h" #include "common/macros.h" #include "completion.pb.h" #include "core/common/macros.h" @@ -124,6 +125,12 @@ struct RequestParams { std::vector tools; std::string tool_choice = "auto"; bool has_tools() const { return !tools.empty(); } + + bool offline = false; + + int32_t slo_ms = 0; + + xllm::proto::Priority priority = xllm::proto::Priority::NORMAL; }; } // namespace xllm diff --git a/xllm/core/runtime/llm_master.cpp b/xllm/core/runtime/llm_master.cpp index c430ed40..dde08c91 100644 --- a/xllm/core/runtime/llm_master.cpp +++ b/xllm/core/runtime/llm_master.cpp @@ -442,7 +442,10 @@ std::shared_ptr LLMMaster::generate_request( sp.x_request_id, sp.x_request_time, std::move(req_state), - sp.service_request_id); + sp.service_request_id, + sp.offline, + sp.slo_ms, + sp.priority); // add one sequence, rest will be added by scheduler return request; diff --git a/xllm/core/runtime/options.h b/xllm/core/runtime/options.h index d45cf990..e9647a8a 100644 --- a/xllm/core/runtime/options.h +++ b/xllm/core/runtime/options.h @@ -121,6 +121,10 @@ struct Options { // enable service routing mode. PROPERTY(bool, enable_service_routing) = false; + + PROPERTY(std::string, priority_strategy) = "FCFS"; + + PROPERTY(bool, enable_on_preempt_off) = true; }; } // namespace runtime diff --git a/xllm/core/scheduler/CMakeLists.txt b/xllm/core/scheduler/CMakeLists.txt index 52a3b075..6e7e312d 100644 --- a/xllm/core/scheduler/CMakeLists.txt +++ b/xllm/core/scheduler/CMakeLists.txt @@ -12,6 +12,7 @@ cc_library( async_response_processor.h scheduler.h scheduler_factory.h + decode_priority_queue.h SRCS chunked_prefill_scheduler.cpp zero_eviction_scheduler.cpp @@ -32,8 +33,10 @@ cc_library( cc_test( NAME chunked_prefill_scheduler_test + continuous_scheduler_test SRCS chunked_prefill_scheduler_test.cpp + continuous_scheduler_test.cpp DEPS :scheduler GTest::gtest_main diff --git a/xllm/core/scheduler/chunked_prefill_scheduler.cpp b/xllm/core/scheduler/chunked_prefill_scheduler.cpp index 2c56714e..aa0e4773 100644 --- a/xllm/core/scheduler/chunked_prefill_scheduler.cpp +++ b/xllm/core/scheduler/chunked_prefill_scheduler.cpp @@ -37,8 +37,8 @@ ChunkedPrefillScheduler::~ChunkedPrefillScheduler() { } // release all requests in the running priority queue - while (!running_queue_.empty()) { - running_queue_.pop_front(); + while (!running_queue_->empty()) { + running_queue_->pop_top(); } } @@ -52,7 +52,7 @@ void ChunkedPrefillScheduler::handle_abnormal_request( size_t& remaining_seq_budget, bool budget_exhausted, bool blocks_exhausted) { - std::shared_ptr request = running_queue_.front(); + std::shared_ptr request = running_queue_->top(); if (candidate_sequences.empty()) { if (!running_sequences_.empty()) { return; @@ -73,10 +73,10 @@ void ChunkedPrefillScheduler::handle_abnormal_request( LOG(ERROR) << "Request prompt is too long, please set a larger " "max_tokens value via --max_tokens_per_batch."; } else { - CHECK(running_queue_.size() == 1) + CHECK(running_queue_->size() == 1) << "Running queue size is not 1, there maybe a bug of request " "preemption logic. running_queue_.size =" - << running_queue_.size(); + << running_queue_->size(); if (util::sum(block_manager_pool_->num_used_blocks()) != request->total_num_blocks()) { // blocks_exhausted is true. @@ -90,7 +90,7 @@ void ChunkedPrefillScheduler::handle_abnormal_request( } // request is too long, budget or memory no enough. - running_queue_.pop_front(); + running_queue_->pop_top(); block_manager_pool_->deallocate(request.get()); response_processor_->process_failed_request( request, @@ -98,7 +98,7 @@ void ChunkedPrefillScheduler::handle_abnormal_request( "No enough resource to schedule a single sequence"}); } else { // partially schedule the sequences in request - running_queue_.pop_front(); + running_queue_->pop_top(); running_requests_.emplace_back(request); running_sequences_.insert(running_sequences_.end(), candidate_sequences.begin(), @@ -119,10 +119,10 @@ void ChunkedPrefillScheduler::handle_running_queue_requests( std::vector& prefill_stage_sequences, bool& budget_exhausted, bool& blocks_exhausted) { - while (!running_queue_.empty() && + while (!running_queue_->empty() && remaining_token_budget > options_.num_speculative_tokens() && remaining_seq_budget > 0) { - std::shared_ptr request(running_queue_.front()); + std::shared_ptr request(running_queue_->top()); // TODO: check if request is timeout const size_t num_sequences = request->sequences().size(); @@ -186,7 +186,7 @@ void ChunkedPrefillScheduler::handle_running_queue_requests( if (has_enough_budget && has_enough_blocks) { // remove the request from the priority queue - running_queue_.pop_front(); + running_queue_->pop_top(); // add the request to the batch running_requests_.emplace_back(request); running_sequences_.insert(running_sequences_.end(), @@ -216,13 +216,13 @@ void ChunkedPrefillScheduler::handle_running_queue_requests( // memory exhausted, preempt lowest priority request and retry. // preemptable_requests_ only contain decoding requests. - if (running_queue_.size() > 1) { - std::shared_ptr request_to_preempt = running_queue_.back(); + if (running_queue_->size() > 1) { + std::shared_ptr request_to_preempt = running_queue_->back(); if (request_to_preempt.get() != request.get()) { ++num_preempted_requests; block_manager_pool_->deallocate(request_to_preempt.get()); - running_queue_.pop_back(); + running_queue_.->pop_back(); // add preemptable request to waiting priority queue request_to_preempt->set_preempted(); waiting_priority_queue_.push(request_to_preempt); @@ -343,7 +343,7 @@ void ChunkedPrefillScheduler::handle_prefill_requests( } if (running_sequences_.empty() && !waiting_priority_queue_.empty() && - running_queue_.empty() && + running_queue_->empty() && block_manager_pool_->kv_cache_utilization() == 0) { LOG(ERROR) << "Request prompt is too long, no enough memory to schedule " "a single sequence"; @@ -463,7 +463,7 @@ std::vector ChunkedPrefillScheduler::prepare_batch() { } // push the request front to the priority deque - running_queue_.push_front(request); + running_queue_->push(request, false /*if_back*/); } // clear previous batch @@ -551,7 +551,7 @@ std::vector ChunkedPrefillScheduler::prepare_batch() { pending_requests_.load(std::memory_order_relaxed)); GAUGE_SET(num_running_requests, running_requests_.size()); GAUGE_SET(num_waiting_requests, - waiting_priority_queue_.size() + running_queue_.size()); + waiting_priority_queue_.size() + running_queue_->size()); GAUGE_SET(num_preempted_requests, num_preempted_requests); GAUGE_SET(num_running_sequences, running_sequences_.size()); diff --git a/xllm/core/scheduler/chunked_prefill_scheduler_test.cpp b/xllm/core/scheduler/chunked_prefill_scheduler_test.cpp index 141c0b6c..1d2db983 100644 --- a/xllm/core/scheduler/chunked_prefill_scheduler_test.cpp +++ b/xllm/core/scheduler/chunked_prefill_scheduler_test.cpp @@ -309,6 +309,7 @@ TEST(ChunkedPrefillSchedulerTest, NormalSchedule) { // test preempt TEST(ChunkedPrefillSchedulerTest, PreemptSchedule) { // set max free blocks: 9, support 9*32=288 tokens + // actually only 8 free blocks , because default 1 block is for padding int block_num = 9; int block_size = 32; int max_tokens_per_chunk_for_prefill = 1024; @@ -347,6 +348,7 @@ TEST(ChunkedPrefillSchedulerTest, PreemptSchedule) { int free_blocks_after_preempt = util::max(block_manager_pool->num_free_blocks()); EXPECT_TRUE(free_blocks_after_preempt > free_blocks_before_preempt); + EXPECT_TRUE(scheduler->get_waiting_requests_num() == 1); // append a new block block_manager_pool->allocate(batch[0][0]); // remove preempted request from running_requests diff --git a/xllm/core/scheduler/continuous_scheduler.cpp b/xllm/core/scheduler/continuous_scheduler.cpp index 75fd223b..da1204e9 100644 --- a/xllm/core/scheduler/continuous_scheduler.cpp +++ b/xllm/core/scheduler/continuous_scheduler.cpp @@ -26,9 +26,11 @@ limitations under the License. #include "common/metrics.h" #include "framework/batch/batch_factory.h" +#include "framework/request/priority_comparator.h" #include "framework/request/request.h" #include "framework/request/sequence.h" #include "runtime/engine.h" +#include "scheduler/decode_priority_queue.h" #include "util/utils.h" namespace xllm { @@ -37,7 +39,12 @@ constexpr size_t kRequestQueueSize = 100000; } // namespace ContinuousScheduler::ContinuousScheduler(Engine* engine, const Options& options) - : options_(options), engine_(engine), request_queue_(kRequestQueueSize) { + : options_(options), + engine_(engine), + request_queue_(kRequestQueueSize), + waiting_priority_queue_(create_comparator(options.priority_strategy())), + waiting_priority_queue_offline_( + create_comparator(options.priority_strategy())) { CHECK(engine_ != nullptr); block_manager_pool_ = engine_->block_manager_pool(); CHECK(block_manager_pool_ != nullptr); @@ -50,7 +57,7 @@ ContinuousScheduler::ContinuousScheduler(Engine* engine, const Options& options) options_.instance_role(), options_.enable_schedule_overlap(), options_.enable_decode_response_to_service()); - + create_running_queue(options); if (options_.enable_service_routing()) { XServiceClient::get_instance()->set_scheduler(this); } @@ -69,9 +76,31 @@ bool ContinuousScheduler::add_request(std::shared_ptr& request) { return false; } +void ContinuousScheduler::create_running_queue(const Options& options) { + if (options.priority_strategy() == "FCFS") { + running_queue_offline_ = std::make_unique(); + running_queue_ = std::make_unique(); + } else { + std::unique_ptr comparator; + if (options.priority_strategy() == "deadline") { + comparator = std::make_unique(); + } else if (options.priority_strategy() == "priority") { + comparator = std::make_unique(); + } else { + LOG(FATAL) << "Unknown strategy: " << options.priority_strategy(); + } + running_queue_ = + std::make_unique(std::move(comparator)); + running_queue_offline_ = + std::make_unique(std::move(comparator)); + } +} + void ContinuousScheduler::handle_prefill_requests( size_t& remaining_token_budget, size_t& remaining_seq_budget, + RequestPriorityQueue& waiting_priority_queue, + size_t& num_onp_preempt_off_requests, std::vector>& finished_requests) { // Handle new request prompt first. // Include those requests that are preempted by others. @@ -83,17 +112,17 @@ void ContinuousScheduler::handle_prefill_requests( // // NOTE: preempted requests will be pushed in waiting_priority_queue, // they may contian many sequences, so we should check here. - while (!waiting_priority_queue_.empty() && remaining_seq_budget > 0 && + while (!waiting_priority_queue.empty() && remaining_seq_budget > 0 && remaining_token_budget > 0 && block_manager_pool_->kv_cache_utilization() < FLAGS_prefill_scheduling_memory_usage_threshold) { - std::shared_ptr request(waiting_priority_queue_.top()); + std::shared_ptr request(waiting_priority_queue.top()); if (request->finished() || request->cancelled()) { block_manager_pool_->deallocate(request.get()); // release the ownership of the request finished_requests.emplace_back(request); // remove the request from the priority queue - waiting_priority_queue_.pop(); + waiting_priority_queue.pop(); continue; } @@ -105,6 +134,8 @@ void ContinuousScheduler::handle_prefill_requests( // TODO: FIXME later // Optimization of the scheduling algorithm under multiple sequences + // TODO: can refactor like handle_decode otherwise request with multiple + // long sequences may stuck when n>1 size_t allocated_tokens = 0; size_t allocated_seqs = 0; bool can_schedule = true; @@ -124,12 +155,45 @@ void ContinuousScheduler::handle_prefill_requests( break; } + // preempt offline decode if (!block_manager_pool_->allocate(prefill_sequence.get())) { - block_manager_pool_->deallocate(prefill_sequence.get()); can_schedule = false; - break; + if (options_.enable_on_preempt_off() && !request->offline() && + !running_queue_offline_->empty()) { + size_t num_request_to_evict = 0; + // according to the prefill_sequence num tokens to check if can + // allocate blocks for it through evict + bool enough_to_evict = block_manager_pool_->check_if_enough_to_evict( + running_queue_offline_.get(), + prefill_sequence.get(), + num_request_to_evict); + if (enough_to_evict) { + for (size_t i = 0; i < num_request_to_evict; ++i) { + std::shared_ptr request_to_preempt = + running_queue_offline_->back(); + ++num_onp_preempt_off_requests; + block_manager_pool_->deallocate(request_to_preempt.get()); + running_queue_offline_->pop_back(); + // add preemptable request to waiting priority queue + // TO IMPROVE?: not process this offline request in current batch + request_to_preempt->set_preempted(); + waiting_priority_queue_offline_.push(request_to_preempt); + } + if (!block_manager_pool_->allocate(prefill_sequence.get())) { + LOG(ERROR) << "Should be able to allocate after preempting " + << num_request_to_evict + << " offline requests, but failed."; + can_schedule = false; + } else { + can_schedule = true; + } + } + } + if (!can_schedule) { + block_manager_pool_->deallocate(prefill_sequence.get()); + break; + } } - prefill_sequences_budget.emplace_back(num_tokens); prefill_sequences.emplace_back(prefill_sequence.get()); allocated_tokens += num_tokens; @@ -143,14 +207,13 @@ void ContinuousScheduler::handle_prefill_requests( } break; } - if (prefill_sequences.empty()) { continue; } remaining_token_budget -= allocated_tokens; remaining_seq_budget -= allocated_seqs; - waiting_priority_queue_.pop(); + waiting_priority_queue.pop(); running_requests_.emplace_back(request); running_sequences_.insert(running_sequences_.end(), prefill_sequences.begin(), @@ -159,10 +222,9 @@ void ContinuousScheduler::handle_prefill_requests( prefill_sequences_budget.begin(), prefill_sequences_budget.end()); } - - if (running_sequences_.empty() && !waiting_priority_queue_.empty() && - running_queue_.empty() && - block_manager_pool_->kv_cache_utilization() == 0) { + // maybe can pre-compute if prompt beyond lnegth + if (running_sequences_.empty() && !waiting_priority_queue.empty() && + running_queue_->empty() && running_queue_offline_->empty()) { LOG(ERROR) << "Request prompt is too long, no enough memory to schedule " "a single sequence."; // no enough memory to schedule single sequence, just finish the request @@ -183,20 +245,14 @@ void ContinuousScheduler::handle_prefill_requests( void ContinuousScheduler::handle_decode_requests( size_t& remaining_token_budget, size_t& remaining_seq_budget, - size_t& num_preempted_requests) { - // Do nothing: have new prefill requests to handle, or have no running - // requests - if (!running_sequences_.empty() || running_queue_.empty()) { - return; - } - - // Handle decoding requests. - // no prefill request, schedule the decode requests in the running priority - // queue - while (!running_queue_.empty() && + size_t& num_offd_preempt_off_requests, + size_t& num_ond_preempt_on_requests, + size_t& num_ond_preempt_off_requests, + std::unique_ptr& running_queue) { + while (!running_queue->empty() && remaining_token_budget > options_.num_speculative_tokens() && remaining_seq_budget > 0) { - std::shared_ptr request = running_queue_.front(); + std::shared_ptr request = running_queue->top(); // TODO: check if request is timeout const size_t num_sequences = request->sequences().size(); @@ -221,9 +277,9 @@ void ContinuousScheduler::handle_decode_requests( has_enough_budget = false; break; } - + // sequence token already appended size_t updated_num_tokens = - sequence->num_tokens() + options_.num_speculative_tokens() + 1; + sequence->num_tokens() + options_.num_speculative_tokens(); // no blocks left if (!block_manager_pool_->allocate(sequence.get(), updated_num_tokens)) { has_enough_blocks = false; @@ -246,9 +302,9 @@ void ContinuousScheduler::handle_decode_requests( // schedule candidates in the request if there are enough blocks if (has_enough_budget && has_enough_blocks) { // remove the request from the priority queue - running_queue_.pop_front(); + running_queue->pop_top(); // add the request to the batch - running_requests_.push_back(request); + running_requests_.emplace_back(request); running_sequences_.insert(running_sequences_.end(), candidate_sequences.begin(), candidate_sequences.end()); @@ -263,7 +319,8 @@ void ContinuousScheduler::handle_decode_requests( // budget exhausted, do partially schedule the request if (!has_enough_budget) { - handle_abnormal_request(candidate_sequences, + handle_abnormal_request(running_queue, + candidate_sequences, candidate_token_budgets, allocated_tokens, allocated_seqs, @@ -275,13 +332,31 @@ void ContinuousScheduler::handle_decode_requests( } // memory exhausted, try to preempt lowest priority request - if (running_queue_.size() > 1) { - std::shared_ptr request_to_preempt = running_queue_.back(); - + // continue to evict blocks until enough or no other requests that can be + // preempted TO IMPROVE: preplan if is enough to evict, if not, then not + // evict the offline request or online request with lowest priority + if (options_.enable_on_preempt_off() && !request->offline() && + !running_queue_offline_->empty()) { + std::shared_ptr request_to_preempt = + running_queue_offline_->back(); + ++num_ond_preempt_off_requests; + block_manager_pool_->deallocate(request_to_preempt.get()); + running_queue_offline_->pop_back(); + // add preemptable request to waiting priority queue + request_to_preempt->set_preempted(); + waiting_priority_queue_offline_.push(request_to_preempt); + continue; + } else if (running_queue->size() > 1) { + std::shared_ptr request_to_preempt = running_queue->back(); if (request_to_preempt.get() != request.get()) { - ++num_preempted_requests; + if (request->offline()) { + ++num_offd_preempt_off_requests; + } else { + ++num_ond_preempt_on_requests; + } + // TO IMPROVE: kv cache offload to cpu block_manager_pool_->deallocate(request_to_preempt.get()); - running_queue_.pop_back(); + running_queue->pop_back(); // add preemptable request to waiting priority queue request_to_preempt->set_preempted(); waiting_priority_queue_.push(request_to_preempt); @@ -293,7 +368,8 @@ void ContinuousScheduler::handle_decode_requests( } // no requests left to preempt - handle_abnormal_request(candidate_sequences, + handle_abnormal_request(running_queue, + candidate_sequences, candidate_token_budgets, allocated_tokens, allocated_seqs, @@ -307,6 +383,7 @@ void ContinuousScheduler::handle_decode_requests( // NOTE: refactor ChunkedPrefillScheduler and ContinuousScheduler later. void ContinuousScheduler::handle_abnormal_request( + std::unique_ptr& running_queue, const std::vector& candidate_sequences, const std::vector& candidate_token_budgets, const size_t& allocated_tokens, @@ -315,7 +392,7 @@ void ContinuousScheduler::handle_abnormal_request( size_t& remaining_seq_budget, bool budget_exhausted, bool blocks_exhausted) { - std::shared_ptr request = running_queue_.front(); + std::shared_ptr request = running_queue->top(); if (candidate_sequences.empty()) { if (!running_sequences_.empty()) { return; @@ -336,10 +413,10 @@ void ContinuousScheduler::handle_abnormal_request( LOG(ERROR) << "Request prompt is too long, please set a larger " "max_tokens value via --max_tokens_per_batch."; } else { - CHECK(running_queue_.size() == 1) + CHECK(running_queue->size() == 1) << "Running queue size is not 1, there maybe a bug of request " "preemption logic. running_queue_.size =" - << running_queue_.size(); + << running_queue_->size(); if (util::sum(block_manager_pool_->num_used_blocks()) != request->total_num_blocks()) { // blocks_exhausted is true. @@ -353,7 +430,7 @@ void ContinuousScheduler::handle_abnormal_request( } // request is too long, budget or memory no enough. - running_queue_.pop_front(); + running_queue_->pop_top(); block_manager_pool_->deallocate(request.get()); response_processor_->process_failed_request( request, @@ -361,7 +438,7 @@ void ContinuousScheduler::handle_abnormal_request( "No enough resource to schedule a single sequence"}); } else { // partially schedule the sequences in request - running_queue_.pop_front(); + running_queue->pop_top(); running_requests_.emplace_back(request); running_sequences_.insert(running_sequences_.end(), candidate_sequences.begin(), @@ -406,13 +483,20 @@ std::vector ContinuousScheduler::prepare_batch() { CHECK(request); // expand sequences to the target number if prefix cache is disabled. + // TODO: for no prefix cache, we can expand sequence when one sequence's + // prefill finishes and set to share the blocks among the sequences + // otherwise it's unfriendly for n>1 and no prefix cache. if (!enable_prefix_cache_) { // expand sequences to the target number request->expand_sequences(false); } if (request->sequences()[0]->kv_state().kv_cache_tokens_num() == 0) { - waiting_priority_queue_.push(request); + if (request->offline()) { + waiting_priority_queue_offline_.push(request); + } else { + waiting_priority_queue_.push(request); + } } else { // request from prefill instance in disagge pd mode. running_requests_.emplace_back(request); @@ -436,45 +520,79 @@ std::vector ContinuousScheduler::prepare_batch() { *it = nullptr; } } - + // process previous batch // insert running requests back to the running queue, iterating from // the highest priority to the lowest // insert running requests back to the running queue, iterating from // the highest priority to the lowest - if (last_step_prefill_) { - // insert all requests to the back of running_queue_ - // 1. last step is prefill step: - // new prefill has high priority, but these requests has lower priority - // then existed requests in running_queue_ in decoding stage. - // so we need to push them to the back of running_queue_. - for (auto it = running_requests_.begin(); it != running_requests_.end(); - ++it) { - // finished request is set to nullptr - if (*it == nullptr) { - continue; + // 1. last step is prefill step: + // new prefill has high priority, but these requests has lower priority + // then existed requests in running_queue_ in decoding stage. + // so we need to push them to the back of running_queue_-> + // 2. last step is decode step: + // We need to traverse running_requests_ array in reverse order. + // Because there may be some unexecuted requests with + // lower priorities remaining in the running_queue_-> + // For the requests in running_requests_, + // their priorities are all higher than those of the + // remaining requests. Therefore, insert all requests to the front of + // running_queue_ + if (options_.priority_strategy() == "FCFS") { + if (last_step_prefill_) { + // insert all requests to the back of running_queue_ + // 1. last step is prefill step: + // new prefill has high priority, but these requests has lower priority + // then existed requests in running_queue_ in decoding stage. + // so we need to push them to the back of running_queue_. + for (auto it = running_requests_.begin(); it != running_requests_.end(); + ++it) { + // finished request is set to nullptr + if (*it == nullptr) { + continue; + } + handle_running_requests(*it); + if ((*it)->offline()) { + running_queue_offline_->push(*it, last_step_prefill_); + } else { + running_queue_->push(*it, last_step_prefill_); + } + } + } else { + // insert all requests to the front of running_queue_ + // 2. last step is decode step: + // We need to traverse running_requests_ array in reverse order. + // Because there may be some unexecuted requests with + // lower priorities remaining in the running_queue_. + // For the requests in running_requests_, + // their priorities are all higher than those of the + // remaining requests. Therefore, the `push_front` + // method needs to be used. + // + for (auto it = running_requests_.rbegin(); it != running_requests_.rend(); + ++it) { + // finished request is set to nullptr + if (*it == nullptr) { + continue; + } + handle_running_requests(*it); + if ((*it)->offline()) { + running_queue_offline_->push(*it, last_step_prefill_); + } else { + running_queue_->push(*it, last_step_prefill_); + } } - handle_running_requests(*it); - running_queue_.push_back(*it); } } else { - // insert all requests to the front of running_queue_ - // 2. last step is decode step: - // We need to traverse running_requests_ array in reverse order. - // Because there may be some unexecuted requests with - // lower priorities remaining in the running_queue_. - // For the requests in running_requests_, - // their priorities are all higher than those of the - // remaining requests. Therefore, the `push_front` - // method needs to be used. - // - for (auto it = running_requests_.rbegin(); it != running_requests_.rend(); + for (auto it = running_requests_.begin(); it != running_requests_.end(); ++it) { - // finished request is set to nullptr if (*it == nullptr) { continue; } - handle_running_requests(*it); - running_queue_.push_front(*it); + if ((*it)->offline()) { + running_queue_offline_->push(*it); + } else { + running_queue_->push(*it); + } } } @@ -488,13 +606,43 @@ std::vector ContinuousScheduler::prepare_batch() { size_t remaining_token_budget = options_.max_tokens_per_batch(); size_t remaining_seq_budget = std::max(options_.max_seqs_per_batch(), 1); size_t num_preempted_requests = 0; + size_t num_offd_preempt_off_requests = 0; + size_t num_ond_preempt_on_requests = 0; + size_t num_onp_preempt_off_requests = 0; + size_t num_ond_preempt_off_requests = 0; + // TO IMPROVE?: handle online decode request before prefill offline request + handle_prefill_requests(remaining_token_budget, + remaining_seq_budget, + waiting_priority_queue_, + num_onp_preempt_off_requests, + finished_requests); + handle_prefill_requests(remaining_token_budget, + remaining_seq_budget, + waiting_priority_queue_offline_, + num_onp_preempt_off_requests, + finished_requests); + + if (running_sequences_.empty()) { + // Handle decoding requests. + // no prefill request, schedule the decode requests in the running priority + // queue + handle_decode_requests(remaining_token_budget, + remaining_seq_budget, + num_offd_preempt_off_requests, + num_ond_preempt_on_requests, + num_ond_preempt_off_requests, + running_queue_); + handle_decode_requests(remaining_token_budget, + remaining_seq_budget, + num_offd_preempt_off_requests, + num_ond_preempt_on_requests, + num_ond_preempt_off_requests, + running_queue_offline_); + } - handle_prefill_requests( - remaining_token_budget, remaining_seq_budget, finished_requests); - - handle_decode_requests( - remaining_token_budget, remaining_seq_budget, num_preempted_requests); - + num_preempted_requests = + num_offd_preempt_off_requests + num_ond_preempt_on_requests + + num_ond_preempt_off_requests + num_onp_preempt_off_requests; if (!finished_requests.empty()) { response_processor_->process_completed_requests(finished_requests); } @@ -512,8 +660,13 @@ std::vector ContinuousScheduler::prepare_batch() { pending_requests_.load(std::memory_order_relaxed)); GAUGE_SET(num_running_requests, running_requests_.size()); GAUGE_SET(num_waiting_requests, - waiting_priority_queue_.size() + running_queue_.size()); - GAUGE_SET(num_preempted_requests, num_preempted_requests); + waiting_priority_queue_.size() + running_queue_->size()); + + GAUGE_ADD(num_preempted_requests, num_preempted_requests); + GAUGE_ADD(num_offd_preempt_off_requests, num_offd_preempt_off_requests); + GAUGE_ADD(num_ond_preempt_on_requests, num_ond_preempt_on_requests); + GAUGE_ADD(num_onp_preempt_off_requests, num_onp_preempt_off_requests); + GAUGE_ADD(num_ond_preempt_off_requests, num_ond_preempt_off_requests); GAUGE_SET(num_running_sequences, running_sequences_.size()); @@ -540,7 +693,8 @@ std::vector ContinuousScheduler::schedule_request( return batch; } - if (!waiting_priority_queue_.empty() || !running_queue_.empty()) { + if (!waiting_priority_queue_.empty() || !running_queue_->empty() || + !waiting_priority_queue_offline_.empty()) { continue; } diff --git a/xllm/core/scheduler/continuous_scheduler.h b/xllm/core/scheduler/continuous_scheduler.h index 754475a7..74f8b7ce 100644 --- a/xllm/core/scheduler/continuous_scheduler.h +++ b/xllm/core/scheduler/continuous_scheduler.h @@ -27,21 +27,16 @@ limitations under the License. #include "common/types.h" #include "framework/batch/batch.h" #include "framework/block/block_manager_pool.h" +#include "framework/request/priority_comparator.h" #include "framework/request/request.h" #include "framework/request/sequence.h" #include "runtime/xservice_client.h" #include "scheduler.h" +#include "scheduler/decode_priority_queue.h" namespace xllm { class Engine; - -struct RequestComparator { - bool operator()(std::shared_ptr a, - std::shared_ptr b) const { - return a->created_time() > b->created_time(); - } -}; - +class DecodePriorityQueue; class ContinuousScheduler : public Scheduler { public: struct Options { @@ -87,6 +82,11 @@ class ContinuousScheduler : public Scheduler { PROPERTY(bool, enable_chunked_prefill) = true; PROPERTY(bool, enable_service_routing) = false; + + // TODO: think if distinguish prefill and decode priority strategy + PROPERTY(std::string, + priority_strategy) = "FCFS"; // priority, deadline, FCFS + PROPERTY(bool, enable_on_preempt_off) = true; }; ContinuousScheduler(Engine* engine, const Options& options); @@ -113,7 +113,25 @@ class ContinuousScheduler : public Scheduler { } virtual uint32_t get_waiting_requests_num() const override { - return waiting_priority_queue_.size(); + return waiting_priority_queue_.size() + + waiting_priority_queue_offline_.size(); + } + // for test only + std::vector prepare_batch_test() { return prepare_batch(); } + std::vector> get_running_requests() { + return running_requests_; + } + std::vector> get_waiting_requests() { + std::vector> result; + + auto temp_queue = waiting_priority_queue_; + + while (!temp_queue.empty()) { + result.push_back(temp_queue.top()); + temp_queue.pop(); + } + + return result; } protected: @@ -160,9 +178,11 @@ class ContinuousScheduler : public Scheduler { using RequestPriorityQueue = std::priority_queue, std::vector>, - RequestComparator>; + std::function&, + const std::shared_ptr&)>>; // keep all new requests, generally speaking, they do not have any kv cache. RequestPriorityQueue waiting_priority_queue_; + RequestPriorityQueue waiting_priority_queue_offline_; // keep all running request from high priority to low. // NOTE: Maybe not all requests are scheduled in one step, @@ -171,12 +191,30 @@ class ContinuousScheduler : public Scheduler { // popped from waiting_priority_queue_ but remain in prefill stage, // these requests have already allocated some kv caches, // so they can be preemeted in scheduler. - std::deque> running_queue_; // is last step handle prefill requests bool last_step_prefill_ = false; + // std::deque> running_queue_; + // std::deque> running_queue_offline_; + std::unique_ptr running_queue_; + std::unique_ptr running_queue_offline_; + + void handle_prefill_requests( + size_t& remaining_token_budget, + size_t& remaining_seq_budget, + RequestPriorityQueue& waiting_priority_queue, + size_t& num_onp_preempt_off_requests, + std::vector>& finished_requests); + void handle_decode_requests( + size_t& remaining_token_budget, + size_t& remaining_seq_budget, + size_t& num_offd_preempt_off_requests, + size_t& num_ond_preempt_on_requests, + size_t& num_ond_preempt_off_requests, + std::unique_ptr& running_queue); void handle_abnormal_request( + std::unique_ptr& running_queue, const std::vector& candidate_sequences, const std::vector& candidate_token_budgets, const size_t& allocated_tokens, @@ -190,15 +228,6 @@ class ContinuousScheduler : public Scheduler { // build a batch of requests from the priority queue virtual std::vector prepare_batch(); - virtual void handle_prefill_requests( - size_t& remaining_token_budget, - size_t& remaining_seq_budget, - std::vector>& finished_requests); - - void handle_decode_requests(size_t& remaining_token_budget, - size_t& remaining_seq_budget, - size_t& num_preempted_requests); - private: std::vector schedule_request(const absl::Duration& timeout); @@ -211,6 +240,8 @@ class ContinuousScheduler : public Scheduler { std::vector& sequences) const; std::vector get_active_activation_in_bytes(); + void create_running_queue(const Options& options); + private: // tokenizer std::unique_ptr tokenizer_; diff --git a/xllm/core/scheduler/continuous_scheduler_test.cpp b/xllm/core/scheduler/continuous_scheduler_test.cpp new file mode 100644 index 00000000..f677e75f --- /dev/null +++ b/xllm/core/scheduler/continuous_scheduler_test.cpp @@ -0,0 +1,341 @@ +#include "continuous_scheduler.h" + +#include +#include + +#include "chunked_prefill_scheduler.h" +#include "runtime/engine.h" +#include "util/utils.h" + +namespace xllm { + +namespace { +class FakeTokenizer : public Tokenizer { + public: + bool encode(const std::string_view& text, std::vector* ids) const { + LOG(FATAL) << "Not implemented"; + } + std::string decode(const Slice& ids, + bool skip_special_tokens) const { + LOG(FATAL) << "Not implemented"; + } + std::optional token_to_id(const std::string_view& token) const { + LOG(FATAL) << "Not implemented"; + } + std::string id_to_token(int32_t id) const { LOG(FATAL) << "Not implemented"; } + size_t vocab_size() const { LOG(FATAL) << "Not implemented"; } + std::unique_ptr clone() const { + return std::make_unique(); + } +}; + +class FakeEngine : public Engine { + public: + FakeEngine(int32_t num_blocks, int32_t block_size) { + BlockManager::Options opt; + opt.num_blocks_ = num_blocks; + opt.block_size_ = block_size; + opt.enable_prefix_cache_ = false; // we dont consider prefix cache here + fake_tokenizer_ = std::make_unique(); + fake_block_manager_ = std::make_unique(opt, 1); + } + ForwardOutput step(std::vector& batch) { + LOG(FATAL) << "Not implemented"; + } + void update_last_step_result(std::vector& batch) { + LOG(FATAL) << "Not implemented"; + } + const Tokenizer* tokenizer() const { return fake_tokenizer_.get(); } + BlockManagerPool* block_manager_pool() const { + return fake_block_manager_.get(); + } + const ModelArgs& model_args() const { LOG(FATAL) << "Not implemented"; } + const TokenizerArgs& tokenizer_args() const { + LOG(FATAL) << "Not implemented"; + } + std::vector get_active_activation_memory() const { + LOG(FATAL) << "Not implemented"; + } + bool init() override { return true; } + + private: + std::unique_ptr fake_tokenizer_; + std::unique_ptr fake_block_manager_; +}; + +ContinuousScheduler::Options create_scheduler_options( + int32_t max_tokens_per_batch, + int32_t max_seqs_per_batch, + int32_t num_speculative_tokens, + int32_t max_tokens_per_chunk_for_prefill, + int32_t dp_size, + const std::string& priority_strategy = "FCFS") { + ContinuousScheduler::Options opt; + opt.num_speculative_tokens_ = num_speculative_tokens; + opt.max_tokens_per_chunk_for_prefill_ = max_tokens_per_chunk_for_prefill; + opt.max_tokens_per_batch_ = max_tokens_per_batch; + opt.max_seqs_per_batch_ = max_seqs_per_batch; + opt.dp_size_ = dp_size; + opt.priority_strategy_ = priority_strategy; + + return opt; +} + +std::vector> generate_request( + const std::vector& prompt_lens, + const std::vector& max_tokens, + const std::vector& offlines, + const std::vector& priorities, + int32_t max_context_len) { + std::vector> requests; + EXPECT_TRUE(prompt_lens.size() == max_tokens.size()); + for (size_t i = 0; i < prompt_lens.size(); ++i) { + std::vector prompt_token_ids; + prompt_token_ids.resize(prompt_lens[i]); + RequestSamplingParam sampling_param; + StoppingChecker stopping_checker; + stopping_checker.set_max_generated_tokens(max_tokens[i]); + stopping_checker.set_max_context_len(max_context_len); + stopping_checker.set_ignore_eos(true); + RequestState req_state("x", + prompt_token_ids, + sampling_param, + stopping_checker, + prompt_lens[i] + 30000, + 1, + 1, + false, + false, + false, + false, + false, + nullptr, + nullptr); + auto request = std::make_shared( + "1", + "1", + "1", + std::move(req_state), + "1", + offlines[i], + 0, + static_cast(priorities[i])); + requests.emplace_back(request); + } + + return requests; +} + +// dont not consider speculative decoding. +void update_requests(std::vector> requests) { + for (auto req : requests) { + for (auto& seq : req->sequences()) { + if (seq->kv_state().kv_cache_tokens_num() == 0) { + seq->kv_state().incr_kv_cache_tokens_num(seq->num_prompt_tokens()); + } else { + seq->kv_state().incr_kv_cache_tokens_num(1); + } + Token token(1); + seq->append_token(token); + } + } +} + +} // namespace + +// TEST-1: +// test preempt +TEST(ContinuousSchedulerTest, OnDecodePreemptOffDecode) { + // set max free blocks: 9, support 9*32=288 tokens + // actually only 8 free blocks , because default 1 block is for padding + int block_num = 9; + int block_size = 32; + int max_tokens_per_chunk_for_prefill = 1024; + // set chunked max_tokens budgets 10000 per step + ContinuousScheduler::Options opt = create_scheduler_options( + 10000, 256, 0, max_tokens_per_chunk_for_prefill, 1); + auto engine = std::make_unique(block_num, block_size); + auto scheduler = std::make_unique(engine.get(), opt); + BlockManagerPool* block_manager_pool = engine->block_manager_pool(); + EXPECT_TRUE(scheduler != nullptr); + + std::vector> running_requests; + + // 1. schedule two new online prefill requests + auto requests = + generate_request({127, 127}, {10, 10}, {true, false}, {2, 2}, 30000); + running_requests = requests; + for (auto req : requests) { + scheduler->add_request(req); + } + auto batch = scheduler->prepare_batch_test(); + EXPECT_TRUE(batch.size() == 1); + EXPECT_TRUE(batch[0].size() == 2); + update_requests(running_requests); + + batch = scheduler->prepare_batch_test(); + + EXPECT_TRUE(batch.size() == 1); + EXPECT_TRUE(batch[0].size() == 2); + update_requests(running_requests); + + int free_blocks_before_preempt = + util::max(block_manager_pool->num_free_blocks()); + batch = scheduler->prepare_batch_test(); + EXPECT_TRUE(batch.size() == 1); + EXPECT_TRUE(batch[0].size() == 1); + int free_blocks_after_preempt = + util::max(block_manager_pool->num_free_blocks()); + EXPECT_TRUE(free_blocks_after_preempt > free_blocks_before_preempt); + + // check the running request is online request + EXPECT_TRUE(scheduler->get_running_requests().size() == 1); + EXPECT_TRUE(scheduler->get_running_requests()[0]->offline() == false); + EXPECT_TRUE(scheduler->get_waiting_requests_num() == 1); +} + +// TEST-2: +// test preempt +TEST(ContinuousSchedulerTest, OnPrefillPreemptOffDecode) { + // set max free blocks: 9, support 9*32=288 tokens + // actually only 8 free blocks , because default 1 block is for padding + int block_num = 9; + int block_size = 32; + int max_tokens_per_chunk_for_prefill = 1024; + // set chunked max_tokens budgets 10000 per step + ContinuousScheduler::Options opt = create_scheduler_options( + 10000, 256, 0, max_tokens_per_chunk_for_prefill, 1); + FLAGS_prefill_scheduling_memory_usage_threshold = 2; // release threshold + + { + // 1. two offline decode requests then one online prefill request preempt + // them + auto engine = std::make_unique(block_num, block_size); + auto scheduler = std::make_unique(engine.get(), opt); + BlockManagerPool* block_manager_pool = engine->block_manager_pool(); + EXPECT_TRUE(scheduler != nullptr); + + std::vector> running_requests; + + auto requests = + generate_request({100, 100}, {10, 10}, {true, true}, {2, 2}, 30000); + running_requests = requests; + for (auto req : requests) { + scheduler->add_request(req); + } + auto batch = scheduler->prepare_batch_test(); + EXPECT_TRUE(batch.size() == 1); + EXPECT_TRUE(batch[0].size() == 2); + EXPECT_TRUE(util::max(block_manager_pool->num_free_blocks()) == 0); + update_requests(running_requests); + + batch = scheduler->prepare_batch_test(); + EXPECT_TRUE(batch.size() == 1); + EXPECT_TRUE(batch[0].size() == 2); + EXPECT_TRUE(util::max(block_manager_pool->num_free_blocks()) == 0); + update_requests(running_requests); + + auto new_requests = + generate_request({80}, {10}, {false}, {2}, 30000); // use 3 blocks + scheduler->add_request(new_requests[0]); + batch = scheduler->prepare_batch_test(); + EXPECT_TRUE(batch.size() == 1); + EXPECT_TRUE(batch[0].size() == 1); + + // online prefill request preempt offline decode request + EXPECT_TRUE(scheduler->get_running_requests().size() == 1); + EXPECT_TRUE(scheduler->get_running_requests()[0]->offline() == false); + EXPECT_TRUE(scheduler->get_waiting_requests_num() == 1); + + // offline is evicted + EXPECT_TRUE(util::max(block_manager_pool->num_free_blocks()) == 1); + running_requests.pop_back(); + update_requests(new_requests); + } + + // 2. another case: longer online prefill request arrives, but can not evict + // offline because evicting offline is not enough + { + auto engine = std::make_unique(block_num, block_size); + auto scheduler = std::make_unique(engine.get(), opt); + BlockManagerPool* block_manager_pool = engine->block_manager_pool(); + EXPECT_TRUE(scheduler != nullptr); + + std::vector> running_requests; + // one online, one offline + auto requests = + generate_request({100, 100}, {10, 10}, {true, false}, {2, 2}, 30000); + running_requests = requests; + for (auto req : requests) { + scheduler->add_request(req); + } + auto batch = scheduler->prepare_batch_test(); + EXPECT_TRUE(batch.size() == 1); + EXPECT_TRUE(batch[0].size() == 2); + EXPECT_TRUE(util::max(block_manager_pool->num_free_blocks()) == 0); + update_requests(running_requests); + + auto new_requests = generate_request({200}, {10}, {false}, {2}, 30000); + scheduler->add_request(new_requests[0]); + batch = scheduler->prepare_batch_test(); + // online is still waiting + EXPECT_TRUE(batch.size() == 1); + EXPECT_TRUE(batch[0].size() == 2); + EXPECT_TRUE(scheduler->get_waiting_requests().size() == 1); + EXPECT_TRUE(scheduler->get_waiting_requests()[0].get() == + new_requests[0].get()); + } +} + +// TEST-3: +// test priority schedule +TEST(ContinuousSchedulerTest, PrioritySchedule) { + // set max free blocks: 12 + // actually only 11 free blocks , because default 1 block is for padding + int block_num = 12; + int block_size = 32; + int max_tokens_per_chunk_for_prefill = 1024; + // set chunked max_tokens budgets 10000 per step + ContinuousScheduler::Options opt = create_scheduler_options( + 10000, 256, 0, max_tokens_per_chunk_for_prefill, 1, "priority"); + auto engine = std::make_unique(block_num, block_size); + auto scheduler = std::make_unique(engine.get(), opt); + EXPECT_TRUE(scheduler != nullptr); + + std::vector> running_requests; + + // 1: HIGH, 2: NORMAL, 3: LOW + auto requests = generate_request( + {128, 128, 128}, {10, 10, 10}, {false, false, false}, {3, 3, 2}, 30000); + for (auto req : requests) { + scheduler->add_request(req); + } + auto batch = scheduler->prepare_batch_test(); + EXPECT_TRUE(batch.size() == 1); + EXPECT_TRUE(batch[0].size() == 2); + EXPECT_TRUE(scheduler->get_running_requests().size() == 2); + EXPECT_TRUE(scheduler->get_running_requests()[0]->priority() == 2 /*NORMAL*/); + EXPECT_TRUE(scheduler->get_running_requests()[1]->priority() == 3 /*LOW*/); + running_requests = scheduler->get_running_requests(); + update_requests(running_requests); + + // new HIGH priority request arrives, its prefill starts + auto new_requests = + generate_request({32}, {10}, {false}, {1}, 30000); // use 1 blocks + scheduler->add_request(new_requests[0]); + batch = scheduler->prepare_batch_test(); + EXPECT_TRUE(batch.size() == 1); + EXPECT_TRUE(batch[0].size() == 1); + EXPECT_TRUE(scheduler->get_running_requests().size() == 1); + update_requests(new_requests); + + // only HIGH and NORMAL requests decode + batch = scheduler->prepare_batch_test(); + EXPECT_TRUE(batch.size() == 1); + EXPECT_TRUE(batch[0].size() == 2); + EXPECT_TRUE(scheduler->get_running_requests().size() == 2); + EXPECT_TRUE(scheduler->get_running_requests()[0]->priority() == 1 /*HIGH*/); + EXPECT_TRUE(scheduler->get_running_requests()[1]->priority() == 2 /*NORMAL*/); +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/scheduler/decode_priority_queue.h b/xllm/core/scheduler/decode_priority_queue.h new file mode 100644 index 00000000..1cb3acf6 --- /dev/null +++ b/xllm/core/scheduler/decode_priority_queue.h @@ -0,0 +1,175 @@ +#pragma once +#include +#include +#include +#include + +#include "framework/request/priority_comparator.h" +#include "framework/request/request.h" + +namespace xllm { + +// for Encapsulate and support Iterator pattern +class BaseIterator { + public: + virtual ~BaseIterator() = default; + virtual std::shared_ptr operator*() const = 0; + virtual void operator++() = 0; + virtual bool operator!=(const BaseIterator& other) const = 0; + virtual std::unique_ptr clone() const = 0; +}; + +template +class ConcreteIterator : public BaseIterator { + Iterator iter_; + + public: + explicit ConcreteIterator(Iterator iter) : iter_(iter) {} + + std::shared_ptr operator*() const override { return *iter_; } + + void operator++() override { ++iter_; } + + bool operator!=(const BaseIterator& other) const override { + const auto* derived = dynamic_cast(&other); + return derived && iter_ != derived->iter_; + } + + std::unique_ptr clone() const override { + return std::make_unique(iter_); + } +}; +class DecodePriorityQueue { + public: + class Iterator { + std::unique_ptr itr_; + + public: + explicit Iterator(std::unique_ptr itr) + : itr_(std::move(itr)) {} + + std::shared_ptr operator*() const { return **itr_; } + + Iterator& operator++() { + ++*itr_; + return *this; + } + + bool operator!=(const Iterator& other) const { + return itr_->operator!=(*other.itr_); + } + }; + virtual void push(std::shared_ptr req) = 0; + virtual void push(std::shared_ptr req, bool if_back) = 0; + virtual void pop_top() = 0; + virtual void pop_back() = 0; + virtual std::shared_ptr top() const = 0; + virtual std::shared_ptr back() const = 0; + virtual bool empty() const = 0; + virtual size_t size() const = 0; + virtual ~DecodePriorityQueue() = default; + + virtual Iterator begin() const = 0; + virtual Iterator end() const = 0; + virtual Iterator rbegin() const = 0; + virtual Iterator rend() const = 0; +}; + +class DynamicPriorityQueue : public DecodePriorityQueue { + private: + using QueueType = + std::set, + std::function&, + const std::shared_ptr&)>>; + QueueType queue_; + std::unique_ptr comparator_; + + public: + explicit DynamicPriorityQueue(std::unique_ptr comparator) + : comparator_(std::move(comparator)), + queue_([this](const auto& a, const auto& b) { + return !(*comparator_)(a, b); // assign to Priority Comparator + }) {} + + void push(std::shared_ptr req) override { queue_.insert(req); } + void push(std::shared_ptr req, bool if_back) override { + LOG(FATAL) << "DynamicPriorityQueue not support"; + } + void pop_top() override { queue_.erase(queue_.begin()); } + void pop_back() override { queue_.erase(std::prev(queue_.end())); } + std::shared_ptr top() const override { return *queue_.begin(); } + std::shared_ptr back() const override { return *queue_.rbegin(); } + bool empty() const override { return queue_.empty(); } + virtual size_t size() const override { return queue_.size(); } + + Iterator begin() const override { + return Iterator(std::make_unique>( + queue_.begin())); + } + + Iterator end() const override { + return Iterator( + std::make_unique>(queue_.end())); + } + + Iterator rbegin() const override { + return Iterator( + std::make_unique>( + queue_.rbegin())); + } + + Iterator rend() const override { + return Iterator( + std::make_unique>( + queue_.rend())); + } +}; + +class FCFSQueue : public DecodePriorityQueue { + // use deque to implement FCFS queue for insert and evict effeciency + private: + std::deque> queue_; + + public: + void push(std::shared_ptr req) override { queue_.push_back(req); } + + void push(std::shared_ptr req, bool if_back) override { + if (if_back) { + queue_.push_back(req); + } else { + queue_.push_front(req); + } + } + + void pop_top() override { queue_.pop_front(); } + void pop_back() override { queue_.pop_back(); } + std::shared_ptr top() const override { return queue_.front(); } + std::shared_ptr back() const override { return queue_.back(); } + bool empty() const override { return queue_.empty(); } + virtual size_t size() const override { return queue_.size(); } + + Iterator begin() const override { + return Iterator( + std::make_unique>( + queue_.begin())); + } + + Iterator end() const override { + return Iterator( + std::make_unique>( + queue_.end())); + } + + Iterator rbegin() const override { + return Iterator(std::make_unique< + ConcreteIterator>( + queue_.rbegin())); + } + + Iterator rend() const override { + return Iterator(std::make_unique< + ConcreteIterator>( + queue_.rend())); + } +}; +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/scheduler/disagg_pd_scheduler.cpp b/xllm/core/scheduler/disagg_pd_scheduler.cpp index 6b159860..85d5a8ea 100644 --- a/xllm/core/scheduler/disagg_pd_scheduler.cpp +++ b/xllm/core/scheduler/disagg_pd_scheduler.cpp @@ -169,6 +169,11 @@ bool DisaggPDScheduler::add_request(std::shared_ptr& request) { CHECK(request != nullptr); CHECK(!request->sequences().empty()); + if (request->offline()) { + // offline request, push to offline queue + prefill_request_queue_offline_.push(request); + return true; + } // push and wait prefill_request_queue_.push(request); @@ -179,7 +184,19 @@ bool DisaggPDScheduler::add_request(std::shared_ptr& request) { void DisaggPDScheduler::dispatch_requests() { while (true) { std::vector> requests; - std::shared_ptr request = prefill_request_queue_.pop(); + + auto poped_result = prefill_request_queue_.try_pop(); + // OPTIMIZE 之后改为:多次尝试读取在线 prefill + // 请求,只有较长时间未获取在线请求时,才读取离线 prefill 请求。 + if (!poped_result.has_value()) { + poped_result = prefill_request_queue_offline_.try_pop(); + if (!poped_result.has_value()) { + // no offline request, sleep for a while and try again + absl::SleepFor(absl::Milliseconds(100)); + continue; + } + } + auto request = poped_result.value(); if (request == nullptr) { // nullptr is a signal to exit break; @@ -210,6 +227,8 @@ void DisaggPDScheduler::dispatch_requests() { sleep(1); } // select a D instance use RR currently. + // TODO: use better decode selection strategy later. maybe different + // strategy for offline and online request. or implement in xllm service. int try_decode_count = 0; while (!stub) { if (try_decode_count == decode_inst_names_.size()) { @@ -319,7 +338,12 @@ void DisaggPDScheduler::dispatch_requests() { for (size_t i = 0; i < requests.size(); ++i) { if (resps.resps()[i].status_code() != 200) { // push back to prefill_request_queue_ - prefill_request_queue_.push(requests[i]); + if (requests[i]->offline()) { + prefill_request_queue_offline_.push(requests[i]); + } else { + prefill_request_queue_.push(requests[i]); + } + } else { for (auto& sequence : requests[i]->sequences()) { TransferKVInfo info; diff --git a/xllm/core/scheduler/disagg_pd_scheduler.h b/xllm/core/scheduler/disagg_pd_scheduler.h index adcc5d37..673d2780 100644 --- a/xllm/core/scheduler/disagg_pd_scheduler.h +++ b/xllm/core/scheduler/disagg_pd_scheduler.h @@ -114,6 +114,11 @@ class DisaggPDScheduler : public ContinuousScheduler { // for prefill, dispatch request to Decode instance std::unique_ptr dispatch_thread_; ConcurrentQueue> prefill_request_queue_; + + // folly::MPMCQueue> + // prefill_request_queue_offline_; + ConcurrentQueue> prefill_request_queue_offline_; + // for prefill save all remote requests std::unordered_map> remote_requests_map_; @@ -121,8 +126,10 @@ class DisaggPDScheduler : public ContinuousScheduler { using RequestPriorityQueue = std::priority_queue, std::vector>, - RequestComparator>; + std::function&, + const std::shared_ptr&)>>; RequestPriorityQueue waiting_priority_queue_; + RequestPriorityQueue waiting_priority_queue_offline_; // use threadpool to handle prefill-completed request ThreadPool prefill_threadpool_; diff --git a/xllm/core/scheduler/zero_eviction_scheduler.cpp b/xllm/core/scheduler/zero_eviction_scheduler.cpp index 20e376b8..b0f84362 100644 --- a/xllm/core/scheduler/zero_eviction_scheduler.cpp +++ b/xllm/core/scheduler/zero_eviction_scheduler.cpp @@ -33,10 +33,11 @@ uint32_t ceiling_div(uint32_t left, uint32_t right) { } std::vector get_running_sequences( - const std::deque>& running_queue) { + const std::unique_ptr& running_queue) { std::vector running_sequences; - for (auto& running_request : running_queue) { + for (auto it = running_queue->rbegin(); it != running_queue->rend(); ++it) { + std::shared_ptr running_request = *it; if (Request* request = running_request.get()) { for (auto& sequence : request->sequences()) { // skip finished sequence. @@ -234,7 +235,7 @@ bool BlockCapacityGuard::simulate_is_satisfied_for_candidate_sequences() { bool BlockCapacityGuard::if_accept_candidate_sequences( const std::vector& candidate_sequences, - const std::deque>& running_queue, + const std::unique_ptr& running_queue, const std::vector& running_sequences) { num_reserved_block_for_prefill_ = 0; @@ -261,8 +262,8 @@ ZeroEvictionScheduler::~ZeroEvictionScheduler() { } // release all requests in the running priority queue - while (!running_queue_.empty()) { - running_queue_.pop_front(); + while (!running_queue_->empty()) { + running_queue_->pop_top(); } } @@ -412,7 +413,7 @@ void ZeroEvictionScheduler::handle_prefill_requests( } if (running_sequences_.empty() && !waiting_priority_queue_.empty() && - running_queue_.empty() && + running_queue_->empty() && block_manager_pool_->kv_cache_utilization() == 0) { LOG(ERROR) << "Request prompt is too long, no enough memory to schedule " "a single sequence."; diff --git a/xllm/core/scheduler/zero_eviction_scheduler.h b/xllm/core/scheduler/zero_eviction_scheduler.h index e44fe742..abefd64c 100644 --- a/xllm/core/scheduler/zero_eviction_scheduler.h +++ b/xllm/core/scheduler/zero_eviction_scheduler.h @@ -59,7 +59,7 @@ class BlockCapacityGuard { bool if_accept_candidate_sequences( const std::vector& candidate_sequences, - const std::deque>& running_queue, + const std::unique_ptr& running_queue, const std::vector& running_sequences); private: @@ -118,7 +118,7 @@ class ZeroEvictionScheduler final : public ContinuousScheduler { void handle_prefill_requests( size_t& remaining_token_budget, size_t& remaining_seq_budget, - std::vector>& finished_requests) override; + std::vector>& finished_requests); bool try_allocate_block_for(std::shared_ptr request, std::vector* prefill_sequences, diff --git a/xllm/proto/chat.proto b/xllm/proto/chat.proto index 600cadf2..69209cf4 100644 --- a/xllm/proto/chat.proto +++ b/xllm/proto/chat.proto @@ -113,7 +113,16 @@ message ChatRequest { Routing routing = 27; repeated Tool tools = 28; + optional string tool_choice = 29; + + optional bool offline = 30; + + optional int32 slo_ms = 31; + + // request priority. default = DEFAULT + optional Priority priority = 32; + } message ChatLogProbData { diff --git a/xllm/proto/common.proto b/xllm/proto/common.proto index 99a3fb70..bf64c92a 100644 --- a/xllm/proto/common.proto +++ b/xllm/proto/common.proto @@ -27,6 +27,16 @@ message FunctionCall { string name = 1; string arguments = 2; // JSON string } +enum Priority { + DEFAULT = 0; + + HIGH = 1; + + NORMAL = 2; + + LOW = 3; +} + message Usage { // the number of tokens in the prompt. diff --git a/xllm/proto/completion.proto b/xllm/proto/completion.proto index ea7fca8f..19e2fc27 100644 --- a/xllm/proto/completion.proto +++ b/xllm/proto/completion.proto @@ -86,6 +86,13 @@ message CompletionRequest { repeated int32 token_ids = 24; Routing routing = 25; + + optional bool offline = 26; + + optional int32 slo_ms = 27; + + // request priority. default = DEFAULT + optional Priority priority = 28; } message LogProbs { diff --git a/xllm/proto/disagg_pd.proto b/xllm/proto/disagg_pd.proto index cd2b308e..a457e107 100644 --- a/xllm/proto/disagg_pd.proto +++ b/xllm/proto/disagg_pd.proto @@ -53,6 +53,10 @@ message DisaggRequest { bool echo = 26; bool skip_special_tokens = 27; repeated int32 prompt_tokens = 28; + bool offline = 29; + optional int32 slo_ms = 30; + optional Priority priority = 31; + } // response for DisaggRequests from decode instance. diff --git a/xllm/proto/multimodal.proto b/xllm/proto/multimodal.proto index 41a3cfb2..184168b8 100644 --- a/xllm/proto/multimodal.proto +++ b/xllm/proto/multimodal.proto @@ -127,5 +127,13 @@ message MMChatRequest { Routing routing = 26; repeated Tool tools = 27; + optional string tool_choice = 28; + + optional bool offline = 29; + + optional int32 slo_ms = 30; + + // request priority. default = DEFAULT + optional Priority priority = 31; } diff --git a/xllm/xllm.cpp b/xllm/xllm.cpp index b96adf0b..6e0815d5 100644 --- a/xllm/xllm.cpp +++ b/xllm/xllm.cpp @@ -132,7 +132,9 @@ int run() { .kv_cache_transfer_mode(FLAGS_kv_cache_transfer_mode) .etcd_addr(FLAGS_etcd_addr) .enable_service_routing(FLAGS_enable_service_routing) - .tool_call_parser(FLAGS_tool_call_parser); + .tool_call_parser(FLAGS_tool_call_parser) + .priority_strategy(FLAGS_priority_strategy) + .enable_on_preempt_off(FLAGS_enable_on_preempt_off); InstanceName::name()->set_name(options.instance_name().value_or("")); From 41129381fc07729e985175edca0c4f5034f03e30 Mon Sep 17 00:00:00 2001 From: huangweizhe1 Date: Fri, 29 Aug 2025 22:48:33 +0800 Subject: [PATCH 2/2] refactor: refactor the priority and on/offline code. --- xllm/core/common/global_flags.cpp | 2 +- xllm/core/common/global_flags.h | 2 +- xllm/core/common/metrics.cpp | 8 +- xllm/core/common/metrics.h | 8 +- xllm/core/common/options.h | 2 +- .../disagg_pd_service_impl.cpp | 17 ++- xllm/core/framework/block/block_manager.h | 4 - .../framework/block/block_manager_impl.cpp | 27 ---- .../core/framework/block/block_manager_impl.h | 4 - .../framework/block/block_manager_pool.cpp | 10 -- .../core/framework/block/block_manager_pool.h | 4 - .../block/concurrent_block_manager_impl.cpp | 9 -- .../block/concurrent_block_manager_impl.h | 4 - xllm/core/framework/request/request.cpp | 2 +- xllm/core/framework/request/request.h | 8 +- .../core/framework/request/request_params.cpp | 4 +- xllm/core/framework/request/request_params.h | 3 +- xllm/core/runtime/options.h | 2 +- .../scheduler/chunked_prefill_scheduler.cpp | 2 +- xllm/core/scheduler/continuous_scheduler.cpp | 137 ++++++++++-------- xllm/core/scheduler/continuous_scheduler.h | 14 +- .../scheduler/continuous_scheduler_test.cpp | 30 ++-- xllm/core/scheduler/disagg_pd_scheduler.cpp | 15 +- xllm/core/util/concurrent_queue.h | 16 ++ xllm/xllm.cpp | 2 +- 25 files changed, 160 insertions(+), 176 deletions(-) diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index cad4cec1..2ea4b7e4 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -202,6 +202,6 @@ DEFINE_int32(heart_beat_interval, 3, "heart beat interval"); DEFINE_string(priority_strategy, "FCFS", "priority strategy for requests"); -DEFINE_bool(enable_on_preempt_off, +DEFINE_bool(enable_online_preempt_offline, true, "whether enable online preempt offline"); diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 49e059b2..8efbe41d 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -129,4 +129,4 @@ DECLARE_bool(use_zero_evict); DECLARE_string(priority_strategy); -DECLARE_bool(enable_on_preempt_off); +DECLARE_bool(enable_online_preempt_offline); diff --git a/xllm/core/common/metrics.cpp b/xllm/core/common/metrics.cpp index 7a881219..5f792f2b 100644 --- a/xllm/core/common/metrics.cpp +++ b/xllm/core/common/metrics.cpp @@ -88,13 +88,13 @@ DEFINE_GAUGE(num_running_requests, "Number of running requests in scheduler"); DEFINE_GAUGE(num_waiting_requests, "Number of waiting requests in scheduler"); DEFINE_GAUGE(num_preempted_requests, "Number of preempted requests in scheduler"); -DEFINE_GAUGE(num_offd_preempt_off_requests, +DEFINE_GAUGE(num_offline_decode_preempt_offline_requests, "Number of offline decode preempt offline requests in scheduler"); -DEFINE_GAUGE(num_ond_preempt_on_requests, +DEFINE_GAUGE(num_online_decode_preempt_online_requests, "Number of online decode preempt online requests in scheduler"); -DEFINE_GAUGE(num_onp_preempt_off_requests, +DEFINE_GAUGE(num_online_prefill_preempt_offline_requests, "Number of online prefill preempt offline requests in scheduler"); -DEFINE_GAUGE(num_ond_preempt_off_requests, +DEFINE_GAUGE(num_online_decode_preempt_offline_requests, "Number of online decode preempt offline requests in scheduler"); DEFINE_GAUGE(num_running_sequences, "Number of running sequences"); diff --git a/xllm/core/common/metrics.h b/xllm/core/common/metrics.h index 64bf9c55..48663341 100644 --- a/xllm/core/common/metrics.h +++ b/xllm/core/common/metrics.h @@ -149,10 +149,10 @@ DECLARE_GAUGE(num_pending_requests); DECLARE_GAUGE(num_running_requests); DECLARE_GAUGE(num_waiting_requests); DECLARE_GAUGE(num_preempted_requests); -DECLARE_GAUGE(num_offd_preempt_off_requests); -DECLARE_GAUGE(num_ond_preempt_on_requests); -DECLARE_GAUGE(num_onp_preempt_off_requests); -DECLARE_GAUGE(num_ond_preempt_off_requests); +DECLARE_GAUGE(num_offline_decode_preempt_offline_requests); +DECLARE_GAUGE(num_online_decode_preempt_online_requests); +DECLARE_GAUGE(num_online_prefill_preempt_offline_requests); +DECLARE_GAUGE(num_online_decode_preempt_offline_requests); DECLARE_GAUGE(num_running_sequences); DECLARE_GAUGE(kv_cache_utilization_perc); DECLARE_GAUGE(num_blocks_in_prefix_cache); diff --git a/xllm/core/common/options.h b/xllm/core/common/options.h index 950d9e12..49cce352 100644 --- a/xllm/core/common/options.h +++ b/xllm/core/common/options.h @@ -116,7 +116,7 @@ class Options { PROPERTY(std::string, priority_strategy) = "FCFS"; - PROPERTY(bool, enable_on_preempt_off) = true; + PROPERTY(bool, enable_online_preempt_offline) = true; }; } // namespace xllm diff --git a/xllm/core/distributed_runtime/disagg_pd_service_impl.cpp b/xllm/core/distributed_runtime/disagg_pd_service_impl.cpp index d8d4f3a9..4f689baf 100644 --- a/xllm/core/distributed_runtime/disagg_pd_service_impl.cpp +++ b/xllm/core/distributed_runtime/disagg_pd_service_impl.cpp @@ -90,14 +90,15 @@ std::shared_ptr DisaggPDServiceImpl::generate_request( output_callback, batch_output_callback); - auto new_request = std::make_shared(req.req_id(), - req.x_request_id(), - req.x_request_time(), - std::move(req_state), - req.service_req_id(), - req.offline(), - req.slo_ms(), - req.priority()); + auto new_request = std::make_shared( + req.req_id(), + req.x_request_id(), + req.x_request_time(), + std::move(req_state), + req.service_req_id(), + req.offline(), + req.slo_ms(), + static_cast(req.priority())); // add one sequence, rest will be added by scheduler return new_request; diff --git a/xllm/core/framework/block/block_manager.h b/xllm/core/framework/block/block_manager.h index a5e85742..26e994d1 100644 --- a/xllm/core/framework/block/block_manager.h +++ b/xllm/core/framework/block/block_manager.h @@ -62,10 +62,6 @@ class BlockManager { virtual void cache(const Slice& token_ids, const Slice& blocks) = 0; - virtual bool check_if_enough_to_evict( - DecodePriorityQueue* running_queue_to_evict, - Sequence* prefill_sequence, - size_t& num_request_to_evict) = 0; // get merged all dp rank KVCacheEvent virtual void get_merged_kvcache_event(KvCacheEvent* event) const = 0; diff --git a/xllm/core/framework/block/block_manager_impl.cpp b/xllm/core/framework/block/block_manager_impl.cpp index ea686495..2dfbc727 100644 --- a/xllm/core/framework/block/block_manager_impl.cpp +++ b/xllm/core/framework/block/block_manager_impl.cpp @@ -74,33 +74,6 @@ void BlockManagerImpl::deallocate(const Slice& blocks) { } } -bool BlockManagerImpl::check_if_enough_to_evict( - DecodePriorityQueue* running_queue_to_evict, - Sequence* prefill_sequence, - size_t& num_request_to_evict) { - // check if it's enough when we evict this requests queue - - const size_t num_blocks_needed = - (prefill_sequence->num_tokens() + block_size_ - 1) / block_size_; - size_t num_blocks_can_evict = 0; - // count the number of blocks can be preempted - for (auto it = running_queue_to_evict->rbegin(); - it != running_queue_to_evict->rend(); - ++it) { - std::shared_ptr request_to_preempt = *it; - num_request_to_evict++; - // count the number of blocks belong to the request - for (const auto& seq : request_to_preempt->sequences()) { - num_blocks_can_evict += seq->kv_state().num_kv_blocks(); - } - if ((num_blocks_needed <= num_blocks_can_evict) || - has_enough_blocks(num_blocks_needed - num_blocks_can_evict)) { - return true; - } - } - return false; -} - bool BlockManagerImpl::has_enough_blocks(uint32_t num_blocks) { if (num_blocks <= num_free_blocks_) { return true; diff --git a/xllm/core/framework/block/block_manager_impl.h b/xllm/core/framework/block/block_manager_impl.h index 1ce3ba37..c9f5b41d 100644 --- a/xllm/core/framework/block/block_manager_impl.h +++ b/xllm/core/framework/block/block_manager_impl.h @@ -46,10 +46,6 @@ class BlockManagerImpl : public BlockManager { void get_merged_kvcache_event(KvCacheEvent* event) const override; - bool check_if_enough_to_evict(DecodePriorityQueue* running_queue_to_evict, - Sequence* prefill_sequence, - size_t& num_request_to_evict) override; - size_t num_blocks_in_prefix_cache() const override { if (options_.enable_prefix_cache()) { CHECK(prefix_cache_); diff --git a/xllm/core/framework/block/block_manager_pool.cpp b/xllm/core/framework/block/block_manager_pool.cpp index dd620d0d..ae835472 100644 --- a/xllm/core/framework/block/block_manager_pool.cpp +++ b/xllm/core/framework/block/block_manager_pool.cpp @@ -92,16 +92,6 @@ bool BlockManagerPool::allocate(Sequence* sequence) { return allocate(sequence, sequence->num_tokens()); } -bool BlockManagerPool::check_if_enough_to_evict( - DecodePriorityQueue* running_queue_to_evict, - Sequence* prefill_sequence, - size_t& num_request_to_evict) { - DCHECK(prefill_sequence != nullptr); - int32_t dp_rank = prefill_sequence->dp_rank(); - return block_managers_[dp_rank]->check_if_enough_to_evict( - running_queue_to_evict, prefill_sequence, num_request_to_evict); -} - bool BlockManagerPool::allocate(std::vector& sequences) { for (auto* sequence : sequences) { DCHECK(sequence != nullptr); diff --git a/xllm/core/framework/block/block_manager_pool.h b/xllm/core/framework/block/block_manager_pool.h index d39e0d4b..f646646b 100644 --- a/xllm/core/framework/block/block_manager_pool.h +++ b/xllm/core/framework/block/block_manager_pool.h @@ -48,10 +48,6 @@ class BlockManagerPool { void get_merged_kvcache_event(KvCacheEvent* event) const; float get_gpu_cache_usage_perc() const; - bool check_if_enough_to_evict(DecodePriorityQueue* running_queue_to_evict, - Sequence* prefill_sequence, - size_t& num_request_to_evict); - std::vector num_blocks_in_prefix_cache() const; std::vector num_free_blocks() const; std::vector num_used_blocks() const; diff --git a/xllm/core/framework/block/concurrent_block_manager_impl.cpp b/xllm/core/framework/block/concurrent_block_manager_impl.cpp index 63a8a523..d7414dc1 100644 --- a/xllm/core/framework/block/concurrent_block_manager_impl.cpp +++ b/xllm/core/framework/block/concurrent_block_manager_impl.cpp @@ -43,15 +43,6 @@ void ConcurrentBlockManagerImpl::cache(const Slice& token_ids, BlockManagerImpl::cache(token_ids, blocks); } -bool ConcurrentBlockManagerImpl::check_if_enough_to_evict( - DecodePriorityQueue* running_queue_to_evict, - Sequence* prefill_sequence, - size_t& num_request_to_evict) { - std::lock_guard lock(mutex_); - return BlockManagerImpl::check_if_enough_to_evict( - running_queue_to_evict, prefill_sequence, num_request_to_evict); -} - size_t ConcurrentBlockManagerImpl::num_blocks_in_prefix_cache() const { std::lock_guard lock(mutex_); return BlockManagerImpl::num_blocks_in_prefix_cache(); diff --git a/xllm/core/framework/block/concurrent_block_manager_impl.h b/xllm/core/framework/block/concurrent_block_manager_impl.h index a4a77a4e..30233cd0 100644 --- a/xllm/core/framework/block/concurrent_block_manager_impl.h +++ b/xllm/core/framework/block/concurrent_block_manager_impl.h @@ -39,10 +39,6 @@ class ConcurrentBlockManagerImpl : public BlockManagerImpl { void cache(const Slice& token_ids, const Slice& blocks) override; - bool check_if_enough_to_evict(DecodePriorityQueue* running_queue_to_evict, - Sequence* prefill_sequence, - size_t& num_request_to_evict) override; - // get the number of blocks in the prefix cache size_t num_blocks_in_prefix_cache() const override; diff --git a/xllm/core/framework/request/request.cpp b/xllm/core/framework/request/request.cpp index 1fd4a859..71b866d0 100644 --- a/xllm/core/framework/request/request.cpp +++ b/xllm/core/framework/request/request.cpp @@ -36,7 +36,7 @@ Request::Request(const std::string& request_id, const std::string& service_request_id, bool offline, int32_t slo_ms, - xllm::proto::Priority priority) + RequestPriority priority) : request_id_(request_id), service_request_id_(service_request_id), x_request_id_(x_request_id), diff --git a/xllm/core/framework/request/request.h b/xllm/core/framework/request/request.h index 462dab89..b11f1970 100644 --- a/xllm/core/framework/request/request.h +++ b/xllm/core/framework/request/request.h @@ -31,6 +31,8 @@ limitations under the License. namespace xllm { +enum class RequestPriority { DEFAULT = 0, HIGH = 1, NORMAL = 2, LOW = 3 }; + class Request { public: Request(const std::string& request_id, @@ -40,7 +42,7 @@ class Request { const std::string& service_request_id = "", bool offline = false, int32_t slo_ms = 0, - xllm::proto::Priority priority = xllm::proto::Priority::NORMAL); + RequestPriority priority = RequestPriority::NORMAL); bool finished() const; @@ -86,7 +88,7 @@ class Request { const bool offline() const { return offline_; } const int32_t slo_ms() const { return slo_ms_; } - const xllm::proto::Priority priority() const { return priority_; } + const RequestPriority priority() const { return priority_; } RequestState& state() { return state_; } @@ -119,7 +121,7 @@ class Request { int32_t slo_ms_; - xllm::proto::Priority priority_; + RequestPriority priority_; private: void create_sequences_group(); diff --git a/xllm/core/framework/request/request_params.cpp b/xllm/core/framework/request/request_params.cpp index 2552e453..c6b5a806 100644 --- a/xllm/core/framework/request/request_params.cpp +++ b/xllm/core/framework/request/request_params.cpp @@ -54,7 +54,7 @@ RequestParams::RequestParams(const proto::CompletionRequest& request, slo_ms = request.slo_ms(); } if (request.has_priority()) { - priority = request.priority(); + priority = static_cast(request.priority()); } if (request.has_service_request_id()) { @@ -203,7 +203,7 @@ void InitFromChatRequest(RequestParams& params, const ChatRequest& request) { params.slo_ms = request.slo_ms(); } if (request.has_priority()) { - params.priority = request.priority(); + params.priority = static_cast(request.priority()); } if (request.has_service_request_id()) { diff --git a/xllm/core/framework/request/request_params.h b/xllm/core/framework/request/request_params.h index 2e779e9a..f34fe478 100644 --- a/xllm/core/framework/request/request_params.h +++ b/xllm/core/framework/request/request_params.h @@ -29,6 +29,7 @@ limitations under the License. #include "core/common/types.h" #include "embedding.pb.h" #include "multimodal.pb.h" +#include "request.h" #include "request_output.h" namespace xllm { @@ -130,7 +131,7 @@ struct RequestParams { int32_t slo_ms = 0; - xllm::proto::Priority priority = xllm::proto::Priority::NORMAL; + RequestPriority priority = RequestPriority::NORMAL; }; } // namespace xllm diff --git a/xllm/core/runtime/options.h b/xllm/core/runtime/options.h index e9647a8a..ee33618d 100644 --- a/xllm/core/runtime/options.h +++ b/xllm/core/runtime/options.h @@ -124,7 +124,7 @@ struct Options { PROPERTY(std::string, priority_strategy) = "FCFS"; - PROPERTY(bool, enable_on_preempt_off) = true; + PROPERTY(bool, enable_online_preempt_offline) = true; }; } // namespace runtime diff --git a/xllm/core/scheduler/chunked_prefill_scheduler.cpp b/xllm/core/scheduler/chunked_prefill_scheduler.cpp index aa0e4773..ee86c881 100644 --- a/xllm/core/scheduler/chunked_prefill_scheduler.cpp +++ b/xllm/core/scheduler/chunked_prefill_scheduler.cpp @@ -222,7 +222,7 @@ void ChunkedPrefillScheduler::handle_running_queue_requests( if (request_to_preempt.get() != request.get()) { ++num_preempted_requests; block_manager_pool_->deallocate(request_to_preempt.get()); - running_queue_.->pop_back(); + running_queue_->pop_back(); // add preemptable request to waiting priority queue request_to_preempt->set_preempted(); waiting_priority_queue_.push(request_to_preempt); diff --git a/xllm/core/scheduler/continuous_scheduler.cpp b/xllm/core/scheduler/continuous_scheduler.cpp index da1204e9..04489f65 100644 --- a/xllm/core/scheduler/continuous_scheduler.cpp +++ b/xllm/core/scheduler/continuous_scheduler.cpp @@ -96,11 +96,37 @@ void ContinuousScheduler::create_running_queue(const Options& options) { } } +bool ContinuousScheduler::check_if_enough_to_evict( + DecodePriorityQueue* running_queue_to_evict, + Sequence* prefill_sequence, + size_t& num_request_to_evict) { + // check if it's enough when we evict this requests queue + auto block_size = block_manager_pool_->options().block_size(); + const size_t num_blocks_needed = + (prefill_sequence->num_tokens() + block_size - 1) / block_size; + size_t num_blocks_can_evict = 0; + // count the number of blocks can be preempted + for (auto it = running_queue_to_evict->rbegin(); + it != running_queue_to_evict->rend(); + ++it) { + std::shared_ptr request_to_preempt = *it; + num_request_to_evict++; + // count the number of blocks belong to the request + for (const auto& seq : request_to_preempt->sequences()) { + num_blocks_can_evict += seq->kv_state().num_kv_blocks(); + } + if (num_blocks_needed <= num_blocks_can_evict) { + return true; + } + } + return false; +} + void ContinuousScheduler::handle_prefill_requests( size_t& remaining_token_budget, size_t& remaining_seq_budget, RequestPriorityQueue& waiting_priority_queue, - size_t& num_onp_preempt_off_requests, + size_t& num_online_prefill_preempt_offline_requests, std::vector>& finished_requests) { // Handle new request prompt first. // Include those requests that are preempted by others. @@ -158,20 +184,20 @@ void ContinuousScheduler::handle_prefill_requests( // preempt offline decode if (!block_manager_pool_->allocate(prefill_sequence.get())) { can_schedule = false; - if (options_.enable_on_preempt_off() && !request->offline() && + if (options_.enable_online_preempt_offline() && !request->offline() && !running_queue_offline_->empty()) { size_t num_request_to_evict = 0; // according to the prefill_sequence num tokens to check if can // allocate blocks for it through evict - bool enough_to_evict = block_manager_pool_->check_if_enough_to_evict( - running_queue_offline_.get(), - prefill_sequence.get(), - num_request_to_evict); + bool enough_to_evict = + check_if_enough_to_evict(running_queue_offline_.get(), + prefill_sequence.get(), + num_request_to_evict); if (enough_to_evict) { for (size_t i = 0; i < num_request_to_evict; ++i) { std::shared_ptr request_to_preempt = running_queue_offline_->back(); - ++num_onp_preempt_off_requests; + ++num_online_prefill_preempt_offline_requests; block_manager_pool_->deallocate(request_to_preempt.get()); running_queue_offline_->pop_back(); // add preemptable request to waiting priority queue @@ -222,14 +248,14 @@ void ContinuousScheduler::handle_prefill_requests( prefill_sequences_budget.begin(), prefill_sequences_budget.end()); } - // maybe can pre-compute if prompt beyond lnegth + // maybe can pre-compute if prompt beyond length if (running_sequences_.empty() && !waiting_priority_queue.empty() && - running_queue_->empty() && running_queue_offline_->empty()) { + running_queue_->empty()) { LOG(ERROR) << "Request prompt is too long, no enough memory to schedule " "a single sequence."; // no enough memory to schedule single sequence, just finish the request - std::shared_ptr request(waiting_priority_queue_.top()); - waiting_priority_queue_.pop(); + std::shared_ptr request(waiting_priority_queue.top()); + waiting_priority_queue.pop(); block_manager_pool_->deallocate(request.get()); response_processor_->process_failed_request( request, @@ -245,9 +271,9 @@ void ContinuousScheduler::handle_prefill_requests( void ContinuousScheduler::handle_decode_requests( size_t& remaining_token_budget, size_t& remaining_seq_budget, - size_t& num_offd_preempt_off_requests, - size_t& num_ond_preempt_on_requests, - size_t& num_ond_preempt_off_requests, + size_t& num_offline_decode_preempt_offline_requests, + size_t& num_online_decode_preempt_online_requests, + size_t& num_online_decode_preempt_offline_requests, std::unique_ptr& running_queue) { while (!running_queue->empty() && remaining_token_budget > options_.num_speculative_tokens() && @@ -291,10 +317,10 @@ void ContinuousScheduler::handle_decode_requests( } // update the allocated tokens for the sequence - allocated_tokens += options_.num_speculative_tokens() + 1; + allocated_tokens += options_.num_speculative_tokens(); allocated_seqs += 1; - candidate_sequences.push_back(sequence.get()); - candidate_token_budgets.push_back(options_.num_speculative_tokens() + 1); + candidate_sequences.emplace_back(sequence.get()); + candidate_token_budgets.emplace_back(options_.num_speculative_tokens()); } CHECK(allocated_tokens <= remaining_token_budget); CHECK(allocated_seqs <= remaining_seq_budget); @@ -333,13 +359,12 @@ void ContinuousScheduler::handle_decode_requests( // memory exhausted, try to preempt lowest priority request // continue to evict blocks until enough or no other requests that can be - // preempted TO IMPROVE: preplan if is enough to evict, if not, then not - // evict the offline request or online request with lowest priority - if (options_.enable_on_preempt_off() && !request->offline() && + // preempted + if (options_.enable_online_preempt_offline() && !request->offline() && !running_queue_offline_->empty()) { std::shared_ptr request_to_preempt = running_queue_offline_->back(); - ++num_ond_preempt_off_requests; + ++num_online_decode_preempt_offline_requests; block_manager_pool_->deallocate(request_to_preempt.get()); running_queue_offline_->pop_back(); // add preemptable request to waiting priority queue @@ -350,9 +375,9 @@ void ContinuousScheduler::handle_decode_requests( std::shared_ptr request_to_preempt = running_queue->back(); if (request_to_preempt.get() != request.get()) { if (request->offline()) { - ++num_offd_preempt_off_requests; + ++num_offline_decode_preempt_offline_requests; } else { - ++num_ond_preempt_on_requests; + ++num_online_decode_preempt_online_requests; } // TO IMPROVE: kv cache offload to cpu block_manager_pool_->deallocate(request_to_preempt.get()); @@ -520,23 +545,7 @@ std::vector ContinuousScheduler::prepare_batch() { *it = nullptr; } } - // process previous batch - // insert running requests back to the running queue, iterating from - // the highest priority to the lowest - // insert running requests back to the running queue, iterating from - // the highest priority to the lowest - // 1. last step is prefill step: - // new prefill has high priority, but these requests has lower priority - // then existed requests in running_queue_ in decoding stage. - // so we need to push them to the back of running_queue_-> - // 2. last step is decode step: - // We need to traverse running_requests_ array in reverse order. - // Because there may be some unexecuted requests with - // lower priorities remaining in the running_queue_-> - // For the requests in running_requests_, - // their priorities are all higher than those of the - // remaining requests. Therefore, insert all requests to the front of - // running_queue_ + if (options_.priority_strategy() == "FCFS") { if (last_step_prefill_) { // insert all requests to the back of running_queue_ @@ -583,11 +592,13 @@ std::vector ContinuousScheduler::prepare_batch() { } } } else { + // directly push running requests to the priority queue for (auto it = running_requests_.begin(); it != running_requests_.end(); ++it) { if (*it == nullptr) { continue; } + handle_running_requests(*it); if ((*it)->offline()) { running_queue_offline_->push(*it); } else { @@ -606,20 +617,20 @@ std::vector ContinuousScheduler::prepare_batch() { size_t remaining_token_budget = options_.max_tokens_per_batch(); size_t remaining_seq_budget = std::max(options_.max_seqs_per_batch(), 1); size_t num_preempted_requests = 0; - size_t num_offd_preempt_off_requests = 0; - size_t num_ond_preempt_on_requests = 0; - size_t num_onp_preempt_off_requests = 0; - size_t num_ond_preempt_off_requests = 0; + size_t num_offline_decode_preempt_offline_requests = 0; + size_t num_online_decode_preempt_online_requests = 0; + size_t num_online_prefill_preempt_offline_requests = 0; + size_t num_online_decode_preempt_offline_requests = 0; // TO IMPROVE?: handle online decode request before prefill offline request handle_prefill_requests(remaining_token_budget, remaining_seq_budget, waiting_priority_queue_, - num_onp_preempt_off_requests, + num_online_prefill_preempt_offline_requests, finished_requests); handle_prefill_requests(remaining_token_budget, remaining_seq_budget, waiting_priority_queue_offline_, - num_onp_preempt_off_requests, + num_online_prefill_preempt_offline_requests, finished_requests); if (running_sequences_.empty()) { @@ -628,21 +639,22 @@ std::vector ContinuousScheduler::prepare_batch() { // queue handle_decode_requests(remaining_token_budget, remaining_seq_budget, - num_offd_preempt_off_requests, - num_ond_preempt_on_requests, - num_ond_preempt_off_requests, + num_offline_decode_preempt_offline_requests, + num_online_decode_preempt_online_requests, + num_online_decode_preempt_offline_requests, running_queue_); handle_decode_requests(remaining_token_budget, remaining_seq_budget, - num_offd_preempt_off_requests, - num_ond_preempt_on_requests, - num_ond_preempt_off_requests, + num_offline_decode_preempt_offline_requests, + num_online_decode_preempt_online_requests, + num_online_decode_preempt_offline_requests, running_queue_offline_); } - num_preempted_requests = - num_offd_preempt_off_requests + num_ond_preempt_on_requests + - num_ond_preempt_off_requests + num_onp_preempt_off_requests; + num_preempted_requests = num_offline_decode_preempt_offline_requests + + num_online_decode_preempt_online_requests + + num_online_decode_preempt_offline_requests + + num_online_prefill_preempt_offline_requests; if (!finished_requests.empty()) { response_processor_->process_completed_requests(finished_requests); } @@ -663,10 +675,14 @@ std::vector ContinuousScheduler::prepare_batch() { waiting_priority_queue_.size() + running_queue_->size()); GAUGE_ADD(num_preempted_requests, num_preempted_requests); - GAUGE_ADD(num_offd_preempt_off_requests, num_offd_preempt_off_requests); - GAUGE_ADD(num_ond_preempt_on_requests, num_ond_preempt_on_requests); - GAUGE_ADD(num_onp_preempt_off_requests, num_onp_preempt_off_requests); - GAUGE_ADD(num_ond_preempt_off_requests, num_ond_preempt_off_requests); + GAUGE_ADD(num_offline_decode_preempt_offline_requests, + num_offline_decode_preempt_offline_requests); + GAUGE_ADD(num_online_decode_preempt_online_requests, + num_online_decode_preempt_online_requests); + GAUGE_ADD(num_online_prefill_preempt_offline_requests, + num_online_prefill_preempt_offline_requests); + GAUGE_ADD(num_online_decode_preempt_offline_requests, + num_online_decode_preempt_offline_requests); GAUGE_SET(num_running_sequences, running_sequences_.size()); @@ -694,7 +710,8 @@ std::vector ContinuousScheduler::schedule_request( } if (!waiting_priority_queue_.empty() || !running_queue_->empty() || - !waiting_priority_queue_offline_.empty()) { + !waiting_priority_queue_offline_.empty() || + !running_queue_offline_->empty()) { continue; } diff --git a/xllm/core/scheduler/continuous_scheduler.h b/xllm/core/scheduler/continuous_scheduler.h index 74f8b7ce..7acc9923 100644 --- a/xllm/core/scheduler/continuous_scheduler.h +++ b/xllm/core/scheduler/continuous_scheduler.h @@ -86,7 +86,7 @@ class ContinuousScheduler : public Scheduler { // TODO: think if distinguish prefill and decode priority strategy PROPERTY(std::string, priority_strategy) = "FCFS"; // priority, deadline, FCFS - PROPERTY(bool, enable_on_preempt_off) = true; + PROPERTY(bool, enable_online_preempt_offline) = true; }; ContinuousScheduler(Engine* engine, const Options& options); @@ -204,14 +204,14 @@ class ContinuousScheduler : public Scheduler { size_t& remaining_token_budget, size_t& remaining_seq_budget, RequestPriorityQueue& waiting_priority_queue, - size_t& num_onp_preempt_off_requests, + size_t& num_online_prefill_preempt_offline_requests, std::vector>& finished_requests); void handle_decode_requests( size_t& remaining_token_budget, size_t& remaining_seq_budget, - size_t& num_offd_preempt_off_requests, - size_t& num_ond_preempt_on_requests, - size_t& num_ond_preempt_off_requests, + size_t& num_offline_decode_preempt_offline_requests, + size_t& num_online_decode_preempt_online_requests, + size_t& num_online_decode_preempt_offline_requests, std::unique_ptr& running_queue); void handle_abnormal_request( std::unique_ptr& running_queue, @@ -242,6 +242,10 @@ class ContinuousScheduler : public Scheduler { void create_running_queue(const Options& options); + bool check_if_enough_to_evict(DecodePriorityQueue* running_queue_to_evict, + Sequence* prefill_sequence, + size_t& num_request_to_evict); + private: // tokenizer std::unique_ptr tokenizer_; diff --git a/xllm/core/scheduler/continuous_scheduler_test.cpp b/xllm/core/scheduler/continuous_scheduler_test.cpp index f677e75f..d3fe7359 100644 --- a/xllm/core/scheduler/continuous_scheduler_test.cpp +++ b/xllm/core/scheduler/continuous_scheduler_test.cpp @@ -111,15 +111,15 @@ std::vector> generate_request( false, nullptr, nullptr); - auto request = std::make_shared( - "1", - "1", - "1", - std::move(req_state), - "1", - offlines[i], - 0, - static_cast(priorities[i])); + auto request = + std::make_shared("1", + "1", + "1", + std::move(req_state), + "1", + offlines[i], + 0, + static_cast(priorities[i])); requests.emplace_back(request); } @@ -314,8 +314,10 @@ TEST(ContinuousSchedulerTest, PrioritySchedule) { EXPECT_TRUE(batch.size() == 1); EXPECT_TRUE(batch[0].size() == 2); EXPECT_TRUE(scheduler->get_running_requests().size() == 2); - EXPECT_TRUE(scheduler->get_running_requests()[0]->priority() == 2 /*NORMAL*/); - EXPECT_TRUE(scheduler->get_running_requests()[1]->priority() == 3 /*LOW*/); + EXPECT_TRUE(scheduler->get_running_requests()[0]->priority() == + RequestPriority::NORMAL /*NORMAL*/); + EXPECT_TRUE(scheduler->get_running_requests()[1]->priority() == + RequestPriority::LOW /*LOW*/); running_requests = scheduler->get_running_requests(); update_requests(running_requests); @@ -334,8 +336,10 @@ TEST(ContinuousSchedulerTest, PrioritySchedule) { EXPECT_TRUE(batch.size() == 1); EXPECT_TRUE(batch[0].size() == 2); EXPECT_TRUE(scheduler->get_running_requests().size() == 2); - EXPECT_TRUE(scheduler->get_running_requests()[0]->priority() == 1 /*HIGH*/); - EXPECT_TRUE(scheduler->get_running_requests()[1]->priority() == 2 /*NORMAL*/); + EXPECT_TRUE(scheduler->get_running_requests()[0]->priority() == + RequestPriority::HIGH /*HIGH*/); + EXPECT_TRUE(scheduler->get_running_requests()[1]->priority() == + RequestPriority::NORMAL /*NORMAL*/); } } // namespace xllm \ No newline at end of file diff --git a/xllm/core/scheduler/disagg_pd_scheduler.cpp b/xllm/core/scheduler/disagg_pd_scheduler.cpp index 85d5a8ea..2e0298cb 100644 --- a/xllm/core/scheduler/disagg_pd_scheduler.cpp +++ b/xllm/core/scheduler/disagg_pd_scheduler.cpp @@ -184,18 +184,19 @@ bool DisaggPDScheduler::add_request(std::shared_ptr& request) { void DisaggPDScheduler::dispatch_requests() { while (true) { std::vector> requests; - - auto poped_result = prefill_request_queue_.try_pop(); - // OPTIMIZE 之后改为:多次尝试读取在线 prefill - // 请求,只有较长时间未获取在线请求时,才读取离线 prefill 请求。 - if (!poped_result.has_value()) { + const auto timeout = absl::Milliseconds(100); + // Wait for online request until timeout. + // If timeout, try to get offline request once. If no offline request, + // continue to wait for online request. This can avoid offline request + // blocking online request for too long time. + auto poped_result = prefill_request_queue_.pop(timeout); + if (!poped_result.has_value()) { // try get online request timeout poped_result = prefill_request_queue_offline_.try_pop(); if (!poped_result.has_value()) { - // no offline request, sleep for a while and try again - absl::SleepFor(absl::Milliseconds(100)); continue; } } + auto request = poped_result.value(); if (request == nullptr) { // nullptr is a signal to exit diff --git a/xllm/core/util/concurrent_queue.h b/xllm/core/util/concurrent_queue.h index 909afdcb..14cc35d7 100644 --- a/xllm/core/util/concurrent_queue.h +++ b/xllm/core/util/concurrent_queue.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/types/optional.h" + #if __has_attribute(guarded_by) #define GUARDED_BY(x) __attribute__((guarded_by(x))) #else @@ -74,6 +76,20 @@ class ConcurrentQueue { queue_.pop(); return value; } + // pop an element from the queue, block if the queue is empty, with timeout + absl::optional pop(absl::Duration timeout) { + absl::MutexLock lock(&mutex_); + + auto not_empty = +[](std::queue* q) { return !q->empty(); }; + + if (mutex_.AwaitWithTimeout(absl::Condition(not_empty, &queue_), timeout)) { + T value = std::move(queue_.front()); + queue_.pop(); + return value; + } + + return absl::nullopt; + } std::optional try_pop() { absl::MutexLock lock(&mutex_); diff --git a/xllm/xllm.cpp b/xllm/xllm.cpp index 6e0815d5..42ef6a97 100644 --- a/xllm/xllm.cpp +++ b/xllm/xllm.cpp @@ -134,7 +134,7 @@ int run() { .enable_service_routing(FLAGS_enable_service_routing) .tool_call_parser(FLAGS_tool_call_parser) .priority_strategy(FLAGS_priority_strategy) - .enable_on_preempt_off(FLAGS_enable_on_preempt_off); + .enable_online_preempt_offline(FLAGS_enable_online_preempt_offline); InstanceName::name()->set_name(options.instance_name().value_or(""));