Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,19 @@ def auto_setup_router_replay(self):
)
return self

@model_validator(mode="after")
def validate_router_replay_without_kv_offload(self):
if (
self.trainer.enable_router_replay
and self.inference is not None
and self.inference.kv_cache_offload is not None
):
raise ValueError(
"Router replay with inference.kv_cache_offload is not supported. "
"External KV cache hits do not carry routed-expert decisions."
)
return self

@model_validator(mode="after")
def auto_setup_deployment(self):
if self.deployment.type == "single_node": # single-node
Expand Down
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"torchaudio",
"torchdata>=0.11.0",
"transformers",
"vllm>=0.20.2",
"vllm",
"wandb>=0.26.1",
"ring-flash-attn>=0.1.8",
"prime>=0.6.4",
Expand All @@ -36,6 +36,7 @@ dependencies = [
"tilelang>=0.1.8",
"flash-linear-attention",
"nvidia-ml-py>=12.575.51",
"pybase64>=1.4.2",
]

[project.scripts]
Expand Down Expand Up @@ -130,6 +131,7 @@ override-dependencies = [
[tool.uv.exclude-newer-package]
# we want latest vllm, remove next patch
vllm = false
tokenspeed-mla = false
flash_attn_3 = false
# Self-vendored packages on our primeintellect index
reverse-text = false
Expand Down Expand Up @@ -166,15 +168,15 @@ prime-rl-configs = { workspace = true }
torch = { index = "pytorch-cu128" }
torchvision = { index = "pytorch-cu128" }
torchaudio = { index = "pytorch-cu128" }
verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers.git", rev = "aa428f3" }
verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "461a730" }
torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" }
dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" }
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" }
flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" }
pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" }
vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" }
vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" }
vllm = [
{ url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_x86_64.whl", marker = "platform_machine == 'x86_64'" },
{ url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" },
{ url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl", marker = "platform_machine == 'aarch64'" },
]
reverse-text = { index = "primeintellect" }
Expand Down
62 changes: 58 additions & 4 deletions src/prime_rl/inference/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,51 @@ def transformers_v5_compat():
monkey_patch_deep_gemm_silu_mul_quant_int64()
monkey_patch_dp_engine_core_pause_resume_deadlock()
monkey_patch_vllm_layerwise_reload_alias_buffers()
monkey_patch_return_routed_experts_with_nixl_connector()


def monkey_patch_return_routed_experts_with_nixl_connector():
from vllm import envs
from vllm.config.vllm import VllmConfig
from vllm.logger import init_logger

logger = init_logger(__name__)
original_post_init = VllmConfig.__post_init__

if getattr(original_post_init, "_prime_rl_allows_nixl_routed_experts", False):
return

def _is_nixl_routed_experts_pd_config(config: VllmConfig) -> bool:
kv_transfer_config = config.kv_transfer_config
return (
config.model_config is not None
and config.model_config.enable_return_routed_experts
and kv_transfer_config is not None
and kv_transfer_config.kv_connector == "NixlConnector"
and kv_transfer_config.is_kv_transfer_instance
)

def _post_init(config: VllmConfig):
if not _is_nixl_routed_experts_pd_config(config):
return original_post_init(config)

if config.parallel_config.pipeline_parallel_size > 1:
raise ValueError("--enable-return-routed-experts is incompatible with pipeline parallelism (PP > 1).")
if envs.VLLM_USE_V2_MODEL_RUNNER:
raise ValueError("VLLM_USE_V2_MODEL_RUNNER does not yet support: routed experts capture")

# vLLM rejects every KV connector, but our P/D path uses NIXL and
# stitches prefill/decode routed experts in the router. CPU KV offload
# remains rejected by prime-rl config validation.
config.model_config.enable_return_routed_experts = False
try:
return original_post_init(config)
finally:
config.model_config.enable_return_routed_experts = True

_post_init._prime_rl_allows_nixl_routed_experts = True
VllmConfig.__post_init__ = _post_init
logger.warning("Enabled vLLM routed-experts capture with NIXL connector patch.")


def monkey_patch_vllm_layerwise_reload_alias_buffers():
Expand Down Expand Up @@ -897,9 +942,9 @@ def monkey_patch_dp_engine_core_pause_resume_deadlock():
- on resume, wake every DP rank and force an immediate global unfinished
sync instead of waiting for the normal 32-step cadence

This keeps the upstream pause-side fix from
https://github.com/vllm-project/vllm/pull/37024 and extends it with the
resume-side wave-state fix.
This also bypasses vLLM's two-phase DP pause implementation
(https://github.com/vllm-project/vllm/pull/39366), which makes resume
reject states that our weight-update flow can validly hit.
"""
from vllm.config import ParallelConfig
from vllm.v1.core.sched.interface import PauseState
Expand All @@ -909,7 +954,8 @@ def monkey_patch_dp_engine_core_pause_resume_deadlock():

_base_add_request = EngineCore.add_request
_base_handle_client_request = EngineCoreProc._handle_client_request
_base_resume_scheduler = DPEngineCoreProc.resume_scheduler
_base_pause_complete = EngineCoreProc._pause_complete
_base_resume_scheduler = EngineCoreProc.resume_scheduler

def _patched_add_request(self, request: Request, request_wave: int = 0):
_base_add_request(self, request, request_wave)
Expand All @@ -930,8 +976,15 @@ def _patched_handle_client_request(self, request_type, request):
else:
_base_handle_client_request(self, request_type, request)

def _patched_pause_complete(self) -> bool:
self.pending_pause = False
self.ignore_start_dp_wave = False
return _base_pause_complete(self)

def _patched_resume_scheduler(self):
was_paused = self.scheduler.pause_state != PauseState.UNPAUSED
self.pending_pause = False
self.ignore_start_dp_wave = False
_base_resume_scheduler(self)
if was_paused:
self.engines_running = True
Expand All @@ -948,6 +1001,7 @@ def _patched_has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

DPEngineCoreProc.add_request = _patched_add_request
DPEngineCoreProc._handle_client_request = _patched_handle_client_request
DPEngineCoreProc._pause_complete = _patched_pause_complete
DPEngineCoreProc.resume_scheduler = _patched_resume_scheduler
DPEngineCoreProc._has_global_unfinished_reqs = _patched_has_global_unfinished_reqs

Expand Down
40 changes: 40 additions & 0 deletions src/prime_rl/inference/vllm/routed_experts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

from collections.abc import AsyncIterator
from typing import Any

import numpy as np
import pybase64
from vllm.outputs import RequestOutput


def serialize_routed_experts(routed_experts: Any) -> dict[str, Any] | None:
if routed_experts is None:
return None

array = np.asarray(routed_experts)
assert array.ndim == 3
assert np.issubdtype(array.dtype, np.integer)
if array.size:
assert array.min() >= 0
assert array.max() <= np.iinfo(np.uint8).max

compact = np.ascontiguousarray(array.astype(np.uint8, copy=False))
return {
"data": pybase64.b64encode(memoryview(compact)).decode("ascii"),
"shape": list(compact.shape),
}


class RoutedExpertsCapture:
def __init__(self, generator: AsyncIterator[RequestOutput]):
self._generator = generator
self.routed_experts: dict[int, dict[str, Any]] = {}

async def __aiter__(self):
async for request_output in self._generator:
for output in request_output.outputs:
encoded = serialize_routed_experts(getattr(output, "routed_experts", None))
if encoded is not None:
self.routed_experts[output.index] = encoded
yield request_output
24 changes: 7 additions & 17 deletions src/prime_rl/inference/vllm/serving_chat_with_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,11 @@
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import BeamSearchParams, SamplingParams

from prime_rl.inference.vllm.serving_tokens import _RoutedExpertsCaptureBase
from prime_rl.inference.vllm.routed_experts import RoutedExpertsCapture

logger = init_logger(__name__)


class _RoutedExpertsCapture(_RoutedExpertsCaptureBase):
"""Chat-endpoint variant: mutates choices in-place because
``ChatCompletionResponseChoice`` is ``extra='allow'``, so an extra
``routed_experts`` attribute survives serialization."""

def post_process(self, response: ChatCompletionResponse) -> None:
for choice in response.choices:
if choice.index in self.routed_experts:
choice.routed_experts = self.routed_experts[choice.index]


class ChatCompletionRequestWithTokens(ChatCompletionRequest):
field_names: ClassVar[Optional[set[str]]] = None
tokens: list[int] = Field(description=("Prompt tokens to use for the request."))
Expand All @@ -55,11 +44,10 @@ async def chat_completion_full_generator(
# 1. We create a custom generator that encapsulates the original result_generator in self._generator
# 2. We override it's __aiter__ method to also capture the routed experts as an extra field in ChatCompletionResponse.choices
# 3. We override the full_generator method to use the custom generator instead of the original one if expert routing is enabled
capture = None
if self.model_config.enable_return_routed_experts:
capture = _RoutedExpertsCapture(result_generator)
capture = RoutedExpertsCapture(result_generator)
result_generator = capture
else:
capture = None

response = await super().chat_completion_full_generator(
request,
Expand All @@ -72,8 +60,10 @@ async def chat_completion_full_generator(
reasoning_parser,
)

if capture and isinstance(response, ChatCompletionResponse):
capture.post_process(response)
if capture is not None and isinstance(response, ChatCompletionResponse):
for choice in response.choices:
if choice.index in capture.routed_experts:
choice.routed_experts = capture.routed_experts[choice.index]

return response

Expand Down
65 changes: 14 additions & 51 deletions src/prime_rl/inference/vllm/serving_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
header and forwarded to ``engine_client.generate``. The DP-replicated
inference servers prime-RL runs need this to target a specific replica.

2. ``routed_experts`` per-token export — when the engine emits routing
decisions (``enable_return_routed_experts``), surface them on each choice.
This is what the trainer's router-replay path consumes.
2. Compact ``routed_experts`` export — when the engine emits routing
decisions, surface them as base64 raw-byte payloads without requiring a vLLM
source fork.

3. Server-side ``max_tokens`` defaulting — ``ServingTokens`` hands the
client-supplied ``SamplingParams`` to the engine verbatim, and
Expand All @@ -30,13 +30,11 @@

from __future__ import annotations

import base64
from collections.abc import AsyncGenerator
from functools import cached_property
from typing import Any

import numpy as np
from fastapi import Request
from pydantic import Field
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, RequestResponseMetadata
from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
Expand All @@ -48,55 +46,20 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams

from prime_rl.inference.vllm.routed_experts import RoutedExpertsCapture


class PrimeRlGenerateResponseChoice(GenerateResponseChoice):
routed_experts: dict | None = Field(
default=None,
description=(
"Per-token expert routing decisions (base85-encoded int32 array + shape). "
"Populated only when the engine was launched with "
"``enable_return_routed_experts=True``; otherwise ``None``."
),
)
routed_experts: dict[str, Any] | None = None


class PrimeRlGenerateResponse(GenerateResponse):
choices: list[PrimeRlGenerateResponseChoice]


def encode_routed_experts(arr: np.ndarray) -> dict:
return {
"data": base64.b85encode(arr.tobytes()).decode("ascii"),
"shape": list(arr.shape),
}


class _RoutedExpertsCaptureBase:
"""Wraps the engine result generator and accumulates a
``{output_index: encoded_experts}`` map as outputs stream. Subclasses
implement ``post_process`` to fold the captured map into the response
in whatever shape the endpoint returns (in-place vs rebuilt)."""

def __init__(self, generator: AsyncGenerator[RequestOutput, None]):
self._generator = generator
self.routed_experts: dict[int, dict] = {}

async def __aiter__(self):
async for request_output in self._generator:
for output in request_output.outputs:
if output.routed_experts is not None:
self.routed_experts[output.index] = encode_routed_experts(output.routed_experts)
yield request_output


class _RoutedExpertsCapture(_RoutedExpertsCaptureBase):
"""Generate-endpoint variant: rebuilds the response with
``PrimeRlGenerateResponseChoice`` because upstream's
``GenerateResponseChoice`` isn't ``extra='allow'``, so an attribute
set after construction wouldn't survive serialization."""

class _GenerateRoutedExpertsCapture(RoutedExpertsCapture):
def post_process(self, response: GenerateResponse) -> PrimeRlGenerateResponse:
new_choices = [
choices = [
PrimeRlGenerateResponseChoice(
**choice.model_dump(),
routed_experts=self.routed_experts.get(choice.index),
Expand All @@ -105,7 +68,7 @@ def post_process(self, response: GenerateResponse) -> PrimeRlGenerateResponse:
]
return PrimeRlGenerateResponse(
request_id=response.request_id,
choices=new_choices,
choices=choices,
prompt_logprobs=response.prompt_logprobs,
kv_transfer_params=response.kv_transfer_params,
)
Expand Down Expand Up @@ -135,7 +98,7 @@ async def _client_set_max_tokens(raw_request: Request | None) -> bool:


class PrimeRlServingTokens(ServingTokens):
"""ServingTokens + DP-rank routing + routed_experts export + max_tokens defaulting."""
"""ServingTokens + DP-rank routing + compact routed experts + max_tokens defaulting."""

@cached_property
def _max_tokens_defaults(self) -> tuple[dict, int | None]:
Expand Down Expand Up @@ -306,10 +269,10 @@ async def serve_tokens_full_generator( # type: ignore[override]
# encoded experts surface in the JSON. Skipping the wrapper when the
# engine isn't producing routed experts keeps us a no-op subclass on
# the common path.
capture: _RoutedExpertsCapture | None = None
capture: _GenerateRoutedExpertsCapture | None = None
if self.model_config.enable_return_routed_experts:
capture = _RoutedExpertsCapture(result_generator)
result_generator = capture # type: ignore[assignment]
capture = _GenerateRoutedExpertsCapture(result_generator)
result_generator = capture

response = await super().serve_tokens_full_generator(
request, result_generator, request_id, model_name, request_metadata
Expand Down
Loading
Loading