Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,12 @@ def validate_renderer_args(self):
renderer_args_set.append(f"renderer.reasoning_parser={self.renderer.reasoning_parser!r}")
if self.renderer.pool_size is not None:
renderer_args_set.append(f"renderer.pool_size={self.renderer.pool_size!r}")
if self.renderer.preserve_all_thinking:
renderer_args_set.append(f"renderer.preserve_all_thinking={self.renderer.preserve_all_thinking!r}")
if self.renderer.preserve_thinking_between_tool_calls:
renderer_args_set.append(
f"renderer.preserve_thinking_between_tool_calls={self.renderer.preserve_thinking_between_tool_calls!r}"
)

if renderer_args_set:
raise ValueError(
Expand Down
6 changes: 6 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,12 @@ def validate_renderer_args(self):
renderer_args_set.append(f"renderer.tool_parser={self.renderer.tool_parser!r}")
if self.renderer.reasoning_parser is not None:
renderer_args_set.append(f"renderer.reasoning_parser={self.renderer.reasoning_parser!r}")
if self.renderer.preserve_all_thinking:
renderer_args_set.append(f"renderer.preserve_all_thinking={self.renderer.preserve_all_thinking!r}")
if self.renderer.preserve_thinking_between_tool_calls:
renderer_args_set.append(
f"renderer.preserve_thinking_between_tool_calls={self.renderer.preserve_thinking_between_tool_calls!r}"
)

if renderer_args_set:
raise ValueError(
Expand Down
31 changes: 31 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,37 @@ class RendererConfig(BaseConfig):
),
] = None

preserve_all_thinking: Annotated[
bool,
Field(
description=(
"Override flag forwarded to the renderer at construction. When "
"True, every past-assistant turn's ``reasoning_content`` is "
"re-emitted between ``<think>``/``</think>`` (or the model's "
"equivalent), even if the underlying chat template would drop "
"it. Off by default — preserves byte-identical output to the "
"stock template. Strict superset of "
"``preserve_thinking_between_tool_calls``."
),
),
] = False

preserve_thinking_between_tool_calls: Annotated[
bool,
Field(
description=(
"Override flag forwarded to the renderer at construction. When "
"True, preserves past-assistant ``reasoning_content`` only "
"inside the *current* tool cycle — the contiguous "
"assistant→tool→…→assistant block after the most recent user "
"message, when that block contains at least one tool response. "
"A new user turn closes the block; older blocks fall back to "
"the template default (typically dropped). Use "
"``preserve_all_thinking`` to keep older blocks too."
),
),
] = False
Comment thread
cursor[bot] marked this conversation as resolved.


class ElasticConfig(BaseConfig):
"""Configures elastic inference pool with DNS-based service discovery.
Expand Down
4 changes: 4 additions & 0 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,8 @@ async def setup_rollout_inference_pool(
renderer=config.renderer.name,
tool_parser=config.renderer.tool_parser,
reasoning_parser=config.renderer.reasoning_parser,
preserve_all_thinking=config.renderer.preserve_all_thinking,
preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls,
)
logger.info(f"Initialized {type(renderer).__name__} for {config.model.name}")
inference_pool = await setup_inference_pool(
Expand All @@ -937,6 +939,8 @@ async def setup_rollout_inference_pool(
tool_parser=config.renderer.tool_parser,
reasoning_parser=config.renderer.reasoning_parser,
renderer_pool_size=config.renderer.pool_size,
preserve_all_thinking=config.renderer.preserve_all_thinking,
preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls,
)
logger.info("Using direct renderer rollout client")
return renderer, inference_pool
Expand Down
2 changes: 2 additions & 0 deletions src/prime_rl/trainer/sft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def train(config: SFTConfig):
renderer=config.renderer.name,
tool_parser=config.renderer.tool_parser,
reasoning_parser=config.renderer.reasoning_parser,
preserve_all_thinking=config.renderer.preserve_all_thinking,
preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls,
)
if isinstance(renderer, DefaultRenderer):
raise ValueError(
Expand Down
22 changes: 22 additions & 0 deletions src/prime_rl/utils/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def __init__(
tool_parser: str | None = None,
reasoning_parser: str | None = None,
renderer_pool_size: int | None = None,
preserve_all_thinking: bool = False,
preserve_thinking_between_tool_calls: bool = False,
):
renderer_model_name = model_name if train_client_type == "renderer" else None
self._train_clients = setup_clients(
Expand All @@ -78,6 +80,8 @@ def __init__(
tool_parser=tool_parser,
reasoning_parser=reasoning_parser,
renderer_pool_size=renderer_pool_size,
preserve_all_thinking=preserve_all_thinking,
preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls,
)
self._eval_clients = setup_clients(client_config, client_type=eval_client_type)
self._admin_clients = setup_admin_clients(client_config)
Expand Down Expand Up @@ -129,6 +133,8 @@ async def setup_inference_pool(
tool_parser: str | None = None,
reasoning_parser: str | None = None,
renderer_pool_size: int | None = None,
preserve_all_thinking: bool = False,
preserve_thinking_between_tool_calls: bool = False,
) -> InferencePool:
"""Create an inference pool from config (static or elastic)."""
logger = get_logger()
Expand All @@ -152,6 +158,8 @@ async def setup_inference_pool(
tool_parser=tool_parser,
reasoning_parser=reasoning_parser,
renderer_pool_size=renderer_pool_size,
preserve_all_thinking=preserve_all_thinking,
preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls,
)

logger.info(
Expand All @@ -168,6 +176,8 @@ async def setup_inference_pool(
tool_parser=tool_parser,
reasoning_parser=reasoning_parser,
renderer_pool_size=renderer_pool_size,
preserve_all_thinking=preserve_all_thinking,
preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls,
)


Expand All @@ -179,9 +189,20 @@ def setup_clients(
tool_parser: str | None = None,
reasoning_parser: str | None = None,
renderer_pool_size: int | None = None,
preserve_all_thinking: bool = False,
preserve_thinking_between_tool_calls: bool = False,
) -> list[vf.ClientConfig]:
clients = []
client_idx = 0
# Only forward preserve flags when the client actually uses a renderer —
# MITO/TITO clients ignore them and the verifiers ClientConfig may reject
# unknown extras on older versions.
renderer_extra: dict = {}
if client_type == "renderer":
renderer_extra = {
"preserve_all_thinking": preserve_all_thinking,
"preserve_thinking_between_tool_calls": preserve_thinking_between_tool_calls,
}
for base_url in client_config.base_url:
for dp_rank in range(client_config.dp_rank_count):
headers = client_config.headers.copy()
Expand All @@ -205,6 +226,7 @@ def setup_clients(
max_retries=10,
extra_headers=headers,
extra_headers_from_state=client_config.extra_headers_from_state,
**renderer_extra,
)
)
client_idx += 1
Expand Down
10 changes: 10 additions & 0 deletions src/prime_rl/utils/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def __init__(
tool_parser: str | None = None,
reasoning_parser: str | None = None,
renderer_pool_size: int | None = None,
preserve_all_thinking: bool = False,
preserve_thinking_between_tool_calls: bool = False,
):
self.logger = get_logger()
self.client_config = client_config
Expand All @@ -125,6 +127,8 @@ def __init__(
self.tool_parser = tool_parser
self.reasoning_parser = reasoning_parser
self.renderer_pool_size = renderer_pool_size
self.preserve_all_thinking = preserve_all_thinking
self.preserve_thinking_between_tool_calls = preserve_thinking_between_tool_calls
self.router_url = client_config.router_url

self._servers: dict[str, ServerState] = {}
Expand Down Expand Up @@ -152,6 +156,8 @@ async def from_config(
tool_parser: str | None = None,
reasoning_parser: str | None = None,
renderer_pool_size: int | None = None,
preserve_all_thinking: bool = False,
preserve_thinking_between_tool_calls: bool = False,
) -> ElasticInferencePool:
if client_config.elastic is None:
raise ValueError("Elastic inference pool requires elastic config")
Expand All @@ -164,6 +170,8 @@ async def from_config(
tool_parser=tool_parser,
reasoning_parser=reasoning_parser,
renderer_pool_size=renderer_pool_size,
preserve_all_thinking=preserve_all_thinking,
preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls,
)
await pool.start()
return pool
Expand Down Expand Up @@ -214,6 +222,8 @@ def _rebuild_clients(self) -> None:
tool_parser=self.tool_parser,
reasoning_parser=self.reasoning_parser,
renderer_pool_size=self.renderer_pool_size,
preserve_all_thinking=self.preserve_all_thinking,
preserve_thinking_between_tool_calls=self.preserve_thinking_between_tool_calls,
)
if urls
else []
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/orchestrator/test_orchestrator_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ async def run() -> None:
tool_parser=None,
reasoning_parser=None,
pool_size=None,
preserve_all_thinking=False,
preserve_thinking_between_tool_calls=False,
),
)
rollout_client_config = SimpleNamespace(base_url=["http://localhost:8000/v1"])
Expand Down Expand Up @@ -79,6 +81,8 @@ async def run() -> None:
renderer="qwen3_vl",
tool_parser=None,
reasoning_parser=None,
preserve_all_thinking=False,
preserve_thinking_between_tool_calls=False,
)
setup_pool_mock.assert_awaited_once_with(
rollout_client_config,
Expand All @@ -89,6 +93,8 @@ async def run() -> None:
tool_parser=None,
reasoning_parser=None,
renderer_pool_size=None,
preserve_all_thinking=False,
preserve_thinking_between_tool_calls=False,
)

asyncio.run(run())
Loading