feat(sft): default loss_mask to renderer's sampled_mask#2644
Conversation
The existing LossMaskConfig was per-role booleans, AND'd against the renderer's sampled_mask in build_training_sample. With renderer fixes landing in PrimeIntellect-ai/renderers#66, the per-token sampled_mask is now the authoritative "what the model would produce at inference" signal — including chat-template stop tokens (e.g. GLM's <|user|> / <|observation|> after a tool-calling assistant turn) whose structural attribution lives in the *next* message's span. AND-ing those with a role filter that says "assistant only" silently masks them out of the loss, which is exactly the bug that caused GLM-4.5-Air SFT to dump 50-283 parallel tool_calls per turn on SWE-bench Verified. Refactor LossMaskConfig to a discriminated union mirroring prime-rl's existing ``Literal | dict-class`` pattern (cf. ``fused_lm_head_token_chunk_size: int | Literal["auto", "disabled"]``): - ``"sampled"`` (new default, recommended with renderers): every token the renderer marks ``is_sampled=True`` is trainable, regardless of role. Stop tokens whose attribution lives in the next message's span are correctly trained. - ``"all"``: every renderer token contributes to the loss. Debugging. - ``LossMaskRolesConfig`` (renamed from the old per-role LossMaskConfig): AND the renderer's ``is_sampled`` with per-role booleans. Strict opt-in for callers who want to restrict supervision to a subset of roles even when other roles' tokens are model-sampled. Existing TOML configs with ``[data.loss_mask] assistant = true`` parse as ``LossMaskRolesConfig`` via the union — old behaviour preserved. Configs that don't override ``loss_mask`` pick up the new ``"sampled"`` default, which is the fix. The chat-template fallback path (no renderer) has no ``sampled_mask`` signal so it always uses ``LossMaskRolesConfig()`` defaults regardless of the configured mode. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When a renderer is configured its token stream is authoritative — it mirrors the model's chat template (which carries no EOS by design) and trains the real stop signals via sampled_mask (e.g. GLM's turn-closing <|observation|> / <|user|>). Stop injecting an EOS the renderer didn't emit, and don't assert EOS-in-targets on the renderer path. The chat-template fallback (no sampled_mask) is unchanged: it still appends EOS so the model learns to stop.
| STACKING_DATASET_BUCKET_TIMEOUT = 10 | ||
|
|
||
|
|
||
| def _always_true_role_filter(message: dict) -> bool: |
| return True | ||
|
|
||
|
|
||
| def _role_filter_for(cfg: LossMaskRolesConfig): |
| # boolean lookup. The chat-template fallback below has no | ||
| # ``is_sampled`` signal so we treat sentinel modes as the | ||
| # ``LossMaskRolesConfig()`` defaults (assistant-only). | ||
| loss_mask_cfg = self.loss_mask_config |
There was a problem hiding this comment.
I feel like all this logic, plus the always_true_role_filter, should be in role_filter_for which just returns a tuple of the renderer_role_filter and fallback_role_filter. I found this split in logic between the role filter helper and this dataset quite confusing.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit de439e3. Configure here.
| fallback_role_filter = _role_filter_for(LossMaskRolesConfig()) | ||
| elif loss_mask_cfg == "all": | ||
| renderer_role_filter = _always_true_role_filter | ||
| fallback_role_filter = _role_filter_for(LossMaskRolesConfig()) |
There was a problem hiding this comment.
"all" fallback silently restricts loss to assistant-only tokens
Medium Severity
When loss_mask = "all" and no renderer is configured, fallback_role_filter is set to _role_filter_for(LossMaskRolesConfig()), which defaults to assistant-only masking. The "all" mode is documented as training on "every…token…regardless of role," but this fallback silently contradicts that. The chat-template build_incremental_token_mask uses role_to_mask directly (no is_sampled dependency), so _always_true_role_filter would correctly honor the "all" semantic in the fallback path.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit de439e3. Configure here.
| @@ -246,15 +290,15 @@ def should_mask(message: dict) -> bool: | |||
| input_ids, loss_mask = build_training_sample( | |||
There was a problem hiding this comment.
i dislike the extremely different names between build_training_sample and build_incremental_token_mask even though they do the same thing, just with a different backend. Not necessarily important for this PR, but I didn't want to forget it
…default The loss_mask="sampled" path passes role_to_mask=None to build_training_sample, which only falls back to the renderer's sampled_mask as of renderers#66 (1d69232). The branch previously still pinned 2ec28a8, where role_to_mask is a required positional — so the feature was broken against its own pinned renderer. Bump the submodule to the exact validated commit. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>


Motivation
Some chat templates put the per-turn stop signal in the next message's span rather than inside the assistant turn itself. GLM-4.5 and GLM-5 are the immediate cases: after a tool-calling assistant message, the byte that closes the assistant's turn is the
<|observation|>(or<|user|>) at the start of the next message — there is no<|im_end|>-style terminator inside the assistant span. vLLM uses that token as a stop id at inference (get_stop_token_ids), but the renderer attributes it to the next (tool / user) message structurally.The previous SFT loss-mask pipeline AND'd the renderer's
sampled_maskwith a per-role filter (assistant=True, tool=False, user=False, system=Falseby default). For the GLM family this meant: the model would-sample<|observation|>at inference (sampled_mask=True), but it's attributed to the tool message (role_to_mask(tool)=False), so it was masked out of the loss and the model never learned to emit it. At inference the natural continuation of</tool_call>is another<tool_call>, so the model dumps parallel tool_calls instead of closing its turn. In a recent GLM-4.5-Air-Base SFT run, 26% of assistant turns had >1 tool call (max observed: 283 parallel calls in a single turn); on SWE-bench Verified the model'spass@1measurably suffered — rollouts with extreme multi-tool dumps solved at roughly half the rate of the bounded ones.The companion renderer fix (PrimeIntellect-ai/renderers#66) updates
glm45.py/glm5.pyto mark that stop opener asis_sampled=True. This PR teachesbuild_training_sample's caller (SFTDataset) to actually use that signal — by making thesampled_mask-only path the default rather than always AND-ing with a role filter.Changes
Refactor
LossMaskConfigfrom "a class of per-role booleans" to a discriminated union, mirroring prime-rl's existingLiteral | dict-classpattern (cf.fused_lm_head_token_chunk_size: int | Literal[\"auto\", \"disabled\"]inconfigs/trainer.py):\"sampled\"(new default, recommended with renderers): every token the renderer marksis_sampled=Trueis trainable, regardless of role. Stop tokens that the chat template attributes to the next message's span are correctly trained.\"all\": every renderer token contributes to the loss. Mostly useful for debugging.LossMaskRolesConfig(renamed from the old per-roleLossMaskConfigbody): AND the renderer'sis_sampledwith per-role booleans. Use this when you specifically want to restrict supervision to a subset of roles even when other roles' tokens are model-sampled.SFTDatasetdispatches:\"sampled\"→role_to_mask=None(relies onbuild_training_sample's newrole_to_maskdefault added in renderers#66, which falls back tosampled_mask).\"all\"→role_to_mask=lambda m: True.LossMaskRolesConfig→ callable that looks up the per-role bool, same as before.The chat-template fallback path (no renderer registered) has no
sampled_masksignal so it falls back toLossMaskRolesConfig()defaults regardless of the configured mode.Backward compatibility
Existing TOML configs of the form:
still parse: pydantic's union dispatch tries
Literal[\"sampled\", \"all\"]first (fails for a dict), thenLossMaskRolesConfig(succeeds). Per-role behavior is preserved verbatim.Configs that don't set
loss_maskat all pick up the new\"sampled\"default. This is a deliberate behavior change for SFT runs that relied on the implicit assistant-only filter — but the new default is what fixes the bug for renderers whose stop signal lives outside the assistant span. For configs that want to keep the old behavior, set[data.loss_mask] assistant = trueexplicitly.Test plan
tests/unit/train/sft) and orchestrator SFT-trajectory tests pass under both old (per-role) and new (\"sampled\") configs.\"sampled\",\"all\", and{\"assistant\": true, \"tool\": true}all parse to the expected variant.SFTDataConfig().loss_mask == \"sampled\".Renderer dependency
<|user|>/<|observation|>openers asis_sampled=Truewhen they close an assistant turn, and makesrole_to_maskoptional inbuild_training_sample(falls back tosampled_mask)."sampled"path passesrole_to_mask=None, which only resolves to thesampled_maskfallback as of renderers#66. The branch now pinsdeps/renderersat35c2407(via themainmerge), a descendant of add seed as optional arg #66 (1d69232), so the feature works against its pinned renderer.🤖 Generated with Claude Code
Note
Medium Risk
Changes default SFT supervision and token/EOS handling for renderer-based training, which affects what the model learns on multi-turn/tool templates; explicit per-role TOML configs remain compatible.
Overview
SFT loss masking is refactored so renderer runs default to training every token the renderer marks as model-sampled (
"sampled"), instead of always intersecting with assistant-only role flags.LossMaskConfigbecomes"sampled" | "all" | LossMaskRolesConfig(per-role booleans renamed toLossMaskRolesConfig).SFTDatasetmaps those modes torole_to_mask=None, always-true, or per-role filters forbuild_training_sample/ the chat-template fallback; without a renderer,"sampled"/"all"still use assistant-only role defaults because there is nois_sampledsignal.Renderer path: EOS is no longer forced onto the stream and the EOS-in-
target_idsassert applies only to the legacy tokenizer path, so turn-ending tokens from the template (e.g. GLM stop openers on the next message) can be supervised correctly.Reviewed by Cursor Bugbot for commit 38857d1. Bugbot is set up for automated code reviews on this repo. Configure here.