Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
"torchvision",
"torchaudio",
"torchdata>=0.11.0",
"transformers",
"transformers==5.6.2",
"vllm>=0.22.0",
"mooncake-transfer-engine>=0.3.10.post2",
"wandb>=0.26.1",
Expand Down Expand Up @@ -128,7 +128,6 @@ dev = [
"ruff>=0.12.1",
]


[tool.uv]
# Enforce a uv version that supports the friendly-duration form
# (`"7 days"`) in the static pyproject parser. Older uvs silently parse
Expand All @@ -147,7 +146,7 @@ environments = [
override-dependencies = [
"nvidia-cudnn-cu12>=9.15",
"nvidia-cutlass-dsl>=4.4.1",
"transformers>=5.1.0.dev0",
"transformers==5.6.2",
"torch>=2.9.0",
"openenv-core",
]
Expand Down Expand Up @@ -231,7 +230,6 @@ torchvision = { index = "pytorch-cu128" }
torchaudio = { index = "pytorch-cu128" }
torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" }
dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" }
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" }
flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" }
vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.26/vllm_router-0.1.26-cp38-abi3-manylinux_2_28_x86_64.whl" }
vllm = [
Expand Down
6 changes: 6 additions & 0 deletions skills/training/start-run/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ uv run rl @ examples/reverse_text/rl.toml --dry-run
- Config: `RLConfig` (`packages/prime-rl-configs/src/prime_rl/configs/rl.py`)
- Entrypoint: `src/prime_rl/entrypoints/rl.py`
- SLURM: single- and multi-node
- Environment packages: before launching a config with a non-core verifier env id,
verify the package imports under `uv run` (for example
`uv run python -c "import importlib.util; print(importlib.util.find_spec('rlm_swe'))"`).
If a local env exists under `deps/research-environments/environments/` but does not
import, add it to the root `pyproject.toml` env extra, workspace members, and
`[tool.uv.sources]`, then run `uv sync --all-extras`.

## `sft` — SFT training

Expand Down
1 change: 0 additions & 1 deletion src/prime_rl/trainer/ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,6 @@ def save(
f"Converted PrimeRL format to HF format in {time.perf_counter() - start_time:.2f} seconds"
)
else:
# For regular transformers models, revert internal format to original HF hub format
from transformers.core_model_loading import revert_weight_conversion

self.logger.debug("Reverting transformers internal format to HF hub format for weight checkpoint")
Expand Down
4 changes: 2 additions & 2 deletions src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _patch_qwen3_5_moe_conversion_mapping():
incorrectly maps qwen3_5_moe → qwen2_moe, which assumes per-expert 2D checkpoint weights,
causing revert_weight_conversion to produce wrong shapes during weight broadcasting.

Remove once the pinned transformers commit fixes this.
Remove once an official Transformers release fixes this.
"""
from transformers.conversion_mapping import (
get_checkpoint_conversion_mapping,
Expand All @@ -99,7 +99,7 @@ def _patch_qwen3_5_text_position_ids():
"""Fix Qwen3.5 passing 3D MRoPE position_ids to decoder layers instead of 2D text_position_ids.

Upstream fix: https://github.com/huggingface/transformers/pull/44399
Remove once the pinned transformers commit includes this fix.
Remove once an official Transformers release includes this fix.
"""
import inspect

Expand Down
1 change: 0 additions & 1 deletion src/prime_rl/trainer/rl/broadcast/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def broadcast_weights(self, model: nn.Module, step: int) -> None:
if isinstance(model, PreTrainedModelPrimeRL) and model.is_prime_state_dict(state_dict):
model.convert_to_hf(state_dict)
else:
# For regular transformers models, revert internal format to original HF hub format
from transformers.core_model_loading import revert_weight_conversion

state_dict = revert_weight_conversion(model, state_dict)
Expand Down
Loading
Loading