Skip to content

Add SFT validation eval with val_data#1850

Merged
hallerite merged 22 commits into
PrimeIntellect-ai:mainfrom
philippnormann:feature/sft-val-eval
Mar 10, 2026
Merged

Add SFT validation eval with val_data#1850
hallerite merged 22 commits into
PrimeIntellect-ai:mainfrom
philippnormann:feature/sft-val-eval

Conversation

@philippnormann
Copy link
Copy Markdown
Contributor

@philippnormann philippnormann commented Feb 22, 2026

Summary

  • Add optional val_data and eval config blocks to SFT.
  • Run periodic validation inside SFT training and log val/loss and val/num_batches.
  • Add config validation that requires eval and val_data to be set together.
  • Add unit tests for config validation behavior.

Why

Train loss alone is not enough for checkpoint selection and overfitting detection.

Before

  • No native periodic validation signal in SFT runs.

After

  • SFT can emit validation metrics at configurable intervals during training.

Evidence

  • Reverse-text run showing periodic validation logging behavior.
train/loss val/loss
train-loss val-loss
  • Config used:

sft_fullft_rtext_split_200.toml

max_steps = 200

[ckpt]
interval = 20

[model]
name = "PrimeIntellect/Qwen3-0.6B"

[data]
name = "willcb/R1-reverse-wikipedia-paragraphs-v1-1000"
splits = ["train[:90%]"]
seq_len = 4096
batch_size = 32
shuffle = true
seed = 42

[val_data]
name = "willcb/R1-reverse-wikipedia-paragraphs-v1-1000"
splits = ["train[90%:]"]
seq_len = 4096
batch_size = 32
shuffle = false
seed = 42

[eval]
interval = 10
num_batches = 4

[optim]
lr = 2e-5

Validation

  • uv run pytest tests/unit/train/sft/test_sft_eval_config.py -q
  • Unit tests cover: eval without val_data (invalid), val_data without eval (invalid), and eval + val_data (valid).
  • 200-step reverse-text run emits val/loss every 10 steps as configured.

Scope

  • This PR covers periodic SFT validation evaluation and config validation.

Note

Medium Risk
Touches core SFT training-loop and dataset-loading paths; while changes are straightforward, they can impact training correctness/performance and distributed metric aggregation if edge cases are missed.

Overview
Adds optional SFT validation via new SFTValConfig (sft.val) that loads a separate validation dataset and runs full-pass evaluation on a configurable interval (and optionally at step 0), logging val/loss.

Refactors SFT data loading by extracting load_sft_dataset() (expensive HF I/O) and extending setup_dataset() to accept a preloaded raw dataset plus max_epochs, enabling validation to reuse preloaded data.

Restructures the SFT training loop to centralize loss/forward-backward into helpers and to compute aggregated loss/NaN counts consistently across distributed ranks, with minor logging tweaks (e.g., only emitting max_vio when present).

Written by Cursor Bugbot for commit d80bb23. This will update automatically on new commits. Configure here.

Comment thread src/prime_rl/configs/sft.py Outdated
Comment thread src/prime_rl/trainer/sft/train.py Outdated
Comment thread src/prime_rl/configs/sft.py Outdated
Apply CP compatibility checks to val_data, align eval scheduling with checkpoint step numbering, and document new SFT eval config fields in the changelog.
Add SFTEvalConfig.eval_on_start to support an explicit pre-training validation pass while keeping interval-based eval semantics unchanged by default.
Comment thread src/prime_rl/trainer/sft/train.py Outdated
Comment thread src/prime_rl/trainer/sft/train.py Outdated
Comment thread tests/unit/train/sft/test_sft_eval_config.py Outdated
Comment thread src/prime_rl/configs/sft.py Outdated
Comment thread src/prime_rl/configs/sft.py Outdated
Comment thread src/prime_rl/trainer/sft/train.py Outdated
Comment thread src/prime_rl/trainer/sft/train.py Outdated
@philippnormann philippnormann requested a review from samsja March 8, 2026 15:19
Comment thread src/prime_rl/trainer/sft/train.py Outdated
Comment on lines +263 to +283
def run_validation(step: int) -> None:
val_dataset = setup_dataset(
tokenizer, config.val.data, config.model.cp * config.model.tp, max_epochs=1, raw_dataset=val_raw_dataset
)
val_dataloader = setup_dataloader(val_dataset, config.val.data)

was_training = model.training
model.eval()
mean_loss, nan_count, _ = run_forward_loop(val_dataloader, backward=False)
if nan_count > 0:
logger.warning(f"Validation at step {step}: {nan_count} batches had NaN loss")
if mean_loss != mean_loss:
logger.warning(f"Validation at step {step} had no valid tokens")
else:
logger.success(f"Validation | Step {step} | Loss: {mean_loss:.4f}")
monitor.log({"val/loss": mean_loss, "step": step}, step=step)
if was_training:
model.train()

if config.val is not None and config.val.eval_on_start:
run_validation(progress.step)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need a function here

Comment thread src/prime_rl/trainer/sft/train.py Outdated
Comment on lines +282 to +283
if config.val is not None and config.val.eval_on_start:
run_validation(progress.step)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets not do this, lets rather edit the if statement below to also include step 0 as valid if eval_on_start is set

Comment thread src/prime_rl/configs/sft.py
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Comment thread src/prime_rl/trainer/sft/train.py
@hallerite hallerite merged commit 447cafe into PrimeIntellect-ai:main Mar 10, 2026
14 of 16 checks passed
@hallerite
Copy link
Copy Markdown
Member

@philippnormann thank you for the PR! Sorry for taking so long to get it into a mergeable state – it was trickier than we expected, but I think it's in a good state now.

@philippnormann philippnormann deleted the feature/sft-val-eval branch March 10, 2026 10:39
@philippnormann
Copy link
Copy Markdown
Contributor Author

No worries, glad it made it in! Appreciate you and @samsja taking the time to get it into shape.

Also have #1849 open for SFT LoRA support if you get a chance to look at it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants