Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions renderers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
reject_assistant_in_extension,
trim_to_turn_close,
)
from renderers.client import RendererTransport
from renderers.deepseek_v3 import DeepSeekV3Renderer
from renderers.default import DefaultRenderer
from renderers.glm5 import GLM5Renderer
Expand Down Expand Up @@ -55,6 +56,7 @@
"RenderedTokens",
"Renderer",
"RendererPool",
"RendererTransport",
"TextPart",
"ThinkingPart",
"ToolCall",
Expand Down
93 changes: 72 additions & 21 deletions renderers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import asyncio
import base64
import logging
from typing import Any, cast
from typing import Any, Literal, cast

import numpy as np
from openai import AsyncOpenAI, BadRequestError

from renderers.base import Message, Renderer, RendererPool, ToolSpec

RendererTransport = Literal["prime_vllm_generate", "dynamo_chat_nvext"]

_request_logger = logging.getLogger("renderers.client")


Expand All @@ -44,8 +46,9 @@ async def generate(
cache_salt: str | None = None,
priority: int | None = None,
extra_headers: dict[str, str] | None = None,
transport: RendererTransport = "prime_vllm_generate",
) -> dict[str, Any]:
"""Tokenize messages, call vLLM /inference/v1/generate, parse the response.
"""Tokenize messages, call the selected token-in backend, parse response.

``sampling_params`` is forwarded to vLLM verbatim. Two fields are always
set by us and override caller values: ``stop_token_ids`` (from the
Expand Down Expand Up @@ -82,24 +85,57 @@ def _prepare(r: Renderer):
sp["logprobs"] = 1
sp.setdefault("skip_special_tokens", False)

body: dict[str, Any] = {
"model": model,
"token_ids": prompt_ids,
"sampling_params": sp,
}
if cache_salt is not None:
body["cache_salt"] = cache_salt
if priority is not None:
body["priority"] = priority

# /inference/v1/generate is mounted at the server root, not under /v1
# like the OpenAI-compatible endpoints. Build an absolute URL so the
# AsyncOpenAI client doesn't prepend its automatic /v1.
base = str(client.base_url).rstrip("/").removesuffix("/v1")
endpoint = f"{base}/inference/v1/generate"
if transport == "prime_vllm_generate":
body: dict[str, Any] = {
"model": model,
"token_ids": prompt_ids,
"sampling_params": sp,
}
if cache_salt is not None:
body["cache_salt"] = cache_salt
if priority is not None:
body["priority"] = priority

# /inference/v1/generate is mounted at the server root, not under /v1
# like the OpenAI-compatible endpoints. Build an absolute URL so the
# AsyncOpenAI client doesn't prepend its automatic /v1.
base = str(client.base_url).rstrip("/").removesuffix("/v1")
endpoint = f"{base}/inference/v1/generate"
elif transport == "dynamo_chat_nvext":
nvext: dict[str, Any] = {
"token_data": prompt_ids,
"extra_fields": ["completion_token_ids"],
}
if priority is not None:
nvext["agent_hints"] = {"priority": priority}

body = {
"model": model,
"messages": [{"role": "user", "content": "(token-in mode)"}],
"stream": False,
"logprobs": True,
"stop_token_ids": stop_token_ids,
"nvext": nvext,
}
if cache_salt is not None:
body["cache_salt"] = cache_salt

passthrough = dict(sp)
passthrough.pop("stop_token_ids", None)
passthrough.pop("logprobs", None)
passthrough.pop("skip_special_tokens", None)
max_tokens = passthrough.pop("max_tokens", None)
if max_tokens is not None:
body["max_completion_tokens"] = max_tokens
body.update({k: v for k, v in passthrough.items() if v is not None})
endpoint = "/chat/completions"
else:
raise ValueError(f"Unsupported renderer transport: {transport}")

_request_logger.debug(
"POST %s prompt_len=%d max_tokens=%s",
"POST %s transport=%s prompt_len=%d max_tokens=%s",
endpoint,
transport,
len(prompt_ids),
sp.get("max_tokens"),
)
Expand All @@ -121,7 +157,23 @@ def _prepare(r: Renderer):
raise

choice = (data.get("choices") or [{}])[0]
completion_ids = choice.get("token_ids") or []
if transport == "dynamo_chat_nvext":
completion_ids = (
choice.get("token_ids")
or choice.get("nvext", {}).get("completion_token_ids")
or data.get("nvext", {}).get("completion_token_ids")
or []
)
raw_re = (
choice.get("routed_experts")
or choice.get("nvext", {}).get("routed_experts")
or data.get("nvext", {}).get("routed_experts")
)
request_id = data.get("id") or data.get("request_id") or ""
else:
completion_ids = choice.get("token_ids") or []
raw_re = choice.get("routed_experts")
request_id = data.get("request_id") or ""

if pool is not None:
parsed = await _run_pooled(pool, lambda r: r.parse_response(completion_ids))
Expand All @@ -134,7 +186,6 @@ def _prepare(r: Renderer):
completion_logprobs = [float(c.get("logprob") or 0.0) for c in content_lp or []]

routed_experts = None
raw_re = choice.get("routed_experts")
if isinstance(raw_re, dict) and "data" in raw_re and "shape" in raw_re:
routed_experts = (
np.frombuffer(base64.b85decode(raw_re["data"]), dtype=np.int32)
Expand All @@ -152,7 +203,7 @@ def _prepare(r: Renderer):
finish_reason = "tool_calls"

return {
"request_id": data.get("request_id") or "",
"request_id": request_id,
"prompt_ids": list(prompt_ids),
"completion_ids": list(completion_ids),
"completion_logprobs": completion_logprobs,
Expand Down
63 changes: 62 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,17 @@ class _FakeClient:
URL off ``client.base_url``, so we expose one that includes the /v1 suffix
the OpenAI SDK normally appends."""

def __init__(self):
def __init__(self, response=None):
self.calls = []
self.base_url = "http://fake-host:8000/v1"
self.response = response

async def post(self, path, *, cast_to=dict, body=None, options=None):
self.calls.append(
{"path": path, "cast_to": cast_to, "body": body, "options": options}
)
if self.response is not None:
return self.response
routed_experts = np.array([[[1]], [[2]]], dtype=np.int32)
return {
"request_id": "gen-test",
Expand Down Expand Up @@ -147,3 +150,61 @@ def test_generate_uses_prebuilt_prompt_ids_without_rendering():

assert client.calls[0]["body"]["token_ids"] == [11, 12, 13]
assert result["prompt_ids"] == [11, 12, 13]


def test_generate_can_use_dynamo_chat_nvext_transport():
client = _FakeClient(
response={
"id": "chatcmpl-test",
"model": "test-model",
"nvext": {"completion_token_ids": [7, 8]},
"choices": [
{
"logprobs": {
"content": [
{"token": "token_id:7", "logprob": -0.1},
{"token": "token_id:8", "logprob": -0.2},
]
},
"finish_reason": "stop",
}
],
}
)

result = asyncio.run(
generate(
client=client,
renderer=_FakeRenderer(),
messages=[{"role": "user", "content": "hi"}],
model="test-model",
tools=[{"type": "function", "function": {"name": "echo"}}],
sampling_params={"temperature": 0.3, "max_tokens": 7, "min_tokens": 2},
priority=4,
cache_salt="ckpt-42",
transport="dynamo_chat_nvext",
)
)

assert client.calls[0]["path"] == "/chat/completions"
assert client.calls[0]["body"] == {
"model": "test-model",
"messages": [{"role": "user", "content": "(token-in mode)"}],
"stream": False,
"logprobs": True,
"stop_token_ids": [99],
"nvext": {
"token_data": [1, 2, 3],
"extra_fields": ["completion_token_ids"],
"agent_hints": {"priority": 4},
},
"cache_salt": "ckpt-42",
"max_completion_tokens": 7,
"temperature": 0.3,
"min_tokens": 2,
}
assert result["request_id"] == "chatcmpl-test"
assert result["prompt_ids"] == [1, 2, 3]
assert result["completion_ids"] == [7, 8]
assert result["completion_logprobs"] == [-0.1, -0.2]
assert result["finish_reason"] == "tool_calls"
Loading