Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions configs/ci/integration/sft_lora/resume.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
max_steps = 20

[ckpt]
resume_step = 10

[ckpt.weights]
save_adapter_separately = true

[model]
name = "PrimeIntellect/Qwen3-0.6B"

[model.lora]
rank = 8
target_modules = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]

[data]
name = "PrimeIntellect/Reverse-Text-SFT"
batch_size = 4
seq_len = 1024

[optim]
lr = 1.5e-5
29 changes: 29 additions & 0 deletions configs/ci/integration/sft_lora/start.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
max_steps = 10

[ckpt]

[ckpt.weights]
save_adapter_separately = true

[model]
name = "PrimeIntellect/Qwen3-0.6B"

[model.lora]
rank = 8
target_modules = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]

[data]
name = "PrimeIntellect/Reverse-Text-SFT"
batch_size = 4
seq_len = 1024

[optim]
lr = 1.5e-5
31 changes: 23 additions & 8 deletions src/prime_rl/trainer/ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.distributed.checkpoint.state_dict_loader import load as dcp_load
from torch.distributed.checkpoint.state_dict_saver import save as dcp_save
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.tensor import DTensor
from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
Expand All @@ -22,10 +23,9 @@
from prime_rl.trainer.lora import has_lora_layers, save_lora_config
from prime_rl.trainer.models import PreTrainedModelPrimeRL
from prime_rl.trainer.optim import CPUOffloadOptimizer
from prime_rl.trainer.runs import Progress
from prime_rl.trainer.runs import Progress, get_multi_run_manager
from prime_rl.trainer.weights import (
gather_weights_on_master,
get_adapter_state_dict,
save_state_dict,
)
from prime_rl.trainer.world import get_world
Expand Down Expand Up @@ -304,6 +304,19 @@ def mark_stable(self, step: int) -> None:
step_path = self.get_step_path(step)
(step_path / "STABLE").touch()

def get_run_adapter_state_dict(self) -> dict[str, Tensor]:
lora_state_dict = {
f"base_model.model.{key}": (value.full_tensor() if isinstance(value, DTensor) else value).to(
"cpu", non_blocking=False
)
for key, value in get_multi_run_manager().get_state_dict_for_run(0).items()
}

if not lora_state_dict:
raise ValueError("The LoRA state dict is empty. Something went wrong.")

return lora_state_dict

def save_to_path(
self,
path: Path,
Expand Down Expand Up @@ -339,10 +352,12 @@ def save_to_path(
gen_config.save_pretrained(path)
tokenizer.save_pretrained(path)

if self.config.save_adapter_separately and lora_state_dict is not None:
if lora_state_dict is not None:
adapter_path = path / "lora_adapters"
adapter_path.mkdir(parents=True, exist_ok=True)
torch.save(lora_state_dict, adapter_path / "adapter_model.bin")
save_state_dict(
lora_state_dict, adapter_path, self.config.save_format, save_sharded=False, adapter=True
)
if self.lora_config:
save_lora_config(
model,
Expand Down Expand Up @@ -374,11 +389,11 @@ def save(
for key in getattr(model, "_tied_weights_keys", []):
state_dict.pop(key, None)

if has_lora_layers(model):
self.logger.debug("Getting LoRA state dict on master rank for weight checkpoint")
if has_lora_layers(model) and self.config.save_adapter_separately:
self.logger.debug("Getting run adapter state dict for weight checkpoint")
start_time = time.perf_counter()
lora_state_dict = get_adapter_state_dict(model, self.world.is_master)
self.logger.debug(f"Got LoRA state dict on master rank in {time.perf_counter() - start_time:.2f} seconds")
lora_state_dict = self.get_run_adapter_state_dict()
self.logger.debug(f"Got run adapter state dict in {time.perf_counter() - start_time:.2f} seconds")
else:
lora_state_dict = None

Expand Down
14 changes: 13 additions & 1 deletion src/prime_rl/trainer/sft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from prime_rl.utils.pathing import resolve_latest_ckpt_step
from prime_rl.configs.sft import SFTConfig
from prime_rl.utils.cp import setup_cp_params, shard_for_cp
from prime_rl.trainer.runs import Progress
from prime_rl.trainer.runs import Progress, get_multi_run_manager, setup_multi_run_manager
from prime_rl.trainer.models.layers.lora import set_lora_num_tokens
from prime_rl.utils.logger import setup_logger
from prime_rl.trainer.optim import setup_optimizer
from prime_rl.trainer.scheduler import setup_scheduler
Expand Down Expand Up @@ -82,6 +83,9 @@ def train(config: SFTConfig):
)
torch.set_float32_matmul_precision("high")

if config.model.lora is not None:
setup_multi_run_manager(config.output_dir, 1, torch.device("cuda", world.local_rank), config.model.lora)

# Initialize parallel dimensions
parallel_dims = get_parallel_dims(config.model, config.data.seq_len)

Expand Down Expand Up @@ -120,6 +124,11 @@ def train(config: SFTConfig):
config.model, parallel_dims, loading_from_ckpt_later, fused_cross_entropy=config.loss_impl == "liger_fused"
)

if config.model.lora is not None:
multi_run_manager = get_multi_run_manager()
multi_run_manager.reset_run_parameters(0)
multi_run_manager.scaling_factors[0] = config.model.lora.alpha / config.model.lora.rank

logger.info(f"Initializing tokenizer ({config.tokenizer})")
tokenizer = setup_tokenizer(config.tokenizer)

Expand Down Expand Up @@ -199,6 +208,9 @@ def compute_loss(micro_batch: dict) -> tuple[torch.Tensor, torch.Tensor]:
target_ids = shard_for_cp(target_ids, cp_rank=cp_rank, cp_world_size=cp_size)
loss_mask = shard_for_cp(loss_mask, cp_rank=cp_rank, cp_world_size=cp_size)

if config.model.lora is not None:
set_lora_num_tokens(torch.full((1,), input_ids.numel(), dtype=torch.int32, device="cuda"))

token_count = loss_mask.sum(dtype=torch.int64)

with maybe_activation_offloading(config.model.ac_offloading):
Expand Down
129 changes: 129 additions & 0 deletions tests/integration/test_sft_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from pathlib import Path
from typing import Callable

import pytest

from prime_rl.trainer.weights import load_state_dict
from tests.conftest import ProcessResult
from tests.utils import check_loss_goes_down, strip_escape_codes

pytestmark = [pytest.mark.slow, pytest.mark.gpu]

TIMEOUT = 300 # 5 minutes


def assert_adapter_checkpoint(adapter_dir: Path) -> None:
assert (adapter_dir / "adapter_config.json").exists()
state_dict = load_state_dict(adapter_dir)
assert state_dict
assert all(".0.weight" not in key for key in state_dict)
assert any(key.endswith("lora_A.weight") for key in state_dict)
assert all(key.startswith("base_model.model.") for key in state_dict)


@pytest.fixture(scope="module")
def wandb_name(branch_name: str) -> str:
"""Fixture for W&B name for SFT LoRA CI integration tests."""
return f"test-sft-lora-{branch_name}"


@pytest.fixture(scope="module")
def sft_lora_process(
run_process: Callable[..., ProcessResult],
wandb_project: str,
wandb_name: str,
output_dir: Path,
) -> ProcessResult:
"""Fixture for running SFT LoRA CI integration test"""
cmd = [
"uv",
"run",
"torchrun",
"--local-ranks-filter",
"0",
"--nproc-per-node",
"2",
"src/prime_rl/trainer/sft/train.py",
"@",
"configs/ci/integration/sft_lora/start.toml",
"--wandb.project",
wandb_project,
"--wandb.name",
wandb_name,
"--output-dir",
output_dir.as_posix(),
]

return run_process(cmd, timeout=TIMEOUT)


@pytest.fixture(scope="module")
def sft_lora_resume_process(
sft_lora_process, # Resume training can only start when regular SFT LoRA process is finished
run_process: Callable[..., ProcessResult],
wandb_project: str,
wandb_name: str,
output_dir: Path,
) -> ProcessResult:
"""Fixture for resuming SFT LoRA CI integration test"""
wandb_name += "-resume"
cmd = [
"uv",
"run",
"torchrun",
"--local-ranks-filter",
"0",
"--nproc-per-node",
"2",
"src/prime_rl/trainer/sft/train.py",
"@",
"configs/ci/integration/sft_lora/resume.toml",
"--wandb.project",
wandb_project,
"--wandb.name",
wandb_name,
"--output-dir",
output_dir.as_posix(),
]

return run_process(cmd, timeout=TIMEOUT)


def test_no_error(sft_lora_process: ProcessResult):
"""Tests that the SFT LoRA process does not fail."""
assert sft_lora_process.returncode == 0, f"Process has non-zero return code ({sft_lora_process})"


def test_loss_goes_down(sft_lora_process: ProcessResult, output_dir: Path):
"""Tests that the loss goes down in the SFT LoRA process"""
trainer_log_path = output_dir / "logs" / "trainer" / "rank_0.log"
print(f"Checking trainer path in {trainer_log_path}")
with open(trainer_log_path, "r") as f:
trainer_stdout = strip_escape_codes(f.read()).splitlines()
check_loss_goes_down(trainer_stdout)


def test_adapter_checkpoint_written(sft_lora_process: ProcessResult, output_dir: Path):
"""Tests that the adapter checkpoint is written with valid PEFT-compatible keys."""
adapter_dir = output_dir / "weights" / "step_10" / "lora_adapters"
assert_adapter_checkpoint(adapter_dir)


def test_no_error_resume(sft_lora_resume_process: ProcessResult):
"""Tests that the SFT LoRA resume process does not fail."""
assert sft_lora_resume_process.returncode == 0, f"Process has non-zero return code ({sft_lora_resume_process})"


def test_loss_goes_down_resume(sft_lora_resume_process: ProcessResult, output_dir: Path):
"""Tests that the loss goes down in the SFT LoRA resume process"""
trainer_log_path = output_dir / "logs" / "trainer" / "rank_0.log"
print(f"Checking trainer path in {trainer_log_path}")
with open(trainer_log_path, "r") as f:
trainer_stdout = strip_escape_codes(f.read()).splitlines()
check_loss_goes_down(trainer_stdout)


def test_adapter_checkpoint_written_resume(sft_lora_resume_process: ProcessResult, output_dir: Path):
"""Tests that the adapter checkpoint is written after resuming with valid PEFT-compatible keys."""
adapter_dir = output_dir / "weights" / "step_20" / "lora_adapters"
assert_adapter_checkpoint(adapter_dir)
Loading