Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions examples/plex_vlm_sft/sft.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# SFT on Plex browser-agent traces (multimodal: screenshots + tool calls).
# Single H200, LoRA, attention-only target_modules to keep adapter param count modest.
#
# Notes:
# - The renderer auto-resolves to ``qwen3.5`` from the tokenizer's name; the
# trainer is model-agnostic. To swap to a different VLM (e.g. Qwen3-VL-4B),
# change ``model.name`` — no other config changes needed.
# - VLM SFT forces micro_batch_size = 1 (image samples can't be packed).
# - Plex screenshots blow each sample up to ~7k–60k tokens (mostly image
# patches). Pick a seq_len that fits your single-GPU budget; samples that
# exceed it get dropped.
max_steps = 20

[ckpt]
[ckpt.weights]
save_adapter_separately = true

[model]
name = "Qwen/Qwen3.5-0.8B"
optimization_dtype = "bfloat16"
reduce_dtype = "bfloat16"

[model.vlm]
vision_encoder_attr = "model.visual"
language_model_attr = "model.language_model"
freeze_vision_encoder = true

[model.lora]
rank = 8
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]

[model.ac]
freq = 1

[data]
type = "sft"
name = "json"
data_files = ["/home/ubuntu/plex_traces_sample.jsonl.zst"]
batch_size = 1
micro_batch_size = 1
seq_len = 16384

