make the loss equal weight per unmaked token#1961
Conversation
|
manual gradient div is very ugly, but should work. need to test |
|
Thanks @sapiosaturn ! I made some changes & resolved merge conflicts. |
|
awesome, would test at high batch size (maybe with grad acc) + large seq len just to make sure it's not too numerically unstable with all the summation haha |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
| with maybe_record_function("forward"): | ||
| local_loss_sum, local_token_count = compute_loss(micro_batch) | ||
|
|
||
| step_local_token_count += local_token_count |
There was a problem hiding this comment.
NaN micro-step tokens inflate gradient scaling denominator
Medium Severity
step_local_token_count += local_token_count runs unconditionally before the NaN check, so tokens from NaN micro-steps are included in global_step_token_count. Since NaN micro-steps contribute zero gradients (via nan_to_num), the inflated denominator in grad_scale under-scales valid gradients. The eval loop (run_eval_loop) correctly only adds token_count inside the non-NaN branch — the training loop is inconsistent with that.



warning: not tested yet, do not merge yet.
There is a likely algorithmic weighting bug in PrimeRL SFT because each rank computes a masked-token loss mean on its local shard (loss[loss_mask].mean()), backpropagates that scalar divided only by grad_accum_steps, and then relies on FSDP/HSDP gradient averaging over the hsdp mesh, which includes context-parallel ranks when model.cp > 1; that means the optimized objective is effectively an equal-weight average of per-rank, per-microstep local means, not a global mean over all unmasked tokens in the step. This is harmless only in the special case where every participating rank and every gradient-accumulation microstep has the same number of supervised tokens, which is roughly true for pretraining-style dense loss but not for SFT, because loss_mask is sparse and variable, and because the default SFT packing mode is pack_function = "cat" in prime_rl/configs/sft.py, with CatDataset in prime_rl/trainer/sft/data.py concatenating variable numbers of examples into fixed-length windows; as a result, windows, DP ranks, and GA microsteps can contain very different counts of trainable tokens, yet still contribute equal optimization weight. Context parallelism makes this more severe, because the same local-mean reduction now happens on half-sequences or smaller sequence shards, so token weighting depends on how supervised tokens happen to be distributed across CP partitions, making CP SFT optimize a different objective from both globally token-normalized SFT and non-CP SFT.
Note
Medium Risk
Changes the effective training objective and gradient scaling across distributed meshes, which can materially affect convergence/metrics. Touches core training-loop math (DP/CP/FSDP), so any mismatch in token counting or scaling could destabilize training.
Overview
Fixes SFT loss/gradient weighting to be per unmasked token globally instead of averaging per-rank local means.
compute_lossnow returns(loss_sum, token_count)over unmasked tokens, validation aggregates these across thedp_cpmesh to log a token-weighted mean loss, and the training loop accumulates per-microstep loss sums then rescales gradients after backprop using the global token count (undoing FSDP’sfsdp_gradient_divide_factor) so the optimized objective matches a true global token mean under DP+CP+GA.Written by Cursor Bugbot for commit a8fdc13. This will update automatically on new commits. Configure here.