diff --git a/examples/09_rl_training_methods.py b/examples/09_rl_training_methods.py
index 2904d88..66beaae 100644
--- a/examples/09_rl_training_methods.py
+++ b/examples/09_rl_training_methods.py
@@ -17,8 +17,9 @@
ORPOTrainer, ORPOConfig,
GRPOTrainer, GRPOConfig,
# Utilities
- prepare_preference_dataset,
+ prepare_rl_dataset,
create_reward_function,
+ resume_from_checkpoint,
)
@@ -60,6 +61,8 @@ def demo_dpo_training():
},
]
+ prepared_dataset = prepare_rl_dataset(preference_data, mode="preference", tokenizer=tokenizer)
+
# Configure DPO
config = DPOConfig(
beta=0.1, # KL penalty coefficient
@@ -72,7 +75,7 @@ def demo_dpo_training():
# Create trainer
trainer = DPOTrainer(
model=model,
- train_dataset=preference_data,
+ train_dataset=prepared_dataset,
tokenizer=tokenizer,
args=config,
)
@@ -122,12 +125,19 @@ def demo_grpo_training():
},
]
- # Create a math reward function
- math_reward = create_reward_function("math")
+ prepared_dataset = prepare_rl_dataset(reasoning_data, mode="prompt", tokenizer=tokenizer)
+
+ # Compose rewards through the public RL API surface.
+ math_reward = create_reward_function(
+ rewards=[
+ {"name": "math", "source": "math", "weight": 1.0},
+ {"name": "length", "source": "length", "weight": 0.1},
+ ]
+ )
# Configure GRPO
config = GRPOConfig(
- loss_type="grpo", # grpo, dr_grpo, dapo, bnpo
+ loss_type="grpo", # Phase 1 accepts grpo/dr_grpo/dapo/bnpo via one shared objective
beta=0.04,
num_generations=4, # Multiple generations per prompt
temperature=0.7,
@@ -139,7 +149,7 @@ def demo_grpo_training():
# Create trainer with custom reward function
trainer = GRPOTrainer(
model=model,
- train_dataset=reasoning_data,
+ train_dataset=prepared_dataset,
tokenizer=tokenizer,
reward_fn=math_reward, # Custom reward!
args=config,
@@ -153,6 +163,7 @@ def demo_grpo_training():
# Would train with: trainer.train()
print("\nTo train: trainer.train()")
+ print(f"To inspect a saved checkpoint first: {resume_from_checkpoint.__name__}('./grpo_output')")
def demo_orpo_training():
@@ -221,14 +232,14 @@ def show_available_trainers():
print(f"| {name} | {method} | {use_case} |")
print("\n" + "=" * 70)
- print("GRPO Loss Types (for reasoning models)")
+ print("GRPO Loss Types (accepted in Phase 1)")
print("=" * 70)
grpo_types = [
- ("grpo", "Standard GRPO", "Default for reasoning"),
- ("dr_grpo", "Dr. GRPO", "Distilled version"),
- ("dapo", "DAPO", "Data-efficient variant"),
- ("bnpo", "BNPO", "Batch-normalized variant"),
+ ("grpo", "Standard GRPO", "Primary Phase 1 name"),
+ ("dr_grpo", "Dr. GRPO", "Accepted alias; shared Phase 1 objective"),
+ ("dapo", "DAPO", "Accepted alias; shared Phase 1 objective"),
+ ("bnpo", "BNPO", "Accepted alias; shared Phase 1 objective"),
]
print("\n| Loss Type | Name | Description |")
diff --git a/examples/10_qwen3_arithmetic_grpo_validation.py b/examples/10_qwen3_arithmetic_grpo_validation.py
new file mode 100644
index 0000000..3d2b67f
--- /dev/null
+++ b/examples/10_qwen3_arithmetic_grpo_validation.py
@@ -0,0 +1,10 @@
+from pathlib import Path
+import sys
+
+sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
+
+from mlx_tune.arithmetic_grpo_validation import main
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())
diff --git a/mlx_tune/__init__.py b/mlx_tune/__init__.py
index 81e140f..4ccea8a 100644
--- a/mlx_tune/__init__.py
+++ b/mlx_tune/__init__.py
@@ -15,7 +15,15 @@
__version__ = "0.4.0" # Renamed to mlx-tune (formerly unsloth-mlx)
-from mlx_tune.model import FastLanguageModel
+from mlx_tune.model import (
+ FastLanguageModel,
+ ReferencePolicy,
+ RLModelRoles,
+ RewardModel,
+ ValueModel,
+ build_value_model,
+ create_rl_model_roles,
+)
from mlx_tune.trainer import (
prepare_dataset,
format_chat_template,
@@ -25,25 +33,44 @@
get_training_config,
)
from mlx_tune.sft_trainer import SFTTrainer, SFTConfig, TrainingArguments
+from mlx_tune.rl_api import (
+ RLCheckpointBundle,
+ PreparedRLDataset,
+ prepare_rl_dataset,
+ build_reference_policy,
+ build_reward_model,
+ create_reward_function,
+ resume_from_checkpoint,
+)
# RL Trainers
from mlx_tune.rl_trainers import (
+ RewardTrainer,
+ RewardConfig,
DPOTrainer,
DPOConfig,
ORPOTrainer,
ORPOConfig,
GRPOTrainer,
GRPOConfig,
+ PPOTrainer,
+ PPOConfig,
+ OnlineDPOTrainer,
+ OnlineDPOConfig,
+ KTOConfig,
+ SimPOConfig,
KTOTrainer,
SimPOTrainer,
+ prepare_reward_dataset,
prepare_preference_dataset,
- create_reward_function,
+ score_reward_model,
)
# Loss functions for custom training
from mlx_tune.losses import (
compute_log_probs,
compute_log_probs_with_lengths,
+ compute_completion_log_probs,
dpo_loss,
orpo_loss,
kto_loss,
@@ -52,6 +79,16 @@
grpo_loss,
grpo_batch_loss,
compute_reference_logprobs,
+ pairwise_reward_loss,
+ reward_model_pairwise_loss,
+ reward_model_regression_loss,
+ value_regression_loss,
+ value_model_regression_loss,
+ scalar_loss_metrics,
+ pairwise_ranking_accuracy,
+ precompute_preference_reference_logprobs,
+ precompute_kto_reference_logprobs,
+ ppo_sequence_loss,
)
# Vision Language Models
@@ -92,10 +129,24 @@
HFDatasetConfig,
load_dataset_with_config,
)
+from mlx_tune.trl_compat import PatchFastRL
__all__ = [
# Core
"FastLanguageModel",
+ "ReferencePolicy",
+ "RLModelRoles",
+ "RewardModel",
+ "ValueModel",
+ "build_reference_policy",
+ "build_reward_model",
+ "build_value_model",
+ "create_rl_model_roles",
+ "PreparedRLDataset",
+ "RLCheckpointBundle",
+ "prepare_rl_dataset",
+ "resume_from_checkpoint",
+ "PatchFastRL",
"__version__",
# SFT Training
"SFTTrainer",
@@ -108,6 +159,14 @@
"ORPOConfig",
"GRPOTrainer",
"GRPOConfig",
+ "RewardTrainer",
+ "RewardConfig",
+ "PPOTrainer",
+ "PPOConfig",
+ "OnlineDPOTrainer",
+ "OnlineDPOConfig",
+ "KTOConfig",
+ "SimPOConfig",
"KTOTrainer",
"SimPOTrainer",
# Vision Models
@@ -116,6 +175,7 @@
# Loss Functions
"compute_log_probs",
"compute_log_probs_with_lengths",
+ "compute_completion_log_probs",
"dpo_loss",
"orpo_loss",
"kto_loss",
@@ -124,8 +184,19 @@
"grpo_loss",
"grpo_batch_loss",
"compute_reference_logprobs",
+ "pairwise_reward_loss",
+ "reward_model_pairwise_loss",
+ "reward_model_regression_loss",
+ "value_regression_loss",
+ "value_model_regression_loss",
+ "scalar_loss_metrics",
+ "pairwise_ranking_accuracy",
+ "precompute_preference_reference_logprobs",
+ "precompute_kto_reference_logprobs",
+ "ppo_sequence_loss",
# Utilities
"prepare_dataset",
+ "prepare_reward_dataset",
"prepare_preference_dataset",
"format_chat_template",
"create_training_data",
@@ -133,6 +204,7 @@
"export_to_gguf",
"get_training_config",
"create_reward_function",
+ "score_reward_model",
"load_vlm_dataset",
# Chat Templates and Dataset Formatting
"detect_dataset_format",
diff --git a/mlx_tune/_rl_runtime.py b/mlx_tune/_rl_runtime.py
new file mode 100644
index 0000000..b236e1e
--- /dev/null
+++ b/mlx_tune/_rl_runtime.py
@@ -0,0 +1,1237 @@
+from dataclasses import dataclass, fields, replace
+import inspect
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
+
+import mlx.core as mx
+import mlx.nn as nn
+
+
+@dataclass
+class PolicyEvalBatch:
+ input_ids: mx.array
+ sequence_lengths: mx.array
+ token_mask: mx.array
+ prompt_lengths: Optional[mx.array] = None
+ completion_lengths: Optional[mx.array] = None
+ rollout_logprobs: Optional[mx.array] = None
+ old_logprobs: Optional[mx.array] = None
+ old_token_logprobs: Optional[mx.array] = None
+ reference_logprobs: Optional[mx.array] = None
+ value_predictions: Optional[mx.array] = None
+ returns: Optional[mx.array] = None
+ advantages: Optional[mx.array] = None
+ labels: Optional[mx.array] = None
+ prompt_group_indices: Optional[mx.array] = None
+ sample_indices: Optional[mx.array] = None
+ token_logprobs: Optional[mx.array] = None
+ summed_logprobs: Optional[mx.array] = None
+
+
+@dataclass
+class RolloutBatch:
+ prompt_ids: List[List[int]]
+ prompt_lengths: mx.array
+ completion_ids: List[List[int]]
+ completion_lengths: mx.array
+ prompt_texts: List[str]
+ original_prompt_texts: Optional[List[str]]
+ completion_texts: List[str]
+ reward_contexts: List[Any]
+ sampled_token_logprobs: mx.array
+ rollout_logprobs: mx.array
+ eos_flags: mx.array
+ truncation_flags: mx.array
+ prompt_group_indices: mx.array
+ policy_eval: PolicyEvalBatch
+ sample_indices: Optional[mx.array] = None
+ sampled_token_logits: Optional[mx.array] = None
+ token_entropies: Optional[mx.array] = None
+ old_logprobs: Optional[mx.array] = None
+ rewards: Optional[mx.array] = None
+ reference_logprobs: Optional[mx.array] = None
+ value_predictions: Optional[mx.array] = None
+ returns: Optional[mx.array] = None
+ advantages: Optional[mx.array] = None
+
+
+@dataclass
+class RewardBatch:
+ prompt_texts: List[str]
+ completion_texts: List[str]
+ reward_contexts: List[Any]
+ scalar_rewards: mx.array
+ prompt_group_indices: mx.array
+ original_prompt_texts: Optional[List[str]] = None
+ named_reward_components: Optional[List[Dict[str, float]]] = None
+ diagnostics: Optional[List[Dict[str, Any]]] = None
+
+
+@dataclass
+class PreferenceBatch:
+ chosen: PolicyEvalBatch
+ rejected: PolicyEvalBatch
+ sample_indices: mx.array
+ chosen_reference_logprobs: Optional[mx.array] = None
+ rejected_reference_logprobs: Optional[mx.array] = None
+
+
+def pad_sequences(sequences: Sequence[Sequence[int]], pad_id: int) -> tuple[mx.array, mx.array]:
+ max_length = max(len(sequence) for sequence in sequences)
+ padded = [list(sequence) + [pad_id] * (max_length - len(sequence)) for sequence in sequences]
+ lengths = [len(sequence) for sequence in sequences]
+ return mx.array(padded), mx.array(lengths)
+
+
+def truncate_prompt_tokens(prompt_ids: Sequence[int], max_prompt_length: Optional[int]) -> List[int]:
+ prompt_tokens = list(prompt_ids)
+ if max_prompt_length is None or len(prompt_tokens) <= max_prompt_length:
+ return prompt_tokens
+ return prompt_tokens[-max_prompt_length:]
+
+
+def truncate_completion_tokens(
+ completion_ids: Sequence[int],
+ max_completion_length: Optional[int],
+) -> tuple[List[int], bool]:
+ completion_tokens = list(completion_ids)
+ if max_completion_length is None or len(completion_tokens) <= max_completion_length:
+ return completion_tokens, False
+ return completion_tokens[:max_completion_length], True
+
+
+def cap_prompt_and_completion_lengths(
+ prompt_ids: Sequence[int],
+ completion_ids: Sequence[int],
+ max_seq_length: Optional[int],
+ max_completion_length: Optional[int],
+) -> tuple[List[int], List[int], bool]:
+ effective_completion_cap = max_completion_length
+ if max_seq_length is not None:
+ effective_completion_cap = (
+ int(max_seq_length)
+ if effective_completion_cap is None
+ else min(int(effective_completion_cap), int(max_seq_length))
+ )
+ capped_completion_ids, truncated_completion = truncate_completion_tokens(
+ completion_ids,
+ effective_completion_cap,
+ )
+
+ capped_prompt_ids = list(prompt_ids)
+ if max_seq_length is not None:
+ remaining_prompt_budget = max(0, int(max_seq_length) - len(capped_completion_ids))
+ capped_prompt_ids = truncate_prompt_tokens(prompt_ids, remaining_prompt_budget)
+ return capped_prompt_ids, capped_completion_ids, truncated_completion
+
+
+def _resolve_terminal_token_ids(tokenizer: Any) -> set[int]:
+ terminal_ids: set[int] = set()
+
+ eos_token_id = getattr(tokenizer, "eos_token_id", None)
+ if eos_token_id is not None:
+ try:
+ terminal_ids.add(int(eos_token_id))
+ except (TypeError, ValueError):
+ pass
+
+ stop_token = getattr(tokenizer, "_unsloth_stop_token", None)
+ if not stop_token:
+ return terminal_ids
+
+ candidate_ids: list[Any] = []
+ if hasattr(tokenizer, "convert_tokens_to_ids"):
+ try:
+ candidate_ids.append(tokenizer.convert_tokens_to_ids(stop_token))
+ except Exception:
+ pass
+ if hasattr(tokenizer, "get_vocab"):
+ try:
+ vocab = tokenizer.get_vocab()
+ if stop_token in vocab:
+ candidate_ids.append(vocab[stop_token])
+ except Exception:
+ pass
+ if hasattr(tokenizer, "encode"):
+ try:
+ encoded = tokenizer.encode(stop_token, add_special_tokens=False)
+ if len(encoded) == 1:
+ candidate_ids.append(encoded[0])
+ except Exception:
+ pass
+
+ for token_id in candidate_ids:
+ if token_id is None:
+ continue
+ try:
+ terminal_ids.add(int(token_id))
+ except (TypeError, ValueError):
+ continue
+ return terminal_ids
+
+
+def length_mask(lengths: mx.array, width: int) -> mx.array:
+ positions = mx.arange(width)[None, :]
+ return positions < lengths[:, None]
+
+
+def completion_token_mask(
+ input_ids: mx.array,
+ prompt_lengths: mx.array,
+ completion_lengths: mx.array,
+) -> mx.array:
+ width = input_ids.shape[1] - 1
+ positions = mx.arange(width)[None, :]
+ start = mx.maximum(prompt_lengths - 1, 0)[:, None]
+ end = mx.maximum(prompt_lengths + completion_lengths - 1, 0)[:, None]
+ return (positions >= start) & (positions < end)
+
+
+def build_token_mask(
+ input_ids: mx.array,
+ sequence_lengths: mx.array,
+ mode: str = "sequence",
+ prompt_lengths: Optional[mx.array] = None,
+ completion_lengths: Optional[mx.array] = None,
+) -> mx.array:
+ if mode == "sequence":
+ return length_mask(sequence_lengths, input_ids.shape[1] - 1)
+ if mode == "completion":
+ if prompt_lengths is None or completion_lengths is None:
+ raise ValueError("Completion scoring requires prompt_lengths and completion_lengths.")
+ return completion_token_mask(input_ids, prompt_lengths, completion_lengths)
+ raise ValueError(f"Unsupported scoring mode: {mode}")
+
+
+def make_policy_eval_batch(
+ sequences: Sequence[Sequence[int]],
+ pad_id: int,
+ mode: str = "sequence",
+ prompt_lengths: Optional[Sequence[int]] = None,
+ completion_lengths: Optional[Sequence[int]] = None,
+ rollout_logprobs: Optional[mx.array] = None,
+ old_logprobs: Optional[mx.array] = None,
+ old_token_logprobs: Optional[mx.array] = None,
+ reference_logprobs: Optional[mx.array] = None,
+ value_predictions: Optional[mx.array] = None,
+ returns: Optional[mx.array] = None,
+ advantages: Optional[mx.array] = None,
+ labels: Optional[mx.array] = None,
+ prompt_group_indices: Optional[mx.array] = None,
+ sample_indices: Optional[mx.array] = None,
+) -> PolicyEvalBatch:
+ input_ids, sequence_lengths = pad_sequences(sequences, pad_id)
+ prompt_lengths_array = mx.array(prompt_lengths) if prompt_lengths is not None else None
+ completion_lengths_array = mx.array(completion_lengths) if completion_lengths is not None else None
+ token_mask = build_token_mask(
+ input_ids=input_ids,
+ sequence_lengths=sequence_lengths,
+ mode=mode,
+ prompt_lengths=prompt_lengths_array,
+ completion_lengths=completion_lengths_array,
+ )
+ return PolicyEvalBatch(
+ input_ids=input_ids,
+ sequence_lengths=sequence_lengths,
+ prompt_lengths=prompt_lengths_array,
+ completion_lengths=completion_lengths_array,
+ token_mask=token_mask,
+ rollout_logprobs=rollout_logprobs,
+ old_logprobs=old_logprobs if old_logprobs is not None else rollout_logprobs,
+ old_token_logprobs=old_token_logprobs,
+ reference_logprobs=reference_logprobs,
+ value_predictions=value_predictions,
+ returns=returns,
+ advantages=advantages,
+ labels=labels,
+ prompt_group_indices=prompt_group_indices,
+ sample_indices=sample_indices,
+ )
+
+
+def make_preference_batch(
+ chosen_sequences: Sequence[Sequence[int]],
+ rejected_sequences: Sequence[Sequence[int]],
+ pad_id: int,
+ sample_indices: Sequence[int],
+ chosen_reference_logprobs: Optional[mx.array] = None,
+ rejected_reference_logprobs: Optional[mx.array] = None,
+) -> PreferenceBatch:
+ return PreferenceBatch(
+ chosen=make_policy_eval_batch(
+ chosen_sequences,
+ pad_id=pad_id,
+ mode="sequence",
+ reference_logprobs=chosen_reference_logprobs,
+ sample_indices=mx.array(sample_indices),
+ ),
+ rejected=make_policy_eval_batch(
+ rejected_sequences,
+ pad_id=pad_id,
+ mode="sequence",
+ reference_logprobs=rejected_reference_logprobs,
+ sample_indices=mx.array(sample_indices),
+ ),
+ sample_indices=mx.array(sample_indices),
+ chosen_reference_logprobs=chosen_reference_logprobs,
+ rejected_reference_logprobs=rejected_reference_logprobs,
+ )
+
+
+def _token_log_probs(
+ model: Any,
+ input_ids: mx.array,
+ temperature: float = 1.0,
+) -> mx.array:
+ inputs = input_ids[:, :-1]
+ targets = input_ids[:, 1:]
+ logits = model(inputs)
+ if temperature != 1.0:
+ logits = logits / temperature
+ log_probs = nn.log_softmax(logits, axis=-1)
+ return mx.take_along_axis(log_probs, targets[:, :, None], axis=-1).squeeze(-1)
+
+
+def normalize_logprobs(
+ summed_logprobs: mx.array,
+ lengths: mx.array,
+ mode: str = "sum",
+) -> mx.array:
+ if mode == "sum":
+ return summed_logprobs
+ if mode in {"mean", "token_mean"}:
+ return summed_logprobs / mx.maximum(lengths.astype(summed_logprobs.dtype), 1.0)
+ raise ValueError(f"Unsupported length normalization mode: {mode}")
+
+
+def kl_against_reference(
+ policy_logprobs: mx.array,
+ reference_logprobs: mx.array,
+) -> mx.array:
+ log_ratio = policy_logprobs - reference_logprobs
+ return mx.exp(log_ratio) - log_ratio - 1.0
+
+
+def score_policy(
+ model: Any,
+ batch: PolicyEvalBatch,
+ mode: str = "sequence",
+ reference_model: Optional[Any] = None,
+ temperature: float = 1.0,
+) -> PolicyEvalBatch:
+ token_mask = build_token_mask(
+ input_ids=batch.input_ids,
+ sequence_lengths=batch.sequence_lengths,
+ mode=mode,
+ prompt_lengths=batch.prompt_lengths,
+ completion_lengths=batch.completion_lengths,
+ )
+ token_logprobs = _token_log_probs(model, batch.input_ids, temperature=temperature)
+ summed_logprobs = (token_logprobs * token_mask.astype(token_logprobs.dtype)).sum(axis=-1)
+
+ reference_logprobs = batch.reference_logprobs
+ if reference_model is not None:
+ reference_tokens = _token_log_probs(reference_model, batch.input_ids, temperature=temperature)
+ reference_logprobs = mx.stop_gradient(
+ (reference_tokens * token_mask.astype(reference_tokens.dtype)).sum(axis=-1)
+ )
+
+ return replace(
+ batch,
+ token_mask=token_mask,
+ token_logprobs=token_logprobs,
+ summed_logprobs=summed_logprobs,
+ reference_logprobs=reference_logprobs,
+ )
+
+
+def score_policy_in_chunks(
+ model: Any,
+ batch: PolicyEvalBatch,
+ batch_size: Optional[int],
+ mode: str = "sequence",
+ reference_model: Optional[Any] = None,
+ temperature: float = 1.0,
+ token_budget: Optional[int] = None,
+) -> PolicyEvalBatch:
+ if batch.input_ids.shape[0] == 0:
+ return score_policy(
+ model,
+ batch,
+ mode=mode,
+ reference_model=reference_model,
+ temperature=temperature,
+ )
+
+ if token_budget is None and batch_size is None:
+ batch_size = batch.input_ids.shape[0]
+
+ if token_budget is None and batch_size is not None and batch.input_ids.shape[0] <= batch_size:
+ return score_policy(
+ model,
+ batch,
+ mode=mode,
+ reference_model=reference_model,
+ temperature=temperature,
+ )
+
+ scored_chunks = []
+ for minibatch in assemble_minibatches(
+ batch,
+ minibatch_size=batch_size,
+ shuffle=False,
+ mode=mode,
+ token_budget=token_budget,
+ ):
+ scored_chunks.append(
+ score_policy(
+ model,
+ minibatch,
+ mode=mode,
+ reference_model=reference_model,
+ temperature=temperature,
+ )
+ )
+ return _concat_policy_eval_batches(scored_chunks)
+
+
+def _policy_eval_effective_lengths(batch: PolicyEvalBatch, mode: str) -> List[int]:
+ if batch.input_ids.shape[0] == 0:
+ return []
+ if mode == "completion" and batch.completion_lengths is not None:
+ return [max(1, int(value)) for value in batch.completion_lengths.tolist()]
+ return [max(1, int(value) - 1) for value in batch.sequence_lengths.tolist()]
+
+
+def _concat_policy_eval_batches(chunks: Sequence[PolicyEvalBatch]) -> PolicyEvalBatch:
+ def concat_attr(name: str):
+ values = [getattr(chunk, name) for chunk in chunks]
+ if values[0] is None:
+ return None
+ if hasattr(values[0], "shape"):
+ return mx.concatenate(values, axis=0)
+ if isinstance(values[0], list):
+ merged = []
+ for value in values:
+ merged.extend(value)
+ return merged
+ raise TypeError(f"Unsupported PolicyEvalBatch field type for concat: {name}")
+
+ return PolicyEvalBatch(
+ input_ids=concat_attr("input_ids"),
+ sequence_lengths=concat_attr("sequence_lengths"),
+ prompt_lengths=concat_attr("prompt_lengths"),
+ completion_lengths=concat_attr("completion_lengths"),
+ token_mask=concat_attr("token_mask"),
+ rollout_logprobs=concat_attr("rollout_logprobs"),
+ old_logprobs=concat_attr("old_logprobs"),
+ old_token_logprobs=concat_attr("old_token_logprobs"),
+ reference_logprobs=concat_attr("reference_logprobs"),
+ value_predictions=concat_attr("value_predictions"),
+ returns=concat_attr("returns"),
+ advantages=concat_attr("advantages"),
+ labels=concat_attr("labels"),
+ prompt_group_indices=concat_attr("prompt_group_indices"),
+ sample_indices=concat_attr("sample_indices"),
+ token_logprobs=concat_attr("token_logprobs"),
+ summed_logprobs=concat_attr("summed_logprobs"),
+ )
+
+
+def sample_completion(
+ policy: Any,
+ tokenizer: Any,
+ prompt_ids: Sequence[int],
+ max_tokens: int,
+ temperature: float,
+ collect_sample_stats: bool = False,
+) -> Dict[str, Any]:
+ generated_ids = list(prompt_ids)
+ sampled_logprobs: List[float] = []
+ sampled_logits: List[float] = []
+ token_entropies: List[float] = []
+ saw_eos = False
+ terminal_token_ids = _resolve_terminal_token_ids(tokenizer)
+ x = mx.array([generated_ids])
+
+ for _ in range(max_tokens):
+ logits = policy(x)[:, -1, :]
+ if temperature > 0:
+ scaled = logits / temperature
+ probs = mx.softmax(scaled, axis=-1)
+ next_token = mx.random.categorical(mx.log(probs + 1e-10))
+ log_probs = nn.log_softmax(scaled, axis=-1)
+ entropy = -mx.sum(probs * log_probs, axis=-1)[0]
+ else:
+ scaled = logits
+ next_token = mx.argmax(logits, axis=-1)
+ log_probs = nn.log_softmax(logits, axis=-1)
+ probs = mx.softmax(logits, axis=-1)
+ entropy = -mx.sum(probs * log_probs, axis=-1)[0]
+
+ next_token_id = int(next_token.item())
+ sampled_logprobs.append(float(log_probs[0, next_token_id].item()))
+ if collect_sample_stats:
+ sampled_logits.append(float(scaled[0, next_token_id].item()))
+ token_entropies.append(float(entropy.item()))
+
+ generated_ids.append(next_token_id)
+ x = mx.array([generated_ids])
+ if next_token_id in terminal_token_ids:
+ saw_eos = True
+ break
+
+ completion_ids = generated_ids[len(prompt_ids):]
+ return {
+ "generated_ids": generated_ids,
+ "completion_ids": completion_ids,
+ "sampled_logprobs": sampled_logprobs,
+ "sampled_logits": sampled_logits,
+ "token_entropies": token_entropies,
+ "eos_flag": saw_eos,
+ "truncation_flag": (not saw_eos) and len(completion_ids) >= max_tokens,
+ }
+
+
+class _RewardEvaluatorAdapter:
+ def __init__(self, evaluator: Any):
+ self.evaluator = evaluator
+ self.mode = self._resolve_mode(evaluator)
+
+ def _resolve_mode(self, evaluator: Any) -> str:
+ if evaluator is None:
+ return "none"
+ if hasattr(evaluator, "evaluate_batch"):
+ return "evaluate_batch"
+ if hasattr(evaluator, "evaluate"):
+ return "evaluate"
+ if not callable(evaluator):
+ raise TypeError("Reward evaluator must be callable or expose evaluate().")
+
+ try:
+ signature = inspect.signature(evaluator)
+ except (TypeError, ValueError):
+ return "legacy"
+
+ positional = 0
+ for parameter in signature.parameters.values():
+ if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
+ return "legacy"
+ if (
+ parameter.kind in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ )
+ and parameter.default is inspect._empty
+ ):
+ positional += 1
+ return "legacy" if positional >= 2 else "structured"
+
+ def evaluate(self, payload: Dict[str, Any]) -> tuple[float, Optional[Dict[str, float]], Optional[Dict[str, Any]]]:
+ if self.mode == "none":
+ result = 0.0
+ elif self.mode == "evaluate_batch":
+ result = self.evaluator.evaluate_batch([payload])[0]
+ elif self.mode == "evaluate":
+ result = self.evaluator.evaluate(payload)
+ elif self.mode == "structured":
+ result = self.evaluator(payload)
+ else:
+ result = self.evaluator(payload["completion_text"], payload["reward_context"])
+ return self._normalize_result(result)
+
+ def _normalize_result(
+ self,
+ result: Any,
+ ) -> tuple[float, Optional[Dict[str, float]], Optional[Dict[str, Any]]]:
+ if isinstance(result, Mapping):
+ reward = float(result.get("reward", result.get("score", 0.0)))
+ components = result.get("components")
+ diagnostics = result.get("diagnostics")
+ return reward, dict(components) if components is not None else None, diagnostics
+ return float(result), None, None
+
+ def evaluate_batch(
+ self,
+ payloads: Sequence[Dict[str, Any]],
+ ) -> List[tuple[float, Optional[Dict[str, float]], Optional[Dict[str, Any]]]]:
+ if self.mode == "none":
+ return [(0.0, None, None) for _ in payloads]
+ if hasattr(self.evaluator, "evaluate_batch"):
+ results = self.evaluator.evaluate_batch(list(payloads))
+ return [self._normalize_result(result) for result in results]
+ return [self.evaluate(payload) for payload in payloads]
+
+
+def _batched_last_token_logits(
+ policy: Any,
+ input_rows: Sequence[Sequence[int]],
+ pad_id: int,
+ cache_state: Any = None,
+) -> Tuple[mx.array, Any]:
+ if not input_rows:
+ return mx.zeros((0, 0), dtype=mx.float32), cache_state
+
+ lengths = [len(row) for row in input_rows]
+ call_signature = None
+ try:
+ call_signature = inspect.signature(policy.__call__)
+ except (TypeError, ValueError):
+ call_signature = None
+ use_cache = (
+ cache_state is not False
+ and len(input_rows) == 1
+ and len(set(lengths)) == 1
+ and (
+ hasattr(policy, "forward_with_cache")
+ or (call_signature is not None and "cache" in call_signature.parameters)
+ )
+ )
+ # Batched KV-cache decoding can preserve token choices while still
+ # producing incorrect per-token logprobs on real Qwen-family models.
+ # GRPO relies on those sampled logprobs matching a fresh rescore of the
+ # same completion, so shared cache reuse is limited to single-row decoding.
+ if use_cache:
+ prefill = cache_state is None
+ if prefill:
+ try:
+ from mlx_lm.models.cache import make_prompt_cache
+
+ cache_state = make_prompt_cache(getattr(policy, "model", policy))
+ except Exception:
+ cache_state = None
+ inputs = mx.array(input_rows if prefill else [[row[-1]] for row in input_rows])
+ try:
+ if hasattr(policy, "forward_with_cache"):
+ outputs = policy.forward_with_cache(inputs, cache=cache_state)
+ else:
+ outputs = policy(inputs, cache=cache_state)
+ if isinstance(outputs, tuple) and len(outputs) == 2:
+ logits, next_cache = outputs
+ return logits[:, -1, :], next_cache
+ if hasattr(outputs, "shape"):
+ return outputs[:, -1, :], cache_state
+ except Exception:
+ cache_state = False
+
+ padded, seq_lengths = pad_sequences(input_rows, pad_id)
+ logits = policy(padded)
+ row_positions = mx.arange(padded.shape[0])
+ last_positions = seq_lengths - 1
+ return logits[row_positions, last_positions, :], cache_state
+
+
+def collect_rollouts(
+ policy: Any,
+ tokenizer: Any,
+ prompt_samples: Sequence[Dict[str, Any]],
+ sampling_config: Dict[str, Any],
+ reward_evaluator: Any = None,
+ collect_sample_stats: bool = False,
+) -> RolloutBatch:
+ num_generations = int(sampling_config.get("num_generations", 1))
+ max_completion_length = int(sampling_config.get("max_tokens", sampling_config.get("max_completion_length", 256)))
+ max_seq_length = sampling_config.get("max_seq_length")
+ temperature = float(sampling_config.get("temperature", 0.7))
+ max_prompt_length = None
+ if max_seq_length is not None:
+ # Preserve as much prompt context as possible and spend only the
+ # remaining budget on completion tokens. Reserving the full requested
+ # completion cap up front can collapse the prompt to an unusable suffix.
+ max_prompt_length = max(0, int(max_seq_length) - 1)
+
+ generation_batch_size = int(sampling_config.get("generation_batch_size") or 0)
+ generation_batch_size = max(1, generation_batch_size or (len(prompt_samples) * max(1, num_generations)))
+ pad_id = int(getattr(tokenizer, "pad_token_id", 0) or 0)
+ terminal_token_ids = _resolve_terminal_token_ids(tokenizer)
+
+ expanded_samples: List[Dict[str, Any]] = []
+ for group_index, sample in enumerate(prompt_samples):
+ sample_index = int(sample.get("sample_index", group_index))
+ original_prompt_text = sample.get("prompt", sample.get("prompt_text", ""))
+ reward_context = sample.get("reward_context")
+ prepared_prompt_ids = truncate_prompt_tokens(sample.get("prompt_ids", []), max_prompt_length)
+ effective_prompt_text = tokenizer.decode(prepared_prompt_ids)
+ generation_budget = max_completion_length
+ if max_seq_length is not None:
+ generation_budget = min(max_completion_length, max(0, int(max_seq_length) - len(prepared_prompt_ids)))
+ for _ in range(num_generations):
+ expanded_samples.append(
+ {
+ "prompt_ids": list(prepared_prompt_ids),
+ "prompt_text": effective_prompt_text,
+ "original_prompt_text": original_prompt_text,
+ "reward_context": reward_context,
+ "sample_index": sample_index,
+ "prompt_group_index": group_index,
+ "generated_ids": list(prepared_prompt_ids),
+ "sampled_logprobs": [],
+ "sampled_logits": [],
+ "token_entropies": [],
+ "eos_flag": False,
+ "done": generation_budget == 0,
+ "max_completion_tokens": generation_budget,
+ "hit_length_limit": generation_budget == 0,
+ "cache_state": None,
+ }
+ )
+
+ for start in range(0, len(expanded_samples), generation_batch_size):
+ chunk = expanded_samples[start:start + generation_batch_size]
+ for _ in range(max_completion_length):
+ active_rows = [row for row in chunk if not row["done"]]
+ if not active_rows:
+ break
+
+ row_logits = []
+ for row in active_rows:
+ logits, row["cache_state"] = _batched_last_token_logits(
+ policy,
+ [row["generated_ids"]],
+ pad_id=pad_id,
+ cache_state=row.get("cache_state"),
+ )
+ row_logits.append(logits)
+ logits = mx.concatenate(row_logits, axis=0) if row_logits else mx.zeros((0, 0), dtype=mx.float32)
+ if temperature > 0:
+ scaled = logits / temperature
+ log_probs = nn.log_softmax(scaled, axis=-1)
+ probs = mx.softmax(scaled, axis=-1)
+ entropies = -mx.sum(probs * log_probs, axis=-1)
+ next_tokens = [int(value) for value in mx.random.categorical(log_probs).tolist()]
+ else:
+ scaled = logits
+ log_probs = nn.log_softmax(logits, axis=-1)
+ probs = mx.softmax(logits, axis=-1)
+ entropies = -mx.sum(probs * log_probs, axis=-1)
+ next_tokens = [int(value) for value in mx.argmax(logits, axis=-1).tolist()]
+
+ for row_index, row in enumerate(active_rows):
+ token_id = next_tokens[row_index]
+ row["generated_ids"].append(token_id)
+ row["sampled_logprobs"].append(float(log_probs[row_index, token_id].item()))
+ if collect_sample_stats:
+ row["sampled_logits"].append(float(scaled[row_index, token_id].item()))
+ row["token_entropies"].append(float(entropies[row_index].item()))
+ if token_id in terminal_token_ids:
+ row["eos_flag"] = True
+ row["done"] = True
+ elif len(row["generated_ids"]) - len(row["prompt_ids"]) >= row["max_completion_tokens"]:
+ row["done"] = True
+ row["hit_length_limit"] = True
+
+ prompt_texts: List[str] = []
+ original_prompt_texts: List[str] = []
+ prompt_ids: List[List[int]] = []
+ prompt_lengths: List[int] = []
+ completion_ids: List[List[int]] = []
+ completion_lengths: List[int] = []
+ completion_texts: List[str] = []
+ reward_contexts: List[Any] = []
+ rollout_logprobs: List[float] = []
+ sampled_logprob_rows: List[List[float]] = []
+ sampled_logit_rows: List[List[float]] = []
+ entropy_rows: List[List[float]] = []
+ eos_flags: List[bool] = []
+ truncation_flags: List[bool] = []
+ prompt_group_indices: List[int] = []
+ sample_indices: List[int] = []
+
+ for row in expanded_samples:
+ raw_completion_ids = row["generated_ids"][len(row["prompt_ids"]):]
+ prepared_completion_ids, truncated = truncate_completion_tokens(
+ raw_completion_ids,
+ max_completion_length,
+ )
+ sampled_logprobs = row["sampled_logprobs"][: len(prepared_completion_ids)]
+ sampled_logits = row["sampled_logits"][: len(prepared_completion_ids)]
+ entropies = row["token_entropies"][: len(prepared_completion_ids)]
+
+ prompt_texts.append(row["prompt_text"])
+ original_prompt_texts.append(row["original_prompt_text"])
+ prompt_ids.append(list(row["prompt_ids"]))
+ prompt_lengths.append(len(row["prompt_ids"]))
+ reward_contexts.append(row["reward_context"])
+ completion_ids.append(prepared_completion_ids)
+ completion_lengths.append(len(prepared_completion_ids))
+ completion_texts.append(tokenizer.decode(prepared_completion_ids))
+ sampled_logprob_rows.append(sampled_logprobs)
+ rollout_logprobs.append(sum(sampled_logprobs))
+ eos_flags.append(bool(row["eos_flag"]))
+ truncation_flags.append(
+ bool(((not row["eos_flag"]) and row.get("hit_length_limit", False)) or truncated)
+ )
+ prompt_group_indices.append(int(row["prompt_group_index"]))
+ sample_indices.append(int(row["sample_index"]))
+ if collect_sample_stats:
+ sampled_logit_rows.append(sampled_logits)
+ entropy_rows.append(entropies)
+
+ max_completion_width = max(completion_lengths) if completion_lengths else 0
+ padded_token_logprobs, _ = pad_sequences(
+ [
+ [float(value) for value in row] + [0.0] * (max_completion_width - len(row))
+ for row in sampled_logprob_rows
+ ] if sampled_logprob_rows else [[0.0]],
+ 0,
+ )
+ if not sampled_logprob_rows:
+ padded_token_logprobs = mx.zeros((0, 0), dtype=mx.float32)
+ else:
+ padded_token_logprobs = padded_token_logprobs.astype(mx.float32)
+
+ sampled_token_logits = None
+ token_entropies = None
+ if collect_sample_stats:
+ if sampled_logit_rows:
+ sampled_token_logits, _ = pad_sequences(
+ [
+ [float(value) for value in row] + [0.0] * (max_completion_width - len(row))
+ for row in sampled_logit_rows
+ ],
+ 0,
+ )
+ token_entropies, _ = pad_sequences(
+ [
+ [float(value) for value in row] + [0.0] * (max_completion_width - len(row))
+ for row in entropy_rows
+ ],
+ 0,
+ )
+ sampled_token_logits = sampled_token_logits.astype(mx.float32)
+ token_entropies = token_entropies.astype(mx.float32)
+ else:
+ sampled_token_logits = mx.zeros((0, 0), dtype=mx.float32)
+ token_entropies = mx.zeros((0, 0), dtype=mx.float32)
+
+ full_sequences = [
+ prompt_sequence + completion_sequence
+ for prompt_sequence, completion_sequence in zip(prompt_ids, completion_ids)
+ ]
+ aligned_old_token_rows = [
+ [0.0] * max(prompt_length - 1, 0) + [float(value) for value in row]
+ for prompt_length, row in zip(prompt_lengths, sampled_logprob_rows)
+ ]
+ aligned_old_token_logprobs, _ = pad_sequences(
+ aligned_old_token_rows if aligned_old_token_rows else [[0.0]],
+ 0,
+ )
+ if not aligned_old_token_rows:
+ aligned_old_token_logprobs = mx.zeros((0, 0), dtype=mx.float32)
+ else:
+ aligned_old_token_logprobs = aligned_old_token_logprobs.astype(mx.float32)
+ policy_eval = make_policy_eval_batch(
+ full_sequences,
+ pad_id=pad_id,
+ mode="completion",
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ rollout_logprobs=mx.array(rollout_logprobs, dtype=mx.float32),
+ old_logprobs=mx.array(rollout_logprobs, dtype=mx.float32),
+ old_token_logprobs=aligned_old_token_logprobs,
+ prompt_group_indices=mx.array(prompt_group_indices),
+ sample_indices=mx.array(sample_indices),
+ )
+ rollout_batch = RolloutBatch(
+ prompt_ids=prompt_ids,
+ prompt_lengths=mx.array(prompt_lengths),
+ completion_ids=completion_ids,
+ completion_lengths=mx.array(completion_lengths),
+ prompt_texts=prompt_texts,
+ original_prompt_texts=original_prompt_texts,
+ completion_texts=completion_texts,
+ reward_contexts=reward_contexts,
+ sampled_token_logprobs=padded_token_logprobs,
+ rollout_logprobs=mx.array(rollout_logprobs, dtype=mx.float32),
+ old_logprobs=mx.array(rollout_logprobs, dtype=mx.float32),
+ eos_flags=mx.array(eos_flags),
+ truncation_flags=mx.array(truncation_flags),
+ prompt_group_indices=mx.array(prompt_group_indices),
+ policy_eval=policy_eval,
+ sample_indices=mx.array(sample_indices),
+ sampled_token_logits=sampled_token_logits,
+ token_entropies=token_entropies,
+ )
+ if reward_evaluator is not None:
+ reward_batch = evaluate_rewards(rollout_batch, reward_evaluator)
+ rollout_batch.rewards = reward_batch.scalar_rewards
+ return rollout_batch
+
+
+def evaluate_rewards(rollout_batch: RolloutBatch, evaluator: Any) -> RewardBatch:
+ adapter = _RewardEvaluatorAdapter(evaluator)
+ payloads: List[Dict[str, Any]] = []
+ for index in range(len(rollout_batch.prompt_texts)):
+ payloads.append(
+ {
+ "prompt_text": rollout_batch.prompt_texts[index],
+ "original_prompt_text": (
+ rollout_batch.original_prompt_texts[index]
+ if rollout_batch.original_prompt_texts is not None
+ else rollout_batch.prompt_texts[index]
+ ),
+ "completion_text": rollout_batch.completion_texts[index],
+ "reward_context": rollout_batch.reward_contexts[index],
+ "prompt_ids": list(rollout_batch.prompt_ids[index]),
+ "completion_ids": list(rollout_batch.completion_ids[index]),
+ "prompt_length": int(rollout_batch.prompt_lengths[index].item()),
+ "completion_length": int(rollout_batch.completion_lengths[index].item()),
+ "eos_flag": bool(rollout_batch.eos_flags[index].item()),
+ "truncation_flag": bool(rollout_batch.truncation_flags[index].item()),
+ "prompt_group_index": int(rollout_batch.prompt_group_indices[index].item()),
+ "sample_index": (
+ int(rollout_batch.sample_indices[index].item())
+ if rollout_batch.sample_indices is not None
+ else index
+ ),
+ }
+ )
+
+ scalar_rewards: List[float] = []
+ named_components: List[Dict[str, float]] = []
+ diagnostics: List[Dict[str, Any]] = []
+ for reward, components, sample_diagnostics in adapter.evaluate_batch(payloads):
+ scalar_rewards.append(reward)
+ named_components.append(components or {})
+ diagnostics.append(sample_diagnostics or {})
+
+ has_components = any(component for component in named_components)
+ has_diagnostics = any(diagnostic for diagnostic in diagnostics)
+ return RewardBatch(
+ prompt_texts=list(rollout_batch.prompt_texts),
+ completion_texts=list(rollout_batch.completion_texts),
+ reward_contexts=list(rollout_batch.reward_contexts),
+ scalar_rewards=mx.array(scalar_rewards, dtype=mx.float32),
+ prompt_group_indices=mx.array(rollout_batch.prompt_group_indices),
+ original_prompt_texts=list(rollout_batch.original_prompt_texts)
+ if rollout_batch.original_prompt_texts is not None
+ else None,
+ named_reward_components=named_components if has_components else None,
+ diagnostics=diagnostics if has_diagnostics else None,
+ )
+
+
+def compute_advantages(
+ reward_batch: RewardBatch,
+ grouping: str = "per_prompt",
+ normalization: str = "zscore_if_nonzero_else_center",
+) -> mx.array:
+ rewards = reward_batch.scalar_rewards.astype(mx.float32)
+ if grouping != "per_prompt":
+ raise ValueError(f"Unsupported advantage grouping: {grouping}")
+ if normalization != "zscore_if_nonzero_else_center":
+ raise ValueError(f"Unsupported advantage normalization: {normalization}")
+
+ reward_values = rewards.tolist()
+ group_values = reward_batch.prompt_group_indices.tolist()
+ advantages = [0.0] * len(reward_values)
+ grouped_positions: Dict[int, List[int]] = {}
+ for position, group_value in enumerate(group_values):
+ grouped_positions.setdefault(int(group_value), []).append(position)
+
+ for positions in grouped_positions.values():
+ group_rewards = mx.array([reward_values[position] for position in positions], dtype=mx.float32)
+ group_std = mx.std(group_rewards)
+ if float(group_std.item()) < 1e-6:
+ group_advantages = group_rewards - mx.mean(group_rewards)
+ else:
+ group_advantages = (group_rewards - mx.mean(group_rewards)) / (group_std + 1e-8)
+ for offset, position in enumerate(positions):
+ advantages[position] = float(group_advantages[offset].item())
+ return mx.array(advantages, dtype=mx.float32)
+
+
+def score_rollout_references(
+ reference_model: Any,
+ rollout_batch: RolloutBatch,
+ batch_size: Optional[int] = 8,
+ temperature: float = 1.0,
+ token_budget: Optional[int] = None,
+) -> RolloutBatch:
+ if reference_model is None:
+ return rollout_batch
+
+ scored = score_policy_in_chunks(
+ reference_model,
+ rollout_batch.policy_eval,
+ batch_size=batch_size,
+ token_budget=token_budget,
+ mode="completion",
+ temperature=temperature,
+ )
+ reference_logprobs = mx.stop_gradient(scored.summed_logprobs.astype(mx.float32))
+ return replace(
+ rollout_batch,
+ reference_logprobs=reference_logprobs,
+ policy_eval=replace(
+ rollout_batch.policy_eval,
+ reference_logprobs=reference_logprobs,
+ ),
+ )
+
+
+def predict_rollout_values(
+ value_model: Any,
+ rollout_batch: RolloutBatch,
+ batch_size: Optional[int] = 8,
+ token_budget: Optional[int] = None,
+) -> RolloutBatch:
+ if value_model is None:
+ return rollout_batch
+
+ value_chunks = []
+ for minibatch in assemble_minibatches(
+ rollout_batch.policy_eval,
+ minibatch_size=batch_size,
+ shuffle=False,
+ mode="completion",
+ token_budget=token_budget,
+ ):
+ value_chunks.append(
+ value_model.predict(
+ minibatch.input_ids,
+ sequence_lengths=minibatch.sequence_lengths,
+ prompt_lengths=minibatch.prompt_lengths
+ if minibatch.prompt_lengths is not None
+ else None,
+ completion_lengths=minibatch.completion_lengths
+ if minibatch.completion_lengths is not None
+ else None,
+ )
+ )
+ value_predictions = (
+ mx.concatenate(value_chunks, axis=0) if value_chunks else mx.zeros((0,), dtype=mx.float32)
+ )
+ value_predictions = mx.stop_gradient(value_predictions.astype(mx.float32))
+ return replace(
+ rollout_batch,
+ value_predictions=value_predictions,
+ policy_eval=replace(
+ rollout_batch.policy_eval,
+ value_predictions=value_predictions,
+ ),
+ )
+
+
+def compute_returns_and_advantages(
+ rewards: mx.array,
+ values: Optional[mx.array] = None,
+ prompt_group_indices: Optional[mx.array] = None,
+ mode: str = "gae",
+ gamma: float = 1.0,
+ gae_lambda: float = 1.0,
+ normalize: bool = False,
+) -> tuple[mx.array, mx.array]:
+ rewards = rewards.astype(mx.float32)
+ values = mx.zeros_like(rewards) if values is None else values.astype(mx.float32)
+
+ if mode == "gae":
+ deltas = rewards - values
+ advantages = deltas * gae_lambda + deltas * (1.0 - gae_lambda)
+ returns = advantages + values
+ elif mode == "group_zscore":
+ if prompt_group_indices is None:
+ raise ValueError("prompt_group_indices is required for grouped advantages.")
+ reward_batch = RewardBatch(
+ prompt_texts=[""] * rewards.shape[0],
+ completion_texts=[""] * rewards.shape[0],
+ reward_contexts=[None] * rewards.shape[0],
+ scalar_rewards=rewards,
+ prompt_group_indices=prompt_group_indices,
+ )
+ advantages = compute_advantages(reward_batch)
+ returns = rewards
+ elif mode == "group_center":
+ if prompt_group_indices is None:
+ raise ValueError("prompt_group_indices is required for grouped advantages.")
+ reward_values = rewards.tolist()
+ advantages_list = [0.0] * len(reward_values)
+ grouped_positions: Dict[int, List[int]] = {}
+ for position, group_value in enumerate(prompt_group_indices.tolist()):
+ grouped_positions.setdefault(int(group_value), []).append(position)
+ for positions in grouped_positions.values():
+ group_rewards = [reward_values[position] for position in positions]
+ baseline = sum(group_rewards) / float(len(group_rewards))
+ for offset, position in enumerate(positions):
+ advantages_list[position] = group_rewards[offset] - baseline
+ advantages = mx.array(advantages_list, dtype=mx.float32)
+ returns = rewards
+ elif mode == "rloo":
+ if prompt_group_indices is None:
+ raise ValueError("prompt_group_indices is required for RLOO advantages.")
+ reward_values = rewards.tolist()
+ advantages_list = [0.0] * len(reward_values)
+ grouped_positions: Dict[int, List[int]] = {}
+ for position, group_value in enumerate(prompt_group_indices.tolist()):
+ grouped_positions.setdefault(int(group_value), []).append(position)
+ for positions in grouped_positions.values():
+ if len(positions) <= 1:
+ continue
+ group_rewards = [reward_values[position] for position in positions]
+ total = sum(group_rewards)
+ denominator = float(len(positions) - 1)
+ for offset, position in enumerate(positions):
+ baseline = (total - group_rewards[offset]) / denominator
+ advantages_list[position] = group_rewards[offset] - baseline
+ advantages = mx.array(advantages_list, dtype=mx.float32)
+ returns = rewards
+ else:
+ raise ValueError(f"Unsupported returns/advantages mode: {mode}")
+
+ if normalize and advantages.shape[0] > 1:
+ advantages = (advantages - mx.mean(advantages)) / (mx.std(advantages) + 1e-8)
+
+ if gamma != 1.0:
+ returns = rewards + gamma * (returns - rewards)
+ return returns.astype(mx.float32), advantages.astype(mx.float32)
+
+
+def rank_grouped_rollouts(
+ rollout_batch: RolloutBatch,
+ score_tolerance: float = 1e-6,
+) -> List[Dict[str, Any]]:
+ if rollout_batch.rewards is None:
+ raise ValueError("Rollout rewards are required for grouped ranking.")
+
+ grouped_positions: Dict[int, List[int]] = {}
+ for position, group_value in enumerate(rollout_batch.prompt_group_indices.tolist()):
+ grouped_positions.setdefault(int(group_value), []).append(position)
+
+ reward_values = rollout_batch.rewards.tolist()
+ rankings: List[Dict[str, Any]] = []
+ for group_index, positions in grouped_positions.items():
+ ordered_positions = sorted(positions, key=lambda position: reward_values[position], reverse=True)
+ ordered_scores = [reward_values[position] for position in ordered_positions]
+ rankings.append(
+ {
+ "prompt_group_index": group_index,
+ "positions": positions,
+ "ordered_positions": ordered_positions,
+ "scores": ordered_scores,
+ "best_position": ordered_positions[0],
+ "worst_position": ordered_positions[-1],
+ "all_tied": (
+ len(ordered_scores) <= 1
+ or max(ordered_scores) - min(ordered_scores) <= score_tolerance
+ ),
+ }
+ )
+ return rankings
+
+
+def _slice_value(value: Any, indices: Sequence[int]) -> Any:
+ if value is None:
+ return None
+ if hasattr(value, "shape"):
+ return value[mx.array(indices)]
+ if isinstance(value, list):
+ return [value[index] for index in indices]
+ if isinstance(value, tuple):
+ return tuple(value[index] for index in indices)
+ raise TypeError(f"Unsupported minibatch field type: {type(value)!r}")
+
+
+def assemble_minibatches(
+ batch: PolicyEvalBatch,
+ minibatch_size: Optional[int],
+ shuffle: bool = False,
+ mode: str = "sequence",
+ token_budget: Optional[int] = None,
+) -> Iterable[PolicyEvalBatch]:
+ batch_size = batch.input_ids.shape[0]
+ if batch_size == 0:
+ return []
+
+ if shuffle:
+ order = [int(value) for value in mx.random.permutation(batch_size).tolist()]
+ else:
+ order = list(range(batch_size))
+
+ effective_lengths = _policy_eval_effective_lengths(batch, mode)
+ row_budget = max(1, minibatch_size or batch_size)
+ token_budget = max(1, token_budget) if token_budget is not None else None
+
+ batches_of_indices: List[List[int]] = []
+ current_indices: List[int] = []
+ current_tokens = 0
+ for index in order:
+ row_tokens = effective_lengths[index]
+ exceeds_token_budget = token_budget is not None and current_indices and current_tokens + row_tokens > token_budget
+ exceeds_row_budget = current_indices and len(current_indices) >= row_budget
+ if exceeds_token_budget or exceeds_row_budget:
+ batches_of_indices.append(current_indices)
+ current_indices = []
+ current_tokens = 0
+ current_indices.append(index)
+ current_tokens += row_tokens
+ if current_indices:
+ batches_of_indices.append(current_indices)
+
+ minibatches: List[PolicyEvalBatch] = []
+ for indices in batches_of_indices:
+ values = {
+ field.name: _slice_value(getattr(batch, field.name), indices)
+ for field in fields(batch)
+ }
+ minibatches.append(PolicyEvalBatch(**values))
+ return minibatches
+
+
+def summarize_rollout_metrics(
+ rollout_batch: RolloutBatch,
+ policy_loss: Optional[float] = None,
+ value_loss: Optional[float] = None,
+ reward_loss: Optional[float] = None,
+) -> Dict[str, float]:
+ metrics: Dict[str, float] = {}
+ if rollout_batch.rewards is not None and rollout_batch.rewards.shape[0] > 0:
+ metrics["reward_mean"] = float(mx.mean(rollout_batch.rewards).item())
+ metrics["reward_std"] = float(mx.std(rollout_batch.rewards).item())
+ if rollout_batch.reference_logprobs is not None and rollout_batch.rollout_logprobs.shape[0] > 0:
+ completion_lengths = mx.maximum(rollout_batch.completion_lengths.astype(mx.float32), 1.0)
+ log_ratio = (
+ rollout_batch.rollout_logprobs.astype(mx.float32)
+ - rollout_batch.reference_logprobs.astype(mx.float32)
+ )
+ metrics["logprob_delta_mean"] = float(mx.mean(log_ratio).item())
+ metrics["logprob_delta_per_token_mean"] = float(mx.mean(log_ratio / completion_lengths).item())
+ normalized_policy = normalize_logprobs(
+ rollout_batch.rollout_logprobs.astype(mx.float32),
+ completion_lengths,
+ mode="mean",
+ )
+ normalized_reference = normalize_logprobs(
+ rollout_batch.reference_logprobs.astype(mx.float32),
+ completion_lengths,
+ mode="mean",
+ )
+ kl_values = kl_against_reference(
+ normalized_policy,
+ normalized_reference,
+ )
+ kl_mean = float(mx.mean(kl_values).item())
+ if kl_mean != float("inf"):
+ metrics["kl_to_reference_mean"] = kl_mean
+ if rollout_batch.token_entropies is not None and rollout_batch.completion_lengths.shape[0] > 0:
+ entropy_mask = length_mask(
+ rollout_batch.completion_lengths,
+ rollout_batch.token_entropies.shape[1],
+ ).astype(rollout_batch.token_entropies.dtype)
+ valid_tokens = float(mx.sum(entropy_mask).item())
+ if valid_tokens > 0:
+ metrics["entropy_mean"] = float(
+ (mx.sum(rollout_batch.token_entropies * entropy_mask) / valid_tokens).item()
+ )
+ if rollout_batch.completion_lengths.shape[0] > 0:
+ metrics["completion_length_mean"] = float(mx.mean(rollout_batch.completion_lengths.astype(mx.float32)).item())
+ metrics["completion_length_max"] = float(mx.max(rollout_batch.completion_lengths.astype(mx.float32)).item())
+ if rollout_batch.eos_flags is not None and rollout_batch.eos_flags.shape[0] > 0:
+ metrics["eos_rate"] = float(mx.mean(rollout_batch.eos_flags.astype(mx.float32)).item())
+ if rollout_batch.truncation_flags is not None and rollout_batch.truncation_flags.shape[0] > 0:
+ metrics["truncation_rate"] = float(mx.mean(rollout_batch.truncation_flags.astype(mx.float32)).item())
+ if policy_loss is not None:
+ metrics["policy_loss"] = float(policy_loss)
+ if value_loss is not None:
+ metrics["value_loss"] = float(value_loss)
+ if reward_loss is not None:
+ metrics["reward_loss"] = float(reward_loss)
+ return metrics
diff --git a/mlx_tune/arithmetic_grpo_validation.py b/mlx_tune/arithmetic_grpo_validation.py
new file mode 100644
index 0000000..4436c4b
--- /dev/null
+++ b/mlx_tune/arithmetic_grpo_validation.py
@@ -0,0 +1,786 @@
+"""
+Deterministic arithmetic benchmark and GRPO validation for native-thinking Qwen 3 models.
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import random
+import re
+from pathlib import Path
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
+
+from mlx_lm.sample_utils import make_sampler
+
+from mlx_tune import (
+ FastLanguageModel,
+ GRPOConfig,
+ GRPOTrainer,
+ get_chat_template,
+)
+
+
+DEFAULT_MODEL_NAME = "mlx-community/Qwen3-1.7B-4bit"
+DEFAULT_SYSTEM_PROMPT = (
+ "You are a careful arithmetic solver. Think freely if useful. "
+ "Put the final integer answer inside ...."
+)
+DEFAULT_OUTPUT_DIR = Path("./artifacts/qwen3_arithmetic_grpo_validation")
+DEFAULT_RL_SUBDIR = "rl_run"
+DEFAULT_TRAIN_SIZE = 3000
+DEFAULT_VAL_SIZE = 300
+DEFAULT_TEST_SIZE = 300
+DEFAULT_SEED = 0
+DEFAULT_MAX_COMPLETION_LENGTH = 512
+DEFAULT_MAX_SEQ_LENGTH = 768
+DEFAULT_LORA_RANK = 16
+DEFAULT_RL_TEMPERATURE = 0.9
+DEFAULT_BASELINE_TEMPERATURE = 0.0
+
+_INTEGER_RE = re.compile(r"^[+-]?\d+$")
+_SOLUTION_RE = re.compile(r"(.*?)", re.IGNORECASE | re.DOTALL)
+
+
+def _write_json(path: Path, payload: Mapping[str, Any]) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with open(path, "w") as handle:
+ json.dump(payload, handle, indent=2)
+
+
+def _write_jsonl(path: Path, rows: Iterable[Mapping[str, Any]]) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with open(path, "w") as handle:
+ for row in rows:
+ handle.write(json.dumps(dict(row)) + "\n")
+
+
+def _read_json(path: Path) -> Dict[str, Any]:
+ with open(path) as handle:
+ return json.load(handle)
+
+
+def _read_jsonl(path: Path) -> List[Dict[str, Any]]:
+ rows: List[Dict[str, Any]] = []
+ with open(path) as handle:
+ for line in handle:
+ line = line.strip()
+ if line:
+ rows.append(json.loads(line))
+ return rows
+
+
+def _dataset_dir(output_dir: Path) -> Path:
+ return output_dir / "datasets"
+
+
+def _dataset_path(output_dir: Path, split: str) -> Path:
+ return _dataset_dir(output_dir) / f"{split}.jsonl"
+
+
+def _rl_output_dir(output_dir: Path) -> Path:
+ return output_dir / DEFAULT_RL_SUBDIR
+
+
+def _rl_adapter_dir(output_dir: Path) -> Path:
+ return _rl_output_dir(output_dir) / "policy"
+
+
+def _baseline_outputs_path(output_dir: Path) -> Path:
+ return output_dir / "baseline_outputs.jsonl"
+
+
+def _baseline_metrics_path(output_dir: Path) -> Path:
+ return output_dir / "baseline_metrics.json"
+
+
+def _post_outputs_path(output_dir: Path) -> Path:
+ return output_dir / "post_rl_outputs.jsonl"
+
+
+def _post_metrics_path(output_dir: Path) -> Path:
+ return output_dir / "post_rl_metrics.json"
+
+
+def _comparison_json_path(output_dir: Path) -> Path:
+ return output_dir / "comparison.json"
+
+
+def _comparison_md_path(output_dir: Path) -> Path:
+ return output_dir / "comparison.md"
+
+
+def _training_summary_path(output_dir: Path) -> Path:
+ return output_dir / "rl_training_summary.json"
+
+
+def _safe_eval(expression: str) -> int:
+ return int(eval(expression, {"__builtins__": {}}, {}))
+
+
+def _make_easy_expression(rng: random.Random) -> Tuple[str, int]:
+ left = rng.randint(0, 50)
+ right = rng.randint(0, 50)
+ operator = rng.choice(["+", "-"])
+ expression = f"{left} {operator} {right}"
+ return expression, _safe_eval(expression)
+
+
+def _make_medium_expression(rng: random.Random) -> Tuple[str, int]:
+ values = [rng.randint(0, 20) for _ in range(3)]
+ operators = [rng.choice(["+", "-", "*"]) for _ in range(2)]
+ mode = rng.choice(["plain", "left_paren", "right_paren"])
+ if mode == "left_paren":
+ expression = f"({values[0]} {operators[0]} {values[1]}) {operators[1]} {values[2]}"
+ elif mode == "right_paren":
+ expression = f"{values[0]} {operators[0]} ({values[1]} {operators[1]} {values[2]})"
+ else:
+ expression = f"{values[0]} {operators[0]} {values[1]} {operators[1]} {values[2]}"
+ return expression, _safe_eval(expression)
+
+
+def build_sample(task_id: str, expression: str, answer: int, difficulty: str) -> Dict[str, Any]:
+ return {
+ "task_id": task_id,
+ "messages": [
+ {"role": "system", "content": DEFAULT_SYSTEM_PROMPT},
+ {"role": "user", "content": f"Compute exactly:\n{expression}"},
+ ],
+ "answer": str(answer),
+ "expression": expression,
+ "difficulty": difficulty,
+ }
+
+
+def generate_benchmark_splits(
+ output_dir: Path,
+ train_size: int = DEFAULT_TRAIN_SIZE,
+ val_size: int = DEFAULT_VAL_SIZE,
+ test_size: int = DEFAULT_TEST_SIZE,
+ seed: int = DEFAULT_SEED,
+ force: bool = False,
+) -> Dict[str, Path]:
+ output_dir.mkdir(parents=True, exist_ok=True)
+ dataset_paths = {
+ "train": _dataset_path(output_dir, "train"),
+ "val": _dataset_path(output_dir, "val"),
+ "test": _dataset_path(output_dir, "test"),
+ }
+ if not force and all(path.exists() for path in dataset_paths.values()):
+ return dataset_paths
+
+ rng = random.Random(seed)
+ existing_expressions = set()
+ split_sizes = {"train": train_size, "val": val_size, "test": test_size}
+
+ for split_name, split_size in split_sizes.items():
+ rows: List[Dict[str, Any]] = []
+ while len(rows) < split_size:
+ difficulty = "easy" if (len(rows) % 2 == 0) else "medium"
+ if difficulty == "easy":
+ expression, answer = _make_easy_expression(rng)
+ else:
+ expression, answer = _make_medium_expression(rng)
+ if expression in existing_expressions:
+ continue
+ if abs(answer) > 999:
+ continue
+ existing_expressions.add(expression)
+ task_id = f"{split_name}-{len(rows) + 1:06d}"
+ rows.append(build_sample(task_id, expression, answer, difficulty))
+ _write_jsonl(dataset_paths[split_name], rows)
+
+ return dataset_paths
+
+
+def parse_solution_response(response_text: str) -> Dict[str, Any]:
+ matches = _SOLUTION_RE.findall(response_text)
+ single_solution_tag = len(matches) == 1
+ solution_text = matches[0].strip() if single_solution_tag else None
+ parseable_solution = bool(solution_text is not None and _INTEGER_RE.fullmatch(solution_text))
+ parsed_answer = int(solution_text) if parseable_solution else None
+ return {
+ "single_solution_tag": single_solution_tag,
+ "multiple_solution_tags": len(matches) > 1,
+ "solution_text": solution_text,
+ "parseable_solution": parseable_solution,
+ "parsed_answer": parsed_answer,
+ }
+
+
+def score_solution_output(response_text: str, gold_answer: str) -> Dict[str, Any]:
+ parsed = parse_solution_response(response_text)
+ exact_match = parsed["parseable_solution"] and str(parsed["parsed_answer"]) == str(gold_answer).strip()
+ tag_reward = 0.1 if parsed["single_solution_tag"] else 0.0
+ correctness_reward = 1.0 if exact_match else 0.0
+ reward = tag_reward + correctness_reward
+ return {
+ **parsed,
+ "exact_match": exact_match,
+ "tag_reward": tag_reward,
+ "correctness_reward": correctness_reward,
+ "reward": reward,
+ }
+
+
+class ArithmeticSolutionReward:
+ def evaluate(self, payload: Mapping[str, Any]) -> Dict[str, Any]:
+ scored = score_solution_output(
+ str(payload.get("completion_text", "")),
+ str(payload.get("reward_context", "")),
+ )
+ return {
+ "reward": scored["reward"],
+ "components": {
+ "solution_tag": scored["tag_reward"],
+ "correctness": scored["correctness_reward"],
+ },
+ "diagnostics": {
+ "solution_text": scored["solution_text"],
+ "parseable_solution": scored["parseable_solution"],
+ "multiple_solution_tags": scored["multiple_solution_tags"],
+ "exact_match": scored["exact_match"],
+ },
+ }
+
+
+def _token_length(tokenizer: Any, text: str) -> int:
+ try:
+ return len(tokenizer.encode(text, add_special_tokens=False))
+ except TypeError:
+ return len(tokenizer.encode(text))
+
+
+def load_model_bundle(
+ model_name: str,
+ *,
+ max_seq_length: int,
+ load_adapter_path: Optional[Path] = None,
+ for_training: bool = False,
+ lora_rank: int = DEFAULT_LORA_RANK,
+) -> Tuple[Any, Any]:
+ model, tokenizer = FastLanguageModel.from_pretrained(
+ model_name=model_name,
+ max_seq_length=max_seq_length,
+ load_in_4bit=True,
+ )
+ tokenizer = get_chat_template(tokenizer, chat_template="qwen-3")
+ if for_training:
+ model = FastLanguageModel.get_peft_model(
+ model,
+ r=lora_rank,
+ max_seq_length=max_seq_length,
+ )
+ if load_adapter_path is not None and load_adapter_path.exists():
+ model.load_adapter(str(load_adapter_path))
+ if not for_training:
+ FastLanguageModel.for_inference(model)
+ return model, tokenizer
+
+
+def generate_completion(
+ model: Any,
+ tokenizer: Any,
+ messages: Sequence[Mapping[str, Any]],
+ *,
+ max_tokens: int,
+ temperature: float,
+) -> str:
+ prompt = tokenizer.apply_chat_template(
+ list(messages),
+ add_generation_prompt=True,
+ tokenize=False,
+ )
+ return model.generate(
+ prompt=prompt,
+ max_tokens=max_tokens,
+ sampler=make_sampler(temp=temperature),
+ verbose=False,
+ )
+
+
+def _load_split_records(output_dir: Path, split: str) -> List[Dict[str, Any]]:
+ path = _dataset_path(output_dir, split)
+ if not path.exists():
+ raise FileNotFoundError(f"Missing dataset split: {path}")
+ return _read_jsonl(path)
+
+
+def _metrics_from_rows(
+ rows: Sequence[Mapping[str, Any]],
+ *,
+ split: str,
+ model_name: str,
+ seed: int,
+) -> Dict[str, Any]:
+ count = len(rows)
+ if count == 0:
+ return {
+ "split": split,
+ "model_name": model_name,
+ "seed": seed,
+ "num_samples": 0,
+ "exact_match": 0.0,
+ "solution_tag_rate": 0.0,
+ "parseable_solution_rate": 0.0,
+ "multiple_solution_tag_rate": 0.0,
+ "avg_completion_tokens": 0.0,
+ "avg_reward": 0.0,
+ }
+ return {
+ "split": split,
+ "model_name": model_name,
+ "seed": seed,
+ "num_samples": count,
+ "exact_match": sum(1.0 for row in rows if row["exact_match"]) / count,
+ "solution_tag_rate": sum(1.0 for row in rows if row["single_solution_tag"]) / count,
+ "parseable_solution_rate": sum(1.0 for row in rows if row["parseable_solution"]) / count,
+ "multiple_solution_tag_rate": sum(1.0 for row in rows if row["multiple_solution_tags"]) / count,
+ "avg_completion_tokens": sum(float(row["completion_tokens"]) for row in rows) / count,
+ "avg_reward": sum(float(row["reward"]) for row in rows) / count,
+ }
+
+
+def _aggregate_metrics(
+ per_split: Mapping[str, Mapping[str, Any]],
+ *,
+ model_name: str,
+ seed: int,
+) -> Dict[str, Any]:
+ all_rows = sum(int(metrics["num_samples"]) for metrics in per_split.values())
+ if all_rows == 0:
+ return _metrics_from_rows([], split="aggregate", model_name=model_name, seed=seed)
+ weighted = {}
+ for field in (
+ "exact_match",
+ "solution_tag_rate",
+ "parseable_solution_rate",
+ "multiple_solution_tag_rate",
+ "avg_completion_tokens",
+ "avg_reward",
+ ):
+ weighted[field] = sum(
+ float(metrics[field]) * int(metrics["num_samples"]) for metrics in per_split.values()
+ ) / all_rows
+ return {
+ "split": "aggregate",
+ "model_name": model_name,
+ "seed": seed,
+ "num_samples": all_rows,
+ **weighted,
+ }
+
+
+def evaluate_records(
+ model: Any,
+ tokenizer: Any,
+ rows: Sequence[Mapping[str, Any]],
+ *,
+ split: str,
+ model_name: str,
+ seed: int,
+ max_tokens: int,
+ temperature: float,
+) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+ evaluated: List[Dict[str, Any]] = []
+ for row in rows:
+ response_text = generate_completion(
+ model,
+ tokenizer,
+ row["messages"],
+ max_tokens=max_tokens,
+ temperature=temperature,
+ )
+ scored = score_solution_output(response_text, str(row["answer"]))
+ evaluated.append(
+ {
+ "task_id": row["task_id"],
+ "split": split,
+ "expression": row["expression"],
+ "difficulty": row["difficulty"],
+ "answer": row["answer"],
+ "completion_text": response_text,
+ "completion_tokens": _token_length(tokenizer, response_text),
+ **scored,
+ }
+ )
+ return evaluated, _metrics_from_rows(evaluated, split=split, model_name=model_name, seed=seed)
+
+
+def _run_eval(
+ model: Any,
+ tokenizer: Any,
+ output_dir: Path,
+ *,
+ model_name: str,
+ seed: int,
+ output_path: Path,
+ metrics_path: Path,
+ max_tokens: int,
+ temperature: float,
+) -> Dict[str, Any]:
+ all_outputs: List[Dict[str, Any]] = []
+ per_split: Dict[str, Dict[str, Any]] = {}
+ for split in ("val", "test"):
+ rows = _load_split_records(output_dir, split)
+ split_outputs, split_metrics = evaluate_records(
+ model,
+ tokenizer,
+ rows,
+ split=split,
+ model_name=model_name,
+ seed=seed,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ )
+ all_outputs.extend(split_outputs)
+ per_split[split] = split_metrics
+ _write_jsonl(output_path, all_outputs)
+ payload = {
+ "model_name": model_name,
+ "seed": seed,
+ "splits": per_split,
+ "aggregate": _aggregate_metrics(per_split, model_name=model_name, seed=seed),
+ }
+ _write_json(metrics_path, payload)
+ return payload
+
+
+def run_baseline(
+ output_dir: Path,
+ *,
+ model_name: str,
+ seed: int,
+ max_seq_length: int,
+ max_completion_length: int,
+) -> Dict[str, Any]:
+ model, tokenizer = load_model_bundle(
+ model_name,
+ max_seq_length=max_seq_length,
+ )
+ return _run_eval(
+ model,
+ tokenizer,
+ output_dir,
+ model_name=model_name,
+ seed=seed,
+ output_path=_baseline_outputs_path(output_dir),
+ metrics_path=_baseline_metrics_path(output_dir),
+ max_tokens=max_completion_length,
+ temperature=DEFAULT_BASELINE_TEMPERATURE,
+ )
+
+
+def run_training(
+ output_dir: Path,
+ *,
+ model_name: str,
+ seed: int,
+ max_seq_length: int,
+ max_completion_length: int,
+ max_steps: int,
+ learning_rate: float,
+ per_device_train_batch_size: int,
+ rollout_batch_size: int,
+ num_generations: int,
+ rl_temperature: float,
+ lora_rank: int,
+ logging_steps: int,
+ eval_steps: int,
+ save_steps: int,
+) -> Dict[str, Any]:
+ train_rows = _load_split_records(output_dir, "train")
+ val_rows = _load_split_records(output_dir, "val")
+ model, tokenizer = load_model_bundle(
+ model_name,
+ max_seq_length=max_seq_length,
+ for_training=True,
+ lora_rank=lora_rank,
+ )
+ reward = ArithmeticSolutionReward()
+ config = GRPOConfig(
+ loss_type="grpo",
+ learning_rate=learning_rate,
+ per_device_train_batch_size=per_device_train_batch_size,
+ rollout_batch_size=rollout_batch_size,
+ num_generations=num_generations,
+ temperature=rl_temperature,
+ max_steps=max_steps,
+ max_seq_length=max_seq_length,
+ max_completion_length=max_completion_length,
+ logging_steps=logging_steps,
+ eval_steps=eval_steps,
+ save_steps=save_steps,
+ reward_source="online",
+ reward_normalization="none",
+ mask_truncated_completions=True,
+ output_dir=str(_rl_output_dir(output_dir)),
+ seed=seed,
+ )
+ trainer = GRPOTrainer(
+ model=model,
+ train_dataset=train_rows,
+ eval_dataset=val_rows,
+ tokenizer=tokenizer,
+ reward_fn=reward,
+ args=config,
+ )
+ result = trainer.train()
+ _write_json(_training_summary_path(output_dir), result)
+ FastLanguageModel.for_inference(model)
+ post_metrics = _run_eval(
+ model,
+ tokenizer,
+ output_dir,
+ model_name=model_name,
+ seed=seed,
+ output_path=_post_outputs_path(output_dir),
+ metrics_path=_post_metrics_path(output_dir),
+ max_tokens=max_completion_length,
+ temperature=DEFAULT_BASELINE_TEMPERATURE,
+ )
+ return {"training": result, "post_eval": post_metrics}
+
+
+def _metric_delta(baseline: Mapping[str, Any], post: Mapping[str, Any]) -> Dict[str, float]:
+ return {
+ key: float(post[key]) - float(baseline[key])
+ for key in (
+ "exact_match",
+ "solution_tag_rate",
+ "parseable_solution_rate",
+ "multiple_solution_tag_rate",
+ "avg_completion_tokens",
+ "avg_reward",
+ )
+ }
+
+
+def _load_eval_rows(path: Path) -> Dict[Tuple[str, str], Dict[str, Any]]:
+ return {
+ (row["split"], row["task_id"]): row
+ for row in _read_jsonl(path)
+ }
+
+
+def _pair_examples(
+ baseline_rows: Mapping[Tuple[str, str], Mapping[str, Any]],
+ post_rows: Mapping[Tuple[str, str], Mapping[str, Any]],
+) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
+ improvements: List[Dict[str, Any]] = []
+ regressions: List[Dict[str, Any]] = []
+ unchanged_failures: List[Dict[str, Any]] = []
+ for key in sorted(set(baseline_rows) & set(post_rows)):
+ baseline = baseline_rows[key]
+ post = post_rows[key]
+ pair = {
+ "split": key[0],
+ "task_id": key[1],
+ "expression": baseline["expression"],
+ "answer": baseline["answer"],
+ "baseline_completion_text": baseline["completion_text"],
+ "post_completion_text": post["completion_text"],
+ "baseline_reward": baseline["reward"],
+ "post_reward": post["reward"],
+ "baseline_exact_match": baseline["exact_match"],
+ "post_exact_match": post["exact_match"],
+ "baseline_single_solution_tag": baseline["single_solution_tag"],
+ "post_single_solution_tag": post["single_solution_tag"],
+ }
+ if float(post["reward"]) > float(baseline["reward"]):
+ improvements.append(pair)
+ elif float(post["reward"]) < float(baseline["reward"]):
+ regressions.append(pair)
+ elif not baseline["exact_match"] and not post["exact_match"]:
+ unchanged_failures.append(pair)
+ return improvements, regressions, unchanged_failures
+
+
+def _render_examples_md(title: str, rows: Sequence[Mapping[str, Any]]) -> List[str]:
+ lines = [f"### {title}", ""]
+ if not rows:
+ lines.append("_None_")
+ lines.append("")
+ return lines
+ for row in rows:
+ lines.extend(
+ [
+ f"- `{row['split']}/{row['task_id']}` `{row['expression']}` -> `{row['answer']}`",
+ f" baseline: {row['baseline_completion_text']}",
+ f" post_rl: {row['post_completion_text']}",
+ "",
+ ]
+ )
+ return lines
+
+
+def run_compare(output_dir: Path) -> Dict[str, Any]:
+ baseline_metrics = _read_json(_baseline_metrics_path(output_dir))
+ post_metrics = _read_json(_post_metrics_path(output_dir))
+ comparison = {
+ "baseline": baseline_metrics,
+ "post_rl": post_metrics,
+ "delta": {
+ split: _metric_delta(baseline_metrics["splits"][split], post_metrics["splits"][split])
+ for split in ("val", "test")
+ },
+ "aggregate_delta": _metric_delta(
+ baseline_metrics["aggregate"],
+ post_metrics["aggregate"],
+ ),
+ }
+ baseline_rows = _load_eval_rows(_baseline_outputs_path(output_dir))
+ post_rows = _load_eval_rows(_post_outputs_path(output_dir))
+ improvements, regressions, unchanged_failures = _pair_examples(baseline_rows, post_rows)
+ comparison["sample_counts"] = {
+ "improvements": len(improvements),
+ "regressions": len(regressions),
+ "unchanged_failures": len(unchanged_failures),
+ }
+ _write_json(_comparison_json_path(output_dir), comparison)
+
+ lines = [
+ "# Qwen3 Arithmetic GRPO Comparison",
+ "",
+ "| Split | Baseline Exact | Post Exact | Delta Exact | Baseline Reward | Post Reward | Delta Reward | Baseline Tag | Post Tag | Delta Tag |",
+ "| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: |",
+ ]
+ for split in ("val", "test"):
+ baseline = baseline_metrics["splits"][split]
+ post = post_metrics["splits"][split]
+ delta = comparison["delta"][split]
+ lines.append(
+ "| {split} | {b_exact:.4f} | {p_exact:.4f} | {d_exact:+.4f} | "
+ "{b_reward:.4f} | {p_reward:.4f} | {d_reward:+.4f} | "
+ "{b_tag:.4f} | {p_tag:.4f} | {d_tag:+.4f} |".format(
+ split=split,
+ b_exact=baseline["exact_match"],
+ p_exact=post["exact_match"],
+ d_exact=delta["exact_match"],
+ b_reward=baseline["avg_reward"],
+ p_reward=post["avg_reward"],
+ d_reward=delta["avg_reward"],
+ b_tag=baseline["solution_tag_rate"],
+ p_tag=post["solution_tag_rate"],
+ d_tag=delta["solution_tag_rate"],
+ )
+ )
+ aggregate_delta = comparison["aggregate_delta"]
+ lines.extend(
+ [
+ "",
+ "## Conclusion",
+ "",
+ (
+ "GRPO appears to work on this benchmark."
+ if aggregate_delta["avg_reward"] > 0 or aggregate_delta["solution_tag_rate"] > 0
+ else "GRPO did not show a positive held-out signal on this run."
+ ),
+ "",
+ ]
+ )
+ lines.extend(_render_examples_md("Improvements", improvements[:10]))
+ lines.extend(_render_examples_md("Regressions", regressions[:5]))
+ lines.extend(_render_examples_md("Unchanged Failures", unchanged_failures[:5]))
+ _comparison_md_path(output_dir).write_text("\n".join(lines))
+ return comparison
+
+
+def _parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument("mode", choices=["generate", "baseline", "train", "compare", "all"])
+ parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR)
+ parser.add_argument("--model-name", default=DEFAULT_MODEL_NAME)
+ parser.add_argument("--train-size", type=int, default=DEFAULT_TRAIN_SIZE)
+ parser.add_argument("--val-size", type=int, default=DEFAULT_VAL_SIZE)
+ parser.add_argument("--test-size", type=int, default=DEFAULT_TEST_SIZE)
+ parser.add_argument("--seed", type=int, default=DEFAULT_SEED)
+ parser.add_argument("--force-generate", action="store_true")
+ parser.add_argument("--max-completion-length", type=int, default=DEFAULT_MAX_COMPLETION_LENGTH)
+ parser.add_argument("--max-seq-length", type=int, default=DEFAULT_MAX_SEQ_LENGTH)
+ parser.add_argument("--max-steps", type=int, default=500)
+ parser.add_argument("--learning-rate", type=float, default=1e-6)
+ parser.add_argument("--per-device-train-batch-size", type=int, default=2)
+ parser.add_argument("--rollout-batch-size", type=int, default=4)
+ parser.add_argument("--num-generations", type=int, default=4)
+ parser.add_argument("--rl-temperature", type=float, default=DEFAULT_RL_TEMPERATURE)
+ parser.add_argument("--lora-rank", type=int, default=DEFAULT_LORA_RANK)
+ parser.add_argument("--logging-steps", type=int, default=10)
+ parser.add_argument("--eval-steps", type=int, default=50)
+ parser.add_argument("--save-steps", type=int, default=100)
+ return parser.parse_args(argv)
+
+
+def main(argv: Optional[Sequence[str]] = None) -> int:
+ args = _parse_args(argv)
+ generate_benchmark_splits(
+ args.output_dir,
+ train_size=args.train_size,
+ val_size=args.val_size,
+ test_size=args.test_size,
+ seed=args.seed,
+ force=args.force_generate,
+ )
+ if args.mode == "generate":
+ return 0
+ if args.mode == "baseline":
+ run_baseline(
+ args.output_dir,
+ model_name=args.model_name,
+ seed=args.seed,
+ max_seq_length=args.max_seq_length,
+ max_completion_length=args.max_completion_length,
+ )
+ return 0
+ if args.mode == "train":
+ run_training(
+ args.output_dir,
+ model_name=args.model_name,
+ seed=args.seed,
+ max_seq_length=args.max_seq_length,
+ max_completion_length=args.max_completion_length,
+ max_steps=args.max_steps,
+ learning_rate=args.learning_rate,
+ per_device_train_batch_size=args.per_device_train_batch_size,
+ rollout_batch_size=args.rollout_batch_size,
+ num_generations=args.num_generations,
+ rl_temperature=args.rl_temperature,
+ lora_rank=args.lora_rank,
+ logging_steps=args.logging_steps,
+ eval_steps=args.eval_steps,
+ save_steps=args.save_steps,
+ )
+ return 0
+ if args.mode == "compare":
+ run_compare(args.output_dir)
+ return 0
+ run_baseline(
+ args.output_dir,
+ model_name=args.model_name,
+ seed=args.seed,
+ max_seq_length=args.max_seq_length,
+ max_completion_length=args.max_completion_length,
+ )
+ run_training(
+ args.output_dir,
+ model_name=args.model_name,
+ seed=args.seed,
+ max_seq_length=args.max_seq_length,
+ max_completion_length=args.max_completion_length,
+ max_steps=args.max_steps,
+ learning_rate=args.learning_rate,
+ per_device_train_batch_size=args.per_device_train_batch_size,
+ rollout_batch_size=args.rollout_batch_size,
+ num_generations=args.num_generations,
+ rl_temperature=args.rl_temperature,
+ lora_rank=args.lora_rank,
+ logging_steps=args.logging_steps,
+ eval_steps=args.eval_steps,
+ save_steps=args.save_steps,
+ )
+ run_compare(args.output_dir)
+ return 0
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())
diff --git a/mlx_tune/losses.py b/mlx_tune/losses.py
index 5625cd0..219e194 100644
--- a/mlx_tune/losses.py
+++ b/mlx_tune/losses.py
@@ -1,7 +1,7 @@
"""
Loss functions for MLX-Tune RL training.
-Provides proper loss implementations for:
+Provides native MLX losses and reference-logprob helpers for:
- DPO (Direct Preference Optimization)
- ORPO (Odds Ratio Preference Optimization)
- GRPO (Group Relative Policy Optimization)
@@ -9,10 +9,77 @@
- SimPO (Simple Preference Optimization)
"""
-from typing import Optional, Tuple, Callable, List, Any
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
import mlx.core as mx
import mlx.nn as nn
+from mlx_tune._rl_runtime import (
+ PolicyEvalBatch,
+ build_token_mask,
+ collect_rollouts,
+ completion_token_mask,
+ compute_advantages,
+ evaluate_rewards,
+ kl_against_reference,
+ length_mask,
+ make_policy_eval_batch,
+ normalize_logprobs,
+ sample_completion,
+ score_policy,
+ score_policy_in_chunks,
+)
+
+
+def _policy_eval_from_padded(
+ input_ids: mx.array,
+ lengths: mx.array,
+ mode: str = "sequence",
+ prompt_lengths: Optional[mx.array] = None,
+ completion_lengths: Optional[mx.array] = None,
+ rollout_logprobs: Optional[mx.array] = None,
+ old_logprobs: Optional[mx.array] = None,
+ old_token_logprobs: Optional[mx.array] = None,
+ reference_logprobs: Optional[mx.array] = None,
+ value_predictions: Optional[mx.array] = None,
+ returns: Optional[mx.array] = None,
+ advantages: Optional[mx.array] = None,
+ labels: Optional[mx.array] = None,
+) -> PolicyEvalBatch:
+ return PolicyEvalBatch(
+ input_ids=input_ids,
+ sequence_lengths=lengths,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ token_mask=build_token_mask(
+ input_ids=input_ids,
+ sequence_lengths=lengths,
+ mode=mode,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ ),
+ rollout_logprobs=rollout_logprobs,
+ old_logprobs=old_logprobs if old_logprobs is not None else rollout_logprobs,
+ old_token_logprobs=old_token_logprobs,
+ reference_logprobs=reference_logprobs,
+ value_predictions=value_predictions,
+ returns=returns,
+ advantages=advantages,
+ labels=labels,
+ )
+
+
+def _token_log_probs(
+ model: Any,
+ input_ids: mx.array,
+ temperature: float = 1.0,
+) -> mx.array:
+ batch = _policy_eval_from_padded(
+ input_ids=input_ids,
+ lengths=mx.array([input_ids.shape[1]] * input_ids.shape[0]),
+ )
+ return score_policy(model, batch, mode="sequence", temperature=temperature).token_logprobs
+
def compute_log_probs(
model: Any,
@@ -20,48 +87,12 @@ def compute_log_probs(
attention_mask: Optional[mx.array] = None,
) -> mx.array:
"""
- Compute per-token log probabilities for a batch of sequences.
-
- Args:
- model: The language model.
- input_ids: Token IDs of shape [batch_size, seq_len].
- attention_mask: Optional mask of shape [batch_size, seq_len].
-
- Returns:
- Log probabilities of shape [batch_size] (sum over sequence).
+ Compute per-sequence log probabilities for a batch of sequences.
"""
- # Get inputs (all tokens except last) and targets (all tokens except first)
- inputs = input_ids[:, :-1]
- targets = input_ids[:, 1:]
-
- # Forward pass to get logits
- logits = model(inputs) # [batch_size, seq_len-1, vocab_size]
-
- # Compute log softmax to get log probabilities
- log_probs = nn.log_softmax(logits, axis=-1) # [batch_size, seq_len-1, vocab_size]
-
- # Gather log probs for the actual target tokens
- # targets: [batch_size, seq_len-1]
- # We need to get log_probs[b, t, targets[b, t]] for each position
- batch_size, seq_len = targets.shape
-
- # Use advanced indexing to gather target log probs
- target_log_probs = mx.take_along_axis(
- log_probs,
- targets[:, :, None], # [batch_size, seq_len-1, 1]
- axis=-1
- ).squeeze(-1) # [batch_size, seq_len-1]
-
- # Apply attention mask if provided
+ token_log_probs = _token_log_probs(model, input_ids)
if attention_mask is not None:
- # Shift mask to match targets
- mask = attention_mask[:, 1:]
- target_log_probs = target_log_probs * mask
-
- # Sum log probs over sequence to get sequence log probability
- sequence_log_probs = target_log_probs.sum(axis=-1) # [batch_size]
-
- return sequence_log_probs
+ token_log_probs = token_log_probs * attention_mask[:, 1:].astype(token_log_probs.dtype)
+ return token_log_probs.sum(axis=-1)
def compute_log_probs_with_lengths(
@@ -70,38 +101,70 @@ def compute_log_probs_with_lengths(
lengths: mx.array,
) -> mx.array:
"""
- Compute per-token log probabilities with explicit length masking.
+ Compute per-sequence log probabilities with explicit length masking.
+ """
+ batch = _policy_eval_from_padded(input_ids=input_ids, lengths=lengths, mode="sequence")
+ return score_policy(model, batch, mode="sequence").summed_logprobs
- Args:
- model: The language model.
- input_ids: Token IDs of shape [batch_size, seq_len].
- lengths: Sequence lengths of shape [batch_size].
- Returns:
- Log probabilities of shape [batch_size] (sum over valid tokens).
+def compute_completion_log_probs(
+ model: Any,
+ input_ids: mx.array,
+ prompt_lengths: mx.array,
+ completion_lengths: mx.array,
+ temperature: float = 1.0,
+) -> mx.array:
"""
- inputs = input_ids[:, :-1]
- targets = input_ids[:, 1:]
-
- logits = model(inputs)
- log_probs = nn.log_softmax(logits, axis=-1)
+ Compute log probabilities over completion tokens only.
+ """
+ sequence_lengths = prompt_lengths + completion_lengths
+ batch = _policy_eval_from_padded(
+ input_ids=input_ids,
+ lengths=sequence_lengths,
+ mode="completion",
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ )
+ return score_policy(model, batch, mode="completion", temperature=temperature).summed_logprobs
+
+
+def _batched_sequence_log_probs(
+ model: Any,
+ input_ids: mx.array,
+ lengths: mx.array,
+ batch_size: int = 8,
+) -> mx.array:
+ batch = _policy_eval_from_padded(input_ids=input_ids, lengths=lengths, mode="sequence")
+ return score_policy_in_chunks(model, batch, batch_size=batch_size, mode="sequence").summed_logprobs
- target_log_probs = mx.take_along_axis(
- log_probs,
- targets[:, :, None],
- axis=-1
- ).squeeze(-1)
- # Create mask from lengths
- seq_len = targets.shape[1]
- positions = mx.arange(seq_len)[None, :] # [1, seq_len]
- mask = positions < lengths[:, None] # [batch_size, seq_len]
+def precompute_preference_reference_logprobs(
+ model: Any,
+ chosen_ids: mx.array,
+ rejected_ids: mx.array,
+ chosen_lengths: mx.array,
+ rejected_lengths: mx.array,
+ batch_size: int = 8,
+) -> Tuple[mx.array, mx.array]:
+ """
+ Precompute frozen-reference log probabilities for preference pairs.
+ """
+ ref_chosen = _batched_sequence_log_probs(model, chosen_ids, chosen_lengths, batch_size)
+ ref_rejected = _batched_sequence_log_probs(model, rejected_ids, rejected_lengths, batch_size)
+ return mx.stop_gradient(ref_chosen), mx.stop_gradient(ref_rejected)
- # Apply mask and sum
- masked_log_probs = target_log_probs * mask.astype(target_log_probs.dtype)
- sequence_log_probs = masked_log_probs.sum(axis=-1)
- return sequence_log_probs
+def precompute_kto_reference_logprobs(
+ model: Any,
+ input_ids: mx.array,
+ lengths: mx.array,
+ batch_size: int = 8,
+) -> mx.array:
+ """
+ Precompute frozen-reference log probabilities for KTO samples.
+ """
+ ref = _batched_sequence_log_probs(model, input_ids, lengths, batch_size)
+ return mx.stop_gradient(ref)
def dpo_loss(
@@ -116,60 +179,51 @@ def dpo_loss(
label_smoothing: float = 0.0,
) -> Tuple[mx.array, mx.array]:
"""
- Compute DPO (Direct Preference Optimization) loss.
-
- DPO Loss: -log(sigmoid(beta * (log_ratio_chosen - log_ratio_rejected)))
-
- Where:
- log_ratio = log_pi(y|x) - log_ref(y|x)
-
- Args:
- model: The policy model being trained.
- chosen_ids: Token IDs for chosen responses [batch_size, seq_len].
- rejected_ids: Token IDs for rejected responses [batch_size, seq_len].
- chosen_lengths: Lengths of chosen sequences [batch_size].
- rejected_lengths: Lengths of rejected sequences [batch_size].
- beta: KL penalty coefficient (temperature).
- reference_chosen_logprobs: Pre-computed reference log probs for chosen.
- reference_rejected_logprobs: Pre-computed reference log probs for rejected.
- label_smoothing: Label smoothing coefficient.
-
- Returns:
- Tuple of (loss, num_tokens).
+ Compute DPO loss.
"""
- # Compute policy model log probabilities
- log_pi_chosen = compute_log_probs_with_lengths(model, chosen_ids, chosen_lengths)
- log_pi_rejected = compute_log_probs_with_lengths(model, rejected_ids, rejected_lengths)
-
- # Handle reference model log probabilities
- if reference_chosen_logprobs is None or reference_rejected_logprobs is None:
- # Use current model with stop_gradient as reference (memory efficient)
- log_ref_chosen = mx.stop_gradient(log_pi_chosen)
- log_ref_rejected = mx.stop_gradient(log_pi_rejected)
- else:
- log_ref_chosen = reference_chosen_logprobs
- log_ref_rejected = reference_rejected_logprobs
-
- # Compute log ratios
- log_ratio_chosen = log_pi_chosen - log_ref_chosen
- log_ratio_rejected = log_pi_rejected - log_ref_rejected
-
- # DPO loss: -log(sigmoid(beta * (log_ratio_chosen - log_ratio_rejected)))
- logits = beta * (log_ratio_chosen - log_ratio_rejected)
-
+ chosen_batch = score_policy(
+ model,
+ _policy_eval_from_padded(
+ input_ids=chosen_ids,
+ lengths=chosen_lengths,
+ mode="sequence",
+ reference_logprobs=reference_chosen_logprobs,
+ ),
+ mode="sequence",
+ )
+ rejected_batch = score_policy(
+ model,
+ _policy_eval_from_padded(
+ input_ids=rejected_ids,
+ lengths=rejected_lengths,
+ mode="sequence",
+ reference_logprobs=reference_rejected_logprobs,
+ ),
+ mode="sequence",
+ )
+
+ log_pi_chosen = chosen_batch.summed_logprobs
+ log_pi_rejected = rejected_batch.summed_logprobs
+ log_ref_chosen = (
+ mx.stop_gradient(log_pi_chosen)
+ if chosen_batch.reference_logprobs is None
+ else chosen_batch.reference_logprobs
+ )
+ log_ref_rejected = (
+ mx.stop_gradient(log_pi_rejected)
+ if rejected_batch.reference_logprobs is None
+ else rejected_batch.reference_logprobs
+ )
+
+ logits = beta * ((log_pi_chosen - log_ref_chosen) - (log_pi_rejected - log_ref_rejected))
if label_smoothing > 0:
- # Smooth the labels
losses = (
-nn.log_sigmoid(logits) * (1 - label_smoothing)
- nn.log_sigmoid(-logits) * label_smoothing
)
else:
losses = -nn.log_sigmoid(logits)
-
- loss = mx.mean(losses)
- ntoks = chosen_lengths.sum() + rejected_lengths.sum()
-
- return loss, ntoks
+ return mx.mean(losses), chosen_lengths.sum() + rejected_lengths.sum()
def orpo_loss(
@@ -181,101 +235,51 @@ def orpo_loss(
beta: float = 0.1,
) -> Tuple[mx.array, mx.array]:
"""
- Compute ORPO (Odds Ratio Preference Optimization) loss.
-
- ORPO combines SFT loss with odds ratio preference loss:
- L = L_SFT + beta * L_OR
-
- Where:
- L_SFT = -log P(chosen)
- L_OR = -log(sigmoid(log(odds_ratio)))
- odds_ratio = P(chosen) / P(rejected)
-
- Args:
- model: The model being trained.
- chosen_ids: Token IDs for chosen responses.
- rejected_ids: Token IDs for rejected responses.
- chosen_lengths: Lengths of chosen sequences.
- rejected_lengths: Lengths of rejected sequences.
- beta: Weight for odds ratio loss.
-
- Returns:
- Tuple of (loss, num_tokens).
+ Compute ORPO loss.
"""
- # Compute log probabilities
log_pi_chosen = compute_log_probs_with_lengths(model, chosen_ids, chosen_lengths)
log_pi_rejected = compute_log_probs_with_lengths(model, rejected_ids, rejected_lengths)
+ avg_log_pi_chosen = normalize_logprobs(log_pi_chosen, chosen_lengths, mode="mean")
- # SFT loss on chosen (negative log likelihood)
- # Normalize by length for fair comparison
- avg_log_pi_chosen = log_pi_chosen / chosen_lengths.astype(log_pi_chosen.dtype)
- sft_loss = -mx.mean(avg_log_pi_chosen)
-
- # Odds ratio loss
- # log(odds_ratio) = log(P_chosen) - log(P_rejected)
- log_odds = log_pi_chosen - log_pi_rejected
- or_loss = -mx.mean(nn.log_sigmoid(log_odds))
-
- # Combined loss
- loss = sft_loss + beta * or_loss
-
- ntoks = chosen_lengths.sum() + rejected_lengths.sum()
- return loss, ntoks
+ sft_term = -mx.mean(avg_log_pi_chosen)
+ odds_term = -mx.mean(nn.log_sigmoid(log_pi_chosen - log_pi_rejected))
+ loss = sft_term + beta * odds_term
+ return loss, chosen_lengths.sum() + rejected_lengths.sum()
def kto_loss(
model: Any,
input_ids: mx.array,
lengths: mx.array,
- labels: mx.array, # 1 for positive, 0 for negative
+ labels: mx.array,
beta: float = 0.1,
reference_logprobs: Optional[mx.array] = None,
) -> Tuple[mx.array, mx.array]:
"""
- Compute KTO (Kahneman-Tversky Optimization) loss.
-
- KTO uses prospect theory with asymmetric treatment of gains and losses:
- L = -E[w(y) * log(sigmoid(beta * log_ratio))]
-
- Where w(y) = lambda if y is positive, 1 if y is negative.
-
- Args:
- model: The model being trained.
- input_ids: Token IDs [batch_size, seq_len].
- lengths: Sequence lengths [batch_size].
- labels: Binary labels (1=positive, 0=negative) [batch_size].
- beta: Temperature coefficient.
- reference_logprobs: Pre-computed reference log probs.
-
- Returns:
- Tuple of (loss, num_tokens).
+ Compute KTO loss.
"""
- # Compute policy log probs
- log_pi = compute_log_probs_with_lengths(model, input_ids, lengths)
-
- # Handle reference
- if reference_logprobs is None:
- log_ref = mx.stop_gradient(log_pi)
- else:
- log_ref = reference_logprobs
-
+ batch = score_policy(
+ model,
+ _policy_eval_from_padded(
+ input_ids=input_ids,
+ lengths=lengths,
+ mode="sequence",
+ reference_logprobs=reference_logprobs,
+ labels=labels,
+ ),
+ mode="sequence",
+ )
+ log_pi = batch.summed_logprobs
+ log_ref = mx.stop_gradient(log_pi) if batch.reference_logprobs is None else batch.reference_logprobs
log_ratio = log_pi - log_ref
- # KTO weights (lambda for positive, 1 for negative)
- lambda_weight = 1.0 # Can be tuned
- weights = mx.where(labels > 0.5, lambda_weight, 1.0)
-
- # Loss with asymmetric weights
positive_mask = labels > 0.5
negative_mask = ~positive_mask
-
+ weights = mx.where(positive_mask, 1.0, 1.0)
positive_loss = -nn.log_sigmoid(beta * log_ratio) * positive_mask
negative_loss = -nn.log_sigmoid(-beta * log_ratio) * negative_mask
-
loss = mx.mean(weights * (positive_loss + negative_loss))
- ntoks = lengths.sum()
-
- return loss, ntoks
+ return loss, lengths.sum()
def simpo_loss(
@@ -288,39 +292,138 @@ def simpo_loss(
gamma: float = 0.5,
) -> Tuple[mx.array, mx.array]:
"""
- Compute SimPO (Simple Preference Optimization) loss.
+ Compute SimPO loss.
+ """
+ log_pi_chosen = compute_log_probs_with_lengths(model, chosen_ids, chosen_lengths)
+ log_pi_rejected = compute_log_probs_with_lengths(model, rejected_ids, rejected_lengths)
+ r_chosen = normalize_logprobs(log_pi_chosen, chosen_lengths, mode="mean")
+ r_rejected = normalize_logprobs(log_pi_rejected, rejected_lengths, mode="mean")
+
+ logits = beta * (r_chosen - r_rejected - gamma)
+ return -mx.mean(nn.log_sigmoid(logits)), chosen_lengths.sum() + rejected_lengths.sum()
- SimPO simplifies DPO by removing the need for a reference model:
- L = -log(sigmoid(beta * (r_chosen - r_rejected - gamma)))
- Where r = log P(y|x) / |y| (length-normalized log prob).
+def pairwise_reward_loss(
+ chosen_scores: mx.array,
+ rejected_scores: mx.array,
+ margin: float = 0.0,
+) -> mx.array:
+ """
+ Logistic pairwise preference loss over scalar reward scores.
+ """
+ return -mx.mean(nn.log_sigmoid(chosen_scores - rejected_scores - margin))
+
+
+def reward_model_pairwise_loss(
+ reward_model: Any,
+ chosen_input_ids: mx.array,
+ rejected_input_ids: mx.array,
+ chosen_sequence_lengths: mx.array,
+ rejected_sequence_lengths: mx.array,
+ chosen_prompt_lengths: Optional[mx.array] = None,
+ rejected_prompt_lengths: Optional[mx.array] = None,
+ chosen_completion_lengths: Optional[mx.array] = None,
+ rejected_completion_lengths: Optional[mx.array] = None,
+ margin: float = 0.0,
+) -> Tuple[mx.array, Dict[str, mx.array]]:
+ """
+ Compute pairwise reward-model loss for chosen/rejected sequences.
+ """
+ chosen_scores, rejected_scores = reward_model.score_pairs(
+ chosen_input_ids=chosen_input_ids,
+ rejected_input_ids=rejected_input_ids,
+ chosen_sequence_lengths=chosen_sequence_lengths,
+ rejected_sequence_lengths=rejected_sequence_lengths,
+ chosen_prompt_lengths=chosen_prompt_lengths,
+ rejected_prompt_lengths=rejected_prompt_lengths,
+ chosen_completion_lengths=chosen_completion_lengths,
+ rejected_completion_lengths=rejected_completion_lengths,
+ )
+ loss = pairwise_reward_loss(chosen_scores, rejected_scores, margin=margin)
+ return loss, {
+ "chosen_scores": chosen_scores,
+ "rejected_scores": rejected_scores,
+ }
+
+
+def value_regression_loss(
+ predictions: mx.array,
+ targets: mx.array,
+ loss_type: str = "mse",
+) -> mx.array:
+ """
+ Compute a pointwise scalar regression loss.
+ """
+ if loss_type == "mse":
+ return mx.mean((predictions - targets) ** 2)
+ if loss_type == "mae":
+ return mx.mean(mx.abs(predictions - targets))
+ raise ValueError(f"Unsupported scalar regression loss: {loss_type}")
- Args:
- model: The model being trained.
- chosen_ids: Token IDs for chosen responses.
- rejected_ids: Token IDs for rejected responses.
- chosen_lengths: Lengths of chosen sequences.
- rejected_lengths: Lengths of rejected sequences.
- beta: Temperature coefficient.
- gamma: Target reward margin.
- Returns:
- Tuple of (loss, num_tokens).
+def value_model_regression_loss(
+ value_model: Any,
+ input_ids: mx.array,
+ sequence_lengths: mx.array,
+ targets: mx.array,
+ prompt_lengths: Optional[mx.array] = None,
+ completion_lengths: Optional[mx.array] = None,
+ loss_type: str = "mse",
+) -> Tuple[mx.array, mx.array]:
"""
- # Compute log probabilities
- log_pi_chosen = compute_log_probs_with_lengths(model, chosen_ids, chosen_lengths)
- log_pi_rejected = compute_log_probs_with_lengths(model, rejected_ids, rejected_lengths)
+ Compute pointwise regression loss for a scalar value model.
+ """
+ predictions = value_model.predict(
+ input_ids,
+ sequence_lengths=sequence_lengths,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ )
+ return value_regression_loss(predictions, targets, loss_type=loss_type), predictions
- # Length-normalize to get "reward"
- r_chosen = log_pi_chosen / chosen_lengths.astype(log_pi_chosen.dtype)
- r_rejected = log_pi_rejected / rejected_lengths.astype(log_pi_rejected.dtype)
- # SimPO loss
- logits = beta * (r_chosen - r_rejected - gamma)
- loss = -mx.mean(nn.log_sigmoid(logits))
+def reward_model_regression_loss(
+ reward_model: Any,
+ input_ids: mx.array,
+ sequence_lengths: mx.array,
+ targets: mx.array,
+ prompt_lengths: Optional[mx.array] = None,
+ completion_lengths: Optional[mx.array] = None,
+ loss_type: str = "mse",
+) -> Tuple[mx.array, mx.array]:
+ """
+ Compute pointwise regression loss for a scalar reward model.
+ """
+ predictions = reward_model.score(
+ input_ids,
+ sequence_lengths=sequence_lengths,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ )
+ return value_regression_loss(predictions, targets, loss_type=loss_type), predictions
+
- ntoks = chosen_lengths.sum() + rejected_lengths.sum()
- return loss, ntoks
+def scalar_loss_metrics(loss: mx.array, predictions: mx.array, targets: mx.array) -> Dict[str, float]:
+ """
+ Compute generic scalar regression metrics.
+ """
+ mae = mx.mean(mx.abs(predictions - targets))
+ mse = mx.mean((predictions - targets) ** 2)
+ return {
+ "loss": float(loss.item()),
+ "mae": float(mae.item()),
+ "mse": float(mse.item()),
+ }
+
+
+def pairwise_ranking_accuracy(
+ chosen_scores: mx.array,
+ rejected_scores: mx.array,
+) -> float:
+ """
+ Compute pairwise ranking accuracy for scalar preference scores.
+ """
+ return float(mx.mean((chosen_scores > rejected_scores).astype(mx.float32)).item())
def sft_loss(
@@ -329,37 +432,71 @@ def sft_loss(
lengths: mx.array,
) -> Tuple[mx.array, mx.array]:
"""
- Standard Supervised Fine-Tuning (cross-entropy) loss.
-
- Args:
- model: The model being trained.
- input_ids: Token IDs [batch_size, seq_len].
- lengths: Sequence lengths [batch_size].
-
- Returns:
- Tuple of (loss, num_tokens).
+ Standard supervised fine-tuning loss.
"""
inputs = input_ids[:, :-1]
targets = input_ids[:, 1:]
-
logits = model(inputs)
- # Create length mask
- seq_len = targets.shape[1]
- positions = mx.arange(seq_len)[None, :]
- mask = positions < lengths[:, None]
-
- # Cross entropy loss
- ce = nn.losses.cross_entropy(logits, targets, reduction='none')
- masked_ce = ce * mask.astype(ce.dtype)
-
+ mask = length_mask(lengths, targets.shape[1]).astype(logits.dtype)
+ ce = nn.losses.cross_entropy(logits, targets, reduction="none")
+ masked_ce = ce * mask
ntoks = mask.sum()
- loss = masked_ce.sum() / ntoks
+ return masked_ce.sum() / ntoks, ntoks
- return loss, ntoks
+def ppo_sequence_loss(
+ model: Any,
+ batch: PolicyEvalBatch,
+ beta: float = 0.0,
+ clip_epsilon: float = 0.2,
+ temperature: float = 1.0,
+ reference_model: Optional[Any] = None,
+) -> Tuple[mx.array, Dict[str, mx.array]]:
+ """
+ Compute a clipped PPO objective over full sampled completions.
+ """
+ scored_batch = score_policy(
+ model,
+ batch,
+ mode="completion",
+ reference_model=reference_model if batch.reference_logprobs is None else None,
+ temperature=temperature,
+ )
+ old_logprobs = batch.old_logprobs if batch.old_logprobs is not None else batch.rollout_logprobs
+ if old_logprobs is None:
+ raise ValueError("PPO loss requires stored old log probabilities.")
+ if batch.advantages is None:
+ raise ValueError("PPO loss requires advantages.")
+
+ ratios = mx.exp(scored_batch.summed_logprobs - old_logprobs)
+ clipped_ratios = mx.clip(ratios, 1.0 - clip_epsilon, 1.0 + clip_epsilon)
+ unclipped_objective = ratios * batch.advantages
+ clipped_objective = clipped_ratios * batch.advantages
+ policy_objective = mx.minimum(unclipped_objective, clipped_objective)
+
+ kl_penalty = mx.zeros_like(policy_objective)
+ if scored_batch.reference_logprobs is not None:
+ kl_scored_batch = scored_batch
+ if temperature != 1.0:
+ kl_scored_batch = score_policy(
+ model,
+ batch,
+ mode="completion",
+ reference_model=reference_model if batch.reference_logprobs is None else None,
+ temperature=1.0,
+ )
+ kl_penalty = kl_against_reference(
+ kl_scored_batch.summed_logprobs,
+ kl_scored_batch.reference_logprobs,
+ )
+ loss = -mx.mean(policy_objective - beta * kl_penalty)
+ return loss, {
+ "policy_logprobs": scored_batch.summed_logprobs,
+ "ratios": ratios,
+ "kl_penalty": kl_penalty,
+ }
-# GRPO-specific functions
def generate_with_log_probs(
model: Any,
@@ -369,57 +506,136 @@ def generate_with_log_probs(
temperature: float = 0.7,
) -> Tuple[mx.array, mx.array]:
"""
- Generate a completion and return token IDs with their log probabilities.
-
- Args:
- model: The language model.
- tokenizer: The tokenizer.
- prompt_ids: Prompt token IDs [seq_len].
- max_tokens: Maximum tokens to generate.
- temperature: Sampling temperature.
-
- Returns:
- Tuple of (generated_ids, log_probs) where:
- generated_ids: [prompt_len + gen_len]
- log_probs: [gen_len] log probability of each generated token
+ Generate a sampled completion and return sampled-token log probabilities.
"""
- generated_ids = list(prompt_ids.tolist()) if hasattr(prompt_ids, 'tolist') else list(prompt_ids)
- log_probs = []
-
- # Current sequence
- x = mx.array([generated_ids])
-
- for _ in range(max_tokens):
- # Get logits for next token
- logits = model(x)[:, -1, :] # [1, vocab_size]
-
- # Apply temperature
- if temperature > 0:
- logits = logits / temperature
- probs = mx.softmax(logits, axis=-1)
- # Sample from categorical distribution
- next_token = mx.random.categorical(mx.log(probs + 1e-10))
- else:
- # Greedy decoding
- next_token = mx.argmax(logits, axis=-1)
-
- next_token_id = next_token.item()
-
- # Get log probability of sampled token
- log_prob = nn.log_softmax(logits, axis=-1)[0, next_token_id]
- log_probs.append(log_prob)
-
- # Append to sequence
- generated_ids.append(next_token_id)
-
- # Check for EOS
- if hasattr(tokenizer, 'eos_token_id') and next_token_id == tokenizer.eos_token_id:
- break
-
- # Update input sequence
- x = mx.array([generated_ids])
-
- return mx.array(generated_ids), mx.stack(log_probs)
+ generated = sample_completion(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_ids=prompt_ids.tolist() if hasattr(prompt_ids, "tolist") else list(prompt_ids),
+ max_tokens=max_tokens,
+ temperature=temperature,
+ collect_sample_stats=False,
+ )
+ token_logprobs = generated["sampled_logprobs"]
+ if token_logprobs:
+ return mx.array(generated["generated_ids"]), mx.array(token_logprobs, dtype=mx.float32)
+ return mx.array(generated["generated_ids"]), mx.zeros((0,), dtype=mx.float32)
+
+
+def grpo_recompute_loss(
+ model: Any,
+ reference_model: Any,
+ input_ids: mx.array,
+ prompt_lengths: mx.array,
+ completion_lengths: mx.array,
+ rollout_logprobs: mx.array,
+ advantages: mx.array,
+ beta: float = 0.04,
+ clip_epsilon: float = 0.2,
+ epsilon_low: Optional[float] = None,
+ epsilon_high: Optional[float] = None,
+ temperature: float = 1.0,
+ loss_type: str = "grpo",
+ max_completion_length: Optional[int] = None,
+ old_token_logprobs: Optional[mx.array] = None,
+ reference_logprobs: Optional[mx.array] = None,
+) -> Tuple[mx.array, mx.array]:
+ """
+ Recompute GRPO-family losses on fixed sampled completions.
+ """
+ epsilon_low = clip_epsilon if epsilon_low is None else epsilon_low
+ epsilon_high = clip_epsilon if epsilon_high is None else epsilon_high
+ batch = _policy_eval_from_padded(
+ input_ids=input_ids,
+ lengths=prompt_lengths + completion_lengths,
+ mode="completion",
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ rollout_logprobs=rollout_logprobs,
+ old_logprobs=rollout_logprobs,
+ old_token_logprobs=old_token_logprobs,
+ reference_logprobs=reference_logprobs,
+ advantages=advantages,
+ )
+ scored_batch = score_policy(
+ model,
+ batch,
+ mode="completion",
+ reference_model=reference_model if reference_logprobs is None else None,
+ temperature=temperature,
+ )
+
+ sequence_logprobs = scored_batch.summed_logprobs
+ normalized_current = normalize_logprobs(sequence_logprobs, completion_lengths, mode="mean")
+ kl_scored_batch = scored_batch
+ if beta != 0.0 and temperature != 1.0:
+ kl_scored_batch = score_policy(
+ model,
+ batch,
+ mode="completion",
+ reference_model=reference_model if reference_logprobs is None else None,
+ temperature=1.0,
+ )
+ if kl_scored_batch.reference_logprobs is not None:
+ normalized_reference = normalize_logprobs(
+ kl_scored_batch.reference_logprobs,
+ completion_lengths,
+ mode="mean",
+ )
+ normalized_kl_current = normalize_logprobs(
+ kl_scored_batch.summed_logprobs,
+ completion_lengths,
+ mode="mean",
+ )
+ sequence_kl_penalty = kl_against_reference(normalized_kl_current, normalized_reference)
+ else:
+ sequence_kl_penalty = mx.zeros_like(advantages)
+
+ if loss_type == "gspo":
+ sequence_ratios = mx.exp(normalized_current - normalize_logprobs(rollout_logprobs, completion_lengths, mode="mean"))
+ clipped_sequence_ratios = mx.clip(sequence_ratios, 1.0 - epsilon_low, 1.0 + epsilon_high)
+ policy_objective = mx.minimum(sequence_ratios * advantages, clipped_sequence_ratios * advantages)
+ loss = -mx.mean(policy_objective - beta * sequence_kl_penalty)
+ return loss, completion_lengths.sum()
+
+ if old_token_logprobs is None:
+ sequence_ratios = mx.exp(sequence_logprobs - rollout_logprobs)
+ clipped_sequence_ratios = mx.clip(sequence_ratios, 1.0 - epsilon_low, 1.0 + epsilon_high)
+ policy_objective = mx.minimum(
+ sequence_ratios * advantages,
+ clipped_sequence_ratios * advantages,
+ )
+ loss = -mx.mean(policy_objective - beta * sequence_kl_penalty)
+ return loss, completion_lengths.sum()
+
+ mask = completion_token_mask(input_ids, prompt_lengths, completion_lengths).astype(mx.float32)
+ current_token_logprobs = scored_batch.token_logprobs
+ token_ratios = mx.exp(current_token_logprobs - old_token_logprobs)
+ clipped_token_ratios = mx.clip(token_ratios, 1.0 - epsilon_low, 1.0 + epsilon_high)
+ per_token_objective = mx.minimum(
+ token_ratios * advantages[:, None],
+ clipped_token_ratios * advantages[:, None],
+ )
+ masked_objective = per_token_objective * mask
+
+ if loss_type == "grpo":
+ policy_objective = masked_objective.sum(axis=-1) / mx.maximum(mask.sum(axis=-1), 1.0)
+ loss = -mx.mean(policy_objective - beta * sequence_kl_penalty)
+ elif loss_type == "bnpo" or loss_type == "dapo":
+ normalizer = mx.maximum(mask.sum(), 1.0)
+ loss = -(
+ masked_objective.sum() / normalizer
+ - beta * mx.mean(sequence_kl_penalty)
+ )
+ elif loss_type == "dr_grpo":
+ denominator = float(max_completion_length or int(mx.max(completion_lengths).item()) or 1)
+ loss = -(
+ masked_objective.sum() / max(completion_lengths.shape[0] * denominator, 1.0)
+ - beta * mx.mean(sequence_kl_penalty)
+ )
+ else:
+ raise ValueError(f"Unsupported GRPO loss_type: {loss_type}")
+ return loss, completion_lengths.sum()
def grpo_loss(
@@ -434,65 +650,29 @@ def grpo_loss(
beta: float = 0.04,
) -> Tuple[mx.array, int]:
"""
- Compute GRPO (Group Relative Policy Optimization) loss for a single prompt.
-
- GRPO:
- 1. Generates multiple completions for each prompt
- 2. Computes rewards for each completion
- 3. Uses group statistics for advantage estimation
- 4. Computes policy gradient loss
-
- Args:
- model: The policy model.
- tokenizer: The tokenizer.
- prompt_ids: Prompt token IDs.
- reward_fn: Function(completion, prompt) -> reward.
- prompt_text: Original prompt text for reward computation.
- num_generations: Number of completions to generate per prompt.
- temperature: Sampling temperature.
- max_tokens: Maximum tokens per completion.
- beta: KL penalty coefficient.
-
- Returns:
- Tuple of (loss, num_completions).
- """
- completions = []
- all_log_probs = []
-
- # Generate multiple completions
- for _ in range(num_generations):
- gen_ids, log_probs = generate_with_log_probs(
- model, tokenizer, prompt_ids,
- max_tokens=max_tokens,
- temperature=temperature,
- )
-
- # Decode completion (skip prompt)
- prompt_len = len(prompt_ids)
- completion_ids = gen_ids[prompt_len:]
- completion_text = tokenizer.decode(completion_ids.tolist())
-
- completions.append(completion_text)
- all_log_probs.append(log_probs.sum()) # Sum log probs
-
- # Compute rewards
- rewards = []
- for completion in completions:
- reward = reward_fn(completion, prompt_text)
- rewards.append(reward)
-
- rewards = mx.array(rewards)
- log_probs_tensor = mx.stack(all_log_probs)
-
- # Compute advantages using group statistics
- mean_reward = mx.mean(rewards)
- std_reward = mx.std(rewards) + 1e-8
- advantages = (rewards - mean_reward) / std_reward
-
- # Policy gradient loss: -E[advantage * log_prob]
- # We want to increase prob of high-advantage completions
- pg_loss = -mx.mean(advantages * log_probs_tensor)
-
+ Legacy log-only GRPO loss retained for compatibility.
+ """
+ del beta
+ rollout_batch = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {
+ "sample_index": 0,
+ "prompt": prompt_text,
+ "prompt_ids": prompt_ids.tolist() if hasattr(prompt_ids, "tolist") else list(prompt_ids),
+ "reward_context": prompt_text,
+ }
+ ],
+ sampling_config={
+ "num_generations": num_generations,
+ "temperature": temperature,
+ "max_tokens": max_tokens,
+ },
+ )
+ reward_batch = evaluate_rewards(rollout_batch, reward_fn)
+ advantages = compute_advantages(reward_batch)
+ pg_loss = -mx.mean(advantages * rollout_batch.rollout_logprobs)
return pg_loss, num_generations
@@ -507,47 +687,33 @@ def grpo_batch_loss(
beta: float = 0.04,
) -> Tuple[mx.array, int]:
"""
- Compute GRPO loss for a batch of prompts.
-
- Args:
- model: The policy model.
- tokenizer: The tokenizer.
- prompts: List of prompt strings.
- reward_fn: Reward function.
- num_generations: Completions per prompt.
- temperature: Sampling temperature.
- max_tokens: Max tokens per completion.
- beta: KL coefficient.
-
- Returns:
- Tuple of (average_loss, total_completions).
- """
- losses = []
- total_completions = 0
-
- for prompt in prompts:
- prompt_ids = mx.array(tokenizer.encode(prompt))
-
- loss, n_comp = grpo_loss(
- model=model,
- tokenizer=tokenizer,
- prompt_ids=prompt_ids,
- reward_fn=reward_fn,
- prompt_text=prompt,
- num_generations=num_generations,
- temperature=temperature,
- max_tokens=max_tokens,
- beta=beta,
- )
-
- losses.append(loss)
- total_completions += n_comp
-
- avg_loss = mx.mean(mx.stack(losses))
- return avg_loss, total_completions
-
+ Legacy batched GRPO loss retained for compatibility.
+ """
+ del beta
+ prompt_samples = [
+ {
+ "sample_index": index,
+ "prompt": prompt,
+ "prompt_ids": tokenizer.encode(prompt),
+ "reward_context": prompt,
+ }
+ for index, prompt in enumerate(prompts)
+ ]
+ rollout_batch = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=prompt_samples,
+ sampling_config={
+ "num_generations": num_generations,
+ "temperature": temperature,
+ "max_tokens": max_tokens,
+ },
+ )
+ reward_batch = evaluate_rewards(rollout_batch, reward_fn)
+ advantages = compute_advantages(reward_batch)
+ pg_loss = -mx.mean(advantages * rollout_batch.rollout_logprobs)
+ return pg_loss, len(prompt_samples) * num_generations
-# Utility function for batched DPO
def compute_reference_logprobs(
model: Any,
@@ -557,22 +723,12 @@ def compute_reference_logprobs(
rejected_lengths: mx.array,
) -> Tuple[mx.array, mx.array]:
"""
- Compute reference log probabilities (for frozen reference model).
-
- Call this once before training to get reference logprobs,
- then pass them to dpo_loss to avoid recomputation.
-
- Args:
- model: The reference model (should be frozen/not updated).
- chosen_ids: Chosen sequence token IDs.
- rejected_ids: Rejected sequence token IDs.
- chosen_lengths: Chosen sequence lengths.
- rejected_lengths: Rejected sequence lengths.
-
- Returns:
- Tuple of (ref_chosen_logprobs, ref_rejected_logprobs).
+ Backwards-compatible alias for batched DPO reference precompute.
"""
- ref_chosen = compute_log_probs_with_lengths(model, chosen_ids, chosen_lengths)
- ref_rejected = compute_log_probs_with_lengths(model, rejected_ids, rejected_lengths)
-
- return mx.stop_gradient(ref_chosen), mx.stop_gradient(ref_rejected)
+ return precompute_preference_reference_logprobs(
+ model,
+ chosen_ids,
+ rejected_ids,
+ chosen_lengths,
+ rejected_lengths,
+ )
diff --git a/mlx_tune/model.py b/mlx_tune/model.py
index 80026ca..15f0812 100644
--- a/mlx_tune/model.py
+++ b/mlx_tune/model.py
@@ -5,9 +5,14 @@
using Apple's MLX framework under the hood.
"""
-from typing import Optional, Tuple, Union, List, Any, Dict
+from dataclasses import dataclass
+import json
+from typing import Optional, Tuple, Union, List, Any, Dict, Mapping, Sequence
from pathlib import Path
+import copy
import mlx.core as mx
+import mlx.nn as nn
+from mlx.utils import tree_flatten, tree_unflatten
from mlx_lm import load as mlx_load
import warnings
@@ -128,7 +133,13 @@ def from_pretrained(
try:
# Load model using MLX (with config for saving later)
- model, tokenizer, config = mlx_load(model_name, return_config=True, **mlx_kwargs)
+ try:
+ model, tokenizer, config = mlx_load(model_name, return_config=True, **mlx_kwargs)
+ except TypeError as exc:
+ if "return_config" not in str(exc):
+ raise
+ model, tokenizer = mlx_load(model_name, **mlx_kwargs)
+ config = None
# Wrap model with our compatibility layer
wrapped_model = MLXModelWrapper(
@@ -489,6 +500,12 @@ def set_adapter_path(self, path: str) -> None:
"""
self._adapter_path = Path(path)
+ def has_adapters(self) -> bool:
+ """
+ Return whether this wrapper currently tracks adapter state.
+ """
+ return bool(self._lora_applied or self._adapter_path is not None)
+
def get_adapter_path(self) -> Optional[Path]:
"""
Get the current adapter path.
@@ -498,6 +515,142 @@ def get_adapter_path(self) -> Optional[Path]:
"""
return self._adapter_path
+ def clone(
+ self,
+ freeze: bool = False,
+ snapshot_adapters: bool = True,
+ copy_adapter_path: bool = False,
+ ) -> "MLXModelWrapper":
+ """
+ Deep-clone the wrapped model and optionally freeze the clone.
+ """
+ clone = MLXModelWrapper(
+ model=copy.deepcopy(self.model),
+ tokenizer=self.tokenizer,
+ max_seq_length=self.max_seq_length,
+ model_name=self.model_name,
+ config=copy.deepcopy(self.config),
+ )
+ clone.lora_config = copy.deepcopy(self.lora_config)
+ clone.lora_enabled = self.lora_enabled
+ clone._lora_applied = self._lora_applied
+ clone._adapter_path = (
+ Path(self._adapter_path) if copy_adapter_path and self._adapter_path is not None else None
+ )
+ clone.inference_mode = self.inference_mode
+ clone.use_cache = self.use_cache
+
+ source_actual = self.model
+ clone_actual = clone.model
+ clone_actual.update(source_actual.parameters(), strict=False)
+ if snapshot_adapters and self.has_adapters():
+ clone.load_adapter_state(self.snapshot_adapter_state(), strict=False)
+ mx.eval(clone_actual.parameters())
+
+ if freeze:
+ clone.freeze_parameters()
+ return clone
+
+ def freeze_parameters(self) -> None:
+ """
+ Freeze all parameters on the wrapped model.
+ """
+ if hasattr(self.model, "freeze"):
+ self.model.freeze()
+ mx.eval(self.model.parameters())
+
+ def snapshot_adapter_state(self) -> Dict[str, mx.array]:
+ """
+ Capture the current trainable adapter parameter state as a flat tree.
+ """
+ if not self.has_adapters():
+ return {}
+ return {
+ name: mx.array(value)
+ for name, value in tree_flatten(self.model.trainable_parameters())
+ }
+
+ def load_adapter_state(
+ self,
+ adapter_state: Mapping[str, mx.array],
+ strict: bool = False,
+ ) -> None:
+ """
+ Restore adapter parameters from a flat tree.
+ """
+ if not adapter_state:
+ return
+ self.model.update(tree_unflatten(list(adapter_state.items())), strict=strict)
+ mx.eval(self.model.parameters())
+
+ def build_adapter_config(self) -> Dict[str, Any]:
+ """
+ Build the mlx_lm-compatible adapter configuration for the current LoRA setup.
+ """
+ num_layers = None
+ if hasattr(self.model, "layers"):
+ num_layers = len(self.model.layers)
+ elif hasattr(self.model, "model") and hasattr(self.model.model, "layers"):
+ num_layers = len(self.model.model.layers)
+
+ lora_config = self.lora_config.copy() if self.lora_config else {}
+ r = lora_config.get("r", 16)
+ alpha = lora_config.get("lora_alpha", 16)
+ adapter_config = {
+ "fine_tune_type": "lora",
+ "num_layers": num_layers,
+ "lora_parameters": {
+ "rank": r,
+ "scale": alpha / r,
+ "dropout": lora_config.get("lora_dropout", 0.0),
+ },
+ }
+
+ target_modules = lora_config.get("target_modules", [])
+ if target_modules:
+ short_to_full = {
+ "q_proj": "self_attn.q_proj",
+ "k_proj": "self_attn.k_proj",
+ "v_proj": "self_attn.v_proj",
+ "o_proj": "self_attn.o_proj",
+ "gate_proj": "mlp.gate_proj",
+ "up_proj": "mlp.up_proj",
+ "down_proj": "mlp.down_proj",
+ }
+ adapter_config["lora_parameters"]["keys"] = [
+ short_to_full.get(module, module) for module in target_modules
+ ]
+ return adapter_config
+
+ def save_adapter_snapshot(self, output_dir: str) -> bool:
+ """
+ Persist the current adapter state in mlx_lm's adapter directory layout.
+ """
+ if not self.has_adapters():
+ return False
+
+ output_path = Path(output_dir)
+ output_path.mkdir(parents=True, exist_ok=True)
+ mx.save_safetensors(
+ str(output_path / "adapters.safetensors"),
+ self.snapshot_adapter_state(),
+ )
+ with open(output_path / "adapter_config.json", "w") as handle:
+ json.dump(self.build_adapter_config(), handle, indent=2)
+ self._adapter_path = output_path
+ return True
+
+ def load_adapter_snapshot(self, adapter_path: str, strict: bool = False) -> None:
+ """
+ Load adapter state from a role/checkpoint directory.
+ """
+ adapter_dir = Path(adapter_path)
+ adapter_file = adapter_dir / "adapters.safetensors"
+ if not adapter_file.exists():
+ raise FileNotFoundError(f"Missing adapters.safetensors under {adapter_dir}")
+ self.load_adapter_state(mx.load(str(adapter_file)), strict=strict)
+ self._adapter_path = adapter_dir
+
def enable_inference_mode(self, use_cache: bool = True):
"""
Enable inference mode optimizations.
@@ -788,17 +941,556 @@ def save_pretrained_gguf(
**kwargs
)
- def __call__(self, *args, **kwargs):
+ def forward_with_cache(self, *args, cache=None, **kwargs):
+ """
+ Forward pass with optional cache support for autoregressive decoding.
+ """
+ if cache is None:
+ return self.model(*args, **kwargs)
+ return self.model(*args, cache=cache, **kwargs)
+
+ def __call__(self, *args, cache=None, **kwargs):
"""
Forward pass through the model.
Note: This is a simplified interface. For training, use MLX's
training utilities directly.
"""
- return self.model(*args, **kwargs)
+ return self.forward_with_cache(*args, cache=cache, **kwargs)
def __getattr__(self, name):
"""
Delegate attribute access to the underlying MLX model.
"""
return getattr(self.model, name)
+
+
+class ReferencePolicy:
+ """
+ Frozen reference-policy wrapper used by native RL trainers.
+
+ A reference policy can either wrap an explicit reference model provided by
+ the caller or snapshot the current policy into a detached, frozen model
+ instance before RL optimization starts.
+ """
+
+ def __init__(
+ self,
+ model: Any,
+ source: str,
+ metadata: Optional[Dict[str, Any]] = None,
+ ):
+ self.model = model
+ self.source = source
+ self.metadata = metadata or {}
+ self._freeze()
+
+ @classmethod
+ def from_model(
+ cls,
+ policy_model: Any,
+ ref_model: Optional[Any] = None,
+ ) -> "ReferencePolicy":
+ return build_reference_policy(policy_model, ref_model=ref_model, snapshot=True)
+
+ @staticmethod
+ def _unwrap(model: Any) -> Any:
+ return model.model if hasattr(model, "model") else model
+
+ @classmethod
+ def _snapshot_model(cls, model: Any) -> Any:
+ if isinstance(model, MLXModelWrapper):
+ snapshot = MLXModelWrapper(
+ model=copy.deepcopy(model.model),
+ tokenizer=model.tokenizer,
+ max_seq_length=model.max_seq_length,
+ model_name=model.model_name,
+ config=copy.deepcopy(model.config),
+ )
+ snapshot.lora_config = copy.deepcopy(model.lora_config)
+ snapshot.lora_enabled = model.lora_enabled
+ snapshot._lora_applied = model._lora_applied
+ snapshot._adapter_path = model._adapter_path
+ snapshot.inference_mode = model.inference_mode
+ snapshot.use_cache = model.use_cache
+ else:
+ snapshot = copy.deepcopy(model)
+
+ source_actual = cls._unwrap(model)
+ snapshot_actual = cls._unwrap(snapshot)
+ snapshot_actual.update(source_actual.parameters())
+ mx.eval(snapshot_actual.parameters())
+ return snapshot
+
+ def _freeze(self) -> None:
+ actual_model = self._unwrap(self.model)
+ if hasattr(actual_model, "freeze"):
+ actual_model.freeze()
+ mx.eval(actual_model.parameters())
+
+
+def _actual_model(model: Any) -> Any:
+ return model.model if hasattr(model, "model") else model
+
+
+def _clone_role_model(model: Any, freeze: bool = False) -> Any:
+ if isinstance(model, MLXModelWrapper):
+ return model.clone(freeze=freeze, snapshot_adapters=True, copy_adapter_path=False)
+
+ snapshot = copy.deepcopy(model)
+ source_actual = _actual_model(model)
+ snapshot_actual = _actual_model(snapshot)
+ if hasattr(snapshot_actual, "update"):
+ snapshot_actual.update(source_actual.parameters(), strict=False)
+ mx.eval(snapshot_actual.parameters())
+ if freeze and hasattr(snapshot_actual, "freeze"):
+ snapshot_actual.freeze()
+ mx.eval(snapshot_actual.parameters())
+ if hasattr(snapshot, "_adapter_path"):
+ snapshot._adapter_path = None
+ return snapshot
+
+
+def build_reference_policy(
+ policy_model: Any,
+ ref_model: Optional[Any] = None,
+ snapshot: bool = True,
+) -> ReferencePolicy:
+ """
+ Build an explicit frozen reference policy for RL training.
+ """
+ source_model = ref_model if ref_model is not None else policy_model
+ if snapshot:
+ role_model = _clone_role_model(source_model, freeze=True)
+ source = "explicit_snapshot" if ref_model is not None else "policy_snapshot"
+ strategy = "clone_and_freeze"
+ else:
+ role_model = source_model
+ actual_model = _actual_model(role_model)
+ if hasattr(actual_model, "freeze"):
+ actual_model.freeze()
+ mx.eval(actual_model.parameters())
+ source = "explicit_live" if ref_model is not None else "policy_live"
+ strategy = "freeze_in_place"
+
+ metadata = {
+ "source": source,
+ "snapshot_strategy": strategy,
+ "model_name": getattr(source_model, "model_name", None),
+ "adapter_path": (
+ str(source_model.get_adapter_path())
+ if hasattr(source_model, "get_adapter_path") and source_model.get_adapter_path() is not None
+ else None
+ ),
+ }
+ return ReferencePolicy(role_model, source=source, metadata=metadata)
+
+
+def _infer_hidden_size(module: Any) -> int:
+ if hasattr(module, "args"):
+ for attr in ("hidden_size", "dim"):
+ value = getattr(module.args, attr, None)
+ if value is not None:
+ return int(value)
+ for attr in ("hidden_size", "n_embd", "dim"):
+ value = getattr(module, attr, None)
+ if value is not None:
+ return int(value)
+ if hasattr(module, "embed_tokens") and hasattr(module.embed_tokens, "weight"):
+ return int(module.embed_tokens.weight.shape[-1])
+ if hasattr(module, "embedding") and hasattr(module.embedding, "weight"):
+ return int(module.embedding.weight.shape[-1])
+ raise ValueError("Could not infer backbone hidden size for scalar head construction.")
+
+
+def _resolve_hidden_backbone(model: Any) -> Any:
+ actual_model = _actual_model(model)
+ hidden_backbone = getattr(actual_model, "model", None)
+ if callable(hidden_backbone):
+ return hidden_backbone
+ if callable(actual_model):
+ return actual_model
+ raise ValueError("Scalar-head roles require a callable causal-LM backbone.")
+
+
+def _pad_sequence_batch(
+ sequences: Sequence[Sequence[int]],
+ pad_id: int,
+) -> Tuple[mx.array, mx.array]:
+ max_length = max(len(sequence) for sequence in sequences)
+ padded = [list(sequence) + [pad_id] * (max_length - len(sequence)) for sequence in sequences]
+ lengths = [len(sequence) for sequence in sequences]
+ return mx.array(padded), mx.array(lengths)
+
+
+def _flat_parameter_state(model: Any) -> Dict[str, mx.array]:
+ actual_model = _actual_model(model)
+ return {name: mx.array(value) for name, value in tree_flatten(actual_model.parameters())}
+
+
+def _load_flat_parameter_state(
+ model: Any,
+ parameter_state: Mapping[str, mx.array],
+ strict: bool = False,
+) -> None:
+ if not parameter_state:
+ return
+ actual_model = _actual_model(model)
+ actual_model.update(tree_unflatten(list(parameter_state.items())), strict=strict)
+ mx.eval(actual_model.parameters())
+
+
+@dataclass
+class RLModelRoles:
+ policy_model: Any
+ reference_policy: ReferencePolicy
+ reward_model: Optional["RewardModel"] = None
+ value_model: Optional["ValueModel"] = None
+
+ def as_dict(self) -> Dict[str, Any]:
+ return {
+ "policy": self.policy_model,
+ "reference": self.reference_policy,
+ "reward_model": self.reward_model,
+ "value_model": self.value_model,
+ }
+
+
+class ScalarHeadModel:
+ role_name = "scalar_model"
+
+ def __init__(
+ self,
+ base_model: Any,
+ pooling: str = "last_token",
+ target: str = "completion",
+ head: Optional[nn.Linear] = None,
+ head_config: Optional[Dict[str, Any]] = None,
+ ):
+ self.base_model = base_model
+ self.tokenizer = getattr(base_model, "tokenizer", None)
+ self.pooling = pooling
+ self.target = target
+ backbone_module = _resolve_hidden_backbone(base_model)
+ self._hidden_backbone = backbone_module
+ hidden_size = _infer_hidden_size(backbone_module)
+ self.head = head or nn.Linear(hidden_size, 1)
+ self.head_config = {
+ "role": self.role_name,
+ "pooling": pooling,
+ "target": target,
+ "hidden_size": hidden_size,
+ }
+ if head_config:
+ self.head_config.update(head_config)
+ self.pooling = self.head_config.get("pooling", self.pooling)
+ self.target = self.head_config.get("target", self.target)
+
+ @classmethod
+ def from_pretrained(cls, base_model: Any, output_dir: str) -> "ScalarHeadModel":
+ output_path = Path(output_dir)
+ with open(output_path / "head_config.json") as handle:
+ head_config = json.load(handle)
+ instance = cls(
+ base_model=base_model,
+ pooling=head_config.get("pooling", "last_token"),
+ target=head_config.get("target", "completion"),
+ head_config=head_config,
+ )
+ instance.load_pretrained(output_dir)
+ return instance
+
+ def _normalize_batch_inputs(
+ self,
+ input_ids: Union[mx.array, Sequence[Sequence[int]]],
+ sequence_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ prompt_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ completion_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ ) -> Tuple[mx.array, mx.array, Optional[mx.array], Optional[mx.array]]:
+ if hasattr(input_ids, "shape"):
+ array_input_ids = input_ids
+ if sequence_lengths is None:
+ raise ValueError("sequence_lengths is required when input_ids is already padded.")
+ array_sequence_lengths = (
+ sequence_lengths if hasattr(sequence_lengths, "shape") else mx.array(sequence_lengths)
+ )
+ else:
+ sequences = [list(sequence) for sequence in input_ids]
+ pad_id = int(getattr(self.tokenizer, "pad_token_id", 0) or 0)
+ array_input_ids, array_sequence_lengths = _pad_sequence_batch(sequences, pad_id)
+
+ prompt_lengths_array = None
+ completion_lengths_array = None
+ if prompt_lengths is not None:
+ prompt_lengths_array = prompt_lengths if hasattr(prompt_lengths, "shape") else mx.array(prompt_lengths)
+ if completion_lengths is not None:
+ completion_lengths_array = (
+ completion_lengths
+ if hasattr(completion_lengths, "shape")
+ else mx.array(completion_lengths)
+ )
+ return array_input_ids, array_sequence_lengths, prompt_lengths_array, completion_lengths_array
+
+ def _hidden_states(self, input_ids: mx.array) -> mx.array:
+ hidden_states = self._hidden_backbone(input_ids)
+ if isinstance(hidden_states, tuple):
+ hidden_states = hidden_states[0]
+ if hasattr(hidden_states, "last_hidden_state"):
+ hidden_states = hidden_states.last_hidden_state
+ return hidden_states
+
+ def _last_indices(
+ self,
+ sequence_lengths: mx.array,
+ prompt_lengths: Optional[mx.array],
+ completion_lengths: Optional[mx.array],
+ ) -> mx.array:
+ if self.target == "completion":
+ if prompt_lengths is None or completion_lengths is None:
+ raise ValueError("Completion scalar scoring requires prompt_lengths and completion_lengths.")
+ completion_last = prompt_lengths + completion_lengths - 1
+ prompt_last = mx.maximum(prompt_lengths - 1, 0)
+ has_completion = completion_lengths > 0
+ return mx.where(has_completion, completion_last, prompt_last)
+ return mx.maximum(sequence_lengths - 1, 0)
+
+ def _token_mask(
+ self,
+ width: int,
+ sequence_lengths: mx.array,
+ prompt_lengths: Optional[mx.array],
+ completion_lengths: Optional[mx.array],
+ ) -> mx.array:
+ positions = mx.arange(width)[None, :]
+ if self.pooling == "mean_sequence":
+ return positions < sequence_lengths[:, None]
+ if self.pooling == "mean_completion":
+ if prompt_lengths is None or completion_lengths is None:
+ raise ValueError("Completion pooling requires prompt_lengths and completion_lengths.")
+ start = prompt_lengths[:, None]
+ end = (prompt_lengths + completion_lengths)[:, None]
+ mask = (positions >= start) & (positions < end)
+ has_completion = completion_lengths[:, None] > 0
+ fallback = positions == mx.maximum(prompt_lengths - 1, 0)[:, None]
+ return mx.where(has_completion, mask, fallback)
+ raise ValueError(f"Unsupported scalar pooling mode: {self.pooling}")
+
+ def _pool_hidden_states(
+ self,
+ hidden_states: mx.array,
+ sequence_lengths: mx.array,
+ prompt_lengths: Optional[mx.array],
+ completion_lengths: Optional[mx.array],
+ ) -> mx.array:
+ if self.pooling == "last_token":
+ indices = self._last_indices(sequence_lengths, prompt_lengths, completion_lengths)
+ batch_indices = mx.arange(hidden_states.shape[0])
+ return hidden_states[batch_indices, indices]
+
+ token_mask = self._token_mask(
+ hidden_states.shape[1],
+ sequence_lengths,
+ prompt_lengths,
+ completion_lengths,
+ )
+ weights = token_mask.astype(hidden_states.dtype)[:, :, None]
+ totals = weights.sum(axis=1)
+ totals = mx.maximum(totals, 1.0)
+ return (hidden_states * weights).sum(axis=1) / totals
+
+ def score(
+ self,
+ input_ids: Union[mx.array, Sequence[Sequence[int]]],
+ sequence_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ prompt_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ completion_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ ) -> mx.array:
+ array_input_ids, array_sequence_lengths, prompt_lengths_array, completion_lengths_array = (
+ self._normalize_batch_inputs(
+ input_ids=input_ids,
+ sequence_lengths=sequence_lengths,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ )
+ )
+ hidden_states = self._hidden_states(array_input_ids)
+ pooled = self._pool_hidden_states(
+ hidden_states,
+ array_sequence_lengths,
+ prompt_lengths_array,
+ completion_lengths_array,
+ )
+ return self.head(pooled).squeeze(-1)
+
+ def save_pretrained(self, output_dir: str) -> None:
+ output_path = Path(output_dir)
+ output_path.mkdir(parents=True, exist_ok=True)
+ mx.save_safetensors(
+ str(output_path / "weights.safetensors"),
+ _flat_parameter_state(self.base_model),
+ )
+ mx.save_safetensors(
+ str(output_path / "head.safetensors"),
+ dict(tree_flatten(self.head.parameters())),
+ )
+ with open(output_path / "head_config.json", "w") as handle:
+ json.dump(self.head_config, handle, indent=2)
+ with open(output_path / "role.json", "w") as handle:
+ json.dump(
+ {
+ "role": self.role_name,
+ "pooling": self.pooling,
+ "target": self.target,
+ "backbone_weight_format": "weights.safetensors",
+ "backbone_has_adapters": bool(
+ isinstance(self.base_model, MLXModelWrapper) and self.base_model.has_adapters()
+ ),
+ },
+ handle,
+ indent=2,
+ )
+ if isinstance(self.base_model, MLXModelWrapper):
+ self.base_model.save_adapter_snapshot(str(output_path))
+
+ def load_pretrained(self, output_dir: str) -> None:
+ output_path = Path(output_dir)
+ weights_path = output_path / "weights.safetensors"
+ if weights_path.exists():
+ _load_flat_parameter_state(self.base_model, mx.load(str(weights_path)), strict=False)
+ head_weights = mx.load(str(output_path / "head.safetensors"))
+ self.head.update(tree_unflatten(list(head_weights.items())), strict=False)
+ mx.eval(self.head.parameters())
+ if isinstance(self.base_model, MLXModelWrapper):
+ adapter_file = output_path / "adapters.safetensors"
+ if adapter_file.exists():
+ self.base_model.load_adapter_snapshot(str(output_path), strict=False)
+ config_path = output_path / "head_config.json"
+ if config_path.exists():
+ with open(config_path) as handle:
+ self.head_config = json.load(handle)
+ self.pooling = self.head_config.get("pooling", self.pooling)
+ self.target = self.head_config.get("target", self.target)
+
+
+class RewardModel(ScalarHeadModel):
+ role_name = "reward_model"
+
+ def score_pairs(
+ self,
+ chosen_input_ids: Union[mx.array, Sequence[Sequence[int]]],
+ rejected_input_ids: Union[mx.array, Sequence[Sequence[int]]],
+ chosen_sequence_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ rejected_sequence_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ chosen_prompt_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ rejected_prompt_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ chosen_completion_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ rejected_completion_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ ) -> Tuple[mx.array, mx.array]:
+ return (
+ self.score(
+ chosen_input_ids,
+ sequence_lengths=chosen_sequence_lengths,
+ prompt_lengths=chosen_prompt_lengths,
+ completion_lengths=chosen_completion_lengths,
+ ),
+ self.score(
+ rejected_input_ids,
+ sequence_lengths=rejected_sequence_lengths,
+ prompt_lengths=rejected_prompt_lengths,
+ completion_lengths=rejected_completion_lengths,
+ ),
+ )
+
+ def evaluate(self, payload: Dict[str, Any]) -> float:
+ sequence = [list(payload["prompt_ids"]) + list(payload["completion_ids"])]
+ scores = self.score(
+ sequence,
+ prompt_lengths=[len(payload["prompt_ids"])],
+ completion_lengths=[len(payload["completion_ids"])],
+ )
+ return float(scores[0].item())
+
+
+class ValueModel(ScalarHeadModel):
+ role_name = "value_model"
+
+ def predict(
+ self,
+ input_ids: Union[mx.array, Sequence[Sequence[int]]],
+ sequence_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ prompt_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ completion_lengths: Optional[Union[mx.array, Sequence[int]]] = None,
+ ) -> mx.array:
+ return self.score(
+ input_ids,
+ sequence_lengths=sequence_lengths,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ )
+
+
+def build_reward_model(
+ base_model: Any,
+ pooling: str = "last_token",
+ target: str = "completion",
+ snapshot: bool = True,
+ head_config: Optional[Dict[str, Any]] = None,
+) -> RewardModel:
+ role_model = _clone_role_model(base_model, freeze=False) if snapshot else base_model
+ if snapshot and hasattr(role_model, "_adapter_path"):
+ role_model._adapter_path = None
+ return RewardModel(role_model, pooling=pooling, target=target, head_config=head_config)
+
+
+def build_value_model(
+ base_model: Any,
+ pooling: str = "last_token",
+ target: str = "completion",
+ snapshot: bool = True,
+ head_config: Optional[Dict[str, Any]] = None,
+) -> ValueModel:
+ role_model = _clone_role_model(base_model, freeze=False) if snapshot else base_model
+ if snapshot and hasattr(role_model, "_adapter_path"):
+ role_model._adapter_path = None
+ return ValueModel(role_model, pooling=pooling, target=target, head_config=head_config)
+
+
+def create_rl_model_roles(
+ policy_model: Any,
+ ref_model: Optional[Any] = None,
+ reward_model: Optional[RewardModel] = None,
+ value_model: Optional[ValueModel] = None,
+ reward_base_model: Optional[Any] = None,
+ value_base_model: Optional[Any] = None,
+ reference_snapshot: bool = True,
+ reward_pooling: str = "last_token",
+ reward_target: str = "completion",
+ value_pooling: str = "last_token",
+ value_target: str = "completion",
+) -> RLModelRoles:
+ resolved_reward_model = reward_model
+ if resolved_reward_model is None and reward_base_model is not None:
+ resolved_reward_model = build_reward_model(
+ reward_base_model,
+ pooling=reward_pooling,
+ target=reward_target,
+ )
+
+ resolved_value_model = value_model
+ if resolved_value_model is None and value_base_model is not None:
+ resolved_value_model = build_value_model(
+ value_base_model,
+ pooling=value_pooling,
+ target=value_target,
+ )
+
+ return RLModelRoles(
+ policy_model=policy_model,
+ reference_policy=build_reference_policy(
+ policy_model,
+ ref_model=ref_model,
+ snapshot=reference_snapshot,
+ ),
+ reward_model=resolved_reward_model,
+ value_model=resolved_value_model,
+ )
diff --git a/mlx_tune/rl_api.py b/mlx_tune/rl_api.py
new file mode 100644
index 0000000..2f1a0ba
--- /dev/null
+++ b/mlx_tune/rl_api.py
@@ -0,0 +1,878 @@
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+import inspect
+import json
+from pathlib import Path
+from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence
+import warnings
+
+import mlx.core as mx
+from mlx.utils import tree_unflatten
+
+from mlx_tune.chat_templates import apply_chat_template_to_sample
+from mlx_tune.model import build_reference_policy, build_reward_model
+
+
+MANIFEST_FILE = "manifest.json"
+STATE_FILE = "trainer_state.safetensors"
+METADATA_FILE = "trainer_state.json"
+REFERENCE_FILE = "reference_model.safetensors"
+REFERENCE_METADATA_FILE = "reference_metadata.json"
+CHECKPOINT_FORMAT_NAME = "mlx_tune_rl_checkpoint"
+CHECKPOINT_FORMAT_VERSION = 4
+SUPPORTED_RL_DATASET_MODES = (
+ "prompt",
+ "preference",
+ "reward_scalar",
+ "reward_pairwise",
+ "rollout",
+ "chat",
+)
+
+
+@dataclass
+class PreparedRLDataset:
+ samples: List[Dict[str, Any]]
+ mode: str
+ adapter_name: str
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
+ return iter(self.samples)
+
+ def __len__(self) -> int:
+ return len(self.samples)
+
+ def __getitem__(self, index: int) -> Dict[str, Any]:
+ return self.samples[index]
+
+
+@dataclass
+class RLRoleState:
+ role: str
+ weight_format: Any
+ parameter_state: Dict[str, mx.array] = field(default_factory=dict)
+ head_state: Dict[str, mx.array] = field(default_factory=dict)
+ adapter_state: Dict[str, mx.array] = field(default_factory=dict)
+ head_config: Dict[str, Any] = field(default_factory=dict)
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class RLCheckpointBundle:
+ manifest: Dict[str, Any]
+ algorithm: str
+ restored_roles: Dict[str, RLRoleState]
+ optimizer_state_trees: Dict[str, Dict[str, Any]]
+ scheduler_metadata: Dict[str, Dict[str, Any]]
+ trainer_state: Dict[str, Any]
+ rng_state: Dict[str, mx.array]
+ runtime_cache: Dict[str, mx.array]
+ metrics_history: List[Dict[str, Any]]
+ source_format: str
+
+
+def _read_json(path: Path, default: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
+ if not path.exists():
+ return {} if default is None else default
+ with open(path) as handle:
+ return json.load(handle)
+
+
+def _load_jsonl(path: Path) -> List[Dict[str, Any]]:
+ if not path.exists():
+ return []
+ rows: List[Dict[str, Any]] = []
+ with open(path) as handle:
+ for line in handle:
+ line = line.strip()
+ if line:
+ rows.append(json.loads(line))
+ return rows
+
+
+def _extract_prefixed_tree(prefix: str, flat_state: Dict[str, mx.array]) -> Dict[str, Any]:
+ prefix_with_dot = f"{prefix}."
+ items = [(key[len(prefix_with_dot):], value) for key, value in flat_state.items() if key.startswith(prefix_with_dot)]
+ return tree_unflatten(items) if items else {}
+
+
+def _role_dir(checkpoint_dir: Path, role_name: str) -> Path:
+ return checkpoint_dir / role_name
+
+
+def _legacy_checkpoint_exists(checkpoint_dir: Path) -> bool:
+ return (
+ (checkpoint_dir / "adapters" / "adapters.safetensors").exists()
+ or (checkpoint_dir / STATE_FILE).exists()
+ or (checkpoint_dir / REFERENCE_FILE).exists()
+ )
+
+
+def _metrics_path(checkpoint_dir: Path) -> Path:
+ return checkpoint_dir / "metrics" / "history.jsonl"
+
+
+def _runtime_cache_path(checkpoint_dir: Path) -> Path:
+ return checkpoint_dir / "runtime" / "cache.safetensors"
+
+
+def _trainer_state_path(checkpoint_dir: Path) -> Path:
+ return checkpoint_dir / "trainer" / "state.json"
+
+
+def _rng_path(checkpoint_dir: Path) -> Path:
+ return checkpoint_dir / "trainer" / "rng.safetensors"
+
+
+def _scheduler_path(checkpoint_dir: Path, name: Optional[str] = None) -> Path:
+ if name is None:
+ return checkpoint_dir / "scheduler" / "state.json"
+ return checkpoint_dir / "schedulers" / name / "state.json"
+
+
+def _optimizer_path(checkpoint_dir: Path, name: Optional[str] = None) -> Path:
+ if name is None:
+ return checkpoint_dir / "optimizer" / "state.safetensors"
+ return checkpoint_dir / "optimizers" / name / "state.safetensors"
+
+
+def _load_role_state(checkpoint_dir: Path, role_name: str, weight_format: Any) -> RLRoleState:
+ role_dir = _role_dir(checkpoint_dir, role_name)
+ metadata = _read_json(role_dir / "metadata.json")
+ head_config = _read_json(role_dir / "head_config.json")
+ parameter_state: Dict[str, mx.array] = {}
+ head_state: Dict[str, mx.array] = {}
+ adapter_state: Dict[str, mx.array] = {}
+
+ if isinstance(weight_format, str):
+ weight_path = role_dir / weight_format
+ if weight_path.exists():
+ parameter_state = dict(mx.load(str(weight_path)))
+ elif isinstance(weight_format, Mapping):
+ backbone_path = role_dir / str(weight_format.get("backbone", "weights.safetensors"))
+ head_path = role_dir / str(weight_format.get("head", "head.safetensors"))
+ adapters_name = weight_format.get("adapters")
+ if backbone_path.exists():
+ parameter_state = dict(mx.load(str(backbone_path)))
+ if head_path.exists():
+ head_state = dict(mx.load(str(head_path)))
+ if adapters_name:
+ adapter_path = role_dir / str(adapters_name)
+ if adapter_path.exists():
+ adapter_state = dict(mx.load(str(adapter_path)))
+
+ return RLRoleState(
+ role=role_name,
+ weight_format=weight_format,
+ parameter_state=parameter_state,
+ head_state=head_state,
+ adapter_state=adapter_state,
+ head_config=head_config,
+ metadata=metadata,
+ )
+
+
+def _build_manifest_bundle(checkpoint_dir: Path) -> RLCheckpointBundle:
+ manifest = _read_json(checkpoint_dir / MANIFEST_FILE)
+ if not manifest:
+ raise FileNotFoundError(f"Checkpoint manifest not found under {checkpoint_dir}")
+
+ roles = {
+ role_name: _load_role_state(
+ checkpoint_dir,
+ role_name,
+ manifest.get("role_weight_formats", {}).get(role_name, "weights.safetensors"),
+ )
+ for role_name in manifest.get("roles_present", [])
+ }
+
+ trainer_locations = manifest.get("trainer_state_locations", {})
+ optimizer_state_trees: Dict[str, Dict[str, Any]] = {}
+ optimizer_locations = trainer_locations.get("optimizers") or {}
+ if optimizer_locations:
+ for name, relative_path in optimizer_locations.items():
+ path = checkpoint_dir / relative_path
+ if path.exists():
+ optimizer_state_trees[name] = _extract_prefixed_tree("optimizer", dict(mx.load(str(path))))
+ else:
+ legacy_optimizer_path = _optimizer_path(checkpoint_dir)
+ if legacy_optimizer_path.exists():
+ optimizer_state_trees["default"] = _extract_prefixed_tree(
+ "optimizer",
+ dict(mx.load(str(legacy_optimizer_path))),
+ )
+
+ scheduler_metadata: Dict[str, Dict[str, Any]] = {}
+ scheduler_locations = trainer_locations.get("schedulers") or {}
+ if scheduler_locations:
+ for name, relative_path in scheduler_locations.items():
+ scheduler_metadata[name] = _read_json(checkpoint_dir / relative_path)
+ else:
+ default_scheduler = _read_json(_scheduler_path(checkpoint_dir))
+ if default_scheduler:
+ scheduler_metadata["default"] = default_scheduler
+
+ rng_state = dict(mx.load(str(_rng_path(checkpoint_dir)))) if _rng_path(checkpoint_dir).exists() else {}
+ runtime_cache = (
+ dict(mx.load(str(_runtime_cache_path(checkpoint_dir))))
+ if _runtime_cache_path(checkpoint_dir).exists()
+ else {}
+ )
+ trainer_state = _read_json(_trainer_state_path(checkpoint_dir))
+ metrics_history = _load_jsonl(_metrics_path(checkpoint_dir))
+
+ return RLCheckpointBundle(
+ manifest=manifest,
+ algorithm=manifest.get("algorithm", trainer_state.get("algorithm", "rl")),
+ restored_roles=roles,
+ optimizer_state_trees=optimizer_state_trees,
+ scheduler_metadata=scheduler_metadata,
+ trainer_state=trainer_state,
+ rng_state=rng_state,
+ runtime_cache=runtime_cache,
+ metrics_history=metrics_history,
+ source_format="manifest",
+ )
+
+
+def _build_legacy_manifest(metadata: Dict[str, Any], has_reference: bool) -> Dict[str, Any]:
+ roles_present = ["policy"]
+ role_weight_formats: Dict[str, Any] = {"policy": "adapters.safetensors"}
+ if has_reference:
+ roles_present.append("reference")
+ role_weight_formats["reference"] = "weights.safetensors"
+ return {
+ "format_name": CHECKPOINT_FORMAT_NAME,
+ "format_version": CHECKPOINT_FORMAT_VERSION,
+ "algorithm": metadata.get("algorithm", "rl"),
+ "roles_present": roles_present,
+ "role_weight_formats": role_weight_formats,
+ "trainer_state_locations": {
+ "optimizer": "trainer_state.safetensors",
+ "rng": "trainer_state.safetensors",
+ "runtime_cache": "trainer_state.safetensors",
+ "trainer": "trainer_state.json",
+ },
+ "metrics_path": None,
+ }
+
+
+def _build_legacy_bundle(checkpoint_dir: Path) -> RLCheckpointBundle:
+ state_path = checkpoint_dir / STATE_FILE
+ metadata_path = checkpoint_dir / METADATA_FILE
+ if not state_path.exists() or not metadata_path.exists():
+ raise FileNotFoundError(f"Checkpoint state not found under {checkpoint_dir}")
+
+ metadata = _read_json(metadata_path)
+ flat_state = dict(mx.load(str(state_path)))
+ policy_state: Dict[str, mx.array] = {}
+ adapter_path = checkpoint_dir / "adapters" / "adapters.safetensors"
+ if adapter_path.exists():
+ policy_state = dict(mx.load(str(adapter_path)))
+
+ roles = {
+ "policy": RLRoleState(
+ role="policy",
+ weight_format="adapters.safetensors",
+ parameter_state=policy_state,
+ ),
+ }
+ reference_path = checkpoint_dir / REFERENCE_FILE
+ if reference_path.exists():
+ roles["reference"] = RLRoleState(
+ role="reference",
+ weight_format="weights.safetensors",
+ parameter_state=dict(mx.load(str(reference_path))),
+ metadata=_read_json(checkpoint_dir / REFERENCE_METADATA_FILE),
+ )
+
+ runtime_cache = {
+ key: value
+ for key, value in flat_state.items()
+ if not key.startswith("optimizer.") and not key.startswith("rng.")
+ }
+ optimizer_state = _extract_prefixed_tree("optimizer", flat_state)
+ rng_state = {key: value for key, value in flat_state.items() if key.startswith("rng.")}
+
+ return RLCheckpointBundle(
+ manifest=_build_legacy_manifest(metadata, has_reference="reference" in roles),
+ algorithm=metadata.get("algorithm", "rl"),
+ restored_roles=roles,
+ optimizer_state_trees={"default": optimizer_state} if optimizer_state else {},
+ scheduler_metadata={},
+ trainer_state=metadata,
+ rng_state=rng_state,
+ runtime_cache=runtime_cache,
+ metrics_history=[],
+ source_format="legacy",
+ )
+
+
+def resume_from_checkpoint(checkpoint_dir: str | Path) -> RLCheckpointBundle:
+ checkpoint_path = Path(checkpoint_dir)
+ if (checkpoint_path / MANIFEST_FILE).exists():
+ return _build_manifest_bundle(checkpoint_path)
+ if _legacy_checkpoint_exists(checkpoint_path):
+ return _build_legacy_bundle(checkpoint_path)
+ raise FileNotFoundError(f"Checkpoint state not found under {checkpoint_path}")
+
+
+def _tokenizer_encode(tokenizer: Any, text: str, add_special_tokens: bool) -> List[int]:
+ try:
+ return list(tokenizer.encode(text, add_special_tokens=add_special_tokens))
+ except TypeError:
+ return list(tokenizer.encode(text))
+
+
+def _is_message_sequence(value: Any) -> bool:
+ return isinstance(value, list) and all(isinstance(item, Mapping) and "role" in item for item in value)
+
+
+def _render_messages(
+ messages: Sequence[Mapping[str, Any]],
+ tokenizer: Any = None,
+ chat_template: Optional[Any] = None,
+ add_generation_prompt: bool = False,
+) -> str:
+ payload = {"messages": list(messages)}
+ if callable(chat_template):
+ return chat_template(payload, add_generation_prompt=add_generation_prompt)
+ if isinstance(chat_template, str):
+ if add_generation_prompt:
+ return chat_template.format(messages=list(messages), add_generation_prompt=True)
+ return chat_template.format(messages=list(messages))
+ if tokenizer is not None:
+ return apply_chat_template_to_sample(
+ payload,
+ tokenizer,
+ add_generation_prompt=add_generation_prompt,
+ )
+ return "\n".join(f"{message.get('role', 'user')}: {message.get('content', '')}" for message in messages)
+
+
+def _last_assistant_index(messages: Sequence[Mapping[str, Any]]) -> Optional[int]:
+ for index in range(len(messages) - 1, -1, -1):
+ if messages[index].get("role") == "assistant":
+ return index
+ return None
+
+
+def _extract_chat_prompt_response(
+ sample: Mapping[str, Any],
+ tokenizer: Any = None,
+ chat_template: Optional[Any] = None,
+) -> tuple[str, str]:
+ messages = list(sample.get("messages") or [])
+ if not messages:
+ raise ValueError("Chat adaptation requires a messages field.")
+ assistant_index = _last_assistant_index(messages)
+ if assistant_index is None:
+ prompt_messages = messages
+ response = ""
+ else:
+ prompt_messages = messages[:assistant_index]
+ response = str(messages[assistant_index].get("content", ""))
+ prompt = _render_messages(
+ prompt_messages,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ add_generation_prompt=True,
+ )
+ return prompt, response
+
+
+def _extract_preference_value(
+ value: Any,
+ prompt: str,
+ tokenizer: Any = None,
+ chat_template: Optional[Any] = None,
+) -> tuple[str, str]:
+ if isinstance(value, str):
+ return prompt, value
+ if isinstance(value, Mapping) and _is_message_sequence(value.get("messages")):
+ prompt_text, response = _extract_chat_prompt_response(
+ value,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ )
+ return prompt_text or prompt, response
+ if _is_message_sequence(value):
+ prompt_text, response = _extract_chat_prompt_response(
+ {"messages": value},
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ )
+ return prompt_text or prompt, response
+ raise ValueError("Unsupported preference value; expected string or chat message sequence.")
+
+
+def _normalize_prompt_sample(
+ sample: Mapping[str, Any],
+ tokenizer: Any = None,
+ chat_template: Optional[Any] = None,
+) -> Dict[str, Any]:
+ if _is_message_sequence(sample.get("messages")):
+ prompt, response = _extract_chat_prompt_response(
+ sample,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ )
+ reward_context = sample.get("reward_context", sample.get("answer", sample.get("response", response or prompt)))
+ return {
+ "prompt": prompt,
+ "reward_context": reward_context,
+ "source_messages": list(sample.get("messages") or []),
+ }
+
+ prompt = str(sample.get("prompt", sample.get("question", "")))
+ if not prompt:
+ raise ValueError("Prompt samples require a prompt or question field.")
+ return {
+ "prompt": prompt,
+ "reward_context": sample.get(
+ "reward_context",
+ sample.get("answer", sample.get("response", prompt)),
+ ),
+ }
+
+
+def _normalize_preference_sample(
+ sample: Mapping[str, Any],
+ tokenizer: Any = None,
+ chat_template: Optional[Any] = None,
+ target_mode: str = "preference",
+) -> Dict[str, Any]:
+ prompt = str(sample.get("prompt", sample.get("question", "")))
+ if _is_message_sequence(sample.get("messages")):
+ prompt = _render_messages(
+ sample.get("messages") or [],
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ add_generation_prompt=True,
+ )
+ chosen_prompt, chosen = _extract_preference_value(
+ sample.get("chosen"),
+ prompt,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ )
+ rejected_prompt, rejected = _extract_preference_value(
+ sample.get("rejected"),
+ chosen_prompt,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ )
+ return {
+ "type": "pairwise" if target_mode == "reward_pairwise" else "preference",
+ "prompt": chosen_prompt or rejected_prompt,
+ "chosen": chosen,
+ "rejected": rejected,
+ }
+
+
+def _normalize_reward_scalar_sample(
+ sample: Mapping[str, Any],
+ tokenizer: Any = None,
+ chat_template: Optional[Any] = None,
+) -> Dict[str, Any]:
+ if _is_message_sequence(sample.get("messages")):
+ prompt, response = _extract_chat_prompt_response(
+ sample,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ )
+ return {
+ "type": "scalar",
+ "prompt": prompt,
+ "response": response,
+ "score": float(sample.get("score", sample.get("reward", 0.0))),
+ }
+
+ if "text" in sample:
+ return {
+ "type": "scalar",
+ "prompt": "",
+ "response": str(sample.get("text", "")),
+ "score": float(sample.get("score", sample.get("reward", 0.0))),
+ }
+
+ prompt = str(sample.get("prompt", sample.get("question", "")))
+ response = str(sample.get("response", sample.get("completion", sample.get("assistant", ""))))
+ if not response:
+ raise ValueError("Scalar reward samples require a response, completion, assistant, or text field.")
+ return {
+ "type": "scalar",
+ "prompt": prompt,
+ "response": response,
+ "score": float(sample.get("score", sample.get("reward", 0.0))),
+ }
+
+
+def _normalize_rollout_sample(
+ sample: Mapping[str, Any],
+ tokenizer: Any = None,
+ chat_template: Optional[Any] = None,
+) -> Dict[str, Any]:
+ if _is_message_sequence(sample.get("messages")):
+ prompt, response = _extract_chat_prompt_response(
+ sample,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ )
+ completion = response or str(sample.get("completion", sample.get("response", "")))
+ else:
+ prompt = str(sample.get("prompt", sample.get("question", "")))
+ completion = str(sample.get("completion", sample.get("response", "")))
+ if not prompt or not completion:
+ raise ValueError("Rollout samples require prompt/question and completion/response data.")
+ reward = sample.get("reward", sample.get("score"))
+ return {
+ "prompt": prompt,
+ "completion": completion,
+ "reward": None if reward is None else float(reward),
+ "reward_context": sample.get(
+ "reward_context",
+ sample.get("answer", sample.get("response", completion)),
+ ),
+ }
+
+
+def _explicit_adapter_mode(mode: str) -> str:
+ if mode not in SUPPORTED_RL_DATASET_MODES:
+ raise ValueError(
+ f"Unsupported RL dataset mode '{mode}'. Supported modes: {SUPPORTED_RL_DATASET_MODES}"
+ )
+ return mode
+
+
+def _candidate_modes(sample: Mapping[str, Any]) -> List[str]:
+ if _is_message_sequence(sample.get("messages")):
+ messages = list(sample.get("messages") or [])
+ assistant_index = _last_assistant_index(messages)
+ if "chosen" in sample and "rejected" in sample:
+ if "chosen_score" in sample or "rejected_score" in sample:
+ return ["reward_pairwise"]
+ return ["preference", "reward_pairwise"]
+ if "score" in sample or "reward" in sample:
+ if assistant_index is not None:
+ return ["reward_scalar"]
+ if "completion" in sample or "response" in sample or "rewards" in sample:
+ return ["rollout"]
+ if assistant_index is None:
+ return ["prompt"]
+ return []
+
+ keys = set(sample.keys())
+ if {"chosen", "rejected"} <= keys:
+ if "chosen_score" in keys or "rejected_score" in keys:
+ return ["reward_pairwise"]
+ return ["preference", "reward_pairwise"]
+ if "text" in keys and "score" in keys:
+ return ["reward_scalar"]
+ if ("prompt" in keys or "question" in keys) and ("completion" in keys) and ("reward" in keys or "score" in keys):
+ return ["rollout"]
+ if ("prompt" in keys or "question" in keys) and ("response" in keys) and ("score" in keys):
+ return ["reward_scalar"]
+ if ("prompt" in keys or "question" in keys) and not (
+ {"chosen", "rejected", "response", "completion", "score", "reward"} & keys
+ ):
+ return ["prompt"]
+ return []
+
+
+def prepare_rl_dataset(
+ data: Iterable[Mapping[str, Any]] | PreparedRLDataset,
+ mode: Optional[str] = None,
+ tokenizer: Any = None,
+ chat_template: Optional[Any] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+) -> PreparedRLDataset:
+ if isinstance(data, PreparedRLDataset):
+ if mode is None or mode == data.mode:
+ return data
+ data = data.samples
+
+ records = list(data)
+ if not records:
+ resolved_mode = _explicit_adapter_mode(mode) if mode is not None else "prompt"
+ return PreparedRLDataset(
+ samples=[],
+ mode=resolved_mode if resolved_mode != "chat" else "prompt",
+ adapter_name=resolved_mode,
+ metadata=dict(metadata or {}),
+ )
+
+ requested_mode = _explicit_adapter_mode(mode) if mode is not None else None
+ if requested_mode == "chat":
+ candidate_sets = [_candidate_modes(record) for record in records]
+ unique_modes = sorted({candidate for candidates in candidate_sets for candidate in candidates})
+ if len(unique_modes) != 1:
+ raise ValueError(
+ "Chat auto-adaptation is ambiguous. Choose one explicit RL mode from "
+ f"{SUPPORTED_RL_DATASET_MODES[:-1]}."
+ )
+ requested_mode = unique_modes[0]
+
+ if requested_mode is None:
+ candidate_sets = [_candidate_modes(record) for record in records]
+ unique_modes = {candidate for candidates in candidate_sets for candidate in candidates}
+ if not unique_modes:
+ raise ValueError(
+ "Could not auto-detect RL dataset mode. Supported modes: "
+ f"{SUPPORTED_RL_DATASET_MODES[:-1]}."
+ )
+ if len(unique_modes) != 1:
+ raise ValueError(
+ "Ambiguous RL dataset schema. Choose an explicit mode from "
+ f"{sorted(unique_modes)}."
+ )
+ requested_mode = next(iter(unique_modes))
+
+ normalized: List[Dict[str, Any]] = []
+ adapter_name = requested_mode
+ for record in records:
+ if requested_mode == "prompt":
+ sample = _normalize_prompt_sample(record, tokenizer=tokenizer, chat_template=chat_template)
+ elif requested_mode == "preference":
+ sample = _normalize_preference_sample(
+ record,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ target_mode="preference",
+ )
+ elif requested_mode == "reward_scalar":
+ sample = _normalize_reward_scalar_sample(
+ record,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ )
+ elif requested_mode == "reward_pairwise":
+ sample = _normalize_preference_sample(
+ record,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ target_mode="reward_pairwise",
+ )
+ elif requested_mode == "rollout":
+ sample = _normalize_rollout_sample(
+ record,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ )
+ else:
+ raise ValueError(f"Unsupported RL dataset mode: {requested_mode}")
+
+ if _is_message_sequence(record.get("messages")) or _is_message_sequence(record.get("chosen")):
+ adapter_name = f"chat_{requested_mode}"
+ normalized.append(sample)
+
+ dataset_metadata = dict(metadata or {})
+ dataset_metadata.update(
+ {
+ "requested_mode": mode,
+ "num_samples": len(normalized),
+ }
+ )
+ return PreparedRLDataset(
+ samples=normalized,
+ mode=requested_mode,
+ adapter_name=adapter_name,
+ metadata=dataset_metadata,
+ )
+
+
+def prepare_reward_dataset(dataset: Iterable[Mapping[str, Any]]) -> List[Dict[str, Any]]:
+ warnings.warn(
+ "prepare_reward_dataset() is deprecated; use prepare_rl_dataset(..., mode='reward_scalar' or 'reward_pairwise') instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ records = list(dataset)
+ if not records:
+ return []
+ first = records[0]
+ mode = "reward_pairwise" if {"chosen", "rejected"} <= set(first.keys()) else "reward_scalar"
+ return list(prepare_rl_dataset(records, mode=mode))
+
+
+def prepare_preference_dataset(
+ dataset: Iterable[Mapping[str, Any]],
+ tokenizer: Any = None,
+ format_type: str = "dpo",
+) -> List[Dict[str, Any]]:
+ warnings.warn(
+ "prepare_preference_dataset() is deprecated; use prepare_rl_dataset(..., mode='preference' or 'prompt') instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ mode = "prompt" if format_type == "grpo" else "preference"
+ return list(prepare_rl_dataset(dataset, mode=mode, tokenizer=tokenizer))
+
+
+def _simple_reward_builder(reward_type: str) -> Any:
+ if reward_type == "simple":
+ return lambda response, ground_truth: 1.0 if str(ground_truth).lower() in str(response).lower() else 0.0
+ if reward_type == "math":
+ def math_reward(response: str, ground_truth: str) -> float:
+ import re
+
+ numbers = re.findall(r"-?\d+\.?\d*", response)
+ target = re.findall(r"-?\d+\.?\d*", ground_truth)
+ if numbers and target:
+ try:
+ return 1.0 if float(numbers[-1]) == float(target[-1]) else 0.0
+ except Exception:
+ return 0.0
+ return 0.0
+
+ return math_reward
+ if reward_type == "length":
+ def length_reward(response: str, _: str) -> float:
+ length = len(str(response).split())
+ if length < 10:
+ return 0.2
+ if length < 50:
+ return 0.5
+ if length < 200:
+ return 1.0
+ return 0.8
+
+ return length_reward
+ raise ValueError(f"Unknown reward type: {reward_type}")
+
+
+class _WeightedRewardComposer:
+ def __init__(self, components: Sequence[Dict[str, Any]]):
+ self.components = [dict(component) for component in components]
+ self._running_stats: Dict[str, Dict[str, float]] = {}
+
+ def _resolve_source(self, source: Any) -> Any:
+ if isinstance(source, str):
+ return _simple_reward_builder(source)
+ return source
+
+ def _normalize(self, name: str, value: float, mode: Any) -> float:
+ if mode in (None, False, "none"):
+ return value
+ if callable(mode):
+ return float(mode(value))
+ if mode not in (True, "zscore"):
+ raise ValueError(f"Unsupported reward normalization mode: {mode}")
+ stats = self._running_stats.setdefault(name, {"count": 0.0, "mean": 0.0, "m2": 0.0})
+ stats["count"] += 1.0
+ delta = value - stats["mean"]
+ stats["mean"] += delta / stats["count"]
+ delta2 = value - stats["mean"]
+ stats["m2"] += delta * delta2
+ variance = stats["m2"] / max(stats["count"] - 1.0, 1.0)
+ std = variance ** 0.5
+ if std < 1e-6:
+ return value - stats["mean"]
+ return (value - stats["mean"]) / std
+
+ def _evaluate_source(self, source: Any, payload: Dict[str, Any]) -> tuple[float, Dict[str, float], Optional[Dict[str, Any]]]:
+ evaluator = self._resolve_source(source)
+ if hasattr(evaluator, "evaluate"):
+ result = evaluator.evaluate(payload)
+ elif callable(evaluator):
+ try:
+ signature = inspect.signature(evaluator)
+ except (TypeError, ValueError):
+ signature = None
+
+ if signature is None:
+ result = evaluator(payload["completion_text"], payload["reward_context"])
+ else:
+ required = [
+ parameter
+ for parameter in signature.parameters.values()
+ if parameter.kind in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ )
+ and parameter.default is inspect._empty
+ ]
+ if len(required) >= 2:
+ result = evaluator(payload["completion_text"], payload["reward_context"])
+ else:
+ result = evaluator(payload)
+ else:
+ raise TypeError("Reward component source must be a string, callable, or evaluator object.")
+
+ if isinstance(result, Mapping):
+ reward = float(result.get("reward", result.get("score", 0.0)))
+ return reward, dict(result.get("components") or {}), result.get("diagnostics")
+ return float(result), {}, None
+
+ def evaluate(self, payload: Dict[str, Any]) -> Dict[str, Any]:
+ total = 0.0
+ named_components: Dict[str, float] = {}
+ diagnostics: Dict[str, Any] = {}
+ for index, component in enumerate(self.components):
+ name = str(component.get("name") or f"reward_{index}")
+ weight = float(component.get("weight", 1.0))
+ raw_reward, nested_components, component_diagnostics = self._evaluate_source(component.get("source"), payload)
+ reward = self._normalize(name, raw_reward, component.get("normalize"))
+ total += weight * reward
+ named_components[name] = reward
+ for nested_name, nested_value in nested_components.items():
+ named_components[f"{name}.{nested_name}"] = float(nested_value)
+ if component_diagnostics is not None:
+ diagnostics[name] = component_diagnostics
+ result: Dict[str, Any] = {
+ "reward": total,
+ "components": named_components,
+ }
+ if diagnostics:
+ result["diagnostics"] = diagnostics
+ return result
+
+
+def create_reward_function(
+ reward_type: Any = "simple",
+ *,
+ rewards: Optional[Sequence[Any]] = None,
+) -> Any:
+ if rewards is not None:
+ components: List[Dict[str, Any]] = []
+ for index, component in enumerate(rewards):
+ if isinstance(component, Mapping):
+ item = dict(component)
+ if "source" not in item:
+ raise ValueError("Reward components must include a source field.")
+ else:
+ item = {"source": component}
+ item.setdefault("name", f"reward_{index}")
+ item.setdefault("weight", 1.0)
+ components.append(item)
+ return _WeightedRewardComposer(components)
+
+ if isinstance(reward_type, Mapping):
+ if "rewards" in reward_type:
+ return create_reward_function(rewards=reward_type["rewards"])
+ return create_reward_function(rewards=[reward_type])
+
+ if isinstance(reward_type, (list, tuple)):
+ return create_reward_function(rewards=list(reward_type))
+
+ if callable(reward_type) or hasattr(reward_type, "evaluate"):
+ return reward_type
+
+ return _simple_reward_builder(str(reward_type))
+
+
+__all__ = [
+ "RLCheckpointBundle",
+ "RLRoleState",
+ "PreparedRLDataset",
+ "SUPPORTED_RL_DATASET_MODES",
+ "build_reference_policy",
+ "build_reward_model",
+ "create_reward_function",
+ "prepare_preference_dataset",
+ "prepare_reward_dataset",
+ "prepare_rl_dataset",
+ "resume_from_checkpoint",
+]
diff --git a/mlx_tune/rl_trainers.py b/mlx_tune/rl_trainers.py
index d920c10..d9be9cf 100644
--- a/mlx_tune/rl_trainers.py
+++ b/mlx_tune/rl_trainers.py
@@ -1,131 +1,1161 @@
"""
-Reinforcement Learning Trainers for MLX-Tune
-
-Provides Unsloth/TRL-compatible RL training interfaces:
-- DPOTrainer: Direct Preference Optimization
-- ORPOTrainer: Odds Ratio Preference Optimization
-- GRPOTrainer: Group Relative Policy Optimization (DeepSeek R1 style)
-- KTOTrainer: Kahneman-Tversky Optimization
-- SimPOTrainer: Simple Preference Optimization
-
-These trainers use MLX under the hood for Apple Silicon optimization.
-Now with PROPER loss implementations using native MLX training!
+Reinforcement learning trainers for MLX-Tune.
+
+Provides TRL-style trainer interfaces for:
+- DPO
+- ORPO
+- GRPO
+- KTO
+- SimPO
"""
-from typing import Optional, Dict, Any, Union, List, Callable
from pathlib import Path
+from typing import Optional, Dict, Any, List, Callable, Tuple
+from contextlib import contextmanager
+import hashlib
import json
import subprocess
+import time
import warnings
import mlx.core as mx
-# Try to import native training components
try:
- from mlx_lm.tuner.trainer import TrainingArgs
import mlx.nn as nn
import mlx.optimizers as optim
+ from mlx.utils import tree_flatten, tree_unflatten
HAS_NATIVE_TRAINING = True
except ImportError:
HAS_NATIVE_TRAINING = False
-# Import our loss functions
from mlx_tune.losses import (
dpo_loss as compute_dpo_loss,
- orpo_loss as compute_orpo_loss,
+ grpo_recompute_loss,
kto_loss as compute_kto_loss,
+ orpo_loss as compute_orpo_loss,
+ ppo_sequence_loss,
+ reward_model_pairwise_loss,
+ reward_model_regression_loss,
+ scalar_loss_metrics,
+ pairwise_ranking_accuracy,
simpo_loss as compute_simpo_loss,
- grpo_batch_loss,
- compute_reference_logprobs,
- compute_log_probs_with_lengths,
+ value_model_regression_loss,
+)
+from mlx_tune._rl_runtime import (
+ PolicyEvalBatch,
+ PreferenceBatch,
+ RolloutBatch,
+ assemble_minibatches,
+ cap_prompt_and_completion_lengths,
+ collect_rollouts,
+ compute_advantages,
+ compute_returns_and_advantages,
+ evaluate_rewards,
+ kl_against_reference,
+ make_policy_eval_batch,
+ make_preference_batch,
+ pad_sequences,
+ predict_rollout_values,
+ rank_grouped_rollouts,
+ score_rollout_references,
+ score_policy_in_chunks,
+ summarize_rollout_metrics,
+)
+from mlx_tune.model import ReferencePolicy
+from mlx_tune.model import (
+ RewardModel,
+ ValueModel,
+ build_reference_policy,
+ build_reward_model,
+ build_value_model,
+)
+from mlx_tune.rl_api import (
+ RLCheckpointBundle,
+ create_reward_function as public_create_reward_function,
+ prepare_preference_dataset as public_prepare_preference_dataset,
+ prepare_reward_dataset as public_prepare_reward_dataset,
+ prepare_rl_dataset,
+ resume_from_checkpoint,
)
-def _save_adapters_and_config(model, adapter_path: Path):
- """
- Save adapter weights and config (required for GGUF export).
+STATE_FILE = "trainer_state.safetensors"
+METADATA_FILE = "trainer_state.json"
+REFERENCE_FILE = "reference_model.safetensors"
+REFERENCE_METADATA_FILE = "reference_metadata.json"
+MANIFEST_FILE = "manifest.json"
+CHECKPOINT_FORMAT_NAME = "mlx_tune_rl_checkpoint"
+CHECKPOINT_FORMAT_VERSION = 4
+MLX_TUNE_VERSION = "0.4.0"
+GRPO_LOSS_TYPES = {"grpo", "dr_grpo", "dapo", "bnpo", "gspo"}
+
+
+def _actual_model(model: Any) -> Any:
+ return model.model if hasattr(model, "model") else model
+
+
+def _pad_token_id(tokenizer: Any) -> int:
+ pad_id = getattr(tokenizer, "pad_token_id", None)
+ return 0 if pad_id is None else pad_id
+
+
+def _encode_text(tokenizer: Any, text: str, add_special_tokens: bool = True) -> List[int]:
+ try:
+ return list(tokenizer.encode(text, add_special_tokens=add_special_tokens))
+ except TypeError:
+ return list(tokenizer.encode(text))
+
- This helper function is used by all RL trainers to ensure
- adapter_config.json is created alongside adapters.safetensors.
+def _save_adapters_and_config(model: Any, adapter_path: Path) -> bool:
+ """
+ Save trainable parameters and adapter config in mlx_lm-compatible layout.
"""
try:
- from mlx.utils import tree_flatten
- actual_model = model.model if hasattr(model, 'model') else model
- adapter_file = adapter_path / "adapters.safetensors"
- adapter_path.mkdir(parents=True, exist_ok=True)
+ if hasattr(model, "save_adapter_snapshot"):
+ return bool(model.save_adapter_snapshot(str(adapter_path)))
- # Save adapter weights (trainable parameters only)
+ actual_model = _actual_model(model)
+ adapter_path.mkdir(parents=True, exist_ok=True)
+ adapter_file = adapter_path / "adapters.safetensors"
adapter_weights = dict(tree_flatten(actual_model.trainable_parameters()))
mx.save_safetensors(str(adapter_file), adapter_weights)
+ return True
+ except Exception as exc:
+ print(f" Warning: could not save adapters: {exc}")
+ return False
+
+
+def _save_full_model_state(model: Any, path: Path) -> None:
+ actual_model = _actual_model(model)
+ mx.save_safetensors(str(path), dict(tree_flatten(actual_model.parameters())))
+
+
+def _load_parameter_tree(model: Any, path: Path, strict: bool = False) -> None:
+ if not path.exists():
+ return
+ actual_model = _actual_model(model)
+ weights = mx.load(str(path))
+ actual_model.update(tree_unflatten(list(weights.items())), strict=strict)
+ mx.eval(actual_model.parameters())
+
+
+def _load_flat_parameter_tree(model: Any, flat_state: Dict[str, Any], strict: bool = False) -> None:
+ if not flat_state:
+ return
+ actual_model = _actual_model(model)
+ actual_model.update(tree_unflatten(list(flat_state.items())), strict=strict)
+ mx.eval(actual_model.parameters())
+
+
+def _flatten_prefixed_tree(prefix: str, tree: Dict[str, Any]) -> Dict[str, mx.array]:
+ return {f"{prefix}.{key}": value for key, value in tree_flatten(tree)}
+
+
+def _extract_prefixed_tree(prefix: str, flat_state: Dict[str, mx.array]) -> Dict[str, Any]:
+ items = []
+ prefix_with_dot = f"{prefix}."
+ for key, value in flat_state.items():
+ if key.startswith(prefix_with_dot):
+ items.append((key[len(prefix_with_dot):], value))
+ return tree_unflatten(items) if items else {}
+
+
+def _rng_state_to_dict() -> Dict[str, mx.array]:
+ return {f"rng.{idx}": state for idx, state in enumerate(mx.random.state)}
+
+
+def _restore_rng_state(flat_state: Dict[str, mx.array]) -> None:
+ rng_items = []
+ idx = 0
+ while f"rng.{idx}" in flat_state:
+ rng_items.append(mx.array(flat_state[f"rng.{idx}"]))
+ idx += 1
+ if rng_items:
+ mx.random.state = rng_items
+
+
+def _pad_sequences(sequences: List[List[int]], pad_id: int) -> Tuple[mx.array, mx.array]:
+ return pad_sequences(sequences, pad_id)
+
+
+def _read_json(path: Path, default: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
+ if not path.exists():
+ return {} if default is None else default
+ with open(path) as handle:
+ return json.load(handle)
+
+
+def _write_json(path: Path, payload: Dict[str, Any]) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with open(path, "w") as handle:
+ json.dump(payload, handle, indent=2)
+
+
+def _load_jsonl(path: Path) -> List[Dict[str, Any]]:
+ if not path.exists():
+ return []
+ history: List[Dict[str, Any]] = []
+ with open(path) as handle:
+ for line in handle:
+ line = line.strip()
+ if line:
+ history.append(json.loads(line))
+ return history
+
+
+def _save_jsonl(path: Path, rows: List[Dict[str, Any]]) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with open(path, "w") as handle:
+ for row in rows:
+ handle.write(json.dumps(row) + "\n")
+
+
+def _hash_payload(payload: Any) -> str:
+ return hashlib.sha256(json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")).hexdigest()
+
+
+class _ScalarRoleTrainTarget(nn.Module):
+ def __init__(self, backbone: Any, head: Any):
+ super().__init__()
+ self.backbone = backbone
+ self.head = head
+
+
+class _RLTrainerBase:
+ algorithm = "rl"
+ requires_reference_policy = False
+
+ def _init_native_state(self) -> None:
+ self.global_step = 0
+ self.dataset_cursor = 0
+ self.reference_policy: Optional[ReferencePolicy] = None
+ self.cache_metadata: Dict[str, Any] = {}
+ self.runtime_cache_arrays: Dict[str, mx.array] = {}
+ self.optimizer = None
+ self.optimizers: Dict[str, Any] = {}
+ self.reward_model: Optional[RewardModel] = getattr(self, "reward_model", None)
+ self.value_model: Optional[ValueModel] = getattr(self, "value_model", None)
+ self.metrics_history: List[Dict[str, Any]] = []
+ self.loaded_checkpoint_manifest: Optional[Dict[str, Any]] = None
+ self.seed = getattr(self.config, "seed", getattr(self, "seed", 0))
+ self._seed_initialized = False
+
+ def _apply_lora_if_needed(self) -> None:
+ if hasattr(self.model, "_apply_lora") and not getattr(self.model, "_lora_applied", False):
+ self.model._apply_lora()
+
+ def _optimizer_for_training(self, learning_rate: Optional[float] = None):
+ lr_schedule = optim.cosine_decay(self.learning_rate if learning_rate is None else learning_rate, self.iters)
+ return optim.AdamW(learning_rate=lr_schedule)
+
+ def _primary_role_name(self) -> str:
+ return "policy"
+
+ def _primary_role_weight_format(self) -> Any:
+ return "adapters.safetensors"
+
+ def _primary_optimizer_name(self) -> str:
+ return self._primary_role_name()
+
+ def _trainer_cursor_state(self) -> Dict[str, int]:
+ return {"dataset": int(self.dataset_cursor)}
- # Determine number of layers in the model
- num_layers = None
- if hasattr(actual_model, 'layers'):
- num_layers = len(actual_model.layers)
- elif hasattr(actual_model, 'model') and hasattr(actual_model.model, 'layers'):
- num_layers = len(actual_model.model.layers)
-
- # Save adapter_config.json
- lora_config = {}
- if hasattr(model, 'lora_config') and model.lora_config:
- lora_config = model.lora_config.copy()
-
- r = lora_config.get('r', 16)
- alpha = lora_config.get('lora_alpha', 16)
-
- adapter_config = {
- "fine_tune_type": "lora",
- "num_layers": num_layers,
- "lora_parameters": {
- "rank": r,
- "scale": alpha / r,
- "dropout": lora_config.get('lora_dropout', 0.0),
+ def _restore_trainer_cursors(self, cursors: Dict[str, Any]) -> None:
+ self.dataset_cursor = int(cursors.get("dataset", cursors.get("dataset_cursor", self.dataset_cursor)))
+
+ def _sampling_config_payload(self) -> Dict[str, Any]:
+ return {}
+
+ def _sampling_config_fingerprint(self) -> Optional[str]:
+ payload = self._sampling_config_payload()
+ if not payload:
+ return None
+ return _hash_payload(payload)
+
+ def _validate_resume_sampling_fingerprint(self, trainer_state: Dict[str, Any]) -> None:
+ current_fingerprint = self._sampling_config_fingerprint()
+ saved_fingerprint = (
+ trainer_state.get("trainer_state", {}).get("sampling_config_fingerprint")
+ or trainer_state.get("sampling_config_fingerprint")
+ )
+ if current_fingerprint and saved_fingerprint and current_fingerprint != saved_fingerprint:
+ raise ValueError(
+ "Checkpoint sampling config fingerprint does not match the current trainer configuration."
+ )
+
+ def _seed_training_run(self) -> None:
+ if self._seed_initialized:
+ return
+ mx.random.seed(int(self.seed))
+ self._seed_initialized = True
+
+ @contextmanager
+ def _preserve_rng_state(self):
+ saved_state = [mx.array(state) for state in mx.random.state]
+ try:
+ yield
+ finally:
+ mx.random.state = [mx.array(state) for state in saved_state]
+
+ def _normalize_metric_value(self, value: Any) -> Any:
+ if value is None:
+ return None
+ if hasattr(value, "item"):
+ value = value.item()
+ if isinstance(value, bool):
+ return bool(value)
+ if isinstance(value, int):
+ return int(value)
+ if isinstance(value, float):
+ return float(value)
+ return value
+
+ def _record_metrics(
+ self,
+ namespace: str,
+ metrics: Dict[str, Any],
+ step: Optional[int] = None,
+ ) -> Dict[str, Any]:
+ normalized = {
+ (key if "/" in key else f"{namespace}/{key}"): self._normalize_metric_value(value)
+ for key, value in metrics.items()
+ if value is not None
+ }
+ if not normalized:
+ return {}
+ row = {"step": self.global_step if step is None else int(step)}
+ row.update(normalized)
+ self.metrics_history.append(row)
+ if hasattr(self, "output_dir"):
+ _save_jsonl(self._metrics_path(), self.metrics_history)
+ return row
+
+ def _record_metric(self, **metrics: Any) -> None:
+ self._record_metrics("train", metrics)
+
+ def _format_metric_summary(
+ self,
+ row: Dict[str, Any],
+ namespace: str = "train",
+ keys: Optional[List[str]] = None,
+ ) -> str:
+ preferred = keys or [
+ "policy_loss",
+ "loss",
+ "value_loss",
+ "reward_loss",
+ "reward_mean",
+ "logprob_delta_per_token_mean",
+ "logprob_delta_mean",
+ "completion_length_mean",
+ "completion_length_max",
+ "eos_rate",
+ "truncation_rate",
+ "kl_to_reference_mean",
+ "rollout_generate_wall",
+ "reward_eval_wall",
+ "reference_score_wall",
+ "returns_wall",
+ "policy_update_wall",
+ "policy_update_steps",
+ "preference_win_rate",
+ ]
+ parts = [f"step={row['step']}"]
+ for key in preferred:
+ full_key = f"{namespace}/{key}"
+ if full_key not in row:
+ continue
+ value = row[full_key]
+ if isinstance(value, bool):
+ parts.append(f"{key}={value}")
+ elif isinstance(value, (int, float)):
+ parts.append(f"{key}={value:.4f}")
+ else:
+ parts.append(f"{key}={value}")
+ return " | ".join(parts)
+
+ def _gather_runtime_cache_arrays(
+ self,
+ extra_arrays: Optional[Dict[str, mx.array]] = None,
+ ) -> Dict[str, mx.array]:
+ merged: Dict[str, mx.array] = {}
+ merged.update(getattr(self, "runtime_cache_arrays", {}))
+ if hasattr(self, "_extra_state_arrays"):
+ merged.update(getattr(self, "_extra_state_arrays")())
+ if extra_arrays:
+ merged.update(extra_arrays)
+ return merged
+
+ def _save_primary_role(self, checkpoint_dir: Optional[Path] = None) -> None:
+ policy_dir = self._role_dir("policy", checkpoint_dir)
+ _save_adapters_and_config(self.model, policy_dir)
+ if hasattr(self.model, "set_adapter_path"):
+ self.model.set_adapter_path(str(policy_dir))
+ _write_json(
+ policy_dir / "role.json",
+ {
+ "role": "policy",
+ "weight_format": "adapters.safetensors",
},
+ )
+
+ def _load_primary_role(self, checkpoint_dir: Path) -> None:
+ policy_adapter_file = self._role_dir("policy", checkpoint_dir) / "adapters.safetensors"
+ if policy_adapter_file.exists():
+ _load_parameter_tree(self.model, policy_adapter_file, strict=False)
+ if hasattr(self.model, "set_adapter_path"):
+ self.model.set_adapter_path(str(self._role_dir("policy", checkpoint_dir)))
+
+ def _next_samples(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ if not samples:
+ raise ValueError(f"{self.algorithm} training dataset is empty.")
+
+ batch = []
+ for _ in range(max(1, self.batch_size)):
+ batch.append(samples[self.dataset_cursor])
+ self.dataset_cursor = (self.dataset_cursor + 1) % len(samples)
+ return batch
+
+ def _next_rollout_samples(self, samples: List[Dict[str, Any]], count: int) -> List[Dict[str, Any]]:
+ if not samples:
+ raise ValueError(f"{self.algorithm} training dataset is empty.")
+
+ batch = []
+ for _ in range(max(1, count)):
+ batch.append(samples[self.dataset_cursor])
+ self.dataset_cursor = (self.dataset_cursor + 1) % len(samples)
+ return batch
+
+ def _observed_rollout_kl(self, rollout_batch: Optional[RolloutBatch]) -> Optional[float]:
+ if rollout_batch is None:
+ return None
+ if rollout_batch.reference_logprobs is None or rollout_batch.rollout_logprobs is None:
+ return None
+ kl_values = kl_against_reference(
+ rollout_batch.rollout_logprobs.astype(mx.float32),
+ rollout_batch.reference_logprobs.astype(mx.float32),
+ )
+ return float(mx.mean(kl_values).item())
+
+ def _effective_kl_beta(self, rollout_batch: Optional[RolloutBatch] = None) -> float:
+ base_beta = float(getattr(self, "beta", 0.0))
+ if getattr(self, "kl_penalty_mode", "kl") == "none":
+ return 0.0
+ kl_target = getattr(self, "kl_target", None)
+ observed_kl = self._observed_rollout_kl(rollout_batch)
+ if kl_target is None or observed_kl is None:
+ return base_beta
+ target = max(float(kl_target), 1e-8)
+ scale = min(max(observed_kl / target, 0.0), 10.0)
+ return base_beta * scale
+
+ def _checkpoint_dir(self, checkpoint_dir: Optional[Path] = None) -> Path:
+ return checkpoint_dir or self.output_dir
+
+ def _manifest_path(self, checkpoint_dir: Optional[Path] = None) -> Path:
+ return self._checkpoint_dir(checkpoint_dir) / MANIFEST_FILE
+
+ def _role_dir(self, role_name: str, checkpoint_dir: Optional[Path] = None) -> Path:
+ return self._checkpoint_dir(checkpoint_dir) / role_name
+
+ def _optimizer_state_path(
+ self,
+ checkpoint_dir: Optional[Path] = None,
+ optimizer_name: Optional[str] = None,
+ ) -> Path:
+ if optimizer_name is None:
+ return self._checkpoint_dir(checkpoint_dir) / "optimizer" / "state.safetensors"
+ return self._checkpoint_dir(checkpoint_dir) / "optimizers" / optimizer_name / "state.safetensors"
+
+ def _scheduler_state_path(
+ self,
+ checkpoint_dir: Optional[Path] = None,
+ optimizer_name: Optional[str] = None,
+ ) -> Path:
+ if optimizer_name is None:
+ return self._checkpoint_dir(checkpoint_dir) / "scheduler" / "state.json"
+ return self._checkpoint_dir(checkpoint_dir) / "schedulers" / optimizer_name / "state.json"
+
+ def _trainer_state_path(self, checkpoint_dir: Optional[Path] = None) -> Path:
+ return self._checkpoint_dir(checkpoint_dir) / "trainer" / "state.json"
+
+ def _trainer_rng_path(self, checkpoint_dir: Optional[Path] = None) -> Path:
+ return self._checkpoint_dir(checkpoint_dir) / "trainer" / "rng.safetensors"
+
+ def _runtime_cache_path(self, checkpoint_dir: Optional[Path] = None) -> Path:
+ return self._checkpoint_dir(checkpoint_dir) / "runtime" / "cache.safetensors"
+
+ def _metrics_path(self, checkpoint_dir: Optional[Path] = None) -> Path:
+ return self._checkpoint_dir(checkpoint_dir) / "metrics" / "history.jsonl"
+
+ def _legacy_state_path(self, checkpoint_dir: Optional[Path] = None) -> Path:
+ return self._checkpoint_dir(checkpoint_dir) / STATE_FILE
+
+ def _legacy_metadata_path(self, checkpoint_dir: Optional[Path] = None) -> Path:
+ return self._checkpoint_dir(checkpoint_dir) / METADATA_FILE
+
+ def _legacy_reference_path(self, checkpoint_dir: Optional[Path] = None) -> Path:
+ return self._checkpoint_dir(checkpoint_dir) / REFERENCE_FILE
+
+ def _legacy_reference_metadata_path(self, checkpoint_dir: Optional[Path] = None) -> Path:
+ return self._checkpoint_dir(checkpoint_dir) / REFERENCE_METADATA_FILE
+
+ def _has_manifest_checkpoint(self, checkpoint_dir: Path) -> bool:
+ return self._manifest_path(checkpoint_dir).exists()
+
+ def _has_legacy_checkpoint(self, checkpoint_dir: Path) -> bool:
+ return (
+ (checkpoint_dir / "adapters" / "adapters.safetensors").exists()
+ or self._legacy_state_path(checkpoint_dir).exists()
+ or self._legacy_reference_path(checkpoint_dir).exists()
+ )
+
+ def _save_reference_policy(self, checkpoint_dir: Optional[Path] = None) -> None:
+ if self.reference_policy is None:
+ return
+ reference_dir = self._role_dir("reference", checkpoint_dir)
+ reference_dir.mkdir(parents=True, exist_ok=True)
+ _save_full_model_state(self.reference_policy.model, reference_dir / "weights.safetensors")
+ _write_json(
+ reference_dir / "metadata.json",
+ {
+ "source": self.reference_policy.source,
+ "metadata": self.reference_policy.metadata,
+ },
+ )
+ _write_json(
+ reference_dir / "role.json",
+ {
+ "role": "reference",
+ "weight_format": "weights.safetensors",
+ },
+ )
+
+ def _save_optional_scalar_role(
+ self,
+ role_name: str,
+ role_model: Optional[Any],
+ checkpoint_dir: Optional[Path] = None,
+ ) -> None:
+ if role_model is None or role_name == self._primary_role_name():
+ return
+ role_model.save_pretrained(str(self._role_dir(role_name, checkpoint_dir)))
+
+ def _build_training_metadata(self) -> Dict[str, Any]:
+ config = self.config.to_dict() if hasattr(self.config, "to_dict") else dict(self.config)
+ trainer_state = {
+ "cursors": self._trainer_cursor_state(),
+ "sampling_config_fingerprint": self._sampling_config_fingerprint(),
+ "step_boundary": {
+ "completed_optimizer_step": int(self.global_step),
+ "checkpoint_authoritative": True,
+ },
+ }
+ runtime_state = {
+ "cache_metadata": self.cache_metadata,
+ "runtime_cache_keys": sorted(self.runtime_cache_arrays.keys()),
+ "rng_state_path": "trainer/rng.safetensors",
+ }
+ return {
+ "algorithm": self.algorithm,
+ "config": config,
+ "global_step": self.global_step,
+ "dataset_cursor": self.dataset_cursor,
+ "cache_metadata": self.cache_metadata,
+ "seed": int(self.seed),
+ "sampling_config_fingerprint": trainer_state["sampling_config_fingerprint"],
+ "trainer_state": trainer_state,
+ "runtime_state": runtime_state,
}
- target_modules = lora_config.get('target_modules', [])
- if target_modules:
- short_to_full = {
- 'q_proj': 'self_attn.q_proj', 'k_proj': 'self_attn.k_proj',
- 'v_proj': 'self_attn.v_proj', 'o_proj': 'self_attn.o_proj',
- 'gate_proj': 'mlp.gate_proj', 'up_proj': 'mlp.up_proj',
- 'down_proj': 'mlp.down_proj',
+ def _build_scheduler_state(self, optimizer: Any, learning_rate: Optional[float] = None) -> Dict[str, Any]:
+ step_value = 0
+ if optimizer is not None and getattr(optimizer, "state", None):
+ step = optimizer.state.get("step", 0)
+ step_value = int(step.item()) if hasattr(step, "item") else int(step)
+ return {
+ "name": "cosine_decay",
+ "initial_learning_rate": self.learning_rate if learning_rate is None else learning_rate,
+ "total_steps": self.iters,
+ "step": step_value,
+ }
+
+ def _normalize_optimizers(
+ self,
+ optimizer: Any = None,
+ optimizers: Optional[Dict[str, Any]] = None,
+ ) -> Dict[str, Any]:
+ if optimizers:
+ return optimizers
+ if optimizer is not None:
+ return {self._primary_optimizer_name(): optimizer}
+ return getattr(self, "optimizers", {})
+
+ def _optimizer_learning_rates(self) -> Dict[str, float]:
+ return {self._primary_optimizer_name(): self.learning_rate}
+
+ def _build_manifest(self, optimizers: Dict[str, Any]) -> Dict[str, Any]:
+ reward_base = getattr(self.reward_model, "base_model", None)
+ value_base = getattr(self.value_model, "base_model", None)
+ primary_role = self._primary_role_name()
+ roles_present = [primary_role]
+ role_weight_formats = {
+ primary_role: self._primary_role_weight_format(),
+ }
+ reference_provenance = None
+ if self.reference_policy is not None:
+ roles_present.append("reference")
+ role_weight_formats["reference"] = "weights.safetensors"
+ reference_provenance = {
+ "source": self.reference_policy.source,
+ "metadata": self.reference_policy.metadata,
+ }
+ if self.reward_model is not None and primary_role != "reward_model":
+ roles_present.append("reward_model")
+ role_weight_formats["reward_model"] = {
+ "backbone": "weights.safetensors",
+ "head": "head.safetensors",
+ "adapters": "adapters.safetensors"
+ if hasattr(reward_base, "has_adapters") and reward_base.has_adapters()
+ else None,
+ }
+ if self.value_model is not None and primary_role != "value_model":
+ roles_present.append("value_model")
+ role_weight_formats["value_model"] = {
+ "backbone": "weights.safetensors",
+ "head": "head.safetensors",
+ "adapters": "adapters.safetensors"
+ if hasattr(value_base, "has_adapters") and value_base.has_adapters()
+ else None,
}
- adapter_config["lora_parameters"]["keys"] = [
- short_to_full.get(m, m) for m in target_modules
- ]
- config_path = adapter_path / "adapter_config.json"
- with open(config_path, 'w') as f:
- json.dump(adapter_config, f, indent=2)
+ trainer_state_locations = {
+ "optimizers": {
+ name: f"optimizers/{name}/state.safetensors"
+ for name in optimizers
+ },
+ "schedulers": {
+ name: f"schedulers/{name}/state.json"
+ for name in optimizers
+ },
+ "trainer": "trainer/state.json",
+ "rng": "trainer/rng.safetensors",
+ "runtime_cache": "runtime/cache.safetensors",
+ }
+ if len(optimizers) == 1:
+ trainer_state_locations["optimizer"] = "optimizer/state.safetensors"
+ trainer_state_locations["scheduler"] = "scheduler/state.json"
- return True
- except Exception as e:
- print(f" âš Could not save adapters: {e}")
- return False
+ return {
+ "format_name": CHECKPOINT_FORMAT_NAME,
+ "format_version": CHECKPOINT_FORMAT_VERSION,
+ "algorithm": self.algorithm,
+ "roles_present": roles_present,
+ "mlx_tune_version": MLX_TUNE_VERSION,
+ "role_weight_formats": role_weight_formats,
+ "trainer_state_locations": trainer_state_locations,
+ "metrics_path": "metrics/history.jsonl",
+ "reference_provenance": reference_provenance,
+ }
+
+ def save_state(
+ self,
+ optimizer: Any = None,
+ extra_arrays: Optional[Dict[str, mx.array]] = None,
+ optimizers: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ checkpoint_dir = self._checkpoint_dir()
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
+ optimizer_map = self._normalize_optimizers(optimizer=optimizer, optimizers=optimizers)
+ if not optimizer_map:
+ raise ValueError("No optimizer state provided for checkpoint save.")
+ self.optimizers = optimizer_map
+ if len(optimizer_map) == 1:
+ self.optimizer = next(iter(optimizer_map.values()))
+
+ self._save_primary_role(checkpoint_dir)
+ self._save_reference_policy(checkpoint_dir)
+ self._save_optional_scalar_role("reward_model", self.reward_model, checkpoint_dir)
+ self._save_optional_scalar_role("value_model", self.value_model, checkpoint_dir)
+
+ learning_rates = self._optimizer_learning_rates()
+ for name, current_optimizer in optimizer_map.items():
+ optimizer_path = self._optimizer_state_path(checkpoint_dir, name)
+ optimizer_path.parent.mkdir(parents=True, exist_ok=True)
+ mx.save_safetensors(str(optimizer_path), _flatten_prefixed_tree("optimizer", current_optimizer.state))
+ if len(optimizer_map) == 1:
+ legacy_path = self._optimizer_state_path(checkpoint_dir)
+ legacy_path.parent.mkdir(parents=True, exist_ok=True)
+ mx.save_safetensors(str(legacy_path), _flatten_prefixed_tree("optimizer", current_optimizer.state))
+ scheduler_state = self._build_scheduler_state(
+ current_optimizer,
+ learning_rate=learning_rates.get(name, self.learning_rate),
+ )
+ _write_json(self._scheduler_state_path(checkpoint_dir, name), scheduler_state)
+ if len(optimizer_map) == 1:
+ _write_json(self._scheduler_state_path(checkpoint_dir), scheduler_state)
+
+ rng_path = self._trainer_rng_path(checkpoint_dir)
+ rng_path.parent.mkdir(parents=True, exist_ok=True)
+ mx.save_safetensors(str(rng_path), _rng_state_to_dict())
+
+ runtime_arrays = self._gather_runtime_cache_arrays(extra_arrays)
+ self.runtime_cache_arrays = dict(runtime_arrays)
+ runtime_path = self._runtime_cache_path(checkpoint_dir)
+ if runtime_arrays:
+ runtime_path.parent.mkdir(parents=True, exist_ok=True)
+ mx.save_safetensors(str(runtime_path), runtime_arrays)
+
+ _write_json(self._trainer_state_path(checkpoint_dir), self._build_training_metadata())
+ _save_jsonl(self._metrics_path(checkpoint_dir), self.metrics_history)
+ _write_json(self._manifest_path(checkpoint_dir), self._build_manifest(optimizer_map))
+
+ def _ensure_reference_policy(self) -> None:
+ if not self.requires_reference_policy:
+ return
+ if self.reference_policy is None:
+ ref_model = getattr(self, "ref_model", None)
+ self.reference_policy = build_reference_policy(self.model, ref_model=ref_model, snapshot=True)
+
+ def _load_reference_policy(self, checkpoint_dir: Path) -> None:
+ reference_path = self._role_dir("reference", checkpoint_dir) / "weights.safetensors"
+ if reference_path.exists():
+ self.reference_policy = build_reference_policy(self.model, snapshot=True)
+ _load_parameter_tree(self.reference_policy.model, reference_path, strict=False)
+ _actual_model(self.reference_policy.model).freeze()
+ mx.eval(_actual_model(self.reference_policy.model).parameters())
+ metadata = _read_json(self._role_dir("reference", checkpoint_dir) / "metadata.json")
+ if metadata:
+ self.reference_policy.source = metadata.get("source", self.reference_policy.source)
+ self.reference_policy.metadata = metadata.get("metadata", self.reference_policy.metadata)
+ else:
+ self._ensure_reference_policy()
+
+ def _load_optional_scalar_role(self, checkpoint_dir: Path, role_name: str) -> Optional[Any]:
+ role_dir = self._role_dir(role_name, checkpoint_dir)
+ if not role_dir.exists() or role_name == self._primary_role_name():
+ return None
+ with open(role_dir / "head_config.json") as handle:
+ config = json.load(handle)
+ if role_name == "reward_model":
+ if self.reward_model is None:
+ self.reward_model = build_reward_model(
+ self.model,
+ pooling=config.get("pooling", "last_token"),
+ target=config.get("target", "completion"),
+ )
+ self.reward_model.load_pretrained(str(role_dir))
+ return self.reward_model
+ if role_name == "value_model":
+ if self.value_model is None:
+ self.value_model = build_value_model(
+ self.model,
+ pooling=config.get("pooling", "last_token"),
+ target=config.get("target", "completion"),
+ )
+ self.value_model.load_pretrained(str(role_dir))
+ return self.value_model
+ return None
+
+ def _apply_scalar_role_state(self, role_name: str, role_state: Any) -> Optional[Any]:
+ if role_name == "reward_model":
+ if self.reward_model is None:
+ self.reward_model = build_reward_model(
+ self.model,
+ pooling=role_state.head_config.get("pooling", "last_token"),
+ target=role_state.head_config.get("target", "completion"),
+ )
+ role_model = self.reward_model
+ elif role_name == "value_model":
+ if self.value_model is None:
+ self.value_model = build_value_model(
+ self.model,
+ pooling=role_state.head_config.get("pooling", "last_token"),
+ target=role_state.head_config.get("target", "completion"),
+ )
+ role_model = self.value_model
+ else:
+ return None
+
+ _load_flat_parameter_tree(role_model.base_model, role_state.parameter_state, strict=False)
+ if role_state.adapter_state:
+ _load_flat_parameter_tree(role_model.base_model, role_state.adapter_state, strict=False)
+ if role_state.head_state:
+ role_model.head.update(tree_unflatten(list(role_state.head_state.items())), strict=False)
+ mx.eval(role_model.head.parameters())
+ if role_state.head_config:
+ role_model.head_config.update(role_state.head_config)
+ role_model.pooling = role_model.head_config.get("pooling", role_model.pooling)
+ role_model.target = role_model.head_config.get("target", role_model.target)
+ return role_model
+
+ def _apply_checkpoint_bundle(
+ self,
+ bundle: RLCheckpointBundle,
+ optimizer: Any = None,
+ optimizers: Optional[Dict[str, Any]] = None,
+ ) -> Dict[str, mx.array]:
+ self.loaded_checkpoint_manifest = bundle.manifest
+ optimizer_map = self._normalize_optimizers(optimizer=optimizer, optimizers=optimizers)
+ if optimizer_map:
+ for name, current_optimizer in optimizer_map.items():
+ state_tree = bundle.optimizer_state_trees.get(name)
+ if state_tree is None and len(bundle.optimizer_state_trees) == 1:
+ state_tree = next(iter(bundle.optimizer_state_trees.values()))
+ if state_tree:
+ current_optimizer.state = state_tree
+ self.optimizers = optimizer_map
+ if len(optimizer_map) == 1:
+ self.optimizer = next(iter(optimizer_map.values()))
+
+ if bundle.rng_state:
+ _restore_rng_state(bundle.rng_state)
+ self._seed_initialized = True
+
+ trainer_state = bundle.trainer_state
+ self._validate_resume_sampling_fingerprint(trainer_state)
+ self.global_step = trainer_state.get("global_step", 0)
+ self.dataset_cursor = trainer_state.get("dataset_cursor", 0)
+ self.cache_metadata = trainer_state.get("cache_metadata", {})
+ self.seed = int(trainer_state.get("seed", self.seed))
+ self._restore_trainer_cursors(
+ trainer_state.get("trainer_state", {}).get("cursors", {})
+ )
+ self.metrics_history = list(bundle.metrics_history)
+ self.runtime_cache_arrays = dict(bundle.runtime_cache)
+
+ for role_name, role_state in bundle.restored_roles.items():
+ if role_name == "policy":
+ _load_flat_parameter_tree(self.model, role_state.parameter_state, strict=False)
+ elif role_name == "reference":
+ self.reference_policy = build_reference_policy(self.model, snapshot=True)
+ _load_flat_parameter_tree(self.reference_policy.model, role_state.parameter_state, strict=False)
+ reference_actual = _actual_model(self.reference_policy.model)
+ if hasattr(reference_actual, "freeze"):
+ reference_actual.freeze()
+ mx.eval(reference_actual.parameters())
+ if role_state.metadata:
+ self.reference_policy.source = role_state.metadata.get("source", self.reference_policy.source)
+ self.reference_policy.metadata = role_state.metadata.get("metadata", self.reference_policy.metadata)
+ elif role_name in {"reward_model", "value_model"}:
+ self._apply_scalar_role_state(role_name, role_state)
+
+ if self.requires_reference_policy and self.reference_policy is None:
+ self._ensure_reference_policy()
+ return bundle.runtime_cache
+
+ def _restore_optimizer_states(
+ self,
+ checkpoint_dir: Path,
+ optimizer_map: Dict[str, Any],
+ manifest: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ trainer_locations = {} if manifest is None else manifest.get("trainer_state_locations", {})
+ optimizer_locations = trainer_locations.get("optimizers") or {}
+ for name, current_optimizer in optimizer_map.items():
+ optimizer_path = self._optimizer_state_path(checkpoint_dir, name)
+ if name in optimizer_locations:
+ optimizer_path = checkpoint_dir / optimizer_locations[name]
+ elif not optimizer_path.exists() and len(optimizer_map) == 1:
+ optimizer_path = self._optimizer_state_path(checkpoint_dir)
+ if not optimizer_path.exists():
+ continue
+ optimizer_state = _extract_prefixed_tree("optimizer", mx.load(str(optimizer_path)))
+ if optimizer_state:
+ current_optimizer.state = optimizer_state
+
+ def _load_manifest_state(
+ self,
+ optimizer: Any = None,
+ checkpoint_dir: Optional[Path] = None,
+ optimizers: Optional[Dict[str, Any]] = None,
+ ) -> Dict[str, mx.array]:
+ checkpoint_dir = Path(checkpoint_dir)
+ manifest = _read_json(self._manifest_path(checkpoint_dir))
+ if not manifest:
+ raise FileNotFoundError(f"Checkpoint manifest not found under {checkpoint_dir}")
+ self.loaded_checkpoint_manifest = manifest
+
+ self._load_primary_role(checkpoint_dir)
+
+ optimizer_map = self._normalize_optimizers(optimizer=optimizer, optimizers=optimizers)
+ self._restore_optimizer_states(checkpoint_dir, optimizer_map, manifest=manifest)
+ self.optimizers = optimizer_map
+ if len(optimizer_map) == 1:
+ self.optimizer = next(iter(optimizer_map.values()))
+
+ rng_path = self._trainer_rng_path(checkpoint_dir)
+ if rng_path.exists():
+ _restore_rng_state(mx.load(str(rng_path)))
+ self._seed_initialized = True
+
+ metadata = _read_json(self._trainer_state_path(checkpoint_dir))
+ self._validate_resume_sampling_fingerprint(metadata)
+ self.global_step = metadata.get("global_step", 0)
+ self.dataset_cursor = metadata.get("dataset_cursor", 0)
+ self.cache_metadata = metadata.get("cache_metadata", {})
+ self.seed = int(metadata.get("seed", self.seed))
+ self._restore_trainer_cursors(metadata.get("trainer_state", {}).get("cursors", {}))
+ self.metrics_history = _load_jsonl(self._metrics_path(checkpoint_dir))
+
+ self._load_reference_policy(checkpoint_dir)
+ self._load_optional_scalar_role(checkpoint_dir, "reward_model")
+ self._load_optional_scalar_role(checkpoint_dir, "value_model")
+
+ runtime_path = self._runtime_cache_path(checkpoint_dir)
+ self.runtime_cache_arrays = dict(mx.load(str(runtime_path))) if runtime_path.exists() else {}
+ return self.runtime_cache_arrays
+
+ def _load_legacy_state(
+ self,
+ optimizer: Any = None,
+ checkpoint_dir: Optional[Path] = None,
+ optimizers: Optional[Dict[str, Any]] = None,
+ ) -> Dict[str, mx.array]:
+ checkpoint_dir = Path(checkpoint_dir)
+ adapter_file = checkpoint_dir / "adapters" / "adapters.safetensors"
+ if adapter_file.exists():
+ _load_parameter_tree(self.model, adapter_file, strict=False)
+ if hasattr(self.model, "set_adapter_path"):
+ self.model.set_adapter_path(str(checkpoint_dir / "adapters"))
+
+ state_path = self._legacy_state_path(checkpoint_dir)
+ metadata_path = self._legacy_metadata_path(checkpoint_dir)
+ if not state_path.exists() or not metadata_path.exists():
+ raise FileNotFoundError(f"Checkpoint state not found under {checkpoint_dir}")
+
+ metadata = _read_json(metadata_path)
+ flat_state = mx.load(str(state_path))
+ optimizer_map = self._normalize_optimizers(optimizer=optimizer, optimizers=optimizers)
+ if optimizer_map:
+ first_optimizer = next(iter(optimizer_map.values()))
+ optimizer_state = _extract_prefixed_tree("optimizer", flat_state)
+ if optimizer_state:
+ first_optimizer.state = optimizer_state
+ self.optimizers = optimizer_map
+ if len(optimizer_map) == 1:
+ self.optimizer = first_optimizer
+ _restore_rng_state(flat_state)
+ self._seed_initialized = True
+
+ self.global_step = metadata.get("global_step", 0)
+ self.dataset_cursor = metadata.get("dataset_cursor", 0)
+ self.cache_metadata = metadata.get("cache_metadata", {})
+ self.seed = int(metadata.get("seed", self.seed))
+ self.metrics_history = []
+ self.runtime_cache_arrays = {
+ key: value
+ for key, value in flat_state.items()
+ if not key.startswith("optimizer.") and not key.startswith("rng.")
+ }
+ reference_path = self._legacy_reference_path(checkpoint_dir)
+ if reference_path.exists():
+ self.reference_policy = build_reference_policy(self.model, snapshot=True)
+ _load_parameter_tree(self.reference_policy.model, reference_path, strict=False)
+ reference_metadata = _read_json(self._legacy_reference_metadata_path(checkpoint_dir))
+ if reference_metadata:
+ self.reference_policy.source = reference_metadata.get("source", self.reference_policy.source)
+ self.reference_policy.metadata = reference_metadata.get("metadata", self.reference_policy.metadata)
+ else:
+ self._ensure_reference_policy()
+ return flat_state
-class DPOConfig:
- """
- Configuration for Direct Preference Optimization training.
+ def load_state(
+ self,
+ optimizer: Any = None,
+ checkpoint_dir: Optional[Path] = None,
+ optimizers: Optional[Dict[str, Any]] = None,
+ ) -> Dict[str, mx.array]:
+ bundle = resume_from_checkpoint(Path(checkpoint_dir))
+ return self._apply_checkpoint_bundle(bundle, optimizer=optimizer, optimizers=optimizers)
- Compatible with TRL's DPOConfig.
- Example:
- >>> config = DPOConfig(
- ... beta=0.1,
- ... learning_rate=5e-7,
- ... max_steps=100,
- ... )
- """
+def prepare_reward_dataset(dataset: Any) -> List[Dict[str, Any]]:
+ return public_prepare_reward_dataset(dataset)
+
+def _tokenize_reward_scalar_sample(
+ tokenizer: Any,
+ sample: Dict[str, Any],
+ max_seq_length: int,
+) -> Dict[str, Any]:
+ prompt = sample.get("prompt", "")
+ response = sample.get("response", "")
+ sequence_ids = tokenizer.encode(prompt + response)[:max_seq_length]
+ prompt_ids = tokenizer.encode(prompt) if prompt else []
+ prompt_length = min(len(prompt_ids), len(sequence_ids))
+ completion_length = max(len(sequence_ids) - prompt_length, 0)
+ return {
+ "ids": sequence_ids,
+ "length": len(sequence_ids),
+ "prompt_length": prompt_length,
+ "completion_length": completion_length,
+ "score": float(sample.get("score", 0.0)),
+ }
+
+
+def _tokenize_reward_pairwise_sample(
+ tokenizer: Any,
+ sample: Dict[str, Any],
+ max_seq_length: int,
+) -> Dict[str, Any]:
+ prompt = sample.get("prompt", "")
+ prompt_ids = tokenizer.encode(prompt) if prompt else []
+ chosen_ids = tokenizer.encode(prompt + sample["chosen"])[:max_seq_length]
+ rejected_ids = tokenizer.encode(prompt + sample["rejected"])[:max_seq_length]
+ chosen_prompt_length = min(len(prompt_ids), len(chosen_ids))
+ rejected_prompt_length = min(len(prompt_ids), len(rejected_ids))
+ return {
+ "chosen_ids": chosen_ids,
+ "rejected_ids": rejected_ids,
+ "chosen_length": len(chosen_ids),
+ "rejected_length": len(rejected_ids),
+ "chosen_prompt_length": chosen_prompt_length,
+ "rejected_prompt_length": rejected_prompt_length,
+ "chosen_completion_length": max(len(chosen_ids) - chosen_prompt_length, 0),
+ "rejected_completion_length": max(len(rejected_ids) - rejected_prompt_length, 0),
+ }
+
+
+def score_reward_model(
+ reward_model: RewardModel,
+ samples: Any,
+ batch_size: int = 8,
+ tokenizer: Optional[Any] = None,
+ max_seq_length: int = 2048,
+) -> List[float]:
+ scoring_tokenizer = tokenizer or getattr(reward_model, "tokenizer", None)
+ if scoring_tokenizer is None:
+ raise ValueError("score_reward_model requires a tokenizer on the reward model or as an argument.")
+
+ normalized = list(prepare_rl_dataset(samples, mode="reward_scalar", tokenizer=scoring_tokenizer))
+ if any(sample.get("type") != "scalar" for sample in normalized):
+ raise ValueError("score_reward_model only supports scalar reward samples.")
+
+ tokenized = [
+ _tokenize_reward_scalar_sample(scoring_tokenizer, sample, max_seq_length)
+ for sample in normalized
+ ]
+ scores: List[float] = []
+ pad_id = _pad_token_id(scoring_tokenizer)
+ for start in range(0, len(tokenized), max(1, batch_size)):
+ chunk = tokenized[start:start + max(1, batch_size)]
+ input_ids, lengths = _pad_sequences([sample["ids"] for sample in chunk], pad_id)
+ chunk_scores = reward_model.score(
+ input_ids,
+ sequence_lengths=lengths,
+ prompt_lengths=mx.array([sample["prompt_length"] for sample in chunk]),
+ completion_lengths=mx.array([sample["completion_length"] for sample in chunk]),
+ )
+ scores.extend(float(value.item()) for value in chunk_scores)
+ return scores
+
+
+class RLConfigBase:
+ _NON_SERIALIZED_FIELDS: set[str] = set()
+
+ @staticmethod
+ def _pop_alias(kwargs: Dict[str, Any], *names: str, default: Any = None) -> Any:
+ found = [name for name in names if name in kwargs]
+ if not found:
+ return default
+ value = kwargs.pop(found[0])
+ for name in found[1:]:
+ kwargs.pop(name)
+ return value
+
+ @staticmethod
+ def _normalize_reward_sources(
+ reward_sources: Optional[Any],
+ reward_fn: Optional[Any],
+ reward_model: Optional[Any],
+ ) -> List[Any]:
+ resolved: List[Any] = []
+ if reward_sources is not None:
+ if isinstance(reward_sources, (list, tuple)):
+ resolved.extend(list(reward_sources))
+ else:
+ resolved.append(reward_sources)
+ if reward_model is not None:
+ resolved.append({"name": "reward_model", "source": reward_model})
+ if reward_fn is not None:
+ resolved.append({"name": "reward_fn", "source": reward_fn})
+ return resolved
+
+ def _set_remaining(self, kwargs: Dict[str, Any]) -> None:
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+ self._validate()
+
+ def _validate_choice(self, field: str, value: Any, choices: set[str]) -> None:
+ if value not in choices:
+ raise ValueError(f"Unsupported {field} '{value}'. Supported values: {sorted(choices)}")
+
+ def _validate(self) -> None:
+ reward_source = getattr(self, "reward_source", None)
+ if reward_source is not None:
+ self._validate_choice("reward_source", reward_source, {"auto", "online", "offline", "hybrid"})
+ advantage_estimator = getattr(self, "advantage_estimator", None)
+ if advantage_estimator is not None:
+ self._validate_choice("advantage_estimator", advantage_estimator, {"group_zscore", "rloo", "gae"})
+ kl_penalty_mode = getattr(self, "kl_penalty_mode", None)
+ if kl_penalty_mode is not None:
+ self._validate_choice("kl_penalty_mode", kl_penalty_mode, {"kl", "none"})
+
+ def to_dict(self) -> Dict[str, Any]:
+ payload: Dict[str, Any] = {}
+ for key, value in self.__dict__.items():
+ if key.startswith("_") or key in self._NON_SERIALIZED_FIELDS:
+ continue
+ if callable(value):
+ continue
+ if hasattr(value, "parameters") or hasattr(value, "score") or hasattr(value, "predict"):
+ continue
+ payload[key] = value
+ return payload
+
+
+class RewardConfig(RLConfigBase):
def __init__(
self,
- # DPO-specific
- beta: float = 0.1, # KL penalty coefficient
- loss_type: str = "sigmoid", # sigmoid, hinge, ipo, kto_pair
+ output_dir: str = "./reward_outputs",
+ learning_rate: float = 5e-6,
+ per_device_train_batch_size: int = 2,
+ num_train_epochs: int = 1,
+ max_steps: int = -1,
+ logging_steps: int = 10,
+ save_steps: int = 100,
+ max_seq_length: int = 2048,
+ pairwise_margin: float = 0.0,
+ regression_loss_type: str = "mse",
+ dataset_mode: Optional[str] = None,
+ chat_template: Optional[Any] = None,
+ auto_detect_dataset: bool = True,
+ **kwargs,
+ ):
+ self.output_dir = output_dir
+ self.learning_rate = learning_rate
+ self.per_device_train_batch_size = per_device_train_batch_size
+ self.num_train_epochs = num_train_epochs
+ self.max_steps = max_steps
+ self.logging_steps = logging_steps
+ self.save_steps = save_steps
+ self.max_seq_length = max_seq_length
+ self.pairwise_margin = pairwise_margin
+ self.regression_loss_type = regression_loss_type
+ self.dataset_mode = dataset_mode
+ self.chat_template = chat_template
+ self.auto_detect_dataset = auto_detect_dataset
+ self._set_remaining(kwargs)
+
+
+class DPOConfig(RLConfigBase):
+ def __init__(
+ self,
+ beta: float = 0.1,
+ loss_type: str = "sigmoid",
label_smoothing: float = 0.0,
- # Training args
output_dir: str = "./dpo_outputs",
learning_rate: float = 5e-7,
per_device_train_batch_size: int = 2,
@@ -137,7 +1167,10 @@ def __init__(
save_steps: int = 100,
max_seq_length: int = 2048,
max_prompt_length: int = 512,
- **kwargs
+ dataset_mode: Optional[str] = None,
+ chat_template: Optional[Any] = None,
+ auto_detect_dataset: bool = True,
+ **kwargs,
):
self.beta = beta
self.loss_type = loss_type
@@ -153,34 +1186,16 @@ def __init__(
self.save_steps = save_steps
self.max_seq_length = max_seq_length
self.max_prompt_length = max_prompt_length
+ self.dataset_mode = dataset_mode
+ self.chat_template = chat_template
+ self.auto_detect_dataset = auto_detect_dataset
+ self._set_remaining(kwargs)
- for key, value in kwargs.items():
- setattr(self, key, value)
-
- def to_dict(self) -> Dict[str, Any]:
- return {k: v for k, v in self.__dict__.items() if not k.startswith('_')}
-
-
-class ORPOConfig:
- """
- Configuration for Odds Ratio Preference Optimization training.
-
- ORPO combines SFT and preference learning into a single step,
- making it simpler and more efficient than traditional RLHF.
-
- Example:
- >>> config = ORPOConfig(
- ... beta=0.1,
- ... learning_rate=8e-6,
- ... max_steps=1000,
- ... )
- """
+class ORPOConfig(RLConfigBase):
def __init__(
self,
- # ORPO-specific
- beta: float = 0.1, # Odds ratio coefficient
- # Training args
+ beta: float = 0.1,
output_dir: str = "./orpo_outputs",
learning_rate: float = 8e-6,
per_device_train_batch_size: int = 2,
@@ -192,7 +1207,10 @@ def __init__(
save_steps: int = 100,
max_seq_length: int = 2048,
max_prompt_length: int = 512,
- **kwargs
+ dataset_mode: Optional[str] = None,
+ chat_template: Optional[Any] = None,
+ auto_detect_dataset: bool = True,
+ **kwargs,
):
self.beta = beta
self.output_dir = output_dir
@@ -206,49 +1224,29 @@ def __init__(
self.save_steps = save_steps
self.max_seq_length = max_seq_length
self.max_prompt_length = max_prompt_length
+ self.dataset_mode = dataset_mode
+ self.chat_template = chat_template
+ self.auto_detect_dataset = auto_detect_dataset
+ self._set_remaining(kwargs)
- for key, value in kwargs.items():
- setattr(self, key, value)
-
- def to_dict(self) -> Dict[str, Any]:
- return {k: v for k, v in self.__dict__.items() if not k.startswith('_')}
-
-class GRPOConfig:
- """
- Configuration for Group Relative Policy Optimization training.
-
- GRPO is used by DeepSeek to train their R1 reasoning models.
- It replaces the value model with group statistics and uses custom
- reward functions.
-
- Supports loss types:
- - 'grpo': Standard GRPO
- - 'dr_grpo': Dr. GRPO (distilled)
- - 'dapo': DAPO variant
- - 'bnpo': BNPO variant
-
- Example:
- >>> config = GRPOConfig(
- ... loss_type='grpo',
- ... num_generations=4,
- ... learning_rate=1e-6,
- ... )
- """
+class GRPOConfig(RLConfigBase):
+ _NON_SERIALIZED_FIELDS = {"reward_fn", "reward_model", "value_model"}
def __init__(
self,
- # GRPO-specific
- loss_type: str = "grpo", # grpo, dr_grpo, dapo, bnpo
- beta: float = 0.04, # KL coefficient
- num_generations: int = 4, # Number of generations per prompt
+ loss_type: str = "grpo",
+ advantage_mode: str = "group_zscore",
+ beta: float = 0.04,
+ num_generations: int = 4,
temperature: float = 0.7,
max_completion_length: int = 512,
- # Reward function (custom callable)
reward_fn: Optional[Callable] = None,
- # Training args
+ reward_model: Optional[Any] = None,
+ value_model: Optional[Any] = None,
output_dir: str = "./grpo_outputs",
learning_rate: float = 1e-6,
+ seed: int = 0,
per_device_train_batch_size: int = 1,
gradient_accumulation_steps: int = 8,
num_train_epochs: int = 1,
@@ -257,16 +1255,64 @@ def __init__(
logging_steps: int = 1,
save_steps: int = 100,
max_seq_length: int = 2048,
- **kwargs
+ clip_epsilon: float = 0.2,
+ epsilon_low: Optional[float] = None,
+ epsilon_high: Optional[float] = None,
+ rollout_batch_size: Optional[int] = None,
+ scale_rewards: Optional[bool] = None,
+ reward_normalization: str = "none",
+ mask_truncated_completions: bool = False,
+ minibatch_reuse_steps: int = 1,
+ entropy_bonus: float = 0.0,
+ advantage_estimator: Optional[str] = None,
+ reward_source: str = "auto",
+ reward_sources: Optional[Any] = None,
+ kl_target: Optional[float] = None,
+ kl_penalty_mode: str = "kl",
+ eval_steps: Optional[int] = None,
+ eval_num_batches: Optional[int] = None,
+ eval_num_generations: Optional[int] = None,
+ generation_batch_size: Optional[int] = None,
+ score_chunk_size: Optional[int] = None,
+ precompute_reference_scores: bool = False,
+ dataset_mode: Optional[str] = None,
+ chat_template: Optional[Any] = None,
+ auto_detect_dataset: bool = True,
+ **kwargs,
):
self.loss_type = loss_type
- self.beta = beta
- self.num_generations = num_generations
+ resolved_advantage = self._pop_alias(
+ kwargs,
+ "baseline_mode",
+ "advantage_estimator",
+ default=advantage_estimator or advantage_mode,
+ )
+ self.advantage_estimator = resolved_advantage
+ self.advantage_mode = resolved_advantage
+ self.kl_beta = self._pop_alias(kwargs, "kl_beta", default=beta)
+ self.beta = self.kl_beta
+ self.kl_target = kl_target
+ self.kl_penalty_mode = kl_penalty_mode
+ self.num_generations = self._pop_alias(kwargs, "generations_per_prompt", default=num_generations)
self.temperature = temperature
self.max_completion_length = max_completion_length
self.reward_fn = reward_fn
+ self.reward_model = reward_model
+ self.value_model = value_model
+ self.reward_sources = self._normalize_reward_sources(
+ self._pop_alias(kwargs, "reward_sources", default=reward_sources),
+ reward_fn,
+ reward_model,
+ )
+ self.reward_source = reward_source
+ self.reward_normalization = reward_normalization
+ self.mask_truncated_completions = mask_truncated_completions
+ self.minibatch_reuse_steps = minibatch_reuse_steps
+ self.entropy_bonus = entropy_bonus
+ self.rollout_batch_size = rollout_batch_size
self.output_dir = output_dir
self.learning_rate = learning_rate
+ self.seed = seed
self.per_device_train_batch_size = per_device_train_batch_size
self.gradient_accumulation_steps = gradient_accumulation_steps
self.num_train_epochs = num_train_epochs
@@ -275,683 +1321,2813 @@ def __init__(
self.logging_steps = logging_steps
self.save_steps = save_steps
self.max_seq_length = max_seq_length
+ self.clip_epsilon = clip_epsilon
+ self.epsilon_low = clip_epsilon if epsilon_low is None else epsilon_low
+ self.epsilon_high = clip_epsilon if epsilon_high is None else epsilon_high
+ self.scale_rewards = scale_rewards
+ self.eval_steps = eval_steps
+ self.eval_num_batches = eval_num_batches
+ self.eval_num_generations = eval_num_generations
+ self.generation_batch_size = generation_batch_size
+ self.score_chunk_size = score_chunk_size
+ self.precompute_reference_scores = precompute_reference_scores
+ self.dataset_mode = dataset_mode
+ self.chat_template = chat_template
+ self.auto_detect_dataset = auto_detect_dataset
+ if self.loss_type == "dapo":
+ self.mask_truncated_completions = True
+ self.epsilon_high = 0.28 if epsilon_high is None else epsilon_high
+ if self.scale_rewards is None:
+ self.scale_rewards = self.loss_type != "dr_grpo"
+ self._set_remaining(kwargs)
+
+
+class PPOConfig(RLConfigBase):
+ _NON_SERIALIZED_FIELDS = {"reward_fn", "reward_model", "value_model"}
- for key, value in kwargs.items():
- setattr(self, key, value)
+ def __init__(
+ self,
+ output_dir: str = "./ppo_outputs",
+ learning_rate: float = 1e-6,
+ seed: int = 0,
+ value_learning_rate: Optional[float] = None,
+ per_device_train_batch_size: int = 1,
+ num_train_epochs: int = 1,
+ max_steps: int = -1,
+ logging_steps: int = 1,
+ save_steps: int = 100,
+ max_seq_length: int = 2048,
+ max_completion_length: int = 256,
+ num_generations: int = 4,
+ ppo_epochs: int = 2,
+ temperature: float = 0.7,
+ clip_epsilon: float = 0.2,
+ beta: float = 0.0,
+ gamma: float = 1.0,
+ gae_lambda: float = 1.0,
+ reward_fn: Optional[Callable] = None,
+ reward_model: Optional[Any] = None,
+ value_model: Optional[Any] = None,
+ normalize_advantages: bool = True,
+ rollout_batch_size: Optional[int] = None,
+ reward_normalization: str = "none",
+ mask_truncated_completions: bool = False,
+ minibatch_reuse_steps: Optional[int] = None,
+ entropy_bonus: float = 0.0,
+ advantage_estimator: str = "gae",
+ reward_source: str = "auto",
+ reward_sources: Optional[Any] = None,
+ kl_target: Optional[float] = None,
+ kl_penalty_mode: str = "kl",
+ eval_steps: Optional[int] = None,
+ eval_num_batches: Optional[int] = None,
+ eval_num_generations: Optional[int] = None,
+ generation_batch_size: Optional[int] = None,
+ score_chunk_size: Optional[int] = None,
+ precompute_reference_scores: bool = False,
+ dataset_mode: Optional[str] = None,
+ chat_template: Optional[Any] = None,
+ auto_detect_dataset: bool = True,
+ **kwargs,
+ ):
+ reuse_steps = ppo_epochs if minibatch_reuse_steps is None else minibatch_reuse_steps
+ self.output_dir = output_dir
+ self.learning_rate = learning_rate
+ self.seed = seed
+ self.value_learning_rate = learning_rate if value_learning_rate is None else value_learning_rate
+ self.per_device_train_batch_size = per_device_train_batch_size
+ self.num_train_epochs = num_train_epochs
+ self.max_steps = max_steps
+ self.logging_steps = logging_steps
+ self.save_steps = save_steps
+ self.max_seq_length = max_seq_length
+ self.max_completion_length = max_completion_length
+ self.num_generations = self._pop_alias(kwargs, "generations_per_prompt", default=num_generations)
+ self.minibatch_reuse_steps = reuse_steps
+ self.ppo_epochs = reuse_steps
+ self.temperature = temperature
+ self.clip_epsilon = clip_epsilon
+ self.kl_beta = self._pop_alias(kwargs, "kl_beta", default=beta)
+ self.beta = self.kl_beta
+ self.kl_target = kl_target
+ self.kl_penalty_mode = kl_penalty_mode
+ self.gamma = gamma
+ self.gae_lambda = gae_lambda
+ self.reward_fn = reward_fn
+ self.reward_model = reward_model
+ self.value_model = value_model
+ self.reward_sources = self._normalize_reward_sources(
+ self._pop_alias(kwargs, "reward_sources", default=reward_sources),
+ reward_fn,
+ reward_model,
+ )
+ self.reward_source = reward_source
+ self.reward_normalization = reward_normalization
+ self.mask_truncated_completions = mask_truncated_completions
+ self.entropy_bonus = entropy_bonus
+ self.advantage_estimator = advantage_estimator
+ self.rollout_batch_size = rollout_batch_size
+ self.normalize_advantages = normalize_advantages
+ self.eval_steps = eval_steps
+ self.eval_num_batches = eval_num_batches
+ self.eval_num_generations = eval_num_generations
+ self.generation_batch_size = generation_batch_size
+ self.score_chunk_size = score_chunk_size
+ self.precompute_reference_scores = precompute_reference_scores
+ self.dataset_mode = dataset_mode
+ self.chat_template = chat_template
+ self.auto_detect_dataset = auto_detect_dataset
+ self._set_remaining(kwargs)
+
+
+class OnlineDPOConfig(RLConfigBase):
+ _NON_SERIALIZED_FIELDS = {"reward_fn", "reward_model"}
- def to_dict(self) -> Dict[str, Any]:
- return {k: v for k, v in self.__dict__.items()
- if not k.startswith('_') and k != 'reward_fn'}
+ def __init__(
+ self,
+ beta: float = 0.1,
+ label_smoothing: float = 0.0,
+ output_dir: str = "./online_dpo_outputs",
+ learning_rate: float = 5e-7,
+ seed: int = 0,
+ per_device_train_batch_size: int = 1,
+ num_train_epochs: int = 1,
+ max_steps: int = -1,
+ logging_steps: int = 1,
+ save_steps: int = 100,
+ max_seq_length: int = 2048,
+ max_completion_length: int = 256,
+ num_generations: int = 4,
+ temperature: float = 0.7,
+ reward_fn: Optional[Callable] = None,
+ reward_model: Optional[Any] = None,
+ rollout_batch_size: Optional[int] = None,
+ reward_normalization: str = "none",
+ mask_truncated_completions: bool = False,
+ minibatch_reuse_steps: int = 1,
+ entropy_bonus: float = 0.0,
+ reward_source: str = "auto",
+ reward_sources: Optional[Any] = None,
+ kl_target: Optional[float] = None,
+ kl_penalty_mode: str = "kl",
+ eval_steps: Optional[int] = None,
+ eval_num_batches: Optional[int] = None,
+ eval_num_generations: Optional[int] = None,
+ generation_batch_size: Optional[int] = None,
+ score_chunk_size: Optional[int] = None,
+ precompute_reference_scores: bool = False,
+ dataset_mode: Optional[str] = None,
+ chat_template: Optional[Any] = None,
+ auto_detect_dataset: bool = True,
+ **kwargs,
+ ):
+ self.kl_beta = self._pop_alias(kwargs, "kl_beta", default=beta)
+ self.beta = self.kl_beta
+ self.kl_target = kl_target
+ self.kl_penalty_mode = kl_penalty_mode
+ self.label_smoothing = label_smoothing
+ self.output_dir = output_dir
+ self.learning_rate = learning_rate
+ self.seed = seed
+ self.per_device_train_batch_size = per_device_train_batch_size
+ self.num_train_epochs = num_train_epochs
+ self.max_steps = max_steps
+ self.logging_steps = logging_steps
+ self.save_steps = save_steps
+ self.max_seq_length = max_seq_length
+ self.max_completion_length = max_completion_length
+ self.num_generations = self._pop_alias(kwargs, "generations_per_prompt", default=num_generations)
+ self.temperature = temperature
+ self.reward_fn = reward_fn
+ self.reward_model = reward_model
+ self.reward_sources = self._normalize_reward_sources(
+ self._pop_alias(kwargs, "reward_sources", default=reward_sources),
+ reward_fn,
+ reward_model,
+ )
+ self.reward_source = reward_source
+ self.reward_normalization = reward_normalization
+ self.mask_truncated_completions = mask_truncated_completions
+ self.minibatch_reuse_steps = minibatch_reuse_steps
+ self.entropy_bonus = entropy_bonus
+ self.rollout_batch_size = rollout_batch_size
+ self.eval_steps = eval_steps
+ self.eval_num_batches = eval_num_batches
+ self.eval_num_generations = eval_num_generations
+ self.generation_batch_size = generation_batch_size
+ self.score_chunk_size = score_chunk_size
+ self.precompute_reference_scores = precompute_reference_scores
+ self.dataset_mode = dataset_mode
+ self.chat_template = chat_template
+ self.auto_detect_dataset = auto_detect_dataset
+ self._set_remaining(kwargs)
+
+
+class KTOConfig(RLConfigBase):
+ def __init__(
+ self,
+ beta: float = 0.1,
+ output_dir: str = "./kto_outputs",
+ learning_rate: float = 5e-7,
+ per_device_train_batch_size: int = 1,
+ num_train_epochs: int = 1,
+ max_steps: int = 100,
+ logging_steps: int = 10,
+ save_steps: int = 100,
+ max_seq_length: int = 2048,
+ dataset_mode: Optional[str] = None,
+ chat_template: Optional[Any] = None,
+ auto_detect_dataset: bool = True,
+ **kwargs,
+ ):
+ self.beta = beta
+ self.output_dir = output_dir
+ self.learning_rate = learning_rate
+ self.per_device_train_batch_size = per_device_train_batch_size
+ self.num_train_epochs = num_train_epochs
+ self.max_steps = max_steps
+ self.logging_steps = logging_steps
+ self.save_steps = save_steps
+ self.max_seq_length = max_seq_length
+ self.dataset_mode = dataset_mode
+ self.chat_template = chat_template
+ self.auto_detect_dataset = auto_detect_dataset
+ self._set_remaining(kwargs)
-class DPOTrainer:
- """
- Direct Preference Optimization Trainer.
-
- DPO trains models on preference data (chosen vs rejected responses)
- without requiring a separate reward model.
-
- Compatible with TRL's DPOTrainer API.
- Now with PROPER DPO loss implementation!
-
- Example:
- >>> from mlx_tune import FastLanguageModel, DPOTrainer, DPOConfig
- >>>
- >>> model, tokenizer = FastLanguageModel.from_pretrained(...)
- >>> model = FastLanguageModel.get_peft_model(model, r=16)
- >>>
- >>> # Preference dataset with chosen/rejected pairs
- >>> dataset = [
- ... {"prompt": "...", "chosen": "...", "rejected": "..."},
- ... ]
- >>>
- >>> trainer = DPOTrainer(
- ... model=model,
- ... ref_model=None, # Uses stop_gradient by default
- ... train_dataset=dataset,
- ... tokenizer=tokenizer,
- ... args=DPOConfig(beta=0.1),
- ... )
- >>> trainer.train()
- """
+class SimPOConfig(RLConfigBase):
+ def __init__(
+ self,
+ gamma: float = 0.5,
+ beta: float = 2.0,
+ output_dir: str = "./simpo_outputs",
+ learning_rate: float = 5e-7,
+ per_device_train_batch_size: int = 1,
+ num_train_epochs: int = 1,
+ max_steps: int = 100,
+ logging_steps: int = 10,
+ save_steps: int = 100,
+ max_seq_length: int = 2048,
+ dataset_mode: Optional[str] = None,
+ chat_template: Optional[Any] = None,
+ auto_detect_dataset: bool = True,
+ **kwargs,
+ ):
+ self.gamma = gamma
+ self.beta = beta
+ self.output_dir = output_dir
+ self.learning_rate = learning_rate
+ self.per_device_train_batch_size = per_device_train_batch_size
+ self.num_train_epochs = num_train_epochs
+ self.max_steps = max_steps
+ self.logging_steps = logging_steps
+ self.save_steps = save_steps
+ self.max_seq_length = max_seq_length
+ self.dataset_mode = dataset_mode
+ self.chat_template = chat_template
+ self.auto_detect_dataset = auto_detect_dataset
+ self._set_remaining(kwargs)
+
+
+def _resolve_reward_evaluator(
+ reward_model: Optional[Any],
+ reward_fn: Optional[Any],
+ reward_sources: Optional[List[Any]] = None,
+) -> Any:
+ if reward_sources:
+ return public_create_reward_function(rewards=reward_sources)
+ if reward_model is not None:
+ return reward_model
+ if reward_fn is not None:
+ return reward_fn
+ return None
+
+
+def _normalize_reward_values(
+ rewards: mx.array,
+ prompt_group_indices: mx.array,
+ mode: str,
+) -> mx.array:
+ if mode in {"none", "off", ""}:
+ return rewards
+ if mode not in {"center", "zscore"}:
+ raise ValueError(f"Unsupported reward normalization mode: {mode}")
+ reward_values = rewards.tolist()
+ groups = prompt_group_indices.tolist()
+ adjusted = [0.0] * len(reward_values)
+ grouped: Dict[int, List[int]] = {}
+ for index, group in enumerate(groups):
+ grouped.setdefault(int(group), []).append(index)
+ for positions in grouped.values():
+ group_rewards = mx.array([reward_values[position] for position in positions], dtype=mx.float32)
+ mean_value = mx.mean(group_rewards)
+ centered = group_rewards - mean_value
+ if mode == "zscore":
+ std_value = mx.std(group_rewards)
+ centered = centered if float(std_value.item()) < 1e-6 else centered / std_value
+ for offset, position in enumerate(positions):
+ adjusted[position] = float(centered[offset].item())
+ return mx.array(adjusted, dtype=mx.float32)
+
+
+def _apply_truncation_mask_to_rollout(rollout_batch: RolloutBatch) -> RolloutBatch:
+ if rollout_batch.truncation_flags is None or not bool(mx.any(rollout_batch.truncation_flags).item()):
+ return rollout_batch
+
+ keep_mask = (~rollout_batch.truncation_flags).astype(mx.float32)
+ zero_lengths = mx.where(rollout_batch.truncation_flags, mx.zeros_like(rollout_batch.completion_lengths), rollout_batch.completion_lengths)
+ rollout_batch.completion_lengths = zero_lengths
+ rollout_batch.policy_eval.completion_lengths = zero_lengths
+ rollout_batch.rollout_logprobs = rollout_batch.rollout_logprobs * keep_mask
+ rollout_batch.policy_eval.rollout_logprobs = rollout_batch.rollout_logprobs
+ rollout_batch.policy_eval.old_logprobs = rollout_batch.rollout_logprobs
+ if rollout_batch.old_logprobs is not None:
+ rollout_batch.old_logprobs = rollout_batch.old_logprobs * keep_mask
+ if rollout_batch.policy_eval.old_token_logprobs is not None:
+ rollout_batch.policy_eval.old_token_logprobs = rollout_batch.policy_eval.old_token_logprobs * keep_mask[:, None]
+ if rollout_batch.reference_logprobs is not None:
+ rollout_batch.reference_logprobs = rollout_batch.reference_logprobs * keep_mask
+ rollout_batch.policy_eval.reference_logprobs = rollout_batch.reference_logprobs
+ if rollout_batch.value_predictions is not None:
+ rollout_batch.value_predictions = rollout_batch.value_predictions * keep_mask
+ rollout_batch.policy_eval.value_predictions = rollout_batch.value_predictions
+ if rollout_batch.rewards is not None:
+ rollout_batch.rewards = rollout_batch.rewards * keep_mask
+ if rollout_batch.returns is not None:
+ rollout_batch.returns = rollout_batch.returns * keep_mask
+ rollout_batch.policy_eval.returns = rollout_batch.returns
+ if rollout_batch.advantages is not None:
+ rollout_batch.advantages = rollout_batch.advantages * keep_mask
+ rollout_batch.policy_eval.advantages = rollout_batch.advantages
+ return rollout_batch
+
+
+def _prepare_on_policy_samples(
+ dataset: Any,
+ tokenizer: Any,
+ config: Any,
+) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Optional[str]]:
+ prompt_samples: List[Dict[str, Any]] = []
+ rollout_samples: List[Dict[str, Any]] = []
+ prepared = prepare_rl_dataset(
+ dataset,
+ mode=config.dataset_mode,
+ tokenizer=tokenizer,
+ chat_template=getattr(config, "chat_template", None),
+ )
+ for sample_index, sample in enumerate(prepared):
+ if prepared.mode == "prompt":
+ prompt = sample.get("prompt", "")
+ if not prompt:
+ continue
+ prompt_samples.append(
+ {
+ "sample_index": sample_index,
+ "prompt": prompt,
+ "prompt_ids": _encode_text(tokenizer, prompt),
+ "reward_context": sample.get("reward_context", prompt),
+ }
+ )
+ elif prepared.mode == "rollout":
+ rollout_samples.append(
+ {
+ "sample_index": sample_index,
+ "prompt": sample["prompt"],
+ "completion": sample["completion"],
+ "reward": sample.get("reward"),
+ "reward_context": sample.get("reward_context", sample["completion"]),
+ }
+ )
+ return prompt_samples, rollout_samples, prepared.mode
+
+
+def _next_cursor_batch(
+ samples: List[Dict[str, Any]],
+ count: int,
+ cursor: int,
+ algorithm: str,
+) -> Tuple[List[Dict[str, Any]], int]:
+ if not samples:
+ raise ValueError(f"{algorithm} training dataset is empty.")
+
+ batch: List[Dict[str, Any]] = []
+ next_cursor = cursor
+ for _ in range(max(1, count)):
+ batch.append(samples[next_cursor])
+ next_cursor = (next_cursor + 1) % len(samples)
+ return batch, next_cursor
+
+
+def _rollout_score_batch_size(trainer: Any, num_generations: Optional[int] = None) -> int:
+ generations = trainer.num_generations if num_generations is None else num_generations
+ return max(1, trainer.rollout_batch_size * max(1, generations))
+
+
+def _fixed_rollout_cache_dataset_fingerprint(samples: List[Dict[str, Any]]) -> str:
+ return _hash_payload(
+ [
+ {
+ "sample_index": sample.get("sample_index"),
+ "prompt": sample.get("prompt"),
+ "completion": sample.get("completion"),
+ "reward_context": sample.get("reward_context"),
+ }
+ for sample in samples
+ ]
+ )
+
+
+def _preference_cache_dataset_fingerprint(samples: List[Dict[str, Any]]) -> str:
+ return _hash_payload(
+ [
+ {
+ "sample_index": sample.get("sample_index"),
+ "chosen_ids": sample.get("chosen_ids"),
+ "rejected_ids": sample.get("rejected_ids"),
+ }
+ for sample in samples
+ ]
+ )
+
+
+def _fixed_rollout_reference_cache_valid(
+ trainer: Any,
+ cache_key: str,
+ samples: List[Dict[str, Any]],
+) -> bool:
+ cached_scores = trainer.runtime_cache_arrays.get(cache_key)
+ cached_indices = trainer.runtime_cache_arrays.get(f"{cache_key}.sample_indices")
+ if cached_scores is None or cached_indices is None or cached_scores.shape[0] != len(samples):
+ return False
+ cache_info = trainer.cache_metadata.get("reference_score_caches", {}).get(cache_key, {})
+ return (
+ cached_indices.tolist() == [sample["sample_index"] for sample in samples]
+ and cache_info.get("dataset_fingerprint") == _fixed_rollout_cache_dataset_fingerprint(samples)
+ )
+
+
+def _ensure_fixed_rollout_reference_cache(
+ trainer: Any,
+ cache_key: str,
+ samples: List[Dict[str, Any]],
+) -> Optional[mx.array]:
+ if not getattr(trainer.config, "precompute_reference_scores", False):
+ return None
+ if trainer.reference_policy is None or not samples:
+ return None
+ if _fixed_rollout_reference_cache_valid(trainer, cache_key, samples):
+ return trainer.runtime_cache_arrays[cache_key]
+
+ prompt_ids = [_encode_text(trainer.tokenizer, sample["prompt"]) for sample in samples]
+ completion_ids = [
+ _encode_text(trainer.tokenizer, sample["completion"], add_special_tokens=False)
+ for sample in samples
+ ]
+ full_sequences = [prompt + completion for prompt, completion in zip(prompt_ids, completion_ids)]
+ prompt_lengths = [len(prompt) for prompt in prompt_ids]
+ completion_lengths = [len(completion) for completion in completion_ids]
+ eval_batch = make_policy_eval_batch(
+ full_sequences,
+ pad_id=_pad_token_id(trainer.tokenizer),
+ mode="completion",
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ sample_indices=mx.array([sample["sample_index"] for sample in samples]),
+ )
+ reference_logprobs = score_policy_in_chunks(
+ _actual_model(trainer.reference_policy.model),
+ eval_batch,
+ batch_size=_rollout_score_batch_size(trainer),
+ token_budget=getattr(trainer.config, "score_chunk_size", None),
+ mode="completion",
+ ).summed_logprobs
+ reference_logprobs = mx.stop_gradient(reference_logprobs.astype(mx.float32))
+ trainer.runtime_cache_arrays[cache_key] = reference_logprobs
+ trainer.runtime_cache_arrays[f"{cache_key}.sample_indices"] = mx.array(
+ [sample["sample_index"] for sample in samples],
+ dtype=mx.int32,
+ )
+ trainer.cache_metadata.setdefault("reference_score_caches", {})[cache_key] = {
+ "num_samples": len(samples),
+ "sampling_config_fingerprint": trainer._sampling_config_fingerprint(),
+ "dataset_fingerprint": _fixed_rollout_cache_dataset_fingerprint(samples),
+ }
+ return reference_logprobs
+
+
+def _preference_reference_cache_valid(
+ trainer: Any,
+ cache_key: str,
+ samples: List[Dict[str, Any]],
+) -> bool:
+ chosen = trainer.runtime_cache_arrays.get(f"{cache_key}.chosen")
+ rejected = trainer.runtime_cache_arrays.get(f"{cache_key}.rejected")
+ sample_indices = trainer.runtime_cache_arrays.get(f"{cache_key}.sample_indices")
+ if chosen is None or rejected is None or sample_indices is None:
+ return False
+ if chosen.shape[0] != len(samples) or rejected.shape[0] != len(samples):
+ return False
+ cache_info = trainer.cache_metadata.get("reference_score_caches", {}).get(cache_key, {})
+ return (
+ sample_indices.tolist() == [sample["sample_index"] for sample in samples]
+ and cache_info.get("dataset_fingerprint") == _preference_cache_dataset_fingerprint(samples)
+ )
+
+
+def _ensure_preference_reference_cache(
+ trainer: Any,
+ cache_key: str,
+ samples: List[Dict[str, Any]],
+) -> Tuple[Optional[mx.array], Optional[mx.array]]:
+ if not getattr(trainer.config, "precompute_reference_scores", False):
+ return None, None
+ if trainer.reference_policy is None or not samples:
+ return None, None
+ if _preference_reference_cache_valid(trainer, cache_key, samples):
+ return (
+ trainer.runtime_cache_arrays[f"{cache_key}.chosen"],
+ trainer.runtime_cache_arrays[f"{cache_key}.rejected"],
+ )
+
+ preference_batch = make_preference_batch(
+ chosen_sequences=[sample["chosen_ids"] for sample in samples],
+ rejected_sequences=[sample["rejected_ids"] for sample in samples],
+ pad_id=_pad_token_id(trainer.tokenizer),
+ sample_indices=[sample["sample_index"] for sample in samples],
+ )
+ reference_model = _actual_model(trainer.reference_policy.model)
+ chosen_reference = score_policy_in_chunks(
+ reference_model,
+ preference_batch.chosen,
+ batch_size=max(1, trainer.batch_size),
+ token_budget=getattr(trainer.config, "score_chunk_size", None),
+ mode="sequence",
+ ).summed_logprobs
+ rejected_reference = score_policy_in_chunks(
+ reference_model,
+ preference_batch.rejected,
+ batch_size=max(1, trainer.batch_size),
+ token_budget=getattr(trainer.config, "score_chunk_size", None),
+ mode="sequence",
+ ).summed_logprobs
+ trainer.runtime_cache_arrays[f"{cache_key}.chosen"] = mx.stop_gradient(chosen_reference.astype(mx.float32))
+ trainer.runtime_cache_arrays[f"{cache_key}.rejected"] = mx.stop_gradient(rejected_reference.astype(mx.float32))
+ trainer.runtime_cache_arrays[f"{cache_key}.sample_indices"] = mx.array(
+ [sample["sample_index"] for sample in samples],
+ dtype=mx.int32,
+ )
+ trainer.cache_metadata.setdefault("reference_score_caches", {})[cache_key] = {
+ "num_samples": len(samples),
+ "sampling_config_fingerprint": trainer._sampling_config_fingerprint(),
+ "dataset_fingerprint": _preference_cache_dataset_fingerprint(samples),
+ }
+ return (
+ trainer.runtime_cache_arrays[f"{cache_key}.chosen"],
+ trainer.runtime_cache_arrays[f"{cache_key}.rejected"],
+ )
+
+
+def _collect_fixed_rollout_batch(
+ trainer: Any,
+ samples: List[Dict[str, Any]],
+ cached_reference_logprobs: Optional[mx.array] = None,
+) -> RolloutBatch:
+ prompt_ids = []
+ completion_ids = []
+ truncation_flags = []
+ max_seq_length = getattr(trainer, "max_seq_length", getattr(trainer.config, "max_seq_length", None))
+ max_completion_length = getattr(
+ trainer,
+ "max_completion_length",
+ getattr(trainer.config, "max_completion_length", None),
+ )
+ for sample in samples:
+ raw_prompt_ids = _encode_text(trainer.tokenizer, sample["prompt"])
+ raw_completion_ids = _encode_text(trainer.tokenizer, sample["completion"], add_special_tokens=False)
+ capped_prompt_ids, capped_completion_ids, truncated = cap_prompt_and_completion_lengths(
+ raw_prompt_ids,
+ raw_completion_ids,
+ max_seq_length=max_seq_length,
+ max_completion_length=max_completion_length,
+ )
+ prompt_ids.append(capped_prompt_ids)
+ completion_ids.append(capped_completion_ids)
+ truncation_flags.append(bool(truncated))
+ full_sequences = [prompt + completion for prompt, completion in zip(prompt_ids, completion_ids)]
+ prompt_lengths = [len(prompt) for prompt in prompt_ids]
+ completion_lengths = [len(completion) for completion in completion_ids]
+ prompt_group_map: Dict[str, int] = {}
+ grouped_indices = []
+ for sample in samples:
+ grouped_indices.append(prompt_group_map.setdefault(sample["prompt"], len(prompt_group_map)))
+ eval_batch = make_policy_eval_batch(
+ full_sequences,
+ pad_id=_pad_token_id(trainer.tokenizer),
+ mode="completion",
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ sample_indices=mx.array([sample["sample_index"] for sample in samples]),
+ prompt_group_indices=mx.array(grouped_indices),
+ reference_logprobs=cached_reference_logprobs,
+ )
+ scored = score_policy_in_chunks(
+ _actual_model(trainer.model),
+ eval_batch,
+ batch_size=_rollout_score_batch_size(trainer),
+ token_budget=getattr(trainer.config, "score_chunk_size", None),
+ mode="completion",
+ )
+ reward_values = [float(sample.get("reward", 0.0) or 0.0) for sample in samples]
+ rollout_batch = RolloutBatch(
+ prompt_ids=prompt_ids,
+ prompt_lengths=mx.array(prompt_lengths),
+ completion_ids=completion_ids,
+ completion_lengths=mx.array(completion_lengths),
+ prompt_texts=[sample["prompt"] for sample in samples],
+ original_prompt_texts=[sample["prompt"] for sample in samples],
+ completion_texts=[sample["completion"] for sample in samples],
+ reward_contexts=[sample.get("reward_context") for sample in samples],
+ sampled_token_logprobs=mx.zeros(
+ (len(samples), max(completion_lengths) if completion_lengths else 0),
+ dtype=mx.float32,
+ ),
+ rollout_logprobs=scored.summed_logprobs,
+ eos_flags=mx.array([True] * len(samples)),
+ truncation_flags=mx.array(truncation_flags),
+ prompt_group_indices=mx.array(grouped_indices),
+ policy_eval=scored,
+ sample_indices=mx.array([sample["sample_index"] for sample in samples]),
+ old_logprobs=scored.summed_logprobs,
+ rewards=mx.array(reward_values, dtype=mx.float32),
+ reference_logprobs=cached_reference_logprobs,
+ )
+ rollout_batch.policy_eval.old_logprobs = scored.summed_logprobs
+ rollout_batch.policy_eval.old_token_logprobs = scored.token_logprobs
+ if cached_reference_logprobs is not None:
+ rollout_batch.policy_eval.reference_logprobs = cached_reference_logprobs
+ return rollout_batch
+
+
+class RewardTrainer(_RLTrainerBase):
+ algorithm = "reward"
+
+ def __init__(
+ self,
+ model: Any,
+ train_dataset: Any,
+ tokenizer: Optional[Any] = None,
+ args: Optional[RewardConfig] = None,
+ reward_model: Optional[RewardModel] = None,
+ use_native: bool = True,
+ **kwargs,
+ ):
+ self.reward_model = model if isinstance(model, RewardModel) else reward_model or build_reward_model(model)
+ self.model = self.reward_model.base_model
+ self.train_dataset = train_dataset
+ self.tokenizer = tokenizer or getattr(self.reward_model, "tokenizer", None) or getattr(self.model, "tokenizer", None)
+ self.use_native = use_native and HAS_NATIVE_TRAINING
+ self.config = args or RewardConfig()
+ self.output_dir = Path(self.config.output_dir)
+ self.learning_rate = self.config.learning_rate
+ self.batch_size = self.config.per_device_train_batch_size
+ self.max_steps = self.config.max_steps
+ self.max_seq_length = self.config.max_seq_length
+ self.pairwise_margin = self.config.pairwise_margin
+ self.regression_loss_type = self.config.regression_loss_type
+ self.logging_steps = self.config.logging_steps
+ self.save_steps = self.config.save_steps
+ dataset_size = len(train_dataset) if hasattr(train_dataset, "__len__") else 100
+ self.iters = self.max_steps if self.max_steps > 0 else max(
+ 1, (dataset_size // max(1, self.batch_size)) * self.config.num_train_epochs
+ )
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+ self.train_samples: List[Dict[str, Any]] = []
+ self.dataset_type: Optional[str] = None
+ self._train_target = _ScalarRoleTrainTarget(
+ _actual_model(self.reward_model.base_model),
+ self.reward_model.head,
+ )
+ self._init_native_state()
+
+ def _primary_role_name(self) -> str:
+ return "reward_model"
+
+ def _primary_role_weight_format(self) -> Any:
+ reward_base = getattr(self.reward_model, "base_model", None)
+ return {
+ "backbone": "weights.safetensors",
+ "head": "head.safetensors",
+ "adapters": "adapters.safetensors"
+ if hasattr(reward_base, "has_adapters") and reward_base.has_adapters()
+ else None,
+ }
+
+ def _save_primary_role(self, checkpoint_dir: Optional[Path] = None) -> None:
+ self.reward_model.save_pretrained(str(self._role_dir("reward_model", checkpoint_dir)))
+
+ def _load_primary_role(self, checkpoint_dir: Path) -> None:
+ role_dir = self._role_dir("reward_model", checkpoint_dir)
+ if role_dir.exists():
+ self.reward_model.load_pretrained(str(role_dir))
+
+ def _prepare_training_samples(self) -> None:
+ records = list(self.train_dataset)
+ dataset_mode = self.config.dataset_mode
+ if dataset_mode is None and records:
+ first_sample = records[0]
+ dataset_mode = "reward_pairwise" if {"chosen", "rejected"} <= set(first_sample.keys()) else "reward_scalar"
+ prepared = prepare_rl_dataset(
+ records,
+ mode=dataset_mode,
+ tokenizer=self.tokenizer,
+ chat_template=getattr(self.config, "chat_template", None),
+ )
+ if not prepared:
+ raise ValueError("RewardTrainer requires pairwise or scalar reward samples.")
+ if prepared.mode not in {"reward_scalar", "reward_pairwise"}:
+ raise ValueError("RewardTrainer requires reward_scalar or reward_pairwise datasets.")
+ normalized = list(prepared)
+ sample_types = {sample["type"] for sample in normalized}
+ if len(sample_types) != 1:
+ raise ValueError("RewardTrainer currently requires all reward samples to use the same supervision type.")
+ self.dataset_type = next(iter(sample_types))
+ self.train_samples = []
+ for sample in normalized:
+ if self.dataset_type == "pairwise":
+ self.train_samples.append(
+ _tokenize_reward_pairwise_sample(self.tokenizer, sample, self.max_seq_length)
+ )
+ else:
+ if "score" not in sample:
+ raise ValueError("RewardTrainer scalar samples require a score field.")
+ self.train_samples.append(
+ _tokenize_reward_scalar_sample(self.tokenizer, sample, self.max_seq_length)
+ )
+
+ def _build_pairwise_batch(self, samples: List[Dict[str, Any]]) -> Dict[str, mx.array]:
+ pad_id = _pad_token_id(self.tokenizer)
+ chosen_ids, chosen_lengths = _pad_sequences([sample["chosen_ids"] for sample in samples], pad_id)
+ rejected_ids, rejected_lengths = _pad_sequences([sample["rejected_ids"] for sample in samples], pad_id)
+ return {
+ "chosen_ids": chosen_ids,
+ "rejected_ids": rejected_ids,
+ "chosen_lengths": chosen_lengths,
+ "rejected_lengths": rejected_lengths,
+ "chosen_prompt_lengths": mx.array([sample["chosen_prompt_length"] for sample in samples]),
+ "rejected_prompt_lengths": mx.array([sample["rejected_prompt_length"] for sample in samples]),
+ "chosen_completion_lengths": mx.array([sample["chosen_completion_length"] for sample in samples]),
+ "rejected_completion_lengths": mx.array([sample["rejected_completion_length"] for sample in samples]),
+ }
+
+ def _build_scalar_batch(self, samples: List[Dict[str, Any]]) -> Dict[str, mx.array]:
+ pad_id = _pad_token_id(self.tokenizer)
+ input_ids, lengths = _pad_sequences([sample["ids"] for sample in samples], pad_id)
+ return {
+ "input_ids": input_ids,
+ "lengths": lengths,
+ "prompt_lengths": mx.array([sample["prompt_length"] for sample in samples]),
+ "completion_lengths": mx.array([sample["completion_length"] for sample in samples]),
+ "targets": mx.array([sample["score"] for sample in samples], dtype=mx.float32),
+ }
+
+ def train(self, resume_from_checkpoint: Optional[str] = None):
+ if not self.use_native:
+ raise ValueError("RewardTrainer requires native MLX training support.")
+ return self._train_native(resume_from_checkpoint=resume_from_checkpoint)
+
+ def _train_native(self, resume_from_checkpoint: Optional[str] = None):
+ self._apply_lora_if_needed()
+ self._prepare_training_samples()
+
+ optimizer = self._optimizer_for_training()
+ self.optimizer = optimizer
+ self.optimizers = {self._primary_optimizer_name(): optimizer}
+
+ if resume_from_checkpoint is not None:
+ self.load_state(optimizer=optimizer, checkpoint_dir=Path(resume_from_checkpoint))
+
+ if self.dataset_type == "pairwise":
+ def loss_fn(_, batch):
+ loss, outputs = reward_model_pairwise_loss(
+ self.reward_model,
+ chosen_input_ids=batch["chosen_ids"],
+ rejected_input_ids=batch["rejected_ids"],
+ chosen_sequence_lengths=batch["chosen_lengths"],
+ rejected_sequence_lengths=batch["rejected_lengths"],
+ chosen_prompt_lengths=batch["chosen_prompt_lengths"],
+ rejected_prompt_lengths=batch["rejected_prompt_lengths"],
+ chosen_completion_lengths=batch["chosen_completion_lengths"],
+ rejected_completion_lengths=batch["rejected_completion_lengths"],
+ margin=self.pairwise_margin,
+ )
+ return loss, outputs
+ else:
+ def loss_fn(_, batch):
+ loss, predictions = reward_model_regression_loss(
+ self.reward_model,
+ input_ids=batch["input_ids"],
+ sequence_lengths=batch["lengths"],
+ targets=batch["targets"],
+ prompt_lengths=batch["prompt_lengths"],
+ completion_lengths=batch["completion_lengths"],
+ loss_type=self.regression_loss_type,
+ )
+ return loss, predictions
+
+ value_and_grad = nn.value_and_grad(self._train_target, lambda modules, batch: loss_fn(modules, batch)[0])
+ running_loss = 0.0
+ last_loss = None
+
+ while self.global_step < self.iters:
+ batch_samples = self._next_samples(self.train_samples)
+ batch = (
+ self._build_pairwise_batch(batch_samples)
+ if self.dataset_type == "pairwise"
+ else self._build_scalar_batch(batch_samples)
+ )
+ loss, grads = value_and_grad(self._train_target, batch)
+ optimizer.update(self._train_target, grads)
+ mx.eval(self._train_target.parameters(), optimizer.state)
+
+ last_loss = float(loss.item())
+ metric_payload: Dict[str, Any] = {"loss": last_loss}
+ if self.dataset_type == "pairwise":
+ _, outputs = loss_fn(self._train_target, batch)
+ metric_payload["ranking_accuracy"] = pairwise_ranking_accuracy(
+ outputs["chosen_scores"],
+ outputs["rejected_scores"],
+ )
+ else:
+ _, predictions = loss_fn(self._train_target, batch)
+ metric_payload.update(
+ scalar_loss_metrics(loss, predictions, batch["targets"])
+ )
+
+ running_loss += last_loss
+ self.global_step += 1
+ self._record_metric(**metric_payload)
+
+ if self.global_step % self.logging_steps == 0:
+ print(
+ f"Reward step {self.global_step}/{self.iters} | "
+ f"loss={running_loss / self.logging_steps:.4f}"
+ )
+ running_loss = 0.0
+
+ if self.global_step % self.save_steps == 0:
+ self.save_state(optimizer=optimizer)
+
+ self.save_state(optimizer=optimizer)
+ return {
+ "status": "success",
+ "global_step": self.global_step,
+ "final_loss": last_loss,
+ "reward_model_path": str(self._role_dir("reward_model")),
+ }
+
+
+class DPOTrainer(_RLTrainerBase):
+ algorithm = "dpo"
+ requires_reference_policy = True
def __init__(
self,
model: Any,
train_dataset: Any,
ref_model: Optional[Any] = None,
+ reward_model: Optional[Any] = None,
+ value_model: Optional[Any] = None,
tokenizer: Optional[Any] = None,
args: Optional[DPOConfig] = None,
use_native: bool = True,
- **kwargs
+ **kwargs,
):
self.model = model
self.ref_model = ref_model
+ self.reward_model = reward_model
+ self.value_model = value_model
self.train_dataset = train_dataset
- self.tokenizer = tokenizer or getattr(model, 'tokenizer', None)
+ self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
self.use_native = use_native and HAS_NATIVE_TRAINING
-
- # Extract config
- if args is None:
- args = DPOConfig()
-
- self.config = args
- self.beta = args.beta
- self.loss_type = args.loss_type
- self.label_smoothing = args.label_smoothing
- self.output_dir = Path(args.output_dir)
- self.learning_rate = args.learning_rate
- self.batch_size = args.per_device_train_batch_size
- self.max_steps = args.max_steps
- self.max_seq_length = args.max_seq_length
- self.max_prompt_length = args.max_prompt_length
- self.gradient_accumulation_steps = args.gradient_accumulation_steps
- self.warmup_steps = args.warmup_steps
- self.logging_steps = args.logging_steps
- self.save_steps = args.save_steps
-
- # Calculate iters
- if self.max_steps > 0:
- self.iters = self.max_steps
- else:
- dataset_size = len(train_dataset) if hasattr(train_dataset, '__len__') else 100
- self.iters = max(1, (dataset_size // self.batch_size) * args.num_train_epochs)
-
+ self.config = args or DPOConfig()
+ self.beta = self.config.beta
+ self.loss_type = self.config.loss_type
+ self.label_smoothing = self.config.label_smoothing
+ self.output_dir = Path(self.config.output_dir)
+ self.learning_rate = self.config.learning_rate
+ self.batch_size = self.config.per_device_train_batch_size
+ self.max_steps = self.config.max_steps
+ self.max_seq_length = self.config.max_seq_length
+ self.max_prompt_length = self.config.max_prompt_length
+ self.gradient_accumulation_steps = self.config.gradient_accumulation_steps
+ self.warmup_steps = self.config.warmup_steps
+ self.logging_steps = self.config.logging_steps
+ self.save_steps = self.config.save_steps
+ dataset_size = len(train_dataset) if hasattr(train_dataset, "__len__") else 100
+ self.iters = self.max_steps if self.max_steps > 0 else max(
+ 1, (dataset_size // max(1, self.batch_size)) * self.config.num_train_epochs
+ )
self.output_dir.mkdir(parents=True, exist_ok=True)
- self.adapter_path = self.output_dir / "adapters"
+ self.adapter_path = self.output_dir / "policy"
self.adapter_path.mkdir(parents=True, exist_ok=True)
+ self.train_samples: List[Dict[str, Any]] = []
+ self._init_native_state()
- print(f"DPOTrainer initialized:")
- print(f" Beta: {self.beta}")
- print(f" Loss type: {self.loss_type}")
- print(f" Learning rate: {self.learning_rate}")
- print(f" Iterations: {self.iters}")
- print(f" Native training: {self.use_native}")
- print(f" Using proper DPO loss: {self.use_native}")
-
- def _tokenize_preference_pair(self, sample: Dict) -> Dict:
- """Tokenize a preference pair (prompt + chosen, prompt + rejected)."""
- prompt = sample.get('prompt', '')
- chosen = sample.get('chosen', '')
- rejected = sample.get('rejected', '')
-
- # Tokenize chosen and rejected with prompt
- chosen_text = prompt + chosen
- rejected_text = prompt + rejected
-
- chosen_ids = self.tokenizer.encode(chosen_text)
- rejected_ids = self.tokenizer.encode(rejected_text)
-
- # Truncate if needed
- if len(chosen_ids) > self.max_seq_length:
- chosen_ids = chosen_ids[:self.max_seq_length]
- if len(rejected_ids) > self.max_seq_length:
- rejected_ids = rejected_ids[:self.max_seq_length]
+ def _tokenize_preference_pair(self, sample: Dict[str, Any], sample_index: int) -> Dict[str, Any]:
+ prompt = sample.get("prompt", "")
+ chosen = sample.get("chosen", "")
+ rejected = sample.get("rejected", "")
+ chosen_ids = self.tokenizer.encode(prompt + chosen)[: self.max_seq_length]
+ rejected_ids = self.tokenizer.encode(prompt + rejected)[: self.max_seq_length]
return {
- 'chosen_ids': chosen_ids,
- 'rejected_ids': rejected_ids,
- 'chosen_length': len(chosen_ids),
- 'rejected_length': len(rejected_ids),
+ "sample_index": sample_index,
+ "chosen_ids": chosen_ids,
+ "rejected_ids": rejected_ids,
+ "chosen_length": len(chosen_ids),
+ "rejected_length": len(rejected_ids),
}
- def _prepare_dpo_batches(self):
- """Prepare batched DPO data for training."""
- tokenized_data = []
- for sample in self.train_dataset:
- if 'prompt' in sample and 'chosen' in sample and 'rejected' in sample:
- tokenized_data.append(self._tokenize_preference_pair(sample))
+ def _prepare_training_samples(self) -> None:
+ self.train_samples = []
+ prepared = prepare_rl_dataset(
+ self.train_dataset,
+ mode=self.config.dataset_mode or "preference",
+ tokenizer=self.tokenizer,
+ chat_template=getattr(self.config, "chat_template", None),
+ )
+ for sample_index, sample in enumerate(prepared):
+ self.train_samples.append(self._tokenize_preference_pair(sample, sample_index))
+ if not self.train_samples:
+ raise ValueError("DPOTrainer requires prompt/chosen/rejected samples.")
+
+ def _precompute_reference_cache(self) -> None:
+ self._ensure_reference_policy()
+ preference_batch = make_preference_batch(
+ chosen_sequences=[sample["chosen_ids"] for sample in self.train_samples],
+ rejected_sequences=[sample["rejected_ids"] for sample in self.train_samples],
+ pad_id=_pad_token_id(self.tokenizer),
+ sample_indices=[sample["sample_index"] for sample in self.train_samples],
+ )
+ ref_chosen = score_policy_in_chunks(
+ _actual_model(self.reference_policy.model),
+ preference_batch.chosen,
+ batch_size=max(1, self.batch_size),
+ mode="sequence",
+ ).summed_logprobs
+ ref_rejected = score_policy_in_chunks(
+ _actual_model(self.reference_policy.model),
+ preference_batch.rejected,
+ batch_size=max(1, self.batch_size),
+ mode="sequence",
+ ).summed_logprobs
+ ref_chosen = mx.stop_gradient(ref_chosen)
+ ref_rejected = mx.stop_gradient(ref_rejected)
+ for idx, sample in enumerate(self.train_samples):
+ sample["reference_chosen_logprobs"] = ref_chosen[idx]
+ sample["reference_rejected_logprobs"] = ref_rejected[idx]
+ self.cache_metadata = {
+ "type": "inline_preference_reference_logprobs",
+ "num_samples": len(self.train_samples),
+ }
- return tokenized_data
+ def _restore_reference_cache(self, flat_state: Dict[str, mx.array]) -> None:
+ if "dpo.reference_chosen_logprobs" not in flat_state:
+ self._precompute_reference_cache()
+ return
+ ref_chosen = flat_state["dpo.reference_chosen_logprobs"]
+ ref_rejected = flat_state["dpo.reference_rejected_logprobs"]
+ if ref_chosen.shape[0] != len(self.train_samples):
+ raise ValueError("Saved DPO cache does not match current dataset ordering.")
+ for idx, sample in enumerate(self.train_samples):
+ sample["reference_chosen_logprobs"] = ref_chosen[idx]
+ sample["reference_rejected_logprobs"] = ref_rejected[idx]
+
+ def _build_batch(self, samples: List[Dict[str, Any]]) -> PreferenceBatch:
+ return make_preference_batch(
+ chosen_sequences=[sample["chosen_ids"] for sample in samples],
+ rejected_sequences=[sample["rejected_ids"] for sample in samples],
+ pad_id=_pad_token_id(self.tokenizer),
+ sample_indices=[sample["sample_index"] for sample in samples],
+ chosen_reference_logprobs=mx.array(
+ [sample["reference_chosen_logprobs"] for sample in samples]
+ ),
+ rejected_reference_logprobs=mx.array(
+ [sample["reference_rejected_logprobs"] for sample in samples]
+ ),
+ )
- def _pad_to_length(self, ids: List[int], length: int, pad_id: int = 0) -> List[int]:
- """Pad sequence to target length."""
- if len(ids) >= length:
- return ids[:length]
- return ids + [pad_id] * (length - len(ids))
+ def _extra_state_arrays(self) -> Dict[str, mx.array]:
+ return {
+ "dpo.reference_chosen_logprobs": mx.array(
+ [sample["reference_chosen_logprobs"] for sample in self.train_samples]
+ ),
+ "dpo.reference_rejected_logprobs": mx.array(
+ [sample["reference_rejected_logprobs"] for sample in self.train_samples]
+ ),
+ }
- def train(self):
- """
- Train the model using DPO with proper loss computation.
+ def train(self, resume_from_checkpoint: Optional[str] = None):
+ if self.use_native:
+ return self._train_native(resume_from_checkpoint=resume_from_checkpoint)
+ return self._train_subprocess()
+
+ def _train_native(self, resume_from_checkpoint: Optional[str] = None):
+ self._apply_lora_if_needed()
+ self._prepare_training_samples()
+
+ actual_model = _actual_model(self.model)
+ optimizer = self._optimizer_for_training()
+ self.optimizer = optimizer
+
+ if resume_from_checkpoint is not None:
+ flat_state = self.load_state(optimizer, Path(resume_from_checkpoint))
+ self._restore_reference_cache(flat_state)
+ else:
+ self._precompute_reference_cache()
+
+ def loss_fn(model, batch):
+ loss, _ = compute_dpo_loss(
+ model=model,
+ chosen_ids=batch.chosen.input_ids,
+ rejected_ids=batch.rejected.input_ids,
+ chosen_lengths=batch.chosen.sequence_lengths,
+ rejected_lengths=batch.rejected.sequence_lengths,
+ beta=self.beta,
+ reference_chosen_logprobs=batch.chosen.reference_logprobs,
+ reference_rejected_logprobs=batch.rejected.reference_logprobs,
+ label_smoothing=self.label_smoothing,
+ )
+ return loss
- Uses native MLX training with real DPO loss when available,
- falls back to SFT approximation otherwise.
- """
- print("=" * 70)
- print("Starting DPO Training")
- print("=" * 70)
+ value_and_grad = nn.value_and_grad(actual_model, loss_fn)
+ running_loss = 0.0
+ last_loss = None
+ while self.global_step < self.iters:
+ batch_samples = self._next_samples(self.train_samples)
+ batch = self._build_batch(batch_samples)
+ loss, grads = value_and_grad(actual_model, batch)
+ optimizer.update(actual_model, grads)
+ mx.eval(actual_model.parameters(), optimizer.state)
+
+ last_loss = loss.item()
+ running_loss += last_loss
+ self.global_step += 1
+ self._record_metric(loss=last_loss)
+
+ if self.global_step % self.logging_steps == 0:
+ print(
+ f"DPO step {self.global_step}/{self.iters} | "
+ f"loss={running_loss / self.logging_steps:.4f}"
+ )
+ running_loss = 0.0
+
+ if self.global_step % self.save_steps == 0:
+ self.save_state(optimizer, self._extra_state_arrays())
+
+ self.save_state(optimizer, self._extra_state_arrays())
+ return {
+ "status": "success",
+ "adapter_path": str(self.adapter_path),
+ "global_step": self.global_step,
+ "final_loss": last_loss,
+ }
+
+ def _train_subprocess(self):
+ warnings.warn(
+ "Native DPO training not available. Using SFT-on-chosen approximation.",
+ UserWarning,
+ )
+ train_file = self.output_dir / "train.jsonl"
+ valid_file = self.output_dir / "valid.jsonl"
+ with open(train_file, "w") as handle:
+ for sample in self.train_dataset:
+ if "prompt" in sample and "chosen" in sample:
+ messages = [
+ {"role": "user", "content": sample["prompt"]},
+ {"role": "assistant", "content": sample["chosen"]},
+ ]
+ handle.write(json.dumps({"messages": messages}) + "\n")
+ valid_file.write_text(train_file.read_text())
+ cmd = [
+ "mlx_lm.lora",
+ "--model",
+ getattr(self.model, "model_name", "model"),
+ "--train",
+ "--data",
+ str(self.output_dir),
+ "--iters",
+ str(self.iters),
+ "--learning-rate",
+ str(self.learning_rate),
+ "--batch-size",
+ str(self.batch_size),
+ "--adapter-path",
+ str(self.adapter_path),
+ ]
+ subprocess.run(cmd, check=True)
+ return {"status": "success", "adapter_path": str(self.adapter_path)}
+
+
+class ORPOTrainer:
+ def __init__(
+ self,
+ model: Any,
+ train_dataset: Any,
+ tokenizer: Optional[Any] = None,
+ args: Optional[ORPOConfig] = None,
+ use_native: bool = True,
+ **kwargs,
+ ):
+ self.model = model
+ self.train_dataset = train_dataset
+ self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
+ self.use_native = use_native and HAS_NATIVE_TRAINING
+ self.config = args or ORPOConfig()
+ self.beta = self.config.beta
+ self.output_dir = Path(self.config.output_dir)
+ self.learning_rate = self.config.learning_rate
+ self.batch_size = self.config.per_device_train_batch_size
+ self.max_steps = self.config.max_steps
+ self.max_seq_length = self.config.max_seq_length
+ self.logging_steps = self.config.logging_steps
+ self.save_steps = self.config.save_steps
+ dataset_size = len(train_dataset) if hasattr(train_dataset, "__len__") else 100
+ self.iters = self.max_steps if self.max_steps > 0 else max(
+ 1, (dataset_size // max(1, self.batch_size)) * self.config.num_train_epochs
+ )
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+ self.adapter_path = self.output_dir / "policy"
+ self.adapter_path.mkdir(parents=True, exist_ok=True)
+
+ def _tokenize_preference_pair(self, sample: Dict[str, Any]) -> Dict[str, Any]:
+ prompt = sample.get("prompt", "")
+ chosen = sample.get("chosen", "")
+ rejected = sample.get("rejected", "")
+ chosen_ids = self.tokenizer.encode(prompt + chosen)[: self.max_seq_length]
+ rejected_ids = self.tokenizer.encode(prompt + rejected)[: self.max_seq_length]
+ return {
+ "chosen_ids": chosen_ids,
+ "rejected_ids": rejected_ids,
+ "chosen_length": len(chosen_ids),
+ "rejected_length": len(rejected_ids),
+ }
+
+ def train(self):
if self.use_native:
return self._train_native()
- else:
- return self._train_subprocess()
+ return self._train_subprocess()
def _train_native(self):
- """Train using native MLX with proper DPO loss."""
- print("\n[Using Native DPO Training with Proper Loss]")
-
- # Apply LoRA if needed
- if hasattr(self.model, '_apply_lora') and not getattr(self.model, '_lora_applied', False):
- print("Applying LoRA adapters...")
+ if hasattr(self.model, "_apply_lora") and not getattr(self.model, "_lora_applied", False):
self.model._apply_lora()
- # Prepare data
- print("Preparing preference data...")
- tokenized_data = self._prepare_dpo_batches()
- print(f"✓ Prepared {len(tokenized_data)} preference pairs")
+ prepared = prepare_rl_dataset(
+ self.train_dataset,
+ mode=self.config.dataset_mode or "preference",
+ tokenizer=self.tokenizer,
+ chat_template=getattr(self.config, "chat_template", None),
+ )
+ tokenized_data = [self._tokenize_preference_pair(sample) for sample in prepared]
+ actual_model = _actual_model(self.model)
+ optimizer = optim.AdamW(learning_rate=optim.cosine_decay(self.learning_rate, self.iters))
- # Get actual model
- actual_model = self.model.model if hasattr(self.model, 'model') else self.model
+ def loss_fn(model, batch):
+ chosen_ids, rejected_ids, chosen_lengths, rejected_lengths = batch
+ loss, _ = compute_orpo_loss(
+ model,
+ chosen_ids,
+ rejected_ids,
+ chosen_lengths,
+ rejected_lengths,
+ self.beta,
+ )
+ return loss
+
+ value_and_grad = nn.value_and_grad(actual_model, loss_fn)
+ last_loss = None
+ pad_id = _pad_token_id(self.tokenizer)
+
+ for step in range(self.iters):
+ samples = tokenized_data[step % len(tokenized_data): step % len(tokenized_data) + self.batch_size]
+ if len(samples) < self.batch_size:
+ samples += tokenized_data[: self.batch_size - len(samples)]
+ chosen_ids, chosen_lengths = _pad_sequences([sample["chosen_ids"] for sample in samples], pad_id)
+ rejected_ids, rejected_lengths = _pad_sequences([sample["rejected_ids"] for sample in samples], pad_id)
+ loss, grads = value_and_grad(actual_model, (chosen_ids, rejected_ids, chosen_lengths, rejected_lengths))
+ optimizer.update(actual_model, grads)
+ mx.eval(actual_model.parameters(), optimizer.state)
+ last_loss = loss.item()
- # Create optimizer
- lr_schedule = optim.cosine_decay(self.learning_rate, self.iters)
- optimizer = optim.AdamW(learning_rate=lr_schedule)
+ _save_adapters_and_config(self.model, self.adapter_path)
+ return {"status": "success", "adapter_path": str(self.adapter_path), "final_loss": last_loss}
- # Training loop
- print(f"\nStarting training for {self.iters} iterations...")
+ def _train_subprocess(self):
+ warnings.warn("Using SFT approximation for ORPO.", UserWarning)
+ train_file = self.output_dir / "train.jsonl"
+ valid_file = self.output_dir / "valid.jsonl"
+ with open(train_file, "w") as handle:
+ for sample in self.train_dataset:
+ if "prompt" in sample and "chosen" in sample:
+ messages = [
+ {"role": "user", "content": sample["prompt"]},
+ {"role": "assistant", "content": sample["chosen"]},
+ ]
+ handle.write(json.dumps({"messages": messages}) + "\n")
+ valid_file.write_text(train_file.read_text())
+ cmd = [
+ "mlx_lm.lora",
+ "--model",
+ getattr(self.model, "model_name", "model"),
+ "--train",
+ "--data",
+ str(self.output_dir),
+ "--iters",
+ str(self.iters),
+ "--learning-rate",
+ str(self.learning_rate),
+ "--batch-size",
+ str(self.batch_size),
+ "--adapter-path",
+ str(self.adapter_path),
+ ]
+ subprocess.run(cmd, check=True)
+ return {"status": "success", "adapter_path": str(self.adapter_path)}
+
+
+class GRPOTrainer(_RLTrainerBase):
+ algorithm = "grpo"
+ requires_reference_policy = True
+
+ def __init__(
+ self,
+ model: Any,
+ train_dataset: Any,
+ eval_dataset: Any = None,
+ eval_preference_dataset: Any = None,
+ tokenizer: Optional[Any] = None,
+ reward_fn: Optional[Callable] = None,
+ reward_model: Optional[Any] = None,
+ ref_model: Optional[Any] = None,
+ value_model: Optional[Any] = None,
+ args: Optional[GRPOConfig] = None,
+ use_native: bool = True,
+ **kwargs,
+ ):
+ self.model = model
+ self.train_dataset = train_dataset
+ self.eval_dataset = eval_dataset
+ self.eval_preference_dataset = eval_preference_dataset
+ self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
+ self.ref_model = ref_model
+ self.reward_model = reward_model or getattr(args, "reward_model", None)
+ self.value_model = value_model or getattr(args, "value_model", None)
+ self.use_native = use_native and HAS_NATIVE_TRAINING
+ self.config = args or GRPOConfig()
+ self.loss_type = self.config.loss_type
+ self.phase1_loss_type = "phase1_shared_rollout_recompute"
+ self.resolved_loss_type = self._resolve_loss_type(self.loss_type)
+ self.advantage_mode = self.config.advantage_estimator
+ self.beta = self.config.kl_beta
+ self.num_generations = self.config.num_generations
+ self.max_completion_length = self.config.max_completion_length
+ self.reward_fn = reward_fn if reward_fn is not None else self.config.reward_fn
+ self.reward_sources = self.config.reward_sources
+ self.reward_source = self.config.reward_source
+ self.kl_target = self.config.kl_target
+ self.kl_penalty_mode = self.config.kl_penalty_mode
+ self.reward_normalization = self.config.reward_normalization
+ self.mask_truncated_completions = self.config.mask_truncated_completions
+ self.minibatch_reuse_steps = self.config.minibatch_reuse_steps
+ self.entropy_bonus = self.config.entropy_bonus
+ self.output_dir = Path(self.config.output_dir)
+ self.learning_rate = self.config.learning_rate
+ self.batch_size = self.config.per_device_train_batch_size
+ self.rollout_batch_size = self.config.rollout_batch_size or self.batch_size
+ self.max_steps = self.config.max_steps
+ self.temperature = self.config.temperature
+ self.clip_epsilon = self.config.clip_epsilon
+ self.epsilon_low = self.config.epsilon_low
+ self.epsilon_high = self.config.epsilon_high
+ self.scale_rewards = self.config.scale_rewards
+ self.eval_steps = self.config.eval_steps
+ self.eval_num_batches = self.config.eval_num_batches
+ self.eval_num_generations = self.config.eval_num_generations or self.num_generations
+ self.generation_batch_size = self.config.generation_batch_size
+ self.score_chunk_size = self.config.score_chunk_size
+ self.precompute_reference_scores = self.config.precompute_reference_scores
+ self.logging_steps = self.config.logging_steps
+ self.save_steps = self.config.save_steps
+ self.max_seq_length = self.config.max_seq_length
+ dataset_size = len(train_dataset) if hasattr(train_dataset, "__len__") else 100
+ self.iters = self.max_steps if self.max_steps > 0 else max(
+ 1, (dataset_size // max(1, self.batch_size)) * self.config.num_train_epochs
+ )
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+ self.adapter_path = self.output_dir / "policy"
+ self.adapter_path.mkdir(parents=True, exist_ok=True)
+ self.prompt_samples: List[Dict[str, Any]] = []
+ self.rollout_samples: List[Dict[str, Any]] = []
+ self.eval_prompt_samples: List[Dict[str, Any]] = []
+ self.eval_rollout_samples: List[Dict[str, Any]] = []
+ self.eval_preference_samples: List[Dict[str, Any]] = []
+ self.prepared_dataset_mode: Optional[str] = None
+ self.prompt_dataset_cursor = 0
+ self.rollout_dataset_cursor = 0
+ self._last_rollout_batch: Optional[RolloutBatch] = None
+ self._last_rollout_phase_metrics: Dict[str, float] = {}
+ self._init_native_state()
+
+ if self.reward_model is None and self.reward_fn is None:
+ self.reward_fn = lambda response, context: len(response.split()) / 100.0
+
+ def _resolve_loss_type(self, loss_type: str) -> str:
+ if loss_type not in GRPO_LOSS_TYPES:
+ raise ValueError(
+ f"Unsupported GRPO loss_type '{loss_type}'. "
+ f"Supported values: {sorted(GRPO_LOSS_TYPES)}"
+ )
+ return loss_type
+
+ def _trainer_cursor_state(self) -> Dict[str, int]:
+ return {
+ "dataset": int(self.dataset_cursor),
+ "prompt_dataset": int(self.prompt_dataset_cursor),
+ "offline_rollout_dataset": int(self.rollout_dataset_cursor),
+ }
+
+ def _restore_trainer_cursors(self, cursors: Dict[str, Any]) -> None:
+ super()._restore_trainer_cursors(cursors)
+ self.prompt_dataset_cursor = int(cursors.get("prompt_dataset", self.prompt_dataset_cursor))
+ self.rollout_dataset_cursor = int(cursors.get("offline_rollout_dataset", self.rollout_dataset_cursor))
+
+ def _sampling_config_payload(self) -> Dict[str, Any]:
+ return {
+ "algorithm": self.algorithm,
+ "loss_type": self.resolved_loss_type,
+ "beta": self.beta,
+ "clip_epsilon": self.clip_epsilon,
+ "epsilon_low": self.epsilon_low,
+ "epsilon_high": self.epsilon_high,
+ "rollout_batch_size": self.rollout_batch_size,
+ "minibatch_reuse_steps": self.minibatch_reuse_steps,
+ "advantage_mode": self.advantage_mode,
+ "scale_rewards": self.scale_rewards,
+ "kl_target": self.kl_target,
+ "kl_penalty_mode": self.kl_penalty_mode,
+ "reward_source": self.reward_source,
+ "reward_normalization": self.reward_normalization,
+ "mask_truncated_completions": self.mask_truncated_completions,
+ "entropy_bonus": self.entropy_bonus,
+ "temperature": self.temperature,
+ "num_generations": self.num_generations,
+ "max_completion_length": self.max_completion_length,
+ "max_seq_length": self.max_seq_length,
+ "generation_batch_size": self.generation_batch_size,
+ "score_chunk_size": self.score_chunk_size,
+ }
+
+ def _prepare_prompt_samples(self) -> None:
+ self.prompt_samples, self.rollout_samples, self.prepared_dataset_mode = _prepare_on_policy_samples(
+ self.train_dataset,
+ self.tokenizer,
+ self.config,
+ )
+ if not self.prompt_samples and not self.rollout_samples:
+ raise ValueError("GRPOTrainer requires prompt or rollout samples.")
+
+ def _prepare_eval_datasets(self) -> None:
+ self.eval_prompt_samples = []
+ self.eval_rollout_samples = []
+ self.eval_preference_samples = []
+ if self.eval_dataset is not None:
+ self.eval_prompt_samples, self.eval_rollout_samples, _ = _prepare_on_policy_samples(
+ self.eval_dataset,
+ self.tokenizer,
+ self.config,
+ )
+ if self.eval_preference_dataset is not None:
+ prepared = prepare_rl_dataset(
+ self.eval_preference_dataset,
+ mode="preference",
+ tokenizer=self.tokenizer,
+ chat_template=getattr(self.config, "chat_template", None),
+ )
+ for sample_index, sample in enumerate(prepared):
+ prompt = sample.get("prompt", "")
+ prompt_ids = _encode_text(self.tokenizer, prompt)
+ chosen_ids = _encode_text(self.tokenizer, sample["chosen"], add_special_tokens=False)
+ rejected_ids = _encode_text(self.tokenizer, sample["rejected"], add_special_tokens=False)
+ self.eval_preference_samples.append(
+ {
+ "sample_index": sample_index,
+ "prompt_ids": prompt_ids,
+ "chosen_ids": prompt_ids + chosen_ids,
+ "rejected_ids": prompt_ids + rejected_ids,
+ "prompt_length": len(prompt_ids),
+ "chosen_completion_length": len(chosen_ids),
+ "rejected_completion_length": len(rejected_ids),
+ }
+ )
+
+ def _resolve_reward_evaluator(self) -> Any:
+ evaluator = _resolve_reward_evaluator(
+ self.reward_model,
+ self.reward_fn,
+ self.reward_sources,
+ )
+ if evaluator is not None:
+ return evaluator
+ return lambda response, context: len(response.split()) / 100.0
+
+ def _next_prompt_batch(self) -> List[Dict[str, Any]]:
+ batch, self.prompt_dataset_cursor = _next_cursor_batch(
+ self.prompt_samples,
+ self.rollout_batch_size,
+ self.prompt_dataset_cursor,
+ self.algorithm,
+ )
+ self.dataset_cursor = self.prompt_dataset_cursor
+ return batch
+
+ def _next_offline_rollout_batch(self) -> List[Dict[str, Any]]:
+ batch, self.rollout_dataset_cursor = _next_cursor_batch(
+ self.rollout_samples,
+ self.rollout_batch_size * self.num_generations,
+ self.rollout_dataset_cursor,
+ self.algorithm,
+ )
+ self.dataset_cursor = self.rollout_dataset_cursor
+ return batch
+
+ def _collect_fixed_rollout_batch(
+ self,
+ samples: List[Dict[str, Any]],
+ cache_key: Optional[str] = None,
+ ) -> RolloutBatch:
+ cached_reference_logprobs = None
+ if cache_key is not None:
+ cached_reference_logprobs = _ensure_fixed_rollout_reference_cache(self, cache_key, samples)
+ return _collect_fixed_rollout_batch(
+ self,
+ samples,
+ cached_reference_logprobs=cached_reference_logprobs,
+ )
+
+ def _collect_rollout_batch(
+ self,
+ prompt_samples: List[Dict[str, Any]],
+ num_generations: Optional[int] = None,
+ cache_key: Optional[str] = None,
+ ) -> RolloutBatch:
+ generations = self.num_generations if num_generations is None else num_generations
+ reward_evaluator = self._resolve_reward_evaluator()
+ rollout_generate_started_at = time.perf_counter()
+ reward_eval_wall = 0.0
+ reference_score_wall = 0.0
+ returns_wall = 0.0
+ if prompt_samples and "completion" in prompt_samples[0]:
+ rollout_batch = self._collect_fixed_rollout_batch(prompt_samples, cache_key=cache_key)
+ rollout_generate_wall = time.perf_counter() - rollout_generate_started_at
+ else:
+ rollout_batch = collect_rollouts(
+ _actual_model(self.model),
+ self.tokenizer,
+ prompt_samples=prompt_samples,
+ sampling_config={
+ "num_generations": generations,
+ "temperature": self.temperature,
+ "max_completion_length": self.max_completion_length,
+ "max_seq_length": self.max_seq_length,
+ "generation_batch_size": self.generation_batch_size,
+ },
+ reward_evaluator=None,
+ collect_sample_stats=self.entropy_bonus != 0.0,
+ )
+ rollout_generate_wall = time.perf_counter() - rollout_generate_started_at
+ if self.reward_source == "offline":
+ raise ValueError("reward_source='offline' requires rollout samples with completion/reward fields.")
+ reward_eval_started_at = time.perf_counter()
+ reward_batch = evaluate_rewards(rollout_batch, reward_evaluator)
+ reward_eval_wall += time.perf_counter() - reward_eval_started_at
+ rollout_batch.rewards = reward_batch.scalar_rewards
+ if prompt_samples and "completion" in prompt_samples[0]:
+ if self.reward_source == "online":
+ reward_eval_started_at = time.perf_counter()
+ reward_batch = evaluate_rewards(rollout_batch, reward_evaluator)
+ reward_eval_wall += time.perf_counter() - reward_eval_started_at
+ rollout_batch.rewards = reward_batch.scalar_rewards
+ elif self.reward_source == "hybrid":
+ reward_eval_started_at = time.perf_counter()
+ reward_batch = evaluate_rewards(rollout_batch, reward_evaluator)
+ reward_eval_wall += time.perf_counter() - reward_eval_started_at
+ rollout_batch.rewards = rollout_batch.rewards + reward_batch.scalar_rewards
+ if rollout_batch.rewards is None:
+ rollout_batch.rewards = mx.zeros((len(rollout_batch.prompt_ids),), dtype=mx.float32)
+ if self.entropy_bonus and rollout_batch.token_entropies is not None:
+ entropy_bonus = mx.mean(rollout_batch.token_entropies, axis=-1) * self.entropy_bonus
+ rollout_batch.rewards = rollout_batch.rewards + entropy_bonus.astype(mx.float32)
+ if self.reward_normalization != "none":
+ rollout_batch.rewards = _normalize_reward_values(
+ rollout_batch.rewards,
+ rollout_batch.prompt_group_indices,
+ self.reward_normalization,
+ )
+ reference_score_started_at = time.perf_counter()
+ rollout_batch = score_rollout_references(
+ _actual_model(self.reference_policy.model) if self.reference_policy is not None else None,
+ rollout_batch,
+ batch_size=_rollout_score_batch_size(self, num_generations=generations),
+ token_budget=self.score_chunk_size,
+ )
+ reference_score_wall = time.perf_counter() - reference_score_started_at
+ if self.mask_truncated_completions:
+ rollout_batch = _apply_truncation_mask_to_rollout(rollout_batch)
+ returns_mode = "rloo" if self.advantage_mode == "rloo" else "group_zscore"
+ if returns_mode == "group_zscore" and self.scale_rewards is False:
+ returns_mode = "group_center"
+ returns_started_at = time.perf_counter()
+ returns, advantages = compute_returns_and_advantages(
+ rewards=rollout_batch.rewards,
+ prompt_group_indices=rollout_batch.prompt_group_indices,
+ mode=returns_mode,
+ )
+ returns_wall = time.perf_counter() - returns_started_at
+ rollout_batch.returns = returns
+ rollout_batch.advantages = advantages
+ rollout_batch.policy_eval.returns = returns
+ rollout_batch.policy_eval.advantages = advantages
+ self._last_rollout_phase_metrics = {
+ "rollout_generate_wall": float(rollout_generate_wall),
+ "reward_eval_wall": float(reward_eval_wall),
+ "reference_score_wall": float(reference_score_wall),
+ "returns_wall": float(returns_wall),
+ }
+ return rollout_batch
+
+ def _evaluate_rollout_metrics(self) -> Dict[str, Any]:
+ prompt_limit = max(1, self.eval_num_batches or 1) * self.rollout_batch_size
+ rollout_limit = max(1, self.eval_num_batches or 1) * self.rollout_batch_size * self.num_generations
+ if self.eval_rollout_samples:
+ rollout_batch = self._collect_rollout_batch(
+ self.eval_rollout_samples[:rollout_limit],
+ cache_key=f"{self.algorithm}.eval_rollout_reference_logprobs",
+ )
+ elif self.eval_prompt_samples:
+ rollout_batch = self._collect_rollout_batch(
+ self.eval_prompt_samples[:prompt_limit],
+ num_generations=self.eval_num_generations,
+ )
+ else:
+ return {}
+ reference_model = _actual_model(self.reference_policy.model)
+ effective_beta = self._effective_kl_beta(rollout_batch)
+ loss, _ = grpo_recompute_loss(
+ model=_actual_model(self.model),
+ reference_model=reference_model,
+ input_ids=rollout_batch.policy_eval.input_ids,
+ prompt_lengths=rollout_batch.policy_eval.prompt_lengths,
+ completion_lengths=rollout_batch.policy_eval.completion_lengths,
+ rollout_logprobs=rollout_batch.policy_eval.rollout_logprobs,
+ old_token_logprobs=rollout_batch.policy_eval.old_token_logprobs,
+ reference_logprobs=rollout_batch.policy_eval.reference_logprobs,
+ advantages=rollout_batch.policy_eval.advantages,
+ beta=effective_beta,
+ clip_epsilon=self.clip_epsilon,
+ epsilon_low=self.epsilon_low,
+ epsilon_high=self.epsilon_high,
+ temperature=self.temperature,
+ loss_type=self.resolved_loss_type,
+ max_completion_length=self.max_completion_length,
+ )
+ return summarize_rollout_metrics(rollout_batch, policy_loss=float(loss.item()))
+
+ def _evaluate_preference_metrics(self) -> Dict[str, Any]:
+ if not self.eval_preference_samples:
+ return {}
+ limit = max(1, self.eval_num_batches or 1) * self.batch_size
+ samples = self.eval_preference_samples[:limit]
+ chosen_batch = make_policy_eval_batch(
+ [sample["chosen_ids"] for sample in samples],
+ pad_id=_pad_token_id(self.tokenizer),
+ mode="completion",
+ prompt_lengths=[sample["prompt_length"] for sample in samples],
+ completion_lengths=[sample["chosen_completion_length"] for sample in samples],
+ sample_indices=mx.array([sample["sample_index"] for sample in samples]),
+ )
+ rejected_batch = make_policy_eval_batch(
+ [sample["rejected_ids"] for sample in samples],
+ pad_id=_pad_token_id(self.tokenizer),
+ mode="completion",
+ prompt_lengths=[sample["prompt_length"] for sample in samples],
+ completion_lengths=[sample["rejected_completion_length"] for sample in samples],
+ sample_indices=mx.array([sample["sample_index"] for sample in samples]),
+ )
+ chosen_scores = score_policy_in_chunks(
+ _actual_model(self.model),
+ chosen_batch,
+ batch_size=max(1, self.batch_size),
+ token_budget=self.score_chunk_size,
+ mode="completion",
+ ).summed_logprobs
+ rejected_scores = score_policy_in_chunks(
+ _actual_model(self.model),
+ rejected_batch,
+ batch_size=max(1, self.batch_size),
+ token_budget=self.score_chunk_size,
+ mode="completion",
+ ).summed_logprobs
+ return {
+ "preference_win_rate": float(mx.mean((chosen_scores > rejected_scores).astype(mx.float32)).item())
+ }
+
+ def evaluate(self) -> Dict[str, Any]:
+ self._prepare_eval_datasets()
+ if self.reference_policy is None:
+ self._ensure_reference_policy()
+ with self._preserve_rng_state():
+ mx.random.seed(int(self.seed) + 100000 + int(self.global_step))
+ metrics: Dict[str, Any] = {}
+ metrics.update(self._evaluate_rollout_metrics())
+ metrics.update(self._evaluate_preference_metrics())
+ if not metrics:
+ return {}
+ return self._record_metrics("eval", metrics)
+
+ def train(self, resume_from_checkpoint: Optional[str] = None):
+ if self.use_native:
+ return self._train_native(resume_from_checkpoint=resume_from_checkpoint)
+ return self._train_subprocess()
+
+ def _train_native(self, resume_from_checkpoint: Optional[str] = None):
+ self._apply_lora_if_needed()
+ self._prepare_prompt_samples()
+ self._prepare_eval_datasets()
+ if resume_from_checkpoint is None:
+ self._seed_training_run()
+
+ actual_model = _actual_model(self.model)
+ optimizer = self._optimizer_for_training()
+ self.optimizer = optimizer
+
+ if resume_from_checkpoint is not None:
+ self.load_state(optimizer=optimizer, checkpoint_dir=Path(resume_from_checkpoint))
+ else:
+ self._ensure_reference_policy()
+
+ reference_model = _actual_model(self.reference_policy.model)
+ if self.rollout_samples:
+ _ensure_fixed_rollout_reference_cache(
+ self,
+ f"{self.algorithm}.train_rollout_reference_logprobs",
+ self.rollout_samples,
+ )
- # Define loss and grad function
- def loss_fn(model, batch_data):
- chosen_ids, rejected_ids, chosen_lengths, rejected_lengths = batch_data
+ effective_beta = self.beta
- loss, ntoks = compute_dpo_loss(
+ def loss_fn(model, batch):
+ loss, _ = grpo_recompute_loss(
model=model,
- chosen_ids=chosen_ids,
- rejected_ids=rejected_ids,
- chosen_lengths=chosen_lengths,
- rejected_lengths=rejected_lengths,
- beta=self.beta,
- label_smoothing=self.label_smoothing,
+ reference_model=reference_model,
+ input_ids=batch.input_ids,
+ prompt_lengths=batch.prompt_lengths,
+ completion_lengths=batch.completion_lengths,
+ rollout_logprobs=batch.rollout_logprobs,
+ old_token_logprobs=batch.old_token_logprobs,
+ reference_logprobs=batch.reference_logprobs,
+ advantages=batch.advantages,
+ beta=effective_beta,
+ clip_epsilon=self.clip_epsilon,
+ epsilon_low=self.epsilon_low,
+ epsilon_high=self.epsilon_high,
+ temperature=self.temperature,
+ loss_type=self.resolved_loss_type,
+ max_completion_length=self.max_completion_length,
)
return loss
- loss_and_grad = nn.value_and_grad(actual_model, loss_fn)
-
- total_loss = 0.0
- for step in range(self.iters):
- # Get batch
- batch_idx = step % len(tokenized_data)
- sample = tokenized_data[batch_idx]
-
- # Pad sequences
- max_len = max(sample['chosen_length'], sample['rejected_length'])
- pad_id = self.tokenizer.pad_token_id or 0
-
- chosen_padded = self._pad_to_length(sample['chosen_ids'], max_len, pad_id)
- rejected_padded = self._pad_to_length(sample['rejected_ids'], max_len, pad_id)
-
- # Create batch tensors
- chosen_ids = mx.array([chosen_padded])
- rejected_ids = mx.array([rejected_padded])
- chosen_lengths = mx.array([sample['chosen_length']])
- rejected_lengths = mx.array([sample['rejected_length']])
-
- batch_data = (chosen_ids, rejected_ids, chosen_lengths, rejected_lengths)
-
- # Compute loss and gradients
- loss, grads = loss_and_grad(actual_model, batch_data)
- optimizer.update(actual_model, grads)
- mx.eval(actual_model.parameters(), optimizer.state)
-
- total_loss += loss.item()
-
- # Logging
- if (step + 1) % self.logging_steps == 0:
- avg_loss = total_loss / self.logging_steps
- print(f" Step {step + 1}/{self.iters} | Loss: {avg_loss:.4f}")
- total_loss = 0.0
-
- # Save checkpoint
- if (step + 1) % self.save_steps == 0:
- self._save_adapters(step + 1)
+ value_and_grad = nn.value_and_grad(actual_model, loss_fn)
+ running_loss = 0.0
+ last_loss = None
+
+ while self.global_step < self.iters:
+ if self.reward_source == "offline" and self.rollout_samples:
+ prompt_samples = self._next_offline_rollout_batch()
+ rollout_batch = self._collect_rollout_batch(
+ prompt_samples,
+ cache_key=f"{self.algorithm}.train_rollout_reference_logprobs",
+ )
+ else:
+ prompt_samples = self._next_prompt_batch()
+ rollout_batch = self._collect_rollout_batch(prompt_samples)
+ self._last_rollout_batch = rollout_batch
+ effective_beta = self._effective_kl_beta(rollout_batch)
+ step_loss = 0.0
+ step_updates = 0
+ policy_update_started_at = time.perf_counter()
+ for _ in range(max(1, self.minibatch_reuse_steps)):
+ minibatches = assemble_minibatches(
+ rollout_batch.policy_eval,
+ minibatch_size=_rollout_score_batch_size(self),
+ shuffle=False,
+ mode="completion",
+ token_budget=self.score_chunk_size,
+ )
+ for minibatch in minibatches:
+ loss, grads = value_and_grad(actual_model, minibatch)
+ optimizer.update(actual_model, grads)
+ mx.eval(actual_model.parameters(), optimizer.state)
+ step_loss += loss.item()
+ step_updates += 1
+ policy_update_wall = time.perf_counter() - policy_update_started_at
+
+ last_loss = step_loss / max(1, step_updates)
+ running_loss += last_loss
+ self.global_step += 1
+ train_row = self._record_metrics(
+ "train",
+ {
+ **summarize_rollout_metrics(rollout_batch, policy_loss=last_loss),
+ **self._last_rollout_phase_metrics,
+ "policy_update_wall": float(policy_update_wall),
+ "policy_update_steps": float(step_updates),
+ },
+ )
- # Final save
- self._save_adapters(self.iters)
+ if self.global_step % self.logging_steps == 0:
+ print(f"GRPO step {self.global_step}/{self.iters} | {self._format_metric_summary(train_row)}")
+ running_loss = 0.0
- print("\n" + "=" * 70)
- print("DPO Training Complete!")
- print("=" * 70)
- print(f" Adapters saved to: {self.adapter_path}")
+ if self.eval_steps and self.global_step % self.eval_steps == 0:
+ eval_row = self.evaluate()
+ if eval_row:
+ print(f"GRPO eval | {self._format_metric_summary(eval_row, namespace='eval')}")
- return {"status": "success", "adapter_path": str(self.adapter_path)}
+ if self.global_step % self.save_steps == 0:
+ self.save_state(optimizer=optimizer)
- def _save_adapters(self, step: int):
- """Save adapter weights and config."""
- if _save_adapters_and_config(self.model, self.adapter_path):
- print(f" ✓ Saved checkpoint at step {step}")
+ self.save_state(optimizer=optimizer)
+ return {
+ "status": "success",
+ "adapter_path": str(self.adapter_path),
+ "global_step": self.global_step,
+ "final_loss": last_loss,
+ }
def _train_subprocess(self):
- """Fallback: Train using subprocess (SFT approximation)."""
warnings.warn(
- "Native DPO training not available. Using SFT on chosen responses. "
- "Install mlx-lm[train] for proper DPO loss.",
- UserWarning
+ "Native GRPO training not available. Using SFT approximation.",
+ UserWarning,
)
-
- print("\n[Using Subprocess Training (SFT Approximation)]")
-
- # Prepare SFT data from chosen responses
train_file = self.output_dir / "train.jsonl"
valid_file = self.output_dir / "valid.jsonl"
-
- with open(train_file, 'w') as f:
+ with open(train_file, "w") as handle:
for sample in self.train_dataset:
- if 'prompt' in sample and 'chosen' in sample:
- messages = [
- {"role": "user", "content": sample['prompt']},
- {"role": "assistant", "content": sample['chosen']}
- ]
- f.write(json.dumps({"messages": messages}) + '\n')
-
- import shutil
- shutil.copy(train_file, valid_file)
-
- model_name = getattr(self.model, 'model_name', 'model')
-
+ prompt = sample.get("prompt", sample.get("question", ""))
+ if not prompt:
+ continue
+ messages = [{"role": "user", "content": prompt}]
+ if "response" in sample or "answer" in sample:
+ response = sample.get("response", sample.get("answer", ""))
+ messages.append({"role": "assistant", "content": response})
+ handle.write(json.dumps({"messages": messages}) + "\n")
+ valid_file.write_text(train_file.read_text())
cmd = [
"mlx_lm.lora",
- "--model", model_name,
+ "--model",
+ getattr(self.model, "model_name", "model"),
"--train",
- "--data", str(self.output_dir),
- "--iters", str(self.iters),
- "--learning-rate", str(self.learning_rate),
- "--batch-size", str(self.batch_size),
- "--adapter-path", str(self.adapter_path),
+ "--data",
+ str(self.output_dir),
+ "--iters",
+ str(self.iters),
+ "--adapter-path",
+ str(self.adapter_path),
]
-
- print(f"Running: {' '.join(cmd)}")
subprocess.run(cmd, check=True)
-
- print("DPO Training Complete (SFT approximation)!")
return {"status": "success", "adapter_path": str(self.adapter_path)}
-class ORPOTrainer:
- """
- Odds Ratio Preference Optimization Trainer.
-
- ORPO combines SFT and preference alignment in a single training step,
- making it simpler and more memory-efficient than DPO.
-
- Compatible with TRL's ORPOTrainer API.
- Now with PROPER ORPO loss implementation!
-
- Example:
- >>> trainer = ORPOTrainer(
- ... model=model,
- ... train_dataset=preference_dataset,
- ... tokenizer=tokenizer,
- ... args=ORPOConfig(beta=0.1),
- ... )
- >>> trainer.train()
- """
+class PPOTrainer(_RLTrainerBase):
+ algorithm = "ppo"
+ requires_reference_policy = True
def __init__(
self,
model: Any,
train_dataset: Any,
+ eval_dataset: Any = None,
+ eval_preference_dataset: Any = None,
tokenizer: Optional[Any] = None,
- args: Optional[ORPOConfig] = None,
+ reward_fn: Optional[Callable] = None,
+ reward_model: Optional[Any] = None,
+ ref_model: Optional[Any] = None,
+ value_model: Optional[Any] = None,
+ args: Optional[PPOConfig] = None,
use_native: bool = True,
- **kwargs
+ **kwargs,
):
self.model = model
self.train_dataset = train_dataset
- self.tokenizer = tokenizer or getattr(model, 'tokenizer', None)
+ self.eval_dataset = eval_dataset
+ self.eval_preference_dataset = eval_preference_dataset
+ self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
+ self.ref_model = ref_model
+ self.config = args or PPOConfig()
+ self.reward_model = reward_model or self.config.reward_model
+ self.reward_fn = reward_fn if reward_fn is not None else self.config.reward_fn
+ self.reward_sources = self.config.reward_sources
+ self.reward_source = self.config.reward_source
+ self.value_model = value_model or self.config.value_model or build_value_model(model)
self.use_native = use_native and HAS_NATIVE_TRAINING
-
- if args is None:
- args = ORPOConfig()
-
- self.config = args
- self.beta = args.beta
- self.output_dir = Path(args.output_dir)
- self.learning_rate = args.learning_rate
- self.batch_size = args.per_device_train_batch_size
- self.max_steps = args.max_steps
- self.max_seq_length = args.max_seq_length
- self.logging_steps = args.logging_steps
- self.save_steps = args.save_steps
-
- if self.max_steps > 0:
- self.iters = self.max_steps
- else:
- dataset_size = len(train_dataset) if hasattr(train_dataset, '__len__') else 100
- self.iters = max(1, (dataset_size // self.batch_size) * args.num_train_epochs)
-
+ self.output_dir = Path(self.config.output_dir)
+ self.learning_rate = self.config.learning_rate
+ self.value_learning_rate = self.config.value_learning_rate
+ self.batch_size = self.config.per_device_train_batch_size
+ self.rollout_batch_size = self.config.rollout_batch_size or self.batch_size
+ self.max_steps = self.config.max_steps
+ self.max_seq_length = self.config.max_seq_length
+ self.max_completion_length = self.config.max_completion_length
+ self.num_generations = self.config.num_generations
+ self.ppo_epochs = self.config.ppo_epochs
+ self.minibatch_reuse_steps = self.config.minibatch_reuse_steps
+ self.temperature = self.config.temperature
+ self.clip_epsilon = self.config.clip_epsilon
+ self.beta = self.config.kl_beta
+ self.reward_normalization = self.config.reward_normalization
+ self.mask_truncated_completions = self.config.mask_truncated_completions
+ self.entropy_bonus = self.config.entropy_bonus
+ self.gamma = self.config.gamma
+ self.gae_lambda = self.config.gae_lambda
+ self.normalize_advantages = self.config.normalize_advantages
+ self.advantage_estimator = self.config.advantage_estimator
+ self.kl_target = self.config.kl_target
+ self.kl_penalty_mode = self.config.kl_penalty_mode
+ self.eval_steps = self.config.eval_steps
+ self.eval_num_batches = self.config.eval_num_batches
+ self.eval_num_generations = self.config.eval_num_generations or self.num_generations
+ self.generation_batch_size = self.config.generation_batch_size
+ self.score_chunk_size = self.config.score_chunk_size
+ self.precompute_reference_scores = self.config.precompute_reference_scores
+ self.logging_steps = self.config.logging_steps
+ self.save_steps = self.config.save_steps
+ dataset_size = len(train_dataset) if hasattr(train_dataset, "__len__") else 100
+ self.iters = self.max_steps if self.max_steps > 0 else max(
+ 1, (dataset_size // max(1, self.batch_size)) * self.config.num_train_epochs
+ )
self.output_dir.mkdir(parents=True, exist_ok=True)
- self.adapter_path = self.output_dir / "adapters"
+ self.adapter_path = self.output_dir / "policy"
self.adapter_path.mkdir(parents=True, exist_ok=True)
+ self.prompt_samples: List[Dict[str, Any]] = []
+ self.rollout_samples: List[Dict[str, Any]] = []
+ self.eval_prompt_samples: List[Dict[str, Any]] = []
+ self.eval_rollout_samples: List[Dict[str, Any]] = []
+ self.eval_preference_samples: List[Dict[str, Any]] = []
+ self.prepared_dataset_mode: Optional[str] = None
+ self.prompt_dataset_cursor = 0
+ self.rollout_dataset_cursor = 0
+ self._last_rollout_batch: Optional[RolloutBatch] = None
+ self._value_train_target = _ScalarRoleTrainTarget(
+ _actual_model(self.value_model.base_model),
+ self.value_model.head,
+ )
+ self._init_native_state()
- print(f"ORPOTrainer initialized:")
- print(f" Beta: {self.beta}")
- print(f" Learning rate: {self.learning_rate}")
- print(f" Iterations: {self.iters}")
- print(f" Native training: {self.use_native}")
-
- def _tokenize_preference_pair(self, sample: Dict) -> Dict:
- """Tokenize a preference pair."""
- prompt = sample.get('prompt', '')
- chosen = sample.get('chosen', '')
- rejected = sample.get('rejected', '')
+ def _optimizer_learning_rates(self) -> Dict[str, float]:
+ return {"policy": self.learning_rate, "value": self.value_learning_rate}
- chosen_ids = self.tokenizer.encode(prompt + chosen)
- rejected_ids = self.tokenizer.encode(prompt + rejected)
+ def _trainer_cursor_state(self) -> Dict[str, int]:
+ return {
+ "dataset": int(self.dataset_cursor),
+ "prompt_dataset": int(self.prompt_dataset_cursor),
+ "offline_rollout_dataset": int(self.rollout_dataset_cursor),
+ }
- if len(chosen_ids) > self.max_seq_length:
- chosen_ids = chosen_ids[:self.max_seq_length]
- if len(rejected_ids) > self.max_seq_length:
- rejected_ids = rejected_ids[:self.max_seq_length]
+ def _restore_trainer_cursors(self, cursors: Dict[str, Any]) -> None:
+ super()._restore_trainer_cursors(cursors)
+ self.prompt_dataset_cursor = int(cursors.get("prompt_dataset", self.prompt_dataset_cursor))
+ self.rollout_dataset_cursor = int(cursors.get("offline_rollout_dataset", self.rollout_dataset_cursor))
+ def _sampling_config_payload(self) -> Dict[str, Any]:
return {
- 'chosen_ids': chosen_ids,
- 'rejected_ids': rejected_ids,
- 'chosen_length': len(chosen_ids),
- 'rejected_length': len(rejected_ids),
+ "algorithm": self.algorithm,
+ "beta": self.beta,
+ "clip_epsilon": self.clip_epsilon,
+ "rollout_batch_size": self.rollout_batch_size,
+ "minibatch_reuse_steps": self.minibatch_reuse_steps,
+ "gamma": self.gamma,
+ "gae_lambda": self.gae_lambda,
+ "normalize_advantages": self.normalize_advantages,
+ "value_learning_rate": self.value_learning_rate,
+ "kl_target": self.kl_target,
+ "kl_penalty_mode": self.kl_penalty_mode,
+ "reward_source": self.reward_source,
+ "reward_normalization": self.reward_normalization,
+ "mask_truncated_completions": self.mask_truncated_completions,
+ "entropy_bonus": self.entropy_bonus,
+ "advantage_estimator": self.advantage_estimator,
+ "temperature": self.temperature,
+ "num_generations": self.num_generations,
+ "max_completion_length": self.max_completion_length,
+ "max_seq_length": self.max_seq_length,
+ "generation_batch_size": self.generation_batch_size,
+ "score_chunk_size": self.score_chunk_size,
}
- def _pad_to_length(self, ids: List[int], length: int, pad_id: int = 0) -> List[int]:
- if len(ids) >= length:
- return ids[:length]
- return ids + [pad_id] * (length - len(ids))
+ def _prepare_prompt_samples(self) -> None:
+ self.prompt_samples, self.rollout_samples, self.prepared_dataset_mode = _prepare_on_policy_samples(
+ self.train_dataset,
+ self.tokenizer,
+ self.config,
+ )
+ if not self.prompt_samples and not self.rollout_samples:
+ raise ValueError("PPOTrainer requires prompt or rollout samples.")
+
+ def _prepare_eval_datasets(self) -> None:
+ self.eval_prompt_samples = []
+ self.eval_rollout_samples = []
+ self.eval_preference_samples = []
+ if self.eval_dataset is not None:
+ self.eval_prompt_samples, self.eval_rollout_samples, _ = _prepare_on_policy_samples(
+ self.eval_dataset,
+ self.tokenizer,
+ self.config,
+ )
+ if self.eval_preference_dataset is not None:
+ prepared = prepare_rl_dataset(
+ self.eval_preference_dataset,
+ mode="preference",
+ tokenizer=self.tokenizer,
+ chat_template=getattr(self.config, "chat_template", None),
+ )
+ for sample_index, sample in enumerate(prepared):
+ prompt = sample.get("prompt", "")
+ prompt_ids = _encode_text(self.tokenizer, prompt)
+ chosen_ids = _encode_text(self.tokenizer, sample["chosen"], add_special_tokens=False)
+ rejected_ids = _encode_text(self.tokenizer, sample["rejected"], add_special_tokens=False)
+ self.eval_preference_samples.append(
+ {
+ "sample_index": sample_index,
+ "prompt_ids": prompt_ids,
+ "chosen_ids": prompt_ids + chosen_ids,
+ "rejected_ids": prompt_ids + rejected_ids,
+ "prompt_length": len(prompt_ids),
+ "chosen_completion_length": len(chosen_ids),
+ "rejected_completion_length": len(rejected_ids),
+ }
+ )
+
+ def _resolve_reward_evaluator(self) -> Any:
+ evaluator = _resolve_reward_evaluator(
+ self.reward_model,
+ self.reward_fn,
+ self.reward_sources,
+ )
+ if evaluator is not None:
+ return evaluator
+ return lambda response, context: len(response.split()) / 100.0
+
+ def _next_prompt_batch(self) -> List[Dict[str, Any]]:
+ batch, self.prompt_dataset_cursor = _next_cursor_batch(
+ self.prompt_samples,
+ self.rollout_batch_size,
+ self.prompt_dataset_cursor,
+ self.algorithm,
+ )
+ self.dataset_cursor = self.prompt_dataset_cursor
+ return batch
+
+ def _next_offline_rollout_batch(self) -> List[Dict[str, Any]]:
+ batch, self.rollout_dataset_cursor = _next_cursor_batch(
+ self.rollout_samples,
+ self.rollout_batch_size * self.num_generations,
+ self.rollout_dataset_cursor,
+ self.algorithm,
+ )
+ self.dataset_cursor = self.rollout_dataset_cursor
+ return batch
- def train(self):
- """Train using ORPO with proper loss."""
- print("=" * 70)
- print("Starting ORPO Training")
- print("=" * 70)
+ def _collect_fixed_rollout_batch(
+ self,
+ samples: List[Dict[str, Any]],
+ cache_key: Optional[str] = None,
+ ) -> RolloutBatch:
+ cached_reference_logprobs = None
+ if cache_key is not None:
+ cached_reference_logprobs = _ensure_fixed_rollout_reference_cache(self, cache_key, samples)
+ return _collect_fixed_rollout_batch(
+ self,
+ samples,
+ cached_reference_logprobs=cached_reference_logprobs,
+ )
- if self.use_native:
- return self._train_native()
+ def _collect_rollout_batch(
+ self,
+ prompt_samples: List[Dict[str, Any]],
+ num_generations: Optional[int] = None,
+ cache_key: Optional[str] = None,
+ ) -> RolloutBatch:
+ generations = self.num_generations if num_generations is None else num_generations
+ reward_evaluator = self._resolve_reward_evaluator()
+ if prompt_samples and "completion" in prompt_samples[0]:
+ rollout_batch = self._collect_fixed_rollout_batch(prompt_samples, cache_key=cache_key)
else:
- return self._train_subprocess()
-
- def _train_native(self):
- """Train with native ORPO loss."""
- print("\n[Using Native ORPO Training with Proper Loss]")
+ rollout_batch = collect_rollouts(
+ _actual_model(self.model),
+ self.tokenizer,
+ prompt_samples=prompt_samples,
+ sampling_config={
+ "num_generations": generations,
+ "temperature": self.temperature,
+ "max_completion_length": self.max_completion_length,
+ "max_seq_length": self.max_seq_length,
+ "generation_batch_size": self.generation_batch_size,
+ },
+ reward_evaluator=None,
+ collect_sample_stats=self.entropy_bonus != 0.0,
+ )
+ if self.reward_source == "offline":
+ raise ValueError("reward_source='offline' requires rollout samples with completion/reward fields.")
+ reward_batch = evaluate_rewards(rollout_batch, reward_evaluator)
+ rollout_batch.rewards = reward_batch.scalar_rewards
+ if prompt_samples and "completion" in prompt_samples[0]:
+ if self.reward_source == "online":
+ reward_batch = evaluate_rewards(rollout_batch, reward_evaluator)
+ rollout_batch.rewards = reward_batch.scalar_rewards
+ elif self.reward_source == "hybrid":
+ reward_batch = evaluate_rewards(rollout_batch, reward_evaluator)
+ rollout_batch.rewards = rollout_batch.rewards + reward_batch.scalar_rewards
+ if rollout_batch.rewards is None:
+ rollout_batch.rewards = mx.zeros((len(rollout_batch.prompt_ids),), dtype=mx.float32)
+ if self.entropy_bonus and rollout_batch.token_entropies is not None:
+ entropy_bonus = mx.mean(rollout_batch.token_entropies, axis=-1) * self.entropy_bonus
+ rollout_batch.rewards = rollout_batch.rewards + entropy_bonus.astype(mx.float32)
+ if self.reward_normalization != "none":
+ rollout_batch.rewards = _normalize_reward_values(
+ rollout_batch.rewards,
+ rollout_batch.prompt_group_indices,
+ self.reward_normalization,
+ )
+ rollout_batch = score_rollout_references(
+ _actual_model(self.reference_policy.model),
+ rollout_batch,
+ batch_size=_rollout_score_batch_size(self, num_generations=generations),
+ token_budget=self.score_chunk_size,
+ )
+ if self.mask_truncated_completions:
+ rollout_batch = _apply_truncation_mask_to_rollout(rollout_batch)
+ rollout_batch = predict_rollout_values(
+ self.value_model,
+ rollout_batch,
+ batch_size=_rollout_score_batch_size(self, num_generations=generations),
+ token_budget=self.score_chunk_size,
+ )
+ advantage_mode = self.advantage_estimator
+ rollout_values = rollout_batch.value_predictions if advantage_mode == "gae" else None
+ returns, advantages = compute_returns_and_advantages(
+ rewards=rollout_batch.rewards,
+ values=rollout_values,
+ prompt_group_indices=rollout_batch.prompt_group_indices,
+ mode=advantage_mode,
+ gamma=self.gamma,
+ gae_lambda=self.gae_lambda,
+ normalize=self.normalize_advantages,
+ )
+ rollout_batch.returns = returns
+ rollout_batch.advantages = advantages
+ rollout_batch.policy_eval.returns = returns
+ rollout_batch.policy_eval.advantages = advantages
+ rollout_batch.policy_eval.reference_logprobs = rollout_batch.reference_logprobs
+ rollout_batch.policy_eval.value_predictions = rollout_batch.value_predictions
+ return rollout_batch
+
+ def _evaluate_rollout_metrics(self) -> Dict[str, Any]:
+ prompt_limit = max(1, self.eval_num_batches or 1) * self.rollout_batch_size
+ rollout_limit = max(1, self.eval_num_batches or 1) * self.rollout_batch_size * self.num_generations
+ if self.eval_rollout_samples:
+ rollout_batch = self._collect_rollout_batch(
+ self.eval_rollout_samples[:rollout_limit],
+ cache_key=f"{self.algorithm}.eval_rollout_reference_logprobs",
+ )
+ elif self.eval_prompt_samples:
+ rollout_batch = self._collect_rollout_batch(
+ self.eval_prompt_samples[:prompt_limit],
+ num_generations=self.eval_num_generations,
+ )
+ else:
+ return {}
+ effective_beta = self._effective_kl_beta(rollout_batch)
+ policy_loss, _ = ppo_sequence_loss(
+ model=_actual_model(self.model),
+ batch=rollout_batch.policy_eval,
+ beta=effective_beta,
+ clip_epsilon=self.clip_epsilon,
+ temperature=self.temperature,
+ )
+ value_loss, _ = value_model_regression_loss(
+ self.value_model,
+ input_ids=rollout_batch.policy_eval.input_ids,
+ sequence_lengths=rollout_batch.policy_eval.sequence_lengths,
+ targets=rollout_batch.policy_eval.returns,
+ prompt_lengths=rollout_batch.policy_eval.prompt_lengths,
+ completion_lengths=rollout_batch.policy_eval.completion_lengths,
+ )
+ return summarize_rollout_metrics(
+ rollout_batch,
+ policy_loss=float(policy_loss.item()),
+ value_loss=float(value_loss.item()),
+ )
- if hasattr(self.model, '_apply_lora') and not getattr(self.model, '_lora_applied', False):
- self.model._apply_lora()
+ def _evaluate_preference_metrics(self) -> Dict[str, Any]:
+ if not self.eval_preference_samples:
+ return {}
+ limit = max(1, self.eval_num_batches or 1) * self.batch_size
+ samples = self.eval_preference_samples[:limit]
+ chosen_batch = make_policy_eval_batch(
+ [sample["chosen_ids"] for sample in samples],
+ pad_id=_pad_token_id(self.tokenizer),
+ mode="completion",
+ prompt_lengths=[sample["prompt_length"] for sample in samples],
+ completion_lengths=[sample["chosen_completion_length"] for sample in samples],
+ sample_indices=mx.array([sample["sample_index"] for sample in samples]),
+ )
+ rejected_batch = make_policy_eval_batch(
+ [sample["rejected_ids"] for sample in samples],
+ pad_id=_pad_token_id(self.tokenizer),
+ mode="completion",
+ prompt_lengths=[sample["prompt_length"] for sample in samples],
+ completion_lengths=[sample["rejected_completion_length"] for sample in samples],
+ sample_indices=mx.array([sample["sample_index"] for sample in samples]),
+ )
+ chosen_scores = score_policy_in_chunks(
+ _actual_model(self.model),
+ chosen_batch,
+ batch_size=max(1, self.batch_size),
+ token_budget=self.score_chunk_size,
+ mode="completion",
+ ).summed_logprobs
+ rejected_scores = score_policy_in_chunks(
+ _actual_model(self.model),
+ rejected_batch,
+ batch_size=max(1, self.batch_size),
+ token_budget=self.score_chunk_size,
+ mode="completion",
+ ).summed_logprobs
+ return {
+ "preference_win_rate": float(mx.mean((chosen_scores > rejected_scores).astype(mx.float32)).item())
+ }
- # Prepare data
- tokenized_data = []
- for sample in self.train_dataset:
- if 'prompt' in sample and 'chosen' in sample and 'rejected' in sample:
- tokenized_data.append(self._tokenize_preference_pair(sample))
- print(f"✓ Prepared {len(tokenized_data)} preference pairs")
+ def evaluate(self) -> Dict[str, Any]:
+ self._prepare_eval_datasets()
+ if self.reference_policy is None:
+ self._ensure_reference_policy()
+ with self._preserve_rng_state():
+ mx.random.seed(int(self.seed) + 100000 + int(self.global_step))
+ metrics: Dict[str, Any] = {}
+ metrics.update(self._evaluate_rollout_metrics())
+ metrics.update(self._evaluate_preference_metrics())
+ if not metrics:
+ return {}
+ return self._record_metrics("eval", metrics)
+
+ def train(self, resume_from_checkpoint: Optional[str] = None):
+ if not self.use_native:
+ raise ValueError("PPOTrainer requires native MLX training support.")
+ return self._train_native(resume_from_checkpoint=resume_from_checkpoint)
+
+ def _train_native(self, resume_from_checkpoint: Optional[str] = None):
+ self._apply_lora_if_needed()
+ if hasattr(self.value_model.base_model, "_apply_lora") and not getattr(self.value_model.base_model, "_lora_applied", False):
+ self.value_model.base_model._apply_lora()
+ self._prepare_prompt_samples()
+ self._prepare_eval_datasets()
+ if resume_from_checkpoint is None:
+ self._seed_training_run()
+
+ policy_model = _actual_model(self.model)
+ policy_optimizer = self._optimizer_for_training(self.learning_rate)
+ value_optimizer = self._optimizer_for_training(self.value_learning_rate)
+ self.optimizer = policy_optimizer
+ self.optimizers = {"policy": policy_optimizer, "value": value_optimizer}
+
+ if resume_from_checkpoint is not None:
+ self.load_state(
+ checkpoint_dir=Path(resume_from_checkpoint),
+ optimizers=self.optimizers,
+ )
+ else:
+ self._ensure_reference_policy()
+ if self.rollout_samples:
+ _ensure_fixed_rollout_reference_cache(
+ self,
+ f"{self.algorithm}.train_rollout_reference_logprobs",
+ self.rollout_samples,
+ )
- actual_model = self.model.model if hasattr(self.model, 'model') else self.model
- lr_schedule = optim.cosine_decay(self.learning_rate, self.iters)
- optimizer = optim.AdamW(learning_rate=lr_schedule)
+ effective_beta = self.beta
- def loss_fn(model, batch_data):
- chosen_ids, rejected_ids, chosen_lengths, rejected_lengths = batch_data
- loss, _ = compute_orpo_loss(
- model, chosen_ids, rejected_ids, chosen_lengths, rejected_lengths, self.beta
+ def policy_loss_fn(model, batch):
+ loss, _ = ppo_sequence_loss(
+ model=model,
+ batch=batch,
+ beta=effective_beta,
+ clip_epsilon=self.clip_epsilon,
+ temperature=self.temperature,
)
return loss
- loss_and_grad = nn.value_and_grad(actual_model, loss_fn)
-
- total_loss = 0.0
- for step in range(self.iters):
- batch_idx = step % len(tokenized_data)
- sample = tokenized_data[batch_idx]
-
- max_len = max(sample['chosen_length'], sample['rejected_length'])
- pad_id = self.tokenizer.pad_token_id or 0
-
- chosen_ids = mx.array([self._pad_to_length(sample['chosen_ids'], max_len, pad_id)])
- rejected_ids = mx.array([self._pad_to_length(sample['rejected_ids'], max_len, pad_id)])
- chosen_lengths = mx.array([sample['chosen_length']])
- rejected_lengths = mx.array([sample['rejected_length']])
-
- loss, grads = loss_and_grad(actual_model, (chosen_ids, rejected_ids, chosen_lengths, rejected_lengths))
- optimizer.update(actual_model, grads)
- mx.eval(actual_model.parameters(), optimizer.state)
-
- total_loss += loss.item()
-
- if (step + 1) % self.logging_steps == 0:
- print(f" Step {step + 1}/{self.iters} | Loss: {total_loss / self.logging_steps:.4f}")
- total_loss = 0.0
-
- # Save adapters and config
- _save_adapters_and_config(self.model, self.adapter_path)
+ def value_loss_fn(_, batch):
+ loss, _ = value_model_regression_loss(
+ self.value_model,
+ input_ids=batch.input_ids,
+ sequence_lengths=batch.sequence_lengths,
+ targets=batch.returns,
+ prompt_lengths=batch.prompt_lengths,
+ completion_lengths=batch.completion_lengths,
+ )
+ return loss
- print("\n" + "=" * 70)
- print("ORPO Training Complete!")
- print("=" * 70)
- print(f" Adapters saved to: {self.adapter_path}")
- return {"status": "success", "adapter_path": str(self.adapter_path)}
+ policy_value_and_grad = nn.value_and_grad(policy_model, policy_loss_fn)
+ value_value_and_grad = nn.value_and_grad(self._value_train_target, value_loss_fn)
+ running_loss = 0.0
+ last_policy_loss = None
+ last_value_loss = None
+
+ while self.global_step < self.iters:
+ if self.reward_source == "offline" and self.rollout_samples:
+ prompt_samples = self._next_offline_rollout_batch()
+ rollout_batch = self._collect_rollout_batch(
+ prompt_samples,
+ cache_key=f"{self.algorithm}.train_rollout_reference_logprobs",
+ )
+ else:
+ prompt_samples = self._next_prompt_batch()
+ rollout_batch = self._collect_rollout_batch(prompt_samples)
+ self._last_rollout_batch = rollout_batch
+ effective_beta = self._effective_kl_beta(rollout_batch)
+
+ total_policy_loss = 0.0
+ total_value_loss = 0.0
+ update_count = 0
+ for _ in range(max(1, self.minibatch_reuse_steps)):
+ minibatches = assemble_minibatches(
+ rollout_batch.policy_eval,
+ minibatch_size=_rollout_score_batch_size(self),
+ shuffle=True,
+ mode="completion",
+ token_budget=self.score_chunk_size,
+ )
+ for minibatch in minibatches:
+ policy_loss, policy_grads = policy_value_and_grad(policy_model, minibatch)
+ policy_optimizer.update(policy_model, policy_grads)
+ mx.eval(policy_model.parameters(), policy_optimizer.state)
+
+ value_loss, value_grads = value_value_and_grad(self._value_train_target, minibatch)
+ value_optimizer.update(self._value_train_target, value_grads)
+ mx.eval(self._value_train_target.parameters(), value_optimizer.state)
+
+ total_policy_loss += float(policy_loss.item())
+ total_value_loss += float(value_loss.item())
+ update_count += 1
+
+ last_policy_loss = total_policy_loss / max(1, update_count)
+ last_value_loss = total_value_loss / max(1, update_count)
+ running_loss += last_policy_loss
+ self.global_step += 1
+ train_row = self._record_metrics(
+ "train",
+ summarize_rollout_metrics(
+ rollout_batch,
+ policy_loss=last_policy_loss,
+ value_loss=last_value_loss,
+ ),
+ )
- def _train_subprocess(self):
- """Fallback subprocess training."""
- warnings.warn("Using SFT approximation for ORPO.", UserWarning)
+ if self.global_step % self.logging_steps == 0:
+ print(f"PPO step {self.global_step}/{self.iters} | {self._format_metric_summary(train_row)}")
+ running_loss = 0.0
- train_file = self.output_dir / "train.jsonl"
- with open(train_file, 'w') as f:
- for sample in self.train_dataset:
- if 'prompt' in sample and 'chosen' in sample:
- messages = [
- {"role": "user", "content": sample['prompt']},
- {"role": "assistant", "content": sample['chosen']}
- ]
- f.write(json.dumps({"messages": messages}) + '\n')
+ if self.eval_steps and self.global_step % self.eval_steps == 0:
+ eval_row = self.evaluate()
+ if eval_row:
+ print(f"PPO eval | {self._format_metric_summary(eval_row, namespace='eval')}")
- import shutil
- shutil.copy(train_file, self.output_dir / "valid.jsonl")
+ if self.global_step % self.save_steps == 0:
+ self.save_state(optimizers=self.optimizers)
- cmd = [
- "mlx_lm.lora", "--model", getattr(self.model, 'model_name', 'model'),
- "--train", "--data", str(self.output_dir), "--iters", str(self.iters),
- "--learning-rate", str(self.learning_rate), "--batch-size", str(self.batch_size),
- "--adapter-path", str(self.adapter_path),
- ]
- subprocess.run(cmd, check=True)
- return {"status": "success"}
+ self.save_state(optimizers=self.optimizers)
+ return {
+ "status": "success",
+ "adapter_path": str(self.adapter_path),
+ "global_step": self.global_step,
+ "final_loss": last_policy_loss,
+ "final_value_loss": last_value_loss,
+ }
-class GRPOTrainer:
- """
- Group Relative Policy Optimization Trainer.
-
- GRPO is the technique used by DeepSeek to train reasoning models like R1.
- It removes the need for a value model by using group statistics from
- multiple generations and custom reward functions.
-
- Key features:
- - No value model needed (uses group statistics)
- - Custom reward functions (for math, code verification, etc.)
- - Supports GRPO, Dr.GRPO, DAPO, BNPO variants
- - NOW WITH FULL MULTI-GENERATION IMPLEMENTATION!
-
- Example:
- >>> def math_reward(response, prompt):
- ... # Custom reward for math problems
- ... return 1.0 if "correct" in response.lower() else 0.0
- >>>
- >>> trainer = GRPOTrainer(
- ... model=model,
- ... train_dataset=math_dataset,
- ... tokenizer=tokenizer,
- ... reward_fn=math_reward,
- ... args=GRPOConfig(
- ... loss_type='grpo',
- ... num_generations=4,
- ... ),
- ... )
- >>> trainer.train()
- """
+class OnlineDPOTrainer(_RLTrainerBase):
+ algorithm = "online_dpo"
+ requires_reference_policy = True
def __init__(
self,
model: Any,
train_dataset: Any,
+ eval_dataset: Any = None,
+ eval_preference_dataset: Any = None,
tokenizer: Optional[Any] = None,
reward_fn: Optional[Callable] = None,
- args: Optional[GRPOConfig] = None,
+ reward_model: Optional[Any] = None,
+ ref_model: Optional[Any] = None,
+ args: Optional[OnlineDPOConfig] = None,
use_native: bool = True,
- **kwargs
+ **kwargs,
):
self.model = model
self.train_dataset = train_dataset
- self.tokenizer = tokenizer or getattr(model, 'tokenizer', None)
+ self.eval_dataset = eval_dataset
+ self.eval_preference_dataset = eval_preference_dataset
+ self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
+ self.ref_model = ref_model
+ self.config = args or OnlineDPOConfig()
+ self.reward_model = reward_model or self.config.reward_model
+ self.reward_fn = reward_fn if reward_fn is not None else self.config.reward_fn
+ self.reward_sources = self.config.reward_sources
+ self.reward_source = self.config.reward_source
self.use_native = use_native and HAS_NATIVE_TRAINING
-
- if args is None:
- args = GRPOConfig()
-
- self.config = args
- self.loss_type = args.loss_type
- self.beta = args.beta
- self.num_generations = args.num_generations
- self.max_completion_length = args.max_completion_length
- self.reward_fn = reward_fn or args.reward_fn
- self.output_dir = Path(args.output_dir)
- self.learning_rate = args.learning_rate
- self.batch_size = args.per_device_train_batch_size
- self.max_steps = args.max_steps
- self.temperature = args.temperature
- self.logging_steps = args.logging_steps
- self.save_steps = args.save_steps
-
- if self.max_steps > 0:
- self.iters = self.max_steps
- else:
- dataset_size = len(train_dataset) if hasattr(train_dataset, '__len__') else 100
- self.iters = max(1, (dataset_size // self.batch_size) * args.num_train_epochs)
-
+ self.beta = self.config.beta
+ self.kl_target = self.config.kl_target
+ self.kl_penalty_mode = self.config.kl_penalty_mode
+ self.label_smoothing = self.config.label_smoothing
+ self.output_dir = Path(self.config.output_dir)
+ self.learning_rate = self.config.learning_rate
+ self.batch_size = self.config.per_device_train_batch_size
+ self.rollout_batch_size = self.config.rollout_batch_size or self.batch_size
+ self.max_steps = self.config.max_steps
+ self.max_seq_length = self.config.max_seq_length
+ self.max_completion_length = self.config.max_completion_length
+ self.num_generations = self.config.num_generations
+ self.temperature = self.config.temperature
+ self.reward_normalization = self.config.reward_normalization
+ self.mask_truncated_completions = self.config.mask_truncated_completions
+ self.minibatch_reuse_steps = self.config.minibatch_reuse_steps
+ self.entropy_bonus = self.config.entropy_bonus
+ self.eval_steps = self.config.eval_steps
+ self.eval_num_batches = self.config.eval_num_batches
+ self.eval_num_generations = self.config.eval_num_generations or self.num_generations
+ self.generation_batch_size = self.config.generation_batch_size
+ self.score_chunk_size = self.config.score_chunk_size
+ self.precompute_reference_scores = self.config.precompute_reference_scores
+ self.logging_steps = self.config.logging_steps
+ self.save_steps = self.config.save_steps
+ dataset_size = len(train_dataset) if hasattr(train_dataset, "__len__") else 100
+ self.iters = self.max_steps if self.max_steps > 0 else max(
+ 1, (dataset_size // max(1, self.batch_size)) * self.config.num_train_epochs
+ )
self.output_dir.mkdir(parents=True, exist_ok=True)
- self.adapter_path = self.output_dir / "adapters"
+ self.adapter_path = self.output_dir / "policy"
self.adapter_path.mkdir(parents=True, exist_ok=True)
+ self.prompt_samples: List[Dict[str, Any]] = []
+ self.rollout_samples: List[Dict[str, Any]] = []
+ self.eval_prompt_samples: List[Dict[str, Any]] = []
+ self.eval_rollout_samples: List[Dict[str, Any]] = []
+ self.eval_preference_samples: List[Dict[str, Any]] = []
+ self.prepared_dataset_mode: Optional[str] = None
+ self.prompt_dataset_cursor = 0
+ self.rollout_dataset_cursor = 0
+ self._last_rollout_batch: Optional[RolloutBatch] = None
+ self._init_native_state()
+ if self.num_generations < 2:
+ raise ValueError("OnlineDPOTrainer requires num_generations >= 2.")
+
+ def _prepare_prompt_samples(self) -> None:
+ self.prompt_samples, self.rollout_samples, self.prepared_dataset_mode = _prepare_on_policy_samples(
+ self.train_dataset,
+ self.tokenizer,
+ self.config,
+ )
+ if not self.prompt_samples and not self.rollout_samples:
+ raise ValueError("OnlineDPOTrainer requires prompt or rollout samples.")
- # Default reward function if none provided
- if self.reward_fn is None:
- self.reward_fn = lambda response, prompt: len(response.split()) / 100.0
-
- print(f"GRPOTrainer initialized:")
- print(f" Loss type: {self.loss_type}")
- print(f" Beta: {self.beta}")
- print(f" Num generations: {self.num_generations}")
- print(f" Custom reward fn: {'Yes' if reward_fn else 'Default (length-based)'}")
- print(f" Learning rate: {self.learning_rate}")
- print(f" Iterations: {self.iters}")
- print(f" Native GRPO: {self.use_native}")
-
- def train(self):
- """
- Train using GRPO with multi-generation sampling.
- """
- print("=" * 70)
- print(f"Starting GRPO Training (loss_type={self.loss_type})")
- print("=" * 70)
-
- if self.use_native:
- return self._train_native()
- else:
- return self._train_subprocess()
-
- def _train_native(self):
- """Train with native GRPO: multi-generation + reward + policy gradient."""
- print("\n[Using Native GRPO Training with Multi-Generation]")
-
- if hasattr(self.model, '_apply_lora') and not getattr(self.model, '_lora_applied', False):
- self.model._apply_lora()
-
- # Prepare prompts
- prompts = []
- for sample in self.train_dataset:
- if 'prompt' in sample:
- prompts.append(sample['prompt'])
- elif 'question' in sample:
- prompts.append(sample['question'])
- print(f"✓ Prepared {len(prompts)} prompts")
-
- actual_model = self.model.model if hasattr(self.model, 'model') else self.model
- lr_schedule = optim.cosine_decay(self.learning_rate, self.iters)
- optimizer = optim.AdamW(learning_rate=lr_schedule)
+ def _trainer_cursor_state(self) -> Dict[str, int]:
+ return {
+ "dataset": int(self.dataset_cursor),
+ "prompt_dataset": int(self.prompt_dataset_cursor),
+ "offline_rollout_dataset": int(self.rollout_dataset_cursor),
+ }
- print(f"\nStarting training for {self.iters} iterations...")
- print(f" Generating {self.num_generations} completions per prompt")
+ def _restore_trainer_cursors(self, cursors: Dict[str, Any]) -> None:
+ super()._restore_trainer_cursors(cursors)
+ self.prompt_dataset_cursor = int(cursors.get("prompt_dataset", self.prompt_dataset_cursor))
+ self.rollout_dataset_cursor = int(cursors.get("offline_rollout_dataset", self.rollout_dataset_cursor))
- total_loss = 0.0
- for step in range(self.iters):
- # Get prompt for this step
- prompt_idx = step % len(prompts)
- prompt = prompts[prompt_idx]
+ def _sampling_config_payload(self) -> Dict[str, Any]:
+ return {
+ "algorithm": self.algorithm,
+ "beta": self.beta,
+ "label_smoothing": self.label_smoothing,
+ "kl_target": self.kl_target,
+ "kl_penalty_mode": self.kl_penalty_mode,
+ "rollout_batch_size": self.rollout_batch_size,
+ "minibatch_reuse_steps": self.minibatch_reuse_steps,
+ "reward_source": self.reward_source,
+ "reward_normalization": self.reward_normalization,
+ "mask_truncated_completions": self.mask_truncated_completions,
+ "entropy_bonus": self.entropy_bonus,
+ "temperature": self.temperature,
+ "num_generations": self.num_generations,
+ "max_completion_length": self.max_completion_length,
+ "max_seq_length": self.max_seq_length,
+ "generation_batch_size": self.generation_batch_size,
+ "score_chunk_size": self.score_chunk_size,
+ }
- # Compute GRPO loss with multi-generation
- loss, n_gen = grpo_batch_loss(
- model=actual_model,
+ def _prepare_eval_datasets(self) -> None:
+ self.eval_prompt_samples = []
+ self.eval_rollout_samples = []
+ self.eval_preference_samples = []
+ if self.eval_dataset is not None:
+ self.eval_prompt_samples, self.eval_rollout_samples, _ = _prepare_on_policy_samples(
+ self.eval_dataset,
+ self.tokenizer,
+ self.config,
+ )
+ if self.eval_preference_dataset is not None:
+ prepared = prepare_rl_dataset(
+ self.eval_preference_dataset,
+ mode="preference",
tokenizer=self.tokenizer,
- prompts=[prompt],
- reward_fn=self.reward_fn,
- num_generations=self.num_generations,
- temperature=self.temperature,
- max_tokens=self.max_completion_length,
- beta=self.beta,
+ chat_template=getattr(self.config, "chat_template", None),
)
+ for sample_index, sample in enumerate(prepared):
+ prompt = sample.get("prompt", "")
+ prompt_ids = _encode_text(self.tokenizer, prompt)
+ chosen_ids = prompt_ids + _encode_text(self.tokenizer, sample["chosen"], add_special_tokens=False)
+ rejected_ids = prompt_ids + _encode_text(self.tokenizer, sample["rejected"], add_special_tokens=False)
+ self.eval_preference_samples.append(
+ {
+ "sample_index": sample_index,
+ "chosen_ids": chosen_ids,
+ "rejected_ids": rejected_ids,
+ "chosen_length": len(chosen_ids),
+ "rejected_length": len(rejected_ids),
+ }
+ )
+
+ def _resolve_reward_evaluator(self) -> Any:
+ evaluator = _resolve_reward_evaluator(
+ self.reward_model,
+ self.reward_fn,
+ self.reward_sources,
+ )
+ if evaluator is not None:
+ return evaluator
+ return lambda response, context: len(response.split()) / 100.0
+
+ def _next_prompt_batch(self) -> List[Dict[str, Any]]:
+ batch, self.prompt_dataset_cursor = _next_cursor_batch(
+ self.prompt_samples,
+ self.rollout_batch_size,
+ self.prompt_dataset_cursor,
+ self.algorithm,
+ )
+ self.dataset_cursor = self.prompt_dataset_cursor
+ return batch
+
+ def _next_offline_rollout_batch(self) -> List[Dict[str, Any]]:
+ batch, self.rollout_dataset_cursor = _next_cursor_batch(
+ self.rollout_samples,
+ self.rollout_batch_size * self.num_generations,
+ self.rollout_dataset_cursor,
+ self.algorithm,
+ )
+ self.dataset_cursor = self.rollout_dataset_cursor
+ return batch
+
+ def _collect_fixed_rollout_batch(
+ self,
+ samples: List[Dict[str, Any]],
+ cache_key: Optional[str] = None,
+ ) -> RolloutBatch:
+ cached_reference_logprobs = None
+ if cache_key is not None:
+ cached_reference_logprobs = _ensure_fixed_rollout_reference_cache(self, cache_key, samples)
+ return _collect_fixed_rollout_batch(
+ self,
+ samples,
+ cached_reference_logprobs=cached_reference_logprobs,
+ )
- # Manual backward pass since grpo_batch_loss generates internally
- # For a proper implementation, we'd need to track gradients through generation
- # This is a simplified version that uses the loss for logging
- mx.eval(loss)
- total_loss += loss.item()
+ def _build_online_preference_batch(self, rollout_batch: RolloutBatch) -> Optional[PreferenceBatch]:
+ rankings = rank_grouped_rollouts(rollout_batch)
+ chosen_sequences: List[List[int]] = []
+ rejected_sequences: List[List[int]] = []
+ sample_indices: List[int] = []
+ for ranking in rankings:
+ if ranking["all_tied"]:
+ continue
+ best = ranking["best_position"]
+ worst = ranking["worst_position"]
+ chosen_sequences.append(rollout_batch.prompt_ids[best] + rollout_batch.completion_ids[best])
+ rejected_sequences.append(rollout_batch.prompt_ids[worst] + rollout_batch.completion_ids[worst])
+ sample_indices.append(int(rollout_batch.sample_indices[best].item()))
+ if not chosen_sequences:
+ return None
+
+ preference_batch = make_preference_batch(
+ chosen_sequences=chosen_sequences,
+ rejected_sequences=rejected_sequences,
+ pad_id=_pad_token_id(self.tokenizer),
+ sample_indices=sample_indices,
+ )
+ reference_model = _actual_model(self.reference_policy.model)
+ preference_batch.chosen_reference_logprobs = score_policy_in_chunks(
+ reference_model,
+ preference_batch.chosen,
+ batch_size=max(1, self.batch_size),
+ token_budget=self.score_chunk_size,
+ mode="sequence",
+ ).summed_logprobs
+ preference_batch.rejected_reference_logprobs = score_policy_in_chunks(
+ reference_model,
+ preference_batch.rejected,
+ batch_size=max(1, self.batch_size),
+ token_budget=self.score_chunk_size,
+ mode="sequence",
+ ).summed_logprobs
+ preference_batch.chosen.reference_logprobs = preference_batch.chosen_reference_logprobs
+ preference_batch.rejected.reference_logprobs = preference_batch.rejected_reference_logprobs
+ return preference_batch
+
+ def _collect_rollout_batch(
+ self,
+ prompt_samples: List[Dict[str, Any]],
+ num_generations: Optional[int] = None,
+ cache_key: Optional[str] = None,
+ ) -> RolloutBatch:
+ generations = self.num_generations if num_generations is None else num_generations
+ if prompt_samples and "completion" in prompt_samples[0]:
+ rollout_batch = self._collect_fixed_rollout_batch(prompt_samples, cache_key=cache_key)
+ else:
+ rollout_batch = collect_rollouts(
+ _actual_model(self.model),
+ self.tokenizer,
+ prompt_samples=prompt_samples,
+ sampling_config={
+ "num_generations": generations,
+ "temperature": self.temperature,
+ "max_completion_length": self.max_completion_length,
+ "max_seq_length": self.max_seq_length,
+ "generation_batch_size": self.generation_batch_size,
+ },
+ reward_evaluator=None,
+ collect_sample_stats=self.entropy_bonus != 0.0,
+ )
+ if self.reward_source == "offline":
+ raise ValueError("reward_source='offline' requires rollout samples with completion/reward fields.")
+ reward_batch = evaluate_rewards(rollout_batch, self._resolve_reward_evaluator())
+ rollout_batch.rewards = reward_batch.scalar_rewards
+ if prompt_samples and "completion" in prompt_samples[0]:
+ if self.reward_source == "online":
+ reward_batch = evaluate_rewards(rollout_batch, self._resolve_reward_evaluator())
+ rollout_batch.rewards = reward_batch.scalar_rewards
+ elif self.reward_source == "hybrid":
+ reward_batch = evaluate_rewards(rollout_batch, self._resolve_reward_evaluator())
+ rollout_batch.rewards = rollout_batch.rewards + reward_batch.scalar_rewards
+ if self.entropy_bonus and rollout_batch.token_entropies is not None:
+ entropy_bonus = mx.mean(rollout_batch.token_entropies, axis=-1) * self.entropy_bonus
+ rollout_batch.rewards = rollout_batch.rewards + entropy_bonus.astype(mx.float32)
+ if self.reward_normalization != "none":
+ rollout_batch.rewards = _normalize_reward_values(
+ rollout_batch.rewards,
+ rollout_batch.prompt_group_indices,
+ self.reward_normalization,
+ )
+ if self.kl_penalty_mode != "none" or self.kl_target is not None:
+ rollout_batch = score_rollout_references(
+ _actual_model(self.reference_policy.model),
+ rollout_batch,
+ batch_size=_rollout_score_batch_size(self, num_generations=generations),
+ token_budget=self.score_chunk_size,
+ )
+ if self.mask_truncated_completions:
+ rollout_batch = _apply_truncation_mask_to_rollout(rollout_batch)
+ return rollout_batch
+
+ def _evaluate_rollout_metrics(self) -> Dict[str, Any]:
+ prompt_limit = max(1, self.eval_num_batches or 1) * self.rollout_batch_size
+ rollout_limit = max(1, self.eval_num_batches or 1) * self.rollout_batch_size * self.num_generations
+ if self.eval_rollout_samples:
+ rollout_batch = self._collect_rollout_batch(self.eval_rollout_samples[:rollout_limit])
+ elif self.eval_prompt_samples:
+ rollout_batch = self._collect_rollout_batch(
+ self.eval_prompt_samples[:prompt_limit],
+ num_generations=self.eval_num_generations,
+ )
+ else:
+ return {}
+ preference_batch = self._build_online_preference_batch(rollout_batch)
+ metrics = summarize_rollout_metrics(rollout_batch)
+ if preference_batch is not None:
+ effective_beta = self._effective_kl_beta(rollout_batch)
+ loss, _ = compute_dpo_loss(
+ model=_actual_model(self.model),
+ chosen_ids=preference_batch.chosen.input_ids,
+ rejected_ids=preference_batch.rejected.input_ids,
+ chosen_lengths=preference_batch.chosen.sequence_lengths,
+ rejected_lengths=preference_batch.rejected.sequence_lengths,
+ beta=effective_beta,
+ reference_chosen_logprobs=preference_batch.chosen.reference_logprobs,
+ reference_rejected_logprobs=preference_batch.rejected.reference_logprobs,
+ label_smoothing=self.label_smoothing,
+ )
+ metrics["policy_loss"] = float(loss.item())
+ return metrics
+
+ def _evaluate_preference_metrics(self) -> Dict[str, Any]:
+ if not self.eval_preference_samples:
+ return {}
+ limit = max(1, self.eval_num_batches or 1) * self.batch_size
+ samples = self.eval_preference_samples[:limit]
+ preference_batch = make_preference_batch(
+ chosen_sequences=[sample["chosen_ids"] for sample in samples],
+ rejected_sequences=[sample["rejected_ids"] for sample in samples],
+ pad_id=_pad_token_id(self.tokenizer),
+ sample_indices=[sample["sample_index"] for sample in samples],
+ )
+ chosen_scores = score_policy_in_chunks(
+ _actual_model(self.model),
+ preference_batch.chosen,
+ batch_size=max(1, self.batch_size),
+ token_budget=self.score_chunk_size,
+ mode="sequence",
+ ).summed_logprobs
+ rejected_scores = score_policy_in_chunks(
+ _actual_model(self.model),
+ preference_batch.rejected,
+ batch_size=max(1, self.batch_size),
+ token_budget=self.score_chunk_size,
+ mode="sequence",
+ ).summed_logprobs
+ metrics = {
+ "preference_win_rate": float(mx.mean((chosen_scores > rejected_scores).astype(mx.float32)).item())
+ }
+ chosen_reference, rejected_reference = _ensure_preference_reference_cache(
+ self,
+ f"{self.algorithm}.eval_preference_reference_logprobs",
+ samples,
+ )
+ if chosen_reference is None or rejected_reference is None:
+ reference_model = _actual_model(self.reference_policy.model)
+ chosen_reference = score_policy_in_chunks(
+ reference_model,
+ preference_batch.chosen,
+ batch_size=max(1, self.batch_size),
+ token_budget=self.score_chunk_size,
+ mode="sequence",
+ ).summed_logprobs
+ rejected_reference = score_policy_in_chunks(
+ reference_model,
+ preference_batch.rejected,
+ batch_size=max(1, self.batch_size),
+ token_budget=self.score_chunk_size,
+ mode="sequence",
+ ).summed_logprobs
+ loss, _ = compute_dpo_loss(
+ model=_actual_model(self.model),
+ chosen_ids=preference_batch.chosen.input_ids,
+ rejected_ids=preference_batch.rejected.input_ids,
+ chosen_lengths=preference_batch.chosen.sequence_lengths,
+ rejected_lengths=preference_batch.rejected.sequence_lengths,
+ beta=self.beta,
+ reference_chosen_logprobs=chosen_reference,
+ reference_rejected_logprobs=rejected_reference,
+ label_smoothing=self.label_smoothing,
+ )
+ metrics["policy_loss"] = float(loss.item())
+ return metrics
+
+ def evaluate(self) -> Dict[str, Any]:
+ self._prepare_eval_datasets()
+ if self.reference_policy is None:
+ self._ensure_reference_policy()
+ with self._preserve_rng_state():
+ mx.random.seed(int(self.seed) + 100000 + int(self.global_step))
+ metrics: Dict[str, Any] = {}
+ metrics.update(self._evaluate_rollout_metrics())
+ metrics.update(self._evaluate_preference_metrics())
+ if not metrics:
+ return {}
+ return self._record_metrics("eval", metrics)
+
+ def train(self, resume_from_checkpoint: Optional[str] = None):
+ if not self.use_native:
+ raise ValueError("OnlineDPOTrainer requires native MLX training support.")
+ return self._train_native(resume_from_checkpoint=resume_from_checkpoint)
+
+ def _train_native(self, resume_from_checkpoint: Optional[str] = None):
+ self._apply_lora_if_needed()
+ self._prepare_prompt_samples()
+ self._prepare_eval_datasets()
+ if resume_from_checkpoint is None:
+ self._seed_training_run()
+
+ actual_model = _actual_model(self.model)
+ optimizer = self._optimizer_for_training()
+ self.optimizer = optimizer
+ self.optimizers = {"policy": optimizer}
+
+ if resume_from_checkpoint is not None:
+ self.load_state(optimizer=optimizer, checkpoint_dir=Path(resume_from_checkpoint))
+ else:
+ self._ensure_reference_policy()
+ if self.rollout_samples:
+ _ensure_fixed_rollout_reference_cache(
+ self,
+ f"{self.algorithm}.train_rollout_reference_logprobs",
+ self.rollout_samples,
+ )
- if (step + 1) % self.logging_steps == 0:
- avg_loss = total_loss / self.logging_steps
- print(f" Step {step + 1}/{self.iters} | Loss: {avg_loss:.4f}")
- total_loss = 0.0
+ effective_beta = self.beta
- # Save adapters and config
- _save_adapters_and_config(self.model, self.adapter_path)
+ def loss_fn(model, batch):
+ loss, _ = compute_dpo_loss(
+ model=model,
+ chosen_ids=batch.chosen.input_ids,
+ rejected_ids=batch.rejected.input_ids,
+ chosen_lengths=batch.chosen.sequence_lengths,
+ rejected_lengths=batch.rejected.sequence_lengths,
+ beta=effective_beta,
+ reference_chosen_logprobs=batch.chosen.reference_logprobs,
+ reference_rejected_logprobs=batch.rejected.reference_logprobs,
+ label_smoothing=self.label_smoothing,
+ )
+ return loss
- print("\n" + "=" * 70)
- print("GRPO Training Complete!")
- print("=" * 70)
- print(f" Adapters saved to: {self.adapter_path}")
- print(f"Note: Full GRPO with gradient flow through generation requires")
- print(f" custom implementation. This version uses reward signals.")
- return {"status": "success", "adapter_path": str(self.adapter_path)}
+ value_and_grad = nn.value_and_grad(actual_model, loss_fn)
+ running_loss = 0.0
+ last_loss = None
+
+ while self.global_step < self.iters:
+ if self.reward_source == "offline" and self.rollout_samples:
+ prompt_samples = self._next_offline_rollout_batch()
+ rollout_batch = self._collect_rollout_batch(
+ prompt_samples,
+ cache_key=f"{self.algorithm}.train_rollout_reference_logprobs",
+ )
+ else:
+ prompt_samples = self._next_prompt_batch()
+ rollout_batch = self._collect_rollout_batch(prompt_samples)
+ self._last_rollout_batch = rollout_batch
+ effective_beta = self._effective_kl_beta(rollout_batch)
+ preference_batch = self._build_online_preference_batch(rollout_batch)
+ if preference_batch is None:
+ self.global_step += 1
+ train_row = self._record_metrics(
+ "train",
+ {**summarize_rollout_metrics(rollout_batch, policy_loss=0.0), "skipped_pairs": True},
+ )
+ if self.global_step % self.logging_steps == 0:
+ print(f"Online DPO step {self.global_step}/{self.iters} | {self._format_metric_summary(train_row)}")
+ continue
+
+ loss, grads = value_and_grad(actual_model, preference_batch)
+ optimizer.update(actual_model, grads)
+ mx.eval(actual_model.parameters(), optimizer.state)
- def _train_subprocess(self):
- """Fallback to SFT approximation."""
- warnings.warn(
- "Native GRPO not available. Using SFT on provided responses.",
- UserWarning
- )
+ last_loss = float(loss.item())
+ running_loss += last_loss
+ self.global_step += 1
+ train_row = self._record_metrics(
+ "train",
+ summarize_rollout_metrics(rollout_batch, policy_loss=last_loss),
+ )
- train_file = self.output_dir / "train.jsonl"
- with open(train_file, 'w') as f:
- for sample in self.train_dataset:
- if 'prompt' in sample:
- messages = [{"role": "user", "content": sample['prompt']}]
- if 'response' in sample or 'answer' in sample:
- response = sample.get('response', sample.get('answer', ''))
- messages.append({"role": "assistant", "content": response})
- f.write(json.dumps({"messages": messages}) + '\n')
+ if self.global_step % self.logging_steps == 0:
+ print(f"Online DPO step {self.global_step}/{self.iters} | {self._format_metric_summary(train_row)}")
+ running_loss = 0.0
- import shutil
- shutil.copy(train_file, self.output_dir / "valid.jsonl")
+ if self.eval_steps and self.global_step % self.eval_steps == 0:
+ eval_row = self.evaluate()
+ if eval_row:
+ print(f"Online DPO eval | {self._format_metric_summary(eval_row, namespace='eval')}")
- cmd = [
- "mlx_lm.lora", "--model", getattr(self.model, 'model_name', 'model'),
- "--train", "--data", str(self.output_dir), "--iters", str(self.iters),
- "--adapter-path", str(self.adapter_path),
- ]
- subprocess.run(cmd, check=True)
- return {"status": "success"}
+ if self.global_step % self.save_steps == 0:
+ self.save_state(optimizer=optimizer)
+ self.save_state(optimizer=optimizer)
+ return {
+ "status": "success",
+ "adapter_path": str(self.adapter_path),
+ "global_step": self.global_step,
+ "final_loss": last_loss,
+ }
-class KTOTrainer:
- """
- Kahneman-Tversky Optimization Trainer.
- KTO uses prospect theory for preference optimization,
- treating gains and losses asymmetrically.
- Now with proper KTO loss implementation!
- """
+class KTOTrainer(_RLTrainerBase):
+ algorithm = "kto"
+ requires_reference_policy = True
def __init__(
self,
@@ -959,106 +4135,169 @@ def __init__(
train_dataset: Any,
tokenizer: Optional[Any] = None,
beta: float = 0.1,
+ args: Optional[KTOConfig] = None,
+ ref_model: Optional[Any] = None,
+ reward_model: Optional[Any] = None,
+ value_model: Optional[Any] = None,
use_native: bool = True,
- **kwargs
+ **kwargs,
):
self.model = model
self.train_dataset = train_dataset
- self.tokenizer = tokenizer or getattr(model, 'tokenizer', None)
- self.beta = beta
+ self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
+ self.config = args or KTOConfig(beta=beta, **kwargs)
+ self.beta = self.config.beta
+ self.ref_model = ref_model
+ self.reward_model = reward_model
+ self.value_model = value_model
self.use_native = use_native and HAS_NATIVE_TRAINING
- self.output_dir = Path(kwargs.get('output_dir', './kto_outputs'))
- self.learning_rate = kwargs.get('learning_rate', 5e-7)
- self.iters = kwargs.get('max_steps', 100)
- self.max_seq_length = kwargs.get('max_seq_length', 2048)
- self.logging_steps = kwargs.get('logging_steps', 10)
-
+ self.output_dir = Path(self.config.output_dir)
+ self.learning_rate = self.config.learning_rate
+ self.iters = self.config.max_steps
+ self.max_seq_length = self.config.max_seq_length
+ self.batch_size = self.config.per_device_train_batch_size
+ self.logging_steps = self.config.logging_steps
+ self.save_steps = self.config.save_steps
self.output_dir.mkdir(parents=True, exist_ok=True)
self.adapter_path = self.output_dir / "adapters"
self.adapter_path.mkdir(parents=True, exist_ok=True)
+ self.train_samples: List[Dict[str, Any]] = []
+ self._init_native_state()
+
+ def _prepare_training_samples(self) -> None:
+ self.train_samples = []
+ for sample_index, sample in enumerate(self.train_dataset):
+ if "text" not in sample or "label" not in sample:
+ continue
+ ids = self.tokenizer.encode(sample["text"])[: self.max_seq_length]
+ self.train_samples.append(
+ {
+ "sample_index": sample_index,
+ "ids": ids,
+ "length": len(ids),
+ "label": float(sample["label"]),
+ }
+ )
+ if not self.train_samples:
+ raise ValueError("KTOTrainer requires text/label samples.")
+
+ def _precompute_reference_cache(self) -> None:
+ self._ensure_reference_policy()
+ eval_batch = make_policy_eval_batch(
+ [sample["ids"] for sample in self.train_samples],
+ pad_id=_pad_token_id(self.tokenizer),
+ mode="sequence",
+ labels=mx.array([sample["label"] for sample in self.train_samples]),
+ sample_indices=mx.array([sample["sample_index"] for sample in self.train_samples]),
+ )
+ reference_logprobs = score_policy_in_chunks(
+ _actual_model(self.reference_policy.model),
+ eval_batch,
+ batch_size=max(1, self.batch_size),
+ mode="sequence",
+ ).summed_logprobs
+ reference_logprobs = mx.stop_gradient(reference_logprobs)
+ for idx, sample in enumerate(self.train_samples):
+ sample["reference_logprobs"] = reference_logprobs[idx]
+ self.cache_metadata = {
+ "type": "inline_kto_reference_logprobs",
+ "num_samples": len(self.train_samples),
+ }
- print(f"KTOTrainer initialized (beta={self.beta}, native={self.use_native})")
-
- def train(self):
- """Train using KTO with proper loss."""
- print("=" * 70)
- print("Starting KTO Training")
- print("=" * 70)
-
- if not self.use_native:
- warnings.warn("KTO requires native training. Using SFT approximation.", UserWarning)
- return {"status": "fallback"}
+ def _restore_reference_cache(self, flat_state: Dict[str, mx.array]) -> None:
+ if "kto.reference_logprobs" not in flat_state:
+ self._precompute_reference_cache()
+ return
+ reference_logprobs = flat_state["kto.reference_logprobs"]
+ if reference_logprobs.shape[0] != len(self.train_samples):
+ raise ValueError("Saved KTO cache does not match current dataset ordering.")
+ for idx, sample in enumerate(self.train_samples):
+ sample["reference_logprobs"] = reference_logprobs[idx]
+
+ def _build_batch(self, samples: List[Dict[str, Any]]) -> PolicyEvalBatch:
+ return make_policy_eval_batch(
+ [sample["ids"] for sample in samples],
+ pad_id=_pad_token_id(self.tokenizer),
+ mode="sequence",
+ labels=mx.array([sample["label"] for sample in samples]),
+ reference_logprobs=mx.array([sample["reference_logprobs"] for sample in samples]),
+ sample_indices=mx.array([sample["sample_index"] for sample in samples]),
+ )
- print("\n[Using Native KTO Training with Proper Loss]")
+ def _extra_state_arrays(self) -> Dict[str, mx.array]:
+ return {
+ "kto.reference_logprobs": mx.array(
+ [sample["reference_logprobs"] for sample in self.train_samples]
+ )
+ }
- if hasattr(self.model, '_apply_lora') and not getattr(self.model, '_lora_applied', False):
- self.model._apply_lora()
+ def train(self, resume_from_checkpoint: Optional[str] = None):
+ if self.use_native:
+ return self._train_native(resume_from_checkpoint=resume_from_checkpoint)
+ warnings.warn("KTO requires native training. Using SFT approximation.", UserWarning)
+ return {"status": "fallback"}
- actual_model = self.model.model if hasattr(self.model, 'model') else self.model
- lr_schedule = optim.cosine_decay(self.learning_rate, self.iters)
- optimizer = optim.AdamW(learning_rate=lr_schedule)
-
- # Prepare data - KTO expects samples with 'text' and 'label' (1=positive, 0=negative)
- tokenized_data = []
- for sample in self.train_dataset:
- if 'text' in sample and 'label' in sample:
- ids = self.tokenizer.encode(sample['text'])[:self.max_seq_length]
- tokenized_data.append({
- 'ids': ids,
- 'length': len(ids),
- 'label': float(sample['label']),
- })
-
- print(f"✓ Prepared {len(tokenized_data)} samples")
-
- def loss_fn(model, batch_data):
- input_ids, lengths, labels = batch_data
- loss, _ = compute_kto_loss(model, input_ids, lengths, labels, self.beta)
- return loss
+ def _train_native(self, resume_from_checkpoint: Optional[str] = None):
+ self._apply_lora_if_needed()
+ self._prepare_training_samples()
- loss_and_grad = nn.value_and_grad(actual_model, loss_fn)
+ actual_model = _actual_model(self.model)
+ optimizer = self._optimizer_for_training()
+ self.optimizer = optimizer
- total_loss = 0.0
- for step in range(self.iters):
- sample = tokenized_data[step % len(tokenized_data)]
- pad_id = self.tokenizer.pad_token_id or 0
+ if resume_from_checkpoint is not None:
+ flat_state = self.load_state(optimizer, Path(resume_from_checkpoint))
+ self._restore_reference_cache(flat_state)
+ else:
+ self._precompute_reference_cache()
- max_len = sample['length']
- ids_padded = sample['ids'] + [pad_id] * (max_len - len(sample['ids']))
+ def loss_fn(model, batch):
+ loss, _ = compute_kto_loss(
+ model=model,
+ input_ids=batch.input_ids,
+ lengths=batch.sequence_lengths,
+ labels=batch.labels,
+ beta=self.beta,
+ reference_logprobs=batch.reference_logprobs,
+ )
+ return loss
- input_ids = mx.array([ids_padded])
- lengths = mx.array([sample['length']])
- labels = mx.array([sample['label']])
+ value_and_grad = nn.value_and_grad(actual_model, loss_fn)
+ running_loss = 0.0
+ last_loss = None
- loss, grads = loss_and_grad(actual_model, (input_ids, lengths, labels))
+ while self.global_step < self.iters:
+ batch_samples = self._next_samples(self.train_samples)
+ batch = self._build_batch(batch_samples)
+ loss, grads = value_and_grad(actual_model, batch)
optimizer.update(actual_model, grads)
mx.eval(actual_model.parameters(), optimizer.state)
- total_loss += loss.item()
+ last_loss = loss.item()
+ running_loss += last_loss
+ self.global_step += 1
+ self._record_metric(loss=last_loss)
- if (step + 1) % self.logging_steps == 0:
- print(f" Step {step + 1}/{self.iters} | Loss: {total_loss / self.logging_steps:.4f}")
- total_loss = 0.0
+ if self.global_step % self.logging_steps == 0:
+ print(
+ f"KTO step {self.global_step}/{self.iters} | "
+ f"loss={running_loss / self.logging_steps:.4f}"
+ )
+ running_loss = 0.0
- # Save adapters and config
- _save_adapters_and_config(self.model, self.adapter_path)
+ if self.global_step % self.save_steps == 0:
+ self.save_state(optimizer, self._extra_state_arrays())
- print("\n" + "=" * 70)
- print("KTO Training Complete!")
- print("=" * 70)
- print(f" Adapters saved to: {self.adapter_path}")
- return {"status": "success", "adapter_path": str(self.adapter_path)}
+ self.save_state(optimizer, self._extra_state_arrays())
+ return {
+ "status": "success",
+ "adapter_path": str(self.adapter_path),
+ "global_step": self.global_step,
+ "final_loss": last_loss,
+ }
class SimPOTrainer:
- """
- Simple Preference Optimization Trainer.
-
- SimPO simplifies DPO by removing the reference model requirement.
- Uses length-normalized log probabilities as implicit rewards.
- Now with proper SimPO loss implementation!
- """
-
def __init__(
self,
model: Any,
@@ -1066,207 +4305,98 @@ def __init__(
tokenizer: Optional[Any] = None,
gamma: float = 0.5,
beta: float = 2.0,
+ args: Optional[SimPOConfig] = None,
use_native: bool = True,
- **kwargs
+ **kwargs,
):
self.model = model
self.train_dataset = train_dataset
- self.tokenizer = tokenizer or getattr(model, 'tokenizer', None)
- self.gamma = gamma
- self.beta = beta
+ self.tokenizer = tokenizer or getattr(model, "tokenizer", None)
+ self.config = args or SimPOConfig(gamma=gamma, beta=beta, **kwargs)
+ self.gamma = self.config.gamma
+ self.beta = self.config.beta
self.use_native = use_native and HAS_NATIVE_TRAINING
- self.output_dir = Path(kwargs.get('output_dir', './simpo_outputs'))
- self.learning_rate = kwargs.get('learning_rate', 5e-7)
- self.iters = kwargs.get('max_steps', 100)
- self.max_seq_length = kwargs.get('max_seq_length', 2048)
- self.logging_steps = kwargs.get('logging_steps', 10)
-
+ self.output_dir = Path(self.config.output_dir)
+ self.learning_rate = self.config.learning_rate
+ self.batch_size = self.config.per_device_train_batch_size
+ self.iters = self.config.max_steps
+ self.max_seq_length = self.config.max_seq_length
+ self.logging_steps = self.config.logging_steps
self.output_dir.mkdir(parents=True, exist_ok=True)
self.adapter_path = self.output_dir / "adapters"
self.adapter_path.mkdir(parents=True, exist_ok=True)
- print(f"SimPOTrainer initialized (gamma={gamma}, beta={beta}, native={self.use_native})")
-
- def _tokenize_pair(self, sample):
- prompt = sample.get('prompt', '')
- chosen = sample.get('chosen', '')
- rejected = sample.get('rejected', '')
-
- chosen_ids = self.tokenizer.encode(prompt + chosen)[:self.max_seq_length]
- rejected_ids = self.tokenizer.encode(prompt + rejected)[:self.max_seq_length]
-
+ def _tokenize_pair(self, sample: Dict[str, Any]) -> Dict[str, Any]:
+ prompt = sample.get("prompt", "")
+ chosen = sample.get("chosen", "")
+ rejected = sample.get("rejected", "")
+ chosen_ids = self.tokenizer.encode(prompt + chosen)[: self.max_seq_length]
+ rejected_ids = self.tokenizer.encode(prompt + rejected)[: self.max_seq_length]
return {
- 'chosen_ids': chosen_ids,
- 'rejected_ids': rejected_ids,
- 'chosen_length': len(chosen_ids),
- 'rejected_length': len(rejected_ids),
+ "chosen_ids": chosen_ids,
+ "rejected_ids": rejected_ids,
+ "chosen_length": len(chosen_ids),
+ "rejected_length": len(rejected_ids),
}
- def _pad(self, ids, length, pad_id=0):
- return ids + [pad_id] * (length - len(ids)) if len(ids) < length else ids[:length]
-
def train(self):
- """Train using SimPO with proper loss."""
- print("=" * 70)
- print("Starting SimPO Training")
- print("=" * 70)
-
if not self.use_native:
warnings.warn("SimPO requires native training. Using SFT approximation.", UserWarning)
return {"status": "fallback"}
- print("\n[Using Native SimPO Training with Proper Loss]")
-
- if hasattr(self.model, '_apply_lora') and not getattr(self.model, '_lora_applied', False):
+ if hasattr(self.model, "_apply_lora") and not getattr(self.model, "_lora_applied", False):
self.model._apply_lora()
- tokenized_data = []
- for sample in self.train_dataset:
- if 'prompt' in sample and 'chosen' in sample and 'rejected' in sample:
- tokenized_data.append(self._tokenize_pair(sample))
- print(f"✓ Prepared {len(tokenized_data)} preference pairs")
-
- actual_model = self.model.model if hasattr(self.model, 'model') else self.model
- lr_schedule = optim.cosine_decay(self.learning_rate, self.iters)
- optimizer = optim.AdamW(learning_rate=lr_schedule)
+ prepared = prepare_rl_dataset(
+ self.train_dataset,
+ mode=self.config.dataset_mode or "preference",
+ tokenizer=self.tokenizer,
+ chat_template=getattr(self.config, "chat_template", None),
+ )
+ tokenized_data = [self._tokenize_pair(sample) for sample in prepared]
+ actual_model = _actual_model(self.model)
+ optimizer = optim.AdamW(learning_rate=optim.cosine_decay(self.learning_rate, self.iters))
+ pad_id = _pad_token_id(self.tokenizer)
- def loss_fn(model, batch_data):
- chosen_ids, rejected_ids, chosen_lengths, rejected_lengths = batch_data
+ def loss_fn(model, batch):
+ chosen_ids, rejected_ids, chosen_lengths, rejected_lengths = batch
loss, _ = compute_simpo_loss(
- model, chosen_ids, rejected_ids, chosen_lengths, rejected_lengths,
- self.beta, self.gamma
+ model,
+ chosen_ids,
+ rejected_ids,
+ chosen_lengths,
+ rejected_lengths,
+ self.beta,
+ self.gamma,
)
return loss
- loss_and_grad = nn.value_and_grad(actual_model, loss_fn)
+ value_and_grad = nn.value_and_grad(actual_model, loss_fn)
+ last_loss = None
- total_loss = 0.0
for step in range(self.iters):
- sample = tokenized_data[step % len(tokenized_data)]
- max_len = max(sample['chosen_length'], sample['rejected_length'])
- pad_id = self.tokenizer.pad_token_id or 0
-
- chosen_ids = mx.array([self._pad(sample['chosen_ids'], max_len, pad_id)])
- rejected_ids = mx.array([self._pad(sample['rejected_ids'], max_len, pad_id)])
- chosen_lengths = mx.array([sample['chosen_length']])
- rejected_lengths = mx.array([sample['rejected_length']])
-
- loss, grads = loss_and_grad(actual_model, (chosen_ids, rejected_ids, chosen_lengths, rejected_lengths))
+ start = (step * max(1, self.batch_size)) % len(tokenized_data)
+ samples = tokenized_data[start:start + max(1, self.batch_size)]
+ if len(samples) < max(1, self.batch_size):
+ samples += tokenized_data[: max(1, self.batch_size) - len(samples)]
+ chosen_ids, chosen_lengths = _pad_sequences([sample["chosen_ids"] for sample in samples], pad_id)
+ rejected_ids, rejected_lengths = _pad_sequences([sample["rejected_ids"] for sample in samples], pad_id)
+ loss, grads = value_and_grad(actual_model, (chosen_ids, rejected_ids, chosen_lengths, rejected_lengths))
optimizer.update(actual_model, grads)
mx.eval(actual_model.parameters(), optimizer.state)
+ last_loss = loss.item()
- total_loss += loss.item()
-
- if (step + 1) % self.logging_steps == 0:
- print(f" Step {step + 1}/{self.iters} | Loss: {total_loss / self.logging_steps:.4f}")
- total_loss = 0.0
-
- # Save adapters and config
_save_adapters_and_config(self.model, self.adapter_path)
-
- print("\n" + "=" * 70)
- print("SimPO Training Complete!")
- print("=" * 70)
- print(f" Adapters saved to: {self.adapter_path}")
- return {"status": "success", "adapter_path": str(self.adapter_path)}
+ return {"status": "success", "adapter_path": str(self.adapter_path), "final_loss": last_loss}
-# Utility functions for preference data
-
def prepare_preference_dataset(
dataset: Any,
tokenizer: Any,
format_type: str = "dpo",
-) -> List[Dict]:
- """
- Prepare dataset for preference-based training (DPO, ORPO, etc.).
+) -> List[Dict[str, Any]]:
+ return public_prepare_preference_dataset(dataset, tokenizer=tokenizer, format_type=format_type)
- Args:
- dataset: HuggingFace dataset with preference pairs
- tokenizer: Tokenizer for formatting
- format_type: 'dpo', 'orpo', or 'grpo'
-
- Returns:
- Formatted dataset ready for training
-
- Example:
- >>> from datasets import load_dataset
- >>> dataset = load_dataset("Anthropic/hh-rlhf")
- >>> formatted = prepare_preference_dataset(dataset, tokenizer, "dpo")
- """
-
- formatted_data = []
-
- for sample in dataset:
- if format_type in ["dpo", "orpo"]:
- # Expect chosen/rejected format
- if 'chosen' in sample and 'rejected' in sample:
- formatted_data.append({
- "prompt": sample.get('prompt', ''),
- "chosen": sample['chosen'],
- "rejected": sample['rejected'],
- })
- elif format_type == "grpo":
- # Expect prompt + optional ground truth
- formatted_data.append({
- "prompt": sample.get('prompt', sample.get('question', '')),
- "answer": sample.get('answer', sample.get('response', '')),
- })
-
- return formatted_data
-
-
-def create_reward_function(reward_type: str = "simple") -> Callable:
- """
- Create a reward function for GRPO training.
-
- Args:
- reward_type: Type of reward function
- - 'simple': Binary correct/incorrect
- - 'math': Extract and compare numerical answers
- - 'code': Execute and verify code output
- - 'length': Reward based on response length
-
- Returns:
- Reward function callable
-
- Example:
- >>> reward_fn = create_reward_function('math')
- >>> trainer = GRPOTrainer(..., reward_fn=reward_fn)
- """
-
- if reward_type == "simple":
- def simple_reward(response: str, ground_truth: str) -> float:
- return 1.0 if ground_truth.lower() in response.lower() else 0.0
- return simple_reward
-
- elif reward_type == "math":
- def math_reward(response: str, ground_truth: str) -> float:
- import re
- # Extract numbers from response
- numbers = re.findall(r'-?\d+\.?\d*', response)
- target = re.findall(r'-?\d+\.?\d*', ground_truth)
- if numbers and target:
- try:
- return 1.0 if float(numbers[-1]) == float(target[-1]) else 0.0
- except:
- return 0.0
- return 0.0
- return math_reward
-
- elif reward_type == "length":
- def length_reward(response: str, _: str) -> float:
- # Reward longer, more detailed responses (up to a point)
- length = len(response.split())
- if length < 10:
- return 0.2
- elif length < 50:
- return 0.5
- elif length < 200:
- return 1.0
- else:
- return 0.8 # Penalize very long responses
- return length_reward
- else:
- raise ValueError(f"Unknown reward type: {reward_type}")
+def create_reward_function(reward_type: Any = "simple", *, rewards: Optional[List[Any]] = None) -> Callable:
+ return public_create_reward_function(reward_type, rewards=rewards)
diff --git a/mlx_tune/trainer.py b/mlx_tune/trainer.py
index ee346f3..7c5be87 100644
--- a/mlx_tune/trainer.py
+++ b/mlx_tune/trainer.py
@@ -261,11 +261,17 @@ def save_model_hf_format(
# CRITICAL: Fuse LoRA adapters into base weights before saving
# Without this, LoRA layers are saved as-is and won't load properly
- fused_linears = [
- (n, m.fuse(dequantize=kwargs.get('dequantize', False)))
- for n, m in actual_model.named_modules()
- if hasattr(m, "fuse")
- ]
+ fused_linears = []
+ for n, m in actual_model.named_modules():
+ if not hasattr(m, "fuse"):
+ continue
+ try:
+ fused = m.fuse(dequantize=kwargs.get('dequantize', False))
+ except TypeError as exc:
+ if "dequantize" not in str(exc):
+ raise
+ fused = m.fuse()
+ fused_linears.append((n, fused))
if fused_linears:
print(f" Fusing {len(fused_linears)} LoRA layers into base model...")
diff --git a/mlx_tune/trl_compat.py b/mlx_tune/trl_compat.py
new file mode 100644
index 0000000..dd20de8
--- /dev/null
+++ b/mlx_tune/trl_compat.py
@@ -0,0 +1,251 @@
+"""
+Compatibility patching for TRL/Unsloth-style imports.
+
+This module exposes ``PatchFastRL`` so MLX-Tune can populate or mutate a
+``trl`` module in-place with trainer/config shims backed by MLX-Tune's native
+implementations.
+"""
+
+from __future__ import annotations
+
+from dataclasses import asdict, is_dataclass
+import inspect
+import sys
+from types import ModuleType
+from typing import Any, Dict, Mapping, Tuple, Type
+
+from mlx_tune.rl_trainers import (
+ DPOConfig as _DPOConfig,
+ DPOTrainer as _DPOTrainer,
+ GRPOConfig as _GRPOConfig,
+ GRPOTrainer as _GRPOTrainer,
+ KTOConfig as _KTOConfig,
+ KTOTrainer as _KTOTrainer,
+ OnlineDPOConfig as _OnlineDPOConfig,
+ OnlineDPOTrainer as _OnlineDPOTrainer,
+ ORPOConfig as _ORPOConfig,
+ ORPOTrainer as _ORPOTrainer,
+ PPOConfig as _PPOConfig,
+ PPOTrainer as _PPOTrainer,
+ RewardConfig as _RewardConfig,
+ RewardTrainer as _RewardTrainer,
+ SimPOConfig as _SimPOConfig,
+ SimPOTrainer as _SimPOTrainer,
+)
+from mlx_tune.sft_trainer import SFTConfig as _SFTConfig
+from mlx_tune.sft_trainer import SFTTrainer as _SFTTrainer
+
+
+def _normalize_alias_kwargs(kwargs: Mapping[str, Any]) -> Dict[str, Any]:
+ normalized = dict(kwargs)
+
+ if "processing_class" in normalized:
+ normalized.setdefault("tokenizer", normalized["processing_class"])
+ normalized.pop("processing_class", None)
+
+ if "reward_funcs" in normalized:
+ normalized.setdefault("reward_sources", normalized["reward_funcs"])
+ normalized.pop("reward_funcs", None)
+
+ if "reward_func" in normalized:
+ normalized.setdefault("reward_fn", normalized["reward_func"])
+ normalized.pop("reward_func", None)
+
+ return normalized
+
+
+def _extract_public_attrs(value: Any) -> Dict[str, Any]:
+ if value is None:
+ return {}
+ if isinstance(value, Mapping):
+ return dict(value)
+ if hasattr(value, "to_dict") and callable(value.to_dict):
+ return dict(value.to_dict())
+ if is_dataclass(value) and not isinstance(value, type):
+ return dict(asdict(value))
+ if hasattr(value, "__dict__"):
+ return {
+ key: item
+ for key, item in vars(value).items()
+ if not key.startswith("_") and not callable(item)
+ }
+
+ public: Dict[str, Any] = {}
+ for key in dir(value):
+ if key.startswith("_"):
+ continue
+ try:
+ item = getattr(value, key)
+ except Exception:
+ continue
+ if callable(item):
+ continue
+ public[key] = item
+ return public
+
+
+def _config_field_names(config_class: Type[Any]) -> set[str]:
+ parameters = inspect.signature(config_class.__init__).parameters
+ return {
+ name
+ for name, parameter in parameters.items()
+ if name != "self" and parameter.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
+ }
+
+
+def _trainer_param_names(trainer_class: Type[Any]) -> set[str]:
+ parameters = inspect.signature(trainer_class.__init__).parameters
+ return {
+ name
+ for name, parameter in parameters.items()
+ if name != "self" and parameter.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
+ }
+
+
+def _coerce_config(
+ config_value: Any,
+ config_class: Type[Any],
+ config_overrides: Mapping[str, Any] | None = None,
+) -> Any:
+ if isinstance(config_value, config_class):
+ if not config_overrides:
+ return config_value
+ payload = _extract_public_attrs(config_value)
+ else:
+ payload = _extract_public_attrs(config_value)
+
+ if config_overrides:
+ payload.update(config_overrides)
+ return config_class(**_normalize_alias_kwargs(payload))
+
+
+def _prepare_trainer_kwargs(
+ kwargs: Mapping[str, Any],
+ trainer_class: Type[Any],
+ config_class: Type[Any],
+) -> Dict[str, Any]:
+ normalized = _normalize_alias_kwargs(kwargs)
+ trainer_params = _trainer_param_names(trainer_class)
+ config_fields = _config_field_names(config_class)
+
+ config_overrides: Dict[str, Any] = {}
+ for key in list(normalized):
+ if key in {"args"} or key in trainer_params:
+ continue
+ if key in config_fields:
+ config_overrides[key] = normalized.pop(key)
+
+ if "args" not in normalized:
+ if config_overrides:
+ normalized["args"] = config_class(**config_overrides)
+ return normalized
+
+ args_value = normalized.get("args")
+ if args_value is None:
+ if config_overrides:
+ normalized["args"] = config_class(**config_overrides)
+ return normalized
+
+ normalized["args"] = _coerce_config(args_value, config_class, config_overrides=config_overrides)
+ return normalized
+
+
+def _build_compat_config_class(export_name: str, config_class: Type[Any]) -> Type[Any]:
+ class CompatConfig(config_class):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, **_normalize_alias_kwargs(kwargs))
+
+ CompatConfig.__name__ = export_name
+ CompatConfig.__qualname__ = export_name
+ CompatConfig.__module__ = "trl"
+ CompatConfig.__doc__ = config_class.__doc__
+ return CompatConfig
+
+
+def _build_compat_trainer_class(
+ export_name: str,
+ trainer_class: Type[Any],
+ config_class: Type[Any],
+) -> Type[Any]:
+ class CompatTrainer(trainer_class):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ prepared_kwargs = _prepare_trainer_kwargs(kwargs, trainer_class, config_class)
+ super().__init__(*args, **prepared_kwargs)
+
+ CompatTrainer.__name__ = export_name
+ CompatTrainer.__qualname__ = export_name
+ CompatTrainer.__module__ = "trl"
+ CompatTrainer.__doc__ = trainer_class.__doc__
+ return CompatTrainer
+
+
+_COMPAT_PAIRS: Tuple[Tuple[str, Type[Any], Type[Any]], ...] = (
+ ("SFT", _SFTTrainer, _SFTConfig),
+ ("Reward", _RewardTrainer, _RewardConfig),
+ ("DPO", _DPOTrainer, _DPOConfig),
+ ("ORPO", _ORPOTrainer, _ORPOConfig),
+ ("GRPO", _GRPOTrainer, _GRPOConfig),
+ ("PPO", _PPOTrainer, _PPOConfig),
+ ("OnlineDPO", _OnlineDPOTrainer, _OnlineDPOConfig),
+ ("KTO", _KTOTrainer, _KTOConfig),
+ ("SimPO", _SimPOTrainer, _SimPOConfig),
+)
+
+_COMPAT_CLASSES: Dict[str, Type[Any]] = {}
+for stem, trainer_class, config_class in _COMPAT_PAIRS:
+ compat_config = _build_compat_config_class(f"{stem}Config", config_class)
+ compat_trainer = _build_compat_trainer_class(f"{stem}Trainer", trainer_class, config_class)
+ _COMPAT_CLASSES[compat_config.__name__] = compat_config
+ _COMPAT_CLASSES[compat_trainer.__name__] = compat_trainer
+
+_COMPAT_EXPORT_NAMES = tuple(_COMPAT_CLASSES)
+
+
+def _ensure_trl_modules() -> tuple[ModuleType, ModuleType]:
+ trl_module = sys.modules.get("trl")
+ if trl_module is None:
+ trl_module = ModuleType("trl")
+ trl_module.__package__ = "trl"
+ trl_module.__path__ = []
+ trl_module.__version__ = "0.0.0-mlx-tune"
+ sys.modules["trl"] = trl_module
+ else:
+ if not hasattr(trl_module, "__package__"):
+ trl_module.__package__ = "trl"
+ if not hasattr(trl_module, "__path__"):
+ trl_module.__path__ = []
+ if not hasattr(trl_module, "__version__"):
+ trl_module.__version__ = "0.0.0-mlx-tune"
+
+ trainer_module = sys.modules.get("trl.trainer")
+ if trainer_module is None:
+ trainer_module = ModuleType("trl.trainer")
+ trainer_module.__package__ = "trl"
+ sys.modules["trl.trainer"] = trainer_module
+ else:
+ if not hasattr(trainer_module, "__package__"):
+ trainer_module.__package__ = "trl"
+
+ trl_module.trainer = trainer_module
+ return trl_module, trainer_module
+
+
+def PatchFastRL(algorithm: Any = None, FastLanguageModel: Any = None) -> None:
+ """
+ Patch ``trl`` so top-level trainer/config imports resolve to MLX-Tune.
+
+ ``algorithm`` and ``FastLanguageModel`` are accepted for source compatibility
+ with Unsloth's ``PatchFastRL`` signature.
+ """
+ del algorithm, FastLanguageModel
+
+ trl_module, trainer_module = _ensure_trl_modules()
+ for export_name, export_value in _COMPAT_CLASSES.items():
+ setattr(trl_module, export_name, export_value)
+ setattr(trainer_module, export_name, export_value)
+
+ trl_module.__all__ = list(_COMPAT_EXPORT_NAMES)
+ trainer_module.__all__ = list(_COMPAT_EXPORT_NAMES)
+ trl_module.__MLX_TUNE_PATCHED__ = True
+ trainer_module.__MLX_TUNE_PATCHED__ = True
+
diff --git a/tests/test_arithmetic_grpo_validation.py b/tests/test_arithmetic_grpo_validation.py
new file mode 100644
index 0000000..70ebeb1
--- /dev/null
+++ b/tests/test_arithmetic_grpo_validation.py
@@ -0,0 +1,160 @@
+from pathlib import Path
+
+from mlx_tune.arithmetic_grpo_validation import (
+ ArithmeticSolutionReward,
+ generate_benchmark_splits,
+ parse_solution_response,
+ run_baseline,
+ run_compare,
+ run_training,
+ score_solution_output,
+)
+
+
+class FakeTokenizer:
+ def apply_chat_template(self, messages, add_generation_prompt=True, tokenize=False):
+ prompt = "\n".join(message["content"] for message in messages)
+ return prompt + ("\nassistant:" if add_generation_prompt else "")
+
+ def encode(self, text, add_special_tokens=False):
+ del add_special_tokens
+ return [ord(char) % 256 for char in text]
+
+
+class FakeModel:
+ def __init__(self):
+ self.trained = False
+ self.loaded_adapter = None
+
+ def load_adapter(self, path):
+ self.loaded_adapter = path
+ self.trained = True
+
+ def generate(self, prompt, max_tokens, sampler=None, verbose=False):
+ del max_tokens, sampler, verbose
+ expression = prompt.split("Compute exactly:\n", 1)[1].split("\nassistant:", 1)[0].strip()
+ answer = str(int(eval(expression, {"__builtins__": {}}, {})))
+ if self.trained:
+ return f"\nchecking\n\n{answer}"
+ return answer
+
+
+class FakeTrainer:
+ def __init__(self, model, train_dataset, eval_dataset, tokenizer, reward_fn, args):
+ self.model = model
+ self.train_dataset = train_dataset
+ self.eval_dataset = eval_dataset
+ self.tokenizer = tokenizer
+ self.reward_fn = reward_fn
+ self.args = args
+
+ def train(self):
+ self.model.trained = True
+ policy_dir = Path(self.args.output_dir) / "policy"
+ policy_dir.mkdir(parents=True, exist_ok=True)
+ (policy_dir / "adapters.safetensors").write_text("fake")
+ (policy_dir / "adapter_config.json").write_text("{}")
+ return {"status": "success", "global_step": self.args.max_steps}
+
+
+def _fake_load_model_bundle(*args, **kwargs):
+ del args, kwargs
+ return FakeModel(), FakeTokenizer()
+
+
+def test_generate_benchmark_splits_is_deterministic_and_disjoint(tmp_path):
+ first_dir = tmp_path / "first"
+ second_dir = tmp_path / "second"
+ generate_benchmark_splits(first_dir, train_size=12, val_size=4, test_size=4, seed=7)
+ generate_benchmark_splits(second_dir, train_size=12, val_size=4, test_size=4, seed=7)
+
+ first_paths = {split: (first_dir / "datasets" / f"{split}.jsonl").read_text() for split in ("train", "val", "test")}
+ second_paths = {split: (second_dir / "datasets" / f"{split}.jsonl").read_text() for split in ("train", "val", "test")}
+
+ assert first_paths == second_paths
+
+ expressions = {}
+ for split in ("train", "val", "test"):
+ rows = [line for line in first_paths[split].strip().splitlines() if line]
+ expressions[split] = {__import__("json").loads(line)["expression"] for line in rows}
+
+ assert expressions["train"].isdisjoint(expressions["val"])
+ assert expressions["train"].isdisjoint(expressions["test"])
+ assert expressions["val"].isdisjoint(expressions["test"])
+
+
+def test_parse_and_score_solution_output():
+ parsed = parse_solution_response("x\n42\n")
+ assert parsed["single_solution_tag"] is True
+ assert parsed["parseable_solution"] is True
+ assert parsed["parsed_answer"] == 42
+
+ exact = score_solution_output("42", "42")
+ assert exact["exact_match"] is True
+ assert exact["reward"] == 1.1
+
+ wrong = score_solution_output("41", "42")
+ assert wrong["exact_match"] is False
+ assert wrong["reward"] == 0.1
+
+ missing = score_solution_output("42", "42")
+ assert missing["single_solution_tag"] is False
+ assert missing["reward"] == 0.0
+
+ multiple = score_solution_output("4142", "42")
+ assert multiple["multiple_solution_tags"] is True
+ assert multiple["reward"] == 0.0
+
+
+def test_reward_evaluator_returns_components():
+ evaluator = ArithmeticSolutionReward()
+ result = evaluator.evaluate({"completion_text": "7", "reward_context": "7"})
+ assert result["reward"] == 1.1
+ assert result["components"]["solution_tag"] == 0.1
+ assert result["components"]["correctness"] == 1.0
+
+
+def test_baseline_train_and_compare_smoke(tmp_path, monkeypatch):
+ monkeypatch.setattr("mlx_tune.arithmetic_grpo_validation.load_model_bundle", _fake_load_model_bundle)
+ monkeypatch.setattr("mlx_tune.arithmetic_grpo_validation.GRPOTrainer", FakeTrainer)
+ monkeypatch.setattr("mlx_tune.arithmetic_grpo_validation.FastLanguageModel.for_inference", lambda model: model)
+
+ generate_benchmark_splits(tmp_path, train_size=8, val_size=3, test_size=3, seed=0)
+
+ baseline = run_baseline(
+ tmp_path,
+ model_name="fake-qwen",
+ seed=0,
+ max_seq_length=64,
+ max_completion_length=32,
+ )
+ assert baseline["aggregate"]["exact_match"] == 0.0
+ assert (tmp_path / "baseline_outputs.jsonl").exists()
+ assert (tmp_path / "baseline_metrics.json").exists()
+
+ trained = run_training(
+ tmp_path,
+ model_name="fake-qwen",
+ seed=0,
+ max_seq_length=64,
+ max_completion_length=32,
+ max_steps=3,
+ learning_rate=1e-6,
+ per_device_train_batch_size=2,
+ rollout_batch_size=2,
+ num_generations=2,
+ rl_temperature=0.9,
+ lora_rank=4,
+ logging_steps=1,
+ eval_steps=1,
+ save_steps=1,
+ )
+ assert trained["training"]["status"] == "success"
+ assert trained["post_eval"]["aggregate"]["exact_match"] == 1.0
+ assert (tmp_path / "post_rl_outputs.jsonl").exists()
+ assert (tmp_path / "post_rl_metrics.json").exists()
+
+ comparison = run_compare(tmp_path)
+ assert comparison["aggregate_delta"]["exact_match"] > 0.0
+ assert (tmp_path / "comparison.json").exists()
+ assert (tmp_path / "comparison.md").exists()
diff --git a/tests/test_losses.py b/tests/test_losses.py
index 10bb304..db18a20 100644
--- a/tests/test_losses.py
+++ b/tests/test_losses.py
@@ -1,220 +1,377 @@
"""
-Unit tests for loss functions in mlx_tune.losses
+Unit tests for mlx_tune.losses.
"""
-import pytest
import mlx.core as mx
import mlx.nn as nn
-class TestComputeLogProbs:
- """Test log probability computation."""
+class TinyModel(nn.Module):
+ def __init__(self, vocab_size: int = 32, hidden_size: int = 16):
+ super().__init__()
+ self.embedding = nn.Embedding(vocab_size, hidden_size)
+ self.output = nn.Linear(hidden_size, vocab_size)
- def test_compute_log_probs_shape(self):
- """Test output shape of compute_log_probs."""
- from mlx_tune.losses import compute_log_probs_with_lengths
+ def __call__(self, x):
+ return self.output(self.embedding(x))
- # Create a simple mock model
- class MockModel(nn.Module):
- def __init__(self):
- super().__init__()
- self.embedding = nn.Embedding(100, 64)
- self.linear = nn.Linear(64, 100)
- def __call__(self, x):
- return self.linear(self.embedding(x))
+class TinyTokenizer:
+ eos_token_id = 1
- model = MockModel()
- mx.eval(model.parameters())
- # Test input
- batch_size = 2
- seq_len = 10
- input_ids = mx.random.randint(0, 100, (batch_size, seq_len))
- lengths = mx.array([8, 6])
+class ConstantRewardModel:
+ def score(self, input_ids, **kwargs):
+ del input_ids, kwargs
+ return mx.array([1.0, 2.0], dtype=mx.float32)
- log_probs = compute_log_probs_with_lengths(model, input_ids, lengths)
- assert log_probs.shape == (batch_size,), f"Expected shape {(batch_size,)}, got {log_probs.shape}"
+def test_compute_log_probs_with_lengths_shape():
+ from mlx_tune.losses import compute_log_probs_with_lengths
- def test_compute_log_probs_values(self):
- """Test that log probs are negative (as expected for probabilities)."""
- from mlx_tune.losses import compute_log_probs_with_lengths
+ model = TinyModel()
+ mx.eval(model.parameters())
- class MockModel(nn.Module):
- def __init__(self):
- super().__init__()
- self.embedding = nn.Embedding(100, 64)
- self.linear = nn.Linear(64, 100)
+ input_ids = mx.array([[1, 2, 3, 4], [1, 5, 6, 0]])
+ lengths = mx.array([3, 2])
- def __call__(self, x):
- return self.linear(self.embedding(x))
+ log_probs = compute_log_probs_with_lengths(model, input_ids, lengths)
+ assert log_probs.shape == (2,)
- model = MockModel()
- mx.eval(model.parameters())
- input_ids = mx.random.randint(0, 100, (2, 10))
- lengths = mx.array([8, 6])
+def test_precompute_preference_reference_logprobs_matches_direct():
+ from mlx_tune.losses import (
+ compute_reference_logprobs,
+ precompute_preference_reference_logprobs,
+ )
- log_probs = compute_log_probs_with_lengths(model, input_ids, lengths)
- mx.eval(log_probs)
+ model = TinyModel()
+ mx.eval(model.parameters())
- # Log probabilities should be negative (or zero at maximum)
- assert mx.all(log_probs <= 0), "Log probabilities should be non-positive"
+ chosen_ids = mx.array([[1, 2, 3, 4], [1, 4, 5, 6]])
+ rejected_ids = mx.array([[1, 3, 2, 4], [1, 6, 5, 4]])
+ chosen_lengths = mx.array([3, 3])
+ rejected_lengths = mx.array([3, 3])
+ direct = compute_reference_logprobs(
+ model,
+ chosen_ids,
+ rejected_ids,
+ chosen_lengths,
+ rejected_lengths,
+ )
+ batched = precompute_preference_reference_logprobs(
+ model,
+ chosen_ids,
+ rejected_ids,
+ chosen_lengths,
+ rejected_lengths,
+ batch_size=1,
+ )
-class TestDPOLoss:
- """Test DPO loss computation."""
+ assert mx.allclose(direct[0], batched[0])
+ assert mx.allclose(direct[1], batched[1])
- def test_dpo_loss_shape(self):
- """Test DPO loss returns scalar."""
- from mlx_tune.losses import dpo_loss
- class MockModel(nn.Module):
- def __init__(self):
- super().__init__()
- self.embedding = nn.Embedding(100, 64)
- self.linear = nn.Linear(64, 100)
-
- def __call__(self, x):
- return self.linear(self.embedding(x))
-
- model = MockModel()
- mx.eval(model.parameters())
-
- batch_size = 2
- seq_len = 10
- chosen_ids = mx.random.randint(0, 100, (batch_size, seq_len))
- rejected_ids = mx.random.randint(0, 100, (batch_size, seq_len))
- chosen_lengths = mx.array([8, 7])
- rejected_lengths = mx.array([9, 6])
-
- loss, ntoks = dpo_loss(
- model, chosen_ids, rejected_ids,
- chosen_lengths, rejected_lengths,
- beta=0.1
- )
-
- assert loss.shape == (), f"Loss should be scalar, got shape {loss.shape}"
- assert ntoks.shape == (), f"ntoks should be scalar, got shape {ntoks.shape}"
-
- def test_dpo_loss_beta_effect(self):
- """Test that higher beta increases loss magnitude."""
- from mlx_tune.losses import dpo_loss
-
- class MockModel(nn.Module):
- def __init__(self):
- super().__init__()
- self.embedding = nn.Embedding(100, 64)
- self.linear = nn.Linear(64, 100)
-
- def __call__(self, x):
- return self.linear(self.embedding(x))
-
- model = MockModel()
- mx.eval(model.parameters())
-
- chosen_ids = mx.random.randint(0, 100, (2, 10))
- rejected_ids = mx.random.randint(0, 100, (2, 10))
- chosen_lengths = mx.array([8, 7])
- rejected_lengths = mx.array([9, 6])
-
- loss_low_beta, _ = dpo_loss(model, chosen_ids, rejected_ids,
- chosen_lengths, rejected_lengths, beta=0.01)
- loss_high_beta, _ = dpo_loss(model, chosen_ids, rejected_ids,
- chosen_lengths, rejected_lengths, beta=1.0)
-
- mx.eval(loss_low_beta, loss_high_beta)
-
- # Both losses should be finite
- assert not mx.isnan(loss_low_beta), "Low beta loss should not be NaN"
- assert not mx.isnan(loss_high_beta), "High beta loss should not be NaN"
-
-
-class TestORPOLoss:
- """Test ORPO loss computation."""
-
- def test_orpo_loss_shape(self):
- """Test ORPO loss returns scalar."""
- from mlx_tune.losses import orpo_loss
-
- class MockModel(nn.Module):
- def __init__(self):
- super().__init__()
- self.embedding = nn.Embedding(100, 64)
- self.linear = nn.Linear(64, 100)
-
- def __call__(self, x):
- return self.linear(self.embedding(x))
-
- model = MockModel()
- mx.eval(model.parameters())
-
- chosen_ids = mx.random.randint(0, 100, (2, 10))
- rejected_ids = mx.random.randint(0, 100, (2, 10))
- chosen_lengths = mx.array([8, 7])
- rejected_lengths = mx.array([9, 6])
-
- loss, ntoks = orpo_loss(model, chosen_ids, rejected_ids,
- chosen_lengths, rejected_lengths, beta=0.1)
-
- assert loss.shape == (), f"Loss should be scalar, got shape {loss.shape}"
-
-
-class TestSimPOLoss:
- """Test SimPO loss computation."""
-
- def test_simpo_loss_shape(self):
- """Test SimPO loss returns scalar."""
- from mlx_tune.losses import simpo_loss
-
- class MockModel(nn.Module):
- def __init__(self):
- super().__init__()
- self.embedding = nn.Embedding(100, 64)
- self.linear = nn.Linear(64, 100)
-
- def __call__(self, x):
- return self.linear(self.embedding(x))
-
- model = MockModel()
- mx.eval(model.parameters())
-
- chosen_ids = mx.random.randint(0, 100, (2, 10))
- rejected_ids = mx.random.randint(0, 100, (2, 10))
- chosen_lengths = mx.array([8, 7])
- rejected_lengths = mx.array([9, 6])
-
- loss, ntoks = simpo_loss(model, chosen_ids, rejected_ids,
- chosen_lengths, rejected_lengths,
- beta=2.0, gamma=0.5)
-
- assert loss.shape == (), f"Loss should be scalar, got shape {loss.shape}"
-
-
-class TestSFTLoss:
- """Test SFT loss computation."""
-
- def test_sft_loss_shape(self):
- """Test SFT loss returns scalar."""
- from mlx_tune.losses import sft_loss
-
- class MockModel(nn.Module):
- def __init__(self):
- super().__init__()
- self.embedding = nn.Embedding(100, 64)
- self.linear = nn.Linear(64, 100)
-
- def __call__(self, x):
- return self.linear(self.embedding(x))
-
- model = MockModel()
- mx.eval(model.parameters())
-
- input_ids = mx.random.randint(0, 100, (2, 10))
- lengths = mx.array([8, 6])
-
- loss, ntoks = sft_loss(model, input_ids, lengths)
-
- assert loss.shape == (), f"Loss should be scalar, got shape {loss.shape}"
- assert loss.item() > 0, "Cross entropy loss should be positive"
+def test_precompute_kto_reference_logprobs_matches_direct():
+ from mlx_tune.losses import compute_log_probs_with_lengths, precompute_kto_reference_logprobs
+
+ model = TinyModel()
+ mx.eval(model.parameters())
+
+ input_ids = mx.array([[1, 2, 3, 4], [1, 6, 7, 8]])
+ lengths = mx.array([3, 3])
+
+ direct = compute_log_probs_with_lengths(model, input_ids, lengths)
+ cached = precompute_kto_reference_logprobs(model, input_ids, lengths, batch_size=1)
+
+ assert mx.allclose(direct, cached)
+
+
+def test_compute_completion_log_probs_masks_prompt_tokens():
+ from mlx_tune.losses import compute_completion_log_probs
+
+ model = TinyModel()
+ mx.eval(model.parameters())
+
+ input_ids = mx.array([[1, 2, 3, 4, 5]])
+ prompt_lengths = mx.array([3])
+ completion_lengths = mx.array([2])
+
+ completion_only = compute_completion_log_probs(
+ model,
+ input_ids,
+ prompt_lengths,
+ completion_lengths,
+ )
+
+ first_completion = compute_completion_log_probs(
+ model,
+ input_ids,
+ mx.array([4]),
+ mx.array([1]),
+ )
+
+ assert completion_only.shape == (1,)
+ assert not mx.allclose(completion_only, first_completion)
+
+
+def test_grpo_recompute_loss_is_finite_and_trainable():
+ from mlx_tune.losses import (
+ compute_completion_log_probs,
+ grpo_recompute_loss,
+ )
+
+ policy = TinyModel()
+ reference = TinyModel()
+ mx.eval(policy.parameters(), reference.parameters())
+
+ input_ids = mx.array([[1, 2, 3, 4, 5], [1, 4, 3, 2, 1]])
+ prompt_lengths = mx.array([3, 2])
+ completion_lengths = mx.array([2, 3])
+ rollout_logprobs = compute_completion_log_probs(
+ policy,
+ input_ids,
+ prompt_lengths,
+ completion_lengths,
+ )
+ advantages = mx.array([1.0, -0.5])
+
+ loss, ntoks = grpo_recompute_loss(
+ model=policy,
+ reference_model=reference,
+ input_ids=input_ids,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ rollout_logprobs=rollout_logprobs,
+ advantages=advantages,
+ beta=0.04,
+ )
+
+ mx.eval(loss, ntoks)
+ assert loss.shape == ()
+ assert ntoks.item() == 5
+
+
+def test_grpo_recompute_kl_penalty_is_temperature_invariant_when_advantages_are_zero():
+ from mlx_tune.losses import grpo_recompute_loss
+
+ policy = TinyModel()
+ reference = TinyModel()
+ mx.eval(policy.parameters(), reference.parameters())
+
+ input_ids = mx.array([[1, 2, 3, 4, 5], [1, 4, 3, 2, 1]])
+ prompt_lengths = mx.array([3, 2])
+ completion_lengths = mx.array([2, 3])
+ rollout_logprobs = mx.array([0.0, 0.0], dtype=mx.float32)
+ advantages = mx.array([0.0, 0.0], dtype=mx.float32)
+
+ loss_at_one, _ = grpo_recompute_loss(
+ model=policy,
+ reference_model=reference,
+ input_ids=input_ids,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ rollout_logprobs=rollout_logprobs,
+ advantages=advantages,
+ beta=0.04,
+ temperature=1.0,
+ )
+ loss_at_point_seven, _ = grpo_recompute_loss(
+ model=policy,
+ reference_model=reference,
+ input_ids=input_ids,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ rollout_logprobs=rollout_logprobs,
+ advantages=advantages,
+ beta=0.04,
+ temperature=0.7,
+ )
+
+ assert mx.allclose(loss_at_one, loss_at_point_seven)
+
+
+def test_ppo_kl_penalty_is_temperature_invariant_when_advantages_are_zero():
+ from mlx_tune._rl_runtime import make_policy_eval_batch, score_policy
+ from mlx_tune.losses import ppo_sequence_loss
+
+ policy = TinyModel()
+ mx.eval(policy.parameters())
+
+ input_ids = [[1, 2, 3, 4, 5], [1, 4, 3, 2, 1]]
+ prompt_lengths = [3, 2]
+ completion_lengths = [2, 3]
+ batch = make_policy_eval_batch(
+ input_ids,
+ pad_id=0,
+ mode="completion",
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ old_logprobs=mx.array([0.0, 0.0], dtype=mx.float32),
+ advantages=mx.array([0.0, 0.0], dtype=mx.float32),
+ )
+ raw_scores = score_policy(policy, batch, mode="completion", temperature=1.0)
+ batch.reference_logprobs = raw_scores.summed_logprobs
+
+ loss_at_one, metrics_at_one = ppo_sequence_loss(
+ model=policy,
+ batch=batch,
+ beta=0.04,
+ temperature=1.0,
+ )
+ loss_at_point_seven, metrics_at_point_seven = ppo_sequence_loss(
+ model=policy,
+ batch=batch,
+ beta=0.04,
+ temperature=0.7,
+ )
+
+ assert mx.allclose(loss_at_one, loss_at_point_seven)
+ assert mx.allclose(metrics_at_one["kl_penalty"], mx.zeros_like(metrics_at_one["kl_penalty"]))
+ assert mx.allclose(
+ metrics_at_point_seven["kl_penalty"],
+ mx.zeros_like(metrics_at_point_seven["kl_penalty"]),
+ )
+
+
+def test_grpo_rollout_and_recompute_logprobs_match_with_temperature():
+ from mlx_tune.losses import compute_completion_log_probs, generate_with_log_probs
+
+ model = TinyModel()
+ tokenizer = TinyTokenizer()
+ mx.eval(model.parameters())
+ mx.random.seed(21)
+
+ prompt_ids = mx.array([2, 7, 8])
+ generated_ids, rollout_token_logprobs = generate_with_log_probs(
+ model,
+ tokenizer,
+ prompt_ids,
+ max_tokens=3,
+ temperature=0.7,
+ )
+ completion_ids = generated_ids[len(prompt_ids):].tolist()
+ input_ids = mx.array([prompt_ids.tolist() + completion_ids])
+
+ recomputed = compute_completion_log_probs(
+ model,
+ input_ids,
+ mx.array([len(prompt_ids)]),
+ mx.array([len(completion_ids)]),
+ temperature=0.7,
+ )
+
+ assert mx.allclose(recomputed, mx.array([rollout_token_logprobs.sum()]))
+
+
+def test_reward_model_regression_loss_returns_expected_predictions():
+ from mlx_tune.losses import reward_model_regression_loss
+
+ loss, predictions = reward_model_regression_loss(
+ ConstantRewardModel(),
+ input_ids=mx.array([[1, 2], [3, 4]], dtype=mx.int32),
+ sequence_lengths=mx.array([2, 2], dtype=mx.int32),
+ targets=mx.array([1.0, 2.0], dtype=mx.float32),
+ )
+
+ assert float(loss.item()) == 0.0
+ assert predictions.tolist() == [1.0, 2.0]
+
+
+def test_grpo_loss_routes_produce_distinct_losses_from_same_rollout():
+ from mlx_tune._rl_runtime import make_policy_eval_batch, score_policy
+ from mlx_tune.losses import grpo_recompute_loss
+
+ policy = TinyModel()
+ reference = TinyModel()
+ mx.eval(policy.parameters(), reference.parameters())
+
+ input_ids = mx.array([[1, 2, 3, 4, 5], [1, 4, 3, 2, 1]], dtype=mx.int32)
+ prompt_lengths = mx.array([3, 2], dtype=mx.int32)
+ completion_lengths = mx.array([2, 3], dtype=mx.int32)
+ batch = make_policy_eval_batch(
+ input_ids.tolist(),
+ pad_id=0,
+ mode="completion",
+ prompt_lengths=prompt_lengths.tolist(),
+ completion_lengths=completion_lengths.tolist(),
+ )
+ scored = score_policy(policy, batch, mode="completion")
+ advantages = mx.array([1.0, -0.5], dtype=mx.float32)
+
+ losses = {
+ name: grpo_recompute_loss(
+ model=policy,
+ reference_model=reference,
+ input_ids=input_ids,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ rollout_logprobs=scored.summed_logprobs,
+ old_token_logprobs=scored.token_logprobs * batch.token_mask.astype(mx.float32),
+ advantages=advantages,
+ loss_type=name,
+ max_completion_length=4,
+ )[0]
+ for name in ["grpo", "dapo", "dr_grpo", "gspo"]
+ }
+
+ unique_losses = {round(float(loss.item()), 6) for loss in losses.values()}
+ assert len(unique_losses) >= 3
+
+
+def test_grpo_recompute_loss_supports_asymmetric_clip_bounds():
+ from mlx_tune._rl_runtime import make_policy_eval_batch, score_policy
+ from mlx_tune.losses import grpo_recompute_loss
+
+ policy = TinyModel()
+ reference = TinyModel()
+ mx.eval(policy.parameters(), reference.parameters())
+
+ input_ids = mx.array([[1, 2, 3, 4, 5], [1, 4, 3, 2, 1]], dtype=mx.int32)
+ prompt_lengths = mx.array([3, 2], dtype=mx.int32)
+ completion_lengths = mx.array([2, 3], dtype=mx.int32)
+ batch = make_policy_eval_batch(
+ input_ids.tolist(),
+ pad_id=0,
+ mode="completion",
+ prompt_lengths=prompt_lengths.tolist(),
+ completion_lengths=completion_lengths.tolist(),
+ )
+ scored = score_policy(policy, batch, mode="completion")
+ advantages = mx.array([1.0, 0.5], dtype=mx.float32)
+ token_mask = batch.token_mask.astype(mx.float32)
+ old_token_logprobs = (scored.token_logprobs - (0.24 * token_mask)) * token_mask
+
+ symmetric_loss, _ = grpo_recompute_loss(
+ model=policy,
+ reference_model=reference,
+ input_ids=input_ids,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ rollout_logprobs=scored.summed_logprobs,
+ old_token_logprobs=old_token_logprobs,
+ advantages=advantages,
+ loss_type="dapo",
+ clip_epsilon=0.2,
+ epsilon_low=0.2,
+ epsilon_high=0.2,
+ max_completion_length=4,
+ )
+ asymmetric_loss, _ = grpo_recompute_loss(
+ model=policy,
+ reference_model=reference,
+ input_ids=input_ids,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ rollout_logprobs=scored.summed_logprobs,
+ old_token_logprobs=old_token_logprobs,
+ advantages=advantages,
+ loss_type="dapo",
+ clip_epsilon=0.2,
+ epsilon_low=0.2,
+ epsilon_high=0.28,
+ max_completion_length=4,
+ )
+
+ assert not mx.allclose(symmetric_loss, asymmetric_loss)
diff --git a/tests/test_model.py b/tests/test_model.py
index 7beef89..cdd903e 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -222,6 +222,34 @@ def test_enable_inference_mode_method(self, wrapped_model):
assert model.inference_mode is True
assert model.use_cache is True
+ def test_wrapper_forwards_cache_argument(self):
+ """Test that the wrapper exposes and forwards cache for RL decoding."""
+ from mlx_tune.model import MLXModelWrapper
+
+ class MockCacheModel:
+ def __init__(self):
+ self.calls = []
+
+ def __call__(self, x, cache=None):
+ self.calls.append((x, cache))
+ return ("logits", {"seen": True}) if cache is not None else "logits"
+
+ wrapped = MLXModelWrapper(
+ model=MockCacheModel(),
+ tokenizer=None,
+ max_seq_length=128,
+ model_name="mock",
+ )
+
+ assert wrapped("tokens") == "logits"
+ assert wrapped("tokens", cache={"kv": 1}) == ("logits", {"seen": True})
+ assert wrapped.forward_with_cache("tokens", cache={"kv": 2}) == ("logits", {"seen": True})
+ assert wrapped.model.calls == [
+ ("tokens", None),
+ ("tokens", {"kv": 1}),
+ ("tokens", {"kv": 2}),
+ ]
+
class TestGGUFExportFix:
"""Test cases for GGUF export fix (GitHub issue #3).
diff --git a/tests/test_rl_api.py b/tests/test_rl_api.py
new file mode 100644
index 0000000..e96710f
--- /dev/null
+++ b/tests/test_rl_api.py
@@ -0,0 +1,128 @@
+import mlx.core as mx
+import mlx.nn as nn
+import pytest
+
+
+class SmallBackbone(nn.Module):
+ def __init__(self, vocab_size: int = 64, hidden_size: int = 32, num_layers: int = 2):
+ super().__init__()
+ self.embedding = nn.Embedding(vocab_size, hidden_size)
+ self.layers = [nn.Linear(hidden_size, hidden_size) for _ in range(num_layers)]
+
+ def __call__(self, x):
+ h = self.embedding(x)
+ for layer in self.layers:
+ h = mx.maximum(layer(h), 0)
+ return h
+
+
+class SmallLanguageModel(nn.Module):
+ def __init__(self, vocab_size: int = 64, hidden_size: int = 32, num_layers: int = 2):
+ super().__init__()
+ self.model = SmallBackbone(vocab_size=vocab_size, hidden_size=hidden_size, num_layers=num_layers)
+ self.output = nn.Linear(hidden_size, vocab_size)
+
+ def __call__(self, x):
+ return self.output(self.model(x))
+
+
+class MockTokenizer:
+ def __init__(self, vocab_size: int = 64):
+ self.vocab_size = vocab_size
+ self.pad_token_id = 0
+ self.eos_token_id = 1
+ self.bos_token_id = 2
+
+ def encode(self, text: str, add_special_tokens: bool = True):
+ ids = [((ord(char) % (self.vocab_size - 3)) + 3) for char in text[:32]]
+ if add_special_tokens:
+ ids = [self.bos_token_id] + ids + [self.eos_token_id]
+ return ids
+
+ def decode(self, ids, skip_special_tokens: bool = True):
+ if skip_special_tokens:
+ ids = [token for token in ids if token not in (self.pad_token_id, self.eos_token_id, self.bos_token_id)]
+ return "".join(chr(65 + (token % 26)) for token in ids)
+
+
+class MockModelWrapper:
+ def __init__(self, model: SmallLanguageModel):
+ self.model = model
+ self._lora_applied = False
+
+ def __call__(self, x):
+ return self.model(x)
+
+ def _apply_lora(self):
+ self._lora_applied = True
+ return True
+
+
+def make_model(seed: int) -> MockModelWrapper:
+ mx.random.seed(seed)
+ model = SmallLanguageModel()
+ mx.eval(model.parameters())
+ return MockModelWrapper(model)
+
+
+def test_prepare_rl_dataset_auto_detect_and_chat_adaptation():
+ from mlx_tune import prepare_rl_dataset
+
+ prompt_dataset = prepare_rl_dataset([{"prompt": "Solve 2 + 2", "answer": "4"}])
+ assert prompt_dataset.mode == "prompt"
+ assert prompt_dataset.samples[0]["reward_context"] == "4"
+
+ chat_dataset = prepare_rl_dataset(
+ [
+ {
+ "messages": [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello"},
+ ],
+ "score": 1.0,
+ }
+ ],
+ mode="reward_scalar",
+ tokenizer=MockTokenizer(),
+ )
+ assert chat_dataset.mode == "reward_scalar"
+ assert chat_dataset.adapter_name == "chat_reward_scalar"
+ assert chat_dataset.samples[0]["response"] == "Hello"
+
+
+def test_prepare_rl_dataset_raises_on_ambiguous_pairwise_schema():
+ from mlx_tune import prepare_rl_dataset
+
+ with pytest.raises(ValueError, match="Ambiguous RL dataset schema"):
+ prepare_rl_dataset([{"prompt": "Q", "chosen": "A", "rejected": "B"}])
+
+
+def test_resume_from_checkpoint_returns_bundle_with_manifest_fields(tmp_path):
+ from mlx_tune import RewardConfig, RewardTrainer, build_reward_model, resume_from_checkpoint
+
+ tokenizer = MockTokenizer()
+ reward_model = build_reward_model(make_model(100))
+ trainer = RewardTrainer(
+ model=reward_model,
+ train_dataset=[
+ {"prompt": "Q:", "response": "good", "score": 1.0},
+ {"prompt": "Q:", "response": "bad", "score": 0.0},
+ ],
+ tokenizer=tokenizer,
+ args=RewardConfig(
+ learning_rate=1e-2,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(tmp_path / "reward"),
+ ),
+ )
+
+ trainer.train()
+ bundle = resume_from_checkpoint(tmp_path / "reward")
+
+ assert bundle.algorithm == "reward"
+ assert "reward_model" in bundle.restored_roles
+ assert bundle.trainer_state["global_step"] == 1
+ assert bundle.metrics_history
+ assert bundle.source_format == "manifest"
diff --git a/tests/test_rl_model_roles.py b/tests/test_rl_model_roles.py
new file mode 100644
index 0000000..3c091ba
--- /dev/null
+++ b/tests/test_rl_model_roles.py
@@ -0,0 +1,336 @@
+from pathlib import Path
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx.utils import tree_flatten
+
+from mlx_tune import (
+ build_reference_policy,
+ build_reward_model,
+ build_value_model,
+ create_rl_model_roles,
+ pairwise_ranking_accuracy,
+ reward_model_pairwise_loss,
+ scalar_loss_metrics,
+ value_model_regression_loss,
+)
+from mlx_tune.model import MLXModelWrapper
+
+
+class MockTokenizer:
+ pad_token_id = 0
+ eos_token_id = 1
+ bos_token_id = 2
+
+ def encode(self, text: str, add_special_tokens: bool = True):
+ ids = [((ord(char) % 20) + 3) for char in text[:16]]
+ if add_special_tokens:
+ return [self.bos_token_id] + ids + [self.eos_token_id]
+ return ids
+
+ def decode(self, ids, skip_special_tokens: bool = True):
+ if skip_special_tokens:
+ ids = [token for token in ids if token not in (0, 1, 2)]
+ return "".join(chr(65 + (token % 26)) for token in ids)
+
+
+class DeterministicBackbone(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.hidden_size = 2
+
+ def __call__(self, x):
+ values = x.astype(mx.float32)
+ return mx.stack([values, values * 10.0], axis=-1)
+
+
+class DeterministicCausalLM(nn.Module):
+ def __init__(self, vocab_size: int = 64):
+ super().__init__()
+ self.model = DeterministicBackbone()
+ self.output = nn.Linear(2, vocab_size)
+
+ def __call__(self, x):
+ return self.output(self.model(x))
+
+
+class WeightedBackbone(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.hidden_size = 2
+ self.proj = nn.Linear(1, 2)
+
+ def __call__(self, x):
+ return self.proj(x.astype(mx.float32)[..., None])
+
+
+class WeightedCausalLM(nn.Module):
+ def __init__(self, vocab_size: int = 64):
+ super().__init__()
+ self.model = WeightedBackbone()
+ self.output = nn.Linear(2, vocab_size)
+
+ def __call__(self, x):
+ return self.output(self.model(x))
+
+
+def make_wrapper() -> MLXModelWrapper:
+ model = DeterministicCausalLM()
+ mx.eval(model.parameters())
+ wrapper = MLXModelWrapper(
+ model=model,
+ tokenizer=MockTokenizer(),
+ max_seq_length=32,
+ model_name="deterministic-model",
+ )
+ return wrapper
+
+
+def make_weighted_wrapper(seed: int) -> MLXModelWrapper:
+ mx.random.seed(seed)
+ model = WeightedCausalLM()
+ mx.eval(model.parameters())
+ return MLXModelWrapper(
+ model=model,
+ tokenizer=MockTokenizer(),
+ max_seq_length=32,
+ model_name=f"weighted-model-{seed}",
+ )
+
+
+def _set_scalar_head_to_first_feature(role_model) -> None:
+ role_model.head.update(
+ {
+ "weight": mx.array([[1.0, 0.0]], dtype=mx.float32),
+ "bias": mx.array([0.0], dtype=mx.float32),
+ },
+ strict=False,
+ )
+ mx.eval(role_model.head.parameters())
+
+
+def _parameter_snapshot(model) -> dict[str, mx.array]:
+ actual_model = model.model if hasattr(model, "model") else model
+ return {name: mx.array(value) for name, value in tree_flatten(actual_model.parameters())}
+
+
+def _parameters_match(before: dict[str, mx.array], after_model) -> bool:
+ actual_model = after_model.model if hasattr(after_model, "model") else after_model
+ after = {name: value for name, value in tree_flatten(actual_model.parameters())}
+ return all(mx.allclose(before[name], after[name]) for name in before)
+
+
+def test_reference_policy_clone_is_frozen_and_isolated():
+ policy = make_wrapper()
+ policy.lora_enabled = True
+ policy._lora_applied = True
+ policy.set_adapter_path("/tmp/live-policy")
+
+ roles = create_rl_model_roles(policy)
+ reference = roles.reference_policy
+
+ assert reference.model is not policy
+ assert reference.model.get_adapter_path() is None
+ assert not tree_flatten(reference.model.model.trainable_parameters())
+
+ reference_before = _parameter_snapshot(reference.model)
+ policy.model.output.update(
+ {
+ "weight": mx.ones_like(policy.model.output.weight),
+ "bias": mx.zeros_like(policy.model.output.bias),
+ },
+ strict=False,
+ )
+ mx.eval(policy.model.parameters())
+
+ assert _parameters_match(reference_before, reference.model)
+
+
+def test_scalar_roles_default_to_last_completion_token_and_support_mean_pooling():
+ base_model = make_wrapper()
+
+ reward_model = build_reward_model(base_model)
+ _set_scalar_head_to_first_feature(reward_model)
+
+ input_ids = mx.array([[2, 5, 7, 9], [2, 4, 6, 8]], dtype=mx.int32)
+ sequence_lengths = mx.array([4, 4], dtype=mx.int32)
+ prompt_lengths = mx.array([2, 3], dtype=mx.int32)
+ completion_lengths = mx.array([2, 1], dtype=mx.int32)
+
+ reward_scores = reward_model.score(
+ input_ids,
+ sequence_lengths=sequence_lengths,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ )
+ assert mx.allclose(reward_scores, mx.array([9.0, 8.0], dtype=mx.float32))
+
+ mean_completion_model = build_value_model(base_model, pooling="mean_completion")
+ _set_scalar_head_to_first_feature(mean_completion_model)
+ mean_completion = mean_completion_model.predict(
+ input_ids,
+ sequence_lengths=sequence_lengths,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ )
+ assert mx.allclose(mean_completion, mx.array([8.0, 8.0], dtype=mx.float32))
+
+ mean_sequence_model = build_value_model(base_model, pooling="mean_sequence", target="sequence")
+ _set_scalar_head_to_first_feature(mean_sequence_model)
+ mean_sequence = mean_sequence_model.predict(
+ input_ids,
+ sequence_lengths=sequence_lengths,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ )
+ assert mx.allclose(mean_sequence, mx.array([5.75, 5.0], dtype=mx.float32))
+
+
+def test_scalar_role_save_load_round_trip_preserves_head_and_adapter_state(tmp_path):
+ base_model = make_wrapper()
+ base_model.lora_enabled = True
+ base_model._lora_applied = True
+
+ reward_model = build_reward_model(base_model)
+ _set_scalar_head_to_first_feature(reward_model)
+ reward_model.base_model.model.output.update(
+ {
+ "weight": mx.ones_like(reward_model.base_model.model.output.weight) * 0.5,
+ "bias": mx.ones_like(reward_model.base_model.model.output.bias) * 0.25,
+ },
+ strict=False,
+ )
+ mx.eval(reward_model.base_model.model.parameters())
+
+ output_dir = Path(tmp_path) / "reward_role"
+ reward_model.save_pretrained(str(output_dir))
+
+ restored = build_reward_model(make_wrapper())
+ restored.load_pretrained(str(output_dir))
+
+ score_inputs = mx.array([[2, 4, 6]], dtype=mx.int32)
+ sequence_lengths = mx.array([3], dtype=mx.int32)
+ prompt_lengths = mx.array([1], dtype=mx.int32)
+ completion_lengths = mx.array([2], dtype=mx.int32)
+
+ assert (output_dir / "head.safetensors").exists()
+ assert (output_dir / "head_config.json").exists()
+ assert (output_dir / "weights.safetensors").exists()
+ assert (output_dir / "adapters.safetensors").exists()
+ assert (output_dir / "adapter_config.json").exists()
+ assert mx.allclose(
+ reward_model.score(
+ score_inputs,
+ sequence_lengths=sequence_lengths,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ ),
+ restored.score(
+ score_inputs,
+ sequence_lengths=sequence_lengths,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ ),
+ )
+
+
+def test_scalar_role_save_load_round_trip_preserves_independent_backbone_weights(tmp_path):
+ original_base = make_weighted_wrapper(101)
+
+ reward_model = build_reward_model(original_base)
+ _set_scalar_head_to_first_feature(reward_model)
+
+ output_dir = Path(tmp_path) / "independent_reward_role"
+ reward_model.save_pretrained(str(output_dir))
+
+ restored = build_reward_model(make_weighted_wrapper(202))
+ restored.load_pretrained(str(output_dir))
+
+ sequence = mx.array([[2, 5, 9]], dtype=mx.int32)
+ sequence_lengths = mx.array([3], dtype=mx.int32)
+ prompt_lengths = mx.array([1], dtype=mx.int32)
+ completion_lengths = mx.array([2], dtype=mx.int32)
+
+ assert mx.allclose(
+ reward_model.base_model.model.model.proj.weight,
+ restored.base_model.model.model.proj.weight,
+ )
+ assert mx.allclose(
+ reward_model.score(
+ sequence,
+ sequence_lengths=sequence_lengths,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ ),
+ restored.score(
+ sequence,
+ sequence_lengths=sequence_lengths,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ ),
+ )
+
+
+def test_snapshot_false_scalar_builders_preserve_caller_adapter_path():
+ base_model = make_wrapper()
+ base_model.set_adapter_path("/tmp/existing-adapters")
+
+ reward_model = build_reward_model(base_model, snapshot=False)
+ value_model = build_value_model(base_model, snapshot=False)
+
+ assert reward_model.base_model is base_model
+ assert value_model.base_model is base_model
+ assert str(base_model.get_adapter_path()) == "/tmp/existing-adapters"
+
+
+def test_scalar_objective_helpers_return_stable_losses_and_metrics():
+ reward_model = build_reward_model(make_wrapper())
+ _set_scalar_head_to_first_feature(reward_model)
+
+ chosen_input_ids = mx.array([[2, 4, 9], [2, 3, 8]], dtype=mx.int32)
+ rejected_input_ids = mx.array([[2, 4, 5], [2, 3, 6]], dtype=mx.int32)
+ chosen_lengths = mx.array([3, 3], dtype=mx.int32)
+ rejected_lengths = mx.array([3, 3], dtype=mx.int32)
+ prompt_lengths = mx.array([1, 1], dtype=mx.int32)
+ completion_lengths = mx.array([2, 2], dtype=mx.int32)
+
+ reward_loss, reward_outputs = reward_model_pairwise_loss(
+ reward_model,
+ chosen_input_ids=chosen_input_ids,
+ rejected_input_ids=rejected_input_ids,
+ chosen_sequence_lengths=chosen_lengths,
+ rejected_sequence_lengths=rejected_lengths,
+ chosen_prompt_lengths=prompt_lengths,
+ rejected_prompt_lengths=prompt_lengths,
+ chosen_completion_lengths=completion_lengths,
+ rejected_completion_lengths=completion_lengths,
+ )
+ assert float(reward_loss.item()) > 0.0
+ assert pairwise_ranking_accuracy(
+ reward_outputs["chosen_scores"],
+ reward_outputs["rejected_scores"],
+ ) == 1.0
+
+ value_model = build_value_model(make_wrapper())
+ _set_scalar_head_to_first_feature(value_model)
+ value_targets = mx.array([9.0, 8.0], dtype=mx.float32)
+ value_loss, predictions = value_model_regression_loss(
+ value_model,
+ input_ids=chosen_input_ids,
+ sequence_lengths=chosen_lengths,
+ targets=value_targets,
+ prompt_lengths=prompt_lengths,
+ completion_lengths=completion_lengths,
+ )
+ metrics = scalar_loss_metrics(value_loss, predictions, value_targets)
+ assert metrics["loss"] == 0.0
+ assert metrics["mae"] == 0.0
+ assert metrics["mse"] == 0.0
+
+
+def test_build_reference_policy_public_builder_returns_compat_wrapper():
+ policy = make_wrapper()
+ reference = build_reference_policy(policy)
+
+ assert reference.source == "policy_snapshot"
+ assert reference.metadata["snapshot_strategy"] == "clone_and_freeze"
diff --git a/tests/test_rl_runtime.py b/tests/test_rl_runtime.py
new file mode 100644
index 0000000..abdccda
--- /dev/null
+++ b/tests/test_rl_runtime.py
@@ -0,0 +1,917 @@
+import math
+
+import mlx.core as mx
+import mlx.nn as nn
+
+
+class TinyModel(nn.Module):
+ def __init__(self, vocab_size: int = 32, hidden_size: int = 16):
+ super().__init__()
+ self.embedding = nn.Embedding(vocab_size, hidden_size)
+ self.output = nn.Linear(hidden_size, vocab_size)
+
+ def __call__(self, x):
+ return self.output(self.embedding(x))
+
+
+class ScriptedModel(nn.Module):
+ def __init__(self, next_tokens, vocab_size: int = 16):
+ super().__init__()
+ self.next_tokens = dict(next_tokens)
+ self.vocab_size = vocab_size
+
+ def __call__(self, x):
+ batch, seq_len = x.shape
+ logits = mx.full((batch, seq_len, self.vocab_size), -100.0)
+ token_id = self.next_tokens.get(seq_len, self.next_tokens.get("default", 0))
+ logits[:, -1, token_id] = 100.0
+ return logits
+
+
+class CacheOnlyScriptedModel(nn.Module):
+ def __init__(self, next_tokens, vocab_size: int = 16):
+ super().__init__()
+ self.next_tokens = dict(next_tokens)
+ self.vocab_size = vocab_size
+ self.calls = []
+
+ def make_cache(self):
+ return [{"steps": 0}]
+
+ def __call__(self, x, cache=None):
+ if cache is None:
+ raise ValueError("cache is required")
+ batch, seq_len = x.shape
+ cache[0]["steps"] += 1
+ self.calls.append((seq_len, cache[0]["steps"]))
+ logits = mx.full((batch, seq_len, self.vocab_size), -100.0)
+ token_id = self.next_tokens.get(cache[0]["steps"], self.next_tokens.get("default", 0))
+ logits[:, -1, token_id] = 100.0
+ return logits
+
+
+class BatchSizedCacheScriptedModel(nn.Module):
+ def __init__(self, vocab_size: int = 16):
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.calls = []
+
+ def make_cache(self):
+ return [{"steps": 0, "batch_size": None}]
+
+ def __call__(self, x, cache=None):
+ if cache is None:
+ raise ValueError("cache is required")
+ batch, seq_len = x.shape
+ if cache[0]["batch_size"] is None:
+ cache[0]["batch_size"] = batch
+ elif cache[0]["batch_size"] != batch:
+ raise ValueError("cache batch size mismatch")
+ cache[0]["steps"] += 1
+ self.calls.append((batch, seq_len, cache[0]["steps"]))
+ logits = mx.full((batch, seq_len, self.vocab_size), -100.0)
+ for row_index in range(batch):
+ if cache[0]["steps"] == 1:
+ token_id = 1 if row_index == 0 else 5
+ elif cache[0]["steps"] == 2:
+ token_id = 1
+ else:
+ token_id = 1
+ logits[row_index, -1, token_id] = 100.0
+ return logits
+
+
+class BatchSensitiveCacheModel(nn.Module):
+ def __init__(self, vocab_size: int = 16):
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.cache_calls = []
+
+ def make_cache(self):
+ return [{"steps": 0}]
+
+ def _token_for_sequence(self, row):
+ last_token = int(row[-1]) if len(row) else 0
+ if last_token == 5:
+ return 1
+ return 5
+
+ def __call__(self, x, cache=None):
+ batch, seq_len = x.shape
+ logits = mx.full((batch, seq_len, self.vocab_size), -100.0)
+ if cache is None:
+ for row_index in range(batch):
+ row = x[row_index].tolist()
+ for position in range(seq_len):
+ logits[row_index, position, self._token_for_sequence(row[: position + 1])] = 100.0
+ return logits
+
+ cache[0]["steps"] += 1
+ self.cache_calls.append((batch, seq_len, cache[0]["steps"]))
+ for row_index in range(batch):
+ token_id = self._token_for_sequence(x[row_index].tolist())
+ if batch > 1 and seq_len == 1:
+ token_id = 7
+ logits[row_index, -1, token_id] = 100.0
+ return logits
+
+
+class TinyTokenizer:
+ pad_token_id = 0
+ eos_token_id = 1
+ bos_token_id = 2
+
+ def encode(self, text: str, add_special_tokens: bool = True):
+ if text == "<|im_end|>":
+ return [9]
+ ids = [((ord(char) % 10) + 3) for char in text]
+ if add_special_tokens:
+ ids = [self.bos_token_id] + ids + [self.eos_token_id]
+ return ids
+
+ def decode(self, ids, skip_special_tokens: bool = True):
+ if skip_special_tokens:
+ ids = [token for token in ids if token not in (self.pad_token_id, self.eos_token_id, self.bos_token_id)]
+ return "".join(chr(65 + (token % 26)) for token in ids)
+
+ def convert_tokens_to_ids(self, token: str):
+ if token == "<|im_end|>":
+ return 9
+ return None
+
+ def get_vocab(self):
+ return {"<|im_end|>": 9}
+
+
+def test_collect_rollouts_returns_metadata_and_truncates_prompt_left():
+ from mlx_tune._rl_runtime import collect_rollouts
+
+ tokenizer = TinyTokenizer()
+ model = ScriptedModel({2: 5, 3: 1, "default": 1})
+
+ rollout = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {
+ "sample_index": 7,
+ "prompt": "abcdef",
+ "prompt_ids": [3, 4, 5, 6, 7, 8],
+ "reward_context": "ctx",
+ }
+ ],
+ sampling_config={
+ "num_generations": 2,
+ "temperature": 0.0,
+ "max_completion_length": 4,
+ "max_seq_length": 6,
+ },
+ )
+
+ assert rollout.prompt_ids == [[4, 5, 6, 7, 8], [4, 5, 6, 7, 8]]
+ assert rollout.prompt_lengths.tolist() == [5, 5]
+ assert rollout.prompt_texts == [tokenizer.decode([4, 5, 6, 7, 8]), tokenizer.decode([4, 5, 6, 7, 8])]
+ assert rollout.original_prompt_texts == ["abcdef", "abcdef"]
+ assert rollout.completion_ids == [[1], [1]]
+ assert rollout.completion_lengths.tolist() == [1, 1]
+ assert rollout.sampled_token_logprobs.shape == (2, 1)
+ assert rollout.eos_flags.tolist() == [True, True]
+ assert rollout.truncation_flags.tolist() == [False, False]
+ assert rollout.prompt_group_indices.tolist() == [0, 0]
+ assert rollout.sample_indices.tolist() == [7, 7]
+ assert rollout.policy_eval.input_ids.shape == (2, 6)
+ assert mx.allclose(
+ rollout.rollout_logprobs,
+ rollout.sampled_token_logprobs.sum(axis=-1),
+ )
+
+
+def test_collect_rollouts_can_capture_sample_stats_and_truncation():
+ from mlx_tune._rl_runtime import collect_rollouts
+
+ tokenizer = TinyTokenizer()
+ model = ScriptedModel({2: 5, 3: 6, 4: 7, "default": 7})
+
+ rollout = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {
+ "sample_index": 1,
+ "prompt": "ab",
+ "prompt_ids": [3, 4],
+ "reward_context": "ctx",
+ }
+ ],
+ sampling_config={
+ "num_generations": 1,
+ "temperature": 0.0,
+ "max_completion_length": 2,
+ "max_seq_length": 8,
+ },
+ collect_sample_stats=True,
+ )
+
+ assert rollout.completion_ids == [[5, 6]]
+ assert rollout.eos_flags.tolist() == [False]
+ assert rollout.truncation_flags.tolist() == [True]
+ assert rollout.sampled_token_logits.shape == (1, 2)
+ assert rollout.token_entropies.shape == (1, 2)
+
+
+def test_collect_rollouts_sampled_logprobs_match_rescored_completion_logprobs():
+ from mlx_tune._rl_runtime import collect_rollouts, score_policy
+
+ tokenizer = TinyTokenizer()
+ mx.random.seed(0)
+ model = TinyModel()
+ mx.eval(model.parameters())
+
+ rollout = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {
+ "sample_index": 0,
+ "prompt": "abc",
+ "prompt_ids": [tokenizer.bos_token_id, 3, 4, 5],
+ "reward_context": "ctx",
+ }
+ ],
+ sampling_config={
+ "num_generations": 4,
+ "temperature": 1.0,
+ "max_completion_length": 5,
+ "max_seq_length": 16,
+ },
+ )
+
+ rescored = score_policy(model, rollout.policy_eval, mode="completion", temperature=1.0)
+ masked_rescored = rescored.token_logprobs * rollout.policy_eval.token_mask.astype(mx.float32)
+
+ assert mx.allclose(rollout.rollout_logprobs, rescored.summed_logprobs, atol=1e-5)
+ assert mx.allclose(rollout.policy_eval.old_token_logprobs, masked_rescored, atol=1e-5)
+
+
+def test_collect_rollouts_respects_max_seq_length_during_generation():
+ from mlx_tune._rl_runtime import collect_rollouts
+
+ tokenizer = TinyTokenizer()
+ model = ScriptedModel({2: 5, 3: 6, 4: 7, 5: 8, 6: 9, "default": 9})
+
+ rollout = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {
+ "sample_index": 0,
+ "prompt": "abcdef",
+ "prompt_ids": [3, 4, 5, 6, 7, 8],
+ "reward_context": "ctx",
+ }
+ ],
+ sampling_config={
+ "num_generations": 1,
+ "temperature": 0.0,
+ "max_completion_length": 5,
+ "max_seq_length": 3,
+ },
+ )
+
+ assert rollout.prompt_lengths.tolist() == [2]
+ assert rollout.completion_lengths.tolist() == [1]
+ assert rollout.policy_eval.input_ids.shape == (1, 3)
+ assert rollout.truncation_flags.tolist() == [True]
+
+
+def test_collect_rollouts_reduces_completion_budget_before_collapsing_prompt():
+ from mlx_tune._rl_runtime import collect_rollouts
+
+ tokenizer = TinyTokenizer()
+ model = ScriptedModel({6: 5, 7: 6, "default": 6})
+
+ rollout = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {
+ "sample_index": 0,
+ "prompt": "abcdef",
+ "prompt_ids": [3, 4, 5, 6, 7, 8],
+ "reward_context": "ctx",
+ }
+ ],
+ sampling_config={
+ "num_generations": 1,
+ "temperature": 0.0,
+ "max_completion_length": 6,
+ "max_seq_length": 6,
+ },
+ )
+
+ assert rollout.prompt_ids == [[4, 5, 6, 7, 8]]
+ assert rollout.prompt_lengths.tolist() == [5]
+ assert rollout.completion_ids == [[6]]
+ assert rollout.completion_lengths.tolist() == [1]
+ assert rollout.policy_eval.input_ids.shape == (1, 6)
+ assert rollout.truncation_flags.tolist() == [True]
+
+
+def test_collect_rollouts_treats_unsloth_stop_token_as_terminal():
+ from mlx_tune._rl_runtime import collect_rollouts, sample_completion
+
+ tokenizer = TinyTokenizer()
+ tokenizer._unsloth_stop_token = "<|im_end|>"
+ model = ScriptedModel({2: 9, 3: 7, 4: 7, "default": 7})
+
+ sampled = sample_completion(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_ids=[3, 4],
+ max_tokens=4,
+ temperature=0.0,
+ )
+ rollout = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {
+ "sample_index": 0,
+ "prompt": "ab",
+ "prompt_ids": [3, 4],
+ "reward_context": "ctx",
+ }
+ ],
+ sampling_config={
+ "num_generations": 1,
+ "temperature": 0.0,
+ "max_completion_length": 4,
+ },
+ )
+
+ assert sampled["completion_ids"] == [9]
+ assert sampled["eos_flag"] is True
+ assert sampled["truncation_flag"] is False
+ assert rollout.completion_ids == [[9]]
+ assert rollout.eos_flags.tolist() == [True]
+ assert rollout.truncation_flags.tolist() == [False]
+
+
+def test_collect_rollouts_initializes_and_reuses_prompt_cache_for_logits_only_models():
+ from mlx_tune._rl_runtime import collect_rollouts
+
+ tokenizer = TinyTokenizer()
+ model = CacheOnlyScriptedModel({1: 5, 2: 1, "default": 1})
+
+ rollout = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {
+ "sample_index": 0,
+ "prompt": "ab",
+ "prompt_ids": [3, 4],
+ "reward_context": "ctx",
+ }
+ ],
+ sampling_config={
+ "num_generations": 2,
+ "temperature": 0.0,
+ "max_completion_length": 4,
+ "generation_batch_size": 2,
+ },
+ )
+
+ assert rollout.completion_ids == [[5, 1], [5, 1]]
+ assert rollout.eos_flags.tolist() == [True, True]
+ assert model.calls == [(2, 1), (2, 1), (1, 2), (1, 2)]
+
+
+def test_collect_rollouts_rebuilds_prompt_cache_when_active_batch_shrinks():
+ from mlx_tune._rl_runtime import collect_rollouts
+
+ tokenizer = TinyTokenizer()
+ model = BatchSizedCacheScriptedModel()
+
+ rollout = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {
+ "sample_index": 0,
+ "prompt": "ab",
+ "prompt_ids": [3, 4],
+ "reward_context": "ctx-a",
+ },
+ {
+ "sample_index": 1,
+ "prompt": "ac",
+ "prompt_ids": [3, 5],
+ "reward_context": "ctx-b",
+ },
+ ],
+ sampling_config={
+ "num_generations": 1,
+ "temperature": 0.0,
+ "max_completion_length": 3,
+ "generation_batch_size": 2,
+ },
+ )
+
+ assert rollout.completion_ids == [[1], [1]]
+ assert rollout.eos_flags.tolist() == [True, True]
+ assert rollout.truncation_flags.tolist() == [False, False]
+ assert model.calls == [(1, 2, 1), (1, 2, 1)]
+
+
+def test_collect_rollouts_uses_per_row_cache_to_preserve_rescored_logprobs():
+ from mlx_tune._rl_runtime import collect_rollouts, score_policy
+
+ tokenizer = TinyTokenizer()
+ model = BatchSensitiveCacheModel()
+
+ rollout = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {
+ "sample_index": 0,
+ "prompt": "ab",
+ "prompt_ids": [3, 4],
+ "reward_context": "ctx",
+ }
+ ],
+ sampling_config={
+ "num_generations": 2,
+ "temperature": 0.0,
+ "max_completion_length": 3,
+ "generation_batch_size": 2,
+ },
+ )
+
+ rescored = score_policy(model, rollout.policy_eval, mode="completion", temperature=1.0)
+ masked_rescored = rescored.token_logprobs * rollout.policy_eval.token_mask.astype(mx.float32)
+
+ assert rollout.completion_ids == [[5, 1], [5, 1]]
+ assert all(batch == 1 for batch, _, _ in model.cache_calls)
+ assert mx.allclose(rollout.rollout_logprobs, rescored.summed_logprobs, atol=1e-5)
+ assert mx.allclose(rollout.policy_eval.old_token_logprobs, masked_rescored, atol=1e-5)
+
+
+def test_score_policy_matches_public_logprob_helpers():
+ from mlx_tune._rl_runtime import make_policy_eval_batch, score_policy
+ from mlx_tune.losses import compute_completion_log_probs, compute_log_probs_with_lengths
+
+ model = TinyModel()
+ mx.eval(model.parameters())
+
+ sequence_batch = make_policy_eval_batch(
+ [[1, 2, 3, 4], [1, 5, 6]],
+ pad_id=0,
+ mode="sequence",
+ )
+ sequence_scores = score_policy(model, sequence_batch, mode="sequence")
+ direct_sequence = compute_log_probs_with_lengths(
+ model,
+ sequence_batch.input_ids,
+ sequence_batch.sequence_lengths,
+ )
+
+ completion_batch = make_policy_eval_batch(
+ [[1, 2, 3, 4, 5], [1, 4, 3, 2]],
+ pad_id=0,
+ mode="completion",
+ prompt_lengths=[3, 2],
+ completion_lengths=[2, 2],
+ )
+ completion_scores = score_policy(model, completion_batch, mode="completion")
+ direct_completion = compute_completion_log_probs(
+ model,
+ completion_batch.input_ids,
+ completion_batch.prompt_lengths,
+ completion_batch.completion_lengths,
+ )
+
+ assert mx.allclose(sequence_scores.summed_logprobs, direct_sequence)
+ assert mx.allclose(completion_scores.summed_logprobs, direct_completion)
+
+
+def test_reference_precompute_helpers_match_runtime_scorer():
+ from mlx_tune._rl_runtime import make_policy_eval_batch, score_policy_in_chunks
+ from mlx_tune.losses import (
+ precompute_kto_reference_logprobs,
+ precompute_preference_reference_logprobs,
+ )
+
+ model = TinyModel()
+ mx.eval(model.parameters())
+
+ chosen = make_policy_eval_batch([[1, 2, 3, 4], [1, 4, 5]], pad_id=0, mode="sequence")
+ rejected = make_policy_eval_batch([[1, 3, 2, 4], [1, 6, 5]], pad_id=0, mode="sequence")
+ direct_chosen = score_policy_in_chunks(model, chosen, batch_size=1, mode="sequence").summed_logprobs
+ direct_rejected = score_policy_in_chunks(model, rejected, batch_size=1, mode="sequence").summed_logprobs
+ cached_chosen, cached_rejected = precompute_preference_reference_logprobs(
+ model,
+ chosen.input_ids,
+ rejected.input_ids,
+ chosen.sequence_lengths,
+ rejected.sequence_lengths,
+ batch_size=1,
+ )
+
+ kto_batch = make_policy_eval_batch([[1, 2, 3], [1, 4, 5, 6]], pad_id=0, mode="sequence")
+ direct_kto = score_policy_in_chunks(model, kto_batch, batch_size=1, mode="sequence").summed_logprobs
+ cached_kto = precompute_kto_reference_logprobs(
+ model,
+ kto_batch.input_ids,
+ kto_batch.sequence_lengths,
+ batch_size=1,
+ )
+
+ assert mx.allclose(direct_chosen, cached_chosen)
+ assert mx.allclose(direct_rejected, cached_rejected)
+ assert mx.allclose(direct_kto, cached_kto)
+
+
+def test_reward_adapter_supports_legacy_and_structured_evaluators():
+ from mlx_tune._rl_runtime import collect_rollouts, evaluate_rewards
+
+ tokenizer = TinyTokenizer()
+ model = ScriptedModel({2: 5, 3: 1, "default": 1})
+ rollout = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {
+ "sample_index": 0,
+ "prompt": "ab",
+ "prompt_ids": [3, 4],
+ "reward_context": "ctx",
+ }
+ ],
+ sampling_config={"num_generations": 1, "temperature": 0.0, "max_completion_length": 2},
+ )
+
+ legacy = evaluate_rewards(
+ rollout,
+ lambda response, context: float(len(response) + len(context)),
+ )
+
+ seen_payloads = []
+
+ class StructuredEvaluator:
+ def evaluate(self, payload):
+ seen_payloads.append(payload)
+ return {
+ "reward": float(payload["completion_length"]),
+ "components": {"length": float(payload["completion_length"])},
+ "diagnostics": {"used_context": payload["reward_context"]},
+ }
+
+ structured = evaluate_rewards(rollout, StructuredEvaluator())
+
+ assert legacy.scalar_rewards.tolist() == [float(len(rollout.completion_texts[0]) + 3)]
+ assert structured.scalar_rewards.tolist() == [2.0]
+ assert structured.named_reward_components == [{"length": 2.0}]
+ assert structured.diagnostics == [{"used_context": "ctx"}]
+ assert seen_payloads[0]["prompt_text"] == tokenizer.decode([3, 4])
+ assert seen_payloads[0]["original_prompt_text"] == "ab"
+
+
+def test_reward_payload_exposes_effective_and_original_prompt_when_truncated():
+ from mlx_tune._rl_runtime import collect_rollouts, evaluate_rewards
+
+ tokenizer = TinyTokenizer()
+ model = ScriptedModel({2: 5, 3: 1, "default": 1})
+ captured = []
+
+ rollout = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {
+ "sample_index": 0,
+ "prompt": "abcdef",
+ "prompt_ids": [3, 4, 5, 6, 7, 8],
+ "reward_context": "ctx",
+ }
+ ],
+ sampling_config={
+ "num_generations": 1,
+ "temperature": 0.0,
+ "max_completion_length": 2,
+ "max_seq_length": 4,
+ },
+ )
+
+ evaluate_rewards(
+ rollout,
+ lambda payload: captured.append(payload) or float(payload["completion_length"]),
+ )
+
+ assert captured[0]["prompt_ids"] == [6, 7, 8]
+ assert captured[0]["prompt_text"] == tokenizer.decode([6, 7, 8])
+ assert captured[0]["original_prompt_text"] == "abcdef"
+
+
+def test_compute_advantages_uses_zero_variance_fallback_per_prompt():
+ from mlx_tune._rl_runtime import RewardBatch, compute_advantages
+
+ reward_batch = RewardBatch(
+ prompt_texts=["p0", "p0", "p1", "p1"],
+ completion_texts=["a", "b", "c", "d"],
+ reward_contexts=["c0", "c0", "c1", "c1"],
+ scalar_rewards=mx.array([2.0, 2.0, 1.0, 3.0], dtype=mx.float32),
+ prompt_group_indices=mx.array([0, 0, 1, 1]),
+ )
+
+ advantages = compute_advantages(reward_batch)
+
+ assert mx.allclose(advantages, mx.array([0.0, 0.0, -1.0, 1.0], dtype=mx.float32))
+
+
+def test_assemble_minibatches_preserves_order_and_prompt_groups():
+ from mlx_tune._rl_runtime import assemble_minibatches, make_policy_eval_batch
+
+ batch = make_policy_eval_batch(
+ [[1, 2, 3], [1, 4, 5], [1, 6, 7]],
+ pad_id=0,
+ mode="sequence",
+ rollout_logprobs=mx.array([0.1, 0.2, 0.3], dtype=mx.float32),
+ advantages=mx.array([1.0, 2.0, 3.0], dtype=mx.float32),
+ prompt_group_indices=mx.array([9, 9, 10]),
+ sample_indices=mx.array([0, 1, 2]),
+ )
+
+ minibatches = list(assemble_minibatches(batch, minibatch_size=2, shuffle=False))
+
+ assert len(minibatches) == 2
+ assert minibatches[0].input_ids.tolist() == [[1, 2, 3], [1, 4, 5]]
+ assert minibatches[0].prompt_group_indices.tolist() == [9, 9]
+ assert minibatches[1].input_ids.tolist() == [[1, 6, 7]]
+ assert minibatches[1].prompt_group_indices.tolist() == [10]
+
+
+def test_post_rollout_helpers_attach_reference_values_and_rloo_advantages():
+ from mlx_tune._rl_runtime import (
+ collect_rollouts,
+ compute_returns_and_advantages,
+ predict_rollout_values,
+ rank_grouped_rollouts,
+ score_rollout_references,
+ )
+
+ class ConstantValueModel:
+ def predict(self, input_ids, **kwargs):
+ del kwargs
+ return mx.array([0.5] * input_ids.shape[0], dtype=mx.float32)
+
+ tokenizer = TinyTokenizer()
+ policy = ScriptedModel({2: 5, 3: 1, "default": 1})
+ reference = ScriptedModel({2: 5, 3: 1, "default": 1})
+
+ rollout = collect_rollouts(
+ policy=policy,
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {"sample_index": 0, "prompt": "ab", "prompt_ids": [3, 4], "reward_context": "x"},
+ {"sample_index": 1, "prompt": "cd", "prompt_ids": [5, 6], "reward_context": "y"},
+ ],
+ sampling_config={"num_generations": 2, "temperature": 0.0, "max_completion_length": 2},
+ )
+ rollout.rewards = mx.array([3.0, 1.0, 4.0, 2.0], dtype=mx.float32)
+
+ rollout = score_rollout_references(reference, rollout, batch_size=1)
+ rollout = predict_rollout_values(ConstantValueModel(), rollout, batch_size=2)
+ returns, advantages = compute_returns_and_advantages(
+ rollout.rewards,
+ prompt_group_indices=rollout.prompt_group_indices,
+ mode="rloo",
+ )
+ rollout.returns = returns
+ rollout.advantages = advantages
+ rankings = rank_grouped_rollouts(rollout)
+
+ assert rollout.reference_logprobs.shape == (4,)
+ assert rollout.value_predictions.tolist() == [0.5, 0.5, 0.5, 0.5]
+ assert returns.tolist() == [3.0, 1.0, 4.0, 2.0]
+ assert advantages.tolist() == [2.0, -2.0, 2.0, -2.0]
+ assert rankings[0]["best_position"] == 0
+ assert rankings[0]["worst_position"] == 1
+
+
+def test_length_normalization_and_kl_helpers_match_expected_math():
+ from mlx_tune._rl_runtime import kl_against_reference, normalize_logprobs
+
+ summed = mx.array([4.0, 6.0], dtype=mx.float32)
+ lengths = mx.array([2, 3])
+ normalized = normalize_logprobs(summed, lengths, mode="mean")
+ kl = kl_against_reference(
+ mx.array([0.0, 1.0], dtype=mx.float32),
+ mx.array([0.0, 0.5], dtype=mx.float32),
+ )
+
+ assert mx.allclose(normalized, mx.array([2.0, 2.0], dtype=mx.float32))
+ assert mx.allclose(
+ kl,
+ mx.array([0.0, float((mx.exp(mx.array(0.5)) - 0.5 - 1.0).item())], dtype=mx.float32),
+ )
+
+
+def test_collect_rollouts_batched_decode_matches_per_sequence_sampler():
+ from mlx_tune._rl_runtime import collect_rollouts, sample_completion
+
+ tokenizer = TinyTokenizer()
+ model = ScriptedModel({2: 5, 3: 6, 4: 1, "default": 1})
+ prompt_samples = [
+ {"sample_index": 0, "prompt": "ab", "prompt_ids": [3, 4], "reward_context": "x"},
+ {"sample_index": 1, "prompt": "cd", "prompt_ids": [5, 6], "reward_context": "y"},
+ ]
+
+ expected = []
+ for sample in prompt_samples:
+ for _ in range(2):
+ expected.append(
+ sample_completion(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_ids=sample["prompt_ids"],
+ max_tokens=3,
+ temperature=0.0,
+ collect_sample_stats=True,
+ )
+ )
+
+ rollout = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=prompt_samples,
+ sampling_config={
+ "num_generations": 2,
+ "temperature": 0.0,
+ "max_completion_length": 3,
+ "generation_batch_size": 2,
+ },
+ collect_sample_stats=True,
+ )
+
+ assert rollout.completion_ids == [item["completion_ids"] for item in expected]
+ assert rollout.eos_flags.tolist() == [item["eos_flag"] for item in expected]
+ assert rollout.truncation_flags.tolist() == [item["truncation_flag"] for item in expected]
+ assert mx.allclose(
+ rollout.sampled_token_logprobs,
+ mx.array([item["sampled_logprobs"] for item in expected], dtype=mx.float32),
+ )
+ assert mx.allclose(
+ rollout.sampled_token_logits,
+ mx.array([item["sampled_logits"] for item in expected], dtype=mx.float32),
+ )
+ assert mx.allclose(
+ rollout.token_entropies,
+ mx.array([item["token_entropies"] for item in expected], dtype=mx.float32),
+ )
+
+
+def test_reward_adapter_prefers_evaluate_batch_when_available():
+ from mlx_tune._rl_runtime import collect_rollouts, evaluate_rewards
+
+ tokenizer = TinyTokenizer()
+ model = ScriptedModel({2: 5, 3: 1, "default": 1})
+ rollout = collect_rollouts(
+ policy=model,
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {"sample_index": 0, "prompt": "ab", "prompt_ids": [3, 4], "reward_context": "ctx0"},
+ {"sample_index": 1, "prompt": "cd", "prompt_ids": [5, 6], "reward_context": "ctx1"},
+ ],
+ sampling_config={"num_generations": 1, "temperature": 0.0, "max_completion_length": 2},
+ )
+
+ calls = []
+
+ class BatchEvaluator:
+ def evaluate_batch(self, payloads):
+ calls.append([payload["reward_context"] for payload in payloads])
+ return [
+ {"reward": float(payload["completion_length"]), "components": {"len": float(payload["completion_length"])}}
+ for payload in payloads
+ ]
+
+ reward_batch = evaluate_rewards(rollout, BatchEvaluator())
+
+ assert calls == [["ctx0", "ctx1"]]
+ assert reward_batch.scalar_rewards.tolist() == [2.0, 2.0]
+ assert reward_batch.named_reward_components == [{"len": 2.0}, {"len": 2.0}]
+
+
+def test_token_budget_chunking_matches_unchunked_policy_and_value_scoring():
+ from mlx_tune._rl_runtime import (
+ collect_rollouts,
+ make_policy_eval_batch,
+ predict_rollout_values,
+ score_policy_in_chunks,
+ )
+
+ class LengthValueModel:
+ def predict(self, input_ids, sequence_lengths, **kwargs):
+ del input_ids, kwargs
+ return sequence_lengths.astype(mx.float32)
+
+ model = TinyModel()
+ mx.eval(model.parameters())
+ batch = make_policy_eval_batch(
+ [[1, 2, 3, 4, 5], [1, 4, 5], [1, 6, 7, 8], [1, 9, 10, 11, 12, 13]],
+ pad_id=0,
+ mode="sequence",
+ )
+
+ full_scores = score_policy_in_chunks(model, batch, batch_size=8, mode="sequence")
+ chunked_scores = score_policy_in_chunks(
+ model,
+ batch,
+ batch_size=8,
+ token_budget=4,
+ mode="sequence",
+ )
+
+ tokenizer = TinyTokenizer()
+ rollout = collect_rollouts(
+ policy=ScriptedModel({2: 5, 3: 1, "default": 1}),
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {"sample_index": 0, "prompt": "ab", "prompt_ids": [3, 4], "reward_context": "x"},
+ {"sample_index": 1, "prompt": "cde", "prompt_ids": [5, 6, 7], "reward_context": "y"},
+ ],
+ sampling_config={"num_generations": 2, "temperature": 0.0, "max_completion_length": 2},
+ )
+ full_values = predict_rollout_values(LengthValueModel(), rollout, batch_size=8).value_predictions
+ chunked_values = predict_rollout_values(
+ LengthValueModel(),
+ rollout,
+ batch_size=8,
+ token_budget=4,
+ ).value_predictions
+
+ assert mx.allclose(full_scores.summed_logprobs, chunked_scores.summed_logprobs)
+ assert mx.allclose(full_scores.token_logprobs, chunked_scores.token_logprobs)
+ assert mx.allclose(full_values, chunked_values)
+
+
+def test_summarize_rollout_metrics_includes_completion_and_stop_stats():
+ from mlx_tune._rl_runtime import collect_rollouts, summarize_rollout_metrics
+
+ tokenizer = TinyTokenizer()
+ tokenizer._unsloth_stop_token = "<|im_end|>"
+ rollout = collect_rollouts(
+ policy=ScriptedModel({2: 9, 3: 7, 4: 7, "default": 7}),
+ tokenizer=tokenizer,
+ prompt_samples=[
+ {"sample_index": 0, "prompt": "ab", "prompt_ids": [3, 4], "reward_context": "x"},
+ {"sample_index": 1, "prompt": "cd", "prompt_ids": [5, 6], "reward_context": "y"},
+ ],
+ sampling_config={"num_generations": 1, "temperature": 0.0, "max_completion_length": 4},
+ )
+ rollout.rewards = mx.array([1.0, 0.0], dtype=mx.float32)
+
+ metrics = summarize_rollout_metrics(rollout, policy_loss=0.5)
+
+ assert metrics["completion_length_mean"] == 1.0
+ assert metrics["completion_length_max"] == 1.0
+ assert metrics["eos_rate"] == 1.0
+ assert metrics["truncation_rate"] == 0.0
+ assert metrics["reward_mean"] == 0.5
+ assert metrics["policy_loss"] == 0.5
+
+
+def test_summarize_rollout_metrics_normalizes_kl_by_completion_length():
+ from mlx_tune._rl_runtime import RolloutBatch, summarize_rollout_metrics
+
+ rollout = RolloutBatch(
+ prompt_ids=[[1, 2], [3, 4]],
+ completion_ids=[[5], [6]],
+ prompt_texts=["p1", "p2"],
+ original_prompt_texts=None,
+ completion_texts=["c1", "c2"],
+ reward_contexts=["x", "y"],
+ prompt_lengths=mx.array([2, 2], dtype=mx.int32),
+ completion_lengths=mx.array([10, 1000], dtype=mx.int32),
+ sampled_token_logprobs=mx.zeros((2, 1), dtype=mx.float32),
+ rollout_logprobs=mx.array([0.0, 0.0], dtype=mx.float32),
+ eos_flags=mx.array([True, True]),
+ truncation_flags=mx.array([False, False]),
+ prompt_group_indices=mx.array([0, 1], dtype=mx.int32),
+ policy_eval=None,
+ sample_indices=None,
+ sampled_token_logits=None,
+ token_entropies=None,
+ old_logprobs=None,
+ rewards=mx.array([1.0, 0.0], dtype=mx.float32),
+ reference_logprobs=mx.array([-10.0, -1000.0], dtype=mx.float32),
+ returns=None,
+ advantages=None,
+ )
+
+ metrics = summarize_rollout_metrics(rollout)
+
+ assert abs(metrics["logprob_delta_per_token_mean"] - 1.0) < 1e-6
+ assert abs(metrics["kl_to_reference_mean"] - (math.e - 2.0)) < 1e-5
diff --git a/tests/test_rl_trainers_integration.py b/tests/test_rl_trainers_integration.py
index cfadc55..c3e18cd 100644
--- a/tests/test_rl_trainers_integration.py
+++ b/tests/test_rl_trainers_integration.py
@@ -1,153 +1,218 @@
"""
-Integration tests for RL trainers (DPO, ORPO, GRPO, KTO, SimPO).
-
-These tests verify that the trainers actually run and produce valid results,
-not just that they can be imported or configured.
-
-Tests marked with @pytest.mark.integration require more time/resources.
+Integration tests for RL trainers.
"""
-import pytest
+import json
+from pathlib import Path
+
import mlx.core as mx
import mlx.nn as nn
-from typing import Dict, Any, List
-
-
-# =============================================================================
-# TEST FIXTURES - Small models and datasets for fast testing
-# =============================================================================
+import pytest
+from mlx.utils import tree_flatten
-class SmallLanguageModel(nn.Module):
- """A tiny language model for testing - fast to train."""
- def __init__(self, vocab_size: int = 100, hidden_size: int = 64, num_layers: int = 2):
+class SmallBackbone(nn.Module):
+ def __init__(self, vocab_size: int = 64, hidden_size: int = 32, num_layers: int = 2):
super().__init__()
- self.vocab_size = vocab_size
- self.hidden_size = hidden_size
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.layers = [nn.Linear(hidden_size, hidden_size) for _ in range(num_layers)]
- self.output = nn.Linear(hidden_size, vocab_size)
def __call__(self, x):
h = self.embedding(x)
for layer in self.layers:
- h = mx.maximum(layer(h), 0) # ReLU
- return self.output(h)
+ h = mx.maximum(layer(h), 0)
+ return h
-class MockTokenizer:
- """Mock tokenizer for testing."""
+class SmallLanguageModel(nn.Module):
+ def __init__(self, vocab_size: int = 64, hidden_size: int = 32, num_layers: int = 2):
+ super().__init__()
+ self.model = SmallBackbone(vocab_size=vocab_size, hidden_size=hidden_size, num_layers=num_layers)
+ self.output = nn.Linear(hidden_size, vocab_size)
- def __init__(self, vocab_size: int = 100):
+ def __call__(self, x):
+ return self.output(self.model(x))
+
+
+class MockTokenizer:
+ def __init__(self, vocab_size: int = 64):
self.vocab_size = vocab_size
self.pad_token_id = 0
self.eos_token_id = 1
self.bos_token_id = 2
- self.pad_token = ""
- self.eos_token = ""
- self.name_or_path = "mock-tokenizer"
- def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
- """Simple encoding: hash characters to vocab indices."""
- ids = [hash(c) % (self.vocab_size - 3) + 3 for c in text[:50]]
+ def encode(self, text: str, add_special_tokens: bool = True):
+ ids = [((ord(char) % (self.vocab_size - 3)) + 3) for char in text[:32]]
if add_special_tokens:
ids = [self.bos_token_id] + ids + [self.eos_token_id]
return ids
- def decode(self, ids: List[int], skip_special_tokens: bool = True) -> str:
- """Simple decoding."""
+ def decode(self, ids, skip_special_tokens: bool = True):
if skip_special_tokens:
- ids = [i for i in ids if i not in (self.pad_token_id, self.eos_token_id, self.bos_token_id)]
- return "".join(chr(65 + (i % 26)) for i in ids)
-
- def __call__(self, text, return_tensors=None, padding=True, truncation=True, max_length=512):
- """Tokenize text."""
- if isinstance(text, str):
- ids = self.encode(text)
- else:
- ids = [self.encode(t) for t in text]
-
- if return_tensors == "mlx":
- return {"input_ids": mx.array(ids)}
- return {"input_ids": ids}
+ ids = [token for token in ids if token not in (self.pad_token_id, self.eos_token_id, self.bos_token_id)]
+ return "".join(chr(65 + (token % 26)) for token in ids)
class MockModelWrapper:
- """Wrapper to match FastLanguageModel interface."""
-
def __init__(self, model: SmallLanguageModel):
self.model = model
self._lora_applied = False
- self._lora_config = None
+ self.lora_config = None
+ self._adapter_path = None
def __call__(self, x):
return self.model(x)
def _apply_lora(self):
- """Mock LoRA application."""
self._lora_applied = True
- print(" [Mock] LoRA applied")
-
- def parameters(self):
- return self.model.parameters()
-
- def freeze(self):
- pass
-
- def unfreeze(self):
- pass
+ return True
+
+ def set_adapter_path(self, path: str):
+ self._adapter_path = path
+
+
+def write_legacy_rl_checkpoint(trainer, checkpoint_dir: Path):
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
+ adapter_dir = checkpoint_dir / "adapters"
+ adapter_dir.mkdir(parents=True, exist_ok=True)
+ adapter_weights = dict(tree_flatten(trainer.model.model.trainable_parameters()))
+ mx.save_safetensors(str(adapter_dir / "adapters.safetensors"), adapter_weights)
+
+ state_arrays = {
+ f"optimizer.{key}": value
+ for key, value in tree_flatten(trainer.optimizer.state)
+ }
+ state_arrays.update({f"rng.{idx}": state for idx, state in enumerate(mx.random.state)})
+ if hasattr(trainer, "_extra_state_arrays"):
+ state_arrays.update(trainer._extra_state_arrays())
+ mx.save_safetensors(str(checkpoint_dir / "trainer_state.safetensors"), state_arrays)
+
+ metadata = {
+ "algorithm": trainer.algorithm,
+ "config": trainer.config.to_dict() if hasattr(trainer.config, "to_dict") else dict(trainer.config),
+ "global_step": trainer.global_step,
+ "dataset_cursor": trainer.dataset_cursor,
+ "cache_metadata": trainer.cache_metadata,
+ }
+ (checkpoint_dir / "trainer_state.json").write_text(json.dumps(metadata, indent=2))
+
+ if trainer.reference_policy is not None:
+ reference_weights = dict(tree_flatten(trainer.reference_policy.model.model.parameters()))
+ mx.save_safetensors(str(checkpoint_dir / "reference_model.safetensors"), reference_weights)
+ (checkpoint_dir / "reference_metadata.json").write_text(
+ json.dumps(
+ {
+ "source": trainer.reference_policy.source,
+ "metadata": trainer.reference_policy.metadata,
+ },
+ indent=2,
+ )
+ )
-@pytest.fixture
-def small_model():
- """Create a small model for testing."""
- model = SmallLanguageModel(vocab_size=100, hidden_size=64, num_layers=2)
+def make_model(seed: int) -> MockModelWrapper:
+ mx.random.seed(seed)
+ model = SmallLanguageModel()
mx.eval(model.parameters())
return MockModelWrapper(model)
+def parameter_snapshot(model_wrapper: MockModelWrapper):
+ return {name: mx.array(value) for name, value in tree_flatten(model_wrapper.model.parameters())}
+
+
+def parameters_changed(before, after_wrapper: MockModelWrapper) -> bool:
+ after = {name: value for name, value in tree_flatten(after_wrapper.model.parameters())}
+ for name, before_value in before.items():
+ delta = mx.max(mx.abs(after[name] - before_value)).item()
+ if delta > 1e-6:
+ return True
+ return False
+
+
+def rollout_sequence_tensors(rollout_batch, index: int):
+ sequence = rollout_batch.prompt_ids[index] + rollout_batch.completion_ids[index]
+ return (
+ mx.array([sequence]),
+ mx.array([len(rollout_batch.prompt_ids[index])]),
+ mx.array([int(rollout_batch.completion_lengths[index].item())]),
+ )
+
+
+def test_format_metric_summary_includes_rollout_length_and_stop_stats():
+ from mlx_tune import GRPOTrainer
+
+ trainer = GRPOTrainer.__new__(GRPOTrainer)
+ row = {
+ "step": 3,
+ "train/policy_loss": 0.25,
+ "train/reward_mean": 1.0,
+ "train/logprob_delta_per_token_mean": 0.0125,
+ "train/logprob_delta_mean": 0.75,
+ "train/completion_length_mean": 12.0,
+ "train/completion_length_max": 32.0,
+ "train/eos_rate": 0.75,
+ "train/truncation_rate": 0.25,
+ "train/kl_to_reference_mean": 0.1,
+ "train/rollout_generate_wall": 4.5,
+ "train/reward_eval_wall": 0.2,
+ "train/reference_score_wall": 11.0,
+ "train/returns_wall": 0.01,
+ "train/policy_update_wall": 24.0,
+ "train/policy_update_steps": 3.0,
+ }
+
+ summary = trainer._format_metric_summary(row)
+
+ assert summary == (
+ "step=3 | policy_loss=0.2500 | reward_mean=1.0000 | "
+ "logprob_delta_per_token_mean=0.0125 | logprob_delta_mean=0.7500 | "
+ "completion_length_mean=12.0000 | completion_length_max=32.0000 | "
+ "eos_rate=0.7500 | truncation_rate=0.2500 | kl_to_reference_mean=0.1000 | "
+ "rollout_generate_wall=4.5000 | reward_eval_wall=0.2000 | "
+ "reference_score_wall=11.0000 | returns_wall=0.0100 | "
+ "policy_update_wall=24.0000 | policy_update_steps=3.0000"
+ )
+
+
@pytest.fixture
-def mock_tokenizer():
- """Create a mock tokenizer."""
- return MockTokenizer(vocab_size=100)
+def tokenizer():
+ return MockTokenizer()
@pytest.fixture
def preference_dataset():
- """Sample preference dataset for DPO/ORPO/SimPO."""
return [
{
"prompt": "What is machine learning?",
- "chosen": "Machine learning is a branch of AI that enables systems to learn from data.",
- "rejected": "idk its computers doing stuff"
+ "chosen": "Machine learning is learning from data.",
+ "rejected": "Computers do stuff.",
},
{
"prompt": "Explain Python.",
- "chosen": "Python is a high-level programming language known for readability.",
- "rejected": "python is a snake"
+ "chosen": "Python is a high-level programming language.",
+ "rejected": "Python is only a snake.",
},
{
"prompt": "What is deep learning?",
- "chosen": "Deep learning uses neural networks with many layers to learn patterns.",
- "rejected": "its like machine learning but deeper i guess"
+ "chosen": "Deep learning uses many neural-network layers.",
+ "rejected": "It is just regular learning.",
},
]
@pytest.fixture
def kto_dataset():
- """Sample dataset for KTO (binary feedback)."""
return [
- {"text": "Machine learning is a branch of AI.", "label": 1}, # Good
- {"text": "idk computers stuff", "label": 0}, # Bad
- {"text": "Python is a programming language.", "label": 1}, # Good
- {"text": "python snake", "label": 0}, # Bad
+ {"text": "Machine learning uses data.", "label": 1},
+ {"text": "Computers maybe stuff.", "label": 0},
+ {"text": "Python is a programming language.", "label": 1},
+ {"text": "Snake only.", "label": 0},
]
@pytest.fixture
def grpo_dataset():
- """Sample dataset for GRPO (reasoning with answers)."""
return [
{"prompt": "What is 2 + 2?", "answer": "4"},
{"prompt": "What is 5 * 3?", "answer": "15"},
@@ -155,681 +220,1562 @@ def grpo_dataset():
]
-# =============================================================================
-# DPO TRAINER INTEGRATION TESTS
-# =============================================================================
-
@pytest.mark.integration
-class TestDPOTrainerIntegration:
- """Integration tests for DPOTrainer."""
-
- def test_dpo_trainer_init(self, small_model, mock_tokenizer, preference_dataset):
- """Test DPOTrainer can be initialized."""
- from mlx_tune import DPOTrainer, DPOConfig
-
- config = DPOConfig(
- beta=0.1,
- learning_rate=1e-4,
- max_steps=2,
- output_dir="./test_dpo_output",
+class TestRewardAndPPOIntegration:
+ def test_reward_trainer_and_scoring_helper_save_manifest_checkpoint(
+ self,
+ tmp_path,
+ tokenizer,
+ ):
+ from mlx_tune import RewardConfig, RewardTrainer, build_reward_model, score_reward_model
+
+ reward_model = build_reward_model(make_model(50))
+ trainer = RewardTrainer(
+ model=reward_model,
+ train_dataset=[
+ {"prompt": "Q:", "response": "good", "score": 1.0},
+ {"prompt": "Q:", "response": "bad", "score": 0.0},
+ ],
+ tokenizer=tokenizer,
+ args=RewardConfig(
+ learning_rate=1e-2,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(tmp_path / "reward"),
+ ),
)
- trainer = DPOTrainer(
- model=small_model,
- train_dataset=preference_dataset,
- tokenizer=mock_tokenizer,
- args=config,
+ result = trainer.train()
+ scores = score_reward_model(
+ reward_model,
+ [{"prompt": "Q:", "response": "good"}],
+ batch_size=1,
+ tokenizer=tokenizer,
)
- assert trainer is not None
- assert trainer.beta == 0.1
- assert len(trainer.train_dataset) == 3
+ assert result["status"] == "success"
+ assert len(scores) == 1
+ assert (tmp_path / "reward" / "manifest.json").exists()
+ assert (tmp_path / "reward" / "reward_model" / "head.safetensors").exists()
+
+ def test_ppo_trainer_persists_policy_and_value_optimizers(
+ self,
+ tmp_path,
+ tokenizer,
+ grpo_dataset,
+ ):
+ from mlx_tune import PPOConfig, PPOTrainer, build_value_model
+
+ trainer = PPOTrainer(
+ model=make_model(51),
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ value_model=build_value_model(make_model(52)),
+ args=PPOConfig(
+ learning_rate=1e-2,
+ value_learning_rate=2e-2,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ num_generations=2,
+ max_completion_length=4,
+ output_dir=str(tmp_path / "ppo"),
+ ),
+ )
- def test_dpo_trainer_train_runs(self, small_model, mock_tokenizer, preference_dataset):
- """Test DPOTrainer.train() executes without errors."""
- from mlx_tune import DPOTrainer, DPOConfig
+ result = trainer.train()
- config = DPOConfig(
- beta=0.1,
- learning_rate=1e-4,
- max_steps=2, # Very short for testing
- output_dir="./test_dpo_output",
+ assert result["status"] == "success"
+ assert trainer._last_rollout_batch is not None
+ assert trainer._last_rollout_batch.value_predictions is not None
+ assert trainer._last_rollout_batch.returns is not None
+ assert (tmp_path / "ppo" / "optimizers" / "policy" / "state.safetensors").exists()
+ assert (tmp_path / "ppo" / "optimizers" / "value" / "state.safetensors").exists()
+
+ def test_ppo_trainer_uses_requested_advantage_estimator(
+ self,
+ tmp_path,
+ tokenizer,
+ grpo_dataset,
+ ):
+ from mlx_tune import PPOConfig, PPOTrainer, build_value_model
+ from mlx_tune._rl_runtime import compute_returns_and_advantages
+
+ trainer = PPOTrainer(
+ model=make_model(60),
+ ref_model=make_model(61),
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ value_model=build_value_model(make_model(62)),
+ args=PPOConfig(
+ advantage_estimator="rloo",
+ normalize_advantages=False,
+ temperature=0.0,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ num_generations=2,
+ max_completion_length=4,
+ output_dir=str(tmp_path / "ppo_rloo"),
+ ),
)
- trainer = DPOTrainer(
- model=small_model,
- train_dataset=preference_dataset,
- tokenizer=mock_tokenizer,
- args=config,
+ trainer._ensure_reference_policy()
+ trainer._prepare_prompt_samples()
+ rollout_batch = trainer._collect_rollout_batch(trainer._next_prompt_batch())
+
+ _, expected_advantages = compute_returns_and_advantages(
+ rewards=rollout_batch.rewards,
+ prompt_group_indices=rollout_batch.prompt_group_indices,
+ mode="rloo",
+ gamma=trainer.gamma,
+ gae_lambda=trainer.gae_lambda,
+ normalize=False,
)
- # This should run without raising exceptions
- result = trainer.train()
-
- assert result is not None
- # Verify training completed - check model's _lora_applied flag
- assert small_model._lora_applied, "LoRA should have been applied during training"
-
- def test_dpo_loss_decreases(self, small_model, mock_tokenizer, preference_dataset):
- """Test that DPO loss decreases or stays stable during training."""
- from mlx_tune import DPOTrainer, DPOConfig
+ assert mx.allclose(rollout_batch.advantages, expected_advantages)
+
+ def test_on_policy_trainers_activate_kl_controls(
+ self,
+ tmp_path,
+ tokenizer,
+ grpo_dataset,
+ ):
+ from mlx_tune import (
+ GRPOConfig,
+ GRPOTrainer,
+ OnlineDPOConfig,
+ OnlineDPOTrainer,
+ PPOConfig,
+ PPOTrainer,
+ build_value_model,
+ )
- config = DPOConfig(
- beta=0.1,
- learning_rate=1e-3, # Higher LR for visible change
- max_steps=5,
- output_dir="./test_dpo_output",
+ grpo = GRPOTrainer(
+ model=make_model(70),
+ ref_model=make_model(71),
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=GRPOConfig(
+ beta=0.3,
+ kl_target=1e-4,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ num_generations=2,
+ max_completion_length=3,
+ output_dir=str(tmp_path / "grpo_kl"),
+ ),
+ )
+ grpo._ensure_reference_policy()
+ grpo._prepare_prompt_samples()
+ grpo_rollout = grpo._collect_rollout_batch(grpo._next_prompt_batch())
+ assert grpo._effective_kl_beta(grpo_rollout) != grpo.beta
+
+ ppo = PPOTrainer(
+ model=make_model(72),
+ ref_model=make_model(73),
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ value_model=build_value_model(make_model(74)),
+ args=PPOConfig(
+ beta=0.4,
+ kl_penalty_mode="none",
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ num_generations=2,
+ max_completion_length=3,
+ output_dir=str(tmp_path / "ppo_kl_none"),
+ ),
+ )
+ ppo._ensure_reference_policy()
+ ppo._prepare_prompt_samples()
+ ppo_rollout = ppo._collect_rollout_batch(ppo._next_prompt_batch())
+ assert ppo._effective_kl_beta(ppo_rollout) == 0.0
+
+ online_dpo = OnlineDPOTrainer(
+ model=make_model(75),
+ ref_model=make_model(76),
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=OnlineDPOConfig(
+ beta=0.2,
+ kl_target=1e-4,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ num_generations=2,
+ max_completion_length=3,
+ output_dir=str(tmp_path / "online_dpo_kl"),
+ ),
)
+ online_dpo._ensure_reference_policy()
+ online_dpo._prepare_prompt_samples()
+ online_rollout = online_dpo._collect_rollout_batch(online_dpo._next_prompt_batch())
+ assert online_dpo._effective_kl_beta(online_rollout) != online_dpo.beta
+
+@pytest.mark.integration
+class TestDPOTrainerIntegration:
+ def test_dpo_frozen_reference_stays_unchanged_across_policy_updates(
+ self,
+ tmp_path,
+ tokenizer,
+ preference_dataset,
+ ):
+ from mlx_tune import DPOConfig, DPOTrainer, compute_reference_logprobs
+
+ model = make_model(0)
trainer = DPOTrainer(
- model=small_model,
+ model=model,
train_dataset=preference_dataset,
- tokenizer=mock_tokenizer,
- args=config,
+ tokenizer=tokenizer,
+ args=DPOConfig(
+ learning_rate=5e-2,
+ max_steps=3,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(tmp_path),
+ ),
)
result = trainer.train()
+ assert result["global_step"] == 3
+ assert trainer.reference_policy is not None
+
+ pad_id = tokenizer.pad_token_id
+ sample = trainer.train_samples[0]
+ chosen = mx.array([sample["chosen_ids"]])
+ rejected = mx.array([sample["rejected_ids"]])
+ chosen_lengths = mx.array([sample["chosen_length"]])
+ rejected_lengths = mx.array([sample["rejected_length"]])
+
+ ref_chosen, ref_rejected = compute_reference_logprobs(
+ trainer.reference_policy.model.model,
+ chosen,
+ rejected,
+ chosen_lengths,
+ rejected_lengths,
+ )
- # Check no NaN in result
- if isinstance(result, dict) and 'final_loss' in result:
- assert not mx.isnan(mx.array(result['final_loss'])), "Final loss should not be NaN"
-
- def test_dpo_different_loss_types(self, small_model, mock_tokenizer, preference_dataset):
- """Test DPO with different loss types."""
- from mlx_tune import DPOTrainer, DPOConfig
-
- loss_types = ["sigmoid", "hinge", "ipo"]
-
- for loss_type in loss_types:
- # Create fresh model for each test
- model = MockModelWrapper(SmallLanguageModel())
- mx.eval(model.model.parameters())
-
- config = DPOConfig(
- beta=0.1,
- loss_type=loss_type,
- learning_rate=1e-4,
- max_steps=2,
- output_dir="./test_dpo_output",
- )
-
- trainer = DPOTrainer(
- model=model,
- train_dataset=preference_dataset,
- tokenizer=mock_tokenizer,
- args=config,
- )
-
- # Should not raise
- result = trainer.train()
- assert result is not None, f"DPO with loss_type={loss_type} failed"
-
-
-# =============================================================================
-# ORPO TRAINER INTEGRATION TESTS
-# =============================================================================
-
-@pytest.mark.integration
-class TestORPOTrainerIntegration:
- """Integration tests for ORPOTrainer."""
-
- def test_orpo_trainer_init(self, small_model, mock_tokenizer, preference_dataset):
- """Test ORPOTrainer can be initialized."""
- from mlx_tune import ORPOTrainer, ORPOConfig
+ assert mx.allclose(ref_chosen, mx.array([sample["reference_chosen_logprobs"]]))
+ assert mx.allclose(ref_rejected, mx.array([sample["reference_rejected_logprobs"]]))
+ assert model._lora_applied
+
+ def test_dpo_loss_changes_when_reference_cache_changes(
+ self,
+ tokenizer,
+ preference_dataset,
+ ):
+ from mlx_tune import dpo_loss, precompute_preference_reference_logprobs
+
+ policy = make_model(1)
+ reference_a = make_model(2)
+ reference_b = make_model(3)
+
+ prompt = preference_dataset[0]["prompt"]
+ chosen = tokenizer.encode(prompt + preference_dataset[0]["chosen"])
+ rejected = tokenizer.encode(prompt + preference_dataset[0]["rejected"])
+
+ chosen_ids = mx.array([chosen])
+ rejected_ids = mx.array([rejected])
+ chosen_lengths = mx.array([len(chosen)])
+ rejected_lengths = mx.array([len(rejected)])
+
+ ref_a = precompute_preference_reference_logprobs(
+ reference_a.model,
+ chosen_ids,
+ rejected_ids,
+ chosen_lengths,
+ rejected_lengths,
+ )
+ ref_b = precompute_preference_reference_logprobs(
+ reference_b.model,
+ chosen_ids,
+ rejected_ids,
+ chosen_lengths,
+ rejected_lengths,
+ )
- config = ORPOConfig(
+ loss_a, _ = dpo_loss(
+ policy.model,
+ chosen_ids,
+ rejected_ids,
+ chosen_lengths,
+ rejected_lengths,
beta=0.1,
- learning_rate=1e-4,
- max_steps=2,
- output_dir="./test_orpo_output",
+ reference_chosen_logprobs=ref_a[0],
+ reference_rejected_logprobs=ref_a[1],
)
-
- trainer = ORPOTrainer(
- model=small_model,
- train_dataset=preference_dataset,
- tokenizer=mock_tokenizer,
- args=config,
+ loss_b, _ = dpo_loss(
+ policy.model,
+ chosen_ids,
+ rejected_ids,
+ chosen_lengths,
+ rejected_lengths,
+ beta=0.1,
+ reference_chosen_logprobs=ref_b[0],
+ reference_rejected_logprobs=ref_b[1],
)
- assert trainer is not None
- assert trainer.beta == 0.1
+ assert abs(loss_a.item() - loss_b.item()) > 1e-6
- def test_orpo_trainer_train_runs(self, small_model, mock_tokenizer, preference_dataset):
- """Test ORPOTrainer.train() executes without errors."""
- from mlx_tune import ORPOTrainer, ORPOConfig
+ def test_dpo_resume_restores_state_and_cache(
+ self,
+ tmp_path,
+ tokenizer,
+ preference_dataset,
+ ):
+ from mlx_tune import DPOConfig, DPOTrainer
- config = ORPOConfig(
- beta=0.1,
- learning_rate=1e-4,
- max_steps=2,
- output_dir="./test_orpo_output",
+ output_dir = tmp_path / "dpo_resume"
+ trainer = DPOTrainer(
+ model=make_model(4),
+ train_dataset=preference_dataset,
+ tokenizer=tokenizer,
+ args=DPOConfig(
+ learning_rate=1e-2,
+ max_steps=2,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
)
+ trainer.train()
- trainer = ORPOTrainer(
- model=small_model,
+ resumed = DPOTrainer(
+ model=make_model(5),
train_dataset=preference_dataset,
- tokenizer=mock_tokenizer,
- args=config,
+ tokenizer=tokenizer,
+ args=DPOConfig(
+ learning_rate=1e-2,
+ max_steps=4,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+ result = resumed.train(resume_from_checkpoint=str(output_dir))
+
+ assert result["global_step"] == 4
+ assert resumed.cache_metadata == trainer.cache_metadata
+ assert resumed.optimizer is not None
+ assert resumed.optimizer.state["step"].item() == 4
+ assert mx.allclose(
+ mx.array([sample["reference_chosen_logprobs"] for sample in resumed.train_samples]),
+ mx.array([sample["reference_chosen_logprobs"] for sample in trainer.train_samples]),
)
- result = trainer.train()
- assert result is not None
-
- def test_orpo_combines_sft_and_preference(self, small_model, mock_tokenizer, preference_dataset):
- """Test that ORPO combines SFT and preference learning."""
- from mlx_tune import ORPOTrainer, ORPOConfig
- config = ORPOConfig(
- beta=0.1,
- learning_rate=1e-3,
- max_steps=3,
- output_dir="./test_orpo_output",
- )
+@pytest.mark.integration
+class TestGRPOTrainerIntegration:
+ def test_grpo_training_changes_parameters_and_increases_rewarded_logprob(
+ self,
+ tmp_path,
+ tokenizer,
+ grpo_dataset,
+ ):
+ from mlx_tune import GRPOConfig, GRPOTrainer, compute_completion_log_probs
+
+ mx.random.seed(7)
+ model = make_model(6)
+ before = parameter_snapshot(model)
- trainer = ORPOTrainer(
- model=small_model,
- train_dataset=preference_dataset,
- tokenizer=mock_tokenizer,
- args=config,
+ trainer = GRPOTrainer(
+ model=model,
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=GRPOConfig(
+ learning_rate=5e-2,
+ beta=0.01,
+ num_generations=3,
+ max_completion_length=4,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(tmp_path),
+ ),
)
result = trainer.train()
+ rollout = trainer._last_rollout_batch
- # ORPO should complete successfully
- assert result is not None
+ assert result["global_step"] == 1
+ assert rollout is not None
+ assert parameters_changed(before, model)
+ best_index = int(mx.argmax(rollout.rewards).item())
+ input_ids, prompt_lengths, completion_lengths = rollout_sequence_tensors(rollout, best_index)
+ updated_logprob = compute_completion_log_probs(
+ model.model,
+ input_ids,
+ prompt_lengths,
+ completion_lengths,
+ )[0].item()
+ rollout_logprob = rollout.rollout_logprobs[best_index].item()
-# =============================================================================
-# GRPO TRAINER INTEGRATION TESTS (Most important for reasoning)
-# =============================================================================
+ assert updated_logprob > rollout_logprob
-@pytest.mark.integration
-class TestGRPOTrainerIntegration:
- """Integration tests for GRPOTrainer - DeepSeek R1 style reasoning."""
+ def test_grpo_prefers_answer_context_over_prompt(
+ self,
+ tmp_path,
+ tokenizer,
+ ):
+ from mlx_tune import GRPOConfig, GRPOTrainer
- def test_grpo_trainer_init(self, small_model, mock_tokenizer, grpo_dataset):
- """Test GRPOTrainer can be initialized."""
- from mlx_tune import GRPOTrainer, GRPOConfig, create_reward_function
+ seen_contexts = []
- reward_fn = create_reward_function("simple")
-
- config = GRPOConfig(
- beta=0.04,
- num_generations=2, # Small for testing
- learning_rate=1e-5,
- max_steps=2,
- output_dir="./test_grpo_output",
- )
+ def reward_fn(response: str, context: str) -> float:
+ seen_contexts.append(context)
+ return float(len(response))
trainer = GRPOTrainer(
- model=small_model,
- train_dataset=grpo_dataset,
- tokenizer=mock_tokenizer,
+ model=make_model(7),
+ train_dataset=[{"prompt": "Solve 2 + 2", "answer": "4"}],
+ tokenizer=tokenizer,
reward_fn=reward_fn,
- args=config,
+ args=GRPOConfig(
+ learning_rate=1e-2,
+ num_generations=2,
+ max_completion_length=3,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(tmp_path),
+ ),
)
+ trainer.train()
- assert trainer is not None
- assert trainer.num_generations == 2
- assert trainer.reward_fn is not None
+ assert seen_contexts
+ assert all(context == "4" for context in seen_contexts)
- def test_grpo_trainer_train_runs(self, small_model, mock_tokenizer, grpo_dataset):
- """Test GRPOTrainer.train() executes without errors."""
- from mlx_tune import GRPOTrainer, GRPOConfig, create_reward_function
-
- reward_fn = create_reward_function("simple")
+ def test_grpo_phase1_accepts_documented_loss_type_aliases(
+ self,
+ tmp_path,
+ tokenizer,
+ grpo_dataset,
+ ):
+ from mlx_tune import GRPOConfig, GRPOTrainer
+ for loss_type in ["grpo", "dr_grpo", "dapo", "bnpo"]:
+ trainer = GRPOTrainer(
+ model=make_model(20),
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=GRPOConfig(
+ loss_type=loss_type,
+ learning_rate=1e-2,
+ beta=0.01,
+ num_generations=2,
+ max_completion_length=3,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(tmp_path / loss_type),
+ ),
+ )
+ result = trainer.train()
+ assert result["status"] == "success"
+ assert trainer.phase1_loss_type == "phase1_shared_rollout_recompute"
+
+ def test_grpo_resume_restores_rng_and_optimizer_state(
+ self,
+ tmp_path,
+ tokenizer,
+ grpo_dataset,
+ ):
+ from mlx_tune import GRPOConfig, GRPOTrainer
+
+ output_dir = tmp_path / "grpo_resume"
config = GRPOConfig(
- beta=0.04,
+ learning_rate=1e-2,
+ beta=0.01,
num_generations=2,
- learning_rate=1e-5,
- max_steps=2,
- temperature=0.7,
- output_dir="./test_grpo_output",
+ max_completion_length=4,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
)
+ mx.random.seed(11)
trainer = GRPOTrainer(
- model=small_model,
+ model=make_model(8),
train_dataset=grpo_dataset,
- tokenizer=mock_tokenizer,
- reward_fn=reward_fn,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
args=config,
)
+ trainer.train()
- result = trainer.train()
- assert result is not None
-
- def test_grpo_multi_generation(self, small_model, mock_tokenizer, grpo_dataset):
- """Test that GRPO generates multiple completions per prompt."""
- from mlx_tune import GRPOTrainer, GRPOConfig, create_reward_function
-
- reward_fn = create_reward_function("simple")
- num_gens = 3
-
- config = GRPOConfig(
- beta=0.04,
- num_generations=num_gens,
- learning_rate=1e-5,
- max_steps=1, # Just one step to verify multi-gen
- output_dir="./test_grpo_output",
+ def load_and_rollout(seed: int):
+ restored = GRPOTrainer(
+ model=make_model(seed),
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=GRPOConfig(
+ learning_rate=1e-2,
+ beta=0.01,
+ num_generations=2,
+ max_completion_length=4,
+ max_steps=2,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+ restored._apply_lora_if_needed()
+ restored._prepare_prompt_samples()
+ optimizer = restored._optimizer_for_training()
+ restored.optimizer = optimizer
+ restored.load_state(optimizer, Path(output_dir))
+ rollout = restored._collect_rollout_batch(restored._next_samples(restored.prompt_samples))
+ return restored, rollout
+
+ restored_a, rollout_a = load_and_rollout(9)
+ restored_b, rollout_b = load_and_rollout(9)
+
+ assert restored_a.optimizer.state["step"].item() == 1
+ assert rollout_a.completion_ids == rollout_b.completion_ids
+ assert mx.allclose(rollout_a.rollout_logprobs, rollout_b.rollout_logprobs)
+
+ resumed = GRPOTrainer(
+ model=make_model(12),
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=GRPOConfig(
+ learning_rate=1e-2,
+ beta=0.01,
+ num_generations=2,
+ max_completion_length=4,
+ max_steps=2,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+ result = resumed.train(resume_from_checkpoint=str(output_dir))
+ assert result["global_step"] == 2
+
+ def test_grpo_prefers_learned_reward_model_over_reward_fn(
+ self,
+ tmp_path,
+ tokenizer,
+ grpo_dataset,
+ ):
+ from mlx_tune import GRPOConfig, GRPOTrainer, build_reward_model
+
+ reward_model = build_reward_model(make_model(30))
+ reward_model.head.update(
+ {
+ "weight": mx.zeros_like(reward_model.head.weight),
+ "bias": mx.array([1.0], dtype=mx.float32),
+ },
+ strict=False,
)
+ mx.eval(reward_model.head.parameters())
trainer = GRPOTrainer(
- model=small_model,
+ model=make_model(31),
train_dataset=grpo_dataset,
- tokenizer=mock_tokenizer,
- reward_fn=reward_fn,
- args=config,
+ tokenizer=tokenizer,
+ reward_model=reward_model,
+ reward_fn=lambda response, context: (_ for _ in ()).throw(RuntimeError("reward_fn should not run")),
+ args=GRPOConfig(
+ learning_rate=1e-2,
+ beta=0.01,
+ num_generations=2,
+ max_completion_length=4,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(tmp_path),
+ ),
)
- # The trainer should be configured for multi-generation
- assert trainer.num_generations == num_gens
-
- # Training should work
result = trainer.train()
- assert result is not None
-
- def test_grpo_with_math_reward(self, small_model, mock_tokenizer, grpo_dataset):
- """Test GRPO with math reward function."""
- from mlx_tune import GRPOTrainer, GRPOConfig, create_reward_function
-
- math_reward = create_reward_function("math")
-
- config = GRPOConfig(
- beta=0.04,
- num_generations=2,
- learning_rate=1e-5,
- max_steps=2,
- output_dir="./test_grpo_output",
+ assert result["status"] == "success"
+ assert trainer._last_rollout_batch is not None
+ assert mx.allclose(
+ trainer._last_rollout_batch.rewards,
+ mx.ones_like(trainer._last_rollout_batch.rewards),
)
+ def test_grpo_offline_reward_source_uses_dataset_rewards_without_reward_evaluator(
+ self,
+ tmp_path,
+ tokenizer,
+ ):
+ from mlx_tune import GRPOConfig, GRPOTrainer
+
trainer = GRPOTrainer(
- model=small_model,
- train_dataset=grpo_dataset,
- tokenizer=mock_tokenizer,
- reward_fn=math_reward,
- args=config,
+ model=make_model(63),
+ train_dataset=[
+ {"prompt": "Solve 2 + 2", "completion": "4", "reward": 1.0},
+ {"prompt": "Solve 2 + 2", "completion": "5", "reward": -1.0},
+ ],
+ tokenizer=tokenizer,
+ reward_fn=lambda *_: (_ for _ in ()).throw(RuntimeError("reward_fn should not run")),
+ args=GRPOConfig(
+ learning_rate=1e-2,
+ beta=0.01,
+ num_generations=2,
+ reward_source="offline",
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(tmp_path / "grpo_offline"),
+ ),
)
result = trainer.train()
- assert result is not None
- def test_grpo_different_loss_types(self, small_model, mock_tokenizer, grpo_dataset):
- """Test GRPO with different loss types (grpo, dr_grpo, dapo, bnpo)."""
- from mlx_tune import GRPOTrainer, GRPOConfig, create_reward_function
+ assert result["status"] == "success"
+ assert trainer._last_rollout_batch is not None
+ assert trainer._last_rollout_batch.rewards.tolist() == [1.0, -1.0]
- reward_fn = create_reward_function("simple")
- loss_types = ["grpo", "dr_grpo", "dapo", "bnpo"]
+ def test_grpo_single_reward_source_component_preserves_weight_semantics(
+ self,
+ tmp_path,
+ tokenizer,
+ ):
+ from mlx_tune import GRPOConfig, GRPOTrainer
- for loss_type in loss_types:
- model = MockModelWrapper(SmallLanguageModel())
- mx.eval(model.model.parameters())
-
- config = GRPOConfig(
- loss_type=loss_type,
- beta=0.04,
+ trainer = GRPOTrainer(
+ model=make_model(64),
+ train_dataset=[{"prompt": "Solve 2 + 2", "answer": "4"}],
+ tokenizer=tokenizer,
+ args=GRPOConfig(
+ learning_rate=1e-2,
+ beta=0.01,
num_generations=2,
- learning_rate=1e-5,
+ max_completion_length=3,
max_steps=1,
- output_dir="./test_grpo_output",
- )
+ logging_steps=1,
+ save_steps=1,
+ reward_sources=[
+ {"name": "zeroed", "source": "length", "weight": 0.0},
+ ],
+ output_dir=str(tmp_path / "grpo_weighted_single_reward"),
+ ),
+ )
- trainer = GRPOTrainer(
- model=model,
- train_dataset=grpo_dataset,
- tokenizer=mock_tokenizer,
- reward_fn=reward_fn,
- args=config,
- )
+ result = trainer.train()
- result = trainer.train()
- assert result is not None, f"GRPO with loss_type={loss_type} failed"
+ assert result["status"] == "success"
+ assert trainer._last_rollout_batch is not None
+ assert mx.allclose(
+ trainer._last_rollout_batch.rewards,
+ mx.zeros_like(trainer._last_rollout_batch.rewards),
+ )
- def test_grpo_custom_reward_function(self, small_model, mock_tokenizer, grpo_dataset):
- """Test GRPO with a custom reward function."""
- from mlx_tune import GRPOTrainer, GRPOConfig
+ def test_grpo_prompt_rollouts_respect_rollout_batch_size(
+ self,
+ tmp_path,
+ tokenizer,
+ ):
+ from mlx_tune import GRPOConfig, GRPOTrainer
- # Custom reward: reward longer responses
- def length_reward(response: str, answer: str = None) -> float:
- return len(response) / 100.0 # Normalize
+ trainer = GRPOTrainer(
+ model=make_model(65),
+ train_dataset=[
+ {"prompt": "What is 1 + 1?", "answer": "2"},
+ {"prompt": "What is 2 + 2?", "answer": "4"},
+ {"prompt": "What is 3 + 3?", "answer": "6"},
+ ],
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=GRPOConfig(
+ learning_rate=1e-2,
+ beta=0.01,
+ per_device_train_batch_size=1,
+ rollout_batch_size=3,
+ num_generations=2,
+ max_completion_length=3,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(tmp_path / "grpo_rollout_batch_size"),
+ ),
+ )
- config = GRPOConfig(
- beta=0.04,
- num_generations=2,
- learning_rate=1e-5,
- max_steps=2,
- output_dir="./test_grpo_output",
+ result = trainer.train()
+
+ assert result["status"] == "success"
+ assert trainer._last_rollout_batch is not None
+ assert len(trainer._last_rollout_batch.prompt_ids) == 6
+ assert sorted(set(trainer._last_rollout_batch.prompt_group_indices.tolist())) == [0, 1, 2]
+
+ def test_grpo_manifest_checkpoint_persists_roles_and_metrics(
+ self,
+ tmp_path,
+ tokenizer,
+ grpo_dataset,
+ ):
+ from mlx_tune import GRPOConfig, GRPOTrainer, build_reward_model, build_value_model
+
+ output_dir = tmp_path / "grpo_manifest"
+ reward_model = build_reward_model(make_model(33))
+ value_model = build_value_model(make_model(34))
+ score_sequence = mx.array([[2, 4, 7]], dtype=mx.int32)
+ score_sequence_lengths = mx.array([3], dtype=mx.int32)
+ score_prompt_lengths = mx.array([1], dtype=mx.int32)
+ score_completion_lengths = mx.array([2], dtype=mx.int32)
+ saved_reward_score = reward_model.score(
+ score_sequence,
+ sequence_lengths=score_sequence_lengths,
+ prompt_lengths=score_prompt_lengths,
+ completion_lengths=score_completion_lengths,
+ )
+ saved_value_score = value_model.predict(
+ score_sequence,
+ sequence_lengths=score_sequence_lengths,
+ prompt_lengths=score_prompt_lengths,
+ completion_lengths=score_completion_lengths,
)
trainer = GRPOTrainer(
- model=small_model,
+ model=make_model(32),
train_dataset=grpo_dataset,
- tokenizer=mock_tokenizer,
- reward_fn=length_reward,
- args=config,
+ tokenizer=tokenizer,
+ reward_model=reward_model,
+ value_model=value_model,
+ args=GRPOConfig(
+ learning_rate=1e-2,
+ beta=0.01,
+ num_generations=2,
+ max_completion_length=4,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
)
+ trainer.train()
+
+ assert (output_dir / "manifest.json").exists()
+ assert (output_dir / "policy" / "role.json").exists()
+ assert (output_dir / "reference" / "weights.safetensors").exists()
+ assert (output_dir / "reward_model" / "weights.safetensors").exists()
+ assert (output_dir / "reward_model" / "head.safetensors").exists()
+ assert (output_dir / "value_model" / "weights.safetensors").exists()
+ assert (output_dir / "value_model" / "head.safetensors").exists()
+ assert (output_dir / "optimizer" / "state.safetensors").exists()
+ assert (output_dir / "scheduler" / "state.json").exists()
+ assert (output_dir / "trainer" / "state.json").exists()
+ assert (output_dir / "trainer" / "rng.safetensors").exists()
+ assert (output_dir / "metrics" / "history.jsonl").exists()
+ assert trainer.metrics_history
+
+ resumed = GRPOTrainer(
+ model=make_model(35),
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ args=GRPOConfig(
+ learning_rate=1e-2,
+ beta=0.01,
+ num_generations=2,
+ max_completion_length=4,
+ max_steps=2,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+ result = resumed.train(resume_from_checkpoint=str(output_dir))
+
+ assert result["global_step"] == 2
+ assert resumed.reward_model is not None
+ assert resumed.value_model is not None
+ assert mx.allclose(
+ resumed.reward_model.score(
+ score_sequence,
+ sequence_lengths=score_sequence_lengths,
+ prompt_lengths=score_prompt_lengths,
+ completion_lengths=score_completion_lengths,
+ ),
+ saved_reward_score,
+ )
+ assert mx.allclose(
+ resumed.value_model.predict(
+ score_sequence,
+ sequence_lengths=score_sequence_lengths,
+ prompt_lengths=score_prompt_lengths,
+ completion_lengths=score_completion_lengths,
+ ),
+ saved_value_score,
+ )
+ assert len(resumed.metrics_history) >= len(trainer.metrics_history)
+ assert resumed.loaded_checkpoint_manifest is not None
- result = trainer.train()
- assert result is not None
-
-
-# =============================================================================
-# KTO TRAINER INTEGRATION TESTS
-# =============================================================================
@pytest.mark.integration
class TestKTOTrainerIntegration:
- """Integration tests for KTOTrainer."""
-
- def test_kto_trainer_init(self, small_model, mock_tokenizer, kto_dataset):
- """Test KTOTrainer can be initialized."""
- from mlx_tune import KTOTrainer
+ def test_kto_uses_cached_reference_logprobs_instead_of_live_policy_outputs(
+ self,
+ tmp_path,
+ tokenizer,
+ kto_dataset,
+ ):
+ from mlx_tune import KTOTrainer, compute_log_probs_with_lengths, kto_loss
trainer = KTOTrainer(
- model=small_model,
+ model=make_model(13),
train_dataset=kto_dataset,
- tokenizer=mock_tokenizer,
- learning_rate=1e-4,
+ tokenizer=tokenizer,
+ learning_rate=2e-2,
max_steps=2,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(tmp_path),
)
-
- assert trainer is not None
-
- def test_kto_trainer_train_runs(self, small_model, mock_tokenizer, kto_dataset):
- """Test KTOTrainer.train() executes without errors."""
+ trainer.train()
+
+ sample = trainer.train_samples[0]
+ batch = trainer._build_batch([sample])
+ cached_loss, _ = kto_loss(
+ trainer.model.model,
+ batch.input_ids,
+ batch.sequence_lengths,
+ batch.labels,
+ beta=trainer.beta,
+ reference_logprobs=batch.reference_logprobs,
+ )
+ live_policy_reference = compute_log_probs_with_lengths(
+ trainer.model.model,
+ batch.input_ids,
+ batch.sequence_lengths,
+ )
+ live_loss, _ = kto_loss(
+ trainer.model.model,
+ batch.input_ids,
+ batch.sequence_lengths,
+ batch.labels,
+ beta=trainer.beta,
+ reference_logprobs=live_policy_reference,
+ )
+ assert abs(cached_loss.item() - live_loss.item()) > 1e-6
+
+ def test_legacy_rl_checkpoint_loads_and_resaves_manifest_layout(
+ self,
+ tmp_path,
+ tokenizer,
+ kto_dataset,
+ ):
from mlx_tune import KTOTrainer
- trainer = KTOTrainer(
- model=small_model,
+ source_trainer = KTOTrainer(
+ model=make_model(40),
train_dataset=kto_dataset,
- tokenizer=mock_tokenizer,
- learning_rate=1e-4,
- max_steps=2,
+ tokenizer=tokenizer,
+ learning_rate=2e-2,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(tmp_path / "source"),
)
+ source_trainer.train()
- result = trainer.train()
- assert result is not None
-
- def test_kto_binary_feedback(self, small_model, mock_tokenizer, kto_dataset):
- """Test KTO processes binary feedback correctly."""
- from mlx_tune import KTOTrainer
-
- # Ensure dataset has both positive and negative examples
- assert any(d['label'] == 1 for d in kto_dataset), "Need positive examples"
- assert any(d['label'] == 0 for d in kto_dataset), "Need negative examples"
+ legacy_dir = tmp_path / "legacy_kto"
+ write_legacy_rl_checkpoint(source_trainer, legacy_dir)
- trainer = KTOTrainer(
- model=small_model,
+ resumed = KTOTrainer(
+ model=make_model(41),
train_dataset=kto_dataset,
- tokenizer=mock_tokenizer,
- learning_rate=1e-4,
- max_steps=3,
+ tokenizer=tokenizer,
+ learning_rate=2e-2,
+ max_steps=2,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(legacy_dir),
)
+ result = resumed.train(resume_from_checkpoint=str(legacy_dir))
- result = trainer.train()
- assert result is not None
+ assert result["global_step"] == 2
+ assert resumed.cache_metadata == source_trainer.cache_metadata
+ assert (legacy_dir / "manifest.json").exists()
+ assert (legacy_dir / "reference" / "weights.safetensors").exists()
+ assert (legacy_dir / "runtime" / "cache.safetensors").exists()
-# =============================================================================
-# SIMPO TRAINER INTEGRATION TESTS
-# =============================================================================
-
@pytest.mark.integration
-class TestSimPOTrainerIntegration:
- """Integration tests for SimPOTrainer."""
-
- def test_simpo_trainer_init(self, small_model, mock_tokenizer, preference_dataset):
- """Test SimPOTrainer can be initialized."""
- from mlx_tune import SimPOTrainer
-
- trainer = SimPOTrainer(
- model=small_model,
+class TestOtherRLTrainers:
+ def test_orpo_and_simpo_still_train(
+ self,
+ tmp_path,
+ tokenizer,
+ preference_dataset,
+ ):
+ from mlx_tune import ORPOConfig, ORPOTrainer, SimPOTrainer
+
+ orpo = ORPOTrainer(
+ model=make_model(14),
train_dataset=preference_dataset,
- tokenizer=mock_tokenizer,
- learning_rate=1e-4,
- max_steps=2,
+ tokenizer=tokenizer,
+ args=ORPOConfig(
+ learning_rate=1e-2,
+ max_steps=1,
+ output_dir=str(tmp_path / "orpo"),
+ ),
)
-
- assert trainer is not None
-
- def test_simpo_trainer_train_runs(self, small_model, mock_tokenizer, preference_dataset):
- """Test SimPOTrainer.train() executes without errors."""
- from mlx_tune import SimPOTrainer
-
- trainer = SimPOTrainer(
- model=small_model,
+ simpo = SimPOTrainer(
+ model=make_model(15),
train_dataset=preference_dataset,
- tokenizer=mock_tokenizer,
- learning_rate=1e-4,
- max_steps=2,
+ tokenizer=tokenizer,
+ learning_rate=1e-2,
+ max_steps=1,
+ output_dir=str(tmp_path / "simpo"),
)
- result = trainer.train()
- assert result is not None
-
- def test_simpo_no_reference_model(self, small_model, mock_tokenizer, preference_dataset):
- """Test SimPO works without reference model."""
- from mlx_tune import SimPOTrainer
+ assert orpo.train()["status"] == "success"
+ assert simpo.train()["status"] == "success"
+
+ def test_simpo_uses_configured_per_device_train_batch_size(
+ self,
+ tmp_path,
+ tokenizer,
+ preference_dataset,
+ monkeypatch,
+ ):
+ import mlx_tune.rl_trainers as rl_trainers_module
+ from mlx_tune import SimPOConfig, SimPOTrainer
+
+ observed_batch_sizes = []
+ original = rl_trainers_module.compute_simpo_loss
+
+ def wrapped(model, chosen_ids, rejected_ids, chosen_lengths, rejected_lengths, beta, gamma):
+ observed_batch_sizes.append(int(chosen_ids.shape[0]))
+ return original(
+ model,
+ chosen_ids,
+ rejected_ids,
+ chosen_lengths,
+ rejected_lengths,
+ beta,
+ gamma,
+ )
- # SimPO is special: it doesn't require a reference model
+ monkeypatch.setattr(rl_trainers_module, "compute_simpo_loss", wrapped)
trainer = SimPOTrainer(
- model=small_model,
+ model=make_model(16),
train_dataset=preference_dataset,
- tokenizer=mock_tokenizer,
- learning_rate=1e-4,
- max_steps=3,
+ tokenizer=tokenizer,
+ args=SimPOConfig(
+ learning_rate=1e-2,
+ per_device_train_batch_size=2,
+ max_steps=1,
+ output_dir=str(tmp_path / "simpo_batch"),
+ ),
)
result = trainer.train()
- assert result is not None
+ assert result["status"] == "success"
+ assert observed_batch_sizes == [2]
+
+ def test_online_dpo_uses_reward_model_preference_ordering(
+ self,
+ tmp_path,
+ tokenizer,
+ ):
+ from mlx_tune import OnlineDPOConfig, OnlineDPOTrainer, build_reward_model
+
+ reward_model = build_reward_model(make_model(60))
+ reward_model.head.update(
+ {
+ "weight": mx.zeros_like(reward_model.head.weight),
+ "bias": mx.array([1.0], dtype=mx.float32),
+ },
+ strict=False,
+ )
+ mx.eval(reward_model.head.parameters())
+
+ trainer = OnlineDPOTrainer(
+ model=make_model(61),
+ train_dataset=[{"prompt": "Solve 1 + 1", "reward_context": "2"}],
+ tokenizer=tokenizer,
+ reward_model=reward_model,
+ reward_fn=lambda response, context: (_ for _ in ()).throw(RuntimeError("reward_fn should not run")),
+ args=OnlineDPOConfig(
+ learning_rate=1e-2,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ num_generations=2,
+ max_completion_length=3,
+ output_dir=str(tmp_path / "online_dpo"),
+ ),
+ )
-# =============================================================================
-# GRADIENT FLOW TESTS
-# =============================================================================
-
-@pytest.mark.integration
-class TestGradientFlow:
- """Test that gradients flow correctly in RL trainers."""
-
- def _check_grads_nonzero(self, grads):
- """Recursively check if any gradients are non-zero."""
- if isinstance(grads, dict):
- for value in grads.values():
- if self._check_grads_nonzero(value):
- return True
- return False
- elif isinstance(grads, mx.array):
- return float(mx.sum(mx.abs(grads)).item()) > 0
- return False
-
- def test_dpo_gradient_not_zero(self):
- """Test DPO computes non-zero gradients."""
- from mlx_tune.losses import dpo_loss
-
- model = SmallLanguageModel(vocab_size=50, hidden_size=32)
- mx.eval(model.parameters())
-
- # Create sample data
- chosen = mx.array([[1, 2, 3, 4, 5]])
- rejected = mx.array([[1, 2, 6, 7, 8]])
- chosen_len = mx.array([5])
- rejected_len = mx.array([5])
-
- # Compute loss and gradients
- def loss_fn(model):
- loss, _ = dpo_loss(model, chosen, rejected, chosen_len, rejected_len, beta=0.1)
- return loss
-
- loss, grads = nn.value_and_grad(model, loss_fn)(model)
- mx.eval(loss, grads)
-
- # Check gradients are not all zero
- has_nonzero_grad = self._check_grads_nonzero(grads)
+ result = trainer.train()
- assert has_nonzero_grad, "DPO gradients should not all be zero"
- assert not mx.isnan(loss), "DPO loss should not be NaN"
+ assert result["status"] == "success"
+ assert trainer._last_rollout_batch is not None
+ assert mx.allclose(
+ trainer._last_rollout_batch.rewards,
+ mx.ones_like(trainer._last_rollout_batch.rewards),
+ )
- def test_orpo_gradient_not_zero(self):
- """Test ORPO computes non-zero gradients."""
- from mlx_tune.losses import orpo_loss
+ def test_online_dpo_prefers_answer_context_over_prompt(
+ self,
+ tmp_path,
+ tokenizer,
+ ):
+ from mlx_tune import OnlineDPOConfig, OnlineDPOTrainer
- model = SmallLanguageModel(vocab_size=50, hidden_size=32)
- mx.eval(model.parameters())
+ seen_contexts = []
- chosen = mx.array([[1, 2, 3, 4, 5]])
- rejected = mx.array([[1, 2, 6, 7, 8]])
- chosen_len = mx.array([5])
- rejected_len = mx.array([5])
+ def reward_fn(response: str, context: str) -> float:
+ seen_contexts.append(context)
+ return float(len(response))
- def loss_fn(model):
- loss, _ = orpo_loss(model, chosen, rejected, chosen_len, rejected_len, beta=0.1)
- return loss
+ trainer = OnlineDPOTrainer(
+ model=make_model(62),
+ train_dataset=[{"prompt": "Solve 2 + 2", "answer": "4"}],
+ tokenizer=tokenizer,
+ reward_fn=reward_fn,
+ args=OnlineDPOConfig(
+ learning_rate=1e-2,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ num_generations=2,
+ max_completion_length=3,
+ output_dir=str(tmp_path / "online_dpo_answer_context"),
+ ),
+ )
- loss, grads = nn.value_and_grad(model, loss_fn)(model)
- mx.eval(loss, grads)
+ result = trainer.train()
- has_nonzero_grad = self._check_grads_nonzero(grads)
+ assert result["status"] == "success"
+ assert seen_contexts
+ assert all(context == "4" for context in seen_contexts)
+
+ def test_online_dpo_prompt_rollouts_respect_rollout_batch_size(
+ self,
+ tmp_path,
+ tokenizer,
+ ):
+ from mlx_tune import OnlineDPOConfig, OnlineDPOTrainer
+
+ trainer = OnlineDPOTrainer(
+ model=make_model(66),
+ train_dataset=[
+ {"prompt": "Solve 1 + 1", "answer": "2"},
+ {"prompt": "Solve 2 + 2", "answer": "4"},
+ {"prompt": "Solve 3 + 3", "answer": "6"},
+ ],
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=OnlineDPOConfig(
+ learning_rate=1e-2,
+ per_device_train_batch_size=1,
+ rollout_batch_size=3,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ num_generations=2,
+ max_completion_length=3,
+ output_dir=str(tmp_path / "online_dpo_rollout_batch_size"),
+ ),
+ )
- assert has_nonzero_grad, "ORPO gradients should not all be zero"
- assert not mx.isnan(loss), "ORPO loss should not be NaN"
+ result = trainer.train()
+ assert result["status"] == "success"
+ assert trainer._last_rollout_batch is not None
+ assert len(trainer._last_rollout_batch.prompt_ids) == 6
+ assert sorted(set(trainer._last_rollout_batch.prompt_group_indices.tolist())) == [0, 1, 2]
-# =============================================================================
-# LOSS STABILITY TESTS
-# =============================================================================
@pytest.mark.integration
-class TestLossStability:
- """Test that losses remain stable (no NaN, Inf) during training."""
-
- def test_dpo_loss_stability(self):
- """Test DPO loss doesn't produce NaN or Inf."""
- from mlx_tune.losses import dpo_loss
-
- model = SmallLanguageModel(vocab_size=50, hidden_size=32)
- mx.eval(model.parameters())
-
- # Run multiple forward passes
- for i in range(10):
- chosen = mx.random.randint(0, 50, (2, 20))
- rejected = mx.random.randint(0, 50, (2, 20))
- chosen_len = mx.array([15, 18])
- rejected_len = mx.array([17, 16])
-
- loss, ntoks = dpo_loss(model, chosen, rejected, chosen_len, rejected_len, beta=0.1)
- mx.eval(loss)
-
- assert not mx.isnan(loss), f"DPO loss became NaN at iteration {i}"
- assert not mx.isinf(loss), f"DPO loss became Inf at iteration {i}"
-
- def test_orpo_loss_stability(self):
- """Test ORPO loss doesn't produce NaN or Inf."""
- from mlx_tune.losses import orpo_loss
-
- model = SmallLanguageModel(vocab_size=50, hidden_size=32)
- mx.eval(model.parameters())
-
- for i in range(10):
- chosen = mx.random.randint(0, 50, (2, 20))
- rejected = mx.random.randint(0, 50, (2, 20))
- chosen_len = mx.array([15, 18])
- rejected_len = mx.array([17, 16])
-
- loss, _ = orpo_loss(model, chosen, rejected, chosen_len, rejected_len, beta=0.1)
- mx.eval(loss)
-
- assert not mx.isnan(loss), f"ORPO loss became NaN at iteration {i}"
- assert not mx.isinf(loss), f"ORPO loss became Inf at iteration {i}"
-
- def test_simpo_loss_stability(self):
- """Test SimPO loss doesn't produce NaN or Inf."""
- from mlx_tune.losses import simpo_loss
-
- model = SmallLanguageModel(vocab_size=50, hidden_size=32)
- mx.eval(model.parameters())
-
- for i in range(10):
- chosen = mx.random.randint(0, 50, (2, 20))
- rejected = mx.random.randint(0, 50, (2, 20))
- chosen_len = mx.array([15, 18])
- rejected_len = mx.array([17, 16])
-
- loss, _ = simpo_loss(model, chosen, rejected, chosen_len, rejected_len,
- beta=0.1, gamma=0.5)
- mx.eval(loss)
-
- assert not mx.isnan(loss), f"SimPO loss became NaN at iteration {i}"
- assert not mx.isinf(loss), f"SimPO loss became Inf at iteration {i}"
+def test_grpo_evaluate_records_namespaced_metrics_and_preference_accuracy(tmp_path, tokenizer):
+ from mlx_tune import GRPOConfig, GRPOTrainer
+
+ output_dir = tmp_path / "grpo_eval"
+ trainer = GRPOTrainer(
+ model=make_model(80),
+ train_dataset=[{"prompt": "Solve 2 + 2", "answer": "4"}],
+ eval_dataset=[{"prompt": "Solve 3 + 3", "answer": "6"}],
+ eval_preference_dataset=[
+ {"prompt": "Q", "chosen": "AAAA", "rejected": "B"},
+ {"prompt": "Q", "chosen": "CCCC", "rejected": "D"},
+ ],
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=GRPOConfig(
+ seed=123,
+ learning_rate=1e-2,
+ beta=0.01,
+ num_generations=2,
+ max_completion_length=3,
+ max_steps=1,
+ eval_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+
+ trainer.train()
+ eval_metrics = trainer.evaluate()
+ history_rows = [json.loads(line) for line in (output_dir / "metrics" / "history.jsonl").read_text().splitlines()]
+
+ assert "eval/reward_mean" in eval_metrics
+ assert "eval/preference_win_rate" in eval_metrics
+ assert any("train/policy_loss" in row for row in history_rows)
+ assert any("eval/reward_mean" in row for row in history_rows)
+ assert any("eval/preference_win_rate" in row for row in history_rows)
-# =============================================================================
-# REWARD FUNCTION TESTS
-# =============================================================================
+@pytest.mark.integration
+@pytest.mark.parametrize("trainer_kind", ["grpo", "ppo", "online_dpo"])
+def test_on_policy_trainers_with_same_seed_are_reproducible(tmp_path, tokenizer, grpo_dataset, trainer_kind):
+ from mlx_tune import (
+ GRPOConfig,
+ GRPOTrainer,
+ OnlineDPOConfig,
+ OnlineDPOTrainer,
+ PPOConfig,
+ PPOTrainer,
+ build_value_model,
+ )
+
+ def build_trainer(output_dir):
+ if trainer_kind == "grpo":
+ return GRPOTrainer(
+ model=make_model(81),
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=GRPOConfig(
+ seed=777,
+ learning_rate=1e-2,
+ beta=0.01,
+ num_generations=2,
+ max_completion_length=3,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+ if trainer_kind == "ppo":
+ return PPOTrainer(
+ model=make_model(81),
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ value_model=build_value_model(make_model(82)),
+ args=PPOConfig(
+ seed=777,
+ learning_rate=1e-2,
+ value_learning_rate=1e-2,
+ num_generations=2,
+ max_completion_length=3,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+ return OnlineDPOTrainer(
+ model=make_model(81),
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=OnlineDPOConfig(
+ seed=777,
+ learning_rate=1e-2,
+ num_generations=2,
+ max_completion_length=3,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
-class TestRewardFunctions:
- """Test reward functions used in GRPO."""
+ trainer_a = build_trainer(tmp_path / f"{trainer_kind}_a")
+ trainer_b = build_trainer(tmp_path / f"{trainer_kind}_b")
+ trainer_a.train()
+ trainer_b.train()
+
+ def _strip_timing(metrics_history):
+ cleaned = []
+ for row in metrics_history:
+ cleaned.append(
+ {
+ key: value
+ for key, value in row.items()
+ if not key.endswith("_wall")
+ }
+ )
+ return cleaned
- def test_simple_reward_function(self):
- """Test simple reward function."""
- from mlx_tune import create_reward_function
+ assert _strip_timing(trainer_a.metrics_history) == _strip_timing(trainer_b.metrics_history)
+ assert trainer_a._last_rollout_batch is not None
+ assert trainer_b._last_rollout_batch is not None
+ assert trainer_a._last_rollout_batch.completion_ids == trainer_b._last_rollout_batch.completion_ids
- reward_fn = create_reward_function("simple")
- # Simple reward expects (response, ground_truth)
- # Returns 1.0 if ground_truth is in response, else 0.0
- score_match = reward_fn("The answer is 42", "42")
- score_no_match = reward_fn("The answer is something", "42")
+@pytest.mark.integration
+def test_grpo_checkpoint_records_step_boundary_and_independent_cursors(tmp_path, tokenizer):
+ from mlx_tune import GRPOConfig, GRPOTrainer, resume_from_checkpoint
+
+ output_dir = tmp_path / "grpo_cursor_checkpoint"
+ trainer = GRPOTrainer(
+ model=make_model(83),
+ train_dataset=[
+ {"prompt": "Q1", "completion": "A1", "reward": 1.0},
+ {"prompt": "Q2", "completion": "A2", "reward": 0.5},
+ {"prompt": "Q3", "completion": "A3", "reward": 0.2},
+ {"prompt": "Q4", "completion": "A4", "reward": -0.1},
+ ],
+ tokenizer=tokenizer,
+ args=GRPOConfig(
+ seed=99,
+ learning_rate=1e-2,
+ reward_source="offline",
+ num_generations=2,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
- assert score_match == 1.0, "Should return 1.0 when ground_truth is in response"
- assert score_no_match == 0.0, "Should return 0.0 when ground_truth is not in response"
+ trainer.train()
+ bundle = resume_from_checkpoint(output_dir)
+ state = bundle.trainer_state
- def test_math_reward_function(self):
- """Test math reward function."""
- from mlx_tune import create_reward_function
+ assert bundle.manifest["format_version"] == 4
+ assert state["seed"] == 99
+ assert state["trainer_state"]["step_boundary"]["completed_optimizer_step"] == 1
+ assert state["trainer_state"]["step_boundary"]["checkpoint_authoritative"] is True
+ assert state["trainer_state"]["cursors"]["prompt_dataset"] == 0
+ assert state["trainer_state"]["cursors"]["offline_rollout_dataset"] == 2
- reward_fn = create_reward_function("math")
- # Math reward expects (response, ground_truth) and compares extracted numbers
- correct = reward_fn("The answer is 42", "42")
- incorrect = reward_fn("The answer is 99", "42")
+@pytest.mark.integration
+def test_grpo_reuses_precomputed_rollout_reference_cache_after_resume(tmp_path, tokenizer, monkeypatch):
+ import mlx_tune.rl_trainers as rl_trainers_module
+ from mlx_tune import GRPOConfig, GRPOTrainer
+
+ output_dir = tmp_path / "grpo_cache_resume"
+ train_dataset = [
+ {"prompt": "Q1", "completion": "A1", "reward": 1.0},
+ {"prompt": "Q2", "completion": "A2", "reward": 0.5},
+ {"prompt": "Q3", "completion": "A3", "reward": -0.2},
+ ]
+ trainer = GRPOTrainer(
+ model=make_model(84),
+ train_dataset=train_dataset,
+ tokenizer=tokenizer,
+ args=GRPOConfig(
+ seed=11,
+ learning_rate=1e-2,
+ reward_source="offline",
+ num_generations=2,
+ precompute_reference_scores=True,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+ trainer.train()
+
+ resumed = GRPOTrainer(
+ model=make_model(85),
+ train_dataset=train_dataset,
+ tokenizer=tokenizer,
+ args=GRPOConfig(
+ seed=11,
+ learning_rate=1e-2,
+ reward_source="offline",
+ num_generations=2,
+ precompute_reference_scores=True,
+ max_steps=2,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+ resumed._apply_lora_if_needed()
+ resumed._prepare_prompt_samples()
+ optimizer = resumed._optimizer_for_training()
+ resumed.optimizer = optimizer
+ resumed.load_state(optimizer=optimizer, checkpoint_dir=output_dir)
+ resumed._ensure_reference_policy()
+
+ original = rl_trainers_module.score_policy_in_chunks
+ reference_model = resumed.reference_policy.model.model
+ calls = {"reference": 0}
+
+ def wrapped(model, *args, **kwargs):
+ if model is reference_model:
+ calls["reference"] += 1
+ return original(model, *args, **kwargs)
+
+ monkeypatch.setattr(rl_trainers_module, "score_policy_in_chunks", wrapped)
+ batch = resumed._collect_fixed_rollout_batch(
+ resumed.rollout_samples[:2],
+ cache_key="grpo.train_rollout_reference_logprobs",
+ )
+
+ assert calls["reference"] == 0
+ assert batch.reference_logprobs is not None
- assert correct == 1.0, "Correct math answer should get 1.0"
- assert incorrect == 0.0, "Incorrect math answer should get 0.0"
- def test_length_reward_function(self):
- """Test length-based reward function."""
- from mlx_tune import create_reward_function
+@pytest.mark.integration
+def test_grpo_fixed_rollouts_respect_length_caps(tmp_path, tokenizer):
+ from mlx_tune import GRPOConfig, GRPOTrainer
+
+ trainer = GRPOTrainer(
+ model=make_model(87),
+ train_dataset=[
+ {"prompt": "abcdefghij", "completion": "klmnopqrst", "reward": 1.0},
+ ],
+ tokenizer=tokenizer,
+ args=GRPOConfig(
+ seed=13,
+ learning_rate=1e-2,
+ reward_source="offline",
+ max_seq_length=4,
+ max_completion_length=2,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(tmp_path / "grpo_fixed_rollout_caps"),
+ ),
+ )
+
+ trainer._apply_lora_if_needed()
+ trainer._prepare_prompt_samples()
+ batch = trainer._collect_fixed_rollout_batch(trainer.rollout_samples)
+
+ assert batch.prompt_lengths.tolist() == [2]
+ assert batch.completion_lengths.tolist() == [2]
+ assert batch.policy_eval.input_ids.shape == (1, 4)
+ assert batch.truncation_flags.tolist() == [True]
- reward_fn = create_reward_function("length")
- # Length reward expects (response, _) where _ is ignored
- short = reward_fn("Hi there", "")
- medium = reward_fn("This is a longer response with about fifteen words in it here now", "")
- long = reward_fn(" ".join(["word"] * 100), "")
+@pytest.mark.integration
+def test_grpo_invalidates_precomputed_rollout_reference_cache_when_dataset_changes(
+ tmp_path,
+ tokenizer,
+ monkeypatch,
+):
+ import mlx_tune.rl_trainers as rl_trainers_module
+ from mlx_tune import GRPOConfig, GRPOTrainer
+
+ output_dir = tmp_path / "grpo_cache_dataset_drift"
+ original_dataset = [
+ {"prompt": "Q1", "completion": "A1", "reward": 1.0},
+ {"prompt": "Q2", "completion": "A2", "reward": 0.5},
+ ]
+ changed_dataset = [
+ {"prompt": "!!!!!!!!", "completion": "zzzzzzzz", "reward": 1.0},
+ {"prompt": "????????", "completion": "yyyyyyyy", "reward": 0.5},
+ ]
+ trainer = GRPOTrainer(
+ model=make_model(88),
+ train_dataset=original_dataset,
+ tokenizer=tokenizer,
+ args=GRPOConfig(
+ seed=19,
+ learning_rate=1e-2,
+ reward_source="offline",
+ num_generations=2,
+ precompute_reference_scores=True,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+ trainer.train()
+ original_scores = mx.array(trainer.runtime_cache_arrays["grpo.train_rollout_reference_logprobs"])
+
+ resumed = GRPOTrainer(
+ model=make_model(89),
+ train_dataset=changed_dataset,
+ tokenizer=tokenizer,
+ args=GRPOConfig(
+ seed=19,
+ learning_rate=1e-2,
+ reward_source="offline",
+ num_generations=2,
+ precompute_reference_scores=True,
+ max_steps=2,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+ resumed._apply_lora_if_needed()
+ resumed._prepare_prompt_samples()
+ optimizer = resumed._optimizer_for_training()
+ resumed.optimizer = optimizer
+ resumed.load_state(optimizer=optimizer, checkpoint_dir=output_dir)
+ resumed._ensure_reference_policy()
+
+ original = rl_trainers_module.score_policy_in_chunks
+ reference_model = resumed.reference_policy.model.model
+ calls = {"reference": 0}
+
+ def wrapped(model, *args, **kwargs):
+ if model is reference_model:
+ calls["reference"] += 1
+ return original(model, *args, **kwargs)
+
+ monkeypatch.setattr(rl_trainers_module, "score_policy_in_chunks", wrapped)
+ batch = resumed._collect_fixed_rollout_batch(
+ resumed.rollout_samples,
+ cache_key="grpo.train_rollout_reference_logprobs",
+ )
+
+ assert calls["reference"] > 0
+ assert not mx.allclose(batch.reference_logprobs, original_scores)
- # Short (<10 words) = 0.2, Medium (10-50) = 0.5, Long (50-200) = 1.0
- assert short == 0.2, f"Short response should be 0.2, got {short}"
- assert medium == 0.5, f"Medium response should be 0.5, got {medium}"
- assert long == 1.0, f"Long response should be 1.0, got {long}"
- def test_custom_reward_function(self):
- """Test using a custom reward function."""
- def my_reward(response: str, ground_truth: str = "") -> float:
- # Reward responses that contain "please" or "thank"
- score = 0.0
- if "please" in response.lower():
- score += 0.5
- if "thank" in response.lower():
- score += 0.5
- return score
+@pytest.mark.integration
+def test_grpo_resume_rejects_objective_and_update_shape_drift(tmp_path, tokenizer, grpo_dataset):
+ from mlx_tune import GRPOConfig, GRPOTrainer
+
+ output_dir = tmp_path / "grpo_resume_config_drift"
+ trainer = GRPOTrainer(
+ model=make_model(90),
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=GRPOConfig(
+ seed=31,
+ learning_rate=1e-2,
+ beta=0.01,
+ clip_epsilon=0.2,
+ rollout_batch_size=1,
+ minibatch_reuse_steps=1,
+ num_generations=2,
+ max_completion_length=4,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+ trainer.train()
+
+ resumed = GRPOTrainer(
+ model=make_model(91),
+ train_dataset=grpo_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=GRPOConfig(
+ seed=31,
+ learning_rate=1e-2,
+ beta=0.99,
+ clip_epsilon=0.33,
+ rollout_batch_size=5,
+ minibatch_reuse_steps=7,
+ num_generations=2,
+ max_completion_length=4,
+ max_steps=2,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+ resumed._apply_lora_if_needed()
+ resumed._prepare_prompt_samples()
+ optimizer = resumed._optimizer_for_training()
+ resumed.optimizer = optimizer
- assert my_reward("Please help me") == 0.5
- assert my_reward("Thank you") == 0.5
- assert my_reward("Please help, thank you!") == 1.0
- assert my_reward("Hello world") == 0.0
+ with pytest.raises(ValueError, match="fingerprint"):
+ resumed.load_state(optimizer=optimizer, checkpoint_dir=output_dir)
-# Run tests
-if __name__ == "__main__":
- pytest.main([__file__, "-v", "-m", "integration"])
+@pytest.mark.integration
+def test_online_dpo_eval_preference_reference_cache_reused_after_resume(tmp_path, tokenizer, monkeypatch):
+ import mlx_tune.rl_trainers as rl_trainers_module
+ from mlx_tune import OnlineDPOConfig, OnlineDPOTrainer
+
+ output_dir = tmp_path / "online_dpo_eval_cache"
+ eval_preference_dataset = [
+ {"prompt": "Q", "chosen": "AA", "rejected": "B"},
+ {"prompt": "Q", "chosen": "CC", "rejected": "D"},
+ ]
+ trainer = OnlineDPOTrainer(
+ model=make_model(86),
+ train_dataset=[{"prompt": "Solve 2 + 2", "answer": "4"}],
+ eval_preference_dataset=eval_preference_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=OnlineDPOConfig(
+ seed=22,
+ learning_rate=1e-2,
+ num_generations=2,
+ max_completion_length=3,
+ max_steps=1,
+ eval_steps=1,
+ precompute_reference_scores=True,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+ trainer.train()
+ trainer.evaluate()
+ trainer.save_state(optimizer=trainer.optimizer)
+
+ resumed = OnlineDPOTrainer(
+ model=make_model(87),
+ train_dataset=[{"prompt": "Solve 2 + 2", "answer": "4"}],
+ eval_preference_dataset=eval_preference_dataset,
+ tokenizer=tokenizer,
+ reward_fn=lambda response, context: float(len(response)),
+ args=OnlineDPOConfig(
+ seed=22,
+ learning_rate=1e-2,
+ num_generations=2,
+ max_completion_length=3,
+ max_steps=2,
+ eval_steps=1,
+ precompute_reference_scores=True,
+ logging_steps=1,
+ save_steps=1,
+ output_dir=str(output_dir),
+ ),
+ )
+ resumed._apply_lora_if_needed()
+ resumed._prepare_prompt_samples()
+ optimizer = resumed._optimizer_for_training()
+ resumed.optimizer = optimizer
+ resumed.load_state(optimizer=optimizer, checkpoint_dir=output_dir)
+
+ original = rl_trainers_module.score_policy_in_chunks
+ reference_model = resumed.reference_policy.model.model
+ calls = {"reference": 0}
+
+ def wrapped(model, *args, **kwargs):
+ if model is reference_model:
+ calls["reference"] += 1
+ return original(model, *args, **kwargs)
+
+ monkeypatch.setattr(rl_trainers_module, "score_policy_in_chunks", wrapped)
+ eval_metrics = resumed.evaluate()
+
+ assert calls["reference"] == 0
+ assert "eval/preference_win_rate" in eval_metrics
diff --git a/tests/test_trainers.py b/tests/test_trainers.py
index 61367ae..012c1e3 100644
--- a/tests/test_trainers.py
+++ b/tests/test_trainers.py
@@ -74,6 +74,17 @@ def test_dpoconfig_custom_beta(self):
assert config.beta == 0.5
+class TestRewardConfig:
+ def test_rewardconfig_defaults(self):
+ from mlx_tune import RewardConfig
+
+ config = RewardConfig()
+
+ assert config.learning_rate == 5e-6
+ assert config.regression_loss_type == "mse"
+ assert config.dataset_mode is None
+
+
class TestGRPOConfig:
"""Test GRPOConfig class."""
@@ -84,9 +95,13 @@ def test_grpoconfig_defaults(self):
config = GRPOConfig()
assert config.loss_type == "grpo"
+ assert config.advantage_mode == "group_zscore"
+ assert config.advantage_estimator == "group_zscore"
assert config.num_generations == 4
assert config.temperature == 0.7
assert config.beta == 0.04
+ assert config.kl_beta == 0.04
+ assert config.reward_source == "auto"
def test_grpoconfig_with_reward_fn(self):
"""Test GRPOConfig with custom reward function."""
@@ -100,6 +115,62 @@ def custom_reward(response, prompt):
assert config.reward_fn is not None
assert config.num_generations == 8
+ def test_grpoconfig_aliases_and_to_dict(self):
+ from mlx_tune import GRPOConfig
+
+ config = GRPOConfig(
+ generations_per_prompt=6,
+ advantage_estimator="rloo",
+ reward_fn=lambda *_: 1.0,
+ )
+
+ assert config.num_generations == 6
+ assert config.advantage_mode == "rloo"
+ assert config.to_dict()["num_generations"] == 6
+ assert "reward_fn" not in config.to_dict()
+
+ def test_grpoconfig_variant_defaults_match_loss_family(self):
+ from mlx_tune import GRPOConfig
+
+ dapo = GRPOConfig(loss_type="dapo")
+ dr_grpo = GRPOConfig(loss_type="dr_grpo")
+
+ assert dapo.mask_truncated_completions is True
+ assert dapo.epsilon_high == 0.28
+ assert dr_grpo.scale_rewards is False
+ assert dr_grpo.epsilon_low == dr_grpo.clip_epsilon
+ assert dr_grpo.epsilon_high == dr_grpo.clip_epsilon
+
+
+class TestPPOAndOnlineDPOConfig:
+ def test_ppoconfig_defaults(self):
+ from mlx_tune import PPOConfig
+
+ config = PPOConfig()
+
+ assert config.ppo_epochs == 2
+ assert config.minibatch_reuse_steps == 2
+ assert config.value_learning_rate == config.learning_rate
+ assert config.advantage_estimator == "gae"
+ assert config.kl_penalty_mode == "kl"
+
+ def test_online_dpoconfig_defaults(self):
+ from mlx_tune import OnlineDPOConfig
+
+ config = OnlineDPOConfig()
+
+ assert config.num_generations == 4
+ assert config.beta == 0.1
+
+ def test_new_offline_config_exports(self):
+ from mlx_tune import KTOConfig, SimPOConfig
+
+ kto = KTOConfig()
+ simpo = SimPOConfig()
+
+ assert kto.beta == 0.1
+ assert simpo.gamma == 0.5
+
class TestTrainerInitialization:
"""Test trainer initialization (without actual model loading)."""
@@ -109,21 +180,30 @@ def test_imports_work(self):
from mlx_tune import (
SFTTrainer,
SFTConfig,
+ RewardTrainer,
+ RewardConfig,
DPOTrainer,
DPOConfig,
ORPOTrainer,
ORPOConfig,
GRPOTrainer,
GRPOConfig,
+ PPOTrainer,
+ PPOConfig,
+ OnlineDPOTrainer,
+ OnlineDPOConfig,
KTOTrainer,
SimPOTrainer,
)
# Just verify imports work
assert SFTTrainer is not None
+ assert RewardTrainer is not None
assert DPOTrainer is not None
assert ORPOTrainer is not None
assert GRPOTrainer is not None
+ assert PPOTrainer is not None
+ assert OnlineDPOTrainer is not None
assert KTOTrainer is not None
assert SimPOTrainer is not None
@@ -165,6 +245,11 @@ def test_prepare_preference_dataset_import(self):
from mlx_tune import prepare_preference_dataset
assert prepare_preference_dataset is not None
+ def test_prepare_rl_dataset_import(self):
+ from mlx_tune import prepare_rl_dataset
+
+ assert prepare_rl_dataset is not None
+
def test_create_reward_function_simple(self):
"""Test create_reward_function with simple type."""
from mlx_tune import create_reward_function
@@ -205,6 +290,21 @@ def test_create_reward_function_length(self):
medium_result = reward_fn(" ".join(["word"] * 30), "")
assert medium_result == 0.5
+ def test_create_reward_function_composition(self):
+ from mlx_tune import create_reward_function
+
+ reward_fn = create_reward_function(
+ rewards=[
+ {"name": "simple", "source": "simple", "weight": 0.25},
+ {"name": "length", "source": "length", "weight": 0.75},
+ ]
+ )
+
+ result = reward_fn.evaluate({"completion_text": "The answer is 42", "reward_context": "42"})
+
+ assert result["reward"] > 0.0
+ assert set(result["components"]) >= {"simple", "length"}
+
class TestExportFunctions:
"""Test export utility functions."""
diff --git a/tests/test_trl_compat.py b/tests/test_trl_compat.py
new file mode 100644
index 0000000..d9f6395
--- /dev/null
+++ b/tests/test_trl_compat.py
@@ -0,0 +1,204 @@
+import importlib
+import sys
+from pathlib import Path
+from types import ModuleType, SimpleNamespace
+
+import mlx.core as mx
+import mlx.nn as nn
+import pytest
+
+
+class TinyModel(nn.Module):
+ def __init__(self, vocab_size: int = 32, hidden_size: int = 16):
+ super().__init__()
+ self.embedding = nn.Embedding(vocab_size, hidden_size)
+ self.output = nn.Linear(hidden_size, vocab_size)
+
+ def __call__(self, x):
+ return self.output(self.embedding(x))
+
+
+class TinyTokenizer:
+ pad_token_id = 0
+ eos_token_id = 1
+ bos_token_id = 2
+
+ def encode(self, text: str, add_special_tokens: bool = True):
+ token_ids = [((ord(char) % 10) + 3) for char in text]
+ if add_special_tokens:
+ token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
+ return token_ids
+
+ def decode(self, ids, skip_special_tokens: bool = True):
+ if skip_special_tokens:
+ ids = [token for token in ids if token not in (self.pad_token_id, self.eos_token_id, self.bos_token_id)]
+ return "".join(chr(65 + (token % 26)) for token in ids)
+
+
+@pytest.fixture(autouse=True)
+def isolated_trl_modules():
+ saved = {
+ name: module
+ for name, module in sys.modules.items()
+ if name == "trl" or name.startswith("trl.")
+ }
+ for name in list(saved):
+ sys.modules.pop(name, None)
+ yield
+ for name in list(sys.modules):
+ if name == "trl" or name.startswith("trl."):
+ sys.modules.pop(name, None)
+ sys.modules.update(saved)
+
+
+def test_patch_fast_rl_creates_fallback_trl_module_when_missing():
+ from mlx_tune import GRPOConfig as MLXGRPOConfig
+ from mlx_tune import GRPOTrainer as MLXGRPOTrainer
+ from mlx_tune import PatchFastRL
+ from mlx_tune import SFTConfig as MLXSFTConfig
+ from mlx_tune import SFTTrainer as MLXSFTTrainer
+
+ PatchFastRL()
+
+ trl = importlib.import_module("trl")
+ from trl import GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer
+
+ assert trl.__MLX_TUNE_PATCHED__ is True
+ assert trl.trainer.__MLX_TUNE_PATCHED__ is True
+ assert issubclass(GRPOConfig, MLXGRPOConfig)
+ assert issubclass(GRPOTrainer, MLXGRPOTrainer)
+ assert issubclass(SFTConfig, MLXSFTConfig)
+ assert issubclass(SFTTrainer, MLXSFTTrainer)
+ assert trl.trainer.GRPOTrainer is GRPOTrainer
+ assert trl.trainer.SFTConfig is SFTConfig
+
+
+def test_patch_fast_rl_mutates_existing_module_in_place_and_is_idempotent():
+ from mlx_tune import PatchFastRL
+
+ trl_module = ModuleType("trl")
+ trl_module.__package__ = "trl"
+ trl_module.__path__ = []
+ trl_module.keep_me = "present"
+ trainer_module = ModuleType("trl.trainer")
+ trainer_module.keep_me_too = "present"
+ trl_module.trainer = trainer_module
+ sys.modules["trl"] = trl_module
+ sys.modules["trl.trainer"] = trainer_module
+
+ PatchFastRL()
+ first_grpo_trainer = trl_module.GRPOTrainer
+
+ PatchFastRL()
+
+ assert sys.modules["trl"] is trl_module
+ assert sys.modules["trl.trainer"] is trainer_module
+ assert trl_module.trainer is trainer_module
+ assert trl_module.keep_me == "present"
+ assert trainer_module.keep_me_too == "present"
+ assert trl_module.GRPOTrainer is first_grpo_trainer
+ assert trainer_module.GRPOTrainer is first_grpo_trainer
+
+
+def test_grpo_config_normalizes_reward_aliases():
+ from mlx_tune import PatchFastRL
+
+ PatchFastRL()
+
+ from trl import GRPOConfig
+
+ reward_fn = lambda response, context: float(len(response) + len(context))
+ reward_funcs = [reward_fn, reward_fn]
+
+ config = GRPOConfig(
+ reward_funcs=reward_funcs,
+ reward_func=reward_fn,
+ generations_per_prompt=3,
+ baseline_mode="rloo",
+ )
+
+ assert config.reward_sources[:2] == reward_funcs
+ assert config.reward_sources[2]["source"] is reward_fn
+ assert config.reward_fn is reward_fn
+ assert config.num_generations == 3
+ assert config.advantage_estimator == "rloo"
+
+
+def test_grpo_trainer_accepts_processing_class_and_foreign_args(tmp_path):
+ from mlx_tune import PatchFastRL
+
+ PatchFastRL()
+
+ from trl import GRPOTrainer
+
+ reward_fn = lambda response, context: float(len(response) + len(context))
+ tokenizer = TinyTokenizer()
+ model = TinyModel()
+ mx.eval(model.parameters())
+ output_dir = tmp_path / "grpo"
+
+ trainer = GRPOTrainer(
+ model=model,
+ processing_class=tokenizer,
+ train_dataset=[{"prompt": "hi"}],
+ eval_dataset=[{"prompt": "bye"}],
+ args=SimpleNamespace(
+ output_dir=str(output_dir),
+ learning_rate=1e-5,
+ per_device_train_batch_size=1,
+ num_train_epochs=1,
+ max_steps=1,
+ logging_steps=1,
+ save_steps=1,
+ max_seq_length=8,
+ max_completion_length=2,
+ generations_per_prompt=2,
+ reward_funcs=[reward_fn],
+ ),
+ )
+
+ assert trainer.tokenizer is tokenizer
+ assert trainer.eval_dataset == [{"prompt": "bye"}]
+ assert trainer.config.num_generations == 2
+ assert trainer.reward_sources == [reward_fn]
+ assert trainer.output_dir == output_dir
+
+
+def test_dpo_trainer_runs_via_patched_trl_import(tmp_path):
+ from mlx_tune import PatchFastRL
+
+ PatchFastRL()
+
+ from trl import DPOTrainer
+
+ tokenizer = TinyTokenizer()
+ model = TinyModel()
+ mx.eval(model.parameters())
+
+ trainer = DPOTrainer(
+ model=model,
+ processing_class=tokenizer,
+ train_dataset=[
+ {"prompt": "a", "chosen": "b", "rejected": "c"},
+ {"prompt": "d", "chosen": "e", "rejected": "f"},
+ ],
+ args=SimpleNamespace(
+ output_dir=str(tmp_path / "dpo"),
+ learning_rate=1e-4,
+ per_device_train_batch_size=1,
+ gradient_accumulation_steps=1,
+ num_train_epochs=1,
+ max_steps=1,
+ warmup_steps=0,
+ logging_steps=1,
+ save_steps=1,
+ max_seq_length=12,
+ max_prompt_length=6,
+ ),
+ )
+
+ result = trainer.train()
+
+ assert result["status"] == "success"
+ assert result["global_step"] == 1
+ assert Path(result["adapter_path"]).exists()