Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Documenting changes which affect configuration usage patterns (added/moved/remov
- **`trainer.enable_router_replay`**: Added flag to enable router replay. If True, will return routed experts in the batch. This is only supported if `enable_return_routed_experts=True` in the inference config or pass `--enable-return-routed-experts` to vLLM server. This is only supported for custom models. (2026-02-22)
- **`inference.enable_return_routed_experts`**: Added flag to enable return routed experts. Passed to vLLM as `--enable-return-routed-experts` (2026-02-22)
- **`orchestrator.oversampling_factor`**: Added rollout-only over-sampling config that resolves `max_inflight_rollouts = int(batch_size * oversampling_factor)` when `max_inflight_rollouts` is unset. Cannot be used with `token_batch_size` or together with explicit `max_inflight_rollouts` (2026-02-25)
- **`sft.val`**: Added optional periodic SFT validation with `val/loss` and `val/num_tokens` logging. Configure via `sft.val.data` (validation dataset) and `sft.val.interval` (every N steps, default 50). Runs the full validation dataset each pass. (2026-02-26)
- **`model.fused_lm_head_chunk_size`**: Changed default value from 2048 to 8192 for RL training (2026-02-26)
- **`inference.data_parallel_size_local`** and **`inference.data_parallel_rpc_port`**: Added data-parallel node-local controls for vLLM, passed as `--data-parallel-size-local` and `--data-parallel-rpc-port` (defaults: `None`, `13345`) (2026-02-26)
- **`dump_config`**: Removed from `RLConfig`. Replaced by `dry_run` (see below) (2026-02-26)
Expand Down
37 changes: 28 additions & 9 deletions src/prime_rl/configs/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def validate_subsets_and_splits(self):
return self


class SFTValConfig(BaseConfig):
interval: Annotated[int, Field(ge=1, description="Run validation every N training steps.")] = 50
eval_on_start: Annotated[bool, Field(description="Run validation before the first training step.")] = False
Comment thread
cursor[bot] marked this conversation as resolved.
data: SFTDataConfig


DataConfig: TypeAlias = Annotated[FakeDataConfig | SFTDataConfig, Field(discriminator="type")]


Expand Down Expand Up @@ -163,6 +169,9 @@ class SFTConfig(BaseConfig):
# The data configuration
data: DataConfig = SFTDataConfig()

# Optional validation configuration
val: SFTValConfig | None = None

# The optimizer configuration
optim: OptimizerConfig = AdamWConfig()

Expand Down Expand Up @@ -261,27 +270,37 @@ def validate_deployment(self):

@model_validator(mode="after")
def validate_pack_function(self):
if self.model.cp > 1 and self.data.pack_function != "cat":
raise ValueError("Packing function must be 'cat' when CP is enabled")
if self.model.cp > 1:
if self.data.pack_function != "cat":
raise ValueError("Packing function must be 'cat' when CP is enabled")
if self.val is not None and self.val.data.pack_function != "cat":
raise ValueError("Validation packing function must be 'cat' when CP is enabled")
return self

@model_validator(mode="after")
def validate_cp_seq_len(self):
if self.model.cp > 1 and self.data.seq_len % self.model.cp != 0:
raise ValueError("Sequence length must be divisible by CP degree")
if self.model.cp > 1:
if self.data.seq_len % self.model.cp != 0:
raise ValueError("Sequence length must be divisible by CP degree")
if self.val is not None and self.val.data.seq_len % self.model.cp != 0:
raise ValueError("Validation sequence length must be divisible by CP degree")
return self

@model_validator(mode="after")
def validate_cp_micro_batch_size(self):
if self.model.cp > 1 and self.data.micro_batch_size != 1:
raise ValueError("Micro batch size must be 1 when CP is enabled")
if self.model.cp > 1:
if self.data.micro_batch_size != 1:
raise ValueError("Micro batch size must be 1 when CP is enabled")
if self.val is not None and self.val.data.micro_batch_size != 1:
raise ValueError("Validation micro batch size must be 1 when CP is enabled")
return self

@model_validator(mode="after")
def validate_seq_len(self):
if self.data.pack_function == "stack":
if self.data.seq_len % 256 != 0:
raise ValueError("The sequence length must be divisible by 256 when using pack function stack")
if self.data.pack_function == "stack" and self.data.seq_len % 256 != 0:
raise ValueError("The sequence length must be divisible by 256 when using pack function stack")
if self.val is not None and self.val.data.pack_function == "stack" and self.val.data.seq_len % 256 != 0:
raise ValueError("The validation sequence length must be divisible by 256 when using pack function stack")
return self

@model_validator(mode="after")
Expand Down
87 changes: 50 additions & 37 deletions src/prime_rl/trainer/sft/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers.tokenization_utils import PreTrainedTokenizer

from prime_rl.configs.sft import DataConfig, LossMaskConfig
from prime_rl.configs.sft import DataConfig, LossMaskConfig, SFTDataConfig
from prime_rl.trainer.world import get_world
from prime_rl.utils.logger import get_logger

Expand Down Expand Up @@ -541,54 +541,67 @@ def setup_and_interleave_datasets(
return dataset


def setup_dataset(tokenizer: PreTrainedTokenizer, config: DataConfig, non_dp_size: int = 1) -> StatefulIterableDataset:
def load_sft_dataset(config: SFTDataConfig) -> Dataset:
"""Load and interleave the raw HF dataset. This is the expensive I/O step."""
logger = get_logger()
if config.subsets is None and config.splits is None:
return setup_and_interleave_datasets(
dataset_name=config.name,
subsets_and_splits=[(None, "train")],
probabilities=config.probabilities,
stopping_strategy=config.stopping_strategy,
)
elif config.subsets is not None and config.splits is None:
logger.debug(f"Loading datasets for subsets {config.subsets} with default split 'train'")
return setup_and_interleave_datasets(
dataset_name=config.name,
subsets_and_splits=[(subset, "train") for subset in config.subsets],
probabilities=config.probabilities,
stopping_strategy=config.stopping_strategy,
)
elif config.subsets is None and config.splits is not None:
logger.debug(f"Loading datasets for splits {config.splits} with default subset 'None'")
return setup_and_interleave_datasets(
dataset_name=config.name,
subsets_and_splits=[(None, split) for split in config.splits],
probabilities=config.probabilities,
stopping_strategy=config.stopping_strategy,
)
else:
assert config.subsets is not None and config.splits is not None
logger.debug(f"Loading datasets for subsets {config.subsets} with splits {config.splits}")
return setup_and_interleave_datasets(
dataset_name=config.name,
subsets_and_splits=list(zip(config.subsets, config.splits)),
probabilities=config.probabilities,
stopping_strategy=config.stopping_strategy,
)


def setup_dataset(
tokenizer: PreTrainedTokenizer,
config: DataConfig,
non_dp_size: int = 1,
*,
max_epochs: int | None = None,
raw_dataset: Dataset | None = None,
) -> StatefulIterableDataset:
if config.type == "fake":
# Shouldnt matter to handle non_dp_size if dataset is random
return FakeDataset(
vocab_size=tokenizer.vocab_size, seq_len=config.seq_len, length=config.length, input_ids=config.input_ids
)
elif config.type == "sft":
logger = get_logger()
if config.subsets is None and config.splits is None:
dataset = setup_and_interleave_datasets(
dataset_name=config.name,
subsets_and_splits=[(None, "train")],
probabilities=config.probabilities,
stopping_strategy=config.stopping_strategy,
)
elif config.subsets is not None and config.splits is None:
logger.debug(f"Loading datasets for subsets {config.subsets} with default split 'train'")
dataset = setup_and_interleave_datasets(
dataset_name=config.name,
subsets_and_splits=[(subset, "train") for subset in config.subsets],
probabilities=config.probabilities,
stopping_strategy=config.stopping_strategy,
)
elif config.subsets is None and config.splits is not None:
logger.debug(f"Loading datasets for splits {config.splits} with default subset 'None'")
dataset = setup_and_interleave_datasets(
dataset_name=config.name,
subsets_and_splits=[(None, split) for split in config.splits],
probabilities=config.probabilities,
stopping_strategy=config.stopping_strategy,
)
else:
assert config.subsets is not None and config.splits is not None
logger.debug(f"Loading datasets for subsets {config.subsets} with splits {config.splits}")
dataset = setup_and_interleave_datasets(
dataset_name=config.name,
subsets_and_splits=list(zip(config.subsets, config.splits)),
probabilities=config.probabilities,
stopping_strategy=config.stopping_strategy,
)
if raw_dataset is None:
raw_dataset = load_sft_dataset(config)
return SFTDataset(
dataset,
raw_dataset,
tokenizer,
shuffle=config.shuffle,
seed=config.seed,
seq_len=config.seq_len,
loss_mask_config=config.loss_mask,
non_dp_size=non_dp_size,
max_epochs=max_epochs,
)
else:
raise ValueError(f"Invalid dataset type: {config.type}")
Expand Down
Loading
Loading