feat(sft): VLM SFT via renderers + image-aware seq_len truncation#2485
Draft
eligotts wants to merge 3 commits into
Draft
feat(sft): VLM SFT via renderers + image-aware seq_len truncation#2485eligotts wants to merge 3 commits into
eligotts wants to merge 3 commits into
Conversation
Add multimodal SFT support driven by the renderers package, so any VLM with a registered renderer trains via the same path. Replaces the in-tree incremental-tokenization machinery with a single renderer.render() call per sample, recovering per-token loss masks via RenderedTokens.message_indices and per-image pixel buffers via MultiModalData.mm_items — no model-specific imports inside prime-rl. Adding support for a new VLM amounts to registering a renderer upstream. trainer/sft/data.py - SFTDataset takes a Renderer (any), not a tokenizer. Drops normalize_messages postprocessing (deserialize_tool_calls, strip_message_content), the IncrementalTokenizationError skip branch, and the manual EOS append guard — all unneeded with a deterministic renderer. - _flatten_mm_items concatenates per-image renderer items into a flat dict[str, Tensor] keyed by HF forward kwarg names (e.g. pixel_values, image_grid_thw for Qwen-VL). Model-agnostic. - _build_mm_token_type_ids derives the per-token modality flag from the renderer's mm_placeholders. - Image-aware seq_len truncation at the data-pipeline level (mirroring CatDataset's text-only convention): _find_image_safe_cut pulls the cut back to a position that doesn't split an image_pad run; _truncate_mm_data drops mm_items whose placeholders fall past the cut, keeping the surviving placeholder↔item correspondence 1-to-1. - CatDataset / StackDataset yield multimodal samples standalone (no packing — pixel buffers are variable-shape per image; not worth the bookkeeping for effective batch_size 1). - Local file ingestion: data_files config + .zst decompression + a JSONL pre-pass (_normalize_oai_record) that wraps string content in text blocks and stringifies tool_calls.function.arguments so PyArrow's JSON loader gets a stable schema across rows. Also _drop_null_fields strips the noise PyArrow's schema unification adds (text items end up with image_url=None which trips renderers' permissive _is_image_part check). trainer/sft/train.py - create_renderer(tokenizer, "auto") for model-agnostic renderer pickup. - setup_processor() loads the HF AutoProcessor when the model is a VLM and attaches it to the renderer; returns None for text-only. - compute_loss forwards pixel_values / image_grid_thw / mm_token_type_ids to forward() when present. trainer/model.py - New setup_processor() helper. - forward() respects a caller-supplied mm_token_type_ids when given, instead of unconditionally clobbering it with the qwen3_vl-only autocompute helper. The renderer's value comes from exact placeholder ranges and works for any VLM family — keeps the autocompute as a back-compat fallback. utils/chat_template.py - Remove build_incremental_token_mask, IncrementalTokenizationError, should_add_generation_prompt — all subsumed by build_training_sample from renderers. Keep normalize_messages, deserialize_tool_calls, strip_message_content, render_messages (still used by orchestrator fallback path + wandb logging). configs/sft.py - data_files field for local JSONL/JSONL.zst loading. - Validators ported from TrainerConfig: vlms_require_bfloat16, vlm_freeze_incompatible_with_lora. New validate_vlm_constraints forces micro_batch_size=1 and disables CP for VLM (image samples can't pack across samples; CP would split image placeholders across seq shards). examples/plex_vlm_sft/sft.toml - Reference config for browser-agent trace SFT: Qwen3.5-0.8B + LoRA r=8 + ac freq=1 + seq_len 16k. Same shape works for Qwen3.5-35B-A3B by swapping model.name. skills/config/SKILL.md - New "SFT with images" section documenting the multimodal config knobs and constraints. Verified - 19 SFT unit tests pass (renderer-based fixture, same coverage as before). - Text-only smoke (examples/reverse_text/sft.toml, Qwen3-0.6B, 3 steps): loss 4.66 → 4.54 → 4.25; 22.9k tok/s; 11.6% MFU; peak 25.7 GiB. No regression vs main. - VLM smoke (examples/plex_vlm_sft, Qwen3.5-0.8B, 3 steps with truncation): no OOM, loss decreasing, image_safe truncation keeps memory bounded at ~56 GiB peak. - Image-aware truncation verified against full 100-record Plex slice: 28 records truncated to seq_len + 1, 72 untouched, all placeholder↔item invariants hold, 0 user-message images lost. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
_resolve_local_data_files materializes each .zst into $TMPDIR via a deterministic filename derived from the source basename. When training multi-rank with `data_files`, every rank computed the same path, every rank saw `tmp.exists() == False`, and every rank entered the decompress + normalize block in parallel. Concurrent writers to the same path produced byte-interleaved corrupt output, which PyArrow rejected at load with `Missing closing quotation mark in string. in row 0`. Some ranks crashed silently while the survivors went on to training, producing a deadlock at the next NCCL collective and ultimately at `destroy_process_group` in the sync_wrapper finally block. Gate the decompress + normalize work on global rank 0, then call `torch.distributed.barrier()` so all other ranks read a fully-written file. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
The default `[model.ac] mode = "full"` is unsafe with MoE: bf16 router matmul reductions are non-deterministic on GPU, so forward and the activation-checkpoint recompute pick different topk winners when scores are near-tied. The resulting per-expert tensor shape mismatch surfaces as `CheckpointError: Recomputed values ... different metadata` mid-training. Refs PyTorch #171355. Document the symptom, the cause, and the workaround (`mode = "selective"` with `targets = ["norm", "attn_proj", "linear_attn"]`, i.e. everything except `routed_experts`), plus the measured memory trade-off (Qwen3.5-35B-A3B / 32k seq_len / 8x B300: ~126 GiB peak full AC → ~249 GiB peak selective excluding routed_experts). Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
rendererspackage. SFTDataset now takes aRenderer(any registered VLM renderer works) instead of a tokenizer; per-token loss masks come fromRenderedTokens.message_indicesand per-image pixel buffers come fromMultiModalData.mm_items. No model-specific imports anywhere inprime-rl— adding a new VLM means registering a renderer upstream, nothing in this repo.build_incremental_token_mask,IncrementalTokenizationError,should_add_generation_prompt, the manual EOS append guard, the deserialize/strip pre-pass) with a singlerenderer.render()call per sample. The renderer is byte-identical toapply_chat_template, so the workarounds for prefix-stability assertions are no longer needed.seq_lentruncation at the data-pipeline level (mirrors the existing text-onlyCatDatasetconvention but slice-aware):_find_image_safe_cutnever splits an<|image_pad|>placeholder run, and_truncate_mm_datadropsmm_itemswhose placeholders fall past the cut so surviving placeholders stay 1-to-1 with surviving pixel buffers.data_filesconfig field, transparent.zstdecompression, and a small Arrow-schema normalization pass (_normalize_oai_record,_drop_null_fields) so heterogeneous OAI message content + per-row JSON-schema tool definitions load cleanly through HF datasets..zstdecompression is now serialized to rank 0 + barrier so multi-rank training doesn't race on the same$TMPDIRpath (fix commit follows the feature commit).setup_processor()intrainer/model.py;forward()now prefers a caller-suppliedmm_token_type_ids(renderer-exact) over the qwen3_vl-only sniffer (heuristic), with the sniffer kept as a back-compat fallback.SFTConfig: bfloat16-required, freeze-vision-with-LoRA,micro_batch_size = 1, no CP.examples/plex_vlm_sft/sft.toml(Qwen3.5-0.8B + LoRA on the Plex sample). Swapmodel.namefor Qwen3.5-35B-A3B to scale up.skills/config/SKILL.md.Pairs with
PrimeIntellect-ai/renderers#25 — fixes Qwen35Renderer dropping tool-message images. Without that PR, browser-agent traces lose ~82% of their visual content (post-action screenshots arrive as tool responses). With it, image counts go from 100 → 567 across the same 100-record Plex slice (one initial user image + 4–25 tool-response images per trace). The image-aware truncation here is what keeps memory bounded once those screenshots flow through.
Test plan
examples/reverse_text/sft.tomlonPrimeIntellect/Qwen3-0.6B, 3 steps — loss 4.66 → 4.54 → 4.25, 22.9k tok/s, 11.6% MFU, peak 25.7 GiB.examples/plex_vlm_sft/sft.tomlonQwen/Qwen3.5-0.8B, LoRA r=8, seq_len 16k, 3 steps — clean exit, peak ~56 GiB with truncation engaged.mm_itemstill corresponds to a surviving placeholder.routed_experts. Loss trajectory 0.6 → 0.4 plateau (typical for r=8 LoRA, capacity-bound). Clean checkpoints every 500 steps. Adapter eval against the perplexity-eval-env at the final checkpoint:rubric_eval_suite(n=73, c=20): mean reward 0.636, median 0.667, 18/73 perfect, 51/73 ≥ 0.5, 73/73 submitted_result.google_flights_eval(n=40, c=20): mean reward 0.538, median 0.631, 9/40 perfect, 25/40 ≥ 0.5, 40/40 submitted_result.What I did not test
NotImplementedErrormatching the renderer's existing behavior.Notes for reviewers
destroy_process_groupwhen some ranks read mid-write corrupt JSON output and crash silently while the survivors continue training (we hit this on the first multi-rank launch).mode = "full"on any MoE model. The SKILL doc explains and points atmode = "selective"excludingrouted_experts. Considered emitting a config validator warning formode="full"on MoE — happy to add if reviewers want.Other infra findings during the scaled 35B-A3B run (NOT fixed in this PR — heads-up for reviewers)
1. Blackwell B300 (SM103) + CUDA 12.8 →
torch._grouped_mmdoesn't compileThe CUTLASS-backed
torch._grouped_mminmodels/layers/moe.py::_run_experts_grouped_mm_implandmodels/layers/lora/multi_linear.py::_run_lora_grouped_mmfails to compile on Blackwell B300 (SM103) when the system CUDA toolkit is 12.8:compute_103a(Blackwell SM103) wasn't added until CUDA 12.9. Both grouped-GEMM call sites fall over before any forward pass completes.Local workaround we used (not pushed): force-route both MoE experts and MultiLoRA to the existing
_run_experts_for_loop/_run_lora_for_loopfallbacks. Functional buttorch._grouped_mm(one fused op) is much faster than the Python expert-by-expert loop, so this is a perf hit not a correctness fix.Proper fix would gate the fallback on hardware: e.g.
if torch.cuda.get_device_capability()[0] >= 10 and cuda_toolkit_version < (12, 9): use_grouped_mm = False. Not included here — happy to split this into its own PR if reviewers want, or file a GH issue first.2.
model.moe_use_grouped_mm = falseconfig flag doesn't propagate toGroupedExpertsSetting
[model] moe_use_grouped_mm = falsein the SFT config currently has no effect. The flag is read intomodel_config.use_grouped_mmattrainer/model.py:473and is meant to flow through toMoEArgs.use_grouped_mm→GroupedExperts(use_grouped_mm=...)to switch the forward path between_run_experts_grouped_mmand_run_experts_for_loop(seemodels/layers/moe.py:222-227). But during the Blackwell debug above, settingmoe_use_grouped_mm = falsedid not route the model away from grouped GEMM — the forward still went through_run_experts_grouped_mm_impland crashed. Worked around by hard-coding the for-loop path inGroupedExperts.forward.I didn't track down where the propagation breaks (config-class wiring, FSDP wrap, the
MoEArgsconstruction inmodels/qwen3_5_moe/modeling_qwen3_5_moe.py:606, or somewhere in the FlagStore/configstore mechanics). On Hopper this is silent because the defaultTruepath works; on Blackwell it became visible because the True path is broken. Worth a small follow-up to make the flag actually do what it claims.3. Sample-heterogeneity → NCCL all-gather straggler at high seq_len with selective AC excluding
routed_expertsWhen evaluating whether
seq_len = 65536could fit alongside selective AC (which keeps MoE activations in memory rather than recomputing), one rank with a high-image-count sample (~47 images) became a hard straggler: vision encoder + non-recomputed MoE intermediates blew that rank's per-step memory and per-step time so far past the others that the cluster timed out at the default 600s NCCL all-gather. Memory wasn't the proximate failure mode — it was wall-clock variance per rank.The fundamental issue is that selective AC excluding
routed_expertsamplifies per-rank variance from per-sample image count, because the rank with more images now stores ~5x more MoE intermediates than the rank with fewer. Atseq_len = 32768(with image-aware truncation kicking in at the tail) the variance stays bounded and selective AC is stable. We didn't find a clean fix for 65k+selective; options would be per-sample dispatch (load-balance by image count, not by rank) or returning to full AC + a router-determinism fix. Documenting so future scaled runs don't repeat the 65k attempt without addressing the heterogeneity.4. MoE + activation checkpointing →
CheckpointError: Recomputed values have different metadata— unsolvedWe hit this repeatedly during the early scaled runs at
[model.ac] mode = "full":Root cause (covered in the new SKILL doc): the token-choice router does
topkover near-tied bf16 scores from a non-deterministic GPU matmul reduction, so the same layer input picks different winning experts on the forward call vs the recompute call. Different routing → differentnum_tokens_per_expert→ per-expert tensor shapes diverge between forward-saved and backward-recomputed → backward dies. Refs PyTorch #171355 (open, unresolved) and Lightning #19267.We made four attempts to actually fix it — none worked:
determinism_check="none"oncheckpoint_wrapperRuntimeError: size of tensor a (168) must match tensor b (16)at the actual matmul. Disabling the check just delays the failure.preserve_rng_state=Trueoncheckpoint_wrapperF.linear(x.float(), gate.weight.float()))torch.use_deterministic_algorithms(True)+CUBLAS_WORKSPACE_CONFIG=:4096:8[14576, 512], dying on step 0. Some op silently changes behavior in determinism mode in a way that breaks recompute completely.What actually unblocked us was the selective-AC workaround in commit
3256a0d0d(this PR): setmode = "selective"withtargets = ["norm", "attn_proj", "linear_attn"]so the MoE block is not activation-checkpointed at all. No recompute → no drift → no error. This costs ~120 GiB more peak memory at seq_len=32k for Qwen3.5-35B-A3B (126 → 249 GiB on 8x B300) because all the routed-expert intermediates now persist forward→backward instead of being recomputed. We could afford it; lots of MoE training configs cannot.This is not a bug we fixed — we removed MoE from the trigger surface. The actual fix would have to live in either
torch.utils.checkpoint(relax the metadata-equality assumption for ops with data-dependent shapes) or in MoE-aware checkpointing (save the routing decision at forward time and replay it at recompute). Both are non-trivial; for context, the upstream PyTorch issue has been open since the same combo bit a Qwen3 MoE LoRA user on H100, so this is not a Blackwell quirk — anyone training MoE + full AC will hit it eventually.If reviewers want a guard rail before merging, the easiest is a config validator that emits a hard warning (or refuses to run) when the resolved model is MoE and
model.ac.mode == "full". Happy to add in a follow-up.🤖 Generated with Claude Code