From 3cecc7d9bd59582544923ebb80bab29f58cd6d19 Mon Sep 17 00:00:00 2001 From: Philipp Normann Date: Sun, 22 Feb 2026 22:09:08 +0100 Subject: [PATCH 01/19] Add SFT validation eval with val_data --- src/prime_rl/configs/sft.py | 17 +++++ src/prime_rl/trainer/sft/train.py | 75 ++++++++++++++++++++ tests/unit/train/sft/test_sft_eval_config.py | 22 ++++++ 3 files changed, 114 insertions(+) create mode 100644 tests/unit/train/sft/test_sft_eval_config.py diff --git a/src/prime_rl/configs/sft.py b/src/prime_rl/configs/sft.py index 82479e9cbf..0870f67f44 100644 --- a/src/prime_rl/configs/sft.py +++ b/src/prime_rl/configs/sft.py @@ -106,6 +106,11 @@ def validate_subsets_and_splits(self): return self +class SFTEvalConfig(BaseConfig): + interval: Annotated[int, Field(ge=1, description="Run validation every N training steps.")] = 50 + num_batches: Annotated[int, Field(ge=1, description="Number of validation batches per evaluation.")] = 8 + + DataConfig: TypeAlias = Annotated[FakeDataConfig | SFTDataConfig, Field(discriminator="type")] @@ -173,6 +178,12 @@ class SFTConfig(BaseSettings): # The data configuration data: DataConfig = SFTDataConfig() + # Optional validation data configuration + val_data: SFTDataConfig | None = None + + # Optional validation evaluation configuration + eval: SFTEvalConfig | None = None + # The optimizer configuration optim: OptimizerConfig = AdamWConfig() @@ -312,6 +323,12 @@ def validate_lora_adapter_saving(self): ) return self + @model_validator(mode="after") + def validate_eval_and_val_data(self): + if (self.eval is None) != (self.val_data is None): + raise ValueError("SFT validation requires both eval and val_data to be set") + return self + @model_validator(mode="after") def validate_opt_and_fsdp_offload(self): if self.optim.type == "muon" and self.model.fsdp_cpu_offload: diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 13780a5df6..43047cf294 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -141,6 +141,13 @@ def train(config: SFTConfig): dataloader = setup_dataloader(dataset, config.data) dataiter = iter(dataloader) + val_dataiter = None + if config.eval is not None and config.val_data is not None: + logger.info(f"Initializing validation data ({config.val_data})") + val_dataset = setup_dataset(tokenizer, config.val_data, config.model.cp * config.model.tp) + val_dataloader = setup_dataloader(val_dataset, config.val_data) + val_dataiter = iter(val_dataloader) + # Optionally, resume training from a checkpoint progress = Progress() @@ -174,6 +181,71 @@ def train(config: SFTConfig): case _: raise ValueError(f"Invalid loss implementation: {config.loss_impl}") + def run_validation(eval_step: int) -> None: + assert config.eval is not None + assert val_dataiter is not None + + was_training = model.training + model.eval() + + val_loss_sum = torch.tensor(0.0, device="cuda") + val_loss_count = torch.tensor(0.0, device="cuda") + + with torch.no_grad(): + for _ in range(config.eval.num_batches): + micro_batch = next(val_dataiter) + input_ids = micro_batch["input_ids"].to("cuda") + position_ids = micro_batch["position_ids"].to("cuda") + target_ids = micro_batch["target_ids"].to("cuda") + loss_mask = micro_batch["loss_mask"].to("cuda") + + if cp_enabled: + input_ids, position_ids = setup_cp_params(input_ids, position_ids, cp_rank, cp_size, cp_group) + target_ids = shard_for_cp(target_ids, cp_rank=cp_rank, cp_world_size=cp_size) + loss_mask = shard_for_cp(loss_mask, cp_rank=cp_rank, cp_world_size=cp_size) + + if config.model.lora is not None: + lora_num_tokens = torch.full((1,), input_ids.numel(), dtype=torch.int32, device="cuda") + set_lora_num_tokens(lora_num_tokens) + + out = forward(model, input_ids, position_ids) + logits = out["logits"] + B, L, V = logits.shape + loss = ce_loss(logits.view(-1, V), target_ids.view(-1)).view(B, L) + loss = loss[loss_mask].mean() + + if not torch.isnan(loss): + val_loss_sum += loss.detach() + val_loss_count += 1 + + del logits + + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_loss_count, op=dist.ReduceOp.SUM) + + if val_loss_count.item() == 0: + logger.warning(f"Validation at step {eval_step} had only NaN losses") + val_metrics = { + "val/loss": float("nan"), + "val/num_batches": 0.0, + "step": eval_step, + } + else: + mean_val_loss = (val_loss_sum / val_loss_count).item() + logger.success( + f"Validation | Step {eval_step} | Loss: {mean_val_loss:.4f} | Batches: {int(val_loss_count.item())}" + ) + val_metrics = { + "val/loss": mean_val_loss, + "val/num_batches": val_loss_count.item(), + "step": eval_step, + } + + monitor.log(val_metrics, step=eval_step) + + if was_training: + model.train() + logger.info(f"Starting training loop (max_steps={config.max_steps or 'infinite'})") max_memory = torch.cuda.mem_get_info()[1] / 1024**3 # GiB is_first_step = True @@ -405,6 +477,9 @@ def train(config: SFTConfig): } monitor.log(max_vio_log_metrics, step=progress.step) + if config.eval is not None and val_dataiter is not None and (progress.step + 1) % config.eval.interval == 0: + run_validation(progress.step) + is_first_step = False progress.step += 1 diff --git a/tests/unit/train/sft/test_sft_eval_config.py b/tests/unit/train/sft/test_sft_eval_config.py new file mode 100644 index 0000000000..9d5bfc645d --- /dev/null +++ b/tests/unit/train/sft/test_sft_eval_config.py @@ -0,0 +1,22 @@ +import pytest + +from prime_rl.trainer.sft.config import SFTDataConfig, SFTEvalConfig, SFTTrainerConfig + + +def test_sft_eval_requires_val_data(): + with pytest.raises(ValueError, match="both eval and val_data"): + SFTTrainerConfig(eval=SFTEvalConfig(interval=10, num_batches=2)) + + +def test_sft_val_data_requires_eval(): + with pytest.raises(ValueError, match="both eval and val_data"): + SFTTrainerConfig(val_data=SFTDataConfig()) + + +def test_sft_eval_with_val_data_is_valid(): + config = SFTTrainerConfig( + eval=SFTEvalConfig(interval=10, num_batches=2), + val_data=SFTDataConfig(name="willcb/R1-reverse-wikipedia-paragraphs-v1-1000", splits=["train[:5%]"]), + ) + assert config.eval is not None + assert config.val_data is not None From 03343f2ac4dd0d05df48159433e89e92428d604a Mon Sep 17 00:00:00 2001 From: Philipp Normann Date: Mon, 23 Feb 2026 09:51:49 +0100 Subject: [PATCH 02/19] Fix SFT validation config safety and step alignment Apply CP compatibility checks to val_data, align eval scheduling with checkpoint step numbering, and document new SFT eval config fields in the changelog. --- src/prime_rl/configs/sft.py | 28 +++++++++++++------- src/prime_rl/trainer/sft/train.py | 7 ++++- tests/unit/train/sft/test_sft_eval_config.py | 28 ++++++++++++++++++++ 3 files changed, 53 insertions(+), 10 deletions(-) diff --git a/src/prime_rl/configs/sft.py b/src/prime_rl/configs/sft.py index 0870f67f44..b83bb809a8 100644 --- a/src/prime_rl/configs/sft.py +++ b/src/prime_rl/configs/sft.py @@ -278,27 +278,37 @@ def validate_slurm_output_dir(self): @model_validator(mode="after") def validate_pack_function(self): - if self.model.cp > 1 and self.data.pack_function != "cat": - raise ValueError("Packing function must be 'cat' when CP is enabled") + if self.model.cp > 1: + if self.data.pack_function != "cat": + raise ValueError("Packing function must be 'cat' when CP is enabled") + if self.val_data is not None and self.val_data.pack_function != "cat": + raise ValueError("Validation packing function must be 'cat' when CP is enabled") return self @model_validator(mode="after") def validate_cp_seq_len(self): - if self.model.cp > 1 and self.data.seq_len % self.model.cp != 0: - raise ValueError("Sequence length must be divisible by CP degree") + if self.model.cp > 1: + if self.data.seq_len % self.model.cp != 0: + raise ValueError("Sequence length must be divisible by CP degree") + if self.val_data is not None and self.val_data.seq_len % self.model.cp != 0: + raise ValueError("Validation sequence length must be divisible by CP degree") return self @model_validator(mode="after") def validate_cp_micro_batch_size(self): - if self.model.cp > 1 and self.data.micro_batch_size != 1: - raise ValueError("Micro batch size must be 1 when CP is enabled") + if self.model.cp > 1: + if self.data.micro_batch_size != 1: + raise ValueError("Micro batch size must be 1 when CP is enabled") + if self.val_data is not None and self.val_data.micro_batch_size != 1: + raise ValueError("Validation micro batch size must be 1 when CP is enabled") return self @model_validator(mode="after") def validate_seq_len(self): - if self.data.pack_function == "stack": - if self.data.seq_len % 256 != 0: - raise ValueError("The sequence length must be divisible by 256 when using pack function stack") + if self.data.pack_function == "stack" and self.data.seq_len % 256 != 0: + raise ValueError("The sequence length must be divisible by 256 when using pack function stack") + if self.val_data is not None and self.val_data.pack_function == "stack" and self.val_data.seq_len % 256 != 0: + raise ValueError("The validation sequence length must be divisible by 256 when using pack function stack") return self @model_validator(mode="after") diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 43047cf294..77d27da816 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -477,7 +477,12 @@ def run_validation(eval_step: int) -> None: } monitor.log(max_vio_log_metrics, step=progress.step) - if config.eval is not None and val_dataiter is not None and (progress.step + 1) % config.eval.interval == 0: + if ( + config.eval is not None + and val_dataiter is not None + and not is_first_step + and progress.step % config.eval.interval == 0 + ): run_validation(progress.step) is_first_step = False diff --git a/tests/unit/train/sft/test_sft_eval_config.py b/tests/unit/train/sft/test_sft_eval_config.py index 9d5bfc645d..9a266bc31e 100644 --- a/tests/unit/train/sft/test_sft_eval_config.py +++ b/tests/unit/train/sft/test_sft_eval_config.py @@ -1,5 +1,6 @@ import pytest +from prime_rl.trainer.config import ModelConfig from prime_rl.trainer.sft.config import SFTDataConfig, SFTEvalConfig, SFTTrainerConfig @@ -20,3 +21,30 @@ def test_sft_eval_with_val_data_is_valid(): ) assert config.eval is not None assert config.val_data is not None + + +def test_sft_val_data_requires_cp_compatible_pack_function(): + with pytest.raises(ValueError, match="Validation packing function must be 'cat' when CP is enabled"): + SFTTrainerConfig( + model=ModelConfig(cp=2), + eval=SFTEvalConfig(interval=10, num_batches=2), + val_data=SFTDataConfig(pack_function="stack", seq_len=256), + ) + + +def test_sft_val_data_requires_cp_compatible_seq_len(): + with pytest.raises(ValueError, match="Validation sequence length must be divisible by CP degree"): + SFTTrainerConfig( + model=ModelConfig(cp=2), + eval=SFTEvalConfig(interval=10, num_batches=2), + val_data=SFTDataConfig(seq_len=127), + ) + + +def test_sft_val_data_requires_cp_compatible_micro_batch_size(): + with pytest.raises(ValueError, match="Validation micro batch size must be 1 when CP is enabled"): + SFTTrainerConfig( + model=ModelConfig(cp=2), + eval=SFTEvalConfig(interval=10, num_batches=2), + val_data=SFTDataConfig(micro_batch_size=2), + ) From fbd6f9044488850537473e05b19a7ab678a1f6d2 Mon Sep 17 00:00:00 2001 From: Philipp Normann Date: Mon, 23 Feb 2026 09:51:49 +0100 Subject: [PATCH 03/19] Add optional eval-on-start for SFT validation Add SFTEvalConfig.eval_on_start to support an explicit pre-training validation pass while keeping interval-based eval semantics unchanged by default. --- src/prime_rl/configs/sft.py | 1 + src/prime_rl/trainer/sft/train.py | 3 +++ tests/unit/train/sft/test_sft_eval_config.py | 3 ++- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/prime_rl/configs/sft.py b/src/prime_rl/configs/sft.py index b83bb809a8..10c188d98d 100644 --- a/src/prime_rl/configs/sft.py +++ b/src/prime_rl/configs/sft.py @@ -109,6 +109,7 @@ def validate_subsets_and_splits(self): class SFTEvalConfig(BaseConfig): interval: Annotated[int, Field(ge=1, description="Run validation every N training steps.")] = 50 num_batches: Annotated[int, Field(ge=1, description="Number of validation batches per evaluation.")] = 8 + eval_on_start: Annotated[bool, Field(description="Run a validation pass before the training loop starts.")] = False DataConfig: TypeAlias = Annotated[FakeDataConfig | SFTDataConfig, Field(discriminator="type")] diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 77d27da816..22d3668b7f 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -246,6 +246,9 @@ def run_validation(eval_step: int) -> None: if was_training: model.train() + if config.eval is not None and val_dataiter is not None and config.eval.eval_on_start: + run_validation(progress.step) + logger.info(f"Starting training loop (max_steps={config.max_steps or 'infinite'})") max_memory = torch.cuda.mem_get_info()[1] / 1024**3 # GiB is_first_step = True diff --git a/tests/unit/train/sft/test_sft_eval_config.py b/tests/unit/train/sft/test_sft_eval_config.py index 9a266bc31e..860144afcd 100644 --- a/tests/unit/train/sft/test_sft_eval_config.py +++ b/tests/unit/train/sft/test_sft_eval_config.py @@ -16,11 +16,12 @@ def test_sft_val_data_requires_eval(): def test_sft_eval_with_val_data_is_valid(): config = SFTTrainerConfig( - eval=SFTEvalConfig(interval=10, num_batches=2), + eval=SFTEvalConfig(interval=10, num_batches=2, eval_on_start=True), val_data=SFTDataConfig(name="willcb/R1-reverse-wikipedia-paragraphs-v1-1000", splits=["train[:5%]"]), ) assert config.eval is not None assert config.val_data is not None + assert config.eval.eval_on_start is True def test_sft_val_data_requires_cp_compatible_pack_function(): From f26ccbba44b5d517c573dc842457432b26360df3 Mon Sep 17 00:00:00 2001 From: Philipp Normann Date: Thu, 26 Feb 2026 11:58:16 +0100 Subject: [PATCH 04/19] Fix SFT validation LoRA import and changelog entry --- CHANGELOG.md | 1 + src/prime_rl/trainer/sft/train.py | 1 + 2 files changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 944aaefdb2..805c681ed0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,4 +82,5 @@ Documenting changes which affect configuration usage patterns (added/moved/remov - **`clean`**: Removed from `RLConfig`. The old `clean` flag (default: `True`) silently deleted logs, rollouts, and broadcasts on every local RL run. Superseded by the explicit `clean_output_dir` flag (2026-02-24) - **Config consolidation**: All config modules moved into `prime_rl.configs` subpackage. `utils/config.py` + `transport/config.py` → `configs/shared.py`; `trainer/config.py` + `trainer/rl/config.py` → `configs/trainer.py`; `trainer/sft/config.py` → `configs/sft.py`; `orchestrator/config.py` → `configs/orchestrator.py`; `inference/config.py` → `configs/inference.py`; `rl_config.py` → `configs/rl.py`. Class renames: `SFTTrainerConfig` → `SFTConfig`, `RLTrainerConfig` → `TrainerConfig`. Component prefixes dropped from orchestrator and inference config classes (e.g. `OrchestratorCheckpointConfig` → `CheckpointConfig`). TypeAlias renames: dropped `Type` suffix (e.g. `LossConfigType` → `LossConfig`, `TransportConfigType` → `TransportConfig`), renamed `LossConfig` class → `DefaultLossConfig`. No TOML key changes. (2026-02-24) - **`orchestrator.oversampling_factor`**: Added rollout-only over-sampling config that resolves `max_inflight_rollouts = int(batch_size * oversampling_factor)` when `max_inflight_rollouts` is unset. Cannot be used with `token_batch_size` or together with explicit `max_inflight_rollouts` (2026-02-25) +- **`sft.val_data`** and **`sft.eval`**: Added optional periodic SFT validation with `val/loss` and `val/num_batches` logging. `sft.eval` and `sft.val_data` must be configured together. Added `sft.eval.eval_on_start` (default: `False`) to optionally run validation before training starts. (2026-02-26) - **`model.fused_lm_head_chunk_size`**: Changed default value from 2048 to 8192 for RL training (2026-02-26) diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 22d3668b7f..6967861f9d 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -29,6 +29,7 @@ ) from prime_rl.trainer.parallel_dims import get_parallel_dims from prime_rl.trainer.perf import get_perf_counter +from prime_rl.trainer.models.layers.lora import set_lora_num_tokens from prime_rl.trainer.sft.data import setup_dataloader, setup_dataset from prime_rl.trainer.utils import ( MemoryProfiler, From 3c19f26b09f7cddbf369a0b69f7e94a8adb2c56d Mon Sep 17 00:00:00 2001 From: Philipp Normann Date: Fri, 27 Feb 2026 18:22:24 +0100 Subject: [PATCH 05/19] Fix SFT eval test imports after config move --- tests/unit/train/sft/test_sft_eval_config.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/unit/train/sft/test_sft_eval_config.py b/tests/unit/train/sft/test_sft_eval_config.py index 860144afcd..a26c12ea73 100644 --- a/tests/unit/train/sft/test_sft_eval_config.py +++ b/tests/unit/train/sft/test_sft_eval_config.py @@ -1,21 +1,21 @@ import pytest -from prime_rl.trainer.config import ModelConfig -from prime_rl.trainer.sft.config import SFTDataConfig, SFTEvalConfig, SFTTrainerConfig +from prime_rl.configs.sft import SFTConfig, SFTDataConfig, SFTEvalConfig +from prime_rl.configs.trainer import ModelConfig def test_sft_eval_requires_val_data(): with pytest.raises(ValueError, match="both eval and val_data"): - SFTTrainerConfig(eval=SFTEvalConfig(interval=10, num_batches=2)) + SFTConfig(eval=SFTEvalConfig(interval=10, num_batches=2)) def test_sft_val_data_requires_eval(): with pytest.raises(ValueError, match="both eval and val_data"): - SFTTrainerConfig(val_data=SFTDataConfig()) + SFTConfig(val_data=SFTDataConfig()) def test_sft_eval_with_val_data_is_valid(): - config = SFTTrainerConfig( + config = SFTConfig( eval=SFTEvalConfig(interval=10, num_batches=2, eval_on_start=True), val_data=SFTDataConfig(name="willcb/R1-reverse-wikipedia-paragraphs-v1-1000", splits=["train[:5%]"]), ) @@ -26,7 +26,7 @@ def test_sft_eval_with_val_data_is_valid(): def test_sft_val_data_requires_cp_compatible_pack_function(): with pytest.raises(ValueError, match="Validation packing function must be 'cat' when CP is enabled"): - SFTTrainerConfig( + SFTConfig( model=ModelConfig(cp=2), eval=SFTEvalConfig(interval=10, num_batches=2), val_data=SFTDataConfig(pack_function="stack", seq_len=256), @@ -35,7 +35,7 @@ def test_sft_val_data_requires_cp_compatible_pack_function(): def test_sft_val_data_requires_cp_compatible_seq_len(): with pytest.raises(ValueError, match="Validation sequence length must be divisible by CP degree"): - SFTTrainerConfig( + SFTConfig( model=ModelConfig(cp=2), eval=SFTEvalConfig(interval=10, num_batches=2), val_data=SFTDataConfig(seq_len=127), @@ -44,7 +44,7 @@ def test_sft_val_data_requires_cp_compatible_seq_len(): def test_sft_val_data_requires_cp_compatible_micro_batch_size(): with pytest.raises(ValueError, match="Validation micro batch size must be 1 when CP is enabled"): - SFTTrainerConfig( + SFTConfig( model=ModelConfig(cp=2), eval=SFTEvalConfig(interval=10, num_batches=2), val_data=SFTDataConfig(micro_batch_size=2), From 3741a1f09913797c9c1b8ef4d9a173930e13695c Mon Sep 17 00:00:00 2001 From: hallerite Date: Fri, 6 Mar 2026 03:46:44 +0000 Subject: [PATCH 06/19] refactor SFT validation config into single SFTValConfig --- CHANGELOG.md | 2 +- src/prime_rl/configs/sft.py | 25 ++---- src/prime_rl/trainer/sft/data.py | 87 +++++++++++--------- src/prime_rl/trainer/sft/train.py | 81 +++++++++--------- tests/unit/train/sft/test_sft_eval_config.py | 45 +++++----- 5 files changed, 120 insertions(+), 120 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 805c681ed0..f195e8c43b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,5 +82,5 @@ Documenting changes which affect configuration usage patterns (added/moved/remov - **`clean`**: Removed from `RLConfig`. The old `clean` flag (default: `True`) silently deleted logs, rollouts, and broadcasts on every local RL run. Superseded by the explicit `clean_output_dir` flag (2026-02-24) - **Config consolidation**: All config modules moved into `prime_rl.configs` subpackage. `utils/config.py` + `transport/config.py` → `configs/shared.py`; `trainer/config.py` + `trainer/rl/config.py` → `configs/trainer.py`; `trainer/sft/config.py` → `configs/sft.py`; `orchestrator/config.py` → `configs/orchestrator.py`; `inference/config.py` → `configs/inference.py`; `rl_config.py` → `configs/rl.py`. Class renames: `SFTTrainerConfig` → `SFTConfig`, `RLTrainerConfig` → `TrainerConfig`. Component prefixes dropped from orchestrator and inference config classes (e.g. `OrchestratorCheckpointConfig` → `CheckpointConfig`). TypeAlias renames: dropped `Type` suffix (e.g. `LossConfigType` → `LossConfig`, `TransportConfigType` → `TransportConfig`), renamed `LossConfig` class → `DefaultLossConfig`. No TOML key changes. (2026-02-24) - **`orchestrator.oversampling_factor`**: Added rollout-only over-sampling config that resolves `max_inflight_rollouts = int(batch_size * oversampling_factor)` when `max_inflight_rollouts` is unset. Cannot be used with `token_batch_size` or together with explicit `max_inflight_rollouts` (2026-02-25) -- **`sft.val_data`** and **`sft.eval`**: Added optional periodic SFT validation with `val/loss` and `val/num_batches` logging. `sft.eval` and `sft.val_data` must be configured together. Added `sft.eval.eval_on_start` (default: `False`) to optionally run validation before training starts. (2026-02-26) +- **`sft.val`**: Added optional periodic SFT validation with `val/loss` and `val/num_batches` logging. Configure via `sft.val.data` (validation dataset) and `sft.val.interval` (every N steps, default 50). Runs the full validation dataset each pass. Added `sft.val.eval_on_start` (default: `False`) to optionally run validation before training starts. (2026-02-26) - **`model.fused_lm_head_chunk_size`**: Changed default value from 2048 to 8192 for RL training (2026-02-26) diff --git a/src/prime_rl/configs/sft.py b/src/prime_rl/configs/sft.py index 10c188d98d..72576e5487 100644 --- a/src/prime_rl/configs/sft.py +++ b/src/prime_rl/configs/sft.py @@ -106,10 +106,10 @@ def validate_subsets_and_splits(self): return self -class SFTEvalConfig(BaseConfig): +class SFTValConfig(BaseConfig): interval: Annotated[int, Field(ge=1, description="Run validation every N training steps.")] = 50 - num_batches: Annotated[int, Field(ge=1, description="Number of validation batches per evaluation.")] = 8 eval_on_start: Annotated[bool, Field(description="Run a validation pass before the training loop starts.")] = False + data: SFTDataConfig = SFTDataConfig() DataConfig: TypeAlias = Annotated[FakeDataConfig | SFTDataConfig, Field(discriminator="type")] @@ -179,11 +179,8 @@ class SFTConfig(BaseSettings): # The data configuration data: DataConfig = SFTDataConfig() - # Optional validation data configuration - val_data: SFTDataConfig | None = None - - # Optional validation evaluation configuration - eval: SFTEvalConfig | None = None + # Optional validation configuration + val: SFTValConfig | None = None # The optimizer configuration optim: OptimizerConfig = AdamWConfig() @@ -282,7 +279,7 @@ def validate_pack_function(self): if self.model.cp > 1: if self.data.pack_function != "cat": raise ValueError("Packing function must be 'cat' when CP is enabled") - if self.val_data is not None and self.val_data.pack_function != "cat": + if self.val is not None and self.val.data.pack_function != "cat": raise ValueError("Validation packing function must be 'cat' when CP is enabled") return self @@ -291,7 +288,7 @@ def validate_cp_seq_len(self): if self.model.cp > 1: if self.data.seq_len % self.model.cp != 0: raise ValueError("Sequence length must be divisible by CP degree") - if self.val_data is not None and self.val_data.seq_len % self.model.cp != 0: + if self.val is not None and self.val.data.seq_len % self.model.cp != 0: raise ValueError("Validation sequence length must be divisible by CP degree") return self @@ -300,7 +297,7 @@ def validate_cp_micro_batch_size(self): if self.model.cp > 1: if self.data.micro_batch_size != 1: raise ValueError("Micro batch size must be 1 when CP is enabled") - if self.val_data is not None and self.val_data.micro_batch_size != 1: + if self.val is not None and self.val.data.micro_batch_size != 1: raise ValueError("Validation micro batch size must be 1 when CP is enabled") return self @@ -308,7 +305,7 @@ def validate_cp_micro_batch_size(self): def validate_seq_len(self): if self.data.pack_function == "stack" and self.data.seq_len % 256 != 0: raise ValueError("The sequence length must be divisible by 256 when using pack function stack") - if self.val_data is not None and self.val_data.pack_function == "stack" and self.val_data.seq_len % 256 != 0: + if self.val is not None and self.val.data.pack_function == "stack" and self.val.data.seq_len % 256 != 0: raise ValueError("The validation sequence length must be divisible by 256 when using pack function stack") return self @@ -334,12 +331,6 @@ def validate_lora_adapter_saving(self): ) return self - @model_validator(mode="after") - def validate_eval_and_val_data(self): - if (self.eval is None) != (self.val_data is None): - raise ValueError("SFT validation requires both eval and val_data to be set") - return self - @model_validator(mode="after") def validate_opt_and_fsdp_offload(self): if self.optim.type == "muon" and self.model.fsdp_cpu_offload: diff --git a/src/prime_rl/trainer/sft/data.py b/src/prime_rl/trainer/sft/data.py index 3f511b0df0..8417792223 100644 --- a/src/prime_rl/trainer/sft/data.py +++ b/src/prime_rl/trainer/sft/data.py @@ -12,7 +12,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from transformers.tokenization_utils import PreTrainedTokenizer -from prime_rl.configs.sft import DataConfig, LossMaskConfig +from prime_rl.configs.sft import DataConfig, LossMaskConfig, SFTDataConfig from prime_rl.trainer.world import get_world from prime_rl.utils.logger import get_logger @@ -541,54 +541,67 @@ def setup_and_interleave_datasets( return dataset -def setup_dataset(tokenizer: PreTrainedTokenizer, config: DataConfig, non_dp_size: int = 1) -> StatefulIterableDataset: +def load_sft_dataset(config: SFTDataConfig) -> Dataset: + """Load and interleave the raw HF dataset. This is the expensive I/O step.""" + logger = get_logger() + if config.subsets is None and config.splits is None: + return setup_and_interleave_datasets( + dataset_name=config.name, + subsets_and_splits=[(None, "train")], + probabilities=config.probabilities, + stopping_strategy=config.stopping_strategy, + ) + elif config.subsets is not None and config.splits is None: + logger.debug(f"Loading datasets for subsets {config.subsets} with default split 'train'") + return setup_and_interleave_datasets( + dataset_name=config.name, + subsets_and_splits=[(subset, "train") for subset in config.subsets], + probabilities=config.probabilities, + stopping_strategy=config.stopping_strategy, + ) + elif config.subsets is None and config.splits is not None: + logger.debug(f"Loading datasets for splits {config.splits} with default subset 'None'") + return setup_and_interleave_datasets( + dataset_name=config.name, + subsets_and_splits=[(None, split) for split in config.splits], + probabilities=config.probabilities, + stopping_strategy=config.stopping_strategy, + ) + else: + assert config.subsets is not None and config.splits is not None + logger.debug(f"Loading datasets for subsets {config.subsets} with splits {config.splits}") + return setup_and_interleave_datasets( + dataset_name=config.name, + subsets_and_splits=list(zip(config.subsets, config.splits)), + probabilities=config.probabilities, + stopping_strategy=config.stopping_strategy, + ) + + +def setup_dataset( + tokenizer: PreTrainedTokenizer, + config: DataConfig, + non_dp_size: int = 1, + *, + max_epochs: int | None = None, + raw_dataset: Dataset | None = None, +) -> StatefulIterableDataset: if config.type == "fake": - # Shouldnt matter to handle non_dp_size if dataset is random return FakeDataset( vocab_size=tokenizer.vocab_size, seq_len=config.seq_len, length=config.length, input_ids=config.input_ids ) elif config.type == "sft": - logger = get_logger() - if config.subsets is None and config.splits is None: - dataset = setup_and_interleave_datasets( - dataset_name=config.name, - subsets_and_splits=[(None, "train")], - probabilities=config.probabilities, - stopping_strategy=config.stopping_strategy, - ) - elif config.subsets is not None and config.splits is None: - logger.debug(f"Loading datasets for subsets {config.subsets} with default split 'train'") - dataset = setup_and_interleave_datasets( - dataset_name=config.name, - subsets_and_splits=[(subset, "train") for subset in config.subsets], - probabilities=config.probabilities, - stopping_strategy=config.stopping_strategy, - ) - elif config.subsets is None and config.splits is not None: - logger.debug(f"Loading datasets for splits {config.splits} with default subset 'None'") - dataset = setup_and_interleave_datasets( - dataset_name=config.name, - subsets_and_splits=[(None, split) for split in config.splits], - probabilities=config.probabilities, - stopping_strategy=config.stopping_strategy, - ) - else: - assert config.subsets is not None and config.splits is not None - logger.debug(f"Loading datasets for subsets {config.subsets} with splits {config.splits}") - dataset = setup_and_interleave_datasets( - dataset_name=config.name, - subsets_and_splits=list(zip(config.subsets, config.splits)), - probabilities=config.probabilities, - stopping_strategy=config.stopping_strategy, - ) + if raw_dataset is None: + raw_dataset = load_sft_dataset(config) return SFTDataset( - dataset, + raw_dataset, tokenizer, shuffle=config.shuffle, seed=config.seed, seq_len=config.seq_len, loss_mask_config=config.loss_mask, non_dp_size=non_dp_size, + max_epochs=max_epochs, ) else: raise ValueError(f"Invalid dataset type: {config.type}") diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 6967861f9d..5dc96d1dd8 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -30,7 +30,7 @@ from prime_rl.trainer.parallel_dims import get_parallel_dims from prime_rl.trainer.perf import get_perf_counter from prime_rl.trainer.models.layers.lora import set_lora_num_tokens -from prime_rl.trainer.sft.data import setup_dataloader, setup_dataset +from prime_rl.trainer.sft.data import load_sft_dataset, setup_dataloader, setup_dataset from prime_rl.trainer.utils import ( MemoryProfiler, export_benchmark_json, @@ -142,12 +142,10 @@ def train(config: SFTConfig): dataloader = setup_dataloader(dataset, config.data) dataiter = iter(dataloader) - val_dataiter = None - if config.eval is not None and config.val_data is not None: - logger.info(f"Initializing validation data ({config.val_data})") - val_dataset = setup_dataset(tokenizer, config.val_data, config.model.cp * config.model.tp) - val_dataloader = setup_dataloader(val_dataset, config.val_data) - val_dataiter = iter(val_dataloader) + val_raw_dataset = None + if config.val is not None: + logger.info(f"Loading validation dataset ({config.val.data.name})") + val_raw_dataset = load_sft_dataset(config.val.data) # Optionally, resume training from a checkpoint progress = Progress() @@ -182,19 +180,25 @@ def train(config: SFTConfig): case _: raise ValueError(f"Invalid loss implementation: {config.loss_impl}") - def run_validation(eval_step: int) -> None: - assert config.eval is not None - assert val_dataiter is not None + def run_validation(step: int) -> None: + val_dataset = setup_dataset( + tokenizer, + config.val.data, + config.model.cp * config.model.tp, + max_epochs=1, + raw_dataset=val_raw_dataset, + ) + val_dataloader = setup_dataloader(val_dataset, config.val.data) was_training = model.training model.eval() val_loss_sum = torch.tensor(0.0, device="cuda") - val_loss_count = torch.tensor(0.0, device="cuda") + num_batches = torch.tensor(0, dtype=torch.int64, device="cuda") + nan_batches = torch.tensor(0, dtype=torch.int64, device="cuda") with torch.no_grad(): - for _ in range(config.eval.num_batches): - micro_batch = next(val_dataiter) + for micro_batch in val_dataloader: input_ids = micro_batch["input_ids"].to("cuda") position_ids = micro_batch["position_ids"].to("cuda") target_ids = micro_batch["target_ids"].to("cuda") @@ -215,39 +219,37 @@ def run_validation(eval_step: int) -> None: loss = ce_loss(logits.view(-1, V), target_ids.view(-1)).view(B, L) loss = loss[loss_mask].mean() - if not torch.isnan(loss): + if torch.isnan(loss): + nan_batches += 1 + else: val_loss_sum += loss.detach() - val_loss_count += 1 + num_batches += 1 - del logits + del out, logits dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_loss_count, op=dist.ReduceOp.SUM) - - if val_loss_count.item() == 0: - logger.warning(f"Validation at step {eval_step} had only NaN losses") - val_metrics = { - "val/loss": float("nan"), - "val/num_batches": 0.0, - "step": eval_step, - } + dist.all_reduce(num_batches, op=dist.ReduceOp.SUM) + dist.all_reduce(nan_batches, op=dist.ReduceOp.SUM) + + total_batches = num_batches.item() + nan_batches.item() + + if nan_batches.item() > 0: + logger.warning(f"Validation at step {step}: {nan_batches.item()}/{total_batches} batches had NaN loss") + + if num_batches.item() == 0: + logger.warning(f"Validation at step {step} had only NaN losses") + val_metrics = {"val/loss": float("nan"), "val/num_batches": 0, "step": step} else: - mean_val_loss = (val_loss_sum / val_loss_count).item() - logger.success( - f"Validation | Step {eval_step} | Loss: {mean_val_loss:.4f} | Batches: {int(val_loss_count.item())}" - ) - val_metrics = { - "val/loss": mean_val_loss, - "val/num_batches": val_loss_count.item(), - "step": eval_step, - } + mean_val_loss = (val_loss_sum / num_batches).item() + logger.success(f"Validation | Step {step} | Loss: {mean_val_loss:.4f} | Batches: {total_batches}") + val_metrics = {"val/loss": mean_val_loss, "val/num_batches": total_batches, "step": step} - monitor.log(val_metrics, step=eval_step) + monitor.log(val_metrics, step=step) if was_training: model.train() - if config.eval is not None and val_dataiter is not None and config.eval.eval_on_start: + if config.val is not None and config.val.eval_on_start: run_validation(progress.step) logger.info(f"Starting training loop (max_steps={config.max_steps or 'infinite'})") @@ -481,12 +483,7 @@ def run_validation(eval_step: int) -> None: } monitor.log(max_vio_log_metrics, step=progress.step) - if ( - config.eval is not None - and val_dataiter is not None - and not is_first_step - and progress.step % config.eval.interval == 0 - ): + if config.val is not None and not is_first_step and progress.step % config.val.interval == 0: run_validation(progress.step) is_first_step = False diff --git a/tests/unit/train/sft/test_sft_eval_config.py b/tests/unit/train/sft/test_sft_eval_config.py index a26c12ea73..4a774675cc 100644 --- a/tests/unit/train/sft/test_sft_eval_config.py +++ b/tests/unit/train/sft/test_sft_eval_config.py @@ -1,35 +1,30 @@ import pytest -from prime_rl.configs.sft import SFTConfig, SFTDataConfig, SFTEvalConfig +from prime_rl.configs.sft import SFTConfig, SFTDataConfig, SFTValConfig from prime_rl.configs.trainer import ModelConfig -def test_sft_eval_requires_val_data(): - with pytest.raises(ValueError, match="both eval and val_data"): - SFTConfig(eval=SFTEvalConfig(interval=10, num_batches=2)) - - -def test_sft_val_data_requires_eval(): - with pytest.raises(ValueError, match="both eval and val_data"): - SFTConfig(val_data=SFTDataConfig()) - - -def test_sft_eval_with_val_data_is_valid(): +def test_sft_val_config_is_valid(): config = SFTConfig( - eval=SFTEvalConfig(interval=10, num_batches=2, eval_on_start=True), - val_data=SFTDataConfig(name="willcb/R1-reverse-wikipedia-paragraphs-v1-1000", splits=["train[:5%]"]), + val=SFTValConfig( + interval=10, + eval_on_start=True, + data=SFTDataConfig(name="willcb/R1-reverse-wikipedia-paragraphs-v1-1000", splits=["train[:5%]"]), + ), ) - assert config.eval is not None - assert config.val_data is not None - assert config.eval.eval_on_start is True + assert config.val is not None + assert config.val.eval_on_start is True + assert config.val.data.name == "willcb/R1-reverse-wikipedia-paragraphs-v1-1000" def test_sft_val_data_requires_cp_compatible_pack_function(): with pytest.raises(ValueError, match="Validation packing function must be 'cat' when CP is enabled"): SFTConfig( model=ModelConfig(cp=2), - eval=SFTEvalConfig(interval=10, num_batches=2), - val_data=SFTDataConfig(pack_function="stack", seq_len=256), + val=SFTValConfig( + interval=10, + data=SFTDataConfig(pack_function="stack", seq_len=256), + ), ) @@ -37,8 +32,10 @@ def test_sft_val_data_requires_cp_compatible_seq_len(): with pytest.raises(ValueError, match="Validation sequence length must be divisible by CP degree"): SFTConfig( model=ModelConfig(cp=2), - eval=SFTEvalConfig(interval=10, num_batches=2), - val_data=SFTDataConfig(seq_len=127), + val=SFTValConfig( + interval=10, + data=SFTDataConfig(seq_len=127), + ), ) @@ -46,6 +43,8 @@ def test_sft_val_data_requires_cp_compatible_micro_batch_size(): with pytest.raises(ValueError, match="Validation micro batch size must be 1 when CP is enabled"): SFTConfig( model=ModelConfig(cp=2), - eval=SFTEvalConfig(interval=10, num_batches=2), - val_data=SFTDataConfig(micro_batch_size=2), + val=SFTValConfig( + interval=10, + data=SFTDataConfig(micro_batch_size=2), + ), ) From a34ca733d85aed6f39aa69e4abced4611b9c1bd6 Mon Sep 17 00:00:00 2001 From: hallerite Date: Fri, 6 Mar 2026 04:23:16 +0000 Subject: [PATCH 07/19] remove tests --- tests/unit/train/sft/test_sft_eval_config.py | 50 -------------------- 1 file changed, 50 deletions(-) delete mode 100644 tests/unit/train/sft/test_sft_eval_config.py diff --git a/tests/unit/train/sft/test_sft_eval_config.py b/tests/unit/train/sft/test_sft_eval_config.py deleted file mode 100644 index 4a774675cc..0000000000 --- a/tests/unit/train/sft/test_sft_eval_config.py +++ /dev/null @@ -1,50 +0,0 @@ -import pytest - -from prime_rl.configs.sft import SFTConfig, SFTDataConfig, SFTValConfig -from prime_rl.configs.trainer import ModelConfig - - -def test_sft_val_config_is_valid(): - config = SFTConfig( - val=SFTValConfig( - interval=10, - eval_on_start=True, - data=SFTDataConfig(name="willcb/R1-reverse-wikipedia-paragraphs-v1-1000", splits=["train[:5%]"]), - ), - ) - assert config.val is not None - assert config.val.eval_on_start is True - assert config.val.data.name == "willcb/R1-reverse-wikipedia-paragraphs-v1-1000" - - -def test_sft_val_data_requires_cp_compatible_pack_function(): - with pytest.raises(ValueError, match="Validation packing function must be 'cat' when CP is enabled"): - SFTConfig( - model=ModelConfig(cp=2), - val=SFTValConfig( - interval=10, - data=SFTDataConfig(pack_function="stack", seq_len=256), - ), - ) - - -def test_sft_val_data_requires_cp_compatible_seq_len(): - with pytest.raises(ValueError, match="Validation sequence length must be divisible by CP degree"): - SFTConfig( - model=ModelConfig(cp=2), - val=SFTValConfig( - interval=10, - data=SFTDataConfig(seq_len=127), - ), - ) - - -def test_sft_val_data_requires_cp_compatible_micro_batch_size(): - with pytest.raises(ValueError, match="Validation micro batch size must be 1 when CP is enabled"): - SFTConfig( - model=ModelConfig(cp=2), - val=SFTValConfig( - interval=10, - data=SFTDataConfig(micro_batch_size=2), - ), - ) From bb6d3ea7dae613c3962ed18d00a2dbb33367437c Mon Sep 17 00:00:00 2001 From: hallerite Date: Fri, 6 Mar 2026 04:46:48 +0000 Subject: [PATCH 08/19] factorize --- src/prime_rl/trainer/sft/train.py | 112 +++++++++++------------------- 1 file changed, 42 insertions(+), 70 deletions(-) diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 5dc96d1dd8..8ee444ca3e 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -29,7 +29,6 @@ ) from prime_rl.trainer.parallel_dims import get_parallel_dims from prime_rl.trainer.perf import get_perf_counter -from prime_rl.trainer.models.layers.lora import set_lora_num_tokens from prime_rl.trainer.sft.data import load_sft_dataset, setup_dataloader, setup_dataset from prime_rl.trainer.utils import ( MemoryProfiler, @@ -180,6 +179,26 @@ def train(config: SFTConfig): case _: raise ValueError(f"Invalid loss implementation: {config.loss_impl}") + def compute_loss(micro_batch: dict) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass + CE loss. Returns (per_token_loss, loss_mask), both (B, L).""" + input_ids = micro_batch["input_ids"].to("cuda") + position_ids = micro_batch["position_ids"].to("cuda") + target_ids = micro_batch["target_ids"].to("cuda") + loss_mask = micro_batch["loss_mask"].to("cuda") + + if cp_enabled: + input_ids, position_ids = setup_cp_params(input_ids, position_ids, cp_rank, cp_size, cp_group) + target_ids = shard_for_cp(target_ids, cp_rank=cp_rank, cp_world_size=cp_size) + loss_mask = shard_for_cp(loss_mask, cp_rank=cp_rank, cp_world_size=cp_size) + + out = forward(model, input_ids, position_ids) + logits = out["logits"] + B, L, V = logits.shape + per_token_loss = ce_loss(logits.view(-1, V), target_ids.view(-1)).view(B, L) + + del out, logits + return per_token_loss, loss_mask + def run_validation(step: int) -> None: val_dataset = setup_dataset( tokenizer, @@ -193,56 +212,35 @@ def run_validation(step: int) -> None: was_training = model.training model.eval() - val_loss_sum = torch.tensor(0.0, device="cuda") - num_batches = torch.tensor(0, dtype=torch.int64, device="cuda") + total_loss = torch.tensor(0.0, device="cuda") + total_tokens = torch.tensor(0, dtype=torch.int64, device="cuda") nan_batches = torch.tensor(0, dtype=torch.int64, device="cuda") with torch.no_grad(): for micro_batch in val_dataloader: - input_ids = micro_batch["input_ids"].to("cuda") - position_ids = micro_batch["position_ids"].to("cuda") - target_ids = micro_batch["target_ids"].to("cuda") - loss_mask = micro_batch["loss_mask"].to("cuda") - - if cp_enabled: - input_ids, position_ids = setup_cp_params(input_ids, position_ids, cp_rank, cp_size, cp_group) - target_ids = shard_for_cp(target_ids, cp_rank=cp_rank, cp_world_size=cp_size) - loss_mask = shard_for_cp(loss_mask, cp_rank=cp_rank, cp_world_size=cp_size) - - if config.model.lora is not None: - lora_num_tokens = torch.full((1,), input_ids.numel(), dtype=torch.int32, device="cuda") - set_lora_num_tokens(lora_num_tokens) - - out = forward(model, input_ids, position_ids) - logits = out["logits"] - B, L, V = logits.shape - loss = ce_loss(logits.view(-1, V), target_ids.view(-1)).view(B, L) - loss = loss[loss_mask].mean() - - if torch.isnan(loss): + per_token_loss, loss_mask = compute_loss(micro_batch) + masked_loss = per_token_loss[loss_mask] + + if torch.isnan(masked_loss).any(): nan_batches += 1 else: - val_loss_sum += loss.detach() - num_batches += 1 - - del out, logits + total_loss += masked_loss.sum() + total_tokens += masked_loss.numel() - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(num_batches, op=dist.ReduceOp.SUM) + dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(total_tokens, op=dist.ReduceOp.SUM) dist.all_reduce(nan_batches, op=dist.ReduceOp.SUM) - total_batches = num_batches.item() + nan_batches.item() - if nan_batches.item() > 0: - logger.warning(f"Validation at step {step}: {nan_batches.item()}/{total_batches} batches had NaN loss") + logger.warning(f"Validation at step {step}: {nan_batches.item()} batches had NaN loss") - if num_batches.item() == 0: - logger.warning(f"Validation at step {step} had only NaN losses") - val_metrics = {"val/loss": float("nan"), "val/num_batches": 0, "step": step} + if total_tokens.item() == 0: + logger.warning(f"Validation at step {step} had no valid tokens") + val_metrics = {"val/loss": float("nan"), "val/num_tokens": 0, "step": step} else: - mean_val_loss = (val_loss_sum / num_batches).item() - logger.success(f"Validation | Step {step} | Loss: {mean_val_loss:.4f} | Batches: {total_batches}") - val_metrics = {"val/loss": mean_val_loss, "val/num_batches": total_batches, "step": step} + mean_val_loss = (total_loss / total_tokens).item() + logger.success(f"Validation | Step {step} | Loss: {mean_val_loss:.4f} | Tokens: {total_tokens.item()}") + val_metrics = {"val/loss": mean_val_loss, "val/num_tokens": total_tokens.item(), "step": step} monitor.log(val_metrics, step=step) @@ -305,53 +303,27 @@ def run_validation(step: int) -> None: batch_max_vio, max_vio = torch.tensor(0.0).to("cuda"), None for micro_step in range(grad_accum_steps): micro_batch = next(dataiter) - input_ids = micro_batch["input_ids"].to("cuda") - position_ids = micro_batch["position_ids"].to("cuda") - target_ids = micro_batch["target_ids"].to("cuda") - loss_mask = micro_batch["loss_mask"].to("cuda") - - if cp_enabled: - input_ids, position_ids = setup_cp_params(input_ids, position_ids, cp_rank, cp_size, cp_group) - target_ids = shard_for_cp(target_ids, cp_rank=cp_rank, cp_world_size=cp_size) - loss_mask = shard_for_cp(loss_mask, cp_rank=cp_rank, cp_world_size=cp_size) - - assert input_ids.shape == position_ids.shape == target_ids.shape == loss_mask.shape, ( - f"input_ids.shape: {input_ids.shape}, position_ids.shape: {position_ids.shape}, target_ids.shape: {target_ids.shape}, loss_mask.shape: {loss_mask.shape}" - ) if config.log.log_data: logger.debug("Printing samples of the first micro batch") - print_sample(input_ids.flatten().tolist(), loss_mask.flatten().tolist(), tokenizer) + input_ids_log = micro_batch["input_ids"].flatten().tolist() + loss_mask_log = micro_batch["loss_mask"].flatten().tolist() + print_sample(input_ids_log, loss_mask_log, tokenizer) - # Forward pass logger.debug("Starting forward pass") with maybe_record_function("forward"), maybe_activation_offloading(config.model.ac_offloading): - out = forward(model, input_ids, position_ids) - - logits = out["logits"] - B, L, V = logits.shape + per_token_loss, loss_mask = compute_loss(micro_batch) - # Compute loss - loss = ce_loss(logits.view(-1, V), target_ids.view(-1)).view(B, L) - - # Compute average loss over unmasked tokens - loss = loss[loss_mask].mean() - - # Accumulate average loss over gradient accumulation steps + loss = per_token_loss[loss_mask].mean() current_loss = loss.detach() / grad_accum_steps - # only add if the loss is not nan if not torch.isnan(current_loss): batch_loss += current_loss else: nan_loss_count += 1 logger.warning("Loss is nan, not taking into account in the batch loss calculation") - # Delete logits before backward pass to avoid memory spike - del logits - - # Backward pass logger.debug("Starting backward pass") with maybe_record_function("backward"): (loss / grad_accum_steps).backward() From aee17b50edf76f7245e2151e0943ae1f412f46a7 Mon Sep 17 00:00:00 2001 From: Philipp Normann Date: Fri, 6 Mar 2026 10:25:26 +0100 Subject: [PATCH 09/19] remove eval_on_start --- CHANGELOG.md | 2 +- src/prime_rl/configs/sft.py | 1 - src/prime_rl/trainer/sft/train.py | 3 --- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f195e8c43b..031d9bff8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,5 +82,5 @@ Documenting changes which affect configuration usage patterns (added/moved/remov - **`clean`**: Removed from `RLConfig`. The old `clean` flag (default: `True`) silently deleted logs, rollouts, and broadcasts on every local RL run. Superseded by the explicit `clean_output_dir` flag (2026-02-24) - **Config consolidation**: All config modules moved into `prime_rl.configs` subpackage. `utils/config.py` + `transport/config.py` → `configs/shared.py`; `trainer/config.py` + `trainer/rl/config.py` → `configs/trainer.py`; `trainer/sft/config.py` → `configs/sft.py`; `orchestrator/config.py` → `configs/orchestrator.py`; `inference/config.py` → `configs/inference.py`; `rl_config.py` → `configs/rl.py`. Class renames: `SFTTrainerConfig` → `SFTConfig`, `RLTrainerConfig` → `TrainerConfig`. Component prefixes dropped from orchestrator and inference config classes (e.g. `OrchestratorCheckpointConfig` → `CheckpointConfig`). TypeAlias renames: dropped `Type` suffix (e.g. `LossConfigType` → `LossConfig`, `TransportConfigType` → `TransportConfig`), renamed `LossConfig` class → `DefaultLossConfig`. No TOML key changes. (2026-02-24) - **`orchestrator.oversampling_factor`**: Added rollout-only over-sampling config that resolves `max_inflight_rollouts = int(batch_size * oversampling_factor)` when `max_inflight_rollouts` is unset. Cannot be used with `token_batch_size` or together with explicit `max_inflight_rollouts` (2026-02-25) -- **`sft.val`**: Added optional periodic SFT validation with `val/loss` and `val/num_batches` logging. Configure via `sft.val.data` (validation dataset) and `sft.val.interval` (every N steps, default 50). Runs the full validation dataset each pass. Added `sft.val.eval_on_start` (default: `False`) to optionally run validation before training starts. (2026-02-26) +- **`sft.val`**: Added optional periodic SFT validation with `val/loss` and `val/num_tokens` logging. Configure via `sft.val.data` (validation dataset) and `sft.val.interval` (every N steps, default 50). Runs the full validation dataset each pass. (2026-02-26) - **`model.fused_lm_head_chunk_size`**: Changed default value from 2048 to 8192 for RL training (2026-02-26) diff --git a/src/prime_rl/configs/sft.py b/src/prime_rl/configs/sft.py index 72576e5487..a6f37c0be1 100644 --- a/src/prime_rl/configs/sft.py +++ b/src/prime_rl/configs/sft.py @@ -108,7 +108,6 @@ def validate_subsets_and_splits(self): class SFTValConfig(BaseConfig): interval: Annotated[int, Field(ge=1, description="Run validation every N training steps.")] = 50 - eval_on_start: Annotated[bool, Field(description="Run a validation pass before the training loop starts.")] = False data: SFTDataConfig = SFTDataConfig() diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 8ee444ca3e..754ca20213 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -247,9 +247,6 @@ def run_validation(step: int) -> None: if was_training: model.train() - if config.val is not None and config.val.eval_on_start: - run_validation(progress.step) - logger.info(f"Starting training loop (max_steps={config.max_steps or 'infinite'})") max_memory = torch.cuda.mem_get_info()[1] / 1024**3 # GiB is_first_step = True From 4ecfd60ca73444834e9c4ba760b941ccb2dee799 Mon Sep 17 00:00:00 2001 From: Philipp Normann Date: Fri, 6 Mar 2026 10:30:23 +0100 Subject: [PATCH 10/19] require sft val data --- src/prime_rl/configs/sft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/prime_rl/configs/sft.py b/src/prime_rl/configs/sft.py index a6f37c0be1..396d58ee59 100644 --- a/src/prime_rl/configs/sft.py +++ b/src/prime_rl/configs/sft.py @@ -108,7 +108,7 @@ def validate_subsets_and_splits(self): class SFTValConfig(BaseConfig): interval: Annotated[int, Field(ge=1, description="Run validation every N training steps.")] = 50 - data: SFTDataConfig = SFTDataConfig() + data: SFTDataConfig DataConfig: TypeAlias = Annotated[FakeDataConfig | SFTDataConfig, Field(discriminator="type")] From 9a3384d784322d6daddf316b75c0dd96db024466 Mon Sep 17 00:00:00 2001 From: hallerite Date: Fri, 6 Mar 2026 19:38:11 +0000 Subject: [PATCH 11/19] handle ac --- src/prime_rl/trainer/sft/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 754ca20213..10295133de 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -191,7 +191,8 @@ def compute_loss(micro_batch: dict) -> tuple[torch.Tensor, torch.Tensor]: target_ids = shard_for_cp(target_ids, cp_rank=cp_rank, cp_world_size=cp_size) loss_mask = shard_for_cp(loss_mask, cp_rank=cp_rank, cp_world_size=cp_size) - out = forward(model, input_ids, position_ids) + with maybe_activation_offloading(config.model.ac_offloading): + out = forward(model, input_ids, position_ids) logits = out["logits"] B, L, V = logits.shape per_token_loss = ce_loss(logits.view(-1, V), target_ids.view(-1)).view(B, L) @@ -308,7 +309,7 @@ def run_validation(step: int) -> None: print_sample(input_ids_log, loss_mask_log, tokenizer) logger.debug("Starting forward pass") - with maybe_record_function("forward"), maybe_activation_offloading(config.model.ac_offloading): + with maybe_record_function("forward"): per_token_loss, loss_mask = compute_loss(micro_batch) loss = per_token_loss[loss_mask].mean() From bf4b6dde95a49fe854e45bdb77521cc6dd936fb6 Mon Sep 17 00:00:00 2001 From: hallerite Date: Fri, 6 Mar 2026 20:20:25 +0000 Subject: [PATCH 12/19] refactor --- src/prime_rl/trainer/sft/train.py | 153 +++++++++++------------------- 1 file changed, 56 insertions(+), 97 deletions(-) diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 10295133de..fdca0b7789 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -11,7 +11,7 @@ from prime_rl.trainer.models.layers.attn import substitute_ring_attn from prime_rl.utils.act_offloading import maybe_activation_offloading import torch -from torch.profiler import profile, ProfilerActivity, record_function +from torch.profiler import profile, ProfilerActivity from prime_rl.trainer.ckpt import setup_ckpt_managers from prime_rl.utils.pathing import resolve_latest_ckpt_step from prime_rl.configs.sft import SFTConfig @@ -34,7 +34,6 @@ MemoryProfiler, export_benchmark_json, get_ckpt_disk_metrics, - print_sample, setup_torch_distributed, print_benchmark, ) @@ -180,7 +179,6 @@ def train(config: SFTConfig): raise ValueError(f"Invalid loss implementation: {config.loss_impl}") def compute_loss(micro_batch: dict) -> tuple[torch.Tensor, torch.Tensor]: - """Forward pass + CE loss. Returns (per_token_loss, loss_mask), both (B, L).""" input_ids = micro_batch["input_ids"].to("cuda") position_ids = micro_batch["position_ids"].to("cuda") target_ids = micro_batch["target_ids"].to("cuda") @@ -200,62 +198,73 @@ def compute_loss(micro_batch: dict) -> tuple[torch.Tensor, torch.Tensor]: del out, logits return per_token_loss, loss_mask - def run_validation(step: int) -> None: - val_dataset = setup_dataset( - tokenizer, - config.val.data, - config.model.cp * config.model.tp, - max_epochs=1, - raw_dataset=val_raw_dataset, - ) - val_dataloader = setup_dataloader(val_dataset, config.val.data) - - was_training = model.training - model.eval() - + def run_forward_loop(data_iter, num_steps=None, *, backward=True): total_loss = torch.tensor(0.0, device="cuda") - total_tokens = torch.tensor(0, dtype=torch.int64, device="cuda") - nan_batches = torch.tensor(0, dtype=torch.int64, device="cuda") - - with torch.no_grad(): - for micro_batch in val_dataloader: + nan_count = torch.tensor(0, device="cuda") + valid_steps = torch.tensor(0, device="cuda") + max_vio_total = torch.tensor(0.0, device="cuda") + divisor = num_steps or 1 + + ctx = nullcontext() if backward else torch.no_grad() + with ctx: + for step, micro_batch in enumerate(data_iter): per_token_loss, loss_mask = compute_loss(micro_batch) - masked_loss = per_token_loss[loss_mask] + loss = per_token_loss[loss_mask].mean() - if torch.isnan(masked_loss).any(): - nan_batches += 1 + if not torch.isnan(loss.detach()): + total_loss += loss.detach() + valid_steps += 1 else: - total_loss += masked_loss.sum() - total_tokens += masked_loss.numel() + nan_count += 1 + + if backward: + (loss / divisor).backward() + + if is_tt_moe_model(model): + max_vio = get_load_balance_stats(model)["max_vio"] + if max_vio is not None: + max_vio = max_vio.mean() + dist.all_reduce(max_vio, op=dist.ReduceOp.MAX) + max_vio_total += max_vio / divisor + + if num_steps is not None and step + 1 >= num_steps: + break dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) - dist.all_reduce(total_tokens, op=dist.ReduceOp.SUM) - dist.all_reduce(nan_batches, op=dist.ReduceOp.SUM) + dist.all_reduce(valid_steps, op=dist.ReduceOp.SUM) + dist.all_reduce(nan_count, op=dist.ReduceOp.SUM) - if nan_batches.item() > 0: - logger.warning(f"Validation at step {step}: {nan_batches.item()} batches had NaN loss") + mean_loss = (total_loss / valid_steps).item() if valid_steps.item() > 0 else float("nan") + return mean_loss, nan_count.item(), max_vio_total + + def run_validation(step: int) -> None: + val_dataset = setup_dataset( + tokenizer, config.val.data, config.model.cp * config.model.tp, max_epochs=1, raw_dataset=val_raw_dataset + ) + val_dataloader = setup_dataloader(val_dataset, config.val.data) - if total_tokens.item() == 0: + was_training = model.training + model.eval() + mean_loss, nan_count, _ = run_forward_loop(val_dataloader, backward=False) + if nan_count > 0: + logger.warning(f"Validation at step {step}: {nan_count} batches had NaN loss") + if mean_loss != mean_loss: logger.warning(f"Validation at step {step} had no valid tokens") - val_metrics = {"val/loss": float("nan"), "val/num_tokens": 0, "step": step} else: - mean_val_loss = (total_loss / total_tokens).item() - logger.success(f"Validation | Step {step} | Loss: {mean_val_loss:.4f} | Tokens: {total_tokens.item()}") - val_metrics = {"val/loss": mean_val_loss, "val/num_tokens": total_tokens.item(), "step": step} - - monitor.log(val_metrics, step=step) - + logger.success(f"Validation | Step {step} | Loss: {mean_loss:.4f}") + monitor.log({"val/loss": mean_loss, "step": step}, step=step) if was_training: model.train() + if config.val is not None and config.val.eval_on_start: + run_validation(progress.step) + logger.info(f"Starting training loop (max_steps={config.max_steps or 'infinite'})") max_memory = torch.cuda.mem_get_info()[1] / 1024**3 # GiB is_first_step = True - maybe_record_function = nullcontext if config.trace_path: logger.info(f"Tracing to {config.trace_path}") prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True).__enter__() - maybe_record_function = record_function while True: # Reset peak memory stats torch.cuda.reset_peak_memory_stats() @@ -296,48 +305,7 @@ def run_validation(step: int) -> None: step_start_time = time.perf_counter() forward_backward_start_time = time.perf_counter() - batch_loss = torch.tensor(0.0).to("cuda") - nan_loss_count = torch.tensor(0).to("cuda") - batch_max_vio, max_vio = torch.tensor(0.0).to("cuda"), None - for micro_step in range(grad_accum_steps): - micro_batch = next(dataiter) - - if config.log.log_data: - logger.debug("Printing samples of the first micro batch") - input_ids_log = micro_batch["input_ids"].flatten().tolist() - loss_mask_log = micro_batch["loss_mask"].flatten().tolist() - print_sample(input_ids_log, loss_mask_log, tokenizer) - - logger.debug("Starting forward pass") - with maybe_record_function("forward"): - per_token_loss, loss_mask = compute_loss(micro_batch) - - loss = per_token_loss[loss_mask].mean() - - current_loss = loss.detach() / grad_accum_steps - - if not torch.isnan(current_loss): - batch_loss += current_loss - else: - nan_loss_count += 1 - logger.warning("Loss is nan, not taking into account in the batch loss calculation") - - logger.debug("Starting backward pass") - with maybe_record_function("backward"): - (loss / grad_accum_steps).backward() - - if is_tt_moe_model(model): - max_vio = get_load_balance_stats(model)["max_vio"] - if max_vio is not None: - max_vio = max_vio.mean() - dist.all_reduce(max_vio, op=dist.ReduceOp.MAX) - batch_max_vio += max_vio / grad_accum_steps - - # Debug log with *local, micro step* stats - micro_step_message = f"Micro Step {micro_step}/{grad_accum_steps} | Loss: {loss.item():.4f} | Dataloader Step: {dataloader.state_dict()['dataset_state']['dataset']['step']}" - if is_tt_moe_model(model) and max_vio is not None: - micro_step_message += f" | Max Vio: {max_vio.item():.4f}" - logger.debug(micro_step_message) + batch_loss, nan_loss_count, batch_max_vio = run_forward_loop(dataiter, grad_accum_steps, backward=True) logger.debug(f"Clipping gradients with max norm {config.optim.max_norm}") grad_norm = clip_grad_norm_( @@ -360,11 +328,6 @@ def run_validation(step: int) -> None: if memory_profiler is not None: memory_profiler.step() - # Synchronize the tensor metrics across all steps and ranks - logger.debug("Synchronizing tensor metrics across all steps and ranks") - dist.all_reduce(batch_loss, op=dist.ReduceOp.AVG) - dist.all_reduce(nan_loss_count, op=dist.ReduceOp.SUM) - # Compute step metrics # Divide by CP and TP since those ranks process the same data num_tokens = config.data.batch_size * config.data.seq_len // (config.model.cp * config.model.tp) @@ -378,8 +341,8 @@ def run_validation(step: int) -> None: # Log step metrics step_time = time.perf_counter() - step_start_time - step_message = f"Step {progress.step} | Time: {step_time:.2f}s | Loss: {batch_loss.item():.4f} | Grad. Norm: {grad_norm:.4f} | LR: {current_lr:.2e} | Throughput: {throughput:.0f} tokens/s | MFU: {mfu:.1f}% | Peak Mem.: {peak_memory:.1f}/{max_memory:.1f} GiB ({peak_memory / max_memory * 100:.1f}%)" - if is_tt_moe_model(model) and max_vio is not None: + step_message = f"Step {progress.step} | Time: {step_time:.2f}s | Loss: {batch_loss:.4f} | Grad. Norm: {grad_norm:.4f} | LR: {current_lr:.2e} | Throughput: {throughput:.0f} tokens/s | MFU: {mfu:.1f}% | Peak Mem.: {peak_memory:.1f}/{max_memory:.1f} GiB ({peak_memory / max_memory * 100:.1f}%)" + if is_tt_moe_model(model) and batch_max_vio.item() > 0: step_message += f" | Max Vio: {batch_max_vio.item():.4f}" logger.success(step_message) @@ -425,8 +388,8 @@ def run_validation(step: int) -> None: monitor.log(optim_metrics, step=progress.step) loss_log_metrics = { - "loss/mean": batch_loss.item(), - "loss/nan_count": nan_loss_count.item(), + "loss/mean": batch_loss, + "loss/nan_count": nan_loss_count, "step": progress.step, } # Log tensor stats @@ -446,12 +409,8 @@ def run_validation(step: int) -> None: disk_metrics["step"] = progress.step monitor.log(disk_metrics, step=progress.step) - if is_tt_moe_model(model): - max_vio_log_metrics = { - "max_vio/mean": batch_max_vio.item(), - "step": progress.step, - } - monitor.log(max_vio_log_metrics, step=progress.step) + if is_tt_moe_model(model) and batch_max_vio.item() > 0: + monitor.log({"max_vio/mean": batch_max_vio.item(), "step": progress.step}, step=progress.step) if config.val is not None and not is_first_step and progress.step % config.val.interval == 0: run_validation(progress.step) From 0bcdb39dfe17c2f96571e770c3f9fc954e24417a Mon Sep 17 00:00:00 2001 From: hallerite Date: Fri, 6 Mar 2026 20:23:42 +0000 Subject: [PATCH 13/19] add option for val at start --- src/prime_rl/configs/sft.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/prime_rl/configs/sft.py b/src/prime_rl/configs/sft.py index 396d58ee59..1f603e3876 100644 --- a/src/prime_rl/configs/sft.py +++ b/src/prime_rl/configs/sft.py @@ -108,6 +108,7 @@ def validate_subsets_and_splits(self): class SFTValConfig(BaseConfig): interval: Annotated[int, Field(ge=1, description="Run validation every N training steps.")] = 50 + eval_on_start: Annotated[bool, Field(description="Run validation before the first training step.")] = False data: SFTDataConfig From b5cdc934fed64da515e9f787973eabebdc1ec8f9 Mon Sep 17 00:00:00 2001 From: hallerite Date: Fri, 6 Mar 2026 20:34:45 +0000 Subject: [PATCH 14/19] add back some stuff --- src/prime_rl/trainer/sft/train.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index fdca0b7789..2e757b50d5 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -11,7 +11,7 @@ from prime_rl.trainer.models.layers.attn import substitute_ring_attn from prime_rl.utils.act_offloading import maybe_activation_offloading import torch -from torch.profiler import profile, ProfilerActivity +from torch.profiler import profile, ProfilerActivity, record_function from prime_rl.trainer.ckpt import setup_ckpt_managers from prime_rl.utils.pathing import resolve_latest_ckpt_step from prime_rl.configs.sft import SFTConfig @@ -34,6 +34,7 @@ MemoryProfiler, export_benchmark_json, get_ckpt_disk_metrics, + print_sample, setup_torch_distributed, print_benchmark, ) @@ -198,6 +199,8 @@ def compute_loss(micro_batch: dict) -> tuple[torch.Tensor, torch.Tensor]: del out, logits return per_token_loss, loss_mask + maybe_record_function = nullcontext + def run_forward_loop(data_iter, num_steps=None, *, backward=True): total_loss = torch.tensor(0.0, device="cuda") nan_count = torch.tensor(0, device="cuda") @@ -208,7 +211,13 @@ def run_forward_loop(data_iter, num_steps=None, *, backward=True): ctx = nullcontext() if backward else torch.no_grad() with ctx: for step, micro_batch in enumerate(data_iter): - per_token_loss, loss_mask = compute_loss(micro_batch) + if backward and config.log.log_data: + input_ids_log = micro_batch["input_ids"].flatten().tolist() + loss_mask_log = micro_batch["loss_mask"].flatten().tolist() + print_sample(input_ids_log, loss_mask_log, tokenizer) + + with maybe_record_function("forward"): + per_token_loss, loss_mask = compute_loss(micro_batch) loss = per_token_loss[loss_mask].mean() if not torch.isnan(loss.detach()): @@ -218,7 +227,8 @@ def run_forward_loop(data_iter, num_steps=None, *, backward=True): nan_count += 1 if backward: - (loss / divisor).backward() + with maybe_record_function("backward"): + (loss / divisor).backward() if is_tt_moe_model(model): max_vio = get_load_balance_stats(model)["max_vio"] @@ -265,6 +275,7 @@ def run_validation(step: int) -> None: if config.trace_path: logger.info(f"Tracing to {config.trace_path}") prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True).__enter__() + maybe_record_function = record_function # noqa: F841 – captured by run_forward_loop closure while True: # Reset peak memory stats torch.cuda.reset_peak_memory_stats() From e49fd1899c76d829cc6feec321a9d662fcf038dd Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 9 Mar 2026 22:14:08 +0000 Subject: [PATCH 15/19] smarter gating --- src/prime_rl/trainer/sft/train.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index ab85e45ade..0de359d19c 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -279,9 +279,6 @@ def run_validation(step: int) -> None: if was_training: model.train() - if config.val is not None and config.val.eval_on_start: - run_validation(progress.step) - logger.info(f"Starting training loop (max_steps={config.max_steps or 'infinite'})") max_memory = torch.cuda.mem_get_info()[1] / 1024**3 # GiB is_first_step = True @@ -436,7 +433,11 @@ def run_validation(step: int) -> None: if is_tt_moe_model(model) and batch_max_vio.item() > 0: monitor.log({"max_vio/mean": batch_max_vio.item(), "step": progress.step}, step=progress.step) - if config.val is not None and not is_first_step and progress.step % config.val.interval == 0: + should_eval = config.val is not None and ( + (is_first_step and config.val.eval_on_start) + or (not is_first_step and progress.step % config.val.interval == 0) + ) + if should_eval: run_validation(progress.step) is_first_step = False From 47f3b3072d30525636bd69e5b4fc4f0f6eb07063 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 9 Mar 2026 22:45:08 +0000 Subject: [PATCH 16/19] smol fix --- src/prime_rl/trainer/sft/train.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 0de359d19c..c524bd571a 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -319,6 +319,13 @@ def run_validation(step: int) -> None: if config.max_steps is not None and progress.step >= config.max_steps: break + # Run validation (at start if eval_on_start, then every interval steps) + if config.val is not None and ( + (is_first_step and config.val.eval_on_start) + or (not is_first_step and progress.step % config.val.interval == 0) + ): + run_validation(progress.step) + memory_profiler = ( MemoryProfiler(progress.step, config.memory_profiler_path) if config.memory_profiler_path else None ) @@ -433,13 +440,6 @@ def run_validation(step: int) -> None: if is_tt_moe_model(model) and batch_max_vio.item() > 0: monitor.log({"max_vio/mean": batch_max_vio.item(), "step": progress.step}, step=progress.step) - should_eval = config.val is not None and ( - (is_first_step and config.val.eval_on_start) - or (not is_first_step and progress.step % config.val.interval == 0) - ) - if should_eval: - run_validation(progress.step) - is_first_step = False progress.step += 1 From 67b828a1948ea2d9e76c0bb79d75be1ea19a7202 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 9 Mar 2026 22:49:42 +0000 Subject: [PATCH 17/19] fix no-recompile --- src/prime_rl/trainer/sft/train.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index c524bd571a..c0d176c415 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -266,8 +266,7 @@ def run_validation(step: int) -> None: ) val_dataloader = setup_dataloader(val_dataset, config.val.data) - was_training = model.training - model.eval() + # No train/eval switch: no dropout in these models, and toggling would trigger torch.compile recompilation mean_loss, nan_count, _ = run_forward_loop(val_dataloader, backward=False) if nan_count > 0: logger.warning(f"Validation at step {step}: {nan_count} batches had NaN loss") @@ -276,8 +275,6 @@ def run_validation(step: int) -> None: else: logger.success(f"Validation | Step {step} | Loss: {mean_loss:.4f}") monitor.log({"val/loss": mean_loss, "step": step}, step=step) - if was_training: - model.train() logger.info(f"Starting training loop (max_steps={config.max_steps or 'infinite'})") max_memory = torch.cuda.mem_get_info()[1] / 1024**3 # GiB From 4d3d4b8cac7521690c88324688a3292b309b056e Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 9 Mar 2026 23:01:44 +0000 Subject: [PATCH 18/19] make sure forward is compiled --- src/prime_rl/trainer/sft/train.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index c0d176c415..fc45086d7e 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -316,13 +316,6 @@ def run_validation(step: int) -> None: if config.max_steps is not None and progress.step >= config.max_steps: break - # Run validation (at start if eval_on_start, then every interval steps) - if config.val is not None and ( - (is_first_step and config.val.eval_on_start) - or (not is_first_step and progress.step % config.val.interval == 0) - ): - run_validation(progress.step) - memory_profiler = ( MemoryProfiler(progress.step, config.memory_profiler_path) if config.memory_profiler_path else None ) @@ -332,6 +325,14 @@ def run_validation(step: int) -> None: batch_loss, nan_loss_count, batch_max_vio = run_forward_loop(dataiter, grad_accum_steps, backward=True) + # Run validation after forward-backward (so torch.compile sees training graph first) but before + # optimizer step (so eval_on_start evaluates untrained weights) + if config.val is not None and ( + (is_first_step and config.val.eval_on_start) + or (not is_first_step and progress.step % config.val.interval == 0) + ): + run_validation(progress.step) + logger.debug(f"Clipping gradients with max norm {config.optim.max_norm}") grad_norm = clip_grad_norm_( model.parameters(), max_norm=config.optim.max_norm, ep_enabled=parallel_dims.ep_enabled From 81fefe7227f67ce43aac197bf55c76a530f04d63 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 9 Mar 2026 23:13:27 +0000 Subject: [PATCH 19/19] fix timing metric --- src/prime_rl/trainer/sft/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index fc45086d7e..d741fca8c2 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -324,6 +324,7 @@ def run_validation(step: int) -> None: forward_backward_start_time = time.perf_counter() batch_loss, nan_loss_count, batch_max_vio = run_forward_loop(dataiter, grad_accum_steps, backward=True) + forward_backward_time = time.perf_counter() - forward_backward_start_time # Run validation after forward-backward (so torch.compile sees training graph first) but before # optimizer step (so eval_on_start evaluates untrained weights) @@ -348,8 +349,6 @@ def run_validation(step: int) -> None: current_lr = optimizer.param_groups[0]["lr"] scheduler.step() - forward_backward_time = time.perf_counter() - forward_backward_start_time - # Optionally, dump memory snapshot if memory_profiler is not None: memory_profiler.step()