Skip to content

make the loss equal weight per unmaked token#1961

Merged
hallerite merged 6 commits into
PrimeIntellect-ai:mainfrom
sapiosaturn:dist-sft-fixes
Mar 16, 2026
Merged

make the loss equal weight per unmaked token#1961
hallerite merged 6 commits into
PrimeIntellect-ai:mainfrom
sapiosaturn:dist-sft-fixes

Conversation

@sapiosaturn
Copy link
Copy Markdown
Contributor

@sapiosaturn sapiosaturn commented Mar 5, 2026

warning: not tested yet, do not merge yet.

There is a likely algorithmic weighting bug in PrimeRL SFT because each rank computes a masked-token loss mean on its local shard (loss[loss_mask].mean()), backpropagates that scalar divided only by grad_accum_steps, and then relies on FSDP/HSDP gradient averaging over the hsdp mesh, which includes context-parallel ranks when model.cp > 1; that means the optimized objective is effectively an equal-weight average of per-rank, per-microstep local means, not a global mean over all unmasked tokens in the step. This is harmless only in the special case where every participating rank and every gradient-accumulation microstep has the same number of supervised tokens, which is roughly true for pretraining-style dense loss but not for SFT, because loss_mask is sparse and variable, and because the default SFT packing mode is pack_function = "cat" in prime_rl/configs/sft.py, with CatDataset in prime_rl/trainer/sft/data.py concatenating variable numbers of examples into fixed-length windows; as a result, windows, DP ranks, and GA microsteps can contain very different counts of trainable tokens, yet still contribute equal optimization weight. Context parallelism makes this more severe, because the same local-mean reduction now happens on half-sequences or smaller sequence shards, so token weighting depends on how supervised tokens happen to be distributed across CP partitions, making CP SFT optimize a different objective from both globally token-normalized SFT and non-CP SFT.


Note

Medium Risk
Changes the effective training objective and gradient scaling across distributed meshes, which can materially affect convergence/metrics. Touches core training-loop math (DP/CP/FSDP), so any mismatch in token counting or scaling could destabilize training.

Overview
Fixes SFT loss/gradient weighting to be per unmasked token globally instead of averaging per-rank local means.

compute_loss now returns (loss_sum, token_count) over unmasked tokens, validation aggregates these across the dp_cp mesh to log a token-weighted mean loss, and the training loop accumulates per-microstep loss sums then rescales gradients after backprop using the global token count (undoing FSDP’s fsdp_gradient_divide_factor) so the optimized objective matches a true global token mean under DP+CP+GA.

Written by Cursor Bugbot for commit a8fdc13. This will update automatically on new commits. Configure here.

@sapiosaturn
Copy link
Copy Markdown
Contributor Author

manual gradient div is very ugly, but should work. need to test

@samsja samsja marked this pull request as ready for review March 8, 2026 02:49
Comment thread src/prime_rl/trainer/sft/train.py Outdated
Comment thread src/prime_rl/trainer/sft/train.py Outdated
@hallerite
Copy link
Copy Markdown
Member

Thanks @sapiosaturn ! I made some changes & resolved merge conflicts.
image
Purple is prime-rl main & blue is this PR, so lgtm!

@sapiosaturn
Copy link
Copy Markdown
Contributor Author

awesome, would test at high batch size (maybe with grad acc) + large seq len just to make sure it's not too numerically unstable with all the summation haha

Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

with maybe_record_function("forward"):
local_loss_sum, local_token_count = compute_loss(micro_batch)

step_local_token_count += local_token_count
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NaN micro-step tokens inflate gradient scaling denominator

Medium Severity

step_local_token_count += local_token_count runs unconditionally before the NaN check, so tokens from NaN micro-steps are included in global_step_token_count. Since NaN micro-steps contribute zero gradients (via nan_to_num), the inflated denominator in grad_scale under-scales valid gradients. The eval loop (run_eval_loop) correctly only adds token_count inside the non-NaN branch — the training loop is inconsistent with that.

Additional Locations (1)
Fix in Cursor Fix in Web

Copy link
Copy Markdown
Member

@samsja samsja left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@hallerite hallerite merged commit 6c55a9a into PrimeIntellect-ai:main Mar 16, 2026
19 of 21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants