From 17962bd63f397b20ce7542761da3e9a058c7f678 Mon Sep 17 00:00:00 2001 From: jindonghe1 Date: Tue, 19 Aug 2025 20:29:11 +0800 Subject: [PATCH 1/4] feat: support expert dynamic load balancing for DeepSeek. --- xllm/core/common/global_flags.cpp | 6 + xllm/core/common/global_flags.h | 6 + xllm/core/common/options.h | 6 + xllm/core/common/types.h | 6 + .../distributed_runtime/worker_service.cpp | 40 +- xllm/core/framework/CMakeLists.txt | 1 + xllm/core/framework/eplb/CMakeLists.txt | 50 ++ xllm/core/framework/eplb/eplb_executor.cpp | 114 ++++ xllm/core/framework/eplb/eplb_executor.h | 49 ++ xllm/core/framework/eplb/eplb_manager.cpp | 233 +++++++++ xllm/core/framework/eplb/eplb_manager.h | 72 +++ xllm/core/framework/eplb/eplb_policy.cpp | 156 ++++++ xllm/core/framework/eplb/eplb_policy.h | 31 ++ xllm/core/framework/eplb/eplb_policy_test.cpp | 23 + .../framework/eplb/expert_buffer_manager.cpp | 37 ++ .../framework/eplb/expert_buffer_manager.h | 33 ++ .../eplb/expert_weight_buffer_shm.cpp | 239 +++++++++ .../framework/eplb/expert_weight_buffer_shm.h | 78 +++ .../framework/eplb/shared_memory_manager.cpp | 99 ++++ .../framework/eplb/shared_memory_manager.h | 38 ++ xllm/core/framework/model/causal_lm.h | 14 + xllm/core/framework/model/causal_vlm.h | 7 + .../core/framework/model/model_input_params.h | 3 +- .../framework/model/npu_dp_ep_padding.cpp | 11 +- xllm/core/framework/model/npu_dp_ep_padding.h | 2 + .../model/npu_dp_ep_padding_test.cpp | 8 +- xllm/core/layers/npu/CMakeLists.txt | 1 + .../layers/npu/deepseek_v2_decoder_layer.cpp | 486 ++++++++++++++---- .../layers/npu/deepseek_v2_decoder_layer.h | 110 +++- xllm/core/runtime/CMakeLists.txt | 1 + xllm/core/runtime/forward_params.h | 9 + xllm/core/runtime/llm_engine.cpp | 102 +++- xllm/core/runtime/llm_engine.h | 9 + xllm/core/runtime/llm_worker_impl.cpp | 17 +- xllm/core/runtime/master.cpp | 11 +- xllm/core/runtime/params_utils.cpp | 36 ++ xllm/core/runtime/params_utils.h | 2 + xllm/core/runtime/worker_impl.cpp | 20 +- xllm/core/runtime/worker_impl.h | 5 + xllm/models/deepseek_v2.h | 27 + xllm/models/deepseek_v2_mtp.h | 5 + xllm/models/llama.h | 6 + xllm/models/qwen3_moe.h | 6 + xllm/models/qwen_base.h | 6 + xllm/proto/worker.proto | 9 + xllm/xllm.cpp | 3 + 46 files changed, 2100 insertions(+), 133 deletions(-) create mode 100644 xllm/core/framework/eplb/CMakeLists.txt create mode 100644 xllm/core/framework/eplb/eplb_executor.cpp create mode 100644 xllm/core/framework/eplb/eplb_executor.h create mode 100644 xllm/core/framework/eplb/eplb_manager.cpp create mode 100644 xllm/core/framework/eplb/eplb_manager.h create mode 100644 xllm/core/framework/eplb/eplb_policy.cpp create mode 100644 xllm/core/framework/eplb/eplb_policy.h create mode 100644 xllm/core/framework/eplb/eplb_policy_test.cpp create mode 100644 xllm/core/framework/eplb/expert_buffer_manager.cpp create mode 100644 xllm/core/framework/eplb/expert_buffer_manager.h create mode 100644 xllm/core/framework/eplb/expert_weight_buffer_shm.cpp create mode 100644 xllm/core/framework/eplb/expert_weight_buffer_shm.h create mode 100644 xllm/core/framework/eplb/shared_memory_manager.cpp create mode 100644 xllm/core/framework/eplb/shared_memory_manager.h diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index cd94dbf9..5bc72353 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -116,6 +116,12 @@ DEFINE_double(prefill_scheduling_memory_usage_threshold, DEFINE_string(communication_backend, "hccl", "npu communication backend."); +DEFINE_bool(enable_eplb, false, "Whether to use ep load balance."); + +DEFINE_int64(eplb_update_rate, 1000, "eplb update rate."); + +DEFINE_double(eplb_update_threshold, 0.8, "eplb update threshold."); + DEFINE_string(rank_tablefile, "", "atb hccl rank table file."); DEFINE_int32(expert_parallel_degree, 0, "ep degree"); diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 563a42c0..6c588ba5 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -67,6 +67,12 @@ DECLARE_int32(num_response_handling_threads); DECLARE_string(communication_backend); +DECLARE_bool(enable_eplb); + +DECLARE_int64(eplb_update_rate); + +DECLARE_double(eplb_update_threshold); + DECLARE_string(rank_tablefile); DECLARE_bool(enable_mla); diff --git a/xllm/core/common/options.h b/xllm/core/common/options.h index 9ea5a596..cd73dd37 100644 --- a/xllm/core/common/options.h +++ b/xllm/core/common/options.h @@ -70,6 +70,12 @@ class Options { // thread num to handle requests PROPERTY(size_t, num_handling_threads) = 4; + PROPERTY(std::optional, enable_eplb); + + PROPERTY(std::optional, eplb_update_rate); + + PROPERTY(std::optional, eplb_update_threshold); + PROPERTY(std::optional, communication_backend); PROPERTY(std::optional, rank_tablefile); diff --git a/xllm/core/common/types.h b/xllm/core/common/types.h index ab3d5f3d..3f6f218b 100644 --- a/xllm/core/common/types.h +++ b/xllm/core/common/types.h @@ -252,4 +252,10 @@ struct JsonTool { : type(tool_type), function(func) {} }; +struct EplbInfo { + int32_t prepare_layer_id = -1; + std::vector expert_ids; + int32_t update_layer_id = -1; +}; + } // namespace xllm diff --git a/xllm/core/distributed_runtime/worker_service.cpp b/xllm/core/distributed_runtime/worker_service.cpp index df95cbf2..a8de3ecd 100644 --- a/xllm/core/distributed_runtime/worker_service.cpp +++ b/xllm/core/distributed_runtime/worker_service.cpp @@ -31,6 +31,7 @@ limitations under the License. #include #endif +#include "common/global_flags.h" #include "common/metrics.h" #include "framework/request/sequence.h" #include "framework/sampling/sampling_params.h" @@ -331,6 +332,8 @@ void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller, torch::Tensor top_tokens; torch::Tensor top_logprobs; torch::Tensor embeddings; + torch::Tensor expert_load_data; + int32_t prepared_layer_id = -1; // execute model auto future = worker_->step_async(forward_inputs); @@ -341,6 +344,10 @@ void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller, if (forward_outputs) { DCHECK(forward_outputs.has_value()) << "Failed to execute model"; const auto& sample_output = forward_outputs.value().sample_output; + expert_load_data = safe_to( + forward_outputs.value().expert_load_data, torch::kCPU, true); + prepared_layer_id = forward_outputs.value().prepared_layer_id; + { #if defined(USE_NPU) c10::StreamGuard streamGuard( @@ -376,15 +383,19 @@ void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller, #endif } } - } else if (worker_->is_driver()) { - // construct fake output tensor - auto options = - torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); - int32_t prefill_seq_len = - static_cast(pb_forward_input->prefill_seq_len()); - next_tokens = torch::arange( - -1, -1 * (num_sequences - prefill_seq_len + 1), -1, options); - std::move(future).deferValue([](auto&&) {}); + } else { + if (worker_->is_driver()) { + // construct fake output tensor + auto options = + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); + int32_t prefill_seq_len = + static_cast(pb_forward_input->prefill_seq_len()); + next_tokens = torch::arange( + -1, -1 * (num_sequences - prefill_seq_len + 1), -1, options); + std::move(future).deferValue([](auto&&) {}); + } + expert_load_data = + torch::zeros({1, 1}).to(torch::kInt64).contiguous(); } forward_output_to_proto(next_tokens, @@ -392,6 +403,8 @@ void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller, top_tokens, top_logprobs, embeddings, + expert_load_data, + prepared_layer_id, pb_forward_output); COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds()); }); @@ -415,6 +428,9 @@ void WorkerService::GetLastStepResult( auto forward_outputs = std::move(future).get(); if (forward_outputs) { const auto& sample_output = forward_outputs.value().sample_output; + const auto& expert_load_data = safe_to( + forward_outputs.value().expert_load_data, torch::kCPU, true); + int32_t prepared_layer_id = forward_outputs.value().prepared_layer_id; #if defined(USE_NPU) c10::StreamGuard streamGuard( npu_stream_helper_->D2H_memcpy_stream.unwrap()); @@ -429,8 +445,8 @@ void WorkerService::GetLastStepResult( // [num_seq] const auto& next_tokens = safe_to(sample_output.next_tokens, torch::kCPU, true); - if (next_tokens.defined()) { - // [num_seq] + if (next_tokens.defined() || FLAGS_enable_eplb) { + // [num_seq] FloatTensor const auto& logprobs = safe_to(sample_output.logprobs, torch::kCPU, true); // [num_seq, topk] @@ -451,6 +467,8 @@ void WorkerService::GetLastStepResult( top_tokens, top_logprobs, embeddings, + expert_load_data, + prepared_layer_id, pb_forward_output); } } diff --git a/xllm/core/framework/CMakeLists.txt b/xllm/core/framework/CMakeLists.txt index c982ebc3..1b1a0d9c 100644 --- a/xllm/core/framework/CMakeLists.txt +++ b/xllm/core/framework/CMakeLists.txt @@ -13,6 +13,7 @@ add_subdirectory(request) add_subdirectory(sampling) add_subdirectory(state_dict) add_subdirectory(tokenizer) +add_subdirectory(eplb) cc_library( NAME diff --git a/xllm/core/framework/eplb/CMakeLists.txt b/xllm/core/framework/eplb/CMakeLists.txt new file mode 100644 index 00000000..4f24ac2f --- /dev/null +++ b/xllm/core/framework/eplb/CMakeLists.txt @@ -0,0 +1,50 @@ +include(cc_binary) +include(cc_library) +include(cc_test) + +include_directories( + ${CMAKE_SOURCE_DIR}/xllm/core/kernels/ascend + ${CMAKE_SOURCE_DIR}/xllm/core/kernels/ascend/core/include +) + +cc_library( + NAME + eplb + HDRS + eplb_executor.h + eplb_manager.h + eplb_policy.h + expert_weight_buffer_shm.h + shared_memory_manager.h + expert_buffer_manager.h + SRCS + eplb_executor.cpp + eplb_manager.cpp + eplb_policy.cpp + expert_weight_buffer_shm.cpp + shared_memory_manager.cpp + expert_buffer_manager.cpp + DEPS + torch_npu + llm_engine + :request + :common + glog::glog + torch +) + +set(TEST_SRCS + eplb_policy_test.cpp +) + +cc_test( + NAME + eplb_policy_test + SRCS + ${TEST_SRCS} + DEPS + torch + :eplb + GTest::gtest_main +) + diff --git a/xllm/core/framework/eplb/eplb_executor.cpp b/xllm/core/framework/eplb/eplb_executor.cpp new file mode 100644 index 00000000..00fb0ef2 --- /dev/null +++ b/xllm/core/framework/eplb/eplb_executor.cpp @@ -0,0 +1,114 @@ +#include "eplb_executor.h" + +#include +#include +#include +#if defined(USE_NPU) +#include +#include +#include +#include +#endif +#include +#include +#include +#include +#include +#include + +#include "runtime/forward_params.h" + +namespace xllm { +#if defined(USE_NPU) +struct EplbExecutor::EplbStream { + c10_npu::NPUStream eplb_stream; + EplbStream() : eplb_stream(c10_npu::getNPUStreamFromPool()) {} +}; +#endif +EplbExecutor::EplbExecutor(CausalLM* model) + : model_(model), eplb_worker_(&EplbExecutor::eplb_worker_loop, this) { +#if defined(USE_NPU) + eplb_stream_ = std::make_unique(); +#endif +} + +EplbExecutor::~EplbExecutor() { + { + std::unique_lock lock(queue_mutex_); + stop_ = true; + } + condition_.notify_one(); + if (eplb_worker_.joinable()) { + eplb_worker_.join(); + } +} + +void EplbExecutor::eplb_execute(const EplbInfo& eplb_info) { + if (eplb_info.update_layer_id != -1) { + model_->update_expert_weight(eplb_info.update_layer_id); + }; + if (eplb_info.prepare_layer_id != -1) { + prepare_expert_weight_async( + eplb_info.prepare_layer_id, + eplb_info.expert_ids, + [eplb_info](int32_t id) { + LOG(INFO) << "prepare expert weight complete, layer: " + << eplb_info.prepare_layer_id << std::endl; + }); + }; +} + +void EplbExecutor::prepare_expert_weight_async( + int32_t layer_id, + const std::vector& expert_ids, + Callback callback) { + { + std::unique_lock lock(queue_mutex_); + tasks_.emplace(Task{layer_id, expert_ids, callback}); + } + condition_.notify_one(); +} + +int32_t EplbExecutor::get_ready_layer_id() const { + std::lock_guard lock(ready_mutex_); + return ready_layer_id_; +} + +void EplbExecutor::reset_ready_layer_id() { + std::lock_guard lock(ready_mutex_); + ready_layer_id_ = -1; +} + +void EplbExecutor::eplb_worker_loop() { + while (true) { + Task task; + { + std::unique_lock lock(queue_mutex_); + condition_.wait(lock, [this] { return !tasks_.empty() || stop_; }); + if (stop_) return; + task = std::move(tasks_.front()); + tasks_.pop(); + } + auto prepare_start = std::chrono::high_resolution_clock::now(); + + c10::StreamGuard streamGuard(eplb_stream_->eplb_stream.unwrap()); + model_->prepare_expert_weight(task.layer_id, task.expert_ids); + aclrtSynchronizeStream(eplb_stream_->eplb_stream.stream()); + auto prepare_end = std::chrono::high_resolution_clock::now(); + auto prepare_duration = + std::chrono::duration_cast(prepare_end - + prepare_start) + .count(); + LOG(INFO) << "prepare_expert_weight | layer=" << task.layer_id + << " | experts=" << task.expert_ids.size() + << " | duration=" << prepare_duration << "ms"; + { + std::lock_guard lock(ready_mutex_); + ready_layer_id_ = task.layer_id; + } + if (task.callback) { + task.callback(task.layer_id); + } + } +} +} // namespace xllm diff --git a/xllm/core/framework/eplb/eplb_executor.h b/xllm/core/framework/eplb/eplb_executor.h new file mode 100644 index 00000000..cbf37f19 --- /dev/null +++ b/xllm/core/framework/eplb/eplb_executor.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +#include +#include + +#include "common/macros.h" +#include "framework/model/causal_lm.h" +#include "framework/model/model_input_params.h" +#include "runtime/forward_params.h" + +namespace xllm { + +class EplbExecutor final { + public: + using Callback = std::function; + EplbExecutor(CausalLM* model); + + virtual ~EplbExecutor(); + void reset_ready_layer_id(); + int32_t get_ready_layer_id() const; + void eplb_execute(const EplbInfo& eplb_info); + + private: + struct Task { + int32_t layer_id; + std::vector expert_ids; + Callback callback; + }; + + void eplb_worker_loop(); + void prepare_expert_weight_async(int32_t layer_id, + const std::vector& expert_ids, + Callback callback = nullptr); + CausalLM* model_; + std::thread eplb_worker_; + std::queue tasks_; + std::mutex queue_mutex_; + std::condition_variable condition_; + bool stop_ = false; + + mutable std::mutex ready_mutex_; + int32_t ready_layer_id_ = -1; + struct EplbStream; + std::unique_ptr eplb_stream_; +}; + +} // namespace xllm diff --git a/xllm/core/framework/eplb/eplb_manager.cpp b/xllm/core/framework/eplb/eplb_manager.cpp new file mode 100644 index 00000000..2c02bfde --- /dev/null +++ b/xllm/core/framework/eplb/eplb_manager.cpp @@ -0,0 +1,233 @@ + +#include "eplb_manager.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "common/device_memory.h" +#include "common/global_flags.h" + +namespace xllm { + +using namespace std::chrono_literals; + +EplbManager::EplbManager(EplbPolicy* eplb_policy, + int32_t layer_num, + int32_t device_num, + int32_t experts_num) + : eplb_policy_(eplb_policy), + layer_num_(layer_num), + device_num_(device_num), + experts_num_(experts_num), + device_experts_num_((experts_num + device_num) / device_num) { + // Initialize tensors with mutex protection + { + std::lock_guard lock(state_.mtx); + state_.expert_load = + torch::zeros({layer_num_, experts_num_}, torch::kInt64); + state_.prepared_layer_id.resize(device_num, -1); + state_.expert_distribution = torch::zeros( + {layer_num_, device_num_, device_experts_num_}, torch::kInt32); + for (int32_t layer = 0; layer < layer_num_; ++layer) { + for (int32_t device = 0; device < device_num_; ++device) { + int32_t base = device * (device_experts_num_ - 1); + for (int32_t expert = 0; expert < device_experts_num_; ++expert) { + int32_t value = base + expert; + if (expert == device_experts_num_ - 1) { + --value; + } + state_.expert_distribution[layer][device][expert] = value; + } + } + } + } + + // Start worker threads + rebalance_thread_ = std::thread(&EplbManager::rebalance_experts_loop, this); + manager_thread_ = std::thread(&EplbManager::eplb_manager_loop, this); +} + +EplbManager::~EplbManager() { + { + std::lock_guard lock(state_.mtx); + state_.stop = true; + state_.data_cv.notify_all(); + state_.state_cv.notify_all(); + } + + if (rebalance_thread_.joinable()) rebalance_thread_.join(); + if (manager_thread_.joinable()) manager_thread_.join(); +} + +void EplbManager::update_expert_load( + const std::vector expert_load) { + std::lock_guard lock(state_.mtx); + state_.expert_load_queue.push(expert_load); + state_.data_cv.notify_one(); +} + +void EplbManager::aggregate_multi_layer_expert_loads( + torch::Tensor& expert_load, + torch::Tensor& expert_ids_list, + std::vector& expert_loads_list) { + auto options = torch::TensorOptions().dtype(torch::kInt32); + + for (int32_t device = 0; device < device_num_; ++device) { + using namespace torch::indexing; + torch::Tensor expert_load_data_right = expert_loads_list[device].slice( + 1, 1, expert_loads_list[device].size(1)); + torch::Tensor expert_load_data_left = expert_loads_list[device].slice( + 1, 0, expert_loads_list[device].size(1) - 1); + torch::Tensor expert_load_data_sub = + expert_load_data_right - expert_load_data_left; + torch::Tensor first_col = + expert_loads_list[device].select(1, 0).unsqueeze(1); + + expert_loads_list[device] = + torch::cat({first_col, expert_load_data_sub}, 1); + } + + for (int32_t layer = 0; layer < layer_num_; ++layer) { + std::vector layer_ids, layer_loads; + for (int32_t device = 0; device < device_num_; ++device) { + auto ids = expert_ids_list[layer][device]; + auto loads = expert_loads_list[device][layer]; + + layer_ids.emplace_back(ids.flatten().to(torch::kInt64)); + layer_loads.emplace_back(loads.flatten().to(torch::kInt64)); + } + + torch::Tensor all_ids = torch::cat(layer_ids); + torch::Tensor all_loads = torch::cat(layer_loads); + expert_load[layer].scatter_add_(0, all_ids, all_loads); + } +} + +void EplbManager::rebalance_experts_loop() { + int64_t latest_record_time = absl::ToUnixSeconds(absl::Now()); + while (true) { + std::vector> expert_load_batch; + { + std::unique_lock lock(state_.mtx); + state_.data_cv.wait(lock, [&] { + return state_.stop || !state_.expert_load_queue.empty(); + }); + + if (state_.stop) return; + + while (!state_.expert_load_queue.empty()) { + // expert_load_batch.emplace_back(state_.expert_load_queue.front()); + // state_.expert_load_queue.pop(); + aggregate_multi_layer_expert_loads(state_.expert_load, + state_.expert_distribution, + state_.expert_load_queue.front()); + state_.expert_load_queue.pop(); + int64_t current_time = absl::ToUnixSeconds(absl::Now()); + if (current_time - latest_record_time >= FLAGS_eplb_update_rate) { + latest_record_time = current_time; + auto result = eplb_policy_->rebalance_experts(state_.expert_load); + state_.expert_distribution = result.first; + state_.enable_update_vec = result.second; + state_.expert_load = torch::div(state_.expert_load, 2, "trunc"); + state_.to_be_prepared = find_next_true(state_.enable_update_vec, 0); + state_.state_cv.notify_all(); + } + } + } + } +} + +size_t EplbManager::find_next_true(const std::vector& vec, + size_t start_pos) { + if (start_pos >= vec.size()) return static_cast(-1); + auto begin = vec.begin() + start_pos; + auto it = std::find(begin, vec.end(), true); + return (it != vec.end()) ? static_cast(it - vec.begin()) + : static_cast(-1); +} + +void EplbManager::eplb_manager_loop() { + while (true) { + { + std::unique_lock lock(state_.mtx); + state_.state_cv.wait( + lock, [&] { return state_.to_be_prepared != -1 || state_.stop; }); + + if (state_.stop) { + return; + } + } + while (true) { + { + std::unique_lock lock(state_.mtx); + // Update preparation status + if (state_.to_be_prepared >= 0) { + bool all_prepared = true; + for (auto& layer_id : state_.prepared_layer_id) { + if (layer_id != state_.to_be_prepared) { + all_prepared = false; + break; + } + } + if (all_prepared) { + state_.ready_layer_id = state_.to_be_prepared; + // state_.preparing_layer_id = state_.to_be_prepared; + state_.to_be_prepared = find_next_true(state_.enable_update_vec, + ++state_.to_be_prepared); + if (state_.to_be_prepared == -1) { + state_.preparing_layer_id = -1; + } + } + } + if (state_.to_be_prepared < 0) { + break; + } + } + } + } +} + +EplbInfo EplbManager::get_eplb_info() { + EplbInfo info; + { + std::lock_guard lock(state_.mtx); + info.update_layer_id = state_.ready_layer_id; + if (state_.preparing_layer_id != state_.to_be_prepared && + state_.to_be_prepared != -1) { + info.prepare_layer_id = state_.to_be_prepared; + torch::Tensor distribution = + state_.expert_distribution[state_.to_be_prepared].contiguous(); + info.expert_ids = + std::vector(distribution.data_ptr(), + distribution.data_ptr() + distribution.numel()); + state_.preparing_layer_id = state_.to_be_prepared; + } else { + info.prepare_layer_id = -1; + } + state_.ready_layer_id = -1; + } + return info; +} + +void EplbManager::set_prepared_layer_ids( + const std::vector& expert_layer_ids) { + std::lock_guard lock(state_.mtx); + for (size_t i = 0; + i < expert_layer_ids.size() && i < state_.prepared_layer_id.size(); + ++i) { + if (expert_layer_ids[i] == state_.to_be_prepared) { + state_.prepared_layer_id[i] = expert_layer_ids[i]; + } + } +} + +} // namespace xllm diff --git a/xllm/core/framework/eplb/eplb_manager.h b/xllm/core/framework/eplb/eplb_manager.h new file mode 100644 index 00000000..0b251c4b --- /dev/null +++ b/xllm/core/framework/eplb/eplb_manager.h @@ -0,0 +1,72 @@ +// eplb_manager.h +#pragma once + +#include +#include +#include +#include + +#include "eplb_executor.h" +#include "eplb_policy.h" +#include "framework/model/model_input_params.h" +namespace xllm { + +class EplbManager { + public: + EplbManager(EplbPolicy* eplb_policy, + int32_t layer_num, + int32_t device_num, + int32_t experts_num); + ~EplbManager(); + + void update_expert_load(const std::vector expert_load); + EplbInfo get_eplb_info(); + void set_prepared_layer_ids(const std::vector& expert_layer_ids); + + private: + // Thread functions + void rebalance_experts_loop(); + void eplb_manager_loop(); + size_t find_next_true(const std::vector& vec, size_t start_pos); + // Shared data with mutex protection + struct ThreadSafeData { + std::mutex mtx; + std::condition_variable data_cv; + std::condition_variable state_cv; + bool stop = false; + + // Expert load tracking + torch::Tensor expert_load; + torch::Tensor expert_distribution; + std::vector enable_update_vec; + std::queue> expert_load_queue; + + // Layer state tracking + int32_t to_be_prepared = -1; + std::vector prepared_layer_id; + int32_t ready_layer_id = -1; + int32_t preparing_layer_id = -1; + }; + + // Components + EplbPolicy* eplb_policy_; + ThreadSafeData state_; + + // Constants + const int32_t layer_num_; + const int32_t device_num_; + const int32_t experts_num_; + const int32_t device_experts_num_; + + // Threads + std::thread rebalance_thread_; + std::thread manager_thread_; + + // Internal functions + void aggregate_multi_layer_expert_loads( + torch::Tensor& expert_load, + torch::Tensor& expert_ids_list, + std::vector& expert_loads_list); +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/eplb/eplb_policy.cpp b/xllm/core/framework/eplb/eplb_policy.cpp new file mode 100644 index 00000000..3323f3b0 --- /dev/null +++ b/xllm/core/framework/eplb/eplb_policy.cpp @@ -0,0 +1,156 @@ +#include "eplb_policy.h" + +#include +#include +#include + +#include "common/global_flags.h" + +namespace xllm { + +EplbPolicy::EplbPolicy(int32_t device_experts_num, + int32_t device_num, + int32_t layer_num) + : device_experts_num_(device_experts_num), + device_num_(device_num), + layer_num_(layer_num) { + old_expert_load_ = + torch::zeros({layer_num_, device_experts_num * device_num - device_num}, + torch::kInt64); + expert_distribution_ = torch::full( + {layer_num_, device_num_, device_experts_num_}, -1, torch::kInt32); +} + +std::pair> EplbPolicy::rebalance_experts( + torch::Tensor expert_load) { + std::vector enable_update_vec(layer_num_, false); + for (int64_t i = 0; i < layer_num_; ++i) { + auto current_load = expert_load[i].to(torch::kFloat64); + auto prev_load = old_expert_load_[i].to(torch::kFloat64); + + auto current_max_val = torch::max(current_load).item() + 1e-6f; + auto prev_max_val = torch::max(prev_load).item() + 1e-6f; + + current_load = (current_load / current_max_val).unsqueeze(0); + ; + prev_load = (prev_load / prev_max_val).unsqueeze(0); + ; + + auto cos_sim = + torch::nn::functional::cosine_similarity( + current_load, + prev_load, + torch::nn::functional::CosineSimilarityFuncOptions().dim(1)) + .item(); + if (cos_sim < FLAGS_eplb_update_threshold) { + enable_update_vec[i] = true; + old_expert_load_[i] = expert_load[i]; + } + } + + for (int64_t i = 0; i < layer_num_; ++i) { + if (enable_update_vec[i]) { + auto balanced = compute_balanced_pack(expert_load[i]); + expert_distribution_.index_put_({i}, balanced); + } + } + expert_distribution_ = expert_distribution_.contiguous(); + return {expert_distribution_, enable_update_vec}; +} + +torch::Tensor EplbPolicy::compute_balanced_pack( + const torch::Tensor& expert_loads) { + // Parameter Validation + TORCH_CHECK(expert_loads.dim() == 1, "expert_loads must be 1D tensor"); + const int64_t num_experts = expert_loads.size(0); + + // Generate Redundant Experts + auto [updated_weights, redundancy_map] = + update_origin_weights(expert_loads, device_num_); + + // Initialize Allocation Matrix + auto options = torch::TensorOptions().dtype(torch::kInt64); + auto device_assignments = + torch::full({device_num_, device_experts_num_}, -1, options); + auto device_loads = torch::zeros({device_num_}, torch::kInt64); + + // Assign Redundant Experts + for (int64_t origin_id = 0; origin_id < num_experts; ++origin_id) { + auto redundant_ids = redundancy_map[origin_id]; + for (int64_t i = 0; i < redundant_ids.size(0); ++i) { + if (redundant_ids[i].item() == -1) { + break; + } + auto min_idx = torch::argmin(device_loads).item(); + auto available_pos = torch::nonzero(device_assignments[min_idx] == -1); + if (available_pos.size(0) == 0) { + throw std::runtime_error("Device " + std::to_string(min_idx) + + " is full"); + } + auto pos = available_pos.select(0, 0).item(); + + device_assignments[min_idx][pos] = origin_id; + device_loads[min_idx] += updated_weights[origin_id].item(); + } + } + + // Assign Primary Experts + auto sorted_indices = torch::argsort(-updated_weights); + for (int64_t i = 0; i < sorted_indices.size(0); ++i) { + auto expert_id = sorted_indices[i].item(); + auto weight = updated_weights[expert_id].item(); + + auto candidate = (device_assignments == -1).sum(1) > 0; + if (candidate.sum().item() == 0) break; + + auto valid_devices_vec = torch::where(candidate); + auto valid_devices = valid_devices_vec[0]; + + auto min_idx = torch::argmin(device_loads.index({valid_devices})); + auto target_device = valid_devices[min_idx].item(); + + auto pos = torch::nonzero(device_assignments[target_device] == -1); + if (pos.size(0) == 0) { + throw std::runtime_error("Target device " + + std::to_string(target_device) + " is full"); + } + auto pos_idx = pos.select(0, 0).item(); + device_assignments[target_device][pos_idx] = expert_id; + device_loads[target_device] += weight; + } + + return device_assignments; +} + +std::pair EplbPolicy::update_origin_weights( + torch::Tensor expert_loads, + int32_t redundancy_experts) { + // Parameter Validation + TORCH_CHECK(expert_loads.dim() == 1, "expert_loads must be 1D tensor"); + const int64_t num_experts = expert_loads.size(0); + + // Initialize Data Structures + auto redundancy_map = + torch::full({num_experts, redundancy_experts}, -1, torch::kInt64); + auto current_weights = expert_loads.clone(); + + // Dynamic Weight Adjustment + for (int i = 0; i < redundancy_experts; ++i) { + auto max_idx = torch::argmax(current_weights).item(); + auto redundancy_count = + torch::sum(redundancy_map[max_idx] != -1).item() + 1; + + // Update redundancy mapping + redundancy_map[max_idx][redundancy_count - 1] = num_experts + i; + + // Adjust weights using dynamic formula + auto new_weight = + (current_weights[max_idx].item() * redundancy_count) / + (redundancy_count + 1.0); + current_weights[max_idx] = new_weight; + } + + return {current_weights, redundancy_map}; +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/eplb/eplb_policy.h b/xllm/core/framework/eplb/eplb_policy.h new file mode 100644 index 00000000..ec32d5e2 --- /dev/null +++ b/xllm/core/framework/eplb/eplb_policy.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace xllm { + +class EplbPolicy { + public: + EplbPolicy(int32_t device_experts_num, int32_t device_num, int32_t layer_num); + virtual ~EplbPolicy() {}; + std::pair> rebalance_experts( + torch::Tensor expert_load); + + private: + torch::Tensor old_expert_load_; + int32_t device_experts_num_; + int32_t device_num_; + int32_t layer_num_; + torch::Tensor expert_distribution_; + torch::Tensor compute_balanced_pack(const torch::Tensor& expert_loads); + std::pair update_origin_weights( + torch::Tensor expert_loads, + int32_t redundancy_experts); +}; +} // namespace xllm diff --git a/xllm/core/framework/eplb/eplb_policy_test.cpp b/xllm/core/framework/eplb/eplb_policy_test.cpp new file mode 100644 index 00000000..122aeb10 --- /dev/null +++ b/xllm/core/framework/eplb/eplb_policy_test.cpp @@ -0,0 +1,23 @@ +#include "eplb_policy.h" + +#include +#include +#include + +namespace xllm { + +TEST(EplbPolicyTest, Build) { + std::string rank_table_file; + EplbPolicy eplb_policy(5, 4, 1); + std::vector tensors; + tensors.push_back(torch::arange(0, 16)); + + auto expert_load = torch::stack(tensors, 0); + expert_load[0] = + torch::tensor({100, 100, 100, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 100}); + auto [rebalance_expert, enable_update_vec] = + eplb_policy.rebalance_experts(expert_load); + LOG(INFO) << "rebalance_expert:" << rebalance_expert; +} + +} // namespace xllm diff --git a/xllm/core/framework/eplb/expert_buffer_manager.cpp b/xllm/core/framework/eplb/expert_buffer_manager.cpp new file mode 100644 index 00000000..953c93ba --- /dev/null +++ b/xllm/core/framework/eplb/expert_buffer_manager.cpp @@ -0,0 +1,37 @@ +#include "expert_buffer_manager.h" + +namespace xllm { + +ExpertBufferManager::ExpertBufferManager(int32_t num_experts, + int32_t num_layers, + int64_t shm_size_per_expert) + : num_experts_(num_experts), + num_layers_(num_layers), + shm_size_per_expert_(shm_size_per_expert) { + expert_buffers_.reserve(num_experts); + for (int32_t i = 0; i < num_experts; ++i) { + expert_buffers_.emplace_back( + std::make_unique(i, num_layers, shm_size_per_expert)); + } +} + +void ExpertBufferManager::add_tensor(int32_t expert_id, + int32_t layer_id, + const std::string& tensor_name, + const torch::Tensor& tensor) { + if (expert_id < 0 || expert_id >= num_experts_) { + throw std::runtime_error("Invalid expert ID: " + std::to_string(expert_id)); + } + expert_buffers_[expert_id]->add_tensor(layer_id, tensor_name, tensor); +} + +torch::Tensor ExpertBufferManager::get_tensor(int32_t expert_id, + int32_t layer_id, + const std::string& tensor_name) { + if (expert_id < 0 || expert_id >= num_experts_) { + throw std::runtime_error("Invalid expert ID: " + std::to_string(expert_id)); + } + return expert_buffers_[expert_id]->get_tensor(layer_id, tensor_name); +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/eplb/expert_buffer_manager.h b/xllm/core/framework/eplb/expert_buffer_manager.h new file mode 100644 index 00000000..e337ffc2 --- /dev/null +++ b/xllm/core/framework/eplb/expert_buffer_manager.h @@ -0,0 +1,33 @@ +#pragma once + +#include + +#include + +#include "expert_weight_buffer_shm.h" + +namespace xllm { + +class ExpertBufferManager { + public: + ExpertBufferManager(int32_t num_experts, + int32_t num_layers, + int64_t shm_size_per_expert); + + void add_tensor(int32_t expert_id, + int32_t layer_id, + const std::string& tensor_name, + const torch::Tensor& tensor); + + torch::Tensor get_tensor(int32_t expert_id, + int32_t layer_id, + const std::string& tensor_name); + + private: + std::vector> expert_buffers_; + const int32_t num_experts_; + const int32_t num_layers_; + const int64_t shm_size_per_expert_; +}; + +} // namespace xllm diff --git a/xllm/core/framework/eplb/expert_weight_buffer_shm.cpp b/xllm/core/framework/eplb/expert_weight_buffer_shm.cpp new file mode 100644 index 00000000..4db8d765 --- /dev/null +++ b/xllm/core/framework/eplb/expert_weight_buffer_shm.cpp @@ -0,0 +1,239 @@ +#include "expert_weight_buffer_shm.h" + +#include + +#include +#include +#include +namespace xllm { + +ExpertBufferShm::ExpertBufferShm(int32_t expert_id, + int32_t max_layers, + int64_t total_size) + : expert_id_(expert_id), + max_layers_(max_layers), + layer_data_region_size_(total_size / max_layers) { + // Memory alignment calculation (64-byte alignment for performance) + constexpr size_t kAlignment = 64; + + // Calculate aligned header size (header + padding) + size_t header_size = + ((sizeof(SharedHeader) + kAlignment - 1) / kAlignment) * kAlignment; + + // Calculate aligned metadata region size (all experts' metadata + padding) + size_t meta_size = ((max_layers * MAX_TENSORS_PER_LAYER * sizeof(TensorMeta) + + kAlignment - 1) / + kAlignment) * + kAlignment; + + bool is_creator; + std::string shm_name = "xllm_expert_" + std::to_string(expert_id_); + + // Create/attach shared memory segment with calculated size + shm_ = std::make_unique( + shm_name, header_size + meta_size + total_size, is_creator); + + // Memory region pointers setup: + header_ = static_cast(shm_->base_address()); + tensor_metas_ = reinterpret_cast( + static_cast(shm_->base_address()) + header_size); + data_base_ = + static_cast(shm_->base_address()) + header_size + meta_size; + + if (is_creator) { + initialize_as_creator(); + } + verify_and_recover(); +} + +ExpertBufferShm::~ExpertBufferShm() { + std::lock_guard lock(local_mutex_); + std::atomic_thread_fence(std::memory_order_seq_cst); + header_ = nullptr; + tensor_metas_ = nullptr; + data_base_ = nullptr; +} + +void ExpertBufferShm::initialize_as_creator() { + header_->initialized_layers.store(0, std::memory_order_release); + + pthread_mutexattr_t attr; + pthread_mutexattr_init(&attr); + pthread_mutexattr_setpshared(&attr, PTHREAD_PROCESS_SHARED); + pthread_mutexattr_setrobust(&attr, PTHREAD_MUTEX_ROBUST); + + if (pthread_mutex_init(&header_->allocation_mutex, &attr) != 0) { + pthread_mutexattr_destroy(&attr); + throw std::runtime_error("Mutex initialization failed"); + } + pthread_mutexattr_destroy(&attr); + + memset(tensor_metas_, + 0, + max_layers_ * MAX_TENSORS_PER_LAYER * sizeof(TensorMeta)); +} + +void ExpertBufferShm::verify_and_recover() { + int rc = pthread_mutex_lock(&header_->allocation_mutex); + if (rc == EOWNERDEAD) { + pthread_mutex_consistent(&header_->allocation_mutex); + LOG(WARNING) << "Recovered from orphaned mutex for expert " << expert_id_; + } else if (rc != 0) { + throw std::runtime_error("Failed to acquire mutex"); + } + pthread_mutex_unlock(&header_->allocation_mutex); +} + +size_t ExpertBufferShm::get_layer_offset(int32_t layer_id) const { + if (layer_id < 0 || layer_id >= max_layers_) { + throw std::runtime_error("Invalid layer ID: " + std::to_string(layer_id) + + " for expert " + std::to_string(expert_id_)); + } + return layer_id * layer_data_region_size_; +} + +void ExpertBufferShm::add_tensor(int32_t layer_id, + const std::string& tensor_name, + const torch::Tensor& tensor) { + if (layer_id < 0 || layer_id >= max_layers_) { + throw std::runtime_error("Invalid layer ID: " + std::to_string(layer_id) + + " for expert " + std::to_string(expert_id_)); + } + if (tensor_name.empty()) { + throw std::runtime_error("Tensor name cannot be empty"); + } + if (!tensor.defined() || !tensor.is_contiguous()) { + throw std::runtime_error("Tensor must be defined and contiguous"); + } + if (tensor.device().type() != torch::kCPU) { + throw std::runtime_error("Only CPU tensors can be stored in shared memory"); + } + + std::lock_guard lock(local_mutex_); + + // Get this expert's metadata block + TensorMeta* layer_metas = &tensor_metas_[layer_id * MAX_TENSORS_PER_LAYER]; + + // Find available slot and check for duplicates + int available_slot = -1; + for (int i = 0; i < MAX_TENSORS_PER_LAYER; ++i) { + TensorMeta& meta = layer_metas[i]; + if (meta.tensor_name[0] == '\0') { + if (available_slot == -1) available_slot = i; + } else if (strcmp(meta.tensor_name, tensor_name.c_str()) == 0) { + throw std::runtime_error( + "Tensor '" + tensor_name + "' already exists for expert " + + std::to_string(expert_id_) + " layer " + std::to_string(layer_id)); + } + } + + if (available_slot == -1) { + throw std::runtime_error("No available slots for expert " + + std::to_string(expert_id_) + " layer " + + std::to_string(layer_id)); + } + + // Prepare the tensor metadata + TensorMeta& meta = layer_metas[available_slot]; + strncpy(meta.tensor_name, tensor_name.c_str(), sizeof(meta.tensor_name) - 1); + meta.tensor_name[sizeof(meta.tensor_name) - 1] = '\0'; + + meta.rank = tensor.dim(); + for (int i = 0; i < meta.rank; ++i) { + meta.shape[i] = tensor.size(i); + } + meta.dtype = static_cast(tensor.scalar_type()); + + constexpr size_t alignment = 64; + size_t raw_size = tensor.nbytes(); + size_t aligned_size = (raw_size + alignment - 1) & ~(alignment - 1); + + // Calculate offset by summing sizes of previous tensors in this expert + size_t layer_data_offset = 0; + for (int i = 0; i < MAX_TENSORS_PER_LAYER; ++i) { + if (&layer_metas[i] == &meta) break; + layer_data_offset += layer_metas[i].actual_size; + } + + if (layer_data_offset + aligned_size > layer_data_region_size_) { + throw std::runtime_error( + "Insufficient space in expert " + std::to_string(expert_id_) + + " layer " + std::to_string(layer_id) + " (needs " + + std::to_string(aligned_size) + " bytes, has " + + std::to_string(layer_data_region_size_ - layer_data_offset) + + " remaining)"); + } + + // Set final storage location + meta.data_offset = get_layer_offset(layer_id) + layer_data_offset; + meta.actual_size = raw_size; + + // Copy tensor data to shared memory + void* dest = data_base_ + meta.data_offset; + memcpy(dest, tensor.data_ptr(), raw_size); + + // Zero-fill any alignment padding + if (aligned_size > raw_size) { + memset(static_cast(dest) + raw_size, 0, aligned_size - raw_size); + } +} + +torch::Tensor ExpertBufferShm::get_tensor(int32_t layer_id, + const std::string& tensor_name) { + if (layer_id < 0 || layer_id >= max_layers_) { + throw std::runtime_error( + fmt::format("Invalid layer ID {} for expert {}", layer_id, expert_id_)); + } + + // Validate expert ID + std::lock_guard lock(local_mutex_); + + // Get this expert's metadata block + TensorMeta* layer_metas = &tensor_metas_[layer_id * MAX_TENSORS_PER_LAYER]; + + // Search for the requested tensor + for (int i = 0; i < MAX_TENSORS_PER_LAYER; ++i) { + TensorMeta& meta = layer_metas[i]; + + // Skip empty slots + if (meta.tensor_name[0] == '\0') { + continue; + } + + // Check for name match + if (strcmp(meta.tensor_name, tensor_name.c_str()) == 0) { + // Validate metadata + if (meta.data_offset < 0 || meta.actual_size == 0 || + meta.data_offset + meta.actual_size > shm_->size()) { + throw std::runtime_error(fmt::format( + "Corrupted tensor metadata for {} in expert {} layer {}", + tensor_name, + expert_id_, + layer_id)); + } + + // Create tensor options from stored type + auto options = torch::TensorOptions() + .dtype(static_cast(meta.dtype)) + .device(torch::kCPU) + .requires_grad(false); + + // Convert shape array to vector + std::vector shape(meta.shape, meta.shape + meta.rank); + + // Create tensor from shared memory + void* src = data_base_ + meta.data_offset; + torch::Tensor result = torch::from_blob(src, shape, options); + + return result; + } + } + + throw std::runtime_error( + fmt::format("Tensor {} not found in expert {} layer {}", + tensor_name, + expert_id_, + layer_id)); +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/eplb/expert_weight_buffer_shm.h b/xllm/core/framework/eplb/expert_weight_buffer_shm.h new file mode 100644 index 00000000..4d6343ff --- /dev/null +++ b/xllm/core/framework/eplb/expert_weight_buffer_shm.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +#include "shared_memory_manager.h" + +namespace xllm { + +// Maximum number of tensors each expert-layer pair can store +constexpr int MAX_TENSORS_PER_LAYER = 16; +// Maximum number of layers per expert +constexpr int MAX_LAYERS_PER_EXPERT = 128; + +// Shared memory header structure containing control information +struct SharedHeader { + std::atomic initialized_layers; // Number of initialized layers + pthread_mutex_t allocation_mutex; // Cross-process synchronization mutex +}; + +// Metadata structure for each stored tensor +struct TensorMeta { + char tensor_name[256]; // Null-terminated tensor identifier + int32_t rank; // Number of dimensions (1D, 2D, etc.) + int64_t shape[8]; // Dimensions of the tensor (max 8D) + int32_t dtype; // Data type (matches torch::Dtype) + size_t data_offset; // Byte offset in shared memory + size_t actual_size; // Unpadded data size in bytes +}; + +class ExpertBufferShm { + public: + ExpertBufferShm(int32_t expert_id, int32_t max_layers, int64_t total_size); + + virtual ~ExpertBufferShm(); + + void add_tensor(int32_t layer_id, + const std::string& tensor_name, + const torch::Tensor& tensor); + + /** + * @brief Retrieve a tensor from expert's layer memory region + * + * @param layer_id Source layer identifier + * @param tensor_name Name of the tensor to retrieve + * @return torch::Tensor A copy of the requested tensor + * @throws std::runtime_error if tensor not found + */ + torch::Tensor get_tensor(int32_t layer_id, const std::string& tensor_name); + + private: + // Initializes shared memory when creating new region + void initialize_as_creator(); + + // Verifies and recovers shared memory state + void verify_and_recover(); + + // Calculates base offset for a layer's data region + size_t get_layer_offset(int32_t layer_id) const; + + std::mutex local_mutex_; // Thread synchronization + std::unique_ptr shm_; // Shared memory manager + SharedHeader* header_ = nullptr; // Pointer to shared header + TensorMeta* tensor_metas_ = nullptr; // Array of all layers' metadata + char* data_base_ = nullptr; // Base pointer to data region + + const int32_t expert_id_; // Expert identifier + const int32_t max_layers_; // Maximum supported layers + const int64_t layer_data_region_size_; // Bytes allocated per layer +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/eplb/shared_memory_manager.cpp b/xllm/core/framework/eplb/shared_memory_manager.cpp new file mode 100644 index 00000000..c8195914 --- /dev/null +++ b/xllm/core/framework/eplb/shared_memory_manager.cpp @@ -0,0 +1,99 @@ +#include "shared_memory_manager.h" + +#include + +namespace xllm { +std::vector SharedMemoryManager::pending_cleanups; +std::mutex SharedMemoryManager::cleanup_mutex; + +SharedMemoryManager::SharedMemoryManager(const std::string& name, + size_t size, + bool& is_creator) + : shm_name_(name), size_(size) { + // Register cleanup handlers for signals (once per process) + static std::once_flag flag; + std::call_once(flag, [] { + signal(SIGINT, cleanup_handler); + signal(SIGTERM, cleanup_handler); + // signal(SIGSEGV, cleanup_handler); + }); + + // First try to create exclusively (O_CREAT | O_EXCL) + fd_ = shm_open(name.c_str(), O_CREAT | O_RDWR | O_EXCL, 0666); + is_creator = (fd_ != -1); + + // If creation failed, try opening existing + if (!is_creator) { + fd_ = shm_open(name.c_str(), O_RDWR, 0666); + if (fd_ == -1) { + throw std::runtime_error("shm_open failed: " + + std::string(strerror(errno))); + } + } else { + // Track created SHM for later cleanup + std::lock_guard lock(cleanup_mutex); + pending_cleanups.push_back(name); + } + + // Set size for new SHM + if (is_creator && ftruncate(fd_, size) == -1) { + close(fd_); + throw std::runtime_error("ftruncate failed: " + + std::string(strerror(errno))); + } + + // Map into process address space + addr_ = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0); + if (addr_ == MAP_FAILED) { + close(fd_); + throw std::runtime_error("mmap failed: " + std::string(strerror(errno))); + } +} + +SharedMemoryManager::~SharedMemoryManager() { + // Unmap memory + LOG(INFO) << "Delete ~SharedMemoryManager"; + if (addr_ != MAP_FAILED) { + munmap(addr_, size_); + } + + // Close descriptor + if (fd_ != -1) { + close(fd_); + } + + // Cleanup if we're the creator + std::lock_guard lock(cleanup_mutex); + auto it = + std::find(pending_cleanups.begin(), pending_cleanups.end(), shm_name_); + if (it != pending_cleanups.end()) { + shm_unlink(shm_name_.c_str()); + pending_cleanups.erase(it); + } +} + +void SharedMemoryManager::cleanup_handler(int sig) { + std::lock_guard lock(cleanup_mutex); + LOG(INFO) << "Signal: " << sig << " (" << strsignal(sig) << ")"; + for (const auto& name : pending_cleanups) { + LOG(INFO) << "SharedMemoryManager cleanup_handler name:" << name; + shm_unlink(name.c_str()); + } + exit(sig); +} + +void* SharedMemoryManager::allocate(int64_t size, int64_t alignment) { + std::lock_guard lock(mutex_); + + // Calculate aligned size and check bounds + int64_t aligned_size = (size + alignment - 1) & ~(alignment - 1); + if (current_offset_ + aligned_size > size_) { + throw std::runtime_error("Shared memory overflow"); + } + + // Return current offset and advance + void* ptr = static_cast(addr_) + current_offset_; + current_offset_ += aligned_size; + return ptr; +} +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/eplb/shared_memory_manager.h b/xllm/core/framework/eplb/shared_memory_manager.h new file mode 100644 index 00000000..0bcb3761 --- /dev/null +++ b/xllm/core/framework/eplb/shared_memory_manager.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace xllm { + +class SharedMemoryManager { + public: + explicit SharedMemoryManager(const std::string& name, + size_t size, + bool& is_creator); + + ~SharedMemoryManager(); + void* allocate(int64_t size, int64_t alignment = alignof(max_align_t)); + void* base_address() const { return addr_; } + int64_t size() const { return size_; } + + private: + std::string shm_name_; + int fd_ = -1; + void* addr_ = MAP_FAILED; + int64_t size_ = 0; + int64_t current_offset_ = 0; + std::mutex mutex_; + + static void cleanup_handler(int sig); + static std::vector pending_cleanups; + static std::mutex cleanup_mutex; +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/model/causal_lm.h b/xllm/core/framework/model/causal_lm.h index 5863bb43..0edc7b39 100644 --- a/xllm/core/framework/model/causal_lm.h +++ b/xllm/core/framework/model/causal_lm.h @@ -57,6 +57,11 @@ class CausalLM : public torch::nn::Module { virtual torch::Device device() const = 0; + virtual void prepare_expert_weight( + int32_t layer_id, + const std::vector& expert_ids) = 0; + virtual void update_expert_weight(int32_t layer_id) = 0; + virtual const torch::TensorOptions& options() const = 0; #if defined(USE_NPU) @@ -88,6 +93,15 @@ class CausalLMImpl : public CausalLM { void load_model(std::unique_ptr loader) override { model_->load_model(std::move(loader)); } + virtual void prepare_expert_weight( + int32_t layer_id, + const std::vector& expert_ids) override { + return model_->prepare_expert_weight(layer_id, expert_ids); + } + + virtual void update_expert_weight(int32_t layer_id) { + return model_->update_expert_weight(layer_id); + } #if defined(USE_NPU) hf::LlmHead get_lm_head() override { return model_->get_lm_head(); }; diff --git a/xllm/core/framework/model/causal_vlm.h b/xllm/core/framework/model/causal_vlm.h index b0816731..1af4dea8 100644 --- a/xllm/core/framework/model/causal_vlm.h +++ b/xllm/core/framework/model/causal_vlm.h @@ -56,6 +56,13 @@ class CausalVLMImpl : public CausalVLM { void load_model(std::unique_ptr loader) override { model_->load_model(std::move(loader)); } + + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + + virtual void update_expert_weight(int32_t layer_id) { return; } #if defined(USE_NPU) hf::LlmHead get_lm_head() override { return model_->get_lm_head(); }; diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index ad08be79..8d60db16 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -55,7 +55,7 @@ struct ModelInputParams { #if defined(USE_NPU) params.layer_synchronizer = layer_synchronizer; #endif - + params.expert_load_data = expert_load_data; return params; } @@ -119,6 +119,7 @@ struct ModelInputParams { #endif DpEpPaddingData dp_ep_padding_data; + torch::Tensor expert_load_data; }; } // namespace xllm diff --git a/xllm/core/framework/model/npu_dp_ep_padding.cpp b/xllm/core/framework/model/npu_dp_ep_padding.cpp index 37c34ed8..184d38cb 100644 --- a/xllm/core/framework/model/npu_dp_ep_padding.cpp +++ b/xllm/core/framework/model/npu_dp_ep_padding.cpp @@ -27,11 +27,13 @@ DpEpPadding::DpEpPadding(torch::Tensor token_size_per_dp_group, int32_t num_experts_per_tok, const nlohmann::json& mapping_npu, at::Device device, + torch::ScalarType dtype, bool is_prefill) : token_size_per_dp_group_(token_size_per_dp_group.contiguous()), num_experts_per_tok_(num_experts_per_tok), mapping_npu_(mapping_npu), device_(device), + dtype_(dtype), is_prefill_(is_prefill), expert_parallel_degree_(0) { // Validate input tensor @@ -345,11 +347,10 @@ void DpEpPadding::handle_expert_parallel() { moe_idx_data.push_back(i); } moe_idx_ = torch::tensor(moe_idx_data, torch::dtype(torch::kInt32)); - expert_array_ = safe_to( - torch::ones({moe_idx_.sizes()[0]}, torch::dtype(torch::kFloat16)) - .view({-1, 1}), - device_, - true); + expert_array_ = + safe_to(torch::ones({moe_idx_.sizes()[0]}, dtype_).view({-1, 1}), + device_, + true); } else { dynamic_ep_idx_ = torch::zeros({1}, torch::kInt32); moe_idx_ = torch::zeros({1}, torch::kInt32); diff --git a/xllm/core/framework/model/npu_dp_ep_padding.h b/xllm/core/framework/model/npu_dp_ep_padding.h index b9cc2d18..ad11756f 100644 --- a/xllm/core/framework/model/npu_dp_ep_padding.h +++ b/xllm/core/framework/model/npu_dp_ep_padding.h @@ -56,6 +56,7 @@ class DpEpPadding { int32_t num_experts_per_tok, const nlohmann::json& mapping_npu, at::Device device, + torch::ScalarType dtype, bool is_prefill); DpEpPaddingData build(); @@ -104,5 +105,6 @@ class DpEpPadding { torch::Tensor expert_array_; std::vector new_dp_size_; at::Device device_; + torch::ScalarType dtype_; }; } // namespace xllm diff --git a/xllm/core/framework/model/npu_dp_ep_padding_test.cpp b/xllm/core/framework/model/npu_dp_ep_padding_test.cpp index 4023ed47..5a84c70d 100644 --- a/xllm/core/framework/model/npu_dp_ep_padding_test.cpp +++ b/xllm/core/framework/model/npu_dp_ep_padding_test.cpp @@ -40,8 +40,12 @@ TEST(DpEpPaddingTest, Build) { MappingNPU mapping(rank_table_file, 16, 0, options); nlohmann::json data = mapping.to_json(); torch::Tensor token_size_per_dp_group = torch::tensor({10, 10}); - DpEpPadding dp_ep_padding( - token_size_per_dp_group, 8, data, torch::Device(torch::kCPU), true); + DpEpPadding dp_ep_padding(token_size_per_dp_group, + 8, + data, + torch::Device(torch::kCPU), + torch::Dtype(torch::kInt32), + true); DpEpPaddingData dp_ep_padding_data = dp_ep_padding.build(); LOG(INFO) << "attn_padding_idx:" << dp_ep_padding_data.attn_padding_idx(); LOG(INFO) << "attn_unpadding_idx:" << dp_ep_padding_data.attn_unpadding_idx(); diff --git a/xllm/core/layers/npu/CMakeLists.txt b/xllm/core/layers/npu/CMakeLists.txt index 208318de..da068f1a 100755 --- a/xllm/core/layers/npu/CMakeLists.txt +++ b/xllm/core/layers/npu/CMakeLists.txt @@ -53,6 +53,7 @@ cc_library( :kv_cache :prefix_cache :block + :eplb :parallel_state :state_dict glog::glog diff --git a/xllm/core/layers/npu/deepseek_v2_decoder_layer.cpp b/xllm/core/layers/npu/deepseek_v2_decoder_layer.cpp index d4d85b13..7efd6d47 100644 --- a/xllm/core/layers/npu/deepseek_v2_decoder_layer.cpp +++ b/xllm/core/layers/npu/deepseek_v2_decoder_layer.cpp @@ -17,6 +17,10 @@ limitations under the License. #include +#include + +#include "common/global_flags.h" + DECLARE_string(rank_tablefile); DECLARE_string(communication_backend); DECLARE_int32(expert_parallel_degree); @@ -24,75 +28,75 @@ DECLARE_int32(expert_parallel_degree); namespace xllm::hf { enum DecoderLayerTensorId : int { - IN_INPUT_NORM_WEIGHT = 0, //[7168] - IN_INPUT_NORM_BIAS = 1, //[7168] + IN_INPUT_NORM_WEIGHT = 0, + IN_INPUT_NORM_BIAS = 1, IN_INPUT_NORM_NEW_WEIGHT = 2, IN_INPUT_NORM_NEW_BIAS = 3, - IN_Q_PROJ_A_WEIGHT = 4, //[1536, 7168] - IN_Q_PROJ_A_BIAS = 5, //[1536] - IN_Q_PROJ_A_DESCALE = 6, //[1536] - IN_Q_PROJ_A_OFFSET = 7, //[1] - IN_Q_PROJ_A_SCALE = 8, //[1] + IN_Q_PROJ_A_WEIGHT = 4, + IN_Q_PROJ_A_BIAS = 5, + IN_Q_PROJ_A_DESCALE = 6, + IN_Q_PROJ_A_OFFSET = 7, + IN_Q_PROJ_A_SCALE = 8, IN_Q_PROJ_A_COMPRESS_IDX = 9, - IN_Q_PROJ_A_LAYERNORM_WEIGHT = 10, //[1536] - IN_Q_PROJ_A_LAYERNORM_BIAS = 11, //[1536] - - IN_Q_PROJ_B_WEIGHT = 12, //[6144, 1536] - IN_Q_PROJ_B_BIAS = 13, //[6144] - IN_Q_PROJ_B_DESCALE = 14, //[6144] - IN_Q_PROJ_B_OFFSET = 15, //[1] - IN_Q_PROJ_B_SCALE = 16, //[1] + IN_Q_PROJ_A_LAYERNORM_WEIGHT = 10, + IN_Q_PROJ_A_LAYERNORM_BIAS = 11, + + IN_Q_PROJ_B_WEIGHT = 12, + IN_Q_PROJ_B_BIAS = 13, + IN_Q_PROJ_B_DESCALE = 14, + IN_Q_PROJ_B_OFFSET = 15, + IN_Q_PROJ_B_SCALE = 16, IN_Q_PROJ_B_COMPRESS_IDX = 17, - IN_KV_PROJ_WITH_MQA_WEIGHT = 18, //[576, 7168] - IN_KV_PROJ_WITH_MQA_BIAS = 19, //[576] - IN_KV_PROJ_WITH_MQA_DESCALE = 20, //[576] - IN_KV_PROJ_WITH_MQA_OFFSET = 21, //[1] - IN_KV_PROJ_WITH_MQA_SCALE = 22, //[1] + IN_KV_PROJ_WITH_MQA_WEIGHT = 18, + IN_KV_PROJ_WITH_MQA_BIAS = 19, + IN_KV_PROJ_WITH_MQA_DESCALE = 20, + IN_KV_PROJ_WITH_MQA_OFFSET = 21, + IN_KV_PROJ_WITH_MQA_SCALE = 22, IN_KV_PROJ_WITH_MQA_COMPRESS_IDX = 23, - IN_KV_PROJ_A_LAYERNORM_WEIGHT = 24, //[512] + IN_KV_PROJ_A_LAYERNORM_WEIGHT = 24, IN_KV_PROJ_A_LAYERNORM_BIAS = 25, - IN_K_PROJ_B_FOR_Q_WEIGHT = 26, //[8, 128, 512] + IN_K_PROJ_B_FOR_Q_WEIGHT = 26, IN_K_PROJ_B_FOR_Q_BIAS = 27, IN_K_PROJ_B_FOR_Q_DESCALE = 28, IN_K_PROJ_B_FOR_Q_OFFSET = 29, IN_K_PROJ_B_FOR_Q_SCALE = 30, IN_K_PROJ_B_FOR_Q_COMPRESS_IDX = 31, - IN_V_PROJ_B_FOR_O_WEIGHT = 32, //[32, 512, 128] + IN_V_PROJ_B_FOR_O_WEIGHT = 32, IN_V_PROJ_B_FOR_O_BIAS = 33, IN_V_PROJ_B_FOR_O_DESCALE = 34, IN_V_PROJ_B_FOR_O_OFFSET = 35, IN_V_PROJ_B_FOR_O_SCALE = 36, IN_V_PROJ_B_FOR_O_COMPRESS_IDX = 37, - IN_ATTENTION_OUT_WEIGHT = 38, //[7168, 4096] - IN_ATTENTION_OUT_BIAS = 39, //[7168] - IN_ATTENTION_OUT_DESCALE = 40, //[7168] + IN_ATTENTION_OUT_WEIGHT = 38, + IN_ATTENTION_OUT_BIAS = 39, + IN_ATTENTION_OUT_DESCALE = 40, IN_ATTENTION_OUT_OFFSET = 41, IN_ATTENTION_OUT_SCALE = 42, IN_ATTENTION_OUT_COMPRESS_IDX = 43, - IN_SELFATTENTION_OUT_NORM_WEIGHT = 44, //[7168] + IN_SELFATTENTION_OUT_NORM_WEIGHT = 44, IN_SELFATTENTION_OUT_NORM_BIAS = 45, IN_SELFATTENTION_OUT_NEW_NORM_WEIGHT = 46, IN_SELFATTENTION_OUT_NEW_NORM_BIAS = 47, - IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT = 48, //[1024, 7168]] + IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT = 48, IN_MLP_GATEUP_BIAS_SHARED_EXPERT = 49, IN_MLP_GATEUP_DESCALE_SHARED_EXPERT = 50, - IN_MLP_GATEUP_OFFSET_SHARED_EXPERT = 51, //[1024] - IN_MLP_GATEUP_SCALE_SHARED_EXPERT = 52, //[1024] + IN_MLP_GATEUP_OFFSET_SHARED_EXPERT = 51, + IN_MLP_GATEUP_SCALE_SHARED_EXPERT = 52, IN_MLP_GATEUP_COMPRESS_IDX_SHARED_EXPERT = 53, - IN_MLP_DOWN_WEIGHT_SHARED_EXPERT = 54, //[7168, 512] + IN_MLP_DOWN_WEIGHT_SHARED_EXPERT = 54, IN_MLP_DOWN_BIAS_SHARED_EXPERT = 55, IN_MLP_DOWN_DESCALE_SHARED_EXPERT = 56, - IN_MLP_DOWN_OFFSET_SHARED_EXPERT = 57, //[7168] - IN_MLP_DOWN_SCALE_SHARED_EXPERT = 58, //[7168] + IN_MLP_DOWN_OFFSET_SHARED_EXPERT = 57, + IN_MLP_DOWN_SCALE_SHARED_EXPERT = 58, IN_MLP_DOWN_COMPRESS_IDX_SHARED_EXPERT = 59, IN_SHARED_EXPERT_GATE_WEIGHT = 60, @@ -102,25 +106,25 @@ enum DecoderLayerTensorId : int { IN_SHARED_EXPERT_GATE_SCALE = 64, IN_SHARED_EXPERT_GATE_COMPRESS_IDX = 65, - IN_BLOCK_SPARSE_MOE_GATE_WEIGHT = 66, //[256, 7168] - IN_BLOCK_SPARSE_MOE_GATE_BIAS = 67, //[256] + IN_BLOCK_SPARSE_MOE_GATE_WEIGHT = 66, + IN_BLOCK_SPARSE_MOE_GATE_BIAS = 67, IN_BLOCK_SPARSE_MOE_GATE_DESCALE = 68, IN_BLOCK_SPARSE_MOE_GATE_OFFSET = 69, IN_BLOCK_SPARSE_MOE_GATE_SCALE = 70, IN_BLOCK_SPARSE_MOE_GATE_COMPRESS_IDX = 71, - IN_MLP_GATEUP_WEIGHT_EXPERT = 72, //[256, 7168, 1024] + IN_MLP_GATEUP_WEIGHT_EXPERT = 72, IN_MLP_GATEUP_BIAS_EXPERT = 73, IN_MLP_GATEUP_DESCALE_EXPERT = 74, - IN_MLP_GATEUP_OFFSET_EXPERT = 75, //[256, 1024] - IN_MLP_GATEUP_SCALE_EXPERT = 76, //[256, 1024] + IN_MLP_GATEUP_OFFSET_EXPERT = 75, + IN_MLP_GATEUP_SCALE_EXPERT = 76, IN_MLP_GATEUP_COMPRESS_IDX_EXPERT = 77, - IN_MLP_DOWN_WEIGHT_EXPERT = 78, //[256, 512, 7168] + IN_MLP_DOWN_WEIGHT_EXPERT = 78, IN_MLP_DOWN_BIAS_EXPERT = 79, IN_MLP_DOWN_DESCALE_EXPERT = 80, - IN_MLP_DOWN_OFFSET_EXPERT = 81, //[256, 7168] - IN_MLP_DOWN_SCALE_EXPERT = 82, //[256, 7168] + IN_MLP_DOWN_OFFSET_EXPERT = 81, + IN_MLP_DOWN_SCALE_EXPERT = 82, IN_MLP_DOWN_COMPRESS_IDX_EXPERT = 83, }; @@ -259,12 +263,19 @@ DeepseekV2DecoderImpl::DeepseekV2DecoderImpl(const Context& context, auto parallel_args = context.get_parallel_args(); auto model_args = context.get_model_args(); auto options = context.get_tensor_options(); - + rank_ = parallel_args.rank(); + first_k_dense_replace_ = model_args.first_k_dense_replace(); + n_layers_ = model_args.n_layers(); + num_experts_ = model_args.n_routed_experts(); + localWorldSize_ = parallel_args.mapping().localWorldSize(); ep_size_ = parallel_args.ep_size(); ep_local_tp_size_ = parallel_args.world_size() / ep_size_; CHECK_EQ(parallel_args.world_size(), ep_size_ * ep_local_tp_size_); ep_local_tp_rank_ = parallel_args.rank() % ep_local_tp_size_; num_experts_per_partition_ = model_args.n_routed_experts() / ep_size_; + if (FLAGS_enable_eplb) { + num_experts_per_partition_++; + } ep_rank_ = parallel_args.rank() / ep_local_tp_size_; start_expert_id_ = ep_rank_ * num_experts_per_partition_; end_expert_id_ = start_expert_id_ + num_experts_per_partition_ - 1; @@ -291,7 +302,8 @@ void DeepseekV2DecoderImpl::initialize_tensors( block_tables_placeholder_ = torch::zeros({1, 1}).to(torch::kInt32).to(device_); tensor_placeholder_ = torch::zeros({1}).to(options); - resize_experts_weights(prefill_param_.numOfDeviceExperts); + + reserve_experts_weights(prefill_param_.numOfDeviceExperts); expert_group_ = torch::arange(1024, torch::kInt32).to(device_); one_hot_ = torch::tensor({1}, torch::kInt32).to(device_); zero_hot_ = torch::tensor({0}, torch::kInt32).to(device_); @@ -301,6 +313,32 @@ void DeepseekV2DecoderImpl::initialize_tensors( torch::tensor({num_experts_per_partition_ - 1}, torch::kInt64) .to(device_); initialize_weight_tensors(options); + initialize_device_expert_list(decode_param_.worldSize, + num_experts_per_partition_); + if (FLAGS_enable_eplb) { + auto layer_expert_routing_map_ = + build_expert_routing_map(device_expert_list_); + std::vector tensors_vec; + for (int i = 0; i < n_layers_ - first_k_dense_replace_; i++) { + tensors_vec.emplace_back(layer_expert_routing_map_); + } + expert_routing_map_ = torch::stack(tensors_vec, 0); + } +} + +void DeepseekV2DecoderImpl::initialize_device_expert_list( + int num_device, + int num_device_route_expert) { + if (FLAGS_enable_eplb) { + --num_device_route_expert; + } + for (int i = 0; i < num_device * num_device_route_expert; ++i) { + std::vector subvec; + device_expert_list_.emplace_back(i); + if (FLAGS_enable_eplb && (i + 1) % num_device_route_expert == 0) { + device_expert_list_.emplace_back(i); + } + } } void DeepseekV2DecoderImpl::param_from_args( @@ -316,26 +354,23 @@ void DeepseekV2DecoderImpl::param_from_args( initialize_kimi_k2_parameters(param, args, is_prefill); } -void DeepseekV2DecoderImpl::resize_experts_weights(int num_of_device_experts) { - experts_weights_["gate_proj.weight"] = - std::vector(num_of_device_experts); - experts_weights_["up_proj.weight"] = - std::vector(num_of_device_experts); - experts_weights_["down_proj.weight"] = - std::vector(num_of_device_experts); +void DeepseekV2DecoderImpl::reserve_experts_weights(int num_of_device_experts) { + experts_weights_.clear(); + std::vector weight_names = { + "gate_proj.weight", "up_proj.weight", "down_proj.weight"}; if (quantize_type_ == "w8a8_dynamic") { - experts_weights_["gate_proj.weight_offset"] = - std::vector(num_of_device_experts); - experts_weights_["up_proj.weight_offset"] = - std::vector(num_of_device_experts); - experts_weights_["down_proj.weight_offset"] = - std::vector(num_of_device_experts); - experts_weights_["gate_proj.weight_scale"] = - std::vector(num_of_device_experts); - experts_weights_["up_proj.weight_scale"] = - std::vector(num_of_device_experts); - experts_weights_["down_proj.weight_scale"] = + weight_names.emplace_back("gate_proj.weight_offset"); + weight_names.emplace_back("up_proj.weight_offset"); + weight_names.emplace_back("down_proj.weight_offset"); + weight_names.emplace_back("gate_proj.weight_scale"); + weight_names.emplace_back("up_proj.weight_scale"); + weight_names.emplace_back("down_proj.weight_scale"); + } + std::lock_guard lock(experts_mutex_); + for (const auto& weight_name : weight_names) { + experts_weights_[weight_name] = std::vector(num_of_device_experts); + ; } } @@ -344,6 +379,12 @@ void DeepseekV2DecoderImpl::initialize_weight_tensors( for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { at_weight_tensors_[i] = torch::zeros({1}).to(options); } + if (FLAGS_enable_eplb) { + const int64_t size = + 50LL * 1024LL * 1024LL * int64_t(n_layers_ - first_k_dense_replace_); + shared_buffer_ = std::make_unique( + num_experts_, n_layers_ - first_k_dense_replace_, size); + } } void DeepseekV2DecoderImpl::initialize_basic_parameters( @@ -415,10 +456,6 @@ void DeepseekV2DecoderImpl::initialize_attention_parameters( param.qkNopeHeadDim = args.qk_nope_head_dim(); param.qkRopeHeadDim = args.qk_rope_head_dim(); param.kvLoraRank = args.kv_lora_rank(); - - // sm_scale_ shows approximately 9 decimal places difference when compared - // across different engines, which may cause minimal diff during the decode - // phase param.softmaxScale = sm_scale_; if (quantize_type_ == "w8a8_dynamic" && num_speculative_tokens_ == 0) { param.enableMlaPreprocess = param.isBF16 ? false : true; @@ -484,11 +521,14 @@ void DeepseekV2DecoderImpl::initialize_mlp_parameters( parallel_args.dispatchAndCombinecommDomain(); if (layer_id_ >= param.firstKDenseReplace) { - // param.enableQkvdownDp = (param.expertParallelDegree==1 && - // param.isPrefill) ? true:false; param.enableQkvdownDp = false; - param.enableSharedExpertDp = false; // TODO - param.enableGatingDp = false; // TODO + param.enableSharedExpertDp = false; + param.enableGatingDp = false; + if (FLAGS_enable_eplb) { + param.enableExpertCumSumOutput = param.isPrefill ? false : true; + param.enableEPWB = true; + param.numOfRedundantExpert = ep_size_; + } } if (layer_id_ < param.firstKDenseReplace) { param.isDenseLayer = true; @@ -509,7 +549,6 @@ void DeepseekV2DecoderImpl::initialize_kimi_k2_parameters( param.enableFusedTopk = (args.topk_method() == "noaux_tc" && args.n_group() * 32 >= args.n_routed_experts()); param.maskfree = is_prefill; - // TODO: Pending confirmation whether kimi_k2 model supports // enable_gmmswigluquant set to true bool enable_gmmswigluquant = false; @@ -629,23 +668,70 @@ int DeepseekV2DecoderImpl::get_mapped_index( return it->second; } +std::string DeepseekV2DecoderImpl::get_expert_shm_key(int32_t layer_id, + int32_t expert_index, + std::string suffix) { + std::string shm_key = + "layer_" + std::to_string(layer_id - first_k_dense_replace_) + "_" + + "expert_" + std::to_string(expert_index) + "_" + suffix; + return shm_key; +} + void DeepseekV2DecoderImpl::process_expert_weights( const StateDict& state_dict, const std::string& name, const torch::Tensor& tensor) { int expert_index = extract_expert_index(name); - if (expert_index < start_expert_id_ || expert_index > end_expert_id_) { - return; - } - const std::string suffix = extract_endswith(name); const int index = get_mapped_index(suffix, WEIGHT_MAPPING_W8A8); if (index == -1) { return; } - const int local_index = expert_index % num_experts_per_partition_; const bool is_sharded = WEIGHT_SHARD_W8A8.count(index); - + if (FLAGS_enable_eplb && + (rank_ % localWorldSize_ == expert_index % localWorldSize_)) { + std::lock_guard lock(experts_mutex_); + torch::Tensor tmp_tensor_shm = + is_sharded ? get_sharded_tensor(state_dict, + name, + WEIGHT_SHARD_W8A8.at(index), + ep_local_tp_rank_, + ep_local_tp_size_) + : tensor; + std::string shm_key = get_expert_shm_key(layer_id_, expert_index, suffix); + if (!decode_param_.isBF16) { + if (absl::EndsWith(name, "_offset")) { + tmp_tensor_shm = tmp_tensor_shm.to(torch::kFloat16); + } else if (absl::EndsWith(name, "_scale")) { + tmp_tensor_shm = tmp_tensor_shm.to(torch::kFloat32); + } + } + shared_buffer_->add_tensor(expert_index, + layer_id_ - first_k_dense_replace_, + shm_key, + tmp_tensor_shm.contiguous()); + // all_experts_weights_buffer_[shm_key].emplace_back(tmp_tensor.clone()); + } + const int start_idx = ep_rank_ * num_experts_per_partition_; + const int end_idx = (ep_rank_ + 1) * num_experts_per_partition_; + const int safe_end = + std::min(end_idx, static_cast(device_expert_list_.size())); + + auto it = std::find(device_expert_list_.begin() + start_idx, + device_expert_list_.begin() + safe_end, + expert_index); + if (it == device_expert_list_.begin() + safe_end) { + return; + } + std::vector matches_pos; + for (auto iter = device_expert_list_.begin() + start_idx; + iter != device_expert_list_.begin() + safe_end; + ++iter) { + if (*iter == expert_index) { + matches_pos.emplace_back( + std::distance(device_expert_list_.begin(), iter) - start_idx); + } + } std::lock_guard lock(experts_mutex_); torch::Tensor tmp_tensor = is_sharded ? get_sharded_tensor(state_dict, @@ -655,7 +741,9 @@ void DeepseekV2DecoderImpl::process_expert_weights( ep_local_tp_size_) : tensor; - experts_weights_[suffix][local_index] = tmp_tensor.clone(); + for (auto pos : matches_pos) { + experts_weights_[suffix][pos] = tmp_tensor.clone(); + } } void DeepseekV2DecoderImpl::process_shared_expert_weights( @@ -854,7 +942,7 @@ std::string DeepseekV2DecoderImpl::extract_endswith(const std::string& input) { std::stringstream ss(input); std::string part; while (std::getline(ss, part, '.')) { - parts.push_back(part); + parts.emplace_back(part); } if (parts.size() < 2) { return ""; @@ -900,7 +988,9 @@ void DeepseekV2DecoderImpl::merge_loaded_weights() { if (layer_id_ >= prefill_param_.firstKDenseReplace) { merge_experts_weights(); } + squeeze_experts_weights(); + preprocess_linear_for_rope(); at_weight_tensors_[IN_Q_PROJ_A_WEIGHT] = @@ -1066,6 +1156,7 @@ void DeepseekV2DecoderImpl::merge_experts_weights() { torch::Tensor mlp_gateup_weight = merge_experts_weights(experts_weights_["gate_proj.weight"], experts_weights_["up_proj.weight"], + device_, /*transpose=*/true); at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT] = at_npu::native::npu_format_cast(mlp_gateup_weight, 29); @@ -1074,10 +1165,12 @@ void DeepseekV2DecoderImpl::merge_experts_weights() { if (quantize_type_ == "w8a8_dynamic") { at_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT] = merge_experts_weights(experts_weights_["gate_proj.weight_offset"], - experts_weights_["up_proj.weight_offset"]); + experts_weights_["up_proj.weight_offset"], + device_); at_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT] = merge_experts_weights(experts_weights_["gate_proj.weight_scale"], - experts_weights_["up_proj.weight_scale"]); + experts_weights_["up_proj.weight_scale"], + device_); } #if defined(USE_A3) @@ -1102,42 +1195,226 @@ void DeepseekV2DecoderImpl::merge_experts_weights() { } #endif if (quantize_type_ == "w8a8_dynamic") { - at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = - merge_experts_weights(experts_weights_["down_proj.weight_offset"]); - at_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT] = - merge_experts_weights(experts_weights_["down_proj.weight_scale"]); + at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = merge_experts_weights( + experts_weights_["down_proj.weight_offset"], device_); + at_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT] = merge_experts_weights( + experts_weights_["down_proj.weight_scale"], device_); } } torch::Tensor DeepseekV2DecoderImpl::merge_experts_weights( std::vector& experts, + at::Device device, bool transpose) { - torch::Tensor merged_tensor = torch::stack(experts, 0).to(device_); + torch::Tensor merged_tensor = torch::stack(experts, 0).to(device); if (transpose) { merged_tensor = merged_tensor.transpose(1, 2); } merged_tensor = merged_tensor.contiguous(); - experts.clear(); return merged_tensor; } torch::Tensor DeepseekV2DecoderImpl::merge_experts_weights( std::vector& experts_gate, std::vector& experts_up, + at::Device device, bool transpose) { + auto merge_experts_weights_sart = std::chrono::high_resolution_clock::now(); + for (size_t i = 0; i < experts_up.size(); ++i) { experts_gate[i] = torch::cat({experts_gate[i], experts_up[i]}, 0); } - torch::Tensor merged_tensor = torch::stack(experts_gate, 0).to(device_); + + torch::Tensor merged_tensor = torch::stack(experts_gate, 0).to(device); + if (transpose) { merged_tensor = merged_tensor.transpose(1, 2); } + merged_tensor = merged_tensor.contiguous(); - experts_gate.clear(); - experts_up.clear(); return merged_tensor; } +void DeepseekV2DecoderImpl::merge_and_copy_gate_up_weights( + torch::Tensor& + target_buffer, // [num_experts, hidden_dim, gate_dim + up_dim] + const std::vector& experts_gate, // [gate_dim, hidden_dim] + const std::vector& experts_up, // [up_dim, hidden_dim] + bool do_transpose) { + const int64_t num_experts = experts_gate.size(); + const int64_t gate_dim = experts_gate[0].size(0); + const int64_t up_dim = experts_up[0].size(0); + const int64_t hidden_dim = experts_gate[0].size(1); + + auto prepare_experts_weights_start = + std::chrono::high_resolution_clock::now(); + target_buffer = at_npu::native::npu_format_cast(target_buffer.contiguous(), 2) + .reshape({num_experts, gate_dim + up_dim, hidden_dim}); + + prepare_experts_weights_start = std::chrono::high_resolution_clock::now(); + + for (int64_t index = 0; index < num_experts; ++index) { + target_buffer[index].slice(0, 0, gate_dim).copy_(experts_gate[index]); + + target_buffer[index] + .slice(0, gate_dim, gate_dim + up_dim) + .copy_(experts_up[index]); + } + + if (do_transpose) { + target_buffer = target_buffer.transpose(1, 2).contiguous(); + ; + } +} + +void DeepseekV2DecoderImpl::merge_and_copy_down_weights( + torch::Tensor& target_buffer, + const std::vector& experts_down) { + const int64_t num_experts = experts_down.size(); + + for (int64_t index = 0; index < num_experts; ++index) { + target_buffer[index].copy_(experts_down[index]); + } +} + +void DeepseekV2DecoderImpl::prepare_expert_weight( + const std::vector& expert_list) { + auto prepare_experts_weights_start = + std::chrono::high_resolution_clock::now(); + + expert_routing_map_buffer_ = build_expert_routing_map(expert_list); + auto& expert_buffer = ExpertBuffer::Instance(); + + const int32_t num_local_experts = num_experts_per_partition_; + const int32_t hidden_dim = + at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT].size(1); + const int32_t combined_dim = + at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT].size(2); + const int32_t gate_dim = combined_dim / 2; + + expert_buffer.initialize_or_reuse( + /*gateup_weight_shape*/ {num_experts_per_partition_, + hidden_dim, + combined_dim}, + /*gateup_offset_shape*/ {num_experts_per_partition_, combined_dim, 1}, + /*gateup_scale_shape*/ {num_experts_per_partition_, combined_dim, 1}, + /*down_weight_shape*/ {num_experts_per_partition_, hidden_dim, gate_dim}, + /*down_offset_shape*/ {num_experts_per_partition_, hidden_dim, 1}, + /*down_scale_shape*/ {num_experts_per_partition_, hidden_dim, 1}, + at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT].options(), + at_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT].options(), + at_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT].options() + + ); + + const int start_expert_idx = num_experts_per_partition_ * ep_rank_; + const int end_expert_idx = start_expert_idx + num_experts_per_partition_ - 1; + + for (const auto& pair : experts_weights_) { + for (int expert_idx = start_expert_idx; expert_idx <= end_expert_idx; + ++expert_idx) { + std::string shm_key = + get_expert_shm_key(layer_id_, expert_list[expert_idx], pair.first); + experts_weights_[pair.first][expert_idx - start_expert_idx] = + shared_buffer_->get_tensor(expert_list[expert_idx], + layer_id_ - first_k_dense_replace_, + shm_key); + // experts_weights_[pair.first][expert_idx] = + // shared_buffer_->get_tensors(shm_key); + } + } + + merge_and_copy_gate_up_weights(expert_buffer.gateup_weight, + experts_weights_["gate_proj.weight"], + experts_weights_["up_proj.weight"], + /*do_transpose=*/true); + + merge_and_copy_gate_up_weights(expert_buffer.gateup_offset, + experts_weights_["gate_proj.weight_offset"], + experts_weights_["up_proj.weight_offset"]); + + merge_and_copy_gate_up_weights(expert_buffer.gateup_scale, + experts_weights_["gate_proj.weight_scale"], + experts_weights_["up_proj.weight_scale"]); + + merge_and_copy_down_weights(expert_buffer.down_weight, + experts_weights_["down_proj.weight"]); + + merge_and_copy_down_weights(expert_buffer.down_offset, + experts_weights_["down_proj.weight_offset"]); + + merge_and_copy_down_weights(expert_buffer.down_scale, + experts_weights_["down_proj.weight_scale"]); + + expert_buffer.gateup_weight = + at_npu::native::npu_format_cast(expert_buffer.gateup_weight, 29); + auto prepare_experts_weights_end = std::chrono::high_resolution_clock::now(); + auto prepare__experts_weights_duration = + std::chrono::duration_cast( + prepare_experts_weights_end - prepare_experts_weights_start) + .count(); +} + +torch::Tensor DeepseekV2DecoderImpl::build_expert_routing_map( + std::vector expert_lists) { + std::unordered_map> expert_routing_map; + + for (int64_t i = 0; i < expert_lists.size(); ++i) { + int64_t v = expert_lists[i]; + expert_routing_map[v].emplace_back(i); + } + + for (auto& [key, indices] : expert_routing_map) { + int num_of_duplications = indices.size(); + int selected_index = ep_rank_ % num_of_duplications; + indices = {indices[selected_index]}; + } + + int64_t map_size = expert_routing_map.size(); + auto options = torch::TensorOptions().dtype(torch::kInt32); + auto input = torch::zeros({map_size}, options); + std::vector keys; + std::vector values; + + for (const auto& [k, v] : expert_routing_map) { + keys.emplace_back(k); + values.emplace_back(static_cast(v[0])); + } + + auto index_tensor = torch::tensor(keys, torch::kInt64); + auto value_tensor = torch::tensor(values, torch::kInt32); + auto result = input.scatter(0, index_tensor, value_tensor).to(device_); + // result = result.reshape({ep_size_,result.size(0)/ep_size_}).contiguous(); + return result; +} + +void DeepseekV2DecoderImpl::update_expert_weight() { + auto& expert_buffer = ExpertBuffer::Instance(); + const auto tensor_pairs = { + std::make_pair(IN_MLP_GATEUP_WEIGHT_EXPERT, + std::ref(expert_buffer.gateup_weight)), + std::make_pair(IN_MLP_GATEUP_OFFSET_EXPERT, + std::ref(expert_buffer.gateup_offset)), + std::make_pair(IN_MLP_GATEUP_SCALE_EXPERT, + std::ref(expert_buffer.gateup_scale)), + std::make_pair(IN_MLP_DOWN_WEIGHT_EXPERT, + std::ref(expert_buffer.down_weight)), + std::make_pair(IN_MLP_DOWN_OFFSET_EXPERT, + std::ref(expert_buffer.down_offset)), + std::make_pair(IN_MLP_DOWN_SCALE_EXPERT, + std::ref(expert_buffer.down_scale))}; + for (auto& [index, buffer_tensor] : tensor_pairs) { + std::swap(at_weight_tensors_[index], buffer_tensor); + atb_weight_tensors_[index] = + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[index]); + prefill_node_.inTensors.at(index) = &atb_weight_tensors_[index]; + decode_node_.inTensors.at(index) = &atb_weight_tensors_[index]; + } + expert_routing_map_[layer_id_ - first_k_dense_replace_] = + expert_routing_map_buffer_; + expert_routing_map_ = expert_routing_map_.contiguous(); +} + void DeepseekV2DecoderImpl::squeeze_experts_weights() { for (const auto& index : SQUEEZE_WEIGHT_VEC) { if (at_weight_tensors_[index].dim() > 1) { @@ -1169,7 +1446,14 @@ int64_t DeepseekV2DecoderImpl::init_node( return -1; } node.inTensors.resize(node.operation->GetInputNum()); - node.outTensors.resize(1); + + if (FLAGS_enable_eplb && layer_id_ >= decode_param_.firstKDenseReplace && + !param.isPrefill) { + node.outTensors.resize(2); + } else { + node.outTensors.resize(1); + } + size_t inTensorId = 1; for (size_t weightTensorId = 0; weightTensorId < WEIGHT_COUNT_PER_LAYER; @@ -1179,8 +1463,15 @@ int64_t DeepseekV2DecoderImpl::init_node( node.variantPack.inTensors.reserve(node.inTensors.size()); node.variantPack.inTensors.resize(node.inTensors.size()); - node.variantPack.outTensors.reserve(1); - node.variantPack.outTensors.resize(1); + + if (FLAGS_enable_eplb && layer_id_ >= decode_param_.firstKDenseReplace && + !param.isPrefill) { + node.variantPack.outTensors.reserve(2); + node.variantPack.outTensors.resize(2); // TODO + } else { + node.variantPack.outTensors.reserve(1); + node.variantPack.outTensors.resize(1); + } return atb::NO_ERROR; } @@ -1197,6 +1488,7 @@ torch::Tensor DeepseekV2DecoderImpl::forward( std::atomic* event_flag, int node_id) { atb::Status st; + if (input_params.global_empty_kv_cache) { build_node_variant_pack(prefill_node_, x, @@ -1352,6 +1644,16 @@ void DeepseekV2DecoderImpl::build_node_variant_pack( } node.variantPack.outTensors.at(0) = internal_tensor_; + + if (FLAGS_enable_eplb && layer_id_ >= decode_param_.firstKDenseReplace) { + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 30) = + atb_speed::Utils::AtTensor2Tensor(expert_routing_map_); + if (!is_prefill) { + node.variantPack.outTensors.at(1) = atb_speed::Utils::AtTensor2Tensor( + input_params + .expert_load_data[layer_id_ - decode_param_.firstKDenseReplace]); + } + } } DeepseekV2Decoder::DeepseekV2Decoder(const Context& context, diff --git a/xllm/core/layers/npu/deepseek_v2_decoder_layer.h b/xllm/core/layers/npu/deepseek_v2_decoder_layer.h index 87dcb18d..471fade3 100644 --- a/xllm/core/layers/npu/deepseek_v2_decoder_layer.h +++ b/xllm/core/layers/npu/deepseek_v2_decoder_layer.h @@ -22,6 +22,9 @@ limitations under the License. #include #include "atb_base.h" +#include "atb_layers/models/deepseekv2/layer/decoder_layer.h" +#include "framework/eplb/expert_buffer_manager.h" +#include "framework/eplb/expert_weight_buffer_shm.h" #include "framework/model/model_args.h" #include "framework/model/npu_dp_ep_padding.h" #include "framework/parallel_state.h" @@ -30,6 +33,78 @@ limitations under the License. #include "xllm_kernels/models/deepseekv2/layer/decoder_layer.h" namespace xllm::hf { +class ExpertBuffer { + public: + torch::Tensor gateup_weight; + torch::Tensor gateup_offset; + torch::Tensor gateup_scale; + torch::Tensor down_weight; + torch::Tensor down_offset; + torch::Tensor down_scale; + + static ExpertBuffer& Instance() { + static ExpertBuffer instance; + return instance; + } + + void initialize_or_reuse(const std::vector& gateup_weight_shape, + const std::vector& gateup_offset_shape, + const std::vector& gateup_scale_shape, + const std::vector& down_weight_shape, + const std::vector& down_offset_shape, + const std::vector& down_scale_shape, + const torch::TensorOptions& weight_options, + const torch::TensorOptions& offset_options, + const torch::TensorOptions& scale_options, + + bool force_reinit = false) { + std::lock_guard lock(mutex_); + + if (force_reinit) { + initialized_ = false; + } + + if (!initialized_) { + gateup_weight = + torch::empty(gateup_weight_shape, weight_options).contiguous(); + gateup_offset = + torch::empty(gateup_offset_shape, offset_options).contiguous(); + gateup_scale = + torch::empty(gateup_scale_shape, scale_options).contiguous(); + down_weight = + torch::empty(down_weight_shape, weight_options).contiguous(); + down_offset = + torch::empty(down_offset_shape, offset_options).contiguous(); + down_scale = torch::empty(down_scale_shape, scale_options).contiguous(); + initialized_ = true; + } else { + auto validate_shape = [](const torch::Tensor& t, + const std::vector& expected) { + TORCH_CHECK(t.sizes() == expected, + "Shape mismatch. Expected ", + expected, + " got ", + t.sizes()); + }; + + validate_shape(gateup_weight, gateup_weight_shape); + validate_shape(gateup_offset, gateup_offset_shape); + validate_shape(down_weight, down_weight_shape); + validate_shape(down_offset, down_offset_shape); + // gateup_weight = at_npu::native::npu_format_cast( + // gateup_weight.contiguous(), 2); + gateup_offset = gateup_offset.contiguous(); + gateup_scale = gateup_scale.contiguous(); + down_weight = down_weight.contiguous(); + down_offset = down_offset.contiguous(); + down_scale = down_scale.contiguous(); + } + } + + private: + std::mutex mutex_; + bool initialized_ = false; +}; class DeepseekV2DecoderImpl : public torch::nn::Module, public ATBBase { public: @@ -45,6 +120,10 @@ class DeepseekV2DecoderImpl : public torch::nn::Module, public ATBBase { void merge_loaded_weights(); + void prepare_expert_weight(const std::vector& expert_list); + + void update_expert_weight(); + torch::Tensor block_tables_placeholder_; torch::Tensor forward(torch::Tensor& x, @@ -75,8 +154,8 @@ class DeepseekV2DecoderImpl : public torch::nn::Module, public ATBBase { const ParallelArgs& parallel_args, bool is_prefill); - void resize_experts_weights(int num_of_device_experts); - + void reserve_experts_weights(int num_of_device_experts); + void initialize_device_expert_list(int numdevice, int num_layers); void initialize_basic_parameters( atb_speed::deepseekV2::DecoderLayerParam& param, const ModelArgs& args, @@ -117,6 +196,10 @@ class DeepseekV2DecoderImpl : public torch::nn::Module, public ATBBase { std::string extract_endswith(const std::string& input); + std::string get_expert_shm_key(int32_t layer_id, + int32_t expert_ids, + std::string suffix); + torch::Tensor build_expert_routing_map(std::vector expert_lists); void set_kv_weight(const StateDict& state_dict, const std::string& tensor_name, int weight_position, @@ -166,12 +249,23 @@ class DeepseekV2DecoderImpl : public torch::nn::Module, public ATBBase { torch::Tensor trans_rope_weight(torch::Tensor weight); torch::Tensor merge_experts_weights(std::vector& experts, + at::Device device, bool transpose = false); torch::Tensor merge_experts_weights(std::vector& experts_up, std::vector& experts_gate, + at::Device device, bool transpose = false); + void merge_and_copy_gate_up_weights( + torch::Tensor& target_buffer, + const std::vector& experts_gate, + const std::vector& experts_up, + bool do_transpose = false); + void merge_and_copy_down_weights( + torch::Tensor& target_buffer, + const std::vector& experts_down); + int64_t init_layer(); int64_t init_node(atb_speed::Model::Node& node, @@ -196,6 +290,10 @@ class DeepseekV2DecoderImpl : public torch::nn::Module, public ATBBase { int32_t kv_lora_rank_; int32_t qk_rope_head_dim_; + int32_t rank_; + int32_t first_k_dense_replace_; + int32_t n_layers_; + int32_t localWorldSize_; int32_t ep_size_; int32_t num_experts_; int32_t num_experts_per_partition_; @@ -220,6 +318,7 @@ class DeepseekV2DecoderImpl : public torch::nn::Module, public ATBBase { atb::Tensor internal_tensor_; + torch::Tensor at_cumsum_; torch::Tensor tensor_placeholder_; torch::Tensor slot_tensor_placeholder_; torch::Tensor int_tensor_placeholder_; @@ -232,12 +331,19 @@ class DeepseekV2DecoderImpl : public torch::nn::Module, public ATBBase { torch::Tensor at_in_device_expert_count_; std::vector int_placeholder_; + std::vector device_expert_list_; std::unordered_map shared_experts_weights_; std::unordered_map> experts_weights_; + std::unordered_map> + all_experts_weights_buffer_; std::mutex shared_experts_mutex_; std::mutex experts_mutex_; + + std::unique_ptr shared_buffer_ = nullptr; + torch::Tensor expert_routing_map_; + torch::Tensor expert_routing_map_buffer_; }; class DeepseekV2Decoder diff --git a/xllm/core/runtime/CMakeLists.txt b/xllm/core/runtime/CMakeLists.txt index 3545a651..76529876 100644 --- a/xllm/core/runtime/CMakeLists.txt +++ b/xllm/core/runtime/CMakeLists.txt @@ -57,6 +57,7 @@ cc_library( :worker_service :xllm_server $<$:xllm_ops> + :eplb glog::glog Folly::folly absl::strings diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index 67f4e730..4f507977 100644 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -89,6 +89,7 @@ struct ForwardInput { inputs.input_params = input_params.to(device); inputs.sampling_params = sampling_params.to(device, dtype); inputs.transfer_kv_infos = transfer_kv_infos; + inputs.eplb_info = eplb_info; return inputs; } // flatten token ids @@ -99,6 +100,7 @@ struct ForwardInput { SamplingParameters sampling_params; // kv info for disaggregated prefill/decode std::vector transfer_kv_infos; + EplbInfo eplb_info; }; // output after forward execution @@ -112,6 +114,10 @@ struct ForwardOutput { SampleOutput sample_output; torch::Tensor logits; torch::Tensor embedding; + + torch::Tensor expert_load_data; + + int32_t prepared_layer_id; }; // Model input with raw data, which will be @@ -138,6 +144,7 @@ struct RawForwardInput { std::vector dp_global_token_nums; // kv info for disaggregated prefill/decode std::vector transfer_kv_infos; + EplbInfo eplb_info; std::vector> embeddings; // num of prefill sequence in chunked prefill case uint32_t prefill_seq_len; @@ -151,6 +158,8 @@ struct RawSampleOutput { struct RawForwardOutput { std::vector outputs; // num seqs + std::vector expert_load_data; + int32_t prepared_layer_id; }; } // namespace xllm diff --git a/xllm/core/runtime/llm_engine.cpp b/xllm/core/runtime/llm_engine.cpp index 2af9fd3f..39877fa6 100644 --- a/xllm/core/runtime/llm_engine.cpp +++ b/xllm/core/runtime/llm_engine.cpp @@ -122,6 +122,19 @@ bool LLMEngine::init_model() { n_local_q_heads_ = std::max(1, n_heads / world_size); head_dim_ = args_.head_dim(); dtype_ = util::parse_dtype(args_.dtype(), options_.devices()[0]); + if (FLAGS_enable_eplb) { + int32_t num_layers = args_.n_layers() - args_.first_k_dense_replace(); + int32_t num_experts = args_.n_routed_experts(); + expert_load_data_ = + torch::zeros({num_layers, num_experts + worker_clients_.size()}) + .to(torch::kInt64); + eplb_policy_ = + std::make_unique(num_experts / worker_clients_.size() + 1, + worker_clients_.size(), + num_layers); + eplb_manager_ = std::make_unique( + eplb_policy_.get(), num_layers, worker_clients_.size(), num_experts); + } // key + value for all layers LOG(INFO) << "Block info, block_size: " << options_.block_size() @@ -486,6 +499,7 @@ ForwardOutput LLMEngine::step(std::vector& batch) { raw_forward_inputs.reserve(dp_size); std::vector dp_global_token_nums(dp_size); bool global_empty_kv_cache = true; + EplbInfo eplb_info; for (auto dp_rank = 0; dp_rank < dp_size; ++dp_rank) { // assume the order in workers_ is its rank RawForwardInput raw_forward_input = batch[dp_rank].prepare_forward_input(); @@ -497,11 +511,17 @@ ForwardOutput LLMEngine::step(std::vector& batch) { std::vector>> futures; futures.reserve(worker_clients_num); + if (FLAGS_enable_eplb) { + eplb_info = eplb_manager_->get_eplb_info(); + } // update dp related global paramters and then execute model for (auto worker_rank = 0; worker_rank < worker_clients_num; ++worker_rank) { auto dp_rank = worker_rank / dp_local_tp_size; raw_forward_inputs[dp_rank].dp_global_token_nums = dp_global_token_nums; raw_forward_inputs[dp_rank].global_empty_kv_cache = global_empty_kv_cache; + if (FLAGS_enable_eplb) { + raw_forward_inputs[dp_rank].eplb_info = eplb_info; + } futures.emplace_back( worker_clients_[worker_rank]->step_async(raw_forward_inputs[dp_rank])); } @@ -509,6 +529,9 @@ ForwardOutput LLMEngine::step(std::vector& batch) { // wait for the all future to complete auto results = folly::collectAll(futures).get(); + if (FLAGS_enable_eplb && !options_.enable_schedule_overlap()) { + process_eplb_data(results, worker_clients_num); + } // concat results from dp ranks std::vector> raw_forward_outputs; raw_forward_outputs.reserve(dp_size); @@ -541,23 +564,46 @@ void LLMEngine::update_last_step_result(std::vector& last_batch) { std::vector>> futures; futures.reserve(dp_size); - for (auto worker_rank = 0; worker_rank < worker_clients_num; - worker_rank += dp_local_tp_size) { - futures.emplace_back( - worker_clients_[worker_rank]->get_last_step_result_async()); - } - // wait for the all future to complete - auto last_step_results = folly::collectAll(futures).get(); - // concat last step results from dp ranks std::vector raw_forward_outputs; raw_forward_outputs.reserve(dp_size); - for (auto worker_rank = 0; worker_rank < worker_clients_num; - worker_rank += dp_local_tp_size) { - auto result = last_step_results[worker_rank / dp_local_tp_size].value(); - if (result.has_value()) { - raw_forward_outputs.emplace_back(std::move(result.value())); - } else { - throw std::runtime_error("Failed to get last step results."); + if (FLAGS_enable_eplb) { + for (auto worker_rank = 0; worker_rank < worker_clients_num; + worker_rank++) { + futures.emplace_back( + worker_clients_[worker_rank]->get_last_step_result_async()); + } + // wait for the all future to complete + auto last_step_results = folly::collectAll(futures).get(); + // concat last step results from dp ranks + process_eplb_data(last_step_results, worker_clients_num); + for (auto worker_rank = 0; worker_rank < worker_clients_num; + worker_rank += dp_local_tp_size) { + auto result = last_step_results[worker_rank].value(); + if (result.has_value()) { + raw_forward_outputs.emplace_back(std::move(result.value())); + } else { + throw std::runtime_error("Failed to get last step results."); + } + } + } else { + for (auto worker_rank = 0; worker_rank < worker_clients_num; + worker_rank += dp_local_tp_size) { + futures.emplace_back( + worker_clients_[worker_rank]->get_last_step_result_async()); + } + + // wait for the all future to complete + auto last_step_results = folly::collectAll(futures).get(); + // concat last step results from dp ranks + + for (auto worker_rank = 0; worker_rank < worker_clients_num; + worker_rank += dp_local_tp_size) { + auto result = last_step_results[worker_rank / dp_local_tp_size].value(); + if (result.has_value()) { + raw_forward_outputs.emplace_back(std::move(result.value())); + } else { + throw std::runtime_error("Failed to get last step results."); + } } } @@ -592,4 +638,30 @@ void LLMEngine::setup_workers(const runtime::Options& options) { worker_clients_ = dist_manager_->get_worker_clients(); } +void LLMEngine::process_eplb_data( + const std::vector>>& results, + int32_t worker_clients_num) { + int32_t num_layers = args_.n_layers() - args_.first_k_dense_replace(); + int32_t num_device_experts = + args_.n_routed_experts() / worker_clients_.size() + 1; + std::vector tensors; + std::vector layer_ids(num_device_experts - 1, -1); + tensors.reserve(worker_clients_.size()); + for (size_t worker_rank = 0; worker_rank < results.size(); ++worker_rank) { + auto result = results[worker_rank].value(); + if (result.has_value()) { + tensors.emplace_back( + torch::from_blob(result.value().expert_load_data.data(), + {num_layers, num_device_experts}, + torch::TensorOptions().dtype(torch::kInt64)) + .clone()); + layer_ids[worker_rank] = result.value().prepared_layer_id; + } else { + LOG(ERROR) << "Failed to process EPLB data"; + } + } + eplb_manager_->set_prepared_layer_ids(layer_ids); + eplb_manager_->update_expert_load(tensors); +} + } // namespace xllm diff --git a/xllm/core/runtime/llm_engine.h b/xllm/core/runtime/llm_engine.h index 910273b4..ede2cd5b 100644 --- a/xllm/core/runtime/llm_engine.h +++ b/xllm/core/runtime/llm_engine.h @@ -24,6 +24,8 @@ limitations under the License. #include "distributed_runtime/dist_manager.h" #include "framework/batch/batch.h" #include "framework/block/block_manager_pool.h" +#include "framework/eplb/eplb_manager.h" +#include "framework/eplb/eplb_policy.h" #include "framework/quant_args.h" #include "framework/tokenizer/tokenizer.h" #include "framework/tokenizer/tokenizer_args.h" @@ -115,6 +117,7 @@ class LLMEngine : public Engine { int64_t n_local_q_heads_ = 0; int64_t head_dim_ = 0; + torch::Tensor expert_load_data_; // For multi-node serving // engine brpc server, all workers connect to engine_server_, // engine_server_ will send a UniqueId for workers to @@ -122,6 +125,12 @@ class LLMEngine : public Engine { // address to engine, engine will create WorkerClient for each worker. // Engine call workers to step via these WorkerClients. std::shared_ptr dist_manager_ = nullptr; + + std::unique_ptr eplb_manager_ = nullptr; + std::unique_ptr eplb_policy_ = nullptr; + void process_eplb_data( + const std::vector>>& results, + int32_t worker_clients_num); }; } // namespace xllm diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 7383bb4b..c0d1b6cf 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -37,6 +37,7 @@ limitations under the License. #include "common/device_monitor.h" #include "common/metrics.h" #include "common/types.h" +#include "core/common/global_flags.h" #include "framework/kv_cache/kv_cache.h" #include "framework/model/model_input_params.h" #include "framework/parallel_state.h" @@ -80,6 +81,7 @@ bool LLMWorkerImpl::init_model(torch::ScalarType dtype, model_executor_ = std::make_unique(model_.get(), model_args, device_, options_); + eplb_executor_ = std::make_unique(model_.get()); return true; } @@ -95,6 +97,9 @@ std::optional LLMWorkerImpl::step(const ForwardInput& inputs) { auto& params = inputs.input_params; auto& sampling_params = inputs.sampling_params; + if (FLAGS_enable_eplb) { + eplb_executor_->eplb_execute(inputs.eplb_info); + } std::vector> futures; if (options_.instance_role() == InstanceRole::PREFILL && options_.kv_cache_transfer_mode() == "PUSH" && @@ -124,6 +129,14 @@ std::optional LLMWorkerImpl::step(const ForwardInput& inputs) { model_->logits(hidden_states, sampling_params.selected_token_idxes); } + ForwardOutput output; + if (FLAGS_enable_eplb) { + output.expert_load_data = expert_load_data_; + output.prepared_layer_id = eplb_executor_->get_ready_layer_id(); + if (output.prepared_layer_id != -1) { + eplb_executor_->reset_ready_layer_id(); + } + } if (!enable_schedule_overlap() && !driver_ && !dp_driver_ && !options_.enable_speculative_decode()) { #if defined(USE_NPU) @@ -145,11 +158,13 @@ std::optional LLMWorkerImpl::step(const ForwardInput& inputs) { } } } + if (FLAGS_enable_eplb) { + return output; + } return std::nullopt; } // driver prepare model output - ForwardOutput output; SampleOutput sample_output; if (sampling_params.selected_token_idxes.defined()) { sample_output = sampler_->forward(logits, sampling_params); diff --git a/xllm/core/runtime/master.cpp b/xllm/core/runtime/master.cpp index b5a56309..94860b4d 100644 --- a/xllm/core/runtime/master.cpp +++ b/xllm/core/runtime/master.cpp @@ -62,9 +62,18 @@ Master::Master(const Options& options, EngineType type) : options_(options) { if (options.communication_backend().has_value()) { FLAGS_communication_backend = options.communication_backend().value(); } - if (options.communication_backend().has_value()) { + if (options.expert_parallel_degree().has_value()) { FLAGS_expert_parallel_degree = options.expert_parallel_degree().value(); } + if (options.enable_eplb().has_value()) { + FLAGS_enable_eplb = options.enable_eplb().value(); + } + if (options.eplb_update_rate().has_value()) { + FLAGS_eplb_update_rate = options.eplb_update_rate().value(); + } + if (options.eplb_update_threshold().has_value()) { + FLAGS_eplb_update_threshold = options.eplb_update_threshold().value(); + } #endif // construct engine diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 7c7f12d5..08bc9462 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -20,6 +20,7 @@ limitations under the License. #include +#include "common/global_flags.h" #include "common/macros.h" #include "common/metrics.h" #include "framework/model/model_input_params.h" @@ -263,6 +264,13 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, transfer_kv_info.remote_instance_info = std::move(instance_info); forward_inputs.transfer_kv_infos.emplace_back(std::move(transfer_kv_info)); } + auto& eplb_info = forward_inputs.eplb_info; + eplb_info.prepare_layer_id = pb_forward_input->eplb_info().prepare_layer_id(); + eplb_info.expert_ids = + std::vector(pb_forward_input->eplb_info().expert_ids().begin(), + pb_forward_input->eplb_info().expert_ids().end()); + eplb_info.update_layer_id = pb_forward_input->eplb_info().update_layer_id(); + forward_inputs.eplb_info = eplb_info; COUNTER_ADD(proto_latency_seconds_proto2i, timer.elapsed_seconds()); } @@ -366,6 +374,13 @@ void forward_input_to_proto(const RawForwardInput& inputs, transfer_kv_info.remote_instance_info.dp_size); } } + pb_forward_input->mutable_eplb_info()->set_prepare_layer_id( + inputs.eplb_info.prepare_layer_id); + pb_forward_input->mutable_eplb_info()->set_update_layer_id( + inputs.eplb_info.update_layer_id); + ADD_VECTOR_TO_PROTO( + pb_forward_input->mutable_eplb_info()->mutable_expert_ids(), + inputs.eplb_info.expert_ids); pb_forward_input->mutable_embeds()->Reserve(inputs.embeddings.size()); for (auto t : inputs.embeddings) { proto::Embeddings embeds; @@ -383,6 +398,11 @@ void proto_to_forward_output(const proto::ForwardOutput& pb_output, Timer timer; size_t seq_nums = pb_output.outputs().size(); raw_forward_output.outputs.reserve(seq_nums); + size_t expert_load_data_size = pb_output.expert_load_data().size(); + raw_forward_output.expert_load_data.reserve(expert_load_data_size); + raw_forward_output.expert_load_data.assign( + pb_output.expert_load_data().begin(), pb_output.expert_load_data().end()); + raw_forward_output.prepared_layer_id = pb_output.prepared_layer_id(); for (size_t i = 0; i < seq_nums; ++i) { proto::SquenceOutput pb_seq_out = pb_output.outputs()[i]; RawSampleOutput s; @@ -418,6 +438,8 @@ void forward_output_to_proto(const torch::Tensor& next_tokens, const torch::Tensor& top_tokens, const torch::Tensor& top_logprobs, const torch::Tensor& embeddings, + const torch::Tensor& expert_load_data, + int32_t prepared_layer_id, proto::ForwardOutput* pb_forward_output) { Timer timer; int32_t num_seqs = next_tokens.size(0); @@ -533,6 +555,20 @@ void forward_output_to_proto(const torch::Tensor& next_tokens, *pb_forward_output->mutable_outputs()->Add() = pb_seq_out; } } + + if (FLAGS_enable_eplb) { + pb_forward_output->set_prepared_layer_id(prepared_layer_id); + + torch::Tensor expert_load_data_flattened = + expert_load_data.view({-1}).contiguous(); + if (expert_load_data_flattened.defined()) { + Slice expert_load_data_flattened_slice = { + expert_load_data_flattened.data_ptr(), + expert_load_data_flattened.size(0)}; + ADD_VECTOR_TO_PROTO(pb_forward_output->mutable_expert_load_data(), + expert_load_data_flattened_slice); + } + } COUNTER_ADD(proto_latency_seconds_o2proto, timer.elapsed_seconds()); return; } diff --git a/xllm/core/runtime/params_utils.h b/xllm/core/runtime/params_utils.h index db2f1d46..3e50a7a3 100644 --- a/xllm/core/runtime/params_utils.h +++ b/xllm/core/runtime/params_utils.h @@ -39,6 +39,8 @@ void forward_output_to_proto(const torch::Tensor& next_tokens, const torch::Tensor& top_tokens, const torch::Tensor& top_logprobs, const torch::Tensor& embeddings, + const torch::Tensor& expert_load_data, + int32_t prepared_layer_id, proto::ForwardOutput* pb_forward_output); Token build_token(int64_t index, diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 57553f0f..1222f5b4 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -285,7 +285,7 @@ WorkerImpl::estimate_kv_cache_capacity_async() { void WorkerImpl::update_last_step_output( const std::optional& output) { - if (output.value().sample_output.next_tokens.defined()) { + if (output.value().sample_output.next_tokens.defined() || FLAGS_enable_eplb) { last_step_output_ = std::move(output.value()); last_step_output_valid_ = true; } else { @@ -328,8 +328,13 @@ void WorkerImpl::prepare_work_before_execute(const ForwardInput& inputs, context_.get_model_args().num_experts_per_tok(), context_.get_parallel_args().mapping_data(), device_, + dtype_, is_prefill); processed_inputs.input_params.dp_ep_padding_data = dp_ep_padding.build(); + if (FLAGS_enable_eplb) { + expert_load_data_.fill_(0); + processed_inputs.input_params.expert_load_data = expert_load_data_; + } } aclrtSynchronizeStream(npu_stream_helper_->H2D_memcpy_stream.stream()); #endif @@ -356,7 +361,7 @@ folly::SemiFuture> WorkerImpl::step_async( } const auto output = this->step(inputs); if (output.has_value()) { - if (is_driver()) { + if (is_driver() || FLAGS_enable_eplb) { std::unique_lock lock(mtx_); cv_.wait(lock, [this] { return !is_recorded_; }); update_last_step_output(output); @@ -366,7 +371,7 @@ folly::SemiFuture> WorkerImpl::step_async( update_last_step_output(output); } } else { - if (is_driver()) { + if (is_driver() || FLAGS_enable_eplb) { std::unique_lock lock(mtx_); cv_.wait(lock, [this] { return !is_recorded_; }); last_step_output_valid_ = false; @@ -454,6 +459,15 @@ bool WorkerImpl::init_model(const std::string& model_weights_path) { this->load_model(std::move(model_loader)); status_ = Status::LOADED; + if (FLAGS_enable_eplb) { + int32_t num_layers = args.n_layers() - args.first_k_dense_replace(); + int32_t num_device_experts = + args.n_routed_experts() / context_.get_parallel_args().world_size() + 1; + expert_load_data_ = torch::zeros({num_layers, num_device_experts}) + .to(torch::kInt64) + .to(device_) + .contiguous(); + } return true; } diff --git a/xllm/core/runtime/worker_impl.h b/xllm/core/runtime/worker_impl.h index dc91e7fc..14c2e572 100644 --- a/xllm/core/runtime/worker_impl.h +++ b/xllm/core/runtime/worker_impl.h @@ -28,6 +28,7 @@ limitations under the License. #include "framework/kv_cache/hccl_kv_cache_transfer.h" #include "framework/kv_cache/llm_data_dist_transfer.h" #endif +#include "framework/eplb/eplb_executor.h" #include "framework/model/causal_lm.h" #include "framework/model/embedding_lm.h" #include "framework/model/model_input_params.h" @@ -194,6 +195,8 @@ class WorkerImpl { std::unique_ptr sampler_; + std::unique_ptr eplb_executor_; + // params for enable_schedule_overlap case // an output to store the result of last step ForwardOutput last_step_output_; @@ -218,6 +221,8 @@ class WorkerImpl { bool is_spec_draft_ = false; Status status_ = Status::UNINITIALIZED; + + torch::Tensor expert_load_data_; }; } // namespace xllm diff --git a/xllm/models/deepseek_v2.h b/xllm/models/deepseek_v2.h index 32f88c49..269515ea 100644 --- a/xllm/models/deepseek_v2.h +++ b/xllm/models/deepseek_v2.h @@ -87,6 +87,12 @@ class DeepseekV2DecoderLayerImpl : public torch::nn::Module { void merge_loaded_weights() { decoder_layer_->merge_loaded_weights(); } + void prepare_expert_weight(const std::vector& expert_list) { + decoder_layer_->prepare_expert_weight(expert_list); + } + + void update_expert_weight() { decoder_layer_->update_expert_weight(); } + private: DeepseekV2Decoder decoder_layer_{nullptr}; }; @@ -230,6 +236,15 @@ class DeepseekV2ModelImpl : public torch::nn::Module { norm_->merge_loaded_weights(); } + void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + layers_[layer_id]->prepare_expert_weight(expert_ids); + } + + void update_expert_weight(int32_t layer_id) { + layers_[layer_id]->update_expert_weight(); + } + AtbWordEmbedding get_word_embedding() { return embed_tokens_; } void set_word_embedding(AtbWordEmbedding& word_embedding) { @@ -270,6 +285,7 @@ class DeepseekV2ForCausalLMImpl : public torch::nn::Module { context_->SetExecuteStream(stream); context_->SetAsyncTilingCopyStatus(true); lm_head_ = register_module("lm_head", LlmHead(context)); + first_k_dense_replace_ = context.get_model_args().first_k_dense_replace(); } // tokens: [num_tokens] @@ -305,6 +321,16 @@ class DeepseekV2ForCausalLMImpl : public torch::nn::Module { lm_head_->merge_loaded_weights(); } + void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + model_->prepare_expert_weight(layer_id + first_k_dense_replace_, + expert_ids); + } + + void update_expert_weight(int32_t layer_id) { + model_->update_expert_weight(layer_id + first_k_dense_replace_); + } + LlmHead get_lm_head() { return lm_head_; } void set_lm_head(LlmHead& head) { lm_head_ = head; } @@ -320,6 +346,7 @@ class DeepseekV2ForCausalLMImpl : public torch::nn::Module { LlmHead lm_head_{nullptr}; AtbWorkspace work_space_; atb::Context* context_; + int32_t first_k_dense_replace_; }; TORCH_MODULE(DeepseekV2ForCausalLM); diff --git a/xllm/models/deepseek_v2_mtp.h b/xllm/models/deepseek_v2_mtp.h index 585545d3..c36ba2a4 100644 --- a/xllm/models/deepseek_v2_mtp.h +++ b/xllm/models/deepseek_v2_mtp.h @@ -284,6 +284,11 @@ class DeepseekV2MtpForCausalLMImpl : public torch::nn::Module { // lm_head_->merge_loaded_weights(); } + void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + void update_expert_weight(int32_t layer_id) { return; } LlmHead get_lm_head() { return lm_head_; } void set_lm_head(LlmHead& head) { lm_head_ = head; } diff --git a/xllm/models/llama.h b/xllm/models/llama.h index 8988562a..9eb6985c 100644 --- a/xllm/models/llama.h +++ b/xllm/models/llama.h @@ -310,6 +310,12 @@ class LlamaForCausalLMImpl : public torch::nn::Module { lm_head_->merge_loaded_weights(); } + void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + void update_expert_weight(int32_t layer_id) { return; } + LlmHead get_lm_head() { return lm_head_; } void set_lm_head(LlmHead& head) { lm_head_ = head; } diff --git a/xllm/models/qwen3_moe.h b/xllm/models/qwen3_moe.h index 66cb9f52..44131c29 100644 --- a/xllm/models/qwen3_moe.h +++ b/xllm/models/qwen3_moe.h @@ -294,6 +294,12 @@ class Qwen3MoeForCausalLMImpl : public torch::nn::Module { lm_head_->merge_loaded_weights(); } + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + virtual void update_expert_weight(int32_t layer_id) { return; } + LlmHead get_lm_head() { return lm_head_; } void set_lm_head(LlmHead& head) { lm_head_ = head; } diff --git a/xllm/models/qwen_base.h b/xllm/models/qwen_base.h index 0314cd54..22344e3f 100644 --- a/xllm/models/qwen_base.h +++ b/xllm/models/qwen_base.h @@ -392,6 +392,12 @@ class QWenForCausalLMImplBase : public torch::nn::Module { lm_head_->merge_loaded_weights(); } + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + virtual void update_expert_weight(int32_t layer_id) { return; } + virtual LlmHead get_lm_head() { return lm_head_; } virtual void set_lm_head(LlmHead& head) { lm_head_ = head; } diff --git a/xllm/proto/worker.proto b/xllm/proto/worker.proto index 99046431..0ad80ded 100644 --- a/xllm/proto/worker.proto +++ b/xllm/proto/worker.proto @@ -98,6 +98,12 @@ message TransferKVInfo { InstanceInfo remote_instance_info = 5; } +message EplbInfo { + int32 prepare_layer_id = 1; + repeated int32 expert_ids = 2; + int32 update_layer_id = 3; +}; + message RequestSamplingParam { float frequency_penalty = 1; float presence_penalty = 2; @@ -155,6 +161,7 @@ message ForwardInput { repeated Embeddings embeds = 23; uint32 prefill_seq_len = 24; repeated int32 embedding_ids = 25; + EplbInfo eplb_info =26; } message Embeddings { @@ -178,6 +185,8 @@ message SquenceOutput { message ForwardOutput { repeated SquenceOutput outputs = 2; + repeated int64 expert_load_data = 3; + int32 prepared_layer_id = 4; } // master create Collective service diff --git a/xllm/xllm.cpp b/xllm/xllm.cpp index b96adf0b..bcf455cd 100644 --- a/xllm/xllm.cpp +++ b/xllm/xllm.cpp @@ -113,6 +113,9 @@ int run() { .num_speculative_tokens(FLAGS_num_speculative_tokens) .num_handling_threads(FLAGS_num_handling_threads) .communication_backend(FLAGS_communication_backend) + .enable_eplb(FLAGS_enable_eplb) + .eplb_update_rate(FLAGS_eplb_update_rate) + .eplb_update_threshold(FLAGS_eplb_update_threshold) .rank_tablefile(FLAGS_rank_tablefile) .expert_parallel_degree(FLAGS_expert_parallel_degree) .enable_mla(FLAGS_enable_mla) From b8553b9ab519ff3ee7224c413e663887bfc50f4d Mon Sep 17 00:00:00 2001 From: jindonghe1 Date: Tue, 26 Aug 2025 16:07:56 +0800 Subject: [PATCH 2/4] bugfix: fix compile issuse for EPLB. --- xllm/core/layers/npu/deepseek_v2_decoder_layer.cpp | 2 ++ xllm/core/layers/npu/deepseek_v2_decoder_layer.h | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/xllm/core/layers/npu/deepseek_v2_decoder_layer.cpp b/xllm/core/layers/npu/deepseek_v2_decoder_layer.cpp index 7efd6d47..21297b5c 100644 --- a/xllm/core/layers/npu/deepseek_v2_decoder_layer.cpp +++ b/xllm/core/layers/npu/deepseek_v2_decoder_layer.cpp @@ -1183,12 +1183,14 @@ void DeepseekV2DecoderImpl::merge_experts_weights() { if (decode_param_.isBF16) { torch::Tensor mlp_down_weight = merge_experts_weights(experts_weights_["down_proj.weight"], + device_, /*transpose=*/true); at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = at_npu::native::npu_format_cast(mlp_down_weight, 29); } else { torch::Tensor mlp_down_weight = merge_experts_weights(experts_weights_["down_proj.weight"], + device_, /*transpose=*/false); at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); diff --git a/xllm/core/layers/npu/deepseek_v2_decoder_layer.h b/xllm/core/layers/npu/deepseek_v2_decoder_layer.h index 471fade3..c9d3eefc 100644 --- a/xllm/core/layers/npu/deepseek_v2_decoder_layer.h +++ b/xllm/core/layers/npu/deepseek_v2_decoder_layer.h @@ -22,7 +22,6 @@ limitations under the License. #include #include "atb_base.h" -#include "atb_layers/models/deepseekv2/layer/decoder_layer.h" #include "framework/eplb/expert_buffer_manager.h" #include "framework/eplb/expert_weight_buffer_shm.h" #include "framework/model/model_args.h" From 25b0d5f9d8856d1da5c688f14638c13b1e9d45da Mon Sep 17 00:00:00 2001 From: jindonghe1 Date: Thu, 28 Aug 2025 17:00:49 +0800 Subject: [PATCH 3/4] bugfix: fix coredump issue when both EPLB and schedule overlap are enabled. --- xllm/core/runtime/worker_impl.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 1222f5b4..21875d58 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -285,10 +285,13 @@ WorkerImpl::estimate_kv_cache_capacity_async() { void WorkerImpl::update_last_step_output( const std::optional& output) { - if (output.value().sample_output.next_tokens.defined() || FLAGS_enable_eplb) { + if (output.value().sample_output.next_tokens.defined()) { last_step_output_ = std::move(output.value()); last_step_output_valid_ = true; } else { + if(FLAGS_enable_eplb) { + last_step_output_ = std::move(output.value()); + } last_step_output_valid_ = false; } } @@ -332,7 +335,7 @@ void WorkerImpl::prepare_work_before_execute(const ForwardInput& inputs, is_prefill); processed_inputs.input_params.dp_ep_padding_data = dp_ep_padding.build(); if (FLAGS_enable_eplb) { - expert_load_data_.fill_(0); + // expert_load_data_.fill_(0); processed_inputs.input_params.expert_load_data = expert_load_data_; } } @@ -391,7 +394,7 @@ ForwardOutput WorkerImpl::get_last_step_result() { ForwardOutput output; std::unique_lock lock(mtx_); cv_.wait(lock, [this] { return is_recorded_; }); - if (last_step_output_valid_) { + if (last_step_output_valid_ || FLAGS_enable_eplb) { output = last_step_output_; } is_recorded_ = false; From 906d13bd30f55e76f92d8fa161d4543bc2e0ac31 Mon Sep 17 00:00:00 2001 From: jindonghe1 Date: Fri, 29 Aug 2025 12:23:37 +0800 Subject: [PATCH 4/4] feat: support variable number of redundant expert. --- docs/en/features/eplb.md | 4 +- docs/zh/features/eplb.md | 6 +- xllm/core/common/global_flags.cpp | 6 +- xllm/core/common/global_flags.h | 4 +- xllm/core/common/options.h | 4 +- xllm/core/common/types.h | 10 +- xllm/core/framework/eplb/eplb_executor.h | 8 + xllm/core/framework/eplb/eplb_manager.cpp | 24 +-- xllm/core/framework/eplb/eplb_manager.h | 20 ++- xllm/core/framework/eplb/eplb_policy.cpp | 10 +- xllm/core/framework/eplb/eplb_policy.h | 11 ++ xllm/core/framework/model/causal_vlm.h | 2 +- .../layers/npu/deepseek_v2_decoder_layer.cpp | 139 +++++++++--------- .../layers/npu/deepseek_v2_decoder_layer.h | 2 + xllm/core/runtime/llm_engine.cpp | 14 +- xllm/core/runtime/llm_engine.h | 1 - xllm/core/runtime/master.cpp | 7 +- xllm/core/runtime/params_utils.cpp | 1 - xllm/core/runtime/worker_impl.cpp | 5 +- xllm/models/qwen3_embedding.h | 6 + xllm/xllm.cpp | 3 +- 21 files changed, 169 insertions(+), 118 deletions(-) diff --git a/docs/en/features/eplb.md b/docs/en/features/eplb.md index e809f058..dbc45d9f 100644 --- a/docs/en/features/eplb.md +++ b/docs/en/features/eplb.md @@ -23,8 +23,8 @@ Simply add the following gflag parameters when launching xLLM: - xLLM provides the gflag parameter `enable_eplb` (default: false). Set to true in the xLLM service startup script to enable dynamic expert load balancing. - `expert_parallel_degree` and `ep_size` are MoE-related parameters. `expert_parallel_degree` should be set to `2`, and `ep_size` must match the actual number of NPU/GPU devices. See [moe_params](./moe_params.md) -- `eplb_update_rate` sets the expert distribution update interval in seconds (default: 1000). +- `eplb_update_interval` sets the expert distribution update interval in seconds (default: 1000). - The expert distribution update uses a layer-by-layer mechanism based on expert load. When the similarity between consecutive loads for a layer is below `eplb_update_threshold`, that layer is updated (default: 1, range: 0-1). ```bash ---enable_eplb=true --expert_parallel_degree=2 --ep_size=16 --eplb_update_rate=2000 --eplb_update_threshold=0.9 \ No newline at end of file +--enable_eplb=true --expert_parallel_degree=2 --ep_size=16 --eplb_update_interval=2000 --eplb_update_threshold=0.9 \ No newline at end of file diff --git a/docs/zh/features/eplb.md b/docs/zh/features/eplb.md index 55cfa56d..184ea885 100644 --- a/docs/zh/features/eplb.md +++ b/docs/zh/features/eplb.md @@ -18,14 +18,14 @@ xLLM eplb功能主要通过以下三个模块实现: - xLLM中提供了gflags参数`enable_eplb`,默认false,如需开启动态专家负载均衡,在xLLM的服务启动脚本中设置为true即可。 - `expert_parallel_degree`与`ep_size`为moe相关参数,`expert_parallel_degree`需要设置为`2`,`ep_size`要与实际NPU/GPU卡个数保持一致。参考 [moe_params](./moe_params.md) -- `eplb_update_rate`为专家分布更新时间间隔,单位为妙,默认值为1000. -- 专家分布更新采用根据专家负载的逐层更新机制,当某一层专家的前后两次的负载相似度小于`eplb_update_threshold`时选择更新该层,默认值为1,取之范围为(0,1)。 +- `eplb_update_interval`为专家分布更新时间间隔,单位为妙,默认值为1000. +- 专家分布更新采用根据专家负载的逐层更新机制,当某一层专家的前后两次的负载相似度小于`eplb_update_interval`时选择更新该层,默认值为1,取之范围为(0,1)。 ```bash --enable_eplb=true --expert_parallel_degree=2 --ep_size=16 - --eplb_update_rate=2000 + --eplb_update_interval=2000 --eplb_update_threshold=0.9 ``` diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index 5bc72353..e36ff652 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -118,7 +118,11 @@ DEFINE_string(communication_backend, "hccl", "npu communication backend."); DEFINE_bool(enable_eplb, false, "Whether to use ep load balance."); -DEFINE_int64(eplb_update_rate, 1000, "eplb update rate."); +DEFINE_int32(redundant_experts_num, + 1, + "num of redundant experts on per device."); + +DEFINE_int64(eplb_update_interval, 1000, "eplb update rate."); DEFINE_double(eplb_update_threshold, 0.8, "eplb update threshold."); diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 6c588ba5..5a230ec6 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -69,7 +69,9 @@ DECLARE_string(communication_backend); DECLARE_bool(enable_eplb); -DECLARE_int64(eplb_update_rate); +DECLARE_int32(redundant_experts_num); + +DECLARE_int64(eplb_update_interval); DECLARE_double(eplb_update_threshold); diff --git a/xllm/core/common/options.h b/xllm/core/common/options.h index cd73dd37..a4df96d8 100644 --- a/xllm/core/common/options.h +++ b/xllm/core/common/options.h @@ -72,7 +72,9 @@ class Options { PROPERTY(std::optional, enable_eplb); - PROPERTY(std::optional, eplb_update_rate); + PROPERTY(std::optional, redundant_experts_num); + + PROPERTY(std::optional, eplb_update_interval); PROPERTY(std::optional, eplb_update_threshold); diff --git a/xllm/core/common/types.h b/xllm/core/common/types.h index 3f6f218b..8fc55f15 100644 --- a/xllm/core/common/types.h +++ b/xllm/core/common/types.h @@ -251,10 +251,18 @@ struct JsonTool { JsonTool(const std::string& tool_type, const JsonFunction& func) : type(tool_type), function(func) {} }; - +// Experts update the required information struct EplbInfo { + // Target layer ID for new expert weight pre-loading (-1 = no pending load) + // Values >=0 indicate the layer ID that should start loading new expert + // weights int32_t prepare_layer_id = -1; + // Expert IDs requiring updates, ordered by device shard assignment + // Contains per-device expert indices for distributed weight updates std::vector expert_ids; + // Layer ID ready for expert weight activation (-1 = no pending update) + // Values >=0 indicate the layer ID whose pre-loaded weights are ready for + // deployment int32_t update_layer_id = -1; }; diff --git a/xllm/core/framework/eplb/eplb_executor.h b/xllm/core/framework/eplb/eplb_executor.h index cbf37f19..59817b22 100644 --- a/xllm/core/framework/eplb/eplb_executor.h +++ b/xllm/core/framework/eplb/eplb_executor.h @@ -18,8 +18,16 @@ class EplbExecutor final { EplbExecutor(CausalLM* model); virtual ~EplbExecutor(); + + // Reset the ready layer ID marker to -1 (no layer ready) void reset_ready_layer_id(); + + // Get the currently ready layer ID + // return int32_t Layer ID with prepared weights (-1 if none) int32_t get_ready_layer_id() const; + + // Execute EPLB operation based on coordination info + // param eplb_info Contains layer preparation/activation instructions void eplb_execute(const EplbInfo& eplb_info); private: diff --git a/xllm/core/framework/eplb/eplb_manager.cpp b/xllm/core/framework/eplb/eplb_manager.cpp index 2c02bfde..ac4474e7 100644 --- a/xllm/core/framework/eplb/eplb_manager.cpp +++ b/xllm/core/framework/eplb/eplb_manager.cpp @@ -20,17 +20,18 @@ namespace xllm { using namespace std::chrono_literals; -EplbManager::EplbManager(EplbPolicy* eplb_policy, - int32_t layer_num, +EplbManager::EplbManager(int32_t layer_num, int32_t device_num, int32_t experts_num) - : eplb_policy_(eplb_policy), - layer_num_(layer_num), + : layer_num_(layer_num), device_num_(device_num), experts_num_(experts_num), - device_experts_num_((experts_num + device_num) / device_num) { + device_experts_num_(experts_num / device_num + + FLAGS_redundant_experts_num) { // Initialize tensors with mutex protection { + eplb_policy_ = std::make_unique( + device_experts_num_, device_num_, layer_num_); std::lock_guard lock(state_.mtx); state_.expert_load = torch::zeros({layer_num_, experts_num_}, torch::kInt64); @@ -39,11 +40,13 @@ EplbManager::EplbManager(EplbPolicy* eplb_policy, {layer_num_, device_num_, device_experts_num_}, torch::kInt32); for (int32_t layer = 0; layer < layer_num_; ++layer) { for (int32_t device = 0; device < device_num_; ++device) { - int32_t base = device * (device_experts_num_ - 1); + int32_t device_route_experts_num = + device_experts_num_ - FLAGS_redundant_experts_num; + int32_t base = device * device_route_experts_num; for (int32_t expert = 0; expert < device_experts_num_; ++expert) { int32_t value = base + expert; - if (expert == device_experts_num_ - 1) { - --value; + if (expert >= device_route_experts_num) { + value = base + device_route_experts_num - 1; } state_.expert_distribution[layer][device][expert] = value; } @@ -105,7 +108,6 @@ void EplbManager::aggregate_multi_layer_expert_loads( layer_ids.emplace_back(ids.flatten().to(torch::kInt64)); layer_loads.emplace_back(loads.flatten().to(torch::kInt64)); } - torch::Tensor all_ids = torch::cat(layer_ids); torch::Tensor all_loads = torch::cat(layer_loads); expert_load[layer].scatter_add_(0, all_ids, all_loads); @@ -125,14 +127,12 @@ void EplbManager::rebalance_experts_loop() { if (state_.stop) return; while (!state_.expert_load_queue.empty()) { - // expert_load_batch.emplace_back(state_.expert_load_queue.front()); - // state_.expert_load_queue.pop(); aggregate_multi_layer_expert_loads(state_.expert_load, state_.expert_distribution, state_.expert_load_queue.front()); state_.expert_load_queue.pop(); int64_t current_time = absl::ToUnixSeconds(absl::Now()); - if (current_time - latest_record_time >= FLAGS_eplb_update_rate) { + if (current_time - latest_record_time >= FLAGS_eplb_update_interval) { latest_record_time = current_time; auto result = eplb_policy_->rebalance_experts(state_.expert_load); state_.expert_distribution = result.first; diff --git a/xllm/core/framework/eplb/eplb_manager.h b/xllm/core/framework/eplb/eplb_manager.h index 0b251c4b..31da3e72 100644 --- a/xllm/core/framework/eplb/eplb_manager.h +++ b/xllm/core/framework/eplb/eplb_manager.h @@ -13,14 +13,24 @@ namespace xllm { class EplbManager { public: - EplbManager(EplbPolicy* eplb_policy, - int32_t layer_num, - int32_t device_num, - int32_t experts_num); + // Initialize with model dimensions: + // - layer_num: Total layers in the model + // - device_num: Parallel devices in cluster + // - experts_num: Experts per model layer + EplbManager(int32_t layer_num, int32_t device_num, int32_t experts_num); + ~EplbManager(); + // Feed new expert workload data for load balancing + // Input tensors should have shape [layer_num, experts_num] void update_expert_load(const std::vector expert_load); + + // Fetch current coordination instructions for expert updates + // Returns struct containing layer preparation/activation commands EplbInfo get_eplb_info(); + + // Mark specified layers as prepared (call after async loading completes) + // expert_layer_ids: Prepared layer IDs per device void set_prepared_layer_ids(const std::vector& expert_layer_ids); private: @@ -49,7 +59,7 @@ class EplbManager { }; // Components - EplbPolicy* eplb_policy_; + std::unique_ptr eplb_policy_ = nullptr; ThreadSafeData state_; // Constants diff --git a/xllm/core/framework/eplb/eplb_policy.cpp b/xllm/core/framework/eplb/eplb_policy.cpp index 3323f3b0..051be2b6 100644 --- a/xllm/core/framework/eplb/eplb_policy.cpp +++ b/xllm/core/framework/eplb/eplb_policy.cpp @@ -15,7 +15,9 @@ EplbPolicy::EplbPolicy(int32_t device_experts_num, device_num_(device_num), layer_num_(layer_num) { old_expert_load_ = - torch::zeros({layer_num_, device_experts_num * device_num - device_num}, + torch::zeros({layer_num_, + device_experts_num * device_num - + device_num * FLAGS_redundant_experts_num}, torch::kInt64); expert_distribution_ = torch::full( {layer_num_, device_num_, device_experts_num_}, -1, torch::kInt32); @@ -32,9 +34,7 @@ std::pair> EplbPolicy::rebalance_experts( auto prev_max_val = torch::max(prev_load).item() + 1e-6f; current_load = (current_load / current_max_val).unsqueeze(0); - ; prev_load = (prev_load / prev_max_val).unsqueeze(0); - ; auto cos_sim = torch::nn::functional::cosine_similarity( @@ -65,8 +65,8 @@ torch::Tensor EplbPolicy::compute_balanced_pack( const int64_t num_experts = expert_loads.size(0); // Generate Redundant Experts - auto [updated_weights, redundancy_map] = - update_origin_weights(expert_loads, device_num_); + auto [updated_weights, redundancy_map] = update_origin_weights( + expert_loads, device_num_ * FLAGS_redundant_experts_num); // Initialize Allocation Matrix auto options = torch::TensorOptions().dtype(torch::kInt64); diff --git a/xllm/core/framework/eplb/eplb_policy.h b/xllm/core/framework/eplb/eplb_policy.h index ec32d5e2..446ac8d6 100644 --- a/xllm/core/framework/eplb/eplb_policy.h +++ b/xllm/core/framework/eplb/eplb_policy.h @@ -12,8 +12,19 @@ namespace xllm { class EplbPolicy { public: + // Initialize policy engine parameters: + // - device_experts_num: Experts per device (including redundancy) + // - device_num: Total parallel devices + // - layer_num: Model layers to manage EplbPolicy(int32_t device_experts_num, int32_t device_num, int32_t layer_num); + virtual ~EplbPolicy() {}; + + // Recalculate expert distribution based on latest workload + // Input: expert_load - Workload tensor [total_experts] + // Returns: pair + // expert_distribution: [layers x devices x local_experts] + // update_flags: Boolean array marking layers needing update std::pair> rebalance_experts( torch::Tensor expert_load); diff --git a/xllm/core/framework/model/causal_vlm.h b/xllm/core/framework/model/causal_vlm.h index 1af4dea8..5403850e 100644 --- a/xllm/core/framework/model/causal_vlm.h +++ b/xllm/core/framework/model/causal_vlm.h @@ -56,7 +56,7 @@ class CausalVLMImpl : public CausalVLM { void load_model(std::unique_ptr loader) override { model_->load_model(std::move(loader)); } - + virtual void prepare_expert_weight(int32_t layer_id, const std::vector& expert_ids) { return; diff --git a/xllm/core/layers/npu/deepseek_v2_decoder_layer.cpp b/xllm/core/layers/npu/deepseek_v2_decoder_layer.cpp index 21297b5c..da59ff1e 100644 --- a/xllm/core/layers/npu/deepseek_v2_decoder_layer.cpp +++ b/xllm/core/layers/npu/deepseek_v2_decoder_layer.cpp @@ -273,8 +273,9 @@ DeepseekV2DecoderImpl::DeepseekV2DecoderImpl(const Context& context, CHECK_EQ(parallel_args.world_size(), ep_size_ * ep_local_tp_size_); ep_local_tp_rank_ = parallel_args.rank() % ep_local_tp_size_; num_experts_per_partition_ = model_args.n_routed_experts() / ep_size_; + redundant_experts_num_ = FLAGS_redundant_experts_num; if (FLAGS_enable_eplb) { - num_experts_per_partition_++; + num_experts_per_partition_ += redundant_experts_num_; } ep_rank_ = parallel_args.rank() / ep_local_tp_size_; start_expert_id_ = ep_rank_ * num_experts_per_partition_; @@ -328,15 +329,17 @@ void DeepseekV2DecoderImpl::initialize_tensors( void DeepseekV2DecoderImpl::initialize_device_expert_list( int num_device, - int num_device_route_expert) { + int num_device_expert) { + int32_t num_device_route_expert = num_device_expert; if (FLAGS_enable_eplb) { - --num_device_route_expert; + num_device_route_expert = num_device_expert - redundant_experts_num_; } for (int i = 0; i < num_device * num_device_route_expert; ++i) { - std::vector subvec; device_expert_list_.emplace_back(i); if (FLAGS_enable_eplb && (i + 1) % num_device_route_expert == 0) { - device_expert_list_.emplace_back(i); + for (int redundant_expert = 0; redundant_expert < redundant_experts_num_; + ++redundant_expert) + device_expert_list_.emplace_back(i); } } } @@ -527,7 +530,7 @@ void DeepseekV2DecoderImpl::initialize_mlp_parameters( if (FLAGS_enable_eplb) { param.enableExpertCumSumOutput = param.isPrefill ? false : true; param.enableEPWB = true; - param.numOfRedundantExpert = ep_size_; + param.numOfRedundantExpert = ep_size_ * redundant_experts_num_; } } if (layer_id_ < param.firstKDenseReplace) { @@ -681,37 +684,19 @@ void DeepseekV2DecoderImpl::process_expert_weights( const StateDict& state_dict, const std::string& name, const torch::Tensor& tensor) { + // Step 1: Early checks and basic info extraction int expert_index = extract_expert_index(name); const std::string suffix = extract_endswith(name); const int index = get_mapped_index(suffix, WEIGHT_MAPPING_W8A8); if (index == -1) { return; } + const bool is_sharded = WEIGHT_SHARD_W8A8.count(index); - if (FLAGS_enable_eplb && - (rank_ % localWorldSize_ == expert_index % localWorldSize_)) { - std::lock_guard lock(experts_mutex_); - torch::Tensor tmp_tensor_shm = - is_sharded ? get_sharded_tensor(state_dict, - name, - WEIGHT_SHARD_W8A8.at(index), - ep_local_tp_rank_, - ep_local_tp_size_) - : tensor; - std::string shm_key = get_expert_shm_key(layer_id_, expert_index, suffix); - if (!decode_param_.isBF16) { - if (absl::EndsWith(name, "_offset")) { - tmp_tensor_shm = tmp_tensor_shm.to(torch::kFloat16); - } else if (absl::EndsWith(name, "_scale")) { - tmp_tensor_shm = tmp_tensor_shm.to(torch::kFloat32); - } - } - shared_buffer_->add_tensor(expert_index, - layer_id_ - first_k_dense_replace_, - shm_key, - tmp_tensor_shm.contiguous()); - // all_experts_weights_buffer_[shm_key].emplace_back(tmp_tensor.clone()); - } + const bool needs_eplb = FLAGS_enable_eplb && (rank_ % localWorldSize_ == + expert_index % localWorldSize_); + + // Step 2: Check if expert is in partition const int start_idx = ep_rank_ * num_experts_per_partition_; const int end_idx = (ep_rank_ + 1) * num_experts_per_partition_; const int safe_end = @@ -720,29 +705,61 @@ void DeepseekV2DecoderImpl::process_expert_weights( auto it = std::find(device_expert_list_.begin() + start_idx, device_expert_list_.begin() + safe_end, expert_index); - if (it == device_expert_list_.begin() + safe_end) { + const bool in_partition = it != device_expert_list_.begin() + safe_end; + + // Early return if neither EPLB nor partition needs this expert + if (!needs_eplb && !in_partition) { return; } - std::vector matches_pos; - for (auto iter = device_expert_list_.begin() + start_idx; - iter != device_expert_list_.begin() + safe_end; - ++iter) { - if (*iter == expert_index) { - matches_pos.emplace_back( - std::distance(device_expert_list_.begin(), iter) - start_idx); + + // Step 3: Process tensor + torch::Tensor processed_tensor; + { + std::lock_guard lock(experts_mutex_); + processed_tensor = is_sharded + ? get_sharded_tensor(state_dict, + name, + WEIGHT_SHARD_W8A8.at(index), + ep_local_tp_rank_, + ep_local_tp_size_) + : tensor; + + if (!decode_param_.isBF16) { + if (absl::EndsWith(name, "_offset")) { + processed_tensor = processed_tensor.to(torch::kFloat16); + } else if (absl::EndsWith(name, "_scale")) { + processed_tensor = processed_tensor.to(torch::kFloat32); + } } } - std::lock_guard lock(experts_mutex_); - torch::Tensor tmp_tensor = - is_sharded ? get_sharded_tensor(state_dict, - name, - WEIGHT_SHARD_W8A8.at(index), - ep_local_tp_rank_, - ep_local_tp_size_) - : tensor; - for (auto pos : matches_pos) { - experts_weights_[suffix][pos] = tmp_tensor.clone(); + // Step 4: Handle EPLB case + if (needs_eplb) { + std::lock_guard lock(experts_mutex_); + std::string shm_key = get_expert_shm_key(layer_id_, expert_index, suffix); + shared_buffer_->add_tensor(expert_index, + layer_id_ - first_k_dense_replace_, + shm_key, + processed_tensor.contiguous()); + } + + // Step 5: Handle partition case + if (in_partition) { + std::vector matches_pos; + for (auto iter = it; iter != device_expert_list_.begin() + safe_end; + ++iter) { + if (*iter == expert_index) { + matches_pos.emplace_back( + std::distance(device_expert_list_.begin(), iter) - start_idx); + } + } + + if (!matches_pos.empty()) { + std::lock_guard lock(experts_mutex_); + for (auto pos : matches_pos) { + experts_weights_[suffix][pos] = processed_tensor.clone(); + } + } } } @@ -1221,8 +1238,6 @@ torch::Tensor DeepseekV2DecoderImpl::merge_experts_weights( std::vector& experts_up, at::Device device, bool transpose) { - auto merge_experts_weights_sart = std::chrono::high_resolution_clock::now(); - for (size_t i = 0; i < experts_up.size(); ++i) { experts_gate[i] = torch::cat({experts_gate[i], experts_up[i]}, 0); } @@ -1248,13 +1263,9 @@ void DeepseekV2DecoderImpl::merge_and_copy_gate_up_weights( const int64_t up_dim = experts_up[0].size(0); const int64_t hidden_dim = experts_gate[0].size(1); - auto prepare_experts_weights_start = - std::chrono::high_resolution_clock::now(); target_buffer = at_npu::native::npu_format_cast(target_buffer.contiguous(), 2) .reshape({num_experts, gate_dim + up_dim, hidden_dim}); - prepare_experts_weights_start = std::chrono::high_resolution_clock::now(); - for (int64_t index = 0; index < num_experts; ++index) { target_buffer[index].slice(0, 0, gate_dim).copy_(experts_gate[index]); @@ -1281,9 +1292,6 @@ void DeepseekV2DecoderImpl::merge_and_copy_down_weights( void DeepseekV2DecoderImpl::prepare_expert_weight( const std::vector& expert_list) { - auto prepare_experts_weights_start = - std::chrono::high_resolution_clock::now(); - expert_routing_map_buffer_ = build_expert_routing_map(expert_list); auto& expert_buffer = ExpertBuffer::Instance(); @@ -1350,11 +1358,6 @@ void DeepseekV2DecoderImpl::prepare_expert_weight( expert_buffer.gateup_weight = at_npu::native::npu_format_cast(expert_buffer.gateup_weight, 29); - auto prepare_experts_weights_end = std::chrono::high_resolution_clock::now(); - auto prepare__experts_weights_duration = - std::chrono::duration_cast( - prepare_experts_weights_end - prepare_experts_weights_start) - .count(); } torch::Tensor DeepseekV2DecoderImpl::build_expert_routing_map( @@ -1366,22 +1369,20 @@ torch::Tensor DeepseekV2DecoderImpl::build_expert_routing_map( expert_routing_map[v].emplace_back(i); } + std::vector keys; + std::vector values; for (auto& [key, indices] : expert_routing_map) { int num_of_duplications = indices.size(); int selected_index = ep_rank_ % num_of_duplications; indices = {indices[selected_index]}; + + keys.emplace_back(key); + values.emplace_back(static_cast(indices[0])); } int64_t map_size = expert_routing_map.size(); auto options = torch::TensorOptions().dtype(torch::kInt32); auto input = torch::zeros({map_size}, options); - std::vector keys; - std::vector values; - - for (const auto& [k, v] : expert_routing_map) { - keys.emplace_back(k); - values.emplace_back(static_cast(v[0])); - } auto index_tensor = torch::tensor(keys, torch::kInt64); auto value_tensor = torch::tensor(values, torch::kInt32); diff --git a/xllm/core/layers/npu/deepseek_v2_decoder_layer.h b/xllm/core/layers/npu/deepseek_v2_decoder_layer.h index c9d3eefc..c95a3dd3 100644 --- a/xllm/core/layers/npu/deepseek_v2_decoder_layer.h +++ b/xllm/core/layers/npu/deepseek_v2_decoder_layer.h @@ -301,6 +301,7 @@ class DeepseekV2DecoderImpl : public torch::nn::Module, public ATBBase { int32_t start_expert_id_; int32_t end_expert_id_; int32_t ep_rank_; + int32_t redundant_experts_num_; int32_t dp_size_; int32_t dp_local_tp_size_; @@ -309,6 +310,7 @@ class DeepseekV2DecoderImpl : public torch::nn::Module, public ATBBase { float sm_scale_; int32_t num_speculative_tokens_ = 0; + atb_speed::deepseekV2::DecoderLayerParam prefill_param_; atb_speed::deepseekV2::DecoderLayerParam decode_param_; diff --git a/xllm/core/runtime/llm_engine.cpp b/xllm/core/runtime/llm_engine.cpp index 39877fa6..9037f131 100644 --- a/xllm/core/runtime/llm_engine.cpp +++ b/xllm/core/runtime/llm_engine.cpp @@ -125,15 +125,8 @@ bool LLMEngine::init_model() { if (FLAGS_enable_eplb) { int32_t num_layers = args_.n_layers() - args_.first_k_dense_replace(); int32_t num_experts = args_.n_routed_experts(); - expert_load_data_ = - torch::zeros({num_layers, num_experts + worker_clients_.size()}) - .to(torch::kInt64); - eplb_policy_ = - std::make_unique(num_experts / worker_clients_.size() + 1, - worker_clients_.size(), - num_layers); eplb_manager_ = std::make_unique( - eplb_policy_.get(), num_layers, worker_clients_.size(), num_experts); + num_layers, worker_clients_.size(), num_experts); } // key + value for all layers @@ -643,9 +636,10 @@ void LLMEngine::process_eplb_data( int32_t worker_clients_num) { int32_t num_layers = args_.n_layers() - args_.first_k_dense_replace(); int32_t num_device_experts = - args_.n_routed_experts() / worker_clients_.size() + 1; + args_.n_routed_experts() / worker_clients_.size() + + FLAGS_redundant_experts_num; std::vector tensors; - std::vector layer_ids(num_device_experts - 1, -1); + std::vector layer_ids(results.size(), -1); tensors.reserve(worker_clients_.size()); for (size_t worker_rank = 0; worker_rank < results.size(); ++worker_rank) { auto result = results[worker_rank].value(); diff --git a/xllm/core/runtime/llm_engine.h b/xllm/core/runtime/llm_engine.h index ede2cd5b..f5f9c37b 100644 --- a/xllm/core/runtime/llm_engine.h +++ b/xllm/core/runtime/llm_engine.h @@ -127,7 +127,6 @@ class LLMEngine : public Engine { std::shared_ptr dist_manager_ = nullptr; std::unique_ptr eplb_manager_ = nullptr; - std::unique_ptr eplb_policy_ = nullptr; void process_eplb_data( const std::vector>>& results, int32_t worker_clients_num); diff --git a/xllm/core/runtime/master.cpp b/xllm/core/runtime/master.cpp index 94860b4d..1b705688 100644 --- a/xllm/core/runtime/master.cpp +++ b/xllm/core/runtime/master.cpp @@ -68,8 +68,11 @@ Master::Master(const Options& options, EngineType type) : options_(options) { if (options.enable_eplb().has_value()) { FLAGS_enable_eplb = options.enable_eplb().value(); } - if (options.eplb_update_rate().has_value()) { - FLAGS_eplb_update_rate = options.eplb_update_rate().value(); + if (options.redundant_experts_num().has_value()) { + FLAGS_redundant_experts_num = options.redundant_experts_num().value(); + } + if (options.eplb_update_interval().has_value()) { + FLAGS_eplb_update_interval = options.eplb_update_interval().value(); } if (options.eplb_update_threshold().has_value()) { FLAGS_eplb_update_threshold = options.eplb_update_threshold().value(); diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 08bc9462..c6b97c09 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -270,7 +270,6 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, std::vector(pb_forward_input->eplb_info().expert_ids().begin(), pb_forward_input->eplb_info().expert_ids().end()); eplb_info.update_layer_id = pb_forward_input->eplb_info().update_layer_id(); - forward_inputs.eplb_info = eplb_info; COUNTER_ADD(proto_latency_seconds_proto2i, timer.elapsed_seconds()); } diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 21875d58..d5136670 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -289,7 +289,7 @@ void WorkerImpl::update_last_step_output( last_step_output_ = std::move(output.value()); last_step_output_valid_ = true; } else { - if(FLAGS_enable_eplb) { + if (FLAGS_enable_eplb) { last_step_output_ = std::move(output.value()); } last_step_output_valid_ = false; @@ -465,7 +465,8 @@ bool WorkerImpl::init_model(const std::string& model_weights_path) { if (FLAGS_enable_eplb) { int32_t num_layers = args.n_layers() - args.first_k_dense_replace(); int32_t num_device_experts = - args.n_routed_experts() / context_.get_parallel_args().world_size() + 1; + args.n_routed_experts() / context_.get_parallel_args().world_size() + + FLAGS_redundant_experts_num; expert_load_data_ = torch::zeros({num_layers, num_device_experts}) .to(torch::kInt64) .to(device_) diff --git a/xllm/models/qwen3_embedding.h b/xllm/models/qwen3_embedding.h index eda664d4..748b488b 100644 --- a/xllm/models/qwen3_embedding.h +++ b/xllm/models/qwen3_embedding.h @@ -70,6 +70,12 @@ class EmbeddingLMImpl : public EmbeddingLM { return model_->options(); } + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + virtual void update_expert_weight(int32_t layer_id) { return; } + // Delegate head/embedding accessors to underlying model implementation. hf::LlmHead get_lm_head() override { return model_->get_lm_head(); } void set_lm_head(hf::LlmHead& head) override { model_->set_lm_head(head); } diff --git a/xllm/xllm.cpp b/xllm/xllm.cpp index bcf455cd..204ae2d1 100644 --- a/xllm/xllm.cpp +++ b/xllm/xllm.cpp @@ -114,7 +114,8 @@ int run() { .num_handling_threads(FLAGS_num_handling_threads) .communication_backend(FLAGS_communication_backend) .enable_eplb(FLAGS_enable_eplb) - .eplb_update_rate(FLAGS_eplb_update_rate) + .enable_eplb(FLAGS_redundant_experts_num) + .eplb_update_interval(FLAGS_eplb_update_interval) .eplb_update_threshold(FLAGS_eplb_update_threshold) .rank_tablefile(FLAGS_rank_tablefile) .expert_parallel_degree(FLAGS_expert_parallel_degree)