Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions src/llmcompressor/modeling/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@ class DeepseekV3MoECalibrate(torch.nn.Module):
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
"""

def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE):
def __init__(
self,
config: DeepseekV3Config,
original: OriginalDeepseekV3MoE,
calibrate_all_experts: bool,
):
super().__init__()
self.config = config
self.experts = original.experts
self.gate = original.gate
self.shared_experts = original.shared_experts
self.calibrate_all_experts = calibrate_all_experts

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residuals = hidden_states
Expand All @@ -30,24 +36,40 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
)
expert_mask = expert_mask.permute(2, 0, 1)

for expert_idx in range(len(self.experts)):
expert = self.experts[expert_idx]
mask = expert_mask[expert_idx]
token_indices, weight_indices = torch.where(mask)
for expert_idx, expert in enumerate(self.experts):
token_indices, weight_indices = torch.where(expert_mask[expert_idx])
has_tokens = token_indices.numel() > 0

expert_weights = topk_weights[token_indices, weight_indices]
expert_input = hidden_states[token_indices]
expert_output = expert(expert_input)
weighted_output = expert_output * expert_weights.unsqueeze(-1)
if self.calibrate_all_experts:
expert_input = hidden_states
expert_output = expert(expert_input)

if token_indices.numel() > 0:
final_hidden_states.index_add_(0, token_indices, weighted_output)
if has_tokens:
expert_weights = topk_weights[token_indices, weight_indices]
routed_output = expert_output[
token_indices
] * expert_weights.unsqueeze(-1)
final_hidden_states.index_add_(0, token_indices, routed_output)
else:
# Normal MoE: only process tokens routed to this expert
if has_tokens:
expert_input = hidden_states[token_indices]
expert_output = expert(expert_input)
expert_weights = topk_weights[token_indices, weight_indices]
routed_output = expert_output * expert_weights.unsqueeze(-1)
final_hidden_states.index_add_(0, token_indices, routed_output)
# End MoE

hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape)
hidden_states = hidden_states + self.shared_experts(residuals)
return hidden_states


def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE):
return DeepseekV3MoECalibrate(config=config, original=module)
def replace(
config: DeepseekV3Config,
module: OriginalDeepseekV3MoE,
calibrate_all_experts: bool,
):
return DeepseekV3MoECalibrate(
config=config, original=module, calibrate_all_experts=calibrate_all_experts
)
32 changes: 28 additions & 4 deletions src/llmcompressor/modeling/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@


class SequentialLlama4TextMoe(torch.nn.Module):
def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe):
def __init__(
self,
config: Llama4TextConfig,
original: Llama4TextMoe,
calibrate_all_experts: bool,
):
super().__init__()
self.top_k = config.num_experts_per_tok
self.hidden_dim = config.hidden_size
self.num_experts = config.num_local_experts
self.experts = SequentialLlama4TextExperts(config, original.experts)
self.router = original.router
self.shared_expert = original.shared_expert
self.calibrate_all_experts = calibrate_all_experts

def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]:
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
Expand All @@ -44,7 +50,21 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tens

out = self.shared_expert(hidden_states)
for i in range(self.num_experts):
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)
expert_output = None
if self.calibrate_all_experts:
# Run all tokens for calibration
expert_output = self.experts[i](hidden_states)

# Only top-k tokens contribute to final output
top_token_mask = router_scores[i] > 0
if top_token_mask.any():
if expert_output is None:
expert_output = self.experts[i](hidden_states[top_token_mask])
else:
expert_output = expert_output[top_token_mask]
out[top_token_mask] += expert_output * router_scores[
i, top_token_mask
].unsqueeze(-1)

if version.parse(transformers.__version__) >= version.parse("4.54.0"):
return out, router_logits
Expand Down Expand Up @@ -72,5 +92,9 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts):
self[i].down_proj.weight.data = down.t().clone().contiguous()


def replace(config: Llama4Config, module: Llama4TextMoe):
return SequentialLlama4TextMoe(config=config.get_text_config(), original=module)
def replace(config: Llama4Config, module: Llama4TextMoe, calibrate_all_experts: bool):
return SequentialLlama4TextMoe(
config=config.get_text_config(),
original=module,
calibrate_all_experts=calibrate_all_experts,
)
28 changes: 22 additions & 6 deletions src/llmcompressor/modeling/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@
}


def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
def replace_modules_for_calibration(
model: PreTrainedModel,
calibrate_all_experts: bool = False,
) -> PreTrainedModel:

for name, module in model.named_modules():
cls_name = module.__class__.__name__
if cls_name in replacements:
new_module = replacements[cls_name](config=model.config, module=module)
new_module = replacements[cls_name](
config=model.config,
module=module,
calibrate_all_experts=calibrate_all_experts,
)
replace_module(model, name, new_module)

return model
Expand All @@ -28,7 +36,7 @@ def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
# ------------------- module replacements; during calibration --------------------


def update_qwen3_moe(model, stack):
def update_qwen3_moe(model, stack, calibrate_all_experts):
for module in model.modules():
cls_name = module.__class__.__name__
if cls_name == "Qwen3MoeDecoderLayer":
Expand All @@ -37,7 +45,11 @@ def update_qwen3_moe(model, stack):
patch_attr(
module,
"mlp",
replace_Qwen3MoE(config=model.config, module=module.mlp),
replace_Qwen3MoE(
config=model.config,
module=module.mlp,
calibrate_all_experts=calibrate_all_experts,
),
)
)

Expand All @@ -47,9 +59,13 @@ def update_qwen3_moe(model, stack):
}


def moe_calibration_context(model: PreTrainedModel, stack):
def moe_calibration_context(
model: PreTrainedModel,
stack,
calibrate_all_experts: bool = False,
):
# Temporarily updates the MoE modules within the context
# Once the context exists, parameter updates persist
cls_name = model.__class__.__name__
if cls_name in moe_context:
moe_context.get(cls_name)(model, stack)
moe_context.get(cls_name)(model, stack, calibrate_all_experts)
49 changes: 35 additions & 14 deletions src/llmcompressor/modeling/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@

class Qwen3MoeSparseMoeBlock(torch.nn.Module):
def __init__(
self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock
self,
config: Qwen3MoeConfig,
original: OriginalQwen3MoeSparseMoeBlock,
calibrate_all_experts: bool,
):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.top_k
self.norm_topk_prob = config.norm_topk_prob
self.calibrate_all_experts = calibrate_all_experts

# gating
self.gate = original.gate
Expand Down Expand Up @@ -64,24 +68,41 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

for expert_idx in range(len(self.experts)):
expert_layer = self.experts[expert_idx]
cached_output = None

if self.calibrate_all_experts:
cached_output = expert_layer(hidden_states)

idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
expert_output = expert_layer(current_state)
current_hidden_states = expert_output * routing_weights[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)
if top_x.numel() > 0:
if cached_output is not None:
expert_output = cached_output[top_x]
else:
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
expert_output = expert_layer(current_state)
current_hidden_states = (
expert_output * routing_weights[top_x, idx, None]
)
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)

final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim
)
return final_hidden_states, router_logits


def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock):
return Qwen3MoeSparseMoeBlock(config=config, original=module)
def replace(
config: Qwen3MoeConfig,
module: OriginalQwen3MoeSparseMoeBlock,
calibrate_all_experts: bool,
):
return Qwen3MoeSparseMoeBlock(
config=config, original=module, calibrate_all_experts=calibrate_all_experts
)
49 changes: 49 additions & 0 deletions tests/llmcompressor/modeling/test_calib_deepseek_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from functools import partial

import pytest
import torch
from transformers import AutoModelForCausalLM

from llmcompressor.modeling.deepseek_v3 import DeepseekV3MoECalibrate
from llmcompressor.modeling.prepare import replace_modules_for_calibration
from llmcompressor.utils.dev import skip_weights_download


@pytest.mark.parametrize("model_stub", ["unsloth/DeepSeek-R1-0528-BF16"])
def test_calib_replace_deepseekv3moe_all_experts(model_stub):
with skip_weights_download():
model = AutoModelForCausalLM.from_pretrained(model_stub)

replace_modules_for_calibration(model, calibrate_all_experts=True)

# Find a Deepseek MoE layer
moe_layer = None
for _, module in model.named_modules():
if isinstance(module, DeepseekV3MoECalibrate):
moe_layer = module
break

assert moe_layer is not None

num_experts = len(moe_layer.experts)
expert_triggered = [False for _ in range(num_experts)]

# Define the hook function
def hook_fn(i, module, input, output):
expert_triggered[i] = True

# Attach hooks using functools.partial to bind each index
for i, expert in enumerate(moe_layer.experts):
expert.register_forward_hook(partial(hook_fn, i))

# Create dummy input tensor that simulates hidden_states
hidden_dim = model.config.hidden_size
batch, seq_len = 4, 32
sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32)

# Forward through the MoE layer directly
with torch.no_grad():
_ = moe_layer(sample)

# Assert all experts are used
assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}"
51 changes: 51 additions & 0 deletions tests/llmcompressor/modeling/test_calib_llama4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from functools import partial

import pytest
import torch
from transformers import Llama4ForConditionalGeneration

from llmcompressor.modeling.llama4 import SequentialLlama4TextMoe
from llmcompressor.modeling.prepare import replace_modules_for_calibration
from llmcompressor.utils.dev import skip_weights_download


@pytest.mark.parametrize("model_stub", ["meta-llama/Llama-4-Scout-17B-16E-Instruct"])
def test_calib_replace_llama4_moe_all_experts(model_stub):
with skip_weights_download(Llama4ForConditionalGeneration):
model = Llama4ForConditionalGeneration.from_pretrained(
model_stub, torch_dtype="auto"
)

replace_modules_for_calibration(model, calibrate_all_experts=True)

# Find a Llama4 MoE layer
moe_layer = None
for _, module in model.named_modules():
if isinstance(module, SequentialLlama4TextMoe):
moe_layer = module
break

assert moe_layer is not None

num_experts = len(moe_layer.experts)
expert_triggered = [False for _ in range(num_experts)]

# Define the hook function
def hook_fn(i, module, input, output):
expert_triggered[i] = True

# Attach hooks using functools.partial to bind each index
for i, expert in enumerate(moe_layer.experts):
expert.register_forward_hook(partial(hook_fn, i))

# Create dummy input tensor that simulates hidden_states
hidden_dim = model.config.hidden_size
batch, seq_len = 4, 32
sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32)

# Forward through the MoE layer directly
with torch.no_grad():
_ = moe_layer(sample)

# Assert all experts are used
assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}"
Loading