diff --git a/CHANGELOG.md b/CHANGELOG.md index d8e4c880f0..cd9c78a3a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs. +- **`trainer.loss.kl_tau` (default loss)**: Re-added the Kimi-K2.5 KL penalty (`log_importance_ratio**2`) to the default loss, but applied it only over unmasked (kept) tokens — masked tokens are dropped from the policy gradient and should not contribute to the KL either. Default `kl_tau = 0.0` (off); set explicitly to re-enable. (2026-05-09) - **`trainer.loss` (default loss)**: Dropped the Kimi-K2.5 KL term and removed the advantage-conditioned mask from the default DPPO-Binary TV loss. Tokens are now dropped symmetrically when `π_train - π_infer` falls outside `[-dppo_diff_low, dppo_diff_high]` regardless of advantage sign. Removed `kl_tau`. Renamed `dppo_mask_low` → `dppo_diff_low` (default `0.2`) and `dppo_mask_high` → `dppo_diff_high` (default `0.2`). `adv_tau` and `teacher_tau` are unchanged. (2026-05-06) - **`orchestrator.advantage.length_penalty` → discriminated sub-config**: The scalar `length_penalty: Literal["tokens","turns"] | None` is replaced by a `LengthPenaltyConfig | None` discriminated on `type`. Token shaping now takes weighted completion + tool-response token costs. Migration: `length_penalty = "tokens"` becomes `[orchestrator.advantage.length_penalty]\ntype = "tokens"` (default weights `completion_weight = 1.0`, `tool_response_weight = 1.0` — total context). `length_penalty = "turns"` becomes `[orchestrator.advantage.length_penalty]\ntype = "turns"`. (2026-05-06) - **`orchestrator.advantage.length_shaping` → `orchestrator.advantage.length_penalty`**: The boolean `length_shaping` flag has been replaced by `length_penalty: Literal["tokens", "turns"] | None` (default: `None`). `length_shaping = true` becomes `length_penalty = "tokens"`; `length_shaping = false` becomes `length_penalty = None`. The new `"turns"` option applies the same correctness-gated efficiency shaping using trajectory turn count instead of completion-token count. (2026-05-01) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index 2e37a56595..fae19a8869 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -696,6 +696,13 @@ class DefaultLossConfig(BaseModel): ] = 0.2 adv_tau: Annotated[float, Field(ge=0, description="The tau for advantages.")] = 1.0 teacher_tau: Annotated[float, Field(ge=0, description="The tau for teacher logprobs.")] = 0.0 + kl_tau: Annotated[ + float, + Field( + ge=0, + description="The tau for the KL penalty (squared log importance ratio), applied only over unmasked tokens.", + ), + ] = 0.0 class SFTLossConfig(BaseModel): diff --git a/src/prime_rl/trainer/rl/loss.py b/src/prime_rl/trainer/rl/loss.py index cb84aa5d16..2485e586c2 100644 --- a/src/prime_rl/trainer/rl/loss.py +++ b/src/prime_rl/trainer/rl/loss.py @@ -111,8 +111,14 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO Token-level masked importance sampling: tokens whose probability difference π_train - π_infer falls outside [-dppo_diff_low, dppo_diff_high] are dropped (gradient set to 0), not clipped. The mask is symmetric - (not advantage-conditioned). No KL penalty — the double-sided - difference mask is what keeps the update inside the trust region. + (not advantage-conditioned). The double-sided difference mask is what + keeps the update inside the trust region. + + An optional Kimi-K2.5-style KL penalty (https://arxiv.org/pdf/2602.02276), + scaled by `kl_tau`, is applied only over unmasked (kept) tokens — the + masked tokens are already dropped from the policy gradient and including + them in the KL would push gradient through positions we explicitly chose + to ignore. """ trainer_logprobs = inputs.trainer_logprobs inference_logprobs = inputs.inference_logprobs @@ -138,7 +144,8 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO teacher_kl = None pg_loss = keep_mask * advantages * importance_ratio - loss = (-pg_loss).sum() + kl_loss = keep_mask * log_importance_ratio**2 + loss = (-pg_loss + loss_config.kl_tau * kl_loss).sum() metrics = { "mismatch_kl": _safe_mean(mismatch_kl, loss_mask), # all trainable tokens @@ -147,6 +154,7 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO "is_masked": _safe_mean(is_masked, loss_mask), "is_masked_low": _safe_mean(is_masked_low, loss_mask), "is_masked_high": _safe_mean(is_masked_high, loss_mask), + "kl_penalty": _safe_mean(log_importance_ratio**2, keep_mask), } if teacher_kl is not None: metrics["teacher_kl"] = _safe_mean(teacher_kl, loss_mask)