Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
29c4ab3
feat: add frozen-reference native rl training
PastaPastaPasta Mar 6, 2026
a7462ab
fix: align grpo rollout ratios with sampling temperature
PastaPastaPasta Mar 6, 2026
18a58a1
feat: add internal rl runtime foundation
PastaPastaPasta Mar 6, 2026
e422b8b
fix: stabilize grpo kl and reward payloads
PastaPastaPasta Mar 6, 2026
a36beea
feat: add rl role builders and checkpoints
PastaPastaPasta Mar 6, 2026
c4e2f19
fix: restore scalar role bundles faithfully
PastaPastaPasta Mar 6, 2026
1190a32
feat: add reward and online rl trainers
PastaPastaPasta Mar 6, 2026
c6002c2
fix: correct ppo kl and online dpo reward context
PastaPastaPasta Mar 6, 2026
b27356a
feat: add parity-focused rl api surface
PastaPastaPasta Mar 6, 2026
fa0e3f9
fix: honor reward source composition semantics
PastaPastaPasta Mar 6, 2026
d5cb341
feat: harden on-policy rl training runtime
PastaPastaPasta Mar 6, 2026
f39ab07
fix: validate rl resume cache and config drift
PastaPastaPasta Mar 6, 2026
fa6b58c
feat: add trl rl compatibility patch layer
PastaPastaPasta Mar 6, 2026
524b998
fix: align grpo family variant semantics
PastaPastaPasta Mar 6, 2026
1aa7de8
fix: activate on-policy parity control knobs
PastaPastaPasta Mar 6, 2026
87a476f
fix: enforce rl rollout length caps
PastaPastaPasta Mar 6, 2026
7b30e63
feat: add qwen3 arithmetic grpo validation benchmark
PastaPastaPasta Mar 6, 2026
aa1cda8
fix: stop grpo rollouts on chat end tokens
PastaPastaPasta Mar 7, 2026
0a2d39b
fix: expose cache-aware model wrapper API
PastaPastaPasta Mar 7, 2026
7443495
feat: log grpo subphase timings
PastaPastaPasta Mar 8, 2026
38d3473
fix: initialize mlx-lm prompt cache for grpo rollouts
PastaPastaPasta Mar 8, 2026
1fd1d86
fix: rebuild grpo prompt cache when rollouts finish early
PastaPastaPasta Mar 8, 2026
ef14f1a
fix: preserve prompt context in grpo rollouts
PastaPastaPasta Mar 8, 2026
fbdfcf5
fix: stabilize grpo reference drift metrics
PastaPastaPasta Mar 8, 2026
9766cae
fix: batch grpo categorical sampling
PastaPastaPasta Mar 8, 2026
42b860e
fix: isolate rollout caches per sample
PastaPastaPasta Mar 8, 2026
d48c738
fix: normalize logged grpo kl by completion length
PastaPastaPasta Mar 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions examples/09_rl_training_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
ORPOTrainer, ORPOConfig,
GRPOTrainer, GRPOConfig,
# Utilities
prepare_preference_dataset,
prepare_rl_dataset,
create_reward_function,
resume_from_checkpoint,
)


Expand Down Expand Up @@ -60,6 +61,8 @@ def demo_dpo_training():
},
]

prepared_dataset = prepare_rl_dataset(preference_data, mode="preference", tokenizer=tokenizer)

# Configure DPO
config = DPOConfig(
beta=0.1, # KL penalty coefficient
Expand All @@ -72,7 +75,7 @@ def demo_dpo_training():
# Create trainer
trainer = DPOTrainer(
model=model,
train_dataset=preference_data,
train_dataset=prepared_dataset,
tokenizer=tokenizer,
args=config,
)
Expand Down Expand Up @@ -122,12 +125,19 @@ def demo_grpo_training():
},
]

# Create a math reward function
math_reward = create_reward_function("math")
prepared_dataset = prepare_rl_dataset(reasoning_data, mode="prompt", tokenizer=tokenizer)

# Compose rewards through the public RL API surface.
math_reward = create_reward_function(
rewards=[
{"name": "math", "source": "math", "weight": 1.0},
{"name": "length", "source": "length", "weight": 0.1},
]
)

# Configure GRPO
config = GRPOConfig(
loss_type="grpo", # grpo, dr_grpo, dapo, bnpo
loss_type="grpo", # Phase 1 accepts grpo/dr_grpo/dapo/bnpo via one shared objective
beta=0.04,
num_generations=4, # Multiple generations per prompt
temperature=0.7,
Expand All @@ -139,7 +149,7 @@ def demo_grpo_training():
# Create trainer with custom reward function
trainer = GRPOTrainer(
model=model,
train_dataset=reasoning_data,
train_dataset=prepared_dataset,
tokenizer=tokenizer,
reward_fn=math_reward, # Custom reward!
args=config,
Expand All @@ -153,6 +163,7 @@ def demo_grpo_training():

# Would train with: trainer.train()
print("\nTo train: trainer.train()")
print(f"To inspect a saved checkpoint first: {resume_from_checkpoint.__name__}('./grpo_output')")


def demo_orpo_training():
Expand Down Expand Up @@ -221,14 +232,14 @@ def show_available_trainers():
print(f"| {name} | {method} | {use_case} |")

print("\n" + "=" * 70)
print("GRPO Loss Types (for reasoning models)")
print("GRPO Loss Types (accepted in Phase 1)")
print("=" * 70)

grpo_types = [
("grpo", "Standard GRPO", "Default for reasoning"),
("dr_grpo", "Dr. GRPO", "Distilled version"),
("dapo", "DAPO", "Data-efficient variant"),
("bnpo", "BNPO", "Batch-normalized variant"),
("grpo", "Standard GRPO", "Primary Phase 1 name"),
("dr_grpo", "Dr. GRPO", "Accepted alias; shared Phase 1 objective"),
("dapo", "DAPO", "Accepted alias; shared Phase 1 objective"),
("bnpo", "BNPO", "Accepted alias; shared Phase 1 objective"),
]

print("\n| Loss Type | Name | Description |")
Expand Down
10 changes: 10 additions & 0 deletions examples/10_qwen3_arithmetic_grpo_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pathlib import Path
import sys

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from mlx_tune.arithmetic_grpo_validation import main


if __name__ == "__main__":
raise SystemExit(main())
76 changes: 74 additions & 2 deletions mlx_tune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@

__version__ = "0.4.0" # Renamed to mlx-tune (formerly unsloth-mlx)

from mlx_tune.model import FastLanguageModel
from mlx_tune.model import (
FastLanguageModel,
ReferencePolicy,
RLModelRoles,
RewardModel,
ValueModel,
build_value_model,
create_rl_model_roles,
)
from mlx_tune.trainer import (
prepare_dataset,
format_chat_template,
Expand All @@ -25,25 +33,44 @@
get_training_config,
)
from mlx_tune.sft_trainer import SFTTrainer, SFTConfig, TrainingArguments
from mlx_tune.rl_api import (
RLCheckpointBundle,
PreparedRLDataset,
prepare_rl_dataset,
build_reference_policy,
build_reward_model,
create_reward_function,
resume_from_checkpoint,
)

# RL Trainers
from mlx_tune.rl_trainers import (
RewardTrainer,
RewardConfig,
DPOTrainer,
DPOConfig,
ORPOTrainer,
ORPOConfig,
GRPOTrainer,
GRPOConfig,
PPOTrainer,
PPOConfig,
OnlineDPOTrainer,
OnlineDPOConfig,
KTOConfig,
SimPOConfig,
KTOTrainer,
SimPOTrainer,
prepare_reward_dataset,
prepare_preference_dataset,
create_reward_function,
score_reward_model,
)

# Loss functions for custom training
from mlx_tune.losses import (
compute_log_probs,
compute_log_probs_with_lengths,
compute_completion_log_probs,
dpo_loss,
orpo_loss,
kto_loss,
Expand All @@ -52,6 +79,16 @@
grpo_loss,
grpo_batch_loss,
compute_reference_logprobs,
pairwise_reward_loss,
reward_model_pairwise_loss,
reward_model_regression_loss,
value_regression_loss,
value_model_regression_loss,
scalar_loss_metrics,
pairwise_ranking_accuracy,
precompute_preference_reference_logprobs,
precompute_kto_reference_logprobs,
ppo_sequence_loss,
)

# Vision Language Models
Expand Down Expand Up @@ -92,10 +129,24 @@
HFDatasetConfig,
load_dataset_with_config,
)
from mlx_tune.trl_compat import PatchFastRL

__all__ = [
# Core
"FastLanguageModel",
"ReferencePolicy",
"RLModelRoles",
"RewardModel",
"ValueModel",
"build_reference_policy",
"build_reward_model",
"build_value_model",
"create_rl_model_roles",
"PreparedRLDataset",
"RLCheckpointBundle",
"prepare_rl_dataset",
"resume_from_checkpoint",
"PatchFastRL",
"__version__",
# SFT Training
"SFTTrainer",
Expand All @@ -108,6 +159,14 @@
"ORPOConfig",
"GRPOTrainer",
"GRPOConfig",
"RewardTrainer",
"RewardConfig",
"PPOTrainer",
"PPOConfig",
"OnlineDPOTrainer",
"OnlineDPOConfig",
"KTOConfig",
"SimPOConfig",
"KTOTrainer",
"SimPOTrainer",
# Vision Models
Expand All @@ -116,6 +175,7 @@
# Loss Functions
"compute_log_probs",
"compute_log_probs_with_lengths",
"compute_completion_log_probs",
"dpo_loss",
"orpo_loss",
"kto_loss",
Expand All @@ -124,15 +184,27 @@
"grpo_loss",
"grpo_batch_loss",
"compute_reference_logprobs",
"pairwise_reward_loss",
"reward_model_pairwise_loss",
"reward_model_regression_loss",
"value_regression_loss",
"value_model_regression_loss",
"scalar_loss_metrics",
"pairwise_ranking_accuracy",
"precompute_preference_reference_logprobs",
"precompute_kto_reference_logprobs",
"ppo_sequence_loss",
# Utilities
"prepare_dataset",
"prepare_reward_dataset",
"prepare_preference_dataset",
"format_chat_template",
"create_training_data",
"save_model_hf_format",
"export_to_gguf",
"get_training_config",
"create_reward_function",
"score_reward_model",
"load_vlm_dataset",
# Chat Templates and Dataset Formatting
"detect_dataset_format",
Expand Down
Loading