From 29bf6088ef723cda5f55afe9ea46ac199a5db057 Mon Sep 17 00:00:00 2001 From: Philipp Normann Date: Sun, 22 Feb 2026 22:09:01 +0100 Subject: [PATCH 1/4] Add SFT LoRA support --- configs/ci/integration/sft_lora/resume.toml | 30 +++++ configs/ci/integration/sft_lora/start.toml | 29 +++++ src/prime_rl/trainer/ckpt.py | 23 +++- src/prime_rl/trainer/sft/train.py | 14 ++- tests/integration/test_sft_lora.py | 129 ++++++++++++++++++++ 5 files changed, 222 insertions(+), 3 deletions(-) create mode 100644 configs/ci/integration/sft_lora/resume.toml create mode 100644 configs/ci/integration/sft_lora/start.toml create mode 100644 tests/integration/test_sft_lora.py diff --git a/configs/ci/integration/sft_lora/resume.toml b/configs/ci/integration/sft_lora/resume.toml new file mode 100644 index 0000000000..bcc9145e98 --- /dev/null +++ b/configs/ci/integration/sft_lora/resume.toml @@ -0,0 +1,30 @@ +max_steps = 20 + +[ckpt] +resume_step = 10 + +[ckpt.weights] +save_adapter_separately = true + +[model] +name = "PrimeIntellect/Qwen3-0.6B" + +[model.lora] +rank = 8 +target_modules = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] + +[data] +name = "PrimeIntellect/Reverse-Text-SFT" +batch_size = 4 +seq_len = 1024 + +[optim] +lr = 1.5e-5 diff --git a/configs/ci/integration/sft_lora/start.toml b/configs/ci/integration/sft_lora/start.toml new file mode 100644 index 0000000000..0d775bdfa6 --- /dev/null +++ b/configs/ci/integration/sft_lora/start.toml @@ -0,0 +1,29 @@ +max_steps = 10 + +[ckpt] + +[ckpt.weights] +save_adapter_separately = true + +[model] +name = "PrimeIntellect/Qwen3-0.6B" + +[model.lora] +rank = 8 +target_modules = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] + +[data] +name = "PrimeIntellect/Reverse-Text-SFT" +batch_size = 4 +seq_len = 1024 + +[optim] +lr = 1.5e-5 diff --git a/src/prime_rl/trainer/ckpt.py b/src/prime_rl/trainer/ckpt.py index d9fca0fa5e..8ee83fb487 100644 --- a/src/prime_rl/trainer/ckpt.py +++ b/src/prime_rl/trainer/ckpt.py @@ -12,6 +12,7 @@ from torch.distributed.checkpoint.state_dict_loader import load as dcp_load from torch.distributed.checkpoint.state_dict_saver import save as dcp_save from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.tensor import DTensor from torch.nn import Module from torch.optim.lr_scheduler import LRScheduler from torch.optim.optimizer import Optimizer @@ -22,7 +23,7 @@ from prime_rl.trainer.lora import has_lora_layers, save_lora_config from prime_rl.trainer.models import PreTrainedModelPrimeRL from prime_rl.trainer.optim import CPUOffloadOptimizer -from prime_rl.trainer.runs import Progress +from prime_rl.trainer.runs import Progress, get_multi_run_manager from prime_rl.trainer.weights import ( gather_weights_on_master, get_adapter_state_dict, @@ -313,6 +314,18 @@ def save_to_path( tokenizer: PreTrainedTokenizer, ): """Save HF-compatible weight checkpoint to a given path.""" + # Gather LoRA run state on all ranks: full_tensor() is a collective + # that must be called by every rank simultaneously under FSDP. + if self.config.save_adapter_separately and lora_state_dict is not None: + lora_run_state = { + f"base_model.model.{key}": ( + value.full_tensor() if isinstance(value, DTensor) else value + ).to("cpu", non_blocking=False) + for key, value in get_multi_run_manager().get_state_dict_for_run(0).items() + } + else: + lora_run_state = None + if self.world.is_master: path.mkdir(parents=True, exist_ok=True) start_time = time.perf_counter() @@ -342,7 +355,13 @@ def save_to_path( if self.config.save_adapter_separately and lora_state_dict is not None: adapter_path = path / "lora_adapters" adapter_path.mkdir(parents=True, exist_ok=True) - torch.save(lora_state_dict, adapter_path / "adapter_model.bin") + lora_state_dict = { + key: value + for key, value in lora_state_dict.items() + if "lora_A" not in key and "lora_B" not in key + } + lora_state_dict.update(lora_run_state) + save_state_dict(lora_state_dict, adapter_path, self.config.save_format, save_sharded=False, adapter=True) if self.lora_config: save_lora_config( model, diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 089b78703d..29fb79494e 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -16,7 +16,8 @@ from prime_rl.utils.pathing import resolve_latest_ckpt_step from prime_rl.configs.sft import SFTConfig from prime_rl.utils.cp import setup_cp_params, shard_for_cp -from prime_rl.trainer.runs import Progress +from prime_rl.trainer.runs import Progress, get_multi_run_manager, setup_multi_run_manager +from prime_rl.trainer.models.layers.lora import set_lora_num_tokens from prime_rl.utils.logger import setup_logger from prime_rl.trainer.optim import setup_optimizer from prime_rl.trainer.scheduler import setup_scheduler @@ -82,6 +83,9 @@ def train(config: SFTConfig): ) torch.set_float32_matmul_precision("high") + if config.model.lora is not None: + setup_multi_run_manager(config.output_dir, 1, torch.device("cuda", world.local_rank), config.model.lora) + # Initialize parallel dimensions parallel_dims = get_parallel_dims(config.model, config.data.seq_len) @@ -120,6 +124,11 @@ def train(config: SFTConfig): config.model, parallel_dims, loading_from_ckpt_later, fused_cross_entropy=config.loss_impl == "liger_fused" ) + if config.model.lora is not None: + multi_run_manager = get_multi_run_manager() + multi_run_manager.reset_run_parameters(0) + multi_run_manager.scaling_factors[0] = config.model.lora.alpha / config.model.lora.rank + logger.info(f"Initializing tokenizer ({config.tokenizer})") tokenizer = setup_tokenizer(config.tokenizer) @@ -199,6 +208,9 @@ def compute_loss(micro_batch: dict) -> tuple[torch.Tensor, torch.Tensor]: target_ids = shard_for_cp(target_ids, cp_rank=cp_rank, cp_world_size=cp_size) loss_mask = shard_for_cp(loss_mask, cp_rank=cp_rank, cp_world_size=cp_size) + if config.model.lora is not None: + set_lora_num_tokens(torch.full((1,), input_ids.numel(), dtype=torch.int32, device="cuda")) + token_count = loss_mask.sum(dtype=torch.int64) with maybe_activation_offloading(config.model.ac_offloading): diff --git a/tests/integration/test_sft_lora.py b/tests/integration/test_sft_lora.py new file mode 100644 index 0000000000..2bac7792e6 --- /dev/null +++ b/tests/integration/test_sft_lora.py @@ -0,0 +1,129 @@ +from pathlib import Path +from typing import Callable + +import pytest + +from prime_rl.trainer.weights import load_state_dict +from tests.conftest import ProcessResult +from tests.utils import check_loss_goes_down, strip_escape_codes + +pytestmark = [pytest.mark.slow, pytest.mark.gpu] + +TIMEOUT = 300 # 5 minutes + + +def assert_adapter_checkpoint(adapter_dir: Path) -> None: + assert (adapter_dir / "adapter_config.json").exists() + state_dict = load_state_dict(adapter_dir) + assert state_dict + assert all(".0.weight" not in key for key in state_dict) + assert any(key.endswith("lora_A.weight") for key in state_dict) + assert all(key.startswith("base_model.model.") for key in state_dict) + + +@pytest.fixture(scope="module") +def wandb_name(branch_name: str) -> str: + """Fixture for W&B name for SFT LoRA CI integration tests.""" + return f"test-sft-lora-{branch_name}" + + +@pytest.fixture(scope="module") +def sft_lora_process( + run_process: Callable[..., ProcessResult], + wandb_project: str, + wandb_name: str, + output_dir: Path, +) -> ProcessResult: + """Fixture for running SFT LoRA CI integration test""" + cmd = [ + "uv", + "run", + "torchrun", + "--local-ranks-filter", + "0", + "--nproc-per-node", + "2", + "src/prime_rl/trainer/sft/train.py", + "@", + "configs/ci/integration/sft_lora/start.toml", + "--wandb.project", + wandb_project, + "--wandb.name", + wandb_name, + "--output-dir", + output_dir.as_posix(), + ] + + return run_process(cmd, timeout=TIMEOUT) + + +@pytest.fixture(scope="module") +def sft_lora_resume_process( + sft_lora_process, + run_process: Callable[..., ProcessResult], + wandb_project: str, + wandb_name: str, + output_dir: Path, +) -> ProcessResult: + """Fixture for resuming SFT LoRA CI integration test""" + wandb_name += "-resume" + cmd = [ + "uv", + "run", + "torchrun", + "--local-ranks-filter", + "0", + "--nproc-per-node", + "2", + "src/prime_rl/trainer/sft/train.py", + "@", + "configs/ci/integration/sft_lora/resume.toml", + "--wandb.project", + wandb_project, + "--wandb.name", + wandb_name, + "--output-dir", + output_dir.as_posix(), + ] + + return run_process(cmd, timeout=TIMEOUT) + + +def test_no_error(sft_lora_process: ProcessResult): + """Tests that the SFT LoRA process does not fail.""" + assert sft_lora_process.returncode == 0, f"Process has non-zero return code ({sft_lora_process})" + + +def test_loss_goes_down(sft_lora_process: ProcessResult, output_dir: Path): + """Tests that the loss goes down in the SFT LoRA process""" + trainer_log_path = output_dir / "logs" / "trainer" / "rank_0.log" + print(f"Checking trainer path in {trainer_log_path}") + with open(trainer_log_path, "r") as f: + trainer_stdout = strip_escape_codes(f.read()).splitlines() + check_loss_goes_down(trainer_stdout) + + +def test_adapter_checkpoint_written(sft_lora_process: ProcessResult, output_dir: Path): + """Tests that the adapter checkpoint is written with valid PEFT-compatible keys.""" + adapter_dir = output_dir / "weights" / "step_10" / "lora_adapters" + assert_adapter_checkpoint(adapter_dir) + + +def test_no_error_resume(sft_lora_resume_process: ProcessResult): + """Tests that the SFT LoRA resume process does not fail.""" + assert sft_lora_resume_process.returncode == 0, f"Process has non-zero return code ({sft_lora_resume_process})" + + +def test_loss_goes_down_resume(sft_lora_resume_process: ProcessResult, output_dir: Path): + """Tests that the loss goes down in the SFT LoRA resume process""" + trainer_log_path = output_dir / "logs" / "trainer" / "rank_0.log" + print(f"Checking trainer path in {trainer_log_path}") + with open(trainer_log_path, "r") as f: + trainer_stdout = strip_escape_codes(f.read()).splitlines() + check_loss_goes_down(trainer_stdout) + + +def test_adapter_checkpoint_written_resume(sft_lora_resume_process: ProcessResult, output_dir: Path): + """Tests that the adapter checkpoint is written after resuming with valid PEFT-compatible keys.""" + adapter_dir = output_dir / "weights" / "step_20" / "lora_adapters" + assert_adapter_checkpoint(adapter_dir) From 833afe9430af7a622474bbe060a5acf10035ecde Mon Sep 17 00:00:00 2001 From: Philipp Normann Date: Fri, 20 Mar 2026 02:35:30 +0100 Subject: [PATCH 2/4] Fix ruff formatting and add missing comment --- src/prime_rl/trainer/ckpt.py | 14 +++++++------- tests/integration/test_sft_lora.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/prime_rl/trainer/ckpt.py b/src/prime_rl/trainer/ckpt.py index 8ee83fb487..d863ebcbc6 100644 --- a/src/prime_rl/trainer/ckpt.py +++ b/src/prime_rl/trainer/ckpt.py @@ -318,9 +318,9 @@ def save_to_path( # that must be called by every rank simultaneously under FSDP. if self.config.save_adapter_separately and lora_state_dict is not None: lora_run_state = { - f"base_model.model.{key}": ( - value.full_tensor() if isinstance(value, DTensor) else value - ).to("cpu", non_blocking=False) + f"base_model.model.{key}": (value.full_tensor() if isinstance(value, DTensor) else value).to( + "cpu", non_blocking=False + ) for key, value in get_multi_run_manager().get_state_dict_for_run(0).items() } else: @@ -356,12 +356,12 @@ def save_to_path( adapter_path = path / "lora_adapters" adapter_path.mkdir(parents=True, exist_ok=True) lora_state_dict = { - key: value - for key, value in lora_state_dict.items() - if "lora_A" not in key and "lora_B" not in key + key: value for key, value in lora_state_dict.items() if "lora_A" not in key and "lora_B" not in key } lora_state_dict.update(lora_run_state) - save_state_dict(lora_state_dict, adapter_path, self.config.save_format, save_sharded=False, adapter=True) + save_state_dict( + lora_state_dict, adapter_path, self.config.save_format, save_sharded=False, adapter=True + ) if self.lora_config: save_lora_config( model, diff --git a/tests/integration/test_sft_lora.py b/tests/integration/test_sft_lora.py index 2bac7792e6..8f0a938daa 100644 --- a/tests/integration/test_sft_lora.py +++ b/tests/integration/test_sft_lora.py @@ -59,7 +59,7 @@ def sft_lora_process( @pytest.fixture(scope="module") def sft_lora_resume_process( - sft_lora_process, + sft_lora_process, # Resume training can only start when regular SFT LoRA process is finished run_process: Callable[..., ProcessResult], wandb_project: str, wandb_name: str, From 69dc853daabcc59d1e6c97b6ca93ef051fad9630 Mon Sep 17 00:00:00 2001 From: Philipp Normann Date: Sat, 21 Mar 2026 02:48:44 +0100 Subject: [PATCH 3/4] Use MultiRunManager for SFT LoRA adapter export --- src/prime_rl/trainer/ckpt.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/src/prime_rl/trainer/ckpt.py b/src/prime_rl/trainer/ckpt.py index d863ebcbc6..ba0ec6d928 100644 --- a/src/prime_rl/trainer/ckpt.py +++ b/src/prime_rl/trainer/ckpt.py @@ -26,7 +26,6 @@ from prime_rl.trainer.runs import Progress, get_multi_run_manager from prime_rl.trainer.weights import ( gather_weights_on_master, - get_adapter_state_dict, save_state_dict, ) from prime_rl.trainer.world import get_world @@ -305,6 +304,19 @@ def mark_stable(self, step: int) -> None: step_path = self.get_step_path(step) (step_path / "STABLE").touch() + def get_adapter_export_state_dict(self) -> dict[str, Tensor]: + lora_state_dict = { + f"base_model.model.{key}": (value.full_tensor() if isinstance(value, DTensor) else value).to( + "cpu", non_blocking=False + ) + for key, value in get_multi_run_manager().get_state_dict_for_run(0).items() + } + + if not lora_state_dict: + raise ValueError("The LoRA state dict is empty. Something went wrong.") + + return lora_state_dict + def save_to_path( self, path: Path, @@ -314,18 +326,6 @@ def save_to_path( tokenizer: PreTrainedTokenizer, ): """Save HF-compatible weight checkpoint to a given path.""" - # Gather LoRA run state on all ranks: full_tensor() is a collective - # that must be called by every rank simultaneously under FSDP. - if self.config.save_adapter_separately and lora_state_dict is not None: - lora_run_state = { - f"base_model.model.{key}": (value.full_tensor() if isinstance(value, DTensor) else value).to( - "cpu", non_blocking=False - ) - for key, value in get_multi_run_manager().get_state_dict_for_run(0).items() - } - else: - lora_run_state = None - if self.world.is_master: path.mkdir(parents=True, exist_ok=True) start_time = time.perf_counter() @@ -352,13 +352,9 @@ def save_to_path( gen_config.save_pretrained(path) tokenizer.save_pretrained(path) - if self.config.save_adapter_separately and lora_state_dict is not None: + if lora_state_dict is not None: adapter_path = path / "lora_adapters" adapter_path.mkdir(parents=True, exist_ok=True) - lora_state_dict = { - key: value for key, value in lora_state_dict.items() if "lora_A" not in key and "lora_B" not in key - } - lora_state_dict.update(lora_run_state) save_state_dict( lora_state_dict, adapter_path, self.config.save_format, save_sharded=False, adapter=True ) @@ -393,10 +389,10 @@ def save( for key in getattr(model, "_tied_weights_keys", []): state_dict.pop(key, None) - if has_lora_layers(model): + if has_lora_layers(model) and self.config.save_adapter_separately: self.logger.debug("Getting LoRA state dict on master rank for weight checkpoint") start_time = time.perf_counter() - lora_state_dict = get_adapter_state_dict(model, self.world.is_master) + lora_state_dict = self.get_adapter_export_state_dict() self.logger.debug(f"Got LoRA state dict on master rank in {time.perf_counter() - start_time:.2f} seconds") else: lora_state_dict = None From fa222265a3fcb82508fa4b10962b9693079dca4c Mon Sep 17 00:00:00 2001 From: Philipp Normann Date: Sat, 21 Mar 2026 03:28:42 +0100 Subject: [PATCH 4/4] Rename SFT LoRA run adapter export helper --- src/prime_rl/trainer/ckpt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/prime_rl/trainer/ckpt.py b/src/prime_rl/trainer/ckpt.py index ba0ec6d928..f2416c29f2 100644 --- a/src/prime_rl/trainer/ckpt.py +++ b/src/prime_rl/trainer/ckpt.py @@ -304,7 +304,7 @@ def mark_stable(self, step: int) -> None: step_path = self.get_step_path(step) (step_path / "STABLE").touch() - def get_adapter_export_state_dict(self) -> dict[str, Tensor]: + def get_run_adapter_state_dict(self) -> dict[str, Tensor]: lora_state_dict = { f"base_model.model.{key}": (value.full_tensor() if isinstance(value, DTensor) else value).to( "cpu", non_blocking=False @@ -390,10 +390,10 @@ def save( state_dict.pop(key, None) if has_lora_layers(model) and self.config.save_adapter_separately: - self.logger.debug("Getting LoRA state dict on master rank for weight checkpoint") + self.logger.debug("Getting run adapter state dict for weight checkpoint") start_time = time.perf_counter() - lora_state_dict = self.get_adapter_export_state_dict() - self.logger.debug(f"Got LoRA state dict on master rank in {time.perf_counter() - start_time:.2f} seconds") + lora_state_dict = self.get_run_adapter_state_dict() + self.logger.debug(f"Got run adapter state dict in {time.perf_counter() - start_time:.2f} seconds") else: lora_state_dict = None