Skip to content

Commit bfd630d

Browse files
committed
feat: support variable number of redundant expert.
1 parent 25b0d5f commit bfd630d

21 files changed

+169
-118
lines changed

docs/en/features/eplb.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ Simply add the following gflag parameters when launching xLLM:
2323

2424
- xLLM provides the gflag parameter `enable_eplb` (default: false). Set to true in the xLLM service startup script to enable dynamic expert load balancing.
2525
- `expert_parallel_degree` and `ep_size` are MoE-related parameters. `expert_parallel_degree` should be set to `2`, and `ep_size` must match the actual number of NPU/GPU devices. See [moe_params](./moe_params.md)
26-
- `eplb_update_rate` sets the expert distribution update interval in seconds (default: 1000).
26+
- `eplb_update_interval` sets the expert distribution update interval in seconds (default: 1000).
2727
- The expert distribution update uses a layer-by-layer mechanism based on expert load. When the similarity between consecutive loads for a layer is below `eplb_update_threshold`, that layer is updated (default: 1, range: 0-1).
2828

2929
```bash
30-
--enable_eplb=true --expert_parallel_degree=2 --ep_size=16 --eplb_update_rate=2000 --eplb_update_threshold=0.9
30+
--enable_eplb=true --expert_parallel_degree=2 --ep_size=16 --eplb_update_interval=2000 --eplb_update_threshold=0.9

docs/zh/features/eplb.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ xLLM eplb功能主要通过以下三个模块实现:
1818

1919
- xLLM中提供了gflags参数`enable_eplb`,默认false,如需开启动态专家负载均衡,在xLLM的服务启动脚本中设置为true即可。
2020
- `expert_parallel_degree``ep_size`为moe相关参数,`expert_parallel_degree`需要设置为`2``ep_size`要与实际NPU/GPU卡个数保持一致。参考 [moe_params](./moe_params.md)
21-
- `eplb_update_rate`为专家分布更新时间间隔,单位为妙,默认值为1000.
22-
- 专家分布更新采用根据专家负载的逐层更新机制,当某一层专家的前后两次的负载相似度小于`eplb_update_threshold`时选择更新该层,默认值为1,取之范围为(0,1)。
21+
- `eplb_update_interval`为专家分布更新时间间隔,单位为妙,默认值为1000.
22+
- 专家分布更新采用根据专家负载的逐层更新机制,当某一层专家的前后两次的负载相似度小于`eplb_update_interval`时选择更新该层,默认值为1,取之范围为(0,1)。
2323

2424
```bash
2525
--enable_eplb=true
2626
--expert_parallel_degree=2
2727
--ep_size=16
28-
--eplb_update_rate=2000
28+
--eplb_update_interval=2000
2929
--eplb_update_threshold=0.9
3030
```
3131

xllm/core/common/global_flags.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,11 @@ DEFINE_string(communication_backend, "hccl", "npu communication backend.");
118118

119119
DEFINE_bool(enable_eplb, false, "Whether to use ep load balance.");
120120

121-
DEFINE_int64(eplb_update_rate, 1000, "eplb update rate.");
121+
DEFINE_int32(redundant_experts_num,
122+
1,
123+
"num of redundant experts on per device.");
124+
125+
DEFINE_int64(eplb_update_interval, 1000, "eplb update rate.");
122126

123127
DEFINE_double(eplb_update_threshold, 0.8, "eplb update threshold.");
124128

xllm/core/common/global_flags.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ DECLARE_string(communication_backend);
6969

7070
DECLARE_bool(enable_eplb);
7171

72-
DECLARE_int64(eplb_update_rate);
72+
DECLARE_int32(redundant_experts_num);
73+
74+
DECLARE_int64(eplb_update_interval);
7375

7476
DECLARE_double(eplb_update_threshold);
7577

xllm/core/common/options.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ class Options {
7272

7373
PROPERTY(std::optional<bool>, enable_eplb);
7474

75-
PROPERTY(std::optional<int64_t>, eplb_update_rate);
75+
PROPERTY(std::optional<int32_t>, redundant_experts_num);
76+
77+
PROPERTY(std::optional<int64_t>, eplb_update_interval);
7678

7779
PROPERTY(std::optional<double>, eplb_update_threshold);
7880

xllm/core/common/types.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,18 @@ struct JsonTool {
251251
JsonTool(const std::string& tool_type, const JsonFunction& func)
252252
: type(tool_type), function(func) {}
253253
};
254-
254+
//Experts update the required information
255255
struct EplbInfo {
256+
// Target layer ID for new expert weight pre-loading (-1 = no pending load)
257+
// Values >=0 indicate the layer ID that should start loading new expert
258+
// weights
256259
int32_t prepare_layer_id = -1;
260+
// Expert IDs requiring updates, ordered by device shard assignment
261+
// Contains per-device expert indices for distributed weight updates
257262
std::vector<int32_t> expert_ids;
263+
// Layer ID ready for expert weight activation (-1 = no pending update)
264+
// Values >=0 indicate the layer ID whose pre-loaded weights are ready for
265+
// deployment
258266
int32_t update_layer_id = -1;
259267
};
260268

xllm/core/framework/eplb/eplb_executor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,16 @@ class EplbExecutor final {
1818
EplbExecutor(CausalLM* model);
1919

2020
virtual ~EplbExecutor();
21+
22+
// Reset the ready layer ID marker to -1 (no layer ready)
2123
void reset_ready_layer_id();
24+
25+
// Get the currently ready layer ID
26+
// return int32_t Layer ID with prepared weights (-1 if none)
2227
int32_t get_ready_layer_id() const;
28+
29+
// Execute EPLB operation based on coordination info
30+
// param eplb_info Contains layer preparation/activation instructions
2331
void eplb_execute(const EplbInfo& eplb_info);
2432

2533
private:

xllm/core/framework/eplb/eplb_manager.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,18 @@ namespace xllm {
2020

2121
using namespace std::chrono_literals;
2222

23-
EplbManager::EplbManager(EplbPolicy* eplb_policy,
24-
int32_t layer_num,
23+
EplbManager::EplbManager(int32_t layer_num,
2524
int32_t device_num,
2625
int32_t experts_num)
27-
: eplb_policy_(eplb_policy),
28-
layer_num_(layer_num),
26+
: layer_num_(layer_num),
2927
device_num_(device_num),
3028
experts_num_(experts_num),
31-
device_experts_num_((experts_num + device_num) / device_num) {
29+
device_experts_num_(experts_num / device_num +
30+
FLAGS_redundant_experts_num) {
3231
// Initialize tensors with mutex protection
3332
{
33+
eplb_policy_ = std::make_unique<EplbPolicy>(
34+
device_experts_num_, device_num_, layer_num_);
3435
std::lock_guard<std::mutex> lock(state_.mtx);
3536
state_.expert_load =
3637
torch::zeros({layer_num_, experts_num_}, torch::kInt64);
@@ -39,11 +40,13 @@ EplbManager::EplbManager(EplbPolicy* eplb_policy,
3940
{layer_num_, device_num_, device_experts_num_}, torch::kInt32);
4041
for (int32_t layer = 0; layer < layer_num_; ++layer) {
4142
for (int32_t device = 0; device < device_num_; ++device) {
42-
int32_t base = device * (device_experts_num_ - 1);
43+
int32_t device_route_experts_num =
44+
device_experts_num_ - FLAGS_redundant_experts_num;
45+
int32_t base = device * device_route_experts_num;
4346
for (int32_t expert = 0; expert < device_experts_num_; ++expert) {
4447
int32_t value = base + expert;
45-
if (expert == device_experts_num_ - 1) {
46-
--value;
48+
if (expert >= device_route_experts_num) {
49+
value = base + device_route_experts_num - 1;
4750
}
4851
state_.expert_distribution[layer][device][expert] = value;
4952
}
@@ -105,7 +108,6 @@ void EplbManager::aggregate_multi_layer_expert_loads(
105108
layer_ids.emplace_back(ids.flatten().to(torch::kInt64));
106109
layer_loads.emplace_back(loads.flatten().to(torch::kInt64));
107110
}
108-
109111
torch::Tensor all_ids = torch::cat(layer_ids);
110112
torch::Tensor all_loads = torch::cat(layer_loads);
111113
expert_load[layer].scatter_add_(0, all_ids, all_loads);
@@ -125,14 +127,12 @@ void EplbManager::rebalance_experts_loop() {
125127
if (state_.stop) return;
126128

127129
while (!state_.expert_load_queue.empty()) {
128-
// expert_load_batch.emplace_back(state_.expert_load_queue.front());
129-
// state_.expert_load_queue.pop();
130130
aggregate_multi_layer_expert_loads(state_.expert_load,
131131
state_.expert_distribution,
132132
state_.expert_load_queue.front());
133133
state_.expert_load_queue.pop();
134134
int64_t current_time = absl::ToUnixSeconds(absl::Now());
135-
if (current_time - latest_record_time >= FLAGS_eplb_update_rate) {
135+
if (current_time - latest_record_time >= FLAGS_eplb_update_interval) {
136136
latest_record_time = current_time;
137137
auto result = eplb_policy_->rebalance_experts(state_.expert_load);
138138
state_.expert_distribution = result.first;

xllm/core/framework/eplb/eplb_manager.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,24 @@ namespace xllm {
1313

1414
class EplbManager {
1515
public:
16-
EplbManager(EplbPolicy* eplb_policy,
17-
int32_t layer_num,
18-
int32_t device_num,
19-
int32_t experts_num);
16+
// Initialize with model dimensions:
17+
// - layer_num: Total layers in the model
18+
// - device_num: Parallel devices in cluster
19+
// - experts_num: Experts per model layer
20+
EplbManager(int32_t layer_num, int32_t device_num, int32_t experts_num);
21+
2022
~EplbManager();
2123

24+
// Feed new expert workload data for load balancing
25+
// Input tensors should have shape [layer_num, experts_num]
2226
void update_expert_load(const std::vector<torch::Tensor> expert_load);
27+
28+
// Fetch current coordination instructions for expert updates
29+
// Returns struct containing layer preparation/activation commands
2330
EplbInfo get_eplb_info();
31+
32+
// Mark specified layers as prepared (call after async loading completes)
33+
// expert_layer_ids: Prepared layer IDs per device
2434
void set_prepared_layer_ids(const std::vector<int32_t>& expert_layer_ids);
2535

2636
private:
@@ -49,7 +59,7 @@ class EplbManager {
4959
};
5060

5161
// Components
52-
EplbPolicy* eplb_policy_;
62+
std::unique_ptr<EplbPolicy> eplb_policy_ = nullptr;
5363
ThreadSafeData state_;
5464

5565
// Constants

xllm/core/framework/eplb/eplb_policy.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ EplbPolicy::EplbPolicy(int32_t device_experts_num,
1515
device_num_(device_num),
1616
layer_num_(layer_num) {
1717
old_expert_load_ =
18-
torch::zeros({layer_num_, device_experts_num * device_num - device_num},
18+
torch::zeros({layer_num_,
19+
device_experts_num * device_num -
20+
device_num * FLAGS_redundant_experts_num},
1921
torch::kInt64);
2022
expert_distribution_ = torch::full(
2123
{layer_num_, device_num_, device_experts_num_}, -1, torch::kInt32);
@@ -32,9 +34,7 @@ std::pair<torch::Tensor, std::vector<bool>> EplbPolicy::rebalance_experts(
3234
auto prev_max_val = torch::max(prev_load).item<double>() + 1e-6f;
3335

3436
current_load = (current_load / current_max_val).unsqueeze(0);
35-
;
3637
prev_load = (prev_load / prev_max_val).unsqueeze(0);
37-
;
3838

3939
auto cos_sim =
4040
torch::nn::functional::cosine_similarity(
@@ -65,8 +65,8 @@ torch::Tensor EplbPolicy::compute_balanced_pack(
6565
const int64_t num_experts = expert_loads.size(0);
6666

6767
// Generate Redundant Experts
68-
auto [updated_weights, redundancy_map] =
69-
update_origin_weights(expert_loads, device_num_);
68+
auto [updated_weights, redundancy_map] = update_origin_weights(
69+
expert_loads, device_num_ * FLAGS_redundant_experts_num);
7070

7171
// Initialize Allocation Matrix
7272
auto options = torch::TensorOptions().dtype(torch::kInt64);

0 commit comments

Comments
 (0)