Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,19 @@ class ClientConfig(BaseConfig):
),
] = None

backend: Annotated[
Literal["vllm", "dynamo"],
Field(
description=(
"Inference backend selector. Picks the AdminAPI implementation used for "
"pause/resume/update_weights/load_lora_adapter/list_models. Default 'vllm' "
"matches prime-rl's bundled vLLM frontend. 'dynamo' targets NVIDIA Dynamo's "
"native /v1/rl/* admin routes (DYN_ENABLE_RL=true) and routes /v1/models to "
"the OpenAI-compat root rather than under /v1/rl/."
),
),
] = "vllm"

elastic: Annotated[
ElasticConfig | None,
Field(
Expand Down
34 changes: 31 additions & 3 deletions src/prime_rl/utils/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,29 @@ async def init_broadcaster(
response.raise_for_status()


class DynamoAdminAPI(VLLMAdminAPI):
"""NVIDIA Dynamo admin endpoints (DYN_ENABLE_RL=true).

Inherits all relative-path endpoints from VLLMAdminAPI -- with
``admin_base_url`` pointed at ``…/v1/rl``, httpx joins them onto the right
place. Only ``list_models`` differs: Dynamo serves ``/v1/models`` at the
OpenAI-compat root, not under ``/v1/rl/``, so we send an absolute URL to
bypass the admin prefix.
"""

async def list_models(self, client: AsyncClient) -> list[dict]:
url = httpx.URL(client.base_url).copy_with(path="/v1/models")
response = await client.get(str(url))
return response.json()["data"]


def setup_admin_api(client_config: ClientConfig) -> AdminAPI:
"""Pick the AdminAPI implementation that matches ``client_config.backend``."""
if client_config.backend == "dynamo":
return DynamoAdminAPI()
return VLLMAdminAPI()


_DEFAULT_ADMIN: AdminAPI = VLLMAdminAPI()


Expand Down Expand Up @@ -181,6 +204,7 @@ def __init__(
)
self._eval_clients = setup_clients(client_config, client_type=eval_client_type)
self._admin_clients = setup_admin_clients(client_config)
self._admin_api = setup_admin_api(client_config)
self._skip_model_check = client_config.skip_model_check
self._wait_for_ready_timeout = client_config.wait_for_ready_timeout
self._eval_cycle = cycle(self._eval_clients)
Expand All @@ -206,12 +230,16 @@ async def get_eval_client(self) -> vf.ClientConfig:

async def wait_for_ready(self, model_name: str, timeout: int | None = None) -> None:
await check_health(
self._admin_clients, timeout=timeout if timeout is not None else self._wait_for_ready_timeout
self._admin_clients,
timeout=timeout if timeout is not None else self._wait_for_ready_timeout,
admin=self._admin_api,
)
await maybe_check_has_model(
self._admin_clients, model_name, skip_model_check=self._skip_model_check, admin=self._admin_api
)
await maybe_check_has_model(self._admin_clients, model_name, skip_model_check=self._skip_model_check)

async def update_weights(self, weight_dir: Path | None, lora_name: str | None = None, step: int = 0) -> None:
await update_weights(self._admin_clients, weight_dir, lora_name=lora_name, step=step)
await update_weights(self._admin_clients, weight_dir, lora_name=lora_name, step=step, admin=self._admin_api)

def get_metrics(self) -> dict[str, float]:
return {}
Expand Down
15 changes: 7 additions & 8 deletions src/prime_rl/utils/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from httpx import AsyncClient

from prime_rl.configs.shared import ClientConfig
from prime_rl.utils.client import load_lora_adapter, setup_admin_clients, setup_clients
from prime_rl.utils.client import load_lora_adapter, setup_admin_api, setup_admin_clients, setup_clients
from prime_rl.utils.logger import get_logger

# --- Shared discovery functions ---
Expand Down Expand Up @@ -129,6 +129,7 @@ def __init__(

self._servers: dict[str, ServerState] = {}
self._admin_clients: dict[str, AsyncClient] = {}
self._admin_api = setup_admin_api(client_config)
self._lock = asyncio.Lock()
self._desired: AdapterState = AdapterState()

Expand Down Expand Up @@ -335,7 +336,9 @@ async def _sync_server_adapter(self, ip: str) -> bool:
if self._desired.name and self._desired.path:
try:
self.logger.debug(f"Loading adapter {self._desired.name} on {ip}")
await load_lora_adapter([self._admin_clients[ip]], self._desired.name, self._desired.path)
await load_lora_adapter(
[self._admin_clients[ip]], self._desired.name, self._desired.path, admin=self._admin_api
)
except Exception as e:
server.status = "unhealthy"
server.sync_failures += 1
Expand Down Expand Up @@ -370,12 +373,8 @@ async def _check_server_health(self, admin_client: AsyncClient, ip: str) -> bool
return False

try:
response = await admin_client.get("/v1/models")
response.raise_for_status()
data = response.json()
models = [m.get("id") for m in data.get("data", [])]

if self.base_model_name not in models:
models = await self._admin_api.list_models(admin_client)
if self.base_model_name not in [m.get("id") for m in models]:
self.logger.debug(f"Server {ip} does not have base model {self.base_model_name}")
return False
except Exception as e:
Expand Down
Loading