20
20
Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock ,
21
21
)
22
22
23
+ from llmcompressor .modeling .config import CalibrationConfig
24
+
23
25
24
26
class Qwen3MoeSparseMoeBlock (torch .nn .Module ):
25
27
def __init__ (
26
- self , config : Qwen3MoeConfig , original : OriginalQwen3MoeSparseMoeBlock
28
+ self ,
29
+ config : Qwen3MoeConfig ,
30
+ original : OriginalQwen3MoeSparseMoeBlock ,
31
+ calib_config : CalibrationConfig ,
27
32
):
28
33
super ().__init__ ()
29
34
self .num_experts = config .num_experts
30
35
self .top_k = config .top_k
31
36
self .norm_topk_prob = config .norm_topk_prob
37
+ self .calib_config = calib_config
32
38
33
39
# gating
34
40
self .gate = original .gate
35
41
self .experts = original .experts
36
42
43
+ if not self .calib_config .moe_calibrate_gated_acts :
44
+ self .gate .top_k = self .num_experts # ungate experts
45
+
37
46
def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
38
47
batch_size , sequence_length , hidden_dim = hidden_states .shape
39
48
hidden_states = hidden_states .view (- 1 , hidden_dim )
40
49
# router_logits: (batch * sequence_length, n_experts)
41
50
router_logits = self .gate (hidden_states )
42
51
43
- routing_weights = torch .nn .functional .softmax (
44
- router_logits , dim = 1 , dtype = torch .float
45
- )
46
- routing_weights , selected_experts = torch .topk (
47
- routing_weights , self .top_k , dim = - 1
48
- )
49
- if self .norm_topk_prob : # only diff with mixtral sparse moe block!
50
- routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
51
- # we cast back to the input dtype
52
- routing_weights = routing_weights .to (hidden_states .dtype )
52
+ if self .calib_config .moe_calibrate_gated_acts :
53
+ routing_weights = torch .nn .functional .softmax (
54
+ router_logits , dim = 1 , dtype = torch .float
55
+ )
56
+ routing_weights , selected_experts = torch .topk (
57
+ routing_weights , self .top_k , dim = - 1
58
+ )
59
+ # only diff with mixtral sparse moe block!
60
+ if self .norm_topk_prob :
61
+ routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
62
+ # we cast back to the input dtype
63
+ routing_weights = routing_weights .to (hidden_states .dtype )
64
+
65
+ else :
66
+ # ungate experts
67
+ selected_experts = torch .arange (
68
+ self .num_experts , device = hidden_states .device
69
+ )
70
+ selected_experts = selected_experts .unsqueeze (0 ).expand (
71
+ hidden_states .shape [0 ], - 1
72
+ )
73
+ routing_weights = (
74
+ torch .ones_like (selected_experts , dtype = hidden_states .dtype )
75
+ / self .num_experts
76
+ )
77
+
53
78
final_hidden_states = torch .zeros (
54
79
(batch_size * sequence_length , hidden_dim ),
55
80
dtype = hidden_states .dtype ,
@@ -65,23 +90,37 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
65
90
for expert_idx in range (len (self .experts )):
66
91
expert_layer = self .experts [expert_idx ]
67
92
idx , top_x = torch .where (expert_mask [expert_idx ].squeeze (0 ))
68
- # Index the correct hidden states and compute the expert hidden state for
69
- # the current expert. We need to make sure to multiply the output hidden
70
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
71
- current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_dim )
72
- expert_output = expert_layer (current_state )
73
- current_hidden_states = expert_output * routing_weights [top_x , idx , None ]
74
- # However `index_add_` only support torch tensors for indexing so we'll use
75
- # the `top_x` tensor here.
76
- final_hidden_states .index_add_ (
77
- 0 , top_x , current_hidden_states .to (hidden_states .dtype )
78
- )
93
+
94
+ has_tokens = idx .numel () > 0
95
+
96
+ if self .calib_config .moe_calibrate_all_experts or has_tokens :
97
+ # Index the correct hidden states and compute the expert hidden state for
98
+ # the current expert. We need to make sure to multiply the output hidden
99
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
100
+ current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_dim )
101
+ expert_output = expert_layer (current_state )
102
+ current_hidden_states = (
103
+ expert_output * routing_weights [top_x , idx , None ]
104
+ )
105
+
106
+ # However `index_add_` only support torch tensors for indexing so we'll use
107
+ # the `top_x` tensor here.
108
+ if has_tokens and self .calib_config .moe_calibrate_gated_acts :
109
+ final_hidden_states .index_add_ (
110
+ 0 , top_x , current_hidden_states .to (hidden_states .dtype )
111
+ )
79
112
80
113
final_hidden_states = final_hidden_states .reshape (
81
114
batch_size , sequence_length , hidden_dim
82
115
)
83
116
return final_hidden_states , router_logits
84
117
85
118
86
- def replace (config : Qwen3MoeConfig , module : OriginalQwen3MoeSparseMoeBlock ):
87
- return Qwen3MoeSparseMoeBlock (config = config , original = module )
119
+ def replace (
120
+ config : Qwen3MoeConfig ,
121
+ module : OriginalQwen3MoeSparseMoeBlock ,
122
+ calib_config : CalibrationConfig ,
123
+ ):
124
+ return Qwen3MoeSparseMoeBlock (
125
+ config = config , original = module , calib_config = calib_config
126
+ )
0 commit comments