diff --git a/src/prime_rl/inference/patches.py b/src/prime_rl/inference/patches.py index 4f495aa423..3873b79965 100644 --- a/src/prime_rl/inference/patches.py +++ b/src/prime_rl/inference/patches.py @@ -18,6 +18,7 @@ def transformers_v5_compat(): monkey_patch_deep_gemm_ep_scatter() monkey_patch_dp_engine_core_pause_resume_deadlock() monkey_patch_offloading_connector_cpu_block_count() + monkey_patch_routed_experts_capturer_mk_path() @triton.jit @@ -1086,3 +1087,60 @@ def _patched__post_init__(self: FusedMoEConfig): self.is_lora_enabled = False FusedMoEConfig.__post_init__ = _patched__post_init__ + + +def monkey_patch_routed_experts_capturer_mk_path(): + """Fix RoutedExpertsCapturer.capture() for the modular-kernel (DeepEP) path. + + DefaultMoERunner has two MoE dispatch paths: + - naive: dp_size > 1 and not supports_internal_mk -> DP combine concats + all ranks' tokens BEFORE select_experts, so topk_ids.shape[0] equals + cross-DP total. + - modular kernel (MK): supports_internal_mk=True (DeepEP, DEEPGEMM Fp8 MoE, + ...) -> DP combine happens INSIDE quant_method.apply, so select_experts + sees only this rank's tokens. + + The shipped capture() in vLLM 0.19 hardcodes `assert cumsum[-1] == + topk_ids.shape[0]`, which only holds for the naive path. With DeepEP every + DP worker trips the assert during CUDA-graph warmup and engine cores die. + + Mirrors the simpler post-refactor behavior on vLLM main (PR #39917): only + slice when topk_ids is the cross-DP concatenation; otherwise copy verbatim. + This handles the MK path (local tokens) AND CUDA-graph capture warmup + (where dp_metadata still claims max but the captured batch is smaller), + both of which would trip the strict either/or check from the earlier + upstream fix (PR #37879). + """ + from vllm.forward_context import get_forward_context + from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + RoutedExpertsCapturer, + ) + + def _patched_capture(self, layer_id: int, topk_ids: torch.Tensor) -> None: + if self._device_buffer is None: + raise RuntimeError("Buffer not initialized. Call init_buffer() first.") + + if layer_id >= self._device_buffer.shape[1]: + return + + ctx = get_forward_context() + n = topk_ids.shape[0] + + if ctx.dp_metadata is not None: + num_tokens_dp = ctx.dp_metadata.num_tokens_across_dp_cpu + total = int(num_tokens_dp.sum().item()) + if total > 0 and n == total: + # Naive dispatch: all DP ranks' tokens concatenated upstream. + token_num_per_dp = int(num_tokens_dp[self.dp_rank].item()) + cumsum = torch.cumsum(num_tokens_dp, dim=0) + end_loc = int(cumsum[self.dp_rank].item()) + start_loc = end_loc - token_num_per_dp + self._device_buffer[:token_num_per_dp, layer_id, :] = topk_ids[ + start_loc:end_loc, : + ] + return + + # MK path / single-DP / CUDA-graph capture: copy local tensor as-is. + self._device_buffer[:n, layer_id, :] = topk_ids + + RoutedExpertsCapturer.capture = _patched_capture