Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs.

- **`trainer.loss` (default loss)**: Default IPO+KL loss now uses **double-sided importance-ratio masking** (advantage-sign-agnostic) instead of advantage-conditioned probability-difference masking, and the Kimi-K2.5 KL penalty is **only applied to kept (unmasked) tokens**. `dppo_mask_low` and `dppo_mask_high` are now interpreted as importance-ratio bounds (was: probability-difference thresholds). `dppo_mask_high` default changed from `0.2` → `5.0` to match the new ratio semantics; `dppo_mask_low` default stays at `0.2`. `kl_tau`, `adv_tau`, `teacher_tau` unchanged. (2026-05-02)
- **`orchestrator.advantage.length_shaping` → `orchestrator.advantage.length_penalty`**: The boolean `length_shaping` flag has been replaced by `length_penalty: Literal["tokens", "turns"] | None` (default: `None`). `length_shaping = true` becomes `length_penalty = "tokens"`; `length_shaping = false` becomes `length_penalty = None`. The new `"turns"` option applies the same correctness-gated efficiency shaping using trajectory turn count instead of completion-token count. (2026-05-01)
- **`AdvantageInputs` API**: Replaced the `rewards`/`completion_lengths`/`num_turns` tensor fields with a single `rollouts: list[list[vf.RolloutOutput]]` (grouped by problem). Custom advantage functions can now access any rollout metadata. Existing custom advantages must update their signatures and extract per-rollout fields directly (e.g. `torch.tensor([[r["reward"] for r in g] for g in inputs.rollouts])`). (2026-05-01)
- **`orchestrator.teacher_rollout_model` now requires `orchestrator.use_sft_loss = true`**: External teacher rollout configs no longer rely on `trainer.loss.type = "sft"` to select SFT loss. Existing hard-distill configs must set `orchestrator.use_sft_loss = true` alongside `orchestrator.teacher_rollout_model`; the orchestrator validates the pair and stamps training samples with the per-batch SFT loss bool. (2026-04-26)
Expand Down
10 changes: 8 additions & 2 deletions src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,8 +653,14 @@ class DefaultLossConfig(BaseModel):

type: Literal["default"] = "default"

dppo_mask_low: Annotated[float, Field(ge=0, description="The low threshold for masking tokens.")] = 0.2
dppo_mask_high: Annotated[float, Field(ge=0, description="The high threshold for masking tokens.")] = 0.2
dppo_mask_low: Annotated[
float,
Field(ge=0, description="Lower bound α on the importance ratio. Tokens below are dropped from PG and KL."),
] = 0.2
dppo_mask_high: Annotated[
float,
Field(ge=0, description="Upper bound β on the importance ratio. Tokens above are dropped from PG and KL."),
] = 5.0
adv_tau: Annotated[float, Field(ge=0, description="The tau for advantages.")] = 1.0
teacher_tau: Annotated[float, Field(ge=0, description="The tau for teacher logprobs.")] = 0.0
kl_tau: Annotated[float, Field(ge=0, description="The tau for KL divergence.")] = 1e-3
Expand Down
31 changes: 12 additions & 19 deletions src/prime_rl/trainer/rl/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,38 +106,31 @@ def _safe_mean(values: Tensor, mask: Tensor) -> Tensor:

def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossOutputs:
"""
DPPO+KL loss, combining:
IPO+KL loss, combining:
- DPPO-Binary TV Loss (https://arxiv.org/pdf/2602.04879)
- Kimi-K2.5 KL Loss (https://arxiv.org/pdf/2602.02276)

The mask is conditioned on the advantage sign: for positive advantages,
we mask tokens whose probability increased too much (trust region violation
in the upweight direction); for negative advantages, we mask tokens whose
probability decreased too much (trust region violation in the downweight
direction).
Two modifications vs the original DPPO+KL:
- Double-sided importance-ratio masking (advantage-sign-agnostic): tokens
whose ratio π_train/π_infer falls outside [α, β] are dropped.
- The KL penalty is applied only on kept (unmasked) tokens — masked tokens
contribute nothing to either the PG or the KL term.
"""
trainer_logprobs = inputs.trainer_logprobs
inference_logprobs = inputs.inference_logprobs
teacher_logprobs = inputs.teacher_logprobs
advantages = inputs.advantages
loss_mask = inputs.loss_mask

trainer_probs = torch.exp(trainer_logprobs)
inference_probs = torch.exp(inference_logprobs)
probs_diff = trainer_probs - inference_probs
dppo_invalid_mask_high = probs_diff > loss_config.dppo_mask_high
dppo_invalid_mask_low = probs_diff < -loss_config.dppo_mask_low
dppo_invalid_mask = torch.where(advantages > 0, dppo_invalid_mask_high, dppo_invalid_mask_low)

is_masked = dppo_invalid_mask
is_masked_high = (advantages > 0) & dppo_invalid_mask_high
is_masked_low = (advantages < 0) & dppo_invalid_mask_low
keep_mask = loss_mask & ~is_masked

log_importance_ratio = trainer_logprobs - inference_logprobs
importance_ratio = torch.exp(log_importance_ratio)
mismatch_kl = importance_ratio - log_importance_ratio - 1

is_masked_low = importance_ratio.detach() < loss_config.dppo_mask_low
is_masked_high = importance_ratio.detach() > loss_config.dppo_mask_high
is_masked = is_masked_low | is_masked_high
keep_mask = loss_mask & ~is_masked

advantages = loss_config.adv_tau * advantages
if teacher_logprobs is not None:
teacher_kl = teacher_logprobs - trainer_logprobs
Expand All @@ -146,7 +139,7 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO
teacher_kl = None

pg_loss = keep_mask * advantages * importance_ratio
kl_loss = loss_mask * log_importance_ratio**2
kl_loss = keep_mask * log_importance_ratio**2
loss = (-pg_loss + loss_config.kl_tau * kl_loss).sum()

metrics = {
Expand Down
Loading