Skip to content

Add SFT LoRA support#1849

Merged
Jackmin801 merged 4 commits into
PrimeIntellect-ai:mainfrom
philippnormann:feature/sft-lora-support
Mar 21, 2026
Merged

Add SFT LoRA support#1849
Jackmin801 merged 4 commits into
PrimeIntellect-ai:mainfrom
philippnormann:feature/sft-lora-support

Conversation

@philippnormann
Copy link
Copy Markdown
Contributor

@philippnormann philippnormann commented Feb 22, 2026

Summary

  • Add LoRA runtime setup to the SFT trainer path when model.lora is enabled.
  • Initialize MultiRunManager state and LoRA scaling (alpha / rank) for SFT LoRA runs.
  • Set per-step LoRA token counts so LoRA-wrapped layers receive correct token partitioning metadata.

Why

SFT LoRA was not fully wired as a first-class runtime path and could fail at startup without manual setup.

Before

  • SFT with LoRA could fail with:
    • RuntimeError: MultiRunManager not initialized. Please call setup_multi_run_manager first.

After

  • SFT LoRA starts and advances training steps in the default trainer path.

Evidence

  • Reverse-text SFT loss/mean convergence (full-ft vs LoRA), 200 steps.
W B Chart 2_22_2026, 10_19_09 PM
  • Configs used:

sft_fullft_rtext_200.toml

max_steps = 200

[ckpt]
interval = 20

[model]
name = "PrimeIntellect/Qwen3-0.6B"

[data]
name = "willcb/R1-reverse-wikipedia-paragraphs-v1-1000"
seq_len = 4096
batch_size = 32

[optim]
lr = 2e-5

sft_lora_rtext_200.toml

max_steps = 200

[ckpt]
interval = 20

[ckpt.weights]
save_adapter_separately = true
save_format = "safetensors"

[model]
name = "PrimeIntellect/Qwen3-0.6B"

[model.lora]
rank = 16
alpha = 32
dropout = 0.0
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

[data]
name = "willcb/R1-reverse-wikipedia-paragraphs-v1-1000"
seq_len = 4096
batch_size = 32

[optim]
lr = 5e-4

Validation

  • SFT LoRA run completes 200 steps without MultiRunManager initialization failures.
  • Attached reverse-text loss curve confirms stable optimization/convergence behavior.

Scope

  • This PR covers SFT LoRA runtime support.

Note

Medium Risk
Modifies SFT training initialization and distributed checkpoint saving paths for LoRA, including collectives over FSDP/DTensor; mistakes could break training startup or produce invalid adapter checkpoints.

Overview
Enables SFT runs with model.lora by initializing the MultiRunManager, setting LoRA scaling (alpha / rank), and updating per-step LoRA token counts so LoRA layers receive correct partition metadata.

Updates weight checkpointing when save_adapter_separately is enabled to save PEFT-compatible adapter artifacts via save_state_dict, while also capturing MultiRun LoRA state across all ranks (handling DTensor.full_tensor() collectives under FSDP).

Adds GPU CI integration configs and a new integration test that runs SFT LoRA start + resume, asserts loss decreases, and validates adapter checkpoint structure/keys.

Written by Cursor Bugbot for commit 833afe9. This will update automatically on new commits. Configure here.

@philippnormann philippnormann force-pushed the feature/sft-lora-support branch from 6700860 to d1cfa48 Compare February 26, 2026 10:47
@philippnormann
Copy link
Copy Markdown
Contributor Author

philippnormann commented Feb 27, 2026

Hi! It would be great to get some feedback here and in #1850 when you have time 😌

I rebased on the latest main and addressed all issues raised by the bot checks. If you think the design should be adjusted, I’d be very happy to make further changes.

I also have a follow-up PR prepared for LoRA warm-start support in the RL trainer, enabling end-to-end SFT+RL with LoRA adapters (without merges). I was waiting for feedback here and in #1850 before opening the next ones.

If possible, could someone also trigger CI for both PRs?

Thanks a lot! 🙏🏼

@philippnormann philippnormann force-pushed the feature/sft-lora-support branch from 723a6c5 to 29bf608 Compare March 20, 2026 01:18
@philippnormann
Copy link
Copy Markdown
Contributor Author

Just pushed an update that adds integration tests and fixes the adapter checkpoint export so saved adapters are PEFT-compatible / vLLM-loadable:

  • Using get_state_dict_for_run(0) in ckpt.py for clean adapter keys instead of the .0-indexed output from get_adapter_state_dict
  • Saving the adapter via save_state_dict so the format respects config (safetensors by default)
  • Integration tests in tests/integration/test_sft_lora.py with start/resume coverage and adapter key format assertions

Tested end-to-end on 1x 4090 and 2x H100, adapter loads in vLLM.

Comment thread src/prime_rl/trainer/ckpt.py Outdated
Comment thread src/prime_rl/trainer/ckpt.py Outdated
@philippnormann
Copy link
Copy Markdown
Contributor Author

Thanks for the feedback @Jackmin801!

I cleaned this up so save_to_path() now just saves the final adapter state dict it gets, and for save_adapter_separately that state dict is assembled earlier directly from get_multi_run_manager().get_state_dict_for_run(0).

The filtering was from an attempt to also keep modules_to_save, but it can be dropped since current vLLM does not support adapters with non-empty modules_to_save.

Copy link
Copy Markdown
Member

@Jackmin801 Jackmin801 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Good work. lgtm

@Jackmin801 Jackmin801 merged commit 4875c10 into PrimeIntellect-ai:main Mar 21, 2026
12 of 16 checks passed
@philippnormann philippnormann deleted the feature/sft-lora-support branch April 1, 2026 01:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants