diff --git a/CHANGELOG.md b/CHANGELOG.md index 7148829503..e30a402374 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs. +- **`trainer.loss` (default loss)**: Default IPO+KL loss now uses **double-sided importance-ratio masking** (advantage-sign-agnostic) instead of advantage-conditioned probability-difference masking, and the Kimi-K2.5 KL penalty is **only applied to kept (unmasked) tokens**. `dppo_mask_low` and `dppo_mask_high` are now interpreted as importance-ratio bounds (was: probability-difference thresholds). `dppo_mask_high` default changed from `0.2` → `5.0` to match the new ratio semantics; `dppo_mask_low` default stays at `0.2`. `kl_tau`, `adv_tau`, `teacher_tau` unchanged. (2026-05-02) - **`orchestrator.advantage.length_shaping` → `orchestrator.advantage.length_penalty`**: The boolean `length_shaping` flag has been replaced by `length_penalty: Literal["tokens", "turns"] | None` (default: `None`). `length_shaping = true` becomes `length_penalty = "tokens"`; `length_shaping = false` becomes `length_penalty = None`. The new `"turns"` option applies the same correctness-gated efficiency shaping using trajectory turn count instead of completion-token count. (2026-05-01) - **`AdvantageInputs` API**: Replaced the `rewards`/`completion_lengths`/`num_turns` tensor fields with a single `rollouts: list[list[vf.RolloutOutput]]` (grouped by problem). Custom advantage functions can now access any rollout metadata. Existing custom advantages must update their signatures and extract per-rollout fields directly (e.g. `torch.tensor([[r["reward"] for r in g] for g in inputs.rollouts])`). (2026-05-01) - **`orchestrator.teacher_rollout_model` now requires `orchestrator.use_sft_loss = true`**: External teacher rollout configs no longer rely on `trainer.loss.type = "sft"` to select SFT loss. Existing hard-distill configs must set `orchestrator.use_sft_loss = true` alongside `orchestrator.teacher_rollout_model`; the orchestrator validates the pair and stamps training samples with the per-batch SFT loss bool. (2026-04-26) diff --git a/src/prime_rl/configs/trainer.py b/src/prime_rl/configs/trainer.py index 0065dffd3d..dbe539d304 100644 --- a/src/prime_rl/configs/trainer.py +++ b/src/prime_rl/configs/trainer.py @@ -653,8 +653,14 @@ class DefaultLossConfig(BaseModel): type: Literal["default"] = "default" - dppo_mask_low: Annotated[float, Field(ge=0, description="The low threshold for masking tokens.")] = 0.2 - dppo_mask_high: Annotated[float, Field(ge=0, description="The high threshold for masking tokens.")] = 0.2 + dppo_mask_low: Annotated[ + float, + Field(ge=0, description="Lower bound α on the importance ratio. Tokens below are dropped from PG and KL."), + ] = 0.2 + dppo_mask_high: Annotated[ + float, + Field(ge=0, description="Upper bound β on the importance ratio. Tokens above are dropped from PG and KL."), + ] = 5.0 adv_tau: Annotated[float, Field(ge=0, description="The tau for advantages.")] = 1.0 teacher_tau: Annotated[float, Field(ge=0, description="The tau for teacher logprobs.")] = 0.0 kl_tau: Annotated[float, Field(ge=0, description="The tau for KL divergence.")] = 1e-3 diff --git a/src/prime_rl/trainer/rl/loss.py b/src/prime_rl/trainer/rl/loss.py index d67c055a68..46fac096ae 100644 --- a/src/prime_rl/trainer/rl/loss.py +++ b/src/prime_rl/trainer/rl/loss.py @@ -106,15 +106,15 @@ def _safe_mean(values: Tensor, mask: Tensor) -> Tensor: def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossOutputs: """ - DPPO+KL loss, combining: + IPO+KL loss, combining: - DPPO-Binary TV Loss (https://arxiv.org/pdf/2602.04879) - Kimi-K2.5 KL Loss (https://arxiv.org/pdf/2602.02276) - The mask is conditioned on the advantage sign: for positive advantages, - we mask tokens whose probability increased too much (trust region violation - in the upweight direction); for negative advantages, we mask tokens whose - probability decreased too much (trust region violation in the downweight - direction). + Two modifications vs the original DPPO+KL: + - Double-sided importance-ratio masking (advantage-sign-agnostic): tokens + whose ratio π_train/π_infer falls outside [α, β] are dropped. + - The KL penalty is applied only on kept (unmasked) tokens — masked tokens + contribute nothing to either the PG or the KL term. """ trainer_logprobs = inputs.trainer_logprobs inference_logprobs = inputs.inference_logprobs @@ -122,22 +122,15 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO advantages = inputs.advantages loss_mask = inputs.loss_mask - trainer_probs = torch.exp(trainer_logprobs) - inference_probs = torch.exp(inference_logprobs) - probs_diff = trainer_probs - inference_probs - dppo_invalid_mask_high = probs_diff > loss_config.dppo_mask_high - dppo_invalid_mask_low = probs_diff < -loss_config.dppo_mask_low - dppo_invalid_mask = torch.where(advantages > 0, dppo_invalid_mask_high, dppo_invalid_mask_low) - - is_masked = dppo_invalid_mask - is_masked_high = (advantages > 0) & dppo_invalid_mask_high - is_masked_low = (advantages < 0) & dppo_invalid_mask_low - keep_mask = loss_mask & ~is_masked - log_importance_ratio = trainer_logprobs - inference_logprobs importance_ratio = torch.exp(log_importance_ratio) mismatch_kl = importance_ratio - log_importance_ratio - 1 + is_masked_low = importance_ratio.detach() < loss_config.dppo_mask_low + is_masked_high = importance_ratio.detach() > loss_config.dppo_mask_high + is_masked = is_masked_low | is_masked_high + keep_mask = loss_mask & ~is_masked + advantages = loss_config.adv_tau * advantages if teacher_logprobs is not None: teacher_kl = teacher_logprobs - trainer_logprobs @@ -146,7 +139,7 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO teacher_kl = None pg_loss = keep_mask * advantages * importance_ratio - kl_loss = loss_mask * log_importance_ratio**2 + kl_loss = keep_mask * log_importance_ratio**2 loss = (-pg_loss + loss_config.kl_tau * kl_loss).sum() metrics = {