Skip to content

feat(sft): VLM SFT via renderers + image-aware seq_len truncation#2485

Draft
eligotts wants to merge 3 commits into
mainfrom
feat/sft-multimodal-renderer
Draft

feat(sft): VLM SFT via renderers + image-aware seq_len truncation#2485
eligotts wants to merge 3 commits into
mainfrom
feat/sft-multimodal-renderer

Conversation

@eligotts
Copy link
Copy Markdown
Contributor

@eligotts eligotts commented May 13, 2026

Summary

  • Add multimodal SFT support driven by the renderers package. SFTDataset now takes a Renderer (any registered VLM renderer works) instead of a tokenizer; per-token loss masks come from RenderedTokens.message_indices and per-image pixel buffers come from MultiModalData.mm_items. No model-specific imports anywhere in prime-rl — adding a new VLM means registering a renderer upstream, nothing in this repo.
  • Replace the in-tree incremental-tokenization machinery (build_incremental_token_mask, IncrementalTokenizationError, should_add_generation_prompt, the manual EOS append guard, the deserialize/strip pre-pass) with a single renderer.render() call per sample. The renderer is byte-identical to apply_chat_template, so the workarounds for prefix-stability assertions are no longer needed.
  • Image-aware seq_len truncation at the data-pipeline level (mirrors the existing text-only CatDataset convention but slice-aware): _find_image_safe_cut never splits an <|image_pad|> placeholder run, and _truncate_mm_data drops mm_items whose placeholders fall past the cut so surviving placeholders stay 1-to-1 with surviving pixel buffers.
  • Local data ingestion: new data_files config field, transparent .zst decompression, and a small Arrow-schema normalization pass (_normalize_oai_record, _drop_null_fields) so heterogeneous OAI message content + per-row JSON-schema tool definitions load cleanly through HF datasets. .zst decompression is now serialized to rank 0 + barrier so multi-rank training doesn't race on the same $TMPDIR path (fix commit follows the feature commit).
  • New setup_processor() in trainer/model.py; forward() now prefers a caller-supplied mm_token_type_ids (renderer-exact) over the qwen3_vl-only sniffer (heuristic), with the sniffer kept as a back-compat fallback.
  • New VLM config validators in SFTConfig: bfloat16-required, freeze-vision-with-LoRA, micro_batch_size = 1, no CP.
  • New example config examples/plex_vlm_sft/sft.toml (Qwen3.5-0.8B + LoRA on the Plex sample). Swap model.name for Qwen3.5-35B-A3B to scale up.
  • New "SFT with images" + "MoE + activation checkpointing trap" sections in skills/config/SKILL.md.

Pairs with

PrimeIntellect-ai/renderers#25 — fixes Qwen35Renderer dropping tool-message images. Without that PR, browser-agent traces lose ~82% of their visual content (post-action screenshots arrive as tool responses). With it, image counts go from 100 → 567 across the same 100-record Plex slice (one initial user image + 4–25 tool-response images per trace). The image-aware truncation here is what keeps memory bounded once those screenshots flow through.

Test plan

  • 19 SFT unit tests pass — same coverage as main, with the renderer-based fixture.
  • Text-only regression: examples/reverse_text/sft.toml on PrimeIntellect/Qwen3-0.6B, 3 steps — loss 4.66 → 4.54 → 4.25, 22.9k tok/s, 11.6% MFU, peak 25.7 GiB.
  • VLM smoke (single H200): examples/plex_vlm_sft/sft.toml on Qwen/Qwen3.5-0.8B, LoRA r=8, seq_len 16k, 3 steps — clean exit, peak ~56 GiB with truncation engaged.
  • Image-aware truncation invariants checked across all 100 records of the Plex sample: 28 truncated, 72 untouched, every surviving placeholder still falls within the truncated length, every surviving mm_item still corresponds to a surviving placeholder.
  • Qwen3.5-35B-A3B end-to-end on 8x B300 (multi-rank FSDP): 4,600-step LoRA r=8 (attn-only) run on the 14 GB Plex traces dataset, seq_len 32k, selective AC excluding routed_experts. Loss trajectory 0.6 → 0.4 plateau (typical for r=8 LoRA, capacity-bound). Clean checkpoints every 500 steps. Adapter eval against the perplexity-eval-env at the final checkpoint:
    • rubric_eval_suite (n=73, c=20): mean reward 0.636, median 0.667, 18/73 perfect, 51/73 ≥ 0.5, 73/73 submitted_result.
    • google_flights_eval (n=40, c=20): mean reward 0.538, median 0.631, 9/40 perfect, 25/40 ≥ 0.5, 40/40 submitted_result.

What I did not test

  • Resume-from-checkpoint round trip for VLM samples.
  • Video parts in any role — the code raises NotImplementedError matching the renderer's existing behavior.

Notes for reviewers

  • The data-race fix is its own follow-up commit on top of the feature commit, so it's easy to review in isolation. The race manifests as a deadlock at destroy_process_group when some ranks read mid-write corrupt JSON output and crash silently while the survivors continue training (we hit this on the first multi-rank launch).
  • The MoE + activation-checkpointing trap (PyTorch issue #171355) is unavoidable with mode = "full" on any MoE model. The SKILL doc explains and points at mode = "selective" excluding routed_experts. Considered emitting a config validator warning for mode="full" on MoE — happy to add if reviewers want.

Other infra findings during the scaled 35B-A3B run (NOT fixed in this PR — heads-up for reviewers)

1. Blackwell B300 (SM103) + CUDA 12.8 → torch._grouped_mm doesn't compile

The CUTLASS-backed torch._grouped_mm in models/layers/moe.py::_run_experts_grouped_mm_impl and models/layers/lora/multi_linear.py::_run_lora_grouped_mm fails to compile on Blackwell B300 (SM103) when the system CUDA toolkit is 12.8:

nvcc fatal: Unsupported gpu architecture 'compute_103a'
RuntimeError: cutlass cannot run, error 7   # cudaErrorInvalidConfiguration

compute_103a (Blackwell SM103) wasn't added until CUDA 12.9. Both grouped-GEMM call sites fall over before any forward pass completes.

Local workaround we used (not pushed): force-route both MoE experts and MultiLoRA to the existing _run_experts_for_loop / _run_lora_for_loop fallbacks. Functional but torch._grouped_mm (one fused op) is much faster than the Python expert-by-expert loop, so this is a perf hit not a correctness fix.

Proper fix would gate the fallback on hardware: e.g. if torch.cuda.get_device_capability()[0] >= 10 and cuda_toolkit_version < (12, 9): use_grouped_mm = False. Not included here — happy to split this into its own PR if reviewers want, or file a GH issue first.

2. model.moe_use_grouped_mm = false config flag doesn't propagate to GroupedExperts

Setting [model] moe_use_grouped_mm = false in the SFT config currently has no effect. The flag is read into model_config.use_grouped_mm at trainer/model.py:473 and is meant to flow through to MoEArgs.use_grouped_mmGroupedExperts(use_grouped_mm=...) to switch the forward path between _run_experts_grouped_mm and _run_experts_for_loop (see models/layers/moe.py:222-227). But during the Blackwell debug above, setting moe_use_grouped_mm = false did not route the model away from grouped GEMM — the forward still went through _run_experts_grouped_mm_impl and crashed. Worked around by hard-coding the for-loop path in GroupedExperts.forward.

I didn't track down where the propagation breaks (config-class wiring, FSDP wrap, the MoEArgs construction in models/qwen3_5_moe/modeling_qwen3_5_moe.py:606, or somewhere in the FlagStore/configstore mechanics). On Hopper this is silent because the default True path works; on Blackwell it became visible because the True path is broken. Worth a small follow-up to make the flag actually do what it claims.

3. Sample-heterogeneity → NCCL all-gather straggler at high seq_len with selective AC excluding routed_experts

When evaluating whether seq_len = 65536 could fit alongside selective AC (which keeps MoE activations in memory rather than recomputing), one rank with a high-image-count sample (~47 images) became a hard straggler: vision encoder + non-recomputed MoE intermediates blew that rank's per-step memory and per-step time so far past the others that the cluster timed out at the default 600s NCCL all-gather. Memory wasn't the proximate failure mode — it was wall-clock variance per rank.

ProcessGroupNCCL's watchdog got stuck for 480 seconds without making progress
WorkNCCL(SeqNum=54, OpType=_ALLGATHER_BASE, NumelIn=104554048, NumelOut=836432384, Timeout(ms)=600000)

The fundamental issue is that selective AC excluding routed_experts amplifies per-rank variance from per-sample image count, because the rank with more images now stores ~5x more MoE intermediates than the rank with fewer. At seq_len = 32768 (with image-aware truncation kicking in at the tail) the variance stays bounded and selective AC is stable. We didn't find a clean fix for 65k+selective; options would be per-sample dispatch (load-balance by image count, not by rank) or returning to full AC + a router-determinism fix. Documenting so future scaled runs don't repeat the 65k attempt without addressing the heterogeneity.

4. MoE + activation checkpointing → CheckpointError: Recomputed values have different metadata — unsolved

We hit this repeatedly during the early scaled runs at [model.ac] mode = "full":

torch.utils.checkpoint.CheckpointError:
  Recomputed values for the following tensors have different metadata than during the forward pass.
tensor at position 45:
  saved metadata:      shape=torch.Size([760, 512]), dtype=bf16
  recomputed metadata: shape=torch.Size([880, 512]), dtype=bf16

Root cause (covered in the new SKILL doc): the token-choice router does topk over near-tied bf16 scores from a non-deterministic GPU matmul reduction, so the same layer input picks different winning experts on the forward call vs the recompute call. Different routing → different num_tokens_per_expert → per-expert tensor shapes diverge between forward-saved and backward-recomputed → backward dies. Refs PyTorch #171355 (open, unresolved) and Lightning #19267.

We made four attempts to actually fix it — none worked:

Attempt Result
determinism_check="none" on checkpoint_wrapper Bypasses the metadata guard, but the recomputed tensors really have wrong shapes — backward then dies with RuntimeError: size of tensor a (168) must match tensor b (16) at the actual matmul. Disabling the check just delays the failure.
preserve_rng_state=True on checkpoint_wrapper No effect. The drift is not RNG-based.
fp32 cast on the router gate (F.linear(x.float(), gate.weight.float())) No effect. The router input is itself recomputed from upstream bf16 ops in attention / LoRA / LayerNorm, so the gate gets slightly different inputs even when its own matmul is fp32.
torch.use_deterministic_algorithms(True) + CUBLAS_WORKSPACE_CONFIG=:4096:8 Made it worse — recompute now collapsed all per-expert tensors to identical [14576, 512], dying on step 0. Some op silently changes behavior in determinism mode in a way that breaks recompute completely.

What actually unblocked us was the selective-AC workaround in commit 3256a0d0d (this PR): set mode = "selective" with targets = ["norm", "attn_proj", "linear_attn"] so the MoE block is not activation-checkpointed at all. No recompute → no drift → no error. This costs ~120 GiB more peak memory at seq_len=32k for Qwen3.5-35B-A3B (126 → 249 GiB on 8x B300) because all the routed-expert intermediates now persist forward→backward instead of being recomputed. We could afford it; lots of MoE training configs cannot.

This is not a bug we fixed — we removed MoE from the trigger surface. The actual fix would have to live in either torch.utils.checkpoint (relax the metadata-equality assumption for ops with data-dependent shapes) or in MoE-aware checkpointing (save the routing decision at forward time and replay it at recompute). Both are non-trivial; for context, the upstream PyTorch issue has been open since the same combo bit a Qwen3 MoE LoRA user on H100, so this is not a Blackwell quirk — anyone training MoE + full AC will hit it eventually.

If reviewers want a guard rail before merging, the easiest is a config validator that emits a hard warning (or refuses to run) when the resolved model is MoE and model.ac.mode == "full". Happy to add in a follow-up.

🤖 Generated with Claude Code

Add multimodal SFT support driven by the renderers package, so any VLM
with a registered renderer trains via the same path. Replaces the
in-tree incremental-tokenization machinery with a single renderer.render()
call per sample, recovering per-token loss masks via RenderedTokens.message_indices
and per-image pixel buffers via MultiModalData.mm_items — no model-specific
imports inside prime-rl. Adding support for a new VLM amounts to registering
a renderer upstream.

trainer/sft/data.py
- SFTDataset takes a Renderer (any), not a tokenizer. Drops normalize_messages
  postprocessing (deserialize_tool_calls, strip_message_content), the
  IncrementalTokenizationError skip branch, and the manual EOS append guard —
  all unneeded with a deterministic renderer.
- _flatten_mm_items concatenates per-image renderer items into a flat
  dict[str, Tensor] keyed by HF forward kwarg names (e.g. pixel_values,
  image_grid_thw for Qwen-VL). Model-agnostic.
- _build_mm_token_type_ids derives the per-token modality flag from the
  renderer's mm_placeholders.
- Image-aware seq_len truncation at the data-pipeline level (mirroring
  CatDataset's text-only convention): _find_image_safe_cut pulls the cut
  back to a position that doesn't split an image_pad run; _truncate_mm_data
  drops mm_items whose placeholders fall past the cut, keeping the
  surviving placeholder↔item correspondence 1-to-1.
- CatDataset / StackDataset yield multimodal samples standalone (no
  packing — pixel buffers are variable-shape per image; not worth the
  bookkeeping for effective batch_size 1).
- Local file ingestion: data_files config + .zst decompression + a
  JSONL pre-pass (_normalize_oai_record) that wraps string content in
  text blocks and stringifies tool_calls.function.arguments so PyArrow's
  JSON loader gets a stable schema across rows. Also _drop_null_fields
  strips the noise PyArrow's schema unification adds (text items end up
  with image_url=None which trips renderers' permissive _is_image_part check).

trainer/sft/train.py
- create_renderer(tokenizer, "auto") for model-agnostic renderer pickup.
- setup_processor() loads the HF AutoProcessor when the model is a VLM
  and attaches it to the renderer; returns None for text-only.
- compute_loss forwards pixel_values / image_grid_thw / mm_token_type_ids
  to forward() when present.

trainer/model.py
- New setup_processor() helper.
- forward() respects a caller-supplied mm_token_type_ids when given,
  instead of unconditionally clobbering it with the qwen3_vl-only
  autocompute helper. The renderer's value comes from exact placeholder
  ranges and works for any VLM family — keeps the autocompute as a
  back-compat fallback.

utils/chat_template.py
- Remove build_incremental_token_mask, IncrementalTokenizationError,
  should_add_generation_prompt — all subsumed by build_training_sample
  from renderers. Keep normalize_messages, deserialize_tool_calls,
  strip_message_content, render_messages (still used by orchestrator
  fallback path + wandb logging).

configs/sft.py
- data_files field for local JSONL/JSONL.zst loading.
- Validators ported from TrainerConfig: vlms_require_bfloat16,
  vlm_freeze_incompatible_with_lora. New validate_vlm_constraints
  forces micro_batch_size=1 and disables CP for VLM (image samples
  can't pack across samples; CP would split image placeholders across
  seq shards).

examples/plex_vlm_sft/sft.toml
- Reference config for browser-agent trace SFT: Qwen3.5-0.8B + LoRA r=8
  + ac freq=1 + seq_len 16k. Same shape works for Qwen3.5-35B-A3B by
  swapping model.name.

skills/config/SKILL.md
- New "SFT with images" section documenting the multimodal config knobs
  and constraints.

Verified
- 19 SFT unit tests pass (renderer-based fixture, same coverage as before).
- Text-only smoke (examples/reverse_text/sft.toml, Qwen3-0.6B, 3 steps):
  loss 4.66 → 4.54 → 4.25; 22.9k tok/s; 11.6% MFU; peak 25.7 GiB. No
  regression vs main.
- VLM smoke (examples/plex_vlm_sft, Qwen3.5-0.8B, 3 steps with truncation):
  no OOM, loss decreasing, image_safe truncation keeps memory bounded at
  ~56 GiB peak.
- Image-aware truncation verified against full 100-record Plex slice: 28
  records truncated to seq_len + 1, 72 untouched, all placeholder↔item
  invariants hold, 0 user-message images lost.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@hallerite hallerite self-requested a review May 13, 2026 11:20
eligotts and others added 2 commits May 14, 2026 22:20
_resolve_local_data_files materializes each .zst into $TMPDIR via a
deterministic filename derived from the source basename. When training
multi-rank with `data_files`, every rank computed the same path, every
rank saw `tmp.exists() == False`, and every rank entered the decompress +
normalize block in parallel. Concurrent writers to the same path
produced byte-interleaved corrupt output, which PyArrow rejected at
load with `Missing closing quotation mark in string. in row 0`.
Some ranks crashed silently while the survivors went on to training,
producing a deadlock at the next NCCL collective and ultimately at
`destroy_process_group` in the sync_wrapper finally block.

Gate the decompress + normalize work on global rank 0, then call
`torch.distributed.barrier()` so all other ranks read a fully-written
file.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
The default `[model.ac] mode = "full"` is unsafe with MoE: bf16 router
matmul reductions are non-deterministic on GPU, so forward and the
activation-checkpoint recompute pick different topk winners when scores
are near-tied. The resulting per-expert tensor shape mismatch surfaces as
`CheckpointError: Recomputed values ... different metadata` mid-training.
Refs PyTorch #171355.

Document the symptom, the cause, and the workaround
(`mode = "selective"` with `targets = ["norm", "attn_proj", "linear_attn"]`,
i.e. everything except `routed_experts`), plus the measured memory
trade-off (Qwen3.5-35B-A3B / 32k seq_len / 8x B300:
~126 GiB peak full AC → ~249 GiB peak selective excluding routed_experts).

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant