feat(trainer): KL penalty over unmasked tokens only#2458
Draft
samsja wants to merge 1 commit into
Draft
Conversation
Re-add the Kimi-K2.5 KL term (log_importance_ratio**2 scaled by kl_tau) that the symmetric DPPO-Binary TV refactor dropped, but apply it over keep_mask rather than loss_mask. Tokens outside the trust region are already excluded from the policy gradient — putting the KL on them too would push gradient through positions we explicitly chose to ignore, which is what the previous run was doing wrong. Default kl_tau = 0.0 so existing configs are unchanged. Adds a `kl_penalty` metric averaged over kept tokens. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
log_importance_ratio**2scaled bykl_tau) on top of the symmetric DPPO-Binary TV loss.keep_mask(unmasked / kept tokens) instead of the fullloss_mask. The trust-region mask already drops out-of-region tokens from the policy gradient — putting the KL on them too pushes gradient through positions we explicitly chose to ignore, which is what the current run is doing.kl_tau = 0.0(off) so existing configs are unchanged; set explicitly to re-enable.kl_penaltymetric averaged over kept tokens for monitoring.Why this matters
Before this branch's symmetric refactor, main applied
kl_loss = loss_mask * log_importance_ratio**2. After the refactor the KL was removed entirely. The current run wants the KL back, but usingloss_maskis incorrect when we have a non-trivial trust-region mask: masked tokens (large π_train − π_infer outside the band) would dominate the squared log-ratio and contribute the bulk of the KL — exactly where we don't want gradient.Files
src/prime_rl/trainer/rl/loss.py— KL term added withkeep_mask; metric exposed.packages/prime-rl-configs/src/prime_rl/configs/trainer.py—kl_taufield (default0.0) re-added toDefaultLossConfig.CHANGELOG.md— entry for the new field.🤖 Generated with Claude Code