Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/prime_rl/inference/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def transformers_v5_compat():
monkey_patch_deep_gemm_ep_scatter()
monkey_patch_dp_engine_core_pause_resume_deadlock()
monkey_patch_offloading_connector_cpu_block_count()
monkey_patch_routed_experts_capturer_mk_path()


@triton.jit
Expand Down Expand Up @@ -1086,3 +1087,60 @@ def _patched__post_init__(self: FusedMoEConfig):
self.is_lora_enabled = False

FusedMoEConfig.__post_init__ = _patched__post_init__


def monkey_patch_routed_experts_capturer_mk_path():
"""Fix RoutedExpertsCapturer.capture() for the modular-kernel (DeepEP) path.

DefaultMoERunner has two MoE dispatch paths:
- naive: dp_size > 1 and not supports_internal_mk -> DP combine concats
all ranks' tokens BEFORE select_experts, so topk_ids.shape[0] equals
cross-DP total.
- modular kernel (MK): supports_internal_mk=True (DeepEP, DEEPGEMM Fp8 MoE,
...) -> DP combine happens INSIDE quant_method.apply, so select_experts
sees only this rank's tokens.

The shipped capture() in vLLM 0.19 hardcodes `assert cumsum[-1] ==
topk_ids.shape[0]`, which only holds for the naive path. With DeepEP every
DP worker trips the assert during CUDA-graph warmup and engine cores die.

Mirrors the simpler post-refactor behavior on vLLM main (PR #39917): only
slice when topk_ids is the cross-DP concatenation; otherwise copy verbatim.
This handles the MK path (local tokens) AND CUDA-graph capture warmup
(where dp_metadata still claims max but the captured batch is smaller),
both of which would trip the strict either/or check from the earlier
upstream fix (PR #37879).
"""
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsCapturer,
)

def _patched_capture(self, layer_id: int, topk_ids: torch.Tensor) -> None:
if self._device_buffer is None:
raise RuntimeError("Buffer not initialized. Call init_buffer() first.")

if layer_id >= self._device_buffer.shape[1]:
return

ctx = get_forward_context()
n = topk_ids.shape[0]

if ctx.dp_metadata is not None:
num_tokens_dp = ctx.dp_metadata.num_tokens_across_dp_cpu
total = int(num_tokens_dp.sum().item())
if total > 0 and n == total:
# Naive dispatch: all DP ranks' tokens concatenated upstream.
token_num_per_dp = int(num_tokens_dp[self.dp_rank].item())
cumsum = torch.cumsum(num_tokens_dp, dim=0)
end_loc = int(cumsum[self.dp_rank].item())
start_loc = end_loc - token_num_per_dp
self._device_buffer[:token_num_per_dp, layer_id, :] = topk_ids[
start_loc:end_loc, :
]
return

# MK path / single-DP / CUDA-graph capture: copy local tensor as-is.
self._device_buffer[:n, layer_id, :] = topk_ids

RoutedExpertsCapturer.capture = _patched_capture
Loading