Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/en/features/eplb.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ Simply add the following gflag parameters when launching xLLM:

- xLLM provides the gflag parameter `enable_eplb` (default: false). Set to true in the xLLM service startup script to enable dynamic expert load balancing.
- `expert_parallel_degree` and `ep_size` are MoE-related parameters. `expert_parallel_degree` should be set to `2`, and `ep_size` must match the actual number of NPU/GPU devices. See [moe_params](./moe_params.md)
- `eplb_update_rate` sets the expert distribution update interval in seconds (default: 1000).
- `eplb_update_interval` sets the expert distribution update interval in seconds (default: 1000).
- The expert distribution update uses a layer-by-layer mechanism based on expert load. When the similarity between consecutive loads for a layer is below `eplb_update_threshold`, that layer is updated (default: 1, range: 0-1).

```bash
--enable_eplb=true --expert_parallel_degree=2 --ep_size=16 --eplb_update_rate=2000 --eplb_update_threshold=0.9
--enable_eplb=true --expert_parallel_degree=2 --ep_size=16 --eplb_update_interval=2000 --eplb_update_threshold=0.9
6 changes: 3 additions & 3 deletions docs/zh/features/eplb.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ xLLM eplb功能主要通过以下三个模块实现:

- xLLM中提供了gflags参数`enable_eplb`,默认false,如需开启动态专家负载均衡,在xLLM的服务启动脚本中设置为true即可。
- `expert_parallel_degree`与`ep_size`为moe相关参数,`expert_parallel_degree`需要设置为`2`,`ep_size`要与实际NPU/GPU卡个数保持一致。参考 [moe_params](./moe_params.md)
- `eplb_update_rate`为专家分布更新时间间隔,单位为妙,默认值为1000.
- 专家分布更新采用根据专家负载的逐层更新机制,当某一层专家的前后两次的负载相似度小于`eplb_update_threshold`时选择更新该层,默认值为1,取之范围为(0,1)。
- `eplb_update_interval`为专家分布更新时间间隔,单位为妙,默认值为1000.
- 专家分布更新采用根据专家负载的逐层更新机制,当某一层专家的前后两次的负载相似度小于`eplb_update_interval`时选择更新该层,默认值为1,取之范围为(0,1)。

```bash
--enable_eplb=true
--expert_parallel_degree=2
--ep_size=16
--eplb_update_rate=2000
--eplb_update_interval=2000
--eplb_update_threshold=0.9
```

Expand Down
10 changes: 10 additions & 0 deletions xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ DEFINE_double(prefill_scheduling_memory_usage_threshold,

DEFINE_string(communication_backend, "hccl", "npu communication backend.");

DEFINE_bool(enable_eplb, false, "Whether to use ep load balance.");

DEFINE_int32(redundant_experts_num,
1,
"num of redundant experts on per device.");

DEFINE_int64(eplb_update_interval, 1000, "eplb update rate.");

DEFINE_double(eplb_update_threshold, 0.8, "eplb update threshold.");

DEFINE_string(rank_tablefile, "", "atb hccl rank table file.");

DEFINE_int32(expert_parallel_degree, 0, "ep degree");
Expand Down
8 changes: 8 additions & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ DECLARE_int32(num_response_handling_threads);

DECLARE_string(communication_backend);

DECLARE_bool(enable_eplb);

DECLARE_int32(redundant_experts_num);

DECLARE_int64(eplb_update_interval);

DECLARE_double(eplb_update_threshold);

DECLARE_string(rank_tablefile);

DECLARE_bool(enable_mla);
Expand Down
8 changes: 8 additions & 0 deletions xllm/core/common/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ class Options {
// thread num to handle requests
PROPERTY(size_t, num_handling_threads) = 4;

PROPERTY(std::optional<bool>, enable_eplb);

PROPERTY(std::optional<int32_t>, redundant_experts_num);

PROPERTY(std::optional<int64_t>, eplb_update_interval);

PROPERTY(std::optional<double>, eplb_update_threshold);

PROPERTY(std::optional<std::string>, communication_backend);

PROPERTY(std::optional<std::string>, rank_tablefile);
Expand Down
14 changes: 14 additions & 0 deletions xllm/core/common/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,5 +251,19 @@ struct JsonTool {
JsonTool(const std::string& tool_type, const JsonFunction& func)
: type(tool_type), function(func) {}
};
// Experts update the required information
struct EplbInfo {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments for struct / class and its public fields / methods please.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Added detailed comments.

// Target layer ID for new expert weight pre-loading (-1 = no pending load)
// Values >=0 indicate the layer ID that should start loading new expert
// weights
int32_t prepare_layer_id = -1;
// Expert IDs requiring updates, ordered by device shard assignment
// Contains per-device expert indices for distributed weight updates
std::vector<int32_t> expert_ids;
// Layer ID ready for expert weight activation (-1 = no pending update)
// Values >=0 indicate the layer ID whose pre-loaded weights are ready for
// deployment
int32_t update_layer_id = -1;
};

} // namespace xllm
40 changes: 29 additions & 11 deletions xllm/core/distributed_runtime/worker_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include <torch_npu/torch_npu.h>
#endif

#include "common/global_flags.h"
#include "common/metrics.h"
#include "framework/request/sequence.h"
#include "framework/sampling/sampling_params.h"
Expand Down Expand Up @@ -331,6 +332,8 @@ void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller,
torch::Tensor top_tokens;
torch::Tensor top_logprobs;
torch::Tensor embeddings;
torch::Tensor expert_load_data;
int32_t prepared_layer_id = -1;

// execute model
auto future = worker_->step_async(forward_inputs);
Expand All @@ -341,6 +344,10 @@ void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller,
if (forward_outputs) {
DCHECK(forward_outputs.has_value()) << "Failed to execute model";
const auto& sample_output = forward_outputs.value().sample_output;
expert_load_data = safe_to(
forward_outputs.value().expert_load_data, torch::kCPU, true);
prepared_layer_id = forward_outputs.value().prepared_layer_id;

{
#if defined(USE_NPU)
c10::StreamGuard streamGuard(
Expand Down Expand Up @@ -376,22 +383,28 @@ void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller,
#endif
}
}
} else if (worker_->is_driver()) {
// construct fake output tensor
auto options =
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
int32_t prefill_seq_len =
static_cast<int32_t>(pb_forward_input->prefill_seq_len());
next_tokens = torch::arange(
-1, -1 * (num_sequences - prefill_seq_len + 1), -1, options);
std::move(future).deferValue([](auto&&) {});
} else {
if (worker_->is_driver()) {
// construct fake output tensor
auto options =
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
int32_t prefill_seq_len =
static_cast<int32_t>(pb_forward_input->prefill_seq_len());
next_tokens = torch::arange(
-1, -1 * (num_sequences - prefill_seq_len + 1), -1, options);
std::move(future).deferValue([](auto&&) {});
}
expert_load_data =
torch::zeros({1, 1}).to(torch::kInt64).contiguous();
}

forward_output_to_proto(next_tokens,
logprobs,
top_tokens,
top_logprobs,
embeddings,
expert_load_data,
prepared_layer_id,
pb_forward_output);
COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds());
});
Expand All @@ -415,6 +428,9 @@ void WorkerService::GetLastStepResult(
auto forward_outputs = std::move(future).get();
if (forward_outputs) {
const auto& sample_output = forward_outputs.value().sample_output;
const auto& expert_load_data = safe_to(
forward_outputs.value().expert_load_data, torch::kCPU, true);
int32_t prepared_layer_id = forward_outputs.value().prepared_layer_id;
#if defined(USE_NPU)
c10::StreamGuard streamGuard(
npu_stream_helper_->D2H_memcpy_stream.unwrap());
Expand All @@ -429,8 +445,8 @@ void WorkerService::GetLastStepResult(
// [num_seq]
const auto& next_tokens =
safe_to(sample_output.next_tokens, torch::kCPU, true);
if (next_tokens.defined()) {
// [num_seq]
if (next_tokens.defined() || FLAGS_enable_eplb) {
// [num_seq] FloatTensor
const auto& logprobs =
safe_to(sample_output.logprobs, torch::kCPU, true);
// [num_seq, topk]
Expand All @@ -451,6 +467,8 @@ void WorkerService::GetLastStepResult(
top_tokens,
top_logprobs,
embeddings,
expert_load_data,
prepared_layer_id,
pb_forward_output);
}
}
Expand Down
1 change: 1 addition & 0 deletions xllm/core/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_subdirectory(request)
add_subdirectory(sampling)
add_subdirectory(state_dict)
add_subdirectory(tokenizer)
add_subdirectory(eplb)

cc_library(
NAME
Expand Down
50 changes: 50 additions & 0 deletions xllm/core/framework/eplb/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
include(cc_binary)
include(cc_library)
include(cc_test)

include_directories(
${CMAKE_SOURCE_DIR}/xllm/core/kernels/ascend
${CMAKE_SOURCE_DIR}/xllm/core/kernels/ascend/core/include
)

cc_library(
NAME
eplb
HDRS
eplb_executor.h
eplb_manager.h
eplb_policy.h
expert_weight_buffer_shm.h
shared_memory_manager.h
expert_buffer_manager.h
SRCS
eplb_executor.cpp
eplb_manager.cpp
eplb_policy.cpp
expert_weight_buffer_shm.cpp
shared_memory_manager.cpp
expert_buffer_manager.cpp
DEPS
torch_npu
llm_engine
:request
:common
glog::glog
torch
)

set(TEST_SRCS
eplb_policy_test.cpp
)

cc_test(
NAME
eplb_policy_test
SRCS
${TEST_SRCS}
DEPS
torch
:eplb
GTest::gtest_main
)

114 changes: 114 additions & 0 deletions xllm/core/framework/eplb/eplb_executor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#include "eplb_executor.h"

#include <c10/core/Device.h>
#include <c10/core/TensorOptions.h>
#include <glog/logging.h>
#if defined(USE_NPU)
#include <torch_npu/csrc/core/npu/NPUFormat.h>
#include <torch_npu/csrc/core/npu/NPUFunctions.h>
#include <torch_npu/csrc/framework/OpCommand.h>
#include <torch_npu/torch_npu.h>
#endif
#include <condition_variable>
#include <functional>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>

#include "runtime/forward_params.h"

namespace xllm {
#if defined(USE_NPU)
struct EplbExecutor::EplbStream {
c10_npu::NPUStream eplb_stream;
EplbStream() : eplb_stream(c10_npu::getNPUStreamFromPool()) {}
};
#endif
EplbExecutor::EplbExecutor(CausalLM* model)
: model_(model), eplb_worker_(&EplbExecutor::eplb_worker_loop, this) {
#if defined(USE_NPU)
eplb_stream_ = std::make_unique<EplbStream>();
#endif
}

EplbExecutor::~EplbExecutor() {
{
std::unique_lock<std::mutex> lock(queue_mutex_);
stop_ = true;
}
condition_.notify_one();
if (eplb_worker_.joinable()) {
eplb_worker_.join();
}
}

void EplbExecutor::eplb_execute(const EplbInfo& eplb_info) {
if (eplb_info.update_layer_id != -1) {
model_->update_expert_weight(eplb_info.update_layer_id);
};
if (eplb_info.prepare_layer_id != -1) {
prepare_expert_weight_async(
eplb_info.prepare_layer_id,
eplb_info.expert_ids,
[eplb_info](int32_t id) {
LOG(INFO) << "prepare expert weight complete, layer: "
<< eplb_info.prepare_layer_id << std::endl;
});
};
}

void EplbExecutor::prepare_expert_weight_async(
int32_t layer_id,
const std::vector<int32_t>& expert_ids,
Callback callback) {
{
std::unique_lock<std::mutex> lock(queue_mutex_);
tasks_.emplace(Task{layer_id, expert_ids, callback});
}
condition_.notify_one();
}

int32_t EplbExecutor::get_ready_layer_id() const {
std::lock_guard<std::mutex> lock(ready_mutex_);
return ready_layer_id_;
}

void EplbExecutor::reset_ready_layer_id() {
std::lock_guard<std::mutex> lock(ready_mutex_);
ready_layer_id_ = -1;
}

void EplbExecutor::eplb_worker_loop() {
while (true) {
Task task;
{
std::unique_lock<std::mutex> lock(queue_mutex_);
condition_.wait(lock, [this] { return !tasks_.empty() || stop_; });
if (stop_) return;
task = std::move(tasks_.front());
tasks_.pop();
}
auto prepare_start = std::chrono::high_resolution_clock::now();

c10::StreamGuard streamGuard(eplb_stream_->eplb_stream.unwrap());
model_->prepare_expert_weight(task.layer_id, task.expert_ids);
aclrtSynchronizeStream(eplb_stream_->eplb_stream.stream());
auto prepare_end = std::chrono::high_resolution_clock::now();
auto prepare_duration =
std::chrono::duration_cast<std::chrono::milliseconds>(prepare_end -
prepare_start)
.count();
LOG(INFO) << "prepare_expert_weight | layer=" << task.layer_id
<< " | experts=" << task.expert_ids.size()
<< " | duration=" << prepare_duration << "ms";
{
std::lock_guard<std::mutex> lock(ready_mutex_);
ready_layer_id_ = task.layer_id;
}
if (task.callback) {
task.callback(task.layer_id);
}
}
}
} // namespace xllm
Loading