diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 5d04d3369f..8298e2867c 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -1269,6 +1269,12 @@ def validate_renderer_args(self): renderer_args_set.append(f"renderer.reasoning_parser={self.renderer.reasoning_parser!r}") if self.renderer.pool_size is not None: renderer_args_set.append(f"renderer.pool_size={self.renderer.pool_size!r}") + if self.renderer.preserve_all_thinking: + renderer_args_set.append(f"renderer.preserve_all_thinking={self.renderer.preserve_all_thinking!r}") + if self.renderer.preserve_thinking_between_tool_calls: + renderer_args_set.append( + f"renderer.preserve_thinking_between_tool_calls={self.renderer.preserve_thinking_between_tool_calls!r}" + ) if renderer_args_set: raise ValueError( diff --git a/packages/prime-rl-configs/src/prime_rl/configs/sft.py b/packages/prime-rl-configs/src/prime_rl/configs/sft.py index 84ee13018d..29a6090f55 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/sft.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/sft.py @@ -414,6 +414,12 @@ def validate_renderer_args(self): renderer_args_set.append(f"renderer.tool_parser={self.renderer.tool_parser!r}") if self.renderer.reasoning_parser is not None: renderer_args_set.append(f"renderer.reasoning_parser={self.renderer.reasoning_parser!r}") + if self.renderer.preserve_all_thinking: + renderer_args_set.append(f"renderer.preserve_all_thinking={self.renderer.preserve_all_thinking!r}") + if self.renderer.preserve_thinking_between_tool_calls: + renderer_args_set.append( + f"renderer.preserve_thinking_between_tool_calls={self.renderer.preserve_thinking_between_tool_calls!r}" + ) if renderer_args_set: raise ValueError( diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index d26c33d9a9..4c6d55ca66 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -186,6 +186,37 @@ class RendererConfig(BaseConfig): ), ] = None + preserve_all_thinking: Annotated[ + bool, + Field( + description=( + "Override flag forwarded to the renderer at construction. When " + "True, every past-assistant turn's ``reasoning_content`` is " + "re-emitted between ````/```` (or the model's " + "equivalent), even if the underlying chat template would drop " + "it. Off by default — preserves byte-identical output to the " + "stock template. Strict superset of " + "``preserve_thinking_between_tool_calls``." + ), + ), + ] = False + + preserve_thinking_between_tool_calls: Annotated[ + bool, + Field( + description=( + "Override flag forwarded to the renderer at construction. When " + "True, preserves past-assistant ``reasoning_content`` only " + "inside the *current* tool cycle — the contiguous " + "assistant→tool→…→assistant block after the most recent user " + "message, when that block contains at least one tool response. " + "A new user turn closes the block; older blocks fall back to " + "the template default (typically dropped). Use " + "``preserve_all_thinking`` to keep older blocks too." + ), + ), + ] = False + class ElasticConfig(BaseConfig): """Configures elastic inference pool with DNS-based service discovery. diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index bc1128ebc7..67ef7bfa1d 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -926,6 +926,8 @@ async def setup_rollout_inference_pool( renderer=config.renderer.name, tool_parser=config.renderer.tool_parser, reasoning_parser=config.renderer.reasoning_parser, + preserve_all_thinking=config.renderer.preserve_all_thinking, + preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls, ) logger.info(f"Initialized {type(renderer).__name__} for {config.model.name}") inference_pool = await setup_inference_pool( @@ -937,6 +939,8 @@ async def setup_rollout_inference_pool( tool_parser=config.renderer.tool_parser, reasoning_parser=config.renderer.reasoning_parser, renderer_pool_size=config.renderer.pool_size, + preserve_all_thinking=config.renderer.preserve_all_thinking, + preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls, ) logger.info("Using direct renderer rollout client") return renderer, inference_pool diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index ace4158f2e..1c12b342ee 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -166,6 +166,8 @@ def train(config: SFTConfig): renderer=config.renderer.name, tool_parser=config.renderer.tool_parser, reasoning_parser=config.renderer.reasoning_parser, + preserve_all_thinking=config.renderer.preserve_all_thinking, + preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls, ) if isinstance(renderer, DefaultRenderer): raise ValueError( diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index 21659dfc46..9f59d1a2b1 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -68,6 +68,8 @@ def __init__( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ): renderer_model_name = model_name if train_client_type == "renderer" else None self._train_clients = setup_clients( @@ -78,6 +80,8 @@ def __init__( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, + preserve_all_thinking=preserve_all_thinking, + preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) self._eval_clients = setup_clients(client_config, client_type=eval_client_type) self._admin_clients = setup_admin_clients(client_config) @@ -129,6 +133,8 @@ async def setup_inference_pool( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ) -> InferencePool: """Create an inference pool from config (static or elastic).""" logger = get_logger() @@ -152,6 +158,8 @@ async def setup_inference_pool( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, + preserve_all_thinking=preserve_all_thinking, + preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) logger.info( @@ -168,6 +176,8 @@ async def setup_inference_pool( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, + preserve_all_thinking=preserve_all_thinking, + preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) @@ -179,9 +189,20 @@ def setup_clients( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ) -> list[vf.ClientConfig]: clients = [] client_idx = 0 + # Only forward preserve flags when the client actually uses a renderer — + # MITO/TITO clients ignore them and the verifiers ClientConfig may reject + # unknown extras on older versions. + renderer_extra: dict = {} + if client_type == "renderer": + renderer_extra = { + "preserve_all_thinking": preserve_all_thinking, + "preserve_thinking_between_tool_calls": preserve_thinking_between_tool_calls, + } for base_url in client_config.base_url: for dp_rank in range(client_config.dp_rank_count): headers = client_config.headers.copy() @@ -205,6 +226,7 @@ def setup_clients( max_retries=10, extra_headers=headers, extra_headers_from_state=client_config.extra_headers_from_state, + **renderer_extra, ) ) client_idx += 1 diff --git a/src/prime_rl/utils/elastic.py b/src/prime_rl/utils/elastic.py index 902f873903..c59f81e27f 100644 --- a/src/prime_rl/utils/elastic.py +++ b/src/prime_rl/utils/elastic.py @@ -110,6 +110,8 @@ def __init__( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ): self.logger = get_logger() self.client_config = client_config @@ -125,6 +127,8 @@ def __init__( self.tool_parser = tool_parser self.reasoning_parser = reasoning_parser self.renderer_pool_size = renderer_pool_size + self.preserve_all_thinking = preserve_all_thinking + self.preserve_thinking_between_tool_calls = preserve_thinking_between_tool_calls self.router_url = client_config.router_url self._servers: dict[str, ServerState] = {} @@ -152,6 +156,8 @@ async def from_config( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ) -> ElasticInferencePool: if client_config.elastic is None: raise ValueError("Elastic inference pool requires elastic config") @@ -164,6 +170,8 @@ async def from_config( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, + preserve_all_thinking=preserve_all_thinking, + preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) await pool.start() return pool @@ -214,6 +222,8 @@ def _rebuild_clients(self) -> None: tool_parser=self.tool_parser, reasoning_parser=self.reasoning_parser, renderer_pool_size=self.renderer_pool_size, + preserve_all_thinking=self.preserve_all_thinking, + preserve_thinking_between_tool_calls=self.preserve_thinking_between_tool_calls, ) if urls else [] diff --git a/tests/unit/orchestrator/test_orchestrator_setup.py b/tests/unit/orchestrator/test_orchestrator_setup.py index ff9bb5b79f..d4567d9682 100644 --- a/tests/unit/orchestrator/test_orchestrator_setup.py +++ b/tests/unit/orchestrator/test_orchestrator_setup.py @@ -50,6 +50,8 @@ async def run() -> None: tool_parser=None, reasoning_parser=None, pool_size=None, + preserve_all_thinking=False, + preserve_thinking_between_tool_calls=False, ), ) rollout_client_config = SimpleNamespace(base_url=["http://localhost:8000/v1"]) @@ -79,6 +81,8 @@ async def run() -> None: renderer="qwen3_vl", tool_parser=None, reasoning_parser=None, + preserve_all_thinking=False, + preserve_thinking_between_tool_calls=False, ) setup_pool_mock.assert_awaited_once_with( rollout_client_config, @@ -89,6 +93,8 @@ async def run() -> None: tool_parser=None, reasoning_parser=None, renderer_pool_size=None, + preserve_all_thinking=False, + preserve_thinking_between_tool_calls=False, ) asyncio.run(run())