Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs.

- **`trainer.loss.kl_tau` (default loss)**: Re-added the Kimi-K2.5 KL penalty (`log_importance_ratio**2`) to the default loss, but applied it only over unmasked (kept) tokens — masked tokens are dropped from the policy gradient and should not contribute to the KL either. Default `kl_tau = 0.0` (off); set explicitly to re-enable. (2026-05-09)
- **`trainer.loss` (default loss)**: Dropped the Kimi-K2.5 KL term and removed the advantage-conditioned mask from the default DPPO-Binary TV loss. Tokens are now dropped symmetrically when `π_train - π_infer` falls outside `[-dppo_diff_low, dppo_diff_high]` regardless of advantage sign. Removed `kl_tau`. Renamed `dppo_mask_low` → `dppo_diff_low` (default `0.2`) and `dppo_mask_high` → `dppo_diff_high` (default `0.2`). `adv_tau` and `teacher_tau` are unchanged. (2026-05-06)
- **`orchestrator.advantage.length_penalty` → discriminated sub-config**: The scalar `length_penalty: Literal["tokens","turns"] | None` is replaced by a `LengthPenaltyConfig | None` discriminated on `type`. Token shaping now takes weighted completion + tool-response token costs. Migration: `length_penalty = "tokens"` becomes `[orchestrator.advantage.length_penalty]\ntype = "tokens"` (default weights `completion_weight = 1.0`, `tool_response_weight = 1.0` — total context). `length_penalty = "turns"` becomes `[orchestrator.advantage.length_penalty]\ntype = "turns"`. (2026-05-06)
- **`orchestrator.advantage.length_shaping` → `orchestrator.advantage.length_penalty`**: The boolean `length_shaping` flag has been replaced by `length_penalty: Literal["tokens", "turns"] | None` (default: `None`). `length_shaping = true` becomes `length_penalty = "tokens"`; `length_shaping = false` becomes `length_penalty = None`. The new `"turns"` option applies the same correctness-gated efficiency shaping using trajectory turn count instead of completion-token count. (2026-05-01)
Expand Down
7 changes: 7 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,13 @@ class DefaultLossConfig(BaseModel):
] = 0.2
adv_tau: Annotated[float, Field(ge=0, description="The tau for advantages.")] = 1.0
teacher_tau: Annotated[float, Field(ge=0, description="The tau for teacher logprobs.")] = 0.0
kl_tau: Annotated[
float,
Field(
ge=0,
description="The tau for the KL penalty (squared log importance ratio), applied only over unmasked tokens.",
),
] = 0.0


class SFTLossConfig(BaseModel):
Expand Down
14 changes: 11 additions & 3 deletions src/prime_rl/trainer/rl/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,14 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO
Token-level masked importance sampling: tokens whose probability
difference π_train - π_infer falls outside [-dppo_diff_low, dppo_diff_high]
are dropped (gradient set to 0), not clipped. The mask is symmetric
(not advantage-conditioned). No KL penalty — the double-sided
difference mask is what keeps the update inside the trust region.
(not advantage-conditioned). The double-sided difference mask is what
keeps the update inside the trust region.

An optional Kimi-K2.5-style KL penalty (https://arxiv.org/pdf/2602.02276),
scaled by `kl_tau`, is applied only over unmasked (kept) tokens — the
masked tokens are already dropped from the policy gradient and including
them in the KL would push gradient through positions we explicitly chose
to ignore.
"""
trainer_logprobs = inputs.trainer_logprobs
inference_logprobs = inputs.inference_logprobs
Expand All @@ -138,7 +144,8 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO
teacher_kl = None

pg_loss = keep_mask * advantages * importance_ratio
loss = (-pg_loss).sum()
kl_loss = keep_mask * log_importance_ratio**2
loss = (-pg_loss + loss_config.kl_tau * kl_loss).sum()

metrics = {
"mismatch_kl": _safe_mean(mismatch_kl, loss_mask), # all trainable tokens
Expand All @@ -147,6 +154,7 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO
"is_masked": _safe_mean(is_masked, loss_mask),
"is_masked_low": _safe_mean(is_masked_low, loss_mask),
"is_masked_high": _safe_mean(is_masked_high, loss_mask),
"kl_penalty": _safe_mean(log_importance_ratio**2, keep_mask),
}
if teacher_kl is not None:
metrics["teacher_kl"] = _safe_mean(teacher_kl, loss_mask)
Expand Down
Loading