[optim]
lr = 1e-5
49 changes: 49 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ class SFTDataConfig(BaseDataConfig):
name: Annotated[str, Field(description="Name or path of the HF dataset to use.")] = (
"PrimeIntellect/Reverse-Text-SFT"
)
data_files: Annotated[
list[str] | None,
Field(
description=(
"Optional list of local files (JSONL, JSONL.zst, …) to load via "
"``load_dataset(name, data_files=...)``. When set, ``name`` should "
"be a loader id like ``'json'``. ``.zst`` files are transparently "
"decompressed to a tempdir before loading."
),
),
] = None
subsets: Annotated[list[str] | None, Field(description="Subsets to use from the HF dataset.")] = None
splits: Annotated[list[str] | None, Field(description="Splits to use from the HF dataset.")] = None
probabilities: Annotated[list[float] | None, Field(description="Probabilities to use for each subset/split.")] = (
Expand Down Expand Up @@ -338,6 +349,44 @@ def validate_cp_micro_batch_size(self):
raise ValueError("Validation micro batch size must be 1 when CP is enabled")
return self

@model_validator(mode="after")
def vlms_require_bfloat16(self):
if self.model.vlm is not None and (
self.model.optimization_dtype != "bfloat16" or self.model.reduce_dtype != "bfloat16"
):
raise ValueError(
"VLM models must use optimization_dtype='bfloat16' and reduce_dtype='bfloat16' to match the HF processor output dtype."
)
return self

@model_validator(mode="after")
def vlm_freeze_incompatible_with_lora(self):
if self.model.vlm is not None and not self.model.vlm.freeze_vision_encoder and self.model.lora is not None:
raise ValueError(
"freeze_vision_encoder=false is incompatible with LoRA. "
"LoRA freezes all non-adapter parameters including the vision encoder."
)
return self

@model_validator(mode="after")
def validate_vlm_constraints(self):
# VLM samples carry image patches, so per-sample pixel buffers can't be
# packed across samples and seq dimension can't be sharded without
# splitting an image. Enforce both at config time.
if self.model.vlm is None:
return self
if self.data.micro_batch_size != 1:
raise ValueError(
"VLM SFT requires data.micro_batch_size = 1 (image samples can't be packed across samples)."
)
if self.val is not None and self.val.data.micro_batch_size != 1:
raise ValueError(
"VLM SFT requires val.data.micro_batch_size = 1 (image samples can't be packed across samples)."
)
if self.model.cp > 1:
raise ValueError("VLM SFT does not support CP > 1 (image placeholders straddle seq shards).")
return self

@model_validator(mode="after")
def validate_seq_len(self):
if self.data.pack_function == "stack" and self.data.seq_len % 256 != 0:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"uvloop>=0.21.0",
"torchtitan",
"verifiers",
"renderers==0.1.6",
"renderers>=0.1.7",
"dion",
"tilelang>=0.1.8",
"flash-linear-attention",
Expand Down
58 changes: 58 additions & 0 deletions skills/config/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,64 @@ uv run sft --data.type fake --data.batch-size 4

If you wish to configure values of the default variant, you don't need to set the `type` field.

### SFT with images (multimodal)

SFT supports VLMs through the `renderers` package. Set `[model.vlm]` to opt in:

```toml
[model]
name = "Qwen/Qwen3.5-0.8B"

[model.vlm]
freeze_vision_encoder = true # required when combined with LoRA
```

The trainer uses `create_renderer(tokenizer, "auto")` to pick the right renderer from the tokenizer's model name (Qwen3.5, Qwen3-VL, GLM, etc.). Image content blocks in `messages` are auto-resolved to pixel features by the renderer's processor; no prime-rl-side code knows about specific VLMs.

Constraints (enforced by validators in `SFTConfig.validate_vlm_constraints`):
- `data.micro_batch_size = 1` and `val.data.micro_batch_size = 1` (image samples can't be packed across samples — pixel buffers are variable per image).
- `model.cp = 1` (sequence sharding would split images across ranks).
- LoRA requires `model.vlm.freeze_vision_encoder = true` (a separate validator in `trainer.py`).

Loading local data files (e.g. JSONL or JSONL.zst):

```toml
[data]
type = "sft"
name = "json"
data_files = ["/path/to/file.jsonl.zst"]
```

`.zst` files are transparently decompressed to `$TMPDIR` before `load_dataset("json", data_files=...)`.

The renderer also handles tool calls and tool messages — OpenAI-format `tool_calls[].function.arguments` may be either a JSON string or a dict; the renderer accepts both.

### MoE + activation checkpointing trap

`[model.ac] mode = "full"` (the default) is **not safe** for MoE models. `torch.utils.checkpoint` raises mid-training:

```
torch.utils.checkpoint.CheckpointError:
Recomputed values for the following tensors have different metadata than during the forward pass.
```

Root cause: the token-choice router does `topk` over near-tied bf16 scores. GPU bf16 matmul reduction order is non-deterministic, so the same input produces slightly different gate logits on the second call (forward vs activation-checkpoint recompute). Different winning experts → different `num_tokens_per_expert` → per-expert tensor shapes don't match between forward-saved and backward-recomputed → backward dies. The bug is in the MoE side, not the checkpoint wrapper. See [PyTorch #171355](https://github.com/pytorch/pytorch/issues/171355).

The fix is to **not** activation-checkpoint the MoE block. Use selective AC and exclude `routed_experts`:

```toml
[model.ac]
mode = "selective"
freq = 1
# Skip routed_experts: MoE recompute drifts between forward and backward.
# Keep norm/attn_proj/linear_attn to recover most of the memory savings.
targets = ["norm", "attn_proj", "linear_attn"]
```

Trade-off: MoE outputs and per-expert intermediates stay in memory across forward → backward, so peak usage rises. For Qwen3.5-35B-A3B at seq_len=32k, 8x B300 (275 GiB each), we measured ~126 GiB peak with full AC vs ~249 GiB with selective AC excluding `routed_experts`. Plan accordingly.

This is **not Blackwell-specific** and **not VLM-specific** — anyone training any MoE model with full AC will hit it eventually (the per-step crash probability is roughly proportional to the number of router invocations).

### SFT hard distill override

For hosted multi-tenant runs where the trainer image's `trainer.loss.type` is fixed, the orchestrator exposes a per-run override that forces SFT loss on every micro-batch without rebuilding the trainer. Set `orchestrator.use_sft_loss = true` alongside `orchestrator.teacher_rollout_model`; both must be configured together (the orchestrator validator enforces this). The orchestrator stamps each `TrainingSample.sft_loss = True`, which the trainer's `compute_loss` honors by dispatching to `sft_loss_fn` per batch — independent of the trainer's configured default loss.
Expand Down
26 changes: 25 additions & 1 deletion src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,26 @@ def setup_tokenizer(config: TokenizerConfig) -> PreTrainedTokenizer:
return tokenizer


def setup_processor(config: TokenizerConfig):
"""Load an ``AutoProcessor`` for VLM models. Returns ``None`` for text-only models."""
from transformers import AutoProcessor

logger = get_logger()
try:
processor = AutoProcessor.from_pretrained(config.name, trust_remote_code=config.trust_remote_code)
except (ValueError, OSError, KeyError) as e:
logger.debug(f"No AutoProcessor available for {config.name} ({type(e).__name__}); treating as text-only.")
return None
# AutoProcessor returns a plain tokenizer object for text-only models. We
# only want the processor when it actually carries an image / video
# processor — gate on that.
if not (getattr(processor, "image_processor", None) or getattr(processor, "video_processor", None)):
logger.debug(f"AutoProcessor for {config.name} has no image/video processor; treating as text-only.")
return None
logger.info(f"Loaded multimodal processor: {type(processor).__name__}")
return processor


def setup_fsdp(model: nn.Module, config: ModelConfig, parallel_dims: ParallelDims):
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=DTYPE_MAP[config.reduce_dtype])
offload_policy: OffloadPolicy = CPUOffloadPolicy(pin_memory=True) if config.fsdp_cpu_offload else OffloadPolicy()
Expand Down Expand Up @@ -1144,7 +1164,11 @@ def forward(
assert image_grid_thw is not None, "pixel_values requires image_grid_thw for MRoPE computation"
kwargs["pixel_values"] = pixel_values
kwargs["image_grid_thw"] = image_grid_thw
mm_token_type_ids = _get_qwen3_vl_mm_token_type_ids(model, input_ids)
# Prefer the caller-supplied mm_token_type_ids (computed by the renderer
# from placeholder ranges — exact). Fall back to the input-id sniffing
# helper, which only handles qwen3_vl but stays for back-compat.
if mm_token_type_ids is None:
mm_token_type_ids = _get_qwen3_vl_mm_token_type_ids(model, input_ids)
if mm_token_type_ids is not None:
kwargs["mm_token_type_ids"] = mm_token_type_ids
else:
Expand Down
Loading