Skip to content

feat(trainer): KL penalty over unmasked tokens only#2458

Draft
samsja wants to merge 1 commit into
feat/dppo-diff-default-lossfrom
feat/dppo-kl-on-unmasked
Draft

feat(trainer): KL penalty over unmasked tokens only#2458
samsja wants to merge 1 commit into
feat/dppo-diff-default-lossfrom
feat/dppo-kl-on-unmasked

Conversation

@samsja
Copy link
Copy Markdown
Member

@samsja samsja commented May 9, 2026

Summary

  • Re-add the Kimi-K2.5 KL term (log_importance_ratio**2 scaled by kl_tau) on top of the symmetric DPPO-Binary TV loss.
  • Apply the KL over keep_mask (unmasked / kept tokens) instead of the full loss_mask. The trust-region mask already drops out-of-region tokens from the policy gradient — putting the KL on them too pushes gradient through positions we explicitly chose to ignore, which is what the current run is doing.
  • Default kl_tau = 0.0 (off) so existing configs are unchanged; set explicitly to re-enable.
  • Add a kl_penalty metric averaged over kept tokens for monitoring.

Why this matters

Before this branch's symmetric refactor, main applied kl_loss = loss_mask * log_importance_ratio**2. After the refactor the KL was removed entirely. The current run wants the KL back, but using loss_mask is incorrect when we have a non-trivial trust-region mask: masked tokens (large π_train − π_infer outside the band) would dominate the squared log-ratio and contribute the bulk of the KL — exactly where we don't want gradient.

Files

  • src/prime_rl/trainer/rl/loss.py — KL term added with keep_mask; metric exposed.
  • packages/prime-rl-configs/src/prime_rl/configs/trainer.pykl_tau field (default 0.0) re-added to DefaultLossConfig.
  • CHANGELOG.md — entry for the new field.

🤖 Generated with Claude Code

Re-add the Kimi-K2.5 KL term (log_importance_ratio**2 scaled by kl_tau)
that the symmetric DPPO-Binary TV refactor dropped, but apply it over
keep_mask rather than loss_mask. Tokens outside the trust region are
already excluded from the policy gradient — putting the KL on them too
would push gradient through positions we explicitly chose to ignore,
which is what the previous run was doing wrong.

Default kl_tau = 0.0 so existing configs are unchanged. Adds a
`kl_penalty` metric averaged over kept tokens.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant