diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index d26c33d9a9..acc2c5935c 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -310,6 +310,19 @@ class ClientConfig(BaseConfig): ), ] = None + backend: Annotated[ + Literal["vllm", "dynamo"], + Field( + description=( + "Inference backend selector. Picks the AdminAPI implementation used for " + "pause/resume/update_weights/load_lora_adapter/list_models. Default 'vllm' " + "matches prime-rl's bundled vLLM frontend. 'dynamo' targets NVIDIA Dynamo's " + "native /v1/rl/* admin routes (DYN_ENABLE_RL=true) and routes /v1/models to " + "the OpenAI-compat root rather than under /v1/rl/." + ), + ), + ] = "vllm" + elastic: Annotated[ ElasticConfig | None, Field( diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index 1dcc6dcfb9..c175b79e60 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -113,6 +113,29 @@ async def init_broadcaster( response.raise_for_status() +class DynamoAdminAPI(VLLMAdminAPI): + """NVIDIA Dynamo admin endpoints (DYN_ENABLE_RL=true). + + Inherits all relative-path endpoints from VLLMAdminAPI -- with + ``admin_base_url`` pointed at ``…/v1/rl``, httpx joins them onto the right + place. Only ``list_models`` differs: Dynamo serves ``/v1/models`` at the + OpenAI-compat root, not under ``/v1/rl/``, so we send an absolute URL to + bypass the admin prefix. + """ + + async def list_models(self, client: AsyncClient) -> list[dict]: + url = httpx.URL(client.base_url).copy_with(path="/v1/models") + response = await client.get(str(url)) + return response.json()["data"] + + +def setup_admin_api(client_config: ClientConfig) -> AdminAPI: + """Pick the AdminAPI implementation that matches ``client_config.backend``.""" + if client_config.backend == "dynamo": + return DynamoAdminAPI() + return VLLMAdminAPI() + + _DEFAULT_ADMIN: AdminAPI = VLLMAdminAPI() @@ -181,6 +204,7 @@ def __init__( ) self._eval_clients = setup_clients(client_config, client_type=eval_client_type) self._admin_clients = setup_admin_clients(client_config) + self._admin_api = setup_admin_api(client_config) self._skip_model_check = client_config.skip_model_check self._wait_for_ready_timeout = client_config.wait_for_ready_timeout self._eval_cycle = cycle(self._eval_clients) @@ -206,12 +230,16 @@ async def get_eval_client(self) -> vf.ClientConfig: async def wait_for_ready(self, model_name: str, timeout: int | None = None) -> None: await check_health( - self._admin_clients, timeout=timeout if timeout is not None else self._wait_for_ready_timeout + self._admin_clients, + timeout=timeout if timeout is not None else self._wait_for_ready_timeout, + admin=self._admin_api, + ) + await maybe_check_has_model( + self._admin_clients, model_name, skip_model_check=self._skip_model_check, admin=self._admin_api ) - await maybe_check_has_model(self._admin_clients, model_name, skip_model_check=self._skip_model_check) async def update_weights(self, weight_dir: Path | None, lora_name: str | None = None, step: int = 0) -> None: - await update_weights(self._admin_clients, weight_dir, lora_name=lora_name, step=step) + await update_weights(self._admin_clients, weight_dir, lora_name=lora_name, step=step, admin=self._admin_api) def get_metrics(self) -> dict[str, float]: return {} diff --git a/src/prime_rl/utils/elastic.py b/src/prime_rl/utils/elastic.py index 902f873903..8c438f1e64 100644 --- a/src/prime_rl/utils/elastic.py +++ b/src/prime_rl/utils/elastic.py @@ -20,7 +20,7 @@ from httpx import AsyncClient from prime_rl.configs.shared import ClientConfig -from prime_rl.utils.client import load_lora_adapter, setup_admin_clients, setup_clients +from prime_rl.utils.client import load_lora_adapter, setup_admin_api, setup_admin_clients, setup_clients from prime_rl.utils.logger import get_logger # --- Shared discovery functions --- @@ -129,6 +129,7 @@ def __init__( self._servers: dict[str, ServerState] = {} self._admin_clients: dict[str, AsyncClient] = {} + self._admin_api = setup_admin_api(client_config) self._lock = asyncio.Lock() self._desired: AdapterState = AdapterState() @@ -335,7 +336,9 @@ async def _sync_server_adapter(self, ip: str) -> bool: if self._desired.name and self._desired.path: try: self.logger.debug(f"Loading adapter {self._desired.name} on {ip}") - await load_lora_adapter([self._admin_clients[ip]], self._desired.name, self._desired.path) + await load_lora_adapter( + [self._admin_clients[ip]], self._desired.name, self._desired.path, admin=self._admin_api + ) except Exception as e: server.status = "unhealthy" server.sync_failures += 1 @@ -370,12 +373,8 @@ async def _check_server_health(self, admin_client: AsyncClient, ip: str) -> bool return False try: - response = await admin_client.get("/v1/models") - response.raise_for_status() - data = response.json() - models = [m.get("id") for m in data.get("data", [])] - - if self.base_model_name not in models: + models = await self._admin_api.list_models(admin_client) + if self.base_model_name not in [m.get("id") for m in models]: self.logger.debug(f"Server {ip} does not have base model {self.base_model_name}") return False except Exception as e: