Skip to content

feat(sft): default loss_mask to renderer's sampled_mask#2644

Open
hallerite wants to merge 5 commits into
mainfrom
feat/sft-loss-mask-sampled-default
Open

feat(sft): default loss_mask to renderer's sampled_mask#2644
hallerite wants to merge 5 commits into
mainfrom
feat/sft-loss-mask-sampled-default

Conversation

@hallerite
Copy link
Copy Markdown
Member

@hallerite hallerite commented May 26, 2026

Motivation

Some chat templates put the per-turn stop signal in the next message's span rather than inside the assistant turn itself. GLM-4.5 and GLM-5 are the immediate cases: after a tool-calling assistant message, the byte that closes the assistant's turn is the <|observation|> (or <|user|>) at the start of the next message — there is no <|im_end|>-style terminator inside the assistant span. vLLM uses that token as a stop id at inference (get_stop_token_ids), but the renderer attributes it to the next (tool / user) message structurally.

The previous SFT loss-mask pipeline AND'd the renderer's sampled_mask with a per-role filter (assistant=True, tool=False, user=False, system=False by default). For the GLM family this meant: the model would-sample <|observation|> at inference (sampled_mask=True), but it's attributed to the tool message (role_to_mask(tool)=False), so it was masked out of the loss and the model never learned to emit it. At inference the natural continuation of </tool_call> is another <tool_call>, so the model dumps parallel tool_calls instead of closing its turn. In a recent GLM-4.5-Air-Base SFT run, 26% of assistant turns had >1 tool call (max observed: 283 parallel calls in a single turn); on SWE-bench Verified the model's pass@1 measurably suffered — rollouts with extreme multi-tool dumps solved at roughly half the rate of the bounded ones.

The companion renderer fix (PrimeIntellect-ai/renderers#66) updates glm45.py / glm5.py to mark that stop opener as is_sampled=True. This PR teaches build_training_sample's caller (SFTDataset) to actually use that signal — by making the sampled_mask-only path the default rather than always AND-ing with a role filter.

Changes

Refactor LossMaskConfig from "a class of per-role booleans" to a discriminated union, mirroring prime-rl's existing Literal | dict-class pattern (cf. fused_lm_head_token_chunk_size: int | Literal[\"auto\", \"disabled\"] in configs/trainer.py):

LossMaskConfig: TypeAlias = Literal["sampled", "all"] | LossMaskRolesConfig
  • \"sampled\" (new default, recommended with renderers): every token the renderer marks is_sampled=True is trainable, regardless of role. Stop tokens that the chat template attributes to the next message's span are correctly trained.
  • \"all\": every renderer token contributes to the loss. Mostly useful for debugging.
  • LossMaskRolesConfig (renamed from the old per-role LossMaskConfig body): AND the renderer's is_sampled with per-role booleans. Use this when you specifically want to restrict supervision to a subset of roles even when other roles' tokens are model-sampled.

SFTDataset dispatches:

  • \"sampled\"role_to_mask=None (relies on build_training_sample's new role_to_mask default added in renderers#66, which falls back to sampled_mask).
  • \"all\"role_to_mask=lambda m: True.
  • LossMaskRolesConfig → callable that looks up the per-role bool, same as before.

The chat-template fallback path (no renderer registered) has no sampled_mask signal so it falls back to LossMaskRolesConfig() defaults regardless of the configured mode.

Backward compatibility

Existing TOML configs of the form:

[data.loss_mask]
assistant = true
tool = false

still parse: pydantic's union dispatch tries Literal[\"sampled\", \"all\"] first (fails for a dict), then LossMaskRolesConfig (succeeds). Per-role behavior is preserved verbatim.

Configs that don't set loss_mask at all pick up the new \"sampled\" default. This is a deliberate behavior change for SFT runs that relied on the implicit assistant-only filter — but the new default is what fixes the bug for renderers whose stop signal lives outside the assistant span. For configs that want to keep the old behavior, set [data.loss_mask] assistant = true explicitly.

Test plan

  • Existing SFT unit tests (tests/unit/train/sft) and orchestrator SFT-trajectory tests pass under both old (per-role) and new (\"sampled\") configs.
  • Type-adapter round-trip: \"sampled\", \"all\", and {\"assistant\": true, \"tool\": true} all parse to the expected variant.
  • Default SFTDataConfig().loss_mask == \"sampled\".

Renderer dependency

  • PrimeIntellect-ai/renderers#66merged. Marks GLM <|user|> / <|observation|> openers as is_sampled=True when they close an assistant turn, and makes role_to_mask optional in build_training_sample (falls back to sampled_mask).
  • This dependency is required, not cosmetic: the "sampled" path passes role_to_mask=None, which only resolves to the sampled_mask fallback as of renderers#66. The branch now pins deps/renderers at 35c2407 (via the main merge), a descendant of add seed as optional arg #66 (1d69232), so the feature works against its pinned renderer.

🤖 Generated with Claude Code


Note

Medium Risk
Changes default SFT supervision and token/EOS handling for renderer-based training, which affects what the model learns on multi-turn/tool templates; explicit per-role TOML configs remain compatible.

Overview
SFT loss masking is refactored so renderer runs default to training every token the renderer marks as model-sampled ("sampled"), instead of always intersecting with assistant-only role flags.

LossMaskConfig becomes "sampled" | "all" | LossMaskRolesConfig (per-role booleans renamed to LossMaskRolesConfig). SFTDataset maps those modes to role_to_mask=None, always-true, or per-role filters for build_training_sample / the chat-template fallback; without a renderer, "sampled"/"all" still use assistant-only role defaults because there is no is_sampled signal.

Renderer path: EOS is no longer forced onto the stream and the EOS-in-target_ids assert applies only to the legacy tokenizer path, so turn-ending tokens from the template (e.g. GLM stop openers on the next message) can be supervised correctly.

Reviewed by Cursor Bugbot for commit 38857d1. Bugbot is set up for automated code reviews on this repo. Configure here.

hallerite and others added 2 commits May 27, 2026 02:29
The existing LossMaskConfig was per-role booleans, AND'd against the
renderer's sampled_mask in build_training_sample. With renderer fixes
landing in PrimeIntellect-ai/renderers#66, the per-token sampled_mask
is now the authoritative "what the model would produce at inference"
signal — including chat-template stop tokens (e.g. GLM's <|user|> /
<|observation|> after a tool-calling assistant turn) whose structural
attribution lives in the *next* message's span. AND-ing those with a
role filter that says "assistant only" silently masks them out of the
loss, which is exactly the bug that caused GLM-4.5-Air SFT to dump
50-283 parallel tool_calls per turn on SWE-bench Verified.

Refactor LossMaskConfig to a discriminated union mirroring prime-rl's
existing ``Literal | dict-class`` pattern (cf.
``fused_lm_head_token_chunk_size: int | Literal["auto", "disabled"]``):

- ``"sampled"`` (new default, recommended with renderers): every token
  the renderer marks ``is_sampled=True`` is trainable, regardless of
  role. Stop tokens whose attribution lives in the next message's span
  are correctly trained.
- ``"all"``: every renderer token contributes to the loss. Debugging.
- ``LossMaskRolesConfig`` (renamed from the old per-role LossMaskConfig):
  AND the renderer's ``is_sampled`` with per-role booleans. Strict
  opt-in for callers who want to restrict supervision to a subset of
  roles even when other roles' tokens are model-sampled.

Existing TOML configs with ``[data.loss_mask] assistant = true`` parse
as ``LossMaskRolesConfig`` via the union — old behaviour preserved.
Configs that don't override ``loss_mask`` pick up the new ``"sampled"``
default, which is the fix.

The chat-template fallback path (no renderer) has no ``sampled_mask``
signal so it always uses ``LossMaskRolesConfig()`` defaults regardless
of the configured mode.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When a renderer is configured its token stream is authoritative — it
mirrors the model's chat template (which carries no EOS by design) and
trains the real stop signals via sampled_mask (e.g. GLM's turn-closing
<|observation|> / <|user|>). Stop injecting an EOS the renderer didn't
emit, and don't assert EOS-in-targets on the renderer path. The
chat-template fallback (no sampled_mask) is unchanged: it still appends
EOS so the model learns to stop.
STACKING_DATASET_BUCKET_TIMEOUT = 10


def _always_true_role_filter(message: dict) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why private

return True


def _role_filter_for(cfg: LossMaskRolesConfig):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why private

# boolean lookup. The chat-template fallback below has no
# ``is_sampled`` signal so we treat sentinel modes as the
# ``LossMaskRolesConfig()`` defaults (assistant-only).
loss_mask_cfg = self.loss_mask_config
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like all this logic, plus the always_true_role_filter, should be in role_filter_for which just returns a tuple of the renderer_role_filter and fallback_role_filter. I found this split in logic between the role filter helper and this dataset quite confusing.

Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit de439e3. Configure here.

fallback_role_filter = _role_filter_for(LossMaskRolesConfig())
elif loss_mask_cfg == "all":
renderer_role_filter = _always_true_role_filter
fallback_role_filter = _role_filter_for(LossMaskRolesConfig())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"all" fallback silently restricts loss to assistant-only tokens

Medium Severity

When loss_mask = "all" and no renderer is configured, fallback_role_filter is set to _role_filter_for(LossMaskRolesConfig()), which defaults to assistant-only masking. The "all" mode is documented as training on "every…token…regardless of role," but this fallback silently contradicts that. The chat-template build_incremental_token_mask uses role_to_mask directly (no is_sampled dependency), so _always_true_role_filter would correctly honor the "all" semantic in the fallback path.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit de439e3. Configure here.

@@ -246,15 +290,15 @@ def should_mask(message: dict) -> bool:
input_ids, loss_mask = build_training_sample(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dislike the extremely different names between build_training_sample and build_incremental_token_mask even though they do the same thing, just with a different backend. Not necessarily important for this PR, but I didn't want to forget it

hallerite and others added 2 commits June 3, 2026 01:57
…default

The loss_mask="sampled" path passes role_to_mask=None to
build_training_sample, which only falls back to the renderer's
sampled_mask as of renderers#66 (1d69232). The branch previously still
pinned 2ec28a8, where role_to_mask is a required positional — so the
feature was broken against its own pinned renderer. Bump the submodule
to the exact validated commit.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants