diff --git a/renderers/__init__.py b/renderers/__init__.py index 6b2f225..d0f7fda 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -19,6 +19,7 @@ reject_assistant_in_extension, trim_to_turn_close, ) +from renderers.client import RendererTransport from renderers.deepseek_v3 import DeepSeekV3Renderer from renderers.default import DefaultRenderer from renderers.glm5 import GLM5Renderer @@ -55,6 +56,7 @@ "RenderedTokens", "Renderer", "RendererPool", + "RendererTransport", "TextPart", "ThinkingPart", "ToolCall", diff --git a/renderers/client.py b/renderers/client.py index b4e91fd..2efcbfa 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -14,13 +14,15 @@ import asyncio import base64 import logging -from typing import Any, cast +from typing import Any, Literal, cast import numpy as np from openai import AsyncOpenAI, BadRequestError from renderers.base import Message, Renderer, RendererPool, ToolSpec +RendererTransport = Literal["prime_vllm_generate", "dynamo_chat_nvext"] + _request_logger = logging.getLogger("renderers.client") @@ -44,8 +46,9 @@ async def generate( cache_salt: str | None = None, priority: int | None = None, extra_headers: dict[str, str] | None = None, + transport: RendererTransport = "prime_vllm_generate", ) -> dict[str, Any]: - """Tokenize messages, call vLLM /inference/v1/generate, parse the response. + """Tokenize messages, call the selected token-in backend, parse response. ``sampling_params`` is forwarded to vLLM verbatim. Two fields are always set by us and override caller values: ``stop_token_ids`` (from the @@ -82,24 +85,57 @@ def _prepare(r: Renderer): sp["logprobs"] = 1 sp.setdefault("skip_special_tokens", False) - body: dict[str, Any] = { - "model": model, - "token_ids": prompt_ids, - "sampling_params": sp, - } - if cache_salt is not None: - body["cache_salt"] = cache_salt - if priority is not None: - body["priority"] = priority - - # /inference/v1/generate is mounted at the server root, not under /v1 - # like the OpenAI-compatible endpoints. Build an absolute URL so the - # AsyncOpenAI client doesn't prepend its automatic /v1. - base = str(client.base_url).rstrip("/").removesuffix("/v1") - endpoint = f"{base}/inference/v1/generate" + if transport == "prime_vllm_generate": + body: dict[str, Any] = { + "model": model, + "token_ids": prompt_ids, + "sampling_params": sp, + } + if cache_salt is not None: + body["cache_salt"] = cache_salt + if priority is not None: + body["priority"] = priority + + # /inference/v1/generate is mounted at the server root, not under /v1 + # like the OpenAI-compatible endpoints. Build an absolute URL so the + # AsyncOpenAI client doesn't prepend its automatic /v1. + base = str(client.base_url).rstrip("/").removesuffix("/v1") + endpoint = f"{base}/inference/v1/generate" + elif transport == "dynamo_chat_nvext": + nvext: dict[str, Any] = { + "token_data": prompt_ids, + "extra_fields": ["completion_token_ids"], + } + if priority is not None: + nvext["agent_hints"] = {"priority": priority} + + body = { + "model": model, + "messages": [{"role": "user", "content": "(token-in mode)"}], + "stream": False, + "logprobs": True, + "stop_token_ids": stop_token_ids, + "nvext": nvext, + } + if cache_salt is not None: + body["cache_salt"] = cache_salt + + passthrough = dict(sp) + passthrough.pop("stop_token_ids", None) + passthrough.pop("logprobs", None) + passthrough.pop("skip_special_tokens", None) + max_tokens = passthrough.pop("max_tokens", None) + if max_tokens is not None: + body["max_completion_tokens"] = max_tokens + body.update({k: v for k, v in passthrough.items() if v is not None}) + endpoint = "/chat/completions" + else: + raise ValueError(f"Unsupported renderer transport: {transport}") + _request_logger.debug( - "POST %s prompt_len=%d max_tokens=%s", + "POST %s transport=%s prompt_len=%d max_tokens=%s", endpoint, + transport, len(prompt_ids), sp.get("max_tokens"), ) @@ -121,7 +157,23 @@ def _prepare(r: Renderer): raise choice = (data.get("choices") or [{}])[0] - completion_ids = choice.get("token_ids") or [] + if transport == "dynamo_chat_nvext": + completion_ids = ( + choice.get("token_ids") + or choice.get("nvext", {}).get("completion_token_ids") + or data.get("nvext", {}).get("completion_token_ids") + or [] + ) + raw_re = ( + choice.get("routed_experts") + or choice.get("nvext", {}).get("routed_experts") + or data.get("nvext", {}).get("routed_experts") + ) + request_id = data.get("id") or data.get("request_id") or "" + else: + completion_ids = choice.get("token_ids") or [] + raw_re = choice.get("routed_experts") + request_id = data.get("request_id") or "" if pool is not None: parsed = await _run_pooled(pool, lambda r: r.parse_response(completion_ids)) @@ -134,7 +186,6 @@ def _prepare(r: Renderer): completion_logprobs = [float(c.get("logprob") or 0.0) for c in content_lp or []] routed_experts = None - raw_re = choice.get("routed_experts") if isinstance(raw_re, dict) and "data" in raw_re and "shape" in raw_re: routed_experts = ( np.frombuffer(base64.b85decode(raw_re["data"]), dtype=np.int32) @@ -152,7 +203,7 @@ def _prepare(r: Renderer): finish_reason = "tool_calls" return { - "request_id": data.get("request_id") or "", + "request_id": request_id, "prompt_ids": list(prompt_ids), "completion_ids": list(completion_ids), "completion_logprobs": completion_logprobs, diff --git a/tests/test_client.py b/tests/test_client.py index e0093bd..93bbdd2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -40,14 +40,17 @@ class _FakeClient: URL off ``client.base_url``, so we expose one that includes the /v1 suffix the OpenAI SDK normally appends.""" - def __init__(self): + def __init__(self, response=None): self.calls = [] self.base_url = "http://fake-host:8000/v1" + self.response = response async def post(self, path, *, cast_to=dict, body=None, options=None): self.calls.append( {"path": path, "cast_to": cast_to, "body": body, "options": options} ) + if self.response is not None: + return self.response routed_experts = np.array([[[1]], [[2]]], dtype=np.int32) return { "request_id": "gen-test", @@ -147,3 +150,61 @@ def test_generate_uses_prebuilt_prompt_ids_without_rendering(): assert client.calls[0]["body"]["token_ids"] == [11, 12, 13] assert result["prompt_ids"] == [11, 12, 13] + + +def test_generate_can_use_dynamo_chat_nvext_transport(): + client = _FakeClient( + response={ + "id": "chatcmpl-test", + "model": "test-model", + "nvext": {"completion_token_ids": [7, 8]}, + "choices": [ + { + "logprobs": { + "content": [ + {"token": "token_id:7", "logprob": -0.1}, + {"token": "token_id:8", "logprob": -0.2}, + ] + }, + "finish_reason": "stop", + } + ], + } + ) + + result = asyncio.run( + generate( + client=client, + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={"temperature": 0.3, "max_tokens": 7, "min_tokens": 2}, + priority=4, + cache_salt="ckpt-42", + transport="dynamo_chat_nvext", + ) + ) + + assert client.calls[0]["path"] == "/chat/completions" + assert client.calls[0]["body"] == { + "model": "test-model", + "messages": [{"role": "user", "content": "(token-in mode)"}], + "stream": False, + "logprobs": True, + "stop_token_ids": [99], + "nvext": { + "token_data": [1, 2, 3], + "extra_fields": ["completion_token_ids"], + "agent_hints": {"priority": 4}, + }, + "cache_salt": "ckpt-42", + "max_completion_tokens": 7, + "temperature": 0.3, + "min_tokens": 2, + } + assert result["request_id"] == "chatcmpl-test" + assert result["prompt_ids"] == [1, 2, 3] + assert result["completion_ids"] == [7, 8] + assert result["completion_logprobs"] == [-0.1, -0.2] + assert result["finish_reason"] == "tool_calls"