diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e0de85b98..d8e4c880f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs. +- **`trainer.loss` (default loss)**: Dropped the Kimi-K2.5 KL term and removed the advantage-conditioned mask from the default DPPO-Binary TV loss. Tokens are now dropped symmetrically when `π_train - π_infer` falls outside `[-dppo_diff_low, dppo_diff_high]` regardless of advantage sign. Removed `kl_tau`. Renamed `dppo_mask_low` → `dppo_diff_low` (default `0.2`) and `dppo_mask_high` → `dppo_diff_high` (default `0.2`). `adv_tau` and `teacher_tau` are unchanged. (2026-05-06) - **`orchestrator.advantage.length_penalty` → discriminated sub-config**: The scalar `length_penalty: Literal["tokens","turns"] | None` is replaced by a `LengthPenaltyConfig | None` discriminated on `type`. Token shaping now takes weighted completion + tool-response token costs. Migration: `length_penalty = "tokens"` becomes `[orchestrator.advantage.length_penalty]\ntype = "tokens"` (default weights `completion_weight = 1.0`, `tool_response_weight = 1.0` — total context). `length_penalty = "turns"` becomes `[orchestrator.advantage.length_penalty]\ntype = "turns"`. (2026-05-06) - **`orchestrator.advantage.length_shaping` → `orchestrator.advantage.length_penalty`**: The boolean `length_shaping` flag has been replaced by `length_penalty: Literal["tokens", "turns"] | None` (default: `None`). `length_shaping = true` becomes `length_penalty = "tokens"`; `length_shaping = false` becomes `length_penalty = None`. The new `"turns"` option applies the same correctness-gated efficiency shaping using trajectory turn count instead of completion-token count. (2026-05-01) - **`AdvantageInputs` API**: Replaced the `rewards`/`completion_lengths`/`num_turns` tensor fields with a single `rollouts: list[list[vf.RolloutOutput]]` (grouped by problem). Custom advantage functions can now access any rollout metadata. Existing custom advantages must update their signatures and extract per-rollout fields directly (e.g. `torch.tensor([[r["reward"] for r in g] for g in inputs.rollouts])`). (2026-05-01) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 3911816227..e152741bf2 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -1082,6 +1082,14 @@ class OrchestratorConfig(BaseConfig): ), ] = 8 + max_error_reschedule_attempts: Annotated[ + int | None, + Field( + ge=1, + description="Maximum number of times the scheduler will reschedule a group whose rollouts errored or returned empty trajectories. After this many consecutive failed attempts, the group is dropped from the current step's batch (the trainer proceeds with the rollouts from other groups). `None` means retry indefinitely (legacy behavior). Useful for unblocking single-example hangs in agent envs.", + ), + ] = None + max_async_level: Annotated[ int, Field( @@ -1261,6 +1269,10 @@ def validate_renderer_args(self): renderer_args_set.append(f"renderer.reasoning_parser={self.renderer.reasoning_parser!r}") if self.renderer.pool_size is not None: renderer_args_set.append(f"renderer.pool_size={self.renderer.pool_size!r}") + if self.renderer.preserve_all_thinking: + renderer_args_set.append("renderer.preserve_all_thinking=True") + if self.renderer.preserve_thinking_between_tool_calls: + renderer_args_set.append("renderer.preserve_thinking_between_tool_calls=True") if renderer_args_set: raise ValueError( diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index d26c33d9a9..24cb5d97c5 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -186,6 +186,34 @@ class RendererConfig(BaseConfig): ), ] = None + preserve_all_thinking: Annotated[ + bool, + Field( + description=( + "If True, the renderer keeps `...` blocks for ALL " + "assistant turns when re-rendering the prompt for the next request. " + "By default the GLM/Qwen renderers strip thinking from older turns " + "(only the most recent assistant keeps it). Stripping breaks the " + "trajectory-step prefix property at every turn-rebuild — re-rendered " + "tokens no longer extend the streamed tokens, so the splitter opens " + "extra training samples (e.g. 1 + 2*compactions instead of 1 + " + "compactions). Turn this on for RL with multi-turn agents that " + "compact context." + ), + ), + ] = False + + preserve_thinking_between_tool_calls: Annotated[ + bool, + Field( + description=( + "Narrower variant of preserve_all_thinking: keep thinking only " + "between consecutive tool-call assistant turns. See " + "renderers.Renderer.render for exact semantics." + ), + ), + ] = False + class ElasticConfig(BaseConfig): """Configures elastic inference pool with DNS-based service discovery. diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index a076d1e29c..2e37a56595 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -682,15 +682,20 @@ class CheckpointConfig(BaseConfig): class DefaultLossConfig(BaseModel): - """Config for the default loss.""" + """Config for the default loss (DPPO-Binary TV / arXiv:2602.04879).""" type: Literal["default"] = "default" - dppo_mask_low: Annotated[float, Field(ge=0, description="The low threshold for masking tokens.")] = 0.2 - dppo_mask_high: Annotated[float, Field(ge=0, description="The high threshold for masking tokens.")] = 0.2 + dppo_diff_low: Annotated[ + float, + Field(ge=0, description="Lower bound on (π_train - π_infer). Tokens below -dppo_diff_low are dropped."), + ] = 0.2 + dppo_diff_high: Annotated[ + float, + Field(ge=0, description="Upper bound on (π_train - π_infer). Tokens above dppo_diff_high are dropped."), + ] = 0.2 adv_tau: Annotated[float, Field(ge=0, description="The tau for advantages.")] = 1.0 teacher_tau: Annotated[float, Field(ge=0, description="The tau for teacher logprobs.")] = 0.0 - kl_tau: Annotated[float, Field(ge=0, description="The tau for KL divergence.")] = 1e-3 class SFTLossConfig(BaseModel): diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index bc1128ebc7..67ef7bfa1d 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -926,6 +926,8 @@ async def setup_rollout_inference_pool( renderer=config.renderer.name, tool_parser=config.renderer.tool_parser, reasoning_parser=config.renderer.reasoning_parser, + preserve_all_thinking=config.renderer.preserve_all_thinking, + preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls, ) logger.info(f"Initialized {type(renderer).__name__} for {config.model.name}") inference_pool = await setup_inference_pool( @@ -937,6 +939,8 @@ async def setup_rollout_inference_pool( tool_parser=config.renderer.tool_parser, reasoning_parser=config.renderer.reasoning_parser, renderer_pool_size=config.renderer.pool_size, + preserve_all_thinking=config.renderer.preserve_all_thinking, + preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls, ) logger.info("Using direct renderer rollout client") return renderer, inference_pool diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py index a8427b69f7..9187b3ff8d 100644 --- a/src/prime_rl/orchestrator/scheduler.py +++ b/src/prime_rl/orchestrator/scheduler.py @@ -42,6 +42,10 @@ class GroupState: rollouts_to_schedule: int completed_rollouts: list[vf.RolloutOutput] = field(default_factory=list) pinned_client: vf.ClientConfig | None = None + # Number of rollout attempts in this group that returned errored or empty + # trajectories. Compared against config.max_error_reschedule_attempts to + # decide when to drop a permanently-stuck group. + failed_attempts: int = 0 class Scheduler: @@ -114,6 +118,7 @@ def __init__( self.empty_rollouts_by_env: dict[str, int] = defaultdict(int) self.errored_rollouts_by_env: dict[str, int] = defaultdict(int) self.total_rollouts_by_env: dict[str, int] = defaultdict(int) + self.dropped_groups_by_env: dict[str, int] = defaultdict(int) self.last_batch_generation_time = 0.0 @property @@ -454,6 +459,24 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]: rollout["env_name"] = env_name valid_rollouts.append(rollout) + if has_failures: + group.failed_attempts += 1 + max_attempts = self.config.max_error_reschedule_attempts + if max_attempts is not None and group.failed_attempts >= max_attempts: + # Permanently-stuck group: drop it from this step and let the + # rest of the batch proceed. Avoids a single bad example (e.g. + # an agent rollout whose sandbox poll keeps timing out) + # blocking step progress forever. + self.dropped_groups_by_env[env_name] += 1 + self.logger.warning( + f"Dropping group {group_id} ({env_name}) after {group.failed_attempts} " + f"failed attempts ({len(group.completed_rollouts)}/{self.rollouts_per_example} " + f"complete). Set orchestrator.max_error_reschedule_attempts higher (or to None) " + f"to retry more aggressively." + ) + await self.drop_group(group_id) + continue + if has_failures and env.requires_group_scoring: # Group scoring requires all rollouts — discard partial results, reschedule full group group.completed_rollouts.clear() diff --git a/src/prime_rl/trainer/models/kernels/fp8_utils.py b/src/prime_rl/trainer/models/kernels/fp8_utils.py index 83dd81b52a..82c26fb7e4 100644 --- a/src/prime_rl/trainer/models/kernels/fp8_utils.py +++ b/src/prime_rl/trainer/models/kernels/fp8_utils.py @@ -139,17 +139,18 @@ def _unpack_grouped_rows_kernel( actual_m = tl.load(actual_ms_ptr + pid_g) row_offsets = (pid_blk - block_start) * GROUP_BLOCK_M + pid_sub * BLOCK_M + tl.arange(0, BLOCK_M) col_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + src_rows_i64 = (pid_blk * GROUP_BLOCK_M + pid_sub * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) + dst_rows_i64 = (dst_start + row_offsets).to(tl.int64) + col_offsets_i64 = col_offsets.to(tl.int64) valid_rows = row_offsets < actual_m valid_cols = col_offsets < cols x = tl.load( - x_ptr - + (pid_blk * GROUP_BLOCK_M + pid_sub * BLOCK_M + tl.arange(0, BLOCK_M))[:, None] * stride_xm - + col_offsets[None, :] * stride_xn, + x_ptr + src_rows_i64[:, None] * stride_xm + col_offsets_i64[None, :] * stride_xn, mask=valid_rows[:, None] & valid_cols[None, :], other=0.0, ) tl.store( - out_ptr + (dst_start + row_offsets)[:, None] * stride_ym + col_offsets[None, :] * stride_yn, + out_ptr + dst_rows_i64[:, None] * stride_ym + col_offsets_i64[None, :] * stride_yn, x, mask=valid_rows[:, None] & valid_cols[None, :], ) @@ -176,9 +177,11 @@ def _per_token_fp8_kernel( pid_k = tl.program_id(axis=1) row_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) col_offsets = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) + row_offsets_i64 = row_offsets.to(tl.int64) + col_offsets_i64 = col_offsets.to(tl.int64) mask = (row_offsets[:, None] < rows) & (col_offsets[None, :] < cols) x = tl.load( - x_ptr + row_offsets[:, None] * stride_xm + col_offsets[None, :] * stride_xn, + x_ptr + row_offsets_i64[:, None] * stride_xm + col_offsets_i64[None, :] * stride_xn, mask=mask, other=0.0, ).to(tl.float32) @@ -188,11 +191,11 @@ def _per_token_fp8_kernel( scale = tl.exp2(tl.ceil(tl.log2(scale))) y = x / scale[:, None] tl.store( - out_ptr + row_offsets[:, None] * stride_ym + col_offsets[None, :] * stride_yn, + out_ptr + row_offsets_i64[:, None] * stride_ym + col_offsets_i64[None, :] * stride_yn, y.to(tl.float8e4nv), mask=mask, ) - tl.store(sf_ptr + row_offsets * stride_sm + pid_k * stride_sk, scale, mask=row_offsets < rows) + tl.store(sf_ptr + row_offsets_i64 * stride_sm + pid_k * stride_sk, scale, mask=row_offsets < rows) @triton.jit @@ -226,12 +229,13 @@ def _grouped_per_token_fp8_kernel( local_block = pid_blk - block_start row_offsets = local_block * GROUP_BLOCK_M + pid_sub * BLOCK_M + tl.arange(0, BLOCK_M) col_offsets = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) + src_rows_i64 = (src_start + row_offsets).to(tl.int64) + dst_rows_i64 = (pid_blk * GROUP_BLOCK_M + pid_sub * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) + col_offsets_i64 = col_offsets.to(tl.int64) valid_rows = row_offsets < actual_m valid_cols = col_offsets < cols - src_rows = src_start + row_offsets - dst_rows = pid_blk * GROUP_BLOCK_M + pid_sub * BLOCK_M + tl.arange(0, BLOCK_M) x = tl.load( - x_ptr + src_rows[:, None] * stride_xm + col_offsets[None, :] * stride_xn, + x_ptr + src_rows_i64[:, None] * stride_xm + col_offsets_i64[None, :] * stride_xn, mask=valid_rows[:, None] & valid_cols[None, :], other=0.0, ).to(tl.float32) @@ -241,11 +245,11 @@ def _grouped_per_token_fp8_kernel( scale = tl.exp2(tl.ceil(tl.log2(scale))) y = x / scale[:, None] tl.store( - out_ptr + dst_rows[:, None] * stride_ym + col_offsets[None, :] * stride_yn, + out_ptr + dst_rows_i64[:, None] * stride_ym + col_offsets_i64[None, :] * stride_yn, y.to(tl.float8e4nv), mask=valid_rows[:, None] & valid_cols[None, :], ) - tl.store(sf_ptr + dst_rows * stride_sm + pid_k * stride_sk, scale, mask=valid_rows) + tl.store(sf_ptr + dst_rows_i64 * stride_sm + pid_k * stride_sk, scale, mask=valid_rows) @triton.jit @@ -277,11 +281,13 @@ def _grouped_per_channel_fp8_sm90_kmajor_kernel( local_block = pid_blk - block_start row_offsets = local_block * BLOCK_K + tl.arange(0, BLOCK_K) col_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + src_rows_i64 = (src_start + row_offsets).to(tl.int64) + row_offsets_i64 = row_offsets.to(tl.int64) + col_offsets_i64 = col_offsets.to(tl.int64) valid_rows = row_offsets < actual_m valid_cols = col_offsets < cols - src_rows = src_start + row_offsets x = tl.load( - x_ptr + src_rows[:, None] * stride_xm + col_offsets[None, :] * stride_xn, + x_ptr + src_rows_i64[:, None] * stride_xm + col_offsets_i64[None, :] * stride_xn, mask=valid_rows[:, None] & valid_cols[None, :], other=0.0, ).to(tl.float32) @@ -290,11 +296,11 @@ def _grouped_per_channel_fp8_sm90_kmajor_kernel( if USE_UE8M0: scale = tl.exp2(tl.ceil(tl.log2(scale))) y = x / scale[None, :] - flat_base = block_start * BLOCK_K * cols - out_ptrs = out_ptr + flat_base + col_offsets[:, None] * aligned_m + row_offsets[None, :] + flat_base = block_start.to(tl.int64) * BLOCK_K * cols + out_ptrs = out_ptr + flat_base + col_offsets_i64[:, None] * aligned_m + row_offsets_i64[None, :] tl.store(out_ptrs, tl.trans(y).to(tl.float8e4nv), mask=valid_cols[:, None] & (row_offsets[None, :] < aligned_m)) tl.store( - sf_ptr + pid_blk * stride_sf0 + col_offsets * stride_sf1, + sf_ptr + pid_blk * stride_sf0 + col_offsets_i64 * stride_sf1, scale, mask=valid_cols, ) diff --git a/src/prime_rl/trainer/rl/loss.py b/src/prime_rl/trainer/rl/loss.py index d67c055a68..cb84aa5d16 100644 --- a/src/prime_rl/trainer/rl/loss.py +++ b/src/prime_rl/trainer/rl/loss.py @@ -106,15 +106,13 @@ def _safe_mean(values: Tensor, mask: Tensor) -> Tensor: def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossOutputs: """ - DPPO+KL loss, combining: - - DPPO-Binary TV Loss (https://arxiv.org/pdf/2602.04879) - - Kimi-K2.5 KL Loss (https://arxiv.org/pdf/2602.02276) - - The mask is conditioned on the advantage sign: for positive advantages, - we mask tokens whose probability increased too much (trust region violation - in the upweight direction); for negative advantages, we mask tokens whose - probability decreased too much (trust region violation in the downweight - direction). + DPPO-Binary TV loss (https://arxiv.org/pdf/2602.04879), symmetric variant. + + Token-level masked importance sampling: tokens whose probability + difference π_train - π_infer falls outside [-dppo_diff_low, dppo_diff_high] + are dropped (gradient set to 0), not clipped. The mask is symmetric + (not advantage-conditioned). No KL penalty — the double-sided + difference mask is what keeps the update inside the trust region. """ trainer_logprobs = inputs.trainer_logprobs inference_logprobs = inputs.inference_logprobs @@ -122,22 +120,16 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO advantages = inputs.advantages loss_mask = inputs.loss_mask - trainer_probs = torch.exp(trainer_logprobs) - inference_probs = torch.exp(inference_logprobs) - probs_diff = trainer_probs - inference_probs - dppo_invalid_mask_high = probs_diff > loss_config.dppo_mask_high - dppo_invalid_mask_low = probs_diff < -loss_config.dppo_mask_low - dppo_invalid_mask = torch.where(advantages > 0, dppo_invalid_mask_high, dppo_invalid_mask_low) - - is_masked = dppo_invalid_mask - is_masked_high = (advantages > 0) & dppo_invalid_mask_high - is_masked_low = (advantages < 0) & dppo_invalid_mask_low - keep_mask = loss_mask & ~is_masked - log_importance_ratio = trainer_logprobs - inference_logprobs importance_ratio = torch.exp(log_importance_ratio) mismatch_kl = importance_ratio - log_importance_ratio - 1 + probs_diff = (torch.exp(trainer_logprobs) - torch.exp(inference_logprobs)).detach() + is_masked_low = probs_diff < -loss_config.dppo_diff_low + is_masked_high = probs_diff > loss_config.dppo_diff_high + is_masked = is_masked_low | is_masked_high + keep_mask = loss_mask & ~is_masked + advantages = loss_config.adv_tau * advantages if teacher_logprobs is not None: teacher_kl = teacher_logprobs - trainer_logprobs @@ -146,8 +138,7 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO teacher_kl = None pg_loss = keep_mask * advantages * importance_ratio - kl_loss = loss_mask * log_importance_ratio**2 - loss = (-pg_loss + loss_config.kl_tau * kl_loss).sum() + loss = (-pg_loss).sum() metrics = { "mismatch_kl": _safe_mean(mismatch_kl, loss_mask), # all trainable tokens diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index 21659dfc46..fedbdddb8e 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -68,6 +68,8 @@ def __init__( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ): renderer_model_name = model_name if train_client_type == "renderer" else None self._train_clients = setup_clients( @@ -78,6 +80,8 @@ def __init__( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, + preserve_all_thinking=preserve_all_thinking, + preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) self._eval_clients = setup_clients(client_config, client_type=eval_client_type) self._admin_clients = setup_admin_clients(client_config) @@ -129,6 +133,8 @@ async def setup_inference_pool( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ) -> InferencePool: """Create an inference pool from config (static or elastic).""" logger = get_logger() @@ -152,6 +158,8 @@ async def setup_inference_pool( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, + preserve_all_thinking=preserve_all_thinking, + preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) logger.info( @@ -168,6 +176,8 @@ async def setup_inference_pool( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, + preserve_all_thinking=preserve_all_thinking, + preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) @@ -179,6 +189,8 @@ def setup_clients( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ) -> list[vf.ClientConfig]: clients = [] client_idx = 0 @@ -196,6 +208,8 @@ def setup_clients( renderer_pool_size=renderer_pool_size, tool_parser=tool_parser, reasoning_parser=reasoning_parser, + preserve_all_thinking=preserve_all_thinking, + preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, api_base_url=base_url, api_key_var=client_config.api_key_var, timeout=client_config.timeout, diff --git a/src/prime_rl/utils/elastic.py b/src/prime_rl/utils/elastic.py index 902f873903..c59f81e27f 100644 --- a/src/prime_rl/utils/elastic.py +++ b/src/prime_rl/utils/elastic.py @@ -110,6 +110,8 @@ def __init__( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ): self.logger = get_logger() self.client_config = client_config @@ -125,6 +127,8 @@ def __init__( self.tool_parser = tool_parser self.reasoning_parser = reasoning_parser self.renderer_pool_size = renderer_pool_size + self.preserve_all_thinking = preserve_all_thinking + self.preserve_thinking_between_tool_calls = preserve_thinking_between_tool_calls self.router_url = client_config.router_url self._servers: dict[str, ServerState] = {} @@ -152,6 +156,8 @@ async def from_config( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ) -> ElasticInferencePool: if client_config.elastic is None: raise ValueError("Elastic inference pool requires elastic config") @@ -164,6 +170,8 @@ async def from_config( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, + preserve_all_thinking=preserve_all_thinking, + preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) await pool.start() return pool @@ -214,6 +222,8 @@ def _rebuild_clients(self) -> None: tool_parser=self.tool_parser, reasoning_parser=self.reasoning_parser, renderer_pool_size=self.renderer_pool_size, + preserve_all_thinking=self.preserve_all_thinking, + preserve_thinking_between_tool_calls=self.preserve_thinking_between_tool_calls, ) if urls else [] diff --git a/tests/unit/train/rl/test_loss.py b/tests/unit/train/rl/test_loss.py index 696897e368..1a8bcb1c29 100644 --- a/tests/unit/train/rl/test_loss.py +++ b/tests/unit/train/rl/test_loss.py @@ -14,7 +14,7 @@ def test_grpo_loss(): advantages = [torch.randn(50).cuda(), torch.randn(30).cuda()] loss_mask = [torch.ones(50, dtype=torch.bool).cuda(), torch.ones(30, dtype=torch.bool).cuda()] - loss_fn = setup_loss_fn(DefaultLossConfig(dppo_mask_high=10.0)) + loss_fn = setup_loss_fn(DefaultLossConfig(dppo_diff_high=10.0)) loss, _ = compute_loss( trainer_logprobs, inference_logprobs, @@ -34,7 +34,7 @@ def test_gspo_loss(): advantages = [torch.randn(40).cuda(), torch.randn(60).cuda()] loss_mask = [torch.ones(40, dtype=torch.bool).cuda(), torch.ones(60, dtype=torch.bool).cuda()] - loss_fn = setup_loss_fn(DefaultLossConfig(dppo_mask_high=10.0)) + loss_fn = setup_loss_fn(DefaultLossConfig(dppo_diff_high=10.0)) loss, _ = compute_loss( trainer_logprobs, inference_logprobs,