diff --git a/configs/ci/integration/sft_lora/resume.toml b/configs/ci/integration/sft_lora/resume.toml new file mode 100644 index 0000000000..bcc9145e98 --- /dev/null +++ b/configs/ci/integration/sft_lora/resume.toml @@ -0,0 +1,30 @@ +max_steps = 20 + +[ckpt] +resume_step = 10 + +[ckpt.weights] +save_adapter_separately = true + +[model] +name = "PrimeIntellect/Qwen3-0.6B" + +[model.lora] +rank = 8 +target_modules = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] + +[data] +name = "PrimeIntellect/Reverse-Text-SFT" +batch_size = 4 +seq_len = 1024 + +[optim] +lr = 1.5e-5 diff --git a/configs/ci/integration/sft_lora/start.toml b/configs/ci/integration/sft_lora/start.toml new file mode 100644 index 0000000000..0d775bdfa6 --- /dev/null +++ b/configs/ci/integration/sft_lora/start.toml @@ -0,0 +1,29 @@ +max_steps = 10 + +[ckpt] + +[ckpt.weights] +save_adapter_separately = true + +[model] +name = "PrimeIntellect/Qwen3-0.6B" + +[model.lora] +rank = 8 +target_modules = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] + +[data] +name = "PrimeIntellect/Reverse-Text-SFT" +batch_size = 4 +seq_len = 1024 + +[optim] +lr = 1.5e-5 diff --git a/src/prime_rl/trainer/ckpt.py b/src/prime_rl/trainer/ckpt.py index d9fca0fa5e..f2416c29f2 100644 --- a/src/prime_rl/trainer/ckpt.py +++ b/src/prime_rl/trainer/ckpt.py @@ -12,6 +12,7 @@ from torch.distributed.checkpoint.state_dict_loader import load as dcp_load from torch.distributed.checkpoint.state_dict_saver import save as dcp_save from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.tensor import DTensor from torch.nn import Module from torch.optim.lr_scheduler import LRScheduler from torch.optim.optimizer import Optimizer @@ -22,10 +23,9 @@ from prime_rl.trainer.lora import has_lora_layers, save_lora_config from prime_rl.trainer.models import PreTrainedModelPrimeRL from prime_rl.trainer.optim import CPUOffloadOptimizer -from prime_rl.trainer.runs import Progress +from prime_rl.trainer.runs import Progress, get_multi_run_manager from prime_rl.trainer.weights import ( gather_weights_on_master, - get_adapter_state_dict, save_state_dict, ) from prime_rl.trainer.world import get_world @@ -304,6 +304,19 @@ def mark_stable(self, step: int) -> None: step_path = self.get_step_path(step) (step_path / "STABLE").touch() + def get_run_adapter_state_dict(self) -> dict[str, Tensor]: + lora_state_dict = { + f"base_model.model.{key}": (value.full_tensor() if isinstance(value, DTensor) else value).to( + "cpu", non_blocking=False + ) + for key, value in get_multi_run_manager().get_state_dict_for_run(0).items() + } + + if not lora_state_dict: + raise ValueError("The LoRA state dict is empty. Something went wrong.") + + return lora_state_dict + def save_to_path( self, path: Path, @@ -339,10 +352,12 @@ def save_to_path( gen_config.save_pretrained(path) tokenizer.save_pretrained(path) - if self.config.save_adapter_separately and lora_state_dict is not None: + if lora_state_dict is not None: adapter_path = path / "lora_adapters" adapter_path.mkdir(parents=True, exist_ok=True) - torch.save(lora_state_dict, adapter_path / "adapter_model.bin") + save_state_dict( + lora_state_dict, adapter_path, self.config.save_format, save_sharded=False, adapter=True + ) if self.lora_config: save_lora_config( model, @@ -374,11 +389,11 @@ def save( for key in getattr(model, "_tied_weights_keys", []): state_dict.pop(key, None) - if has_lora_layers(model): - self.logger.debug("Getting LoRA state dict on master rank for weight checkpoint") + if has_lora_layers(model) and self.config.save_adapter_separately: + self.logger.debug("Getting run adapter state dict for weight checkpoint") start_time = time.perf_counter() - lora_state_dict = get_adapter_state_dict(model, self.world.is_master) - self.logger.debug(f"Got LoRA state dict on master rank in {time.perf_counter() - start_time:.2f} seconds") + lora_state_dict = self.get_run_adapter_state_dict() + self.logger.debug(f"Got run adapter state dict in {time.perf_counter() - start_time:.2f} seconds") else: lora_state_dict = None diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 089b78703d..29fb79494e 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -16,7 +16,8 @@ from prime_rl.utils.pathing import resolve_latest_ckpt_step from prime_rl.configs.sft import SFTConfig from prime_rl.utils.cp import setup_cp_params, shard_for_cp -from prime_rl.trainer.runs import Progress +from prime_rl.trainer.runs import Progress, get_multi_run_manager, setup_multi_run_manager +from prime_rl.trainer.models.layers.lora import set_lora_num_tokens from prime_rl.utils.logger import setup_logger from prime_rl.trainer.optim import setup_optimizer from prime_rl.trainer.scheduler import setup_scheduler @@ -82,6 +83,9 @@ def train(config: SFTConfig): ) torch.set_float32_matmul_precision("high") + if config.model.lora is not None: + setup_multi_run_manager(config.output_dir, 1, torch.device("cuda", world.local_rank), config.model.lora) + # Initialize parallel dimensions parallel_dims = get_parallel_dims(config.model, config.data.seq_len) @@ -120,6 +124,11 @@ def train(config: SFTConfig): config.model, parallel_dims, loading_from_ckpt_later, fused_cross_entropy=config.loss_impl == "liger_fused" ) + if config.model.lora is not None: + multi_run_manager = get_multi_run_manager() + multi_run_manager.reset_run_parameters(0) + multi_run_manager.scaling_factors[0] = config.model.lora.alpha / config.model.lora.rank + logger.info(f"Initializing tokenizer ({config.tokenizer})") tokenizer = setup_tokenizer(config.tokenizer) @@ -199,6 +208,9 @@ def compute_loss(micro_batch: dict) -> tuple[torch.Tensor, torch.Tensor]: target_ids = shard_for_cp(target_ids, cp_rank=cp_rank, cp_world_size=cp_size) loss_mask = shard_for_cp(loss_mask, cp_rank=cp_rank, cp_world_size=cp_size) + if config.model.lora is not None: + set_lora_num_tokens(torch.full((1,), input_ids.numel(), dtype=torch.int32, device="cuda")) + token_count = loss_mask.sum(dtype=torch.int64) with maybe_activation_offloading(config.model.ac_offloading): diff --git a/tests/integration/test_sft_lora.py b/tests/integration/test_sft_lora.py new file mode 100644 index 0000000000..8f0a938daa --- /dev/null +++ b/tests/integration/test_sft_lora.py @@ -0,0 +1,129 @@ +from pathlib import Path +from typing import Callable + +import pytest + +from prime_rl.trainer.weights import load_state_dict +from tests.conftest import ProcessResult +from tests.utils import check_loss_goes_down, strip_escape_codes + +pytestmark = [pytest.mark.slow, pytest.mark.gpu] + +TIMEOUT = 300 # 5 minutes + + +def assert_adapter_checkpoint(adapter_dir: Path) -> None: + assert (adapter_dir / "adapter_config.json").exists() + state_dict = load_state_dict(adapter_dir) + assert state_dict + assert all(".0.weight" not in key for key in state_dict) + assert any(key.endswith("lora_A.weight") for key in state_dict) + assert all(key.startswith("base_model.model.") for key in state_dict) + + +@pytest.fixture(scope="module") +def wandb_name(branch_name: str) -> str: + """Fixture for W&B name for SFT LoRA CI integration tests.""" + return f"test-sft-lora-{branch_name}" + + +@pytest.fixture(scope="module") +def sft_lora_process( + run_process: Callable[..., ProcessResult], + wandb_project: str, + wandb_name: str, + output_dir: Path, +) -> ProcessResult: + """Fixture for running SFT LoRA CI integration test""" + cmd = [ + "uv", + "run", + "torchrun", + "--local-ranks-filter", + "0", + "--nproc-per-node", + "2", + "src/prime_rl/trainer/sft/train.py", + "@", + "configs/ci/integration/sft_lora/start.toml", + "--wandb.project", + wandb_project, + "--wandb.name", + wandb_name, + "--output-dir", + output_dir.as_posix(), + ] + + return run_process(cmd, timeout=TIMEOUT) + + +@pytest.fixture(scope="module") +def sft_lora_resume_process( + sft_lora_process, # Resume training can only start when regular SFT LoRA process is finished + run_process: Callable[..., ProcessResult], + wandb_project: str, + wandb_name: str, + output_dir: Path, +) -> ProcessResult: + """Fixture for resuming SFT LoRA CI integration test""" + wandb_name += "-resume" + cmd = [ + "uv", + "run", + "torchrun", + "--local-ranks-filter", + "0", + "--nproc-per-node", + "2", + "src/prime_rl/trainer/sft/train.py", + "@", + "configs/ci/integration/sft_lora/resume.toml", + "--wandb.project", + wandb_project, + "--wandb.name", + wandb_name, + "--output-dir", + output_dir.as_posix(), + ] + + return run_process(cmd, timeout=TIMEOUT) + + +def test_no_error(sft_lora_process: ProcessResult): + """Tests that the SFT LoRA process does not fail.""" + assert sft_lora_process.returncode == 0, f"Process has non-zero return code ({sft_lora_process})" + + +def test_loss_goes_down(sft_lora_process: ProcessResult, output_dir: Path): + """Tests that the loss goes down in the SFT LoRA process""" + trainer_log_path = output_dir / "logs" / "trainer" / "rank_0.log" + print(f"Checking trainer path in {trainer_log_path}") + with open(trainer_log_path, "r") as f: + trainer_stdout = strip_escape_codes(f.read()).splitlines() + check_loss_goes_down(trainer_stdout) + + +def test_adapter_checkpoint_written(sft_lora_process: ProcessResult, output_dir: Path): + """Tests that the adapter checkpoint is written with valid PEFT-compatible keys.""" + adapter_dir = output_dir / "weights" / "step_10" / "lora_adapters" + assert_adapter_checkpoint(adapter_dir) + + +def test_no_error_resume(sft_lora_resume_process: ProcessResult): + """Tests that the SFT LoRA resume process does not fail.""" + assert sft_lora_resume_process.returncode == 0, f"Process has non-zero return code ({sft_lora_resume_process})" + + +def test_loss_goes_down_resume(sft_lora_resume_process: ProcessResult, output_dir: Path): + """Tests that the loss goes down in the SFT LoRA resume process""" + trainer_log_path = output_dir / "logs" / "trainer" / "rank_0.log" + print(f"Checking trainer path in {trainer_log_path}") + with open(trainer_log_path, "r") as f: + trainer_stdout = strip_escape_codes(f.read()).splitlines() + check_loss_goes_down(trainer_stdout) + + +def test_adapter_checkpoint_written_resume(sft_lora_resume_process: ProcessResult, output_dir: Path): + """Tests that the adapter checkpoint is written after resuming with valid PEFT-compatible keys.""" + adapter_dir = output_dir / "weights" / "step_20" / "lora_adapters" + assert_adapter_checkpoint(adapter_dir)