diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index 287b343bd..8cc7f4727 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -10,12 +10,18 @@ class DeepseekV3MoECalibrate(torch.nn.Module): Patched DeepseekV3MoE which sends all tokens to all experts for calibration """ - def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE): + def __init__( + self, + config: DeepseekV3Config, + original: OriginalDeepseekV3MoE, + calibrate_all_experts: bool, + ): super().__init__() self.config = config self.experts = original.experts self.gate = original.gate self.shared_experts = original.shared_experts + self.calibrate_all_experts = calibrate_all_experts def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states @@ -30,18 +36,28 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ) expert_mask = expert_mask.permute(2, 0, 1) - for expert_idx in range(len(self.experts)): - expert = self.experts[expert_idx] - mask = expert_mask[expert_idx] - token_indices, weight_indices = torch.where(mask) + for expert_idx, expert in enumerate(self.experts): + token_indices, weight_indices = torch.where(expert_mask[expert_idx]) + has_tokens = token_indices.numel() > 0 - expert_weights = topk_weights[token_indices, weight_indices] - expert_input = hidden_states[token_indices] - expert_output = expert(expert_input) - weighted_output = expert_output * expert_weights.unsqueeze(-1) + if self.calibrate_all_experts: + expert_input = hidden_states + expert_output = expert(expert_input) - if token_indices.numel() > 0: - final_hidden_states.index_add_(0, token_indices, weighted_output) + if has_tokens: + expert_weights = topk_weights[token_indices, weight_indices] + routed_output = expert_output[ + token_indices + ] * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, routed_output) + else: + # Normal MoE: only process tokens routed to this expert + if has_tokens: + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + expert_weights = topk_weights[token_indices, weight_indices] + routed_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, routed_output) # End MoE hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape) @@ -49,5 +65,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE): - return DeepseekV3MoECalibrate(config=config, original=module) +def replace( + config: DeepseekV3Config, + module: OriginalDeepseekV3MoE, + calibrate_all_experts: bool, +): + return DeepseekV3MoECalibrate( + config=config, original=module, calibrate_all_experts=calibrate_all_experts + ) diff --git a/src/llmcompressor/modeling/llama4.py b/src/llmcompressor/modeling/llama4.py index fee1a5afd..ee47241fe 100644 --- a/src/llmcompressor/modeling/llama4.py +++ b/src/llmcompressor/modeling/llama4.py @@ -17,7 +17,12 @@ class SequentialLlama4TextMoe(torch.nn.Module): - def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe): + def __init__( + self, + config: Llama4TextConfig, + original: Llama4TextMoe, + calibrate_all_experts: bool, + ): super().__init__() self.top_k = config.num_experts_per_tok self.hidden_dim = config.hidden_size @@ -25,6 +30,7 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe): self.experts = SequentialLlama4TextExperts(config, original.experts) self.router = original.router self.shared_expert = original.shared_expert + self.calibrate_all_experts = calibrate_all_experts def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]: hidden_states = hidden_states.reshape(-1, self.hidden_dim) @@ -44,7 +50,21 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tens out = self.shared_expert(hidden_states) for i in range(self.num_experts): - out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) + expert_output = None + if self.calibrate_all_experts: + # Run all tokens for calibration + expert_output = self.experts[i](hidden_states) + + # Only top-k tokens contribute to final output + top_token_mask = router_scores[i] > 0 + if top_token_mask.any(): + if expert_output is None: + expert_output = self.experts[i](hidden_states[top_token_mask]) + else: + expert_output = expert_output[top_token_mask] + out[top_token_mask] += expert_output * router_scores[ + i, top_token_mask + ].unsqueeze(-1) if version.parse(transformers.__version__) >= version.parse("4.54.0"): return out, router_logits @@ -72,5 +92,9 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts): self[i].down_proj.weight.data = down.t().clone().contiguous() -def replace(config: Llama4Config, module: Llama4TextMoe): - return SequentialLlama4TextMoe(config=config.get_text_config(), original=module) +def replace(config: Llama4Config, module: Llama4TextMoe, calibrate_all_experts: bool): + return SequentialLlama4TextMoe( + config=config.get_text_config(), + original=module, + calibrate_all_experts=calibrate_all_experts, + ) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index cb61f5fad..0ec960df7 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -15,11 +15,19 @@ } -def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: +def replace_modules_for_calibration( + model: PreTrainedModel, + calibrate_all_experts: bool = False, +) -> PreTrainedModel: + for name, module in model.named_modules(): cls_name = module.__class__.__name__ if cls_name in replacements: - new_module = replacements[cls_name](config=model.config, module=module) + new_module = replacements[cls_name]( + config=model.config, + module=module, + calibrate_all_experts=calibrate_all_experts, + ) replace_module(model, name, new_module) return model @@ -28,7 +36,7 @@ def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: # ------------------- module replacements; during calibration -------------------- -def update_qwen3_moe(model, stack): +def update_qwen3_moe(model, stack, calibrate_all_experts): for module in model.modules(): cls_name = module.__class__.__name__ if cls_name == "Qwen3MoeDecoderLayer": @@ -37,7 +45,11 @@ def update_qwen3_moe(model, stack): patch_attr( module, "mlp", - replace_Qwen3MoE(config=model.config, module=module.mlp), + replace_Qwen3MoE( + config=model.config, + module=module.mlp, + calibrate_all_experts=calibrate_all_experts, + ), ) ) @@ -47,9 +59,13 @@ def update_qwen3_moe(model, stack): } -def moe_calibration_context(model: PreTrainedModel, stack): +def moe_calibration_context( + model: PreTrainedModel, + stack, + calibrate_all_experts: bool = False, +): # Temporarily updates the MoE modules within the context # Once the context exists, parameter updates persist cls_name = model.__class__.__name__ if cls_name in moe_context: - moe_context.get(cls_name)(model, stack) + moe_context.get(cls_name)(model, stack, calibrate_all_experts) diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index fcd5d9925..9c8e8e352 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -23,12 +23,16 @@ class Qwen3MoeSparseMoeBlock(torch.nn.Module): def __init__( - self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock + self, + config: Qwen3MoeConfig, + original: OriginalQwen3MoeSparseMoeBlock, + calibrate_all_experts: bool, ): super().__init__() self.num_experts = config.num_experts self.top_k = config.top_k self.norm_topk_prob = config.norm_topk_prob + self.calibrate_all_experts = calibrate_all_experts # gating self.gate = original.gate @@ -64,18 +68,29 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for expert_idx in range(len(self.experts)): expert_layer = self.experts[expert_idx] + cached_output = None + + if self.calibrate_all_experts: + cached_output = expert_layer(hidden_states) + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - expert_output = expert_layer(current_state) - current_hidden_states = expert_output * routing_weights[top_x, idx, None] - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_( - 0, top_x, current_hidden_states.to(hidden_states.dtype) - ) + if top_x.numel() > 0: + if cached_output is not None: + expert_output = cached_output[top_x] + else: + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + expert_output = expert_layer(current_state) + current_hidden_states = ( + expert_output * routing_weights[top_x, idx, None] + ) + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) final_hidden_states = final_hidden_states.reshape( batch_size, sequence_length, hidden_dim @@ -83,5 +98,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits -def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock): - return Qwen3MoeSparseMoeBlock(config=config, original=module) +def replace( + config: Qwen3MoeConfig, + module: OriginalQwen3MoeSparseMoeBlock, + calibrate_all_experts: bool, +): + return Qwen3MoeSparseMoeBlock( + config=config, original=module, calibrate_all_experts=calibrate_all_experts + ) diff --git a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py new file mode 100644 index 000000000..a1fea04cf --- /dev/null +++ b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py @@ -0,0 +1,49 @@ +from functools import partial + +import pytest +import torch +from transformers import AutoModelForCausalLM + +from llmcompressor.modeling.deepseek_v3 import DeepseekV3MoECalibrate +from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.utils.dev import skip_weights_download + + +@pytest.mark.parametrize("model_stub", ["unsloth/DeepSeek-R1-0528-BF16"]) +def test_calib_replace_deepseekv3moe_all_experts(model_stub): + with skip_weights_download(): + model = AutoModelForCausalLM.from_pretrained(model_stub) + + replace_modules_for_calibration(model, calibrate_all_experts=True) + + # Find a Deepseek MoE layer + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, DeepseekV3MoECalibrate): + moe_layer = module + break + + assert moe_layer is not None + + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] + + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True + + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) + + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) + + # Assert all experts are used + assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" diff --git a/tests/llmcompressor/modeling/test_calib_llama4.py b/tests/llmcompressor/modeling/test_calib_llama4.py new file mode 100644 index 000000000..78e8857ab --- /dev/null +++ b/tests/llmcompressor/modeling/test_calib_llama4.py @@ -0,0 +1,51 @@ +from functools import partial + +import pytest +import torch +from transformers import Llama4ForConditionalGeneration + +from llmcompressor.modeling.llama4 import SequentialLlama4TextMoe +from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.utils.dev import skip_weights_download + + +@pytest.mark.parametrize("model_stub", ["meta-llama/Llama-4-Scout-17B-16E-Instruct"]) +def test_calib_replace_llama4_moe_all_experts(model_stub): + with skip_weights_download(Llama4ForConditionalGeneration): + model = Llama4ForConditionalGeneration.from_pretrained( + model_stub, torch_dtype="auto" + ) + + replace_modules_for_calibration(model, calibrate_all_experts=True) + + # Find a Llama4 MoE layer + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, SequentialLlama4TextMoe): + moe_layer = module + break + + assert moe_layer is not None + + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] + + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True + + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) + + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) + + # Assert all experts are used + assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" diff --git a/tests/llmcompressor/modeling/test_calib_qwen3.py b/tests/llmcompressor/modeling/test_calib_qwen3.py new file mode 100644 index 000000000..f77933c53 --- /dev/null +++ b/tests/llmcompressor/modeling/test_calib_qwen3.py @@ -0,0 +1,58 @@ +import contextlib +from functools import partial + +import pytest +import torch +from transformers import AutoModelForCausalLM + +from llmcompressor.modeling.prepare import moe_calibration_context +from llmcompressor.modeling.qwen3_moe import Qwen3MoeSparseMoeBlock +from llmcompressor.utils.dev import skip_weights_download +from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context + + +@pytest.mark.parametrize("model_stub", ["Qwen/Qwen3-30B-A3B"]) +def test_calib_replace_qwen3moe_all_experts(model_stub): + with skip_weights_download(): + model = AutoModelForCausalLM.from_pretrained(model_stub) + + # Qwen3MoE layer replacement is temporary within the context + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(DisableQuantization(model)) + + moe_calibration_context(model, stack, calibrate_all_experts=True) + + # Find one MoE layer + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, Qwen3MoeSparseMoeBlock): + moe_layer = module + break + + assert moe_layer is not None + + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] + + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True + + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) + + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) + + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}"