diff --git a/examples/09_rl_training_methods.py b/examples/09_rl_training_methods.py index 2904d88..66beaae 100644 --- a/examples/09_rl_training_methods.py +++ b/examples/09_rl_training_methods.py @@ -17,8 +17,9 @@ ORPOTrainer, ORPOConfig, GRPOTrainer, GRPOConfig, # Utilities - prepare_preference_dataset, + prepare_rl_dataset, create_reward_function, + resume_from_checkpoint, ) @@ -60,6 +61,8 @@ def demo_dpo_training(): }, ] + prepared_dataset = prepare_rl_dataset(preference_data, mode="preference", tokenizer=tokenizer) + # Configure DPO config = DPOConfig( beta=0.1, # KL penalty coefficient @@ -72,7 +75,7 @@ def demo_dpo_training(): # Create trainer trainer = DPOTrainer( model=model, - train_dataset=preference_data, + train_dataset=prepared_dataset, tokenizer=tokenizer, args=config, ) @@ -122,12 +125,19 @@ def demo_grpo_training(): }, ] - # Create a math reward function - math_reward = create_reward_function("math") + prepared_dataset = prepare_rl_dataset(reasoning_data, mode="prompt", tokenizer=tokenizer) + + # Compose rewards through the public RL API surface. + math_reward = create_reward_function( + rewards=[ + {"name": "math", "source": "math", "weight": 1.0}, + {"name": "length", "source": "length", "weight": 0.1}, + ] + ) # Configure GRPO config = GRPOConfig( - loss_type="grpo", # grpo, dr_grpo, dapo, bnpo + loss_type="grpo", # Phase 1 accepts grpo/dr_grpo/dapo/bnpo via one shared objective beta=0.04, num_generations=4, # Multiple generations per prompt temperature=0.7, @@ -139,7 +149,7 @@ def demo_grpo_training(): # Create trainer with custom reward function trainer = GRPOTrainer( model=model, - train_dataset=reasoning_data, + train_dataset=prepared_dataset, tokenizer=tokenizer, reward_fn=math_reward, # Custom reward! args=config, @@ -153,6 +163,7 @@ def demo_grpo_training(): # Would train with: trainer.train() print("\nTo train: trainer.train()") + print(f"To inspect a saved checkpoint first: {resume_from_checkpoint.__name__}('./grpo_output')") def demo_orpo_training(): @@ -221,14 +232,14 @@ def show_available_trainers(): print(f"| {name} | {method} | {use_case} |") print("\n" + "=" * 70) - print("GRPO Loss Types (for reasoning models)") + print("GRPO Loss Types (accepted in Phase 1)") print("=" * 70) grpo_types = [ - ("grpo", "Standard GRPO", "Default for reasoning"), - ("dr_grpo", "Dr. GRPO", "Distilled version"), - ("dapo", "DAPO", "Data-efficient variant"), - ("bnpo", "BNPO", "Batch-normalized variant"), + ("grpo", "Standard GRPO", "Primary Phase 1 name"), + ("dr_grpo", "Dr. GRPO", "Accepted alias; shared Phase 1 objective"), + ("dapo", "DAPO", "Accepted alias; shared Phase 1 objective"), + ("bnpo", "BNPO", "Accepted alias; shared Phase 1 objective"), ] print("\n| Loss Type | Name | Description |") diff --git a/examples/10_qwen3_arithmetic_grpo_validation.py b/examples/10_qwen3_arithmetic_grpo_validation.py new file mode 100644 index 0000000..3d2b67f --- /dev/null +++ b/examples/10_qwen3_arithmetic_grpo_validation.py @@ -0,0 +1,10 @@ +from pathlib import Path +import sys + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from mlx_tune.arithmetic_grpo_validation import main + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/mlx_tune/__init__.py b/mlx_tune/__init__.py index 81e140f..4ccea8a 100644 --- a/mlx_tune/__init__.py +++ b/mlx_tune/__init__.py @@ -15,7 +15,15 @@ __version__ = "0.4.0" # Renamed to mlx-tune (formerly unsloth-mlx) -from mlx_tune.model import FastLanguageModel +from mlx_tune.model import ( + FastLanguageModel, + ReferencePolicy, + RLModelRoles, + RewardModel, + ValueModel, + build_value_model, + create_rl_model_roles, +) from mlx_tune.trainer import ( prepare_dataset, format_chat_template, @@ -25,25 +33,44 @@ get_training_config, ) from mlx_tune.sft_trainer import SFTTrainer, SFTConfig, TrainingArguments +from mlx_tune.rl_api import ( + RLCheckpointBundle, + PreparedRLDataset, + prepare_rl_dataset, + build_reference_policy, + build_reward_model, + create_reward_function, + resume_from_checkpoint, +) # RL Trainers from mlx_tune.rl_trainers import ( + RewardTrainer, + RewardConfig, DPOTrainer, DPOConfig, ORPOTrainer, ORPOConfig, GRPOTrainer, GRPOConfig, + PPOTrainer, + PPOConfig, + OnlineDPOTrainer, + OnlineDPOConfig, + KTOConfig, + SimPOConfig, KTOTrainer, SimPOTrainer, + prepare_reward_dataset, prepare_preference_dataset, - create_reward_function, + score_reward_model, ) # Loss functions for custom training from mlx_tune.losses import ( compute_log_probs, compute_log_probs_with_lengths, + compute_completion_log_probs, dpo_loss, orpo_loss, kto_loss, @@ -52,6 +79,16 @@ grpo_loss, grpo_batch_loss, compute_reference_logprobs, + pairwise_reward_loss, + reward_model_pairwise_loss, + reward_model_regression_loss, + value_regression_loss, + value_model_regression_loss, + scalar_loss_metrics, + pairwise_ranking_accuracy, + precompute_preference_reference_logprobs, + precompute_kto_reference_logprobs, + ppo_sequence_loss, ) # Vision Language Models @@ -92,10 +129,24 @@ HFDatasetConfig, load_dataset_with_config, ) +from mlx_tune.trl_compat import PatchFastRL __all__ = [ # Core "FastLanguageModel", + "ReferencePolicy", + "RLModelRoles", + "RewardModel", + "ValueModel", + "build_reference_policy", + "build_reward_model", + "build_value_model", + "create_rl_model_roles", + "PreparedRLDataset", + "RLCheckpointBundle", + "prepare_rl_dataset", + "resume_from_checkpoint", + "PatchFastRL", "__version__", # SFT Training "SFTTrainer", @@ -108,6 +159,14 @@ "ORPOConfig", "GRPOTrainer", "GRPOConfig", + "RewardTrainer", + "RewardConfig", + "PPOTrainer", + "PPOConfig", + "OnlineDPOTrainer", + "OnlineDPOConfig", + "KTOConfig", + "SimPOConfig", "KTOTrainer", "SimPOTrainer", # Vision Models @@ -116,6 +175,7 @@ # Loss Functions "compute_log_probs", "compute_log_probs_with_lengths", + "compute_completion_log_probs", "dpo_loss", "orpo_loss", "kto_loss", @@ -124,8 +184,19 @@ "grpo_loss", "grpo_batch_loss", "compute_reference_logprobs", + "pairwise_reward_loss", + "reward_model_pairwise_loss", + "reward_model_regression_loss", + "value_regression_loss", + "value_model_regression_loss", + "scalar_loss_metrics", + "pairwise_ranking_accuracy", + "precompute_preference_reference_logprobs", + "precompute_kto_reference_logprobs", + "ppo_sequence_loss", # Utilities "prepare_dataset", + "prepare_reward_dataset", "prepare_preference_dataset", "format_chat_template", "create_training_data", @@ -133,6 +204,7 @@ "export_to_gguf", "get_training_config", "create_reward_function", + "score_reward_model", "load_vlm_dataset", # Chat Templates and Dataset Formatting "detect_dataset_format", diff --git a/mlx_tune/_rl_runtime.py b/mlx_tune/_rl_runtime.py new file mode 100644 index 0000000..b236e1e --- /dev/null +++ b/mlx_tune/_rl_runtime.py @@ -0,0 +1,1237 @@ +from dataclasses import dataclass, fields, replace +import inspect +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple + +import mlx.core as mx +import mlx.nn as nn + + +@dataclass +class PolicyEvalBatch: + input_ids: mx.array + sequence_lengths: mx.array + token_mask: mx.array + prompt_lengths: Optional[mx.array] = None + completion_lengths: Optional[mx.array] = None + rollout_logprobs: Optional[mx.array] = None + old_logprobs: Optional[mx.array] = None + old_token_logprobs: Optional[mx.array] = None + reference_logprobs: Optional[mx.array] = None + value_predictions: Optional[mx.array] = None + returns: Optional[mx.array] = None + advantages: Optional[mx.array] = None + labels: Optional[mx.array] = None + prompt_group_indices: Optional[mx.array] = None + sample_indices: Optional[mx.array] = None + token_logprobs: Optional[mx.array] = None + summed_logprobs: Optional[mx.array] = None + + +@dataclass +class RolloutBatch: + prompt_ids: List[List[int]] + prompt_lengths: mx.array + completion_ids: List[List[int]] + completion_lengths: mx.array + prompt_texts: List[str] + original_prompt_texts: Optional[List[str]] + completion_texts: List[str] + reward_contexts: List[Any] + sampled_token_logprobs: mx.array + rollout_logprobs: mx.array + eos_flags: mx.array + truncation_flags: mx.array + prompt_group_indices: mx.array + policy_eval: PolicyEvalBatch + sample_indices: Optional[mx.array] = None + sampled_token_logits: Optional[mx.array] = None + token_entropies: Optional[mx.array] = None + old_logprobs: Optional[mx.array] = None + rewards: Optional[mx.array] = None + reference_logprobs: Optional[mx.array] = None + value_predictions: Optional[mx.array] = None + returns: Optional[mx.array] = None + advantages: Optional[mx.array] = None + + +@dataclass +class RewardBatch: + prompt_texts: List[str] + completion_texts: List[str] + reward_contexts: List[Any] + scalar_rewards: mx.array + prompt_group_indices: mx.array + original_prompt_texts: Optional[List[str]] = None + named_reward_components: Optional[List[Dict[str, float]]] = None + diagnostics: Optional[List[Dict[str, Any]]] = None + + +@dataclass +class PreferenceBatch: + chosen: PolicyEvalBatch + rejected: PolicyEvalBatch + sample_indices: mx.array + chosen_reference_logprobs: Optional[mx.array] = None + rejected_reference_logprobs: Optional[mx.array] = None + + +def pad_sequences(sequences: Sequence[Sequence[int]], pad_id: int) -> tuple[mx.array, mx.array]: + max_length = max(len(sequence) for sequence in sequences) + padded = [list(sequence) + [pad_id] * (max_length - len(sequence)) for sequence in sequences] + lengths = [len(sequence) for sequence in sequences] + return mx.array(padded), mx.array(lengths) + + +def truncate_prompt_tokens(prompt_ids: Sequence[int], max_prompt_length: Optional[int]) -> List[int]: + prompt_tokens = list(prompt_ids) + if max_prompt_length is None or len(prompt_tokens) <= max_prompt_length: + return prompt_tokens + return prompt_tokens[-max_prompt_length:] + + +def truncate_completion_tokens( + completion_ids: Sequence[int], + max_completion_length: Optional[int], +) -> tuple[List[int], bool]: + completion_tokens = list(completion_ids) + if max_completion_length is None or len(completion_tokens) <= max_completion_length: + return completion_tokens, False + return completion_tokens[:max_completion_length], True + + +def cap_prompt_and_completion_lengths( + prompt_ids: Sequence[int], + completion_ids: Sequence[int], + max_seq_length: Optional[int], + max_completion_length: Optional[int], +) -> tuple[List[int], List[int], bool]: + effective_completion_cap = max_completion_length + if max_seq_length is not None: + effective_completion_cap = ( + int(max_seq_length) + if effective_completion_cap is None + else min(int(effective_completion_cap), int(max_seq_length)) + ) + capped_completion_ids, truncated_completion = truncate_completion_tokens( + completion_ids, + effective_completion_cap, + ) + + capped_prompt_ids = list(prompt_ids) + if max_seq_length is not None: + remaining_prompt_budget = max(0, int(max_seq_length) - len(capped_completion_ids)) + capped_prompt_ids = truncate_prompt_tokens(prompt_ids, remaining_prompt_budget) + return capped_prompt_ids, capped_completion_ids, truncated_completion + + +def _resolve_terminal_token_ids(tokenizer: Any) -> set[int]: + terminal_ids: set[int] = set() + + eos_token_id = getattr(tokenizer, "eos_token_id", None) + if eos_token_id is not None: + try: + terminal_ids.add(int(eos_token_id)) + except (TypeError, ValueError): + pass + + stop_token = getattr(tokenizer, "_unsloth_stop_token", None) + if not stop_token: + return terminal_ids + + candidate_ids: list[Any] = [] + if hasattr(tokenizer, "convert_tokens_to_ids"): + try: + candidate_ids.append(tokenizer.convert_tokens_to_ids(stop_token)) + except Exception: + pass + if hasattr(tokenizer, "get_vocab"): + try: + vocab = tokenizer.get_vocab() + if stop_token in vocab: + candidate_ids.append(vocab[stop_token]) + except Exception: + pass + if hasattr(tokenizer, "encode"): + try: + encoded = tokenizer.encode(stop_token, add_special_tokens=False) + if len(encoded) == 1: + candidate_ids.append(encoded[0]) + except Exception: + pass + + for token_id in candidate_ids: + if token_id is None: + continue + try: + terminal_ids.add(int(token_id)) + except (TypeError, ValueError): + continue + return terminal_ids + + +def length_mask(lengths: mx.array, width: int) -> mx.array: + positions = mx.arange(width)[None, :] + return positions < lengths[:, None] + + +def completion_token_mask( + input_ids: mx.array, + prompt_lengths: mx.array, + completion_lengths: mx.array, +) -> mx.array: + width = input_ids.shape[1] - 1 + positions = mx.arange(width)[None, :] + start = mx.maximum(prompt_lengths - 1, 0)[:, None] + end = mx.maximum(prompt_lengths + completion_lengths - 1, 0)[:, None] + return (positions >= start) & (positions < end) + + +def build_token_mask( + input_ids: mx.array, + sequence_lengths: mx.array, + mode: str = "sequence", + prompt_lengths: Optional[mx.array] = None, + completion_lengths: Optional[mx.array] = None, +) -> mx.array: + if mode == "sequence": + return length_mask(sequence_lengths, input_ids.shape[1] - 1) + if mode == "completion": + if prompt_lengths is None or completion_lengths is None: + raise ValueError("Completion scoring requires prompt_lengths and completion_lengths.") + return completion_token_mask(input_ids, prompt_lengths, completion_lengths) + raise ValueError(f"Unsupported scoring mode: {mode}") + + +def make_policy_eval_batch( + sequences: Sequence[Sequence[int]], + pad_id: int, + mode: str = "sequence", + prompt_lengths: Optional[Sequence[int]] = None, + completion_lengths: Optional[Sequence[int]] = None, + rollout_logprobs: Optional[mx.array] = None, + old_logprobs: Optional[mx.array] = None, + old_token_logprobs: Optional[mx.array] = None, + reference_logprobs: Optional[mx.array] = None, + value_predictions: Optional[mx.array] = None, + returns: Optional[mx.array] = None, + advantages: Optional[mx.array] = None, + labels: Optional[mx.array] = None, + prompt_group_indices: Optional[mx.array] = None, + sample_indices: Optional[mx.array] = None, +) -> PolicyEvalBatch: + input_ids, sequence_lengths = pad_sequences(sequences, pad_id) + prompt_lengths_array = mx.array(prompt_lengths) if prompt_lengths is not None else None + completion_lengths_array = mx.array(completion_lengths) if completion_lengths is not None else None + token_mask = build_token_mask( + input_ids=input_ids, + sequence_lengths=sequence_lengths, + mode=mode, + prompt_lengths=prompt_lengths_array, + completion_lengths=completion_lengths_array, + ) + return PolicyEvalBatch( + input_ids=input_ids, + sequence_lengths=sequence_lengths, + prompt_lengths=prompt_lengths_array, + completion_lengths=completion_lengths_array, + token_mask=token_mask, + rollout_logprobs=rollout_logprobs, + old_logprobs=old_logprobs if old_logprobs is not None else rollout_logprobs, + old_token_logprobs=old_token_logprobs, + reference_logprobs=reference_logprobs, + value_predictions=value_predictions, + returns=returns, + advantages=advantages, + labels=labels, + prompt_group_indices=prompt_group_indices, + sample_indices=sample_indices, + ) + + +def make_preference_batch( + chosen_sequences: Sequence[Sequence[int]], + rejected_sequences: Sequence[Sequence[int]], + pad_id: int, + sample_indices: Sequence[int], + chosen_reference_logprobs: Optional[mx.array] = None, + rejected_reference_logprobs: Optional[mx.array] = None, +) -> PreferenceBatch: + return PreferenceBatch( + chosen=make_policy_eval_batch( + chosen_sequences, + pad_id=pad_id, + mode="sequence", + reference_logprobs=chosen_reference_logprobs, + sample_indices=mx.array(sample_indices), + ), + rejected=make_policy_eval_batch( + rejected_sequences, + pad_id=pad_id, + mode="sequence", + reference_logprobs=rejected_reference_logprobs, + sample_indices=mx.array(sample_indices), + ), + sample_indices=mx.array(sample_indices), + chosen_reference_logprobs=chosen_reference_logprobs, + rejected_reference_logprobs=rejected_reference_logprobs, + ) + + +def _token_log_probs( + model: Any, + input_ids: mx.array, + temperature: float = 1.0, +) -> mx.array: + inputs = input_ids[:, :-1] + targets = input_ids[:, 1:] + logits = model(inputs) + if temperature != 1.0: + logits = logits / temperature + log_probs = nn.log_softmax(logits, axis=-1) + return mx.take_along_axis(log_probs, targets[:, :, None], axis=-1).squeeze(-1) + + +def normalize_logprobs( + summed_logprobs: mx.array, + lengths: mx.array, + mode: str = "sum", +) -> mx.array: + if mode == "sum": + return summed_logprobs + if mode in {"mean", "token_mean"}: + return summed_logprobs / mx.maximum(lengths.astype(summed_logprobs.dtype), 1.0) + raise ValueError(f"Unsupported length normalization mode: {mode}") + + +def kl_against_reference( + policy_logprobs: mx.array, + reference_logprobs: mx.array, +) -> mx.array: + log_ratio = policy_logprobs - reference_logprobs + return mx.exp(log_ratio) - log_ratio - 1.0 + + +def score_policy( + model: Any, + batch: PolicyEvalBatch, + mode: str = "sequence", + reference_model: Optional[Any] = None, + temperature: float = 1.0, +) -> PolicyEvalBatch: + token_mask = build_token_mask( + input_ids=batch.input_ids, + sequence_lengths=batch.sequence_lengths, + mode=mode, + prompt_lengths=batch.prompt_lengths, + completion_lengths=batch.completion_lengths, + ) + token_logprobs = _token_log_probs(model, batch.input_ids, temperature=temperature) + summed_logprobs = (token_logprobs * token_mask.astype(token_logprobs.dtype)).sum(axis=-1) + + reference_logprobs = batch.reference_logprobs + if reference_model is not None: + reference_tokens = _token_log_probs(reference_model, batch.input_ids, temperature=temperature) + reference_logprobs = mx.stop_gradient( + (reference_tokens * token_mask.astype(reference_tokens.dtype)).sum(axis=-1) + ) + + return replace( + batch, + token_mask=token_mask, + token_logprobs=token_logprobs, + summed_logprobs=summed_logprobs, + reference_logprobs=reference_logprobs, + ) + + +def score_policy_in_chunks( + model: Any, + batch: PolicyEvalBatch, + batch_size: Optional[int], + mode: str = "sequence", + reference_model: Optional[Any] = None, + temperature: float = 1.0, + token_budget: Optional[int] = None, +) -> PolicyEvalBatch: + if batch.input_ids.shape[0] == 0: + return score_policy( + model, + batch, + mode=mode, + reference_model=reference_model, + temperature=temperature, + ) + + if token_budget is None and batch_size is None: + batch_size = batch.input_ids.shape[0] + + if token_budget is None and batch_size is not None and batch.input_ids.shape[0] <= batch_size: + return score_policy( + model, + batch, + mode=mode, + reference_model=reference_model, + temperature=temperature, + ) + + scored_chunks = [] + for minibatch in assemble_minibatches( + batch, + minibatch_size=batch_size, + shuffle=False, + mode=mode, + token_budget=token_budget, + ): + scored_chunks.append( + score_policy( + model, + minibatch, + mode=mode, + reference_model=reference_model, + temperature=temperature, + ) + ) + return _concat_policy_eval_batches(scored_chunks) + + +def _policy_eval_effective_lengths(batch: PolicyEvalBatch, mode: str) -> List[int]: + if batch.input_ids.shape[0] == 0: + return [] + if mode == "completion" and batch.completion_lengths is not None: + return [max(1, int(value)) for value in batch.completion_lengths.tolist()] + return [max(1, int(value) - 1) for value in batch.sequence_lengths.tolist()] + + +def _concat_policy_eval_batches(chunks: Sequence[PolicyEvalBatch]) -> PolicyEvalBatch: + def concat_attr(name: str): + values = [getattr(chunk, name) for chunk in chunks] + if values[0] is None: + return None + if hasattr(values[0], "shape"): + return mx.concatenate(values, axis=0) + if isinstance(values[0], list): + merged = [] + for value in values: + merged.extend(value) + return merged + raise TypeError(f"Unsupported PolicyEvalBatch field type for concat: {name}") + + return PolicyEvalBatch( + input_ids=concat_attr("input_ids"), + sequence_lengths=concat_attr("sequence_lengths"), + prompt_lengths=concat_attr("prompt_lengths"), + completion_lengths=concat_attr("completion_lengths"), + token_mask=concat_attr("token_mask"), + rollout_logprobs=concat_attr("rollout_logprobs"), + old_logprobs=concat_attr("old_logprobs"), + old_token_logprobs=concat_attr("old_token_logprobs"), + reference_logprobs=concat_attr("reference_logprobs"), + value_predictions=concat_attr("value_predictions"), + returns=concat_attr("returns"), + advantages=concat_attr("advantages"), + labels=concat_attr("labels"), + prompt_group_indices=concat_attr("prompt_group_indices"), + sample_indices=concat_attr("sample_indices"), + token_logprobs=concat_attr("token_logprobs"), + summed_logprobs=concat_attr("summed_logprobs"), + ) + + +def sample_completion( + policy: Any, + tokenizer: Any, + prompt_ids: Sequence[int], + max_tokens: int, + temperature: float, + collect_sample_stats: bool = False, +) -> Dict[str, Any]: + generated_ids = list(prompt_ids) + sampled_logprobs: List[float] = [] + sampled_logits: List[float] = [] + token_entropies: List[float] = [] + saw_eos = False + terminal_token_ids = _resolve_terminal_token_ids(tokenizer) + x = mx.array([generated_ids]) + + for _ in range(max_tokens): + logits = policy(x)[:, -1, :] + if temperature > 0: + scaled = logits / temperature + probs = mx.softmax(scaled, axis=-1) + next_token = mx.random.categorical(mx.log(probs + 1e-10)) + log_probs = nn.log_softmax(scaled, axis=-1) + entropy = -mx.sum(probs * log_probs, axis=-1)[0] + else: + scaled = logits + next_token = mx.argmax(logits, axis=-1) + log_probs = nn.log_softmax(logits, axis=-1) + probs = mx.softmax(logits, axis=-1) + entropy = -mx.sum(probs * log_probs, axis=-1)[0] + + next_token_id = int(next_token.item()) + sampled_logprobs.append(float(log_probs[0, next_token_id].item())) + if collect_sample_stats: + sampled_logits.append(float(scaled[0, next_token_id].item())) + token_entropies.append(float(entropy.item())) + + generated_ids.append(next_token_id) + x = mx.array([generated_ids]) + if next_token_id in terminal_token_ids: + saw_eos = True + break + + completion_ids = generated_ids[len(prompt_ids):] + return { + "generated_ids": generated_ids, + "completion_ids": completion_ids, + "sampled_logprobs": sampled_logprobs, + "sampled_logits": sampled_logits, + "token_entropies": token_entropies, + "eos_flag": saw_eos, + "truncation_flag": (not saw_eos) and len(completion_ids) >= max_tokens, + } + + +class _RewardEvaluatorAdapter: + def __init__(self, evaluator: Any): + self.evaluator = evaluator + self.mode = self._resolve_mode(evaluator) + + def _resolve_mode(self, evaluator: Any) -> str: + if evaluator is None: + return "none" + if hasattr(evaluator, "evaluate_batch"): + return "evaluate_batch" + if hasattr(evaluator, "evaluate"): + return "evaluate" + if not callable(evaluator): + raise TypeError("Reward evaluator must be callable or expose evaluate().") + + try: + signature = inspect.signature(evaluator) + except (TypeError, ValueError): + return "legacy" + + positional = 0 + for parameter in signature.parameters.values(): + if parameter.kind == inspect.Parameter.VAR_POSITIONAL: + return "legacy" + if ( + parameter.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + and parameter.default is inspect._empty + ): + positional += 1 + return "legacy" if positional >= 2 else "structured" + + def evaluate(self, payload: Dict[str, Any]) -> tuple[float, Optional[Dict[str, float]], Optional[Dict[str, Any]]]: + if self.mode == "none": + result = 0.0 + elif self.mode == "evaluate_batch": + result = self.evaluator.evaluate_batch([payload])[0] + elif self.mode == "evaluate": + result = self.evaluator.evaluate(payload) + elif self.mode == "structured": + result = self.evaluator(payload) + else: + result = self.evaluator(payload["completion_text"], payload["reward_context"]) + return self._normalize_result(result) + + def _normalize_result( + self, + result: Any, + ) -> tuple[float, Optional[Dict[str, float]], Optional[Dict[str, Any]]]: + if isinstance(result, Mapping): + reward = float(result.get("reward", result.get("score", 0.0))) + components = result.get("components") + diagnostics = result.get("diagnostics") + return reward, dict(components) if components is not None else None, diagnostics + return float(result), None, None + + def evaluate_batch( + self, + payloads: Sequence[Dict[str, Any]], + ) -> List[tuple[float, Optional[Dict[str, float]], Optional[Dict[str, Any]]]]: + if self.mode == "none": + return [(0.0, None, None) for _ in payloads] + if hasattr(self.evaluator, "evaluate_batch"): + results = self.evaluator.evaluate_batch(list(payloads)) + return [self._normalize_result(result) for result in results] + return [self.evaluate(payload) for payload in payloads] + + +def _batched_last_token_logits( + policy: Any, + input_rows: Sequence[Sequence[int]], + pad_id: int, + cache_state: Any = None, +) -> Tuple[mx.array, Any]: + if not input_rows: + return mx.zeros((0, 0), dtype=mx.float32), cache_state + + lengths = [len(row) for row in input_rows] + call_signature = None + try: + call_signature = inspect.signature(policy.__call__) + except (TypeError, ValueError): + call_signature = None + use_cache = ( + cache_state is not False + and len(input_rows) == 1 + and len(set(lengths)) == 1 + and ( + hasattr(policy, "forward_with_cache") + or (call_signature is not None and "cache" in call_signature.parameters) + ) + ) + # Batched KV-cache decoding can preserve token choices while still + # producing incorrect per-token logprobs on real Qwen-family models. + # GRPO relies on those sampled logprobs matching a fresh rescore of the + # same completion, so shared cache reuse is limited to single-row decoding. + if use_cache: + prefill = cache_state is None + if prefill: + try: + from mlx_lm.models.cache import make_prompt_cache + + cache_state = make_prompt_cache(getattr(policy, "model", policy)) + except Exception: + cache_state = None + inputs = mx.array(input_rows if prefill else [[row[-1]] for row in input_rows]) + try: + if hasattr(policy, "forward_with_cache"): + outputs = policy.forward_with_cache(inputs, cache=cache_state) + else: + outputs = policy(inputs, cache=cache_state) + if isinstance(outputs, tuple) and len(outputs) == 2: + logits, next_cache = outputs + return logits[:, -1, :], next_cache + if hasattr(outputs, "shape"): + return outputs[:, -1, :], cache_state + except Exception: + cache_state = False + + padded, seq_lengths = pad_sequences(input_rows, pad_id) + logits = policy(padded) + row_positions = mx.arange(padded.shape[0]) + last_positions = seq_lengths - 1 + return logits[row_positions, last_positions, :], cache_state + + +def collect_rollouts( + policy: Any, + tokenizer: Any, + prompt_samples: Sequence[Dict[str, Any]], + sampling_config: Dict[str, Any], + reward_evaluator: Any = None, + collect_sample_stats: bool = False, +) -> RolloutBatch: + num_generations = int(sampling_config.get("num_generations", 1)) + max_completion_length = int(sampling_config.get("max_tokens", sampling_config.get("max_completion_length", 256))) + max_seq_length = sampling_config.get("max_seq_length") + temperature = float(sampling_config.get("temperature", 0.7)) + max_prompt_length = None + if max_seq_length is not None: + # Preserve as much prompt context as possible and spend only the + # remaining budget on completion tokens. Reserving the full requested + # completion cap up front can collapse the prompt to an unusable suffix. + max_prompt_length = max(0, int(max_seq_length) - 1) + + generation_batch_size = int(sampling_config.get("generation_batch_size") or 0) + generation_batch_size = max(1, generation_batch_size or (len(prompt_samples) * max(1, num_generations))) + pad_id = int(getattr(tokenizer, "pad_token_id", 0) or 0) + terminal_token_ids = _resolve_terminal_token_ids(tokenizer) + + expanded_samples: List[Dict[str, Any]] = [] + for group_index, sample in enumerate(prompt_samples): + sample_index = int(sample.get("sample_index", group_index)) + original_prompt_text = sample.get("prompt", sample.get("prompt_text", "")) + reward_context = sample.get("reward_context") + prepared_prompt_ids = truncate_prompt_tokens(sample.get("prompt_ids", []), max_prompt_length) + effective_prompt_text = tokenizer.decode(prepared_prompt_ids) + generation_budget = max_completion_length + if max_seq_length is not None: + generation_budget = min(max_completion_length, max(0, int(max_seq_length) - len(prepared_prompt_ids))) + for _ in range(num_generations): + expanded_samples.append( + { + "prompt_ids": list(prepared_prompt_ids), + "prompt_text": effective_prompt_text, + "original_prompt_text": original_prompt_text, + "reward_context": reward_context, + "sample_index": sample_index, + "prompt_group_index": group_index, + "generated_ids": list(prepared_prompt_ids), + "sampled_logprobs": [], + "sampled_logits": [], + "token_entropies": [], + "eos_flag": False, + "done": generation_budget == 0, + "max_completion_tokens": generation_budget, + "hit_length_limit": generation_budget == 0, + "cache_state": None, + } + ) + + for start in range(0, len(expanded_samples), generation_batch_size): + chunk = expanded_samples[start:start + generation_batch_size] + for _ in range(max_completion_length): + active_rows = [row for row in chunk if not row["done"]] + if not active_rows: + break + + row_logits = [] + for row in active_rows: + logits, row["cache_state"] = _batched_last_token_logits( + policy, + [row["generated_ids"]], + pad_id=pad_id, + cache_state=row.get("cache_state"), + ) + row_logits.append(logits) + logits = mx.concatenate(row_logits, axis=0) if row_logits else mx.zeros((0, 0), dtype=mx.float32) + if temperature > 0: + scaled = logits / temperature + log_probs = nn.log_softmax(scaled, axis=-1) + probs = mx.softmax(scaled, axis=-1) + entropies = -mx.sum(probs * log_probs, axis=-1) + next_tokens = [int(value) for value in mx.random.categorical(log_probs).tolist()] + else: + scaled = logits + log_probs = nn.log_softmax(logits, axis=-1) + probs = mx.softmax(logits, axis=-1) + entropies = -mx.sum(probs * log_probs, axis=-1) + next_tokens = [int(value) for value in mx.argmax(logits, axis=-1).tolist()] + + for row_index, row in enumerate(active_rows): + token_id = next_tokens[row_index] + row["generated_ids"].append(token_id) + row["sampled_logprobs"].append(float(log_probs[row_index, token_id].item())) + if collect_sample_stats: + row["sampled_logits"].append(float(scaled[row_index, token_id].item())) + row["token_entropies"].append(float(entropies[row_index].item())) + if token_id in terminal_token_ids: + row["eos_flag"] = True + row["done"] = True + elif len(row["generated_ids"]) - len(row["prompt_ids"]) >= row["max_completion_tokens"]: + row["done"] = True + row["hit_length_limit"] = True + + prompt_texts: List[str] = [] + original_prompt_texts: List[str] = [] + prompt_ids: List[List[int]] = [] + prompt_lengths: List[int] = [] + completion_ids: List[List[int]] = [] + completion_lengths: List[int] = [] + completion_texts: List[str] = [] + reward_contexts: List[Any] = [] + rollout_logprobs: List[float] = [] + sampled_logprob_rows: List[List[float]] = [] + sampled_logit_rows: List[List[float]] = [] + entropy_rows: List[List[float]] = [] + eos_flags: List[bool] = [] + truncation_flags: List[bool] = [] + prompt_group_indices: List[int] = [] + sample_indices: List[int] = [] + + for row in expanded_samples: + raw_completion_ids = row["generated_ids"][len(row["prompt_ids"]):] + prepared_completion_ids, truncated = truncate_completion_tokens( + raw_completion_ids, + max_completion_length, + ) + sampled_logprobs = row["sampled_logprobs"][: len(prepared_completion_ids)] + sampled_logits = row["sampled_logits"][: len(prepared_completion_ids)] + entropies = row["token_entropies"][: len(prepared_completion_ids)] + + prompt_texts.append(row["prompt_text"]) + original_prompt_texts.append(row["original_prompt_text"]) + prompt_ids.append(list(row["prompt_ids"])) + prompt_lengths.append(len(row["prompt_ids"])) + reward_contexts.append(row["reward_context"]) + completion_ids.append(prepared_completion_ids) + completion_lengths.append(len(prepared_completion_ids)) + completion_texts.append(tokenizer.decode(prepared_completion_ids)) + sampled_logprob_rows.append(sampled_logprobs) + rollout_logprobs.append(sum(sampled_logprobs)) + eos_flags.append(bool(row["eos_flag"])) + truncation_flags.append( + bool(((not row["eos_flag"]) and row.get("hit_length_limit", False)) or truncated) + ) + prompt_group_indices.append(int(row["prompt_group_index"])) + sample_indices.append(int(row["sample_index"])) + if collect_sample_stats: + sampled_logit_rows.append(sampled_logits) + entropy_rows.append(entropies) + + max_completion_width = max(completion_lengths) if completion_lengths else 0 + padded_token_logprobs, _ = pad_sequences( + [ + [float(value) for value in row] + [0.0] * (max_completion_width - len(row)) + for row in sampled_logprob_rows + ] if sampled_logprob_rows else [[0.0]], + 0, + ) + if not sampled_logprob_rows: + padded_token_logprobs = mx.zeros((0, 0), dtype=mx.float32) + else: + padded_token_logprobs = padded_token_logprobs.astype(mx.float32) + + sampled_token_logits = None + token_entropies = None + if collect_sample_stats: + if sampled_logit_rows: + sampled_token_logits, _ = pad_sequences( + [ + [float(value) for value in row] + [0.0] * (max_completion_width - len(row)) + for row in sampled_logit_rows + ], + 0, + ) + token_entropies, _ = pad_sequences( + [ + [float(value) for value in row] + [0.0] * (max_completion_width - len(row)) + for row in entropy_rows + ], + 0, + ) + sampled_token_logits = sampled_token_logits.astype(mx.float32) + token_entropies = token_entropies.astype(mx.float32) + else: + sampled_token_logits = mx.zeros((0, 0), dtype=mx.float32) + token_entropies = mx.zeros((0, 0), dtype=mx.float32) + + full_sequences = [ + prompt_sequence + completion_sequence + for prompt_sequence, completion_sequence in zip(prompt_ids, completion_ids) + ] + aligned_old_token_rows = [ + [0.0] * max(prompt_length - 1, 0) + [float(value) for value in row] + for prompt_length, row in zip(prompt_lengths, sampled_logprob_rows) + ] + aligned_old_token_logprobs, _ = pad_sequences( + aligned_old_token_rows if aligned_old_token_rows else [[0.0]], + 0, + ) + if not aligned_old_token_rows: + aligned_old_token_logprobs = mx.zeros((0, 0), dtype=mx.float32) + else: + aligned_old_token_logprobs = aligned_old_token_logprobs.astype(mx.float32) + policy_eval = make_policy_eval_batch( + full_sequences, + pad_id=pad_id, + mode="completion", + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + rollout_logprobs=mx.array(rollout_logprobs, dtype=mx.float32), + old_logprobs=mx.array(rollout_logprobs, dtype=mx.float32), + old_token_logprobs=aligned_old_token_logprobs, + prompt_group_indices=mx.array(prompt_group_indices), + sample_indices=mx.array(sample_indices), + ) + rollout_batch = RolloutBatch( + prompt_ids=prompt_ids, + prompt_lengths=mx.array(prompt_lengths), + completion_ids=completion_ids, + completion_lengths=mx.array(completion_lengths), + prompt_texts=prompt_texts, + original_prompt_texts=original_prompt_texts, + completion_texts=completion_texts, + reward_contexts=reward_contexts, + sampled_token_logprobs=padded_token_logprobs, + rollout_logprobs=mx.array(rollout_logprobs, dtype=mx.float32), + old_logprobs=mx.array(rollout_logprobs, dtype=mx.float32), + eos_flags=mx.array(eos_flags), + truncation_flags=mx.array(truncation_flags), + prompt_group_indices=mx.array(prompt_group_indices), + policy_eval=policy_eval, + sample_indices=mx.array(sample_indices), + sampled_token_logits=sampled_token_logits, + token_entropies=token_entropies, + ) + if reward_evaluator is not None: + reward_batch = evaluate_rewards(rollout_batch, reward_evaluator) + rollout_batch.rewards = reward_batch.scalar_rewards + return rollout_batch + + +def evaluate_rewards(rollout_batch: RolloutBatch, evaluator: Any) -> RewardBatch: + adapter = _RewardEvaluatorAdapter(evaluator) + payloads: List[Dict[str, Any]] = [] + for index in range(len(rollout_batch.prompt_texts)): + payloads.append( + { + "prompt_text": rollout_batch.prompt_texts[index], + "original_prompt_text": ( + rollout_batch.original_prompt_texts[index] + if rollout_batch.original_prompt_texts is not None + else rollout_batch.prompt_texts[index] + ), + "completion_text": rollout_batch.completion_texts[index], + "reward_context": rollout_batch.reward_contexts[index], + "prompt_ids": list(rollout_batch.prompt_ids[index]), + "completion_ids": list(rollout_batch.completion_ids[index]), + "prompt_length": int(rollout_batch.prompt_lengths[index].item()), + "completion_length": int(rollout_batch.completion_lengths[index].item()), + "eos_flag": bool(rollout_batch.eos_flags[index].item()), + "truncation_flag": bool(rollout_batch.truncation_flags[index].item()), + "prompt_group_index": int(rollout_batch.prompt_group_indices[index].item()), + "sample_index": ( + int(rollout_batch.sample_indices[index].item()) + if rollout_batch.sample_indices is not None + else index + ), + } + ) + + scalar_rewards: List[float] = [] + named_components: List[Dict[str, float]] = [] + diagnostics: List[Dict[str, Any]] = [] + for reward, components, sample_diagnostics in adapter.evaluate_batch(payloads): + scalar_rewards.append(reward) + named_components.append(components or {}) + diagnostics.append(sample_diagnostics or {}) + + has_components = any(component for component in named_components) + has_diagnostics = any(diagnostic for diagnostic in diagnostics) + return RewardBatch( + prompt_texts=list(rollout_batch.prompt_texts), + completion_texts=list(rollout_batch.completion_texts), + reward_contexts=list(rollout_batch.reward_contexts), + scalar_rewards=mx.array(scalar_rewards, dtype=mx.float32), + prompt_group_indices=mx.array(rollout_batch.prompt_group_indices), + original_prompt_texts=list(rollout_batch.original_prompt_texts) + if rollout_batch.original_prompt_texts is not None + else None, + named_reward_components=named_components if has_components else None, + diagnostics=diagnostics if has_diagnostics else None, + ) + + +def compute_advantages( + reward_batch: RewardBatch, + grouping: str = "per_prompt", + normalization: str = "zscore_if_nonzero_else_center", +) -> mx.array: + rewards = reward_batch.scalar_rewards.astype(mx.float32) + if grouping != "per_prompt": + raise ValueError(f"Unsupported advantage grouping: {grouping}") + if normalization != "zscore_if_nonzero_else_center": + raise ValueError(f"Unsupported advantage normalization: {normalization}") + + reward_values = rewards.tolist() + group_values = reward_batch.prompt_group_indices.tolist() + advantages = [0.0] * len(reward_values) + grouped_positions: Dict[int, List[int]] = {} + for position, group_value in enumerate(group_values): + grouped_positions.setdefault(int(group_value), []).append(position) + + for positions in grouped_positions.values(): + group_rewards = mx.array([reward_values[position] for position in positions], dtype=mx.float32) + group_std = mx.std(group_rewards) + if float(group_std.item()) < 1e-6: + group_advantages = group_rewards - mx.mean(group_rewards) + else: + group_advantages = (group_rewards - mx.mean(group_rewards)) / (group_std + 1e-8) + for offset, position in enumerate(positions): + advantages[position] = float(group_advantages[offset].item()) + return mx.array(advantages, dtype=mx.float32) + + +def score_rollout_references( + reference_model: Any, + rollout_batch: RolloutBatch, + batch_size: Optional[int] = 8, + temperature: float = 1.0, + token_budget: Optional[int] = None, +) -> RolloutBatch: + if reference_model is None: + return rollout_batch + + scored = score_policy_in_chunks( + reference_model, + rollout_batch.policy_eval, + batch_size=batch_size, + token_budget=token_budget, + mode="completion", + temperature=temperature, + ) + reference_logprobs = mx.stop_gradient(scored.summed_logprobs.astype(mx.float32)) + return replace( + rollout_batch, + reference_logprobs=reference_logprobs, + policy_eval=replace( + rollout_batch.policy_eval, + reference_logprobs=reference_logprobs, + ), + ) + + +def predict_rollout_values( + value_model: Any, + rollout_batch: RolloutBatch, + batch_size: Optional[int] = 8, + token_budget: Optional[int] = None, +) -> RolloutBatch: + if value_model is None: + return rollout_batch + + value_chunks = [] + for minibatch in assemble_minibatches( + rollout_batch.policy_eval, + minibatch_size=batch_size, + shuffle=False, + mode="completion", + token_budget=token_budget, + ): + value_chunks.append( + value_model.predict( + minibatch.input_ids, + sequence_lengths=minibatch.sequence_lengths, + prompt_lengths=minibatch.prompt_lengths + if minibatch.prompt_lengths is not None + else None, + completion_lengths=minibatch.completion_lengths + if minibatch.completion_lengths is not None + else None, + ) + ) + value_predictions = ( + mx.concatenate(value_chunks, axis=0) if value_chunks else mx.zeros((0,), dtype=mx.float32) + ) + value_predictions = mx.stop_gradient(value_predictions.astype(mx.float32)) + return replace( + rollout_batch, + value_predictions=value_predictions, + policy_eval=replace( + rollout_batch.policy_eval, + value_predictions=value_predictions, + ), + ) + + +def compute_returns_and_advantages( + rewards: mx.array, + values: Optional[mx.array] = None, + prompt_group_indices: Optional[mx.array] = None, + mode: str = "gae", + gamma: float = 1.0, + gae_lambda: float = 1.0, + normalize: bool = False, +) -> tuple[mx.array, mx.array]: + rewards = rewards.astype(mx.float32) + values = mx.zeros_like(rewards) if values is None else values.astype(mx.float32) + + if mode == "gae": + deltas = rewards - values + advantages = deltas * gae_lambda + deltas * (1.0 - gae_lambda) + returns = advantages + values + elif mode == "group_zscore": + if prompt_group_indices is None: + raise ValueError("prompt_group_indices is required for grouped advantages.") + reward_batch = RewardBatch( + prompt_texts=[""] * rewards.shape[0], + completion_texts=[""] * rewards.shape[0], + reward_contexts=[None] * rewards.shape[0], + scalar_rewards=rewards, + prompt_group_indices=prompt_group_indices, + ) + advantages = compute_advantages(reward_batch) + returns = rewards + elif mode == "group_center": + if prompt_group_indices is None: + raise ValueError("prompt_group_indices is required for grouped advantages.") + reward_values = rewards.tolist() + advantages_list = [0.0] * len(reward_values) + grouped_positions: Dict[int, List[int]] = {} + for position, group_value in enumerate(prompt_group_indices.tolist()): + grouped_positions.setdefault(int(group_value), []).append(position) + for positions in grouped_positions.values(): + group_rewards = [reward_values[position] for position in positions] + baseline = sum(group_rewards) / float(len(group_rewards)) + for offset, position in enumerate(positions): + advantages_list[position] = group_rewards[offset] - baseline + advantages = mx.array(advantages_list, dtype=mx.float32) + returns = rewards + elif mode == "rloo": + if prompt_group_indices is None: + raise ValueError("prompt_group_indices is required for RLOO advantages.") + reward_values = rewards.tolist() + advantages_list = [0.0] * len(reward_values) + grouped_positions: Dict[int, List[int]] = {} + for position, group_value in enumerate(prompt_group_indices.tolist()): + grouped_positions.setdefault(int(group_value), []).append(position) + for positions in grouped_positions.values(): + if len(positions) <= 1: + continue + group_rewards = [reward_values[position] for position in positions] + total = sum(group_rewards) + denominator = float(len(positions) - 1) + for offset, position in enumerate(positions): + baseline = (total - group_rewards[offset]) / denominator + advantages_list[position] = group_rewards[offset] - baseline + advantages = mx.array(advantages_list, dtype=mx.float32) + returns = rewards + else: + raise ValueError(f"Unsupported returns/advantages mode: {mode}") + + if normalize and advantages.shape[0] > 1: + advantages = (advantages - mx.mean(advantages)) / (mx.std(advantages) + 1e-8) + + if gamma != 1.0: + returns = rewards + gamma * (returns - rewards) + return returns.astype(mx.float32), advantages.astype(mx.float32) + + +def rank_grouped_rollouts( + rollout_batch: RolloutBatch, + score_tolerance: float = 1e-6, +) -> List[Dict[str, Any]]: + if rollout_batch.rewards is None: + raise ValueError("Rollout rewards are required for grouped ranking.") + + grouped_positions: Dict[int, List[int]] = {} + for position, group_value in enumerate(rollout_batch.prompt_group_indices.tolist()): + grouped_positions.setdefault(int(group_value), []).append(position) + + reward_values = rollout_batch.rewards.tolist() + rankings: List[Dict[str, Any]] = [] + for group_index, positions in grouped_positions.items(): + ordered_positions = sorted(positions, key=lambda position: reward_values[position], reverse=True) + ordered_scores = [reward_values[position] for position in ordered_positions] + rankings.append( + { + "prompt_group_index": group_index, + "positions": positions, + "ordered_positions": ordered_positions, + "scores": ordered_scores, + "best_position": ordered_positions[0], + "worst_position": ordered_positions[-1], + "all_tied": ( + len(ordered_scores) <= 1 + or max(ordered_scores) - min(ordered_scores) <= score_tolerance + ), + } + ) + return rankings + + +def _slice_value(value: Any, indices: Sequence[int]) -> Any: + if value is None: + return None + if hasattr(value, "shape"): + return value[mx.array(indices)] + if isinstance(value, list): + return [value[index] for index in indices] + if isinstance(value, tuple): + return tuple(value[index] for index in indices) + raise TypeError(f"Unsupported minibatch field type: {type(value)!r}") + + +def assemble_minibatches( + batch: PolicyEvalBatch, + minibatch_size: Optional[int], + shuffle: bool = False, + mode: str = "sequence", + token_budget: Optional[int] = None, +) -> Iterable[PolicyEvalBatch]: + batch_size = batch.input_ids.shape[0] + if batch_size == 0: + return [] + + if shuffle: + order = [int(value) for value in mx.random.permutation(batch_size).tolist()] + else: + order = list(range(batch_size)) + + effective_lengths = _policy_eval_effective_lengths(batch, mode) + row_budget = max(1, minibatch_size or batch_size) + token_budget = max(1, token_budget) if token_budget is not None else None + + batches_of_indices: List[List[int]] = [] + current_indices: List[int] = [] + current_tokens = 0 + for index in order: + row_tokens = effective_lengths[index] + exceeds_token_budget = token_budget is not None and current_indices and current_tokens + row_tokens > token_budget + exceeds_row_budget = current_indices and len(current_indices) >= row_budget + if exceeds_token_budget or exceeds_row_budget: + batches_of_indices.append(current_indices) + current_indices = [] + current_tokens = 0 + current_indices.append(index) + current_tokens += row_tokens + if current_indices: + batches_of_indices.append(current_indices) + + minibatches: List[PolicyEvalBatch] = [] + for indices in batches_of_indices: + values = { + field.name: _slice_value(getattr(batch, field.name), indices) + for field in fields(batch) + } + minibatches.append(PolicyEvalBatch(**values)) + return minibatches + + +def summarize_rollout_metrics( + rollout_batch: RolloutBatch, + policy_loss: Optional[float] = None, + value_loss: Optional[float] = None, + reward_loss: Optional[float] = None, +) -> Dict[str, float]: + metrics: Dict[str, float] = {} + if rollout_batch.rewards is not None and rollout_batch.rewards.shape[0] > 0: + metrics["reward_mean"] = float(mx.mean(rollout_batch.rewards).item()) + metrics["reward_std"] = float(mx.std(rollout_batch.rewards).item()) + if rollout_batch.reference_logprobs is not None and rollout_batch.rollout_logprobs.shape[0] > 0: + completion_lengths = mx.maximum(rollout_batch.completion_lengths.astype(mx.float32), 1.0) + log_ratio = ( + rollout_batch.rollout_logprobs.astype(mx.float32) + - rollout_batch.reference_logprobs.astype(mx.float32) + ) + metrics["logprob_delta_mean"] = float(mx.mean(log_ratio).item()) + metrics["logprob_delta_per_token_mean"] = float(mx.mean(log_ratio / completion_lengths).item()) + normalized_policy = normalize_logprobs( + rollout_batch.rollout_logprobs.astype(mx.float32), + completion_lengths, + mode="mean", + ) + normalized_reference = normalize_logprobs( + rollout_batch.reference_logprobs.astype(mx.float32), + completion_lengths, + mode="mean", + ) + kl_values = kl_against_reference( + normalized_policy, + normalized_reference, + ) + kl_mean = float(mx.mean(kl_values).item()) + if kl_mean != float("inf"): + metrics["kl_to_reference_mean"] = kl_mean + if rollout_batch.token_entropies is not None and rollout_batch.completion_lengths.shape[0] > 0: + entropy_mask = length_mask( + rollout_batch.completion_lengths, + rollout_batch.token_entropies.shape[1], + ).astype(rollout_batch.token_entropies.dtype) + valid_tokens = float(mx.sum(entropy_mask).item()) + if valid_tokens > 0: + metrics["entropy_mean"] = float( + (mx.sum(rollout_batch.token_entropies * entropy_mask) / valid_tokens).item() + ) + if rollout_batch.completion_lengths.shape[0] > 0: + metrics["completion_length_mean"] = float(mx.mean(rollout_batch.completion_lengths.astype(mx.float32)).item()) + metrics["completion_length_max"] = float(mx.max(rollout_batch.completion_lengths.astype(mx.float32)).item()) + if rollout_batch.eos_flags is not None and rollout_batch.eos_flags.shape[0] > 0: + metrics["eos_rate"] = float(mx.mean(rollout_batch.eos_flags.astype(mx.float32)).item()) + if rollout_batch.truncation_flags is not None and rollout_batch.truncation_flags.shape[0] > 0: + metrics["truncation_rate"] = float(mx.mean(rollout_batch.truncation_flags.astype(mx.float32)).item()) + if policy_loss is not None: + metrics["policy_loss"] = float(policy_loss) + if value_loss is not None: + metrics["value_loss"] = float(value_loss) + if reward_loss is not None: + metrics["reward_loss"] = float(reward_loss) + return metrics diff --git a/mlx_tune/arithmetic_grpo_validation.py b/mlx_tune/arithmetic_grpo_validation.py new file mode 100644 index 0000000..4436c4b --- /dev/null +++ b/mlx_tune/arithmetic_grpo_validation.py @@ -0,0 +1,786 @@ +""" +Deterministic arithmetic benchmark and GRPO validation for native-thinking Qwen 3 models. +""" + +from __future__ import annotations + +import argparse +import json +import random +import re +from pathlib import Path +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple + +from mlx_lm.sample_utils import make_sampler + +from mlx_tune import ( + FastLanguageModel, + GRPOConfig, + GRPOTrainer, + get_chat_template, +) + + +DEFAULT_MODEL_NAME = "mlx-community/Qwen3-1.7B-4bit" +DEFAULT_SYSTEM_PROMPT = ( + "You are a careful arithmetic solver. Think freely if useful. " + "Put the final integer answer inside ...." +) +DEFAULT_OUTPUT_DIR = Path("./artifacts/qwen3_arithmetic_grpo_validation") +DEFAULT_RL_SUBDIR = "rl_run" +DEFAULT_TRAIN_SIZE = 3000 +DEFAULT_VAL_SIZE = 300 +DEFAULT_TEST_SIZE = 300 +DEFAULT_SEED = 0 +DEFAULT_MAX_COMPLETION_LENGTH = 512 +DEFAULT_MAX_SEQ_LENGTH = 768 +DEFAULT_LORA_RANK = 16 +DEFAULT_RL_TEMPERATURE = 0.9 +DEFAULT_BASELINE_TEMPERATURE = 0.0 + +_INTEGER_RE = re.compile(r"^[+-]?\d+$") +_SOLUTION_RE = re.compile(r"(.*?)", re.IGNORECASE | re.DOTALL) + + +def _write_json(path: Path, payload: Mapping[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as handle: + json.dump(payload, handle, indent=2) + + +def _write_jsonl(path: Path, rows: Iterable[Mapping[str, Any]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as handle: + for row in rows: + handle.write(json.dumps(dict(row)) + "\n") + + +def _read_json(path: Path) -> Dict[str, Any]: + with open(path) as handle: + return json.load(handle) + + +def _read_jsonl(path: Path) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path) as handle: + for line in handle: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def _dataset_dir(output_dir: Path) -> Path: + return output_dir / "datasets" + + +def _dataset_path(output_dir: Path, split: str) -> Path: + return _dataset_dir(output_dir) / f"{split}.jsonl" + + +def _rl_output_dir(output_dir: Path) -> Path: + return output_dir / DEFAULT_RL_SUBDIR + + +def _rl_adapter_dir(output_dir: Path) -> Path: + return _rl_output_dir(output_dir) / "policy" + + +def _baseline_outputs_path(output_dir: Path) -> Path: + return output_dir / "baseline_outputs.jsonl" + + +def _baseline_metrics_path(output_dir: Path) -> Path: + return output_dir / "baseline_metrics.json" + + +def _post_outputs_path(output_dir: Path) -> Path: + return output_dir / "post_rl_outputs.jsonl" + + +def _post_metrics_path(output_dir: Path) -> Path: + return output_dir / "post_rl_metrics.json" + + +def _comparison_json_path(output_dir: Path) -> Path: + return output_dir / "comparison.json" + + +def _comparison_md_path(output_dir: Path) -> Path: + return output_dir / "comparison.md" + + +def _training_summary_path(output_dir: Path) -> Path: + return output_dir / "rl_training_summary.json" + + +def _safe_eval(expression: str) -> int: + return int(eval(expression, {"__builtins__": {}}, {})) + + +def _make_easy_expression(rng: random.Random) -> Tuple[str, int]: + left = rng.randint(0, 50) + right = rng.randint(0, 50) + operator = rng.choice(["+", "-"]) + expression = f"{left} {operator} {right}" + return expression, _safe_eval(expression) + + +def _make_medium_expression(rng: random.Random) -> Tuple[str, int]: + values = [rng.randint(0, 20) for _ in range(3)] + operators = [rng.choice(["+", "-", "*"]) for _ in range(2)] + mode = rng.choice(["plain", "left_paren", "right_paren"]) + if mode == "left_paren": + expression = f"({values[0]} {operators[0]} {values[1]}) {operators[1]} {values[2]}" + elif mode == "right_paren": + expression = f"{values[0]} {operators[0]} ({values[1]} {operators[1]} {values[2]})" + else: + expression = f"{values[0]} {operators[0]} {values[1]} {operators[1]} {values[2]}" + return expression, _safe_eval(expression) + + +def build_sample(task_id: str, expression: str, answer: int, difficulty: str) -> Dict[str, Any]: + return { + "task_id": task_id, + "messages": [ + {"role": "system", "content": DEFAULT_SYSTEM_PROMPT}, + {"role": "user", "content": f"Compute exactly:\n{expression}"}, + ], + "answer": str(answer), + "expression": expression, + "difficulty": difficulty, + } + + +def generate_benchmark_splits( + output_dir: Path, + train_size: int = DEFAULT_TRAIN_SIZE, + val_size: int = DEFAULT_VAL_SIZE, + test_size: int = DEFAULT_TEST_SIZE, + seed: int = DEFAULT_SEED, + force: bool = False, +) -> Dict[str, Path]: + output_dir.mkdir(parents=True, exist_ok=True) + dataset_paths = { + "train": _dataset_path(output_dir, "train"), + "val": _dataset_path(output_dir, "val"), + "test": _dataset_path(output_dir, "test"), + } + if not force and all(path.exists() for path in dataset_paths.values()): + return dataset_paths + + rng = random.Random(seed) + existing_expressions = set() + split_sizes = {"train": train_size, "val": val_size, "test": test_size} + + for split_name, split_size in split_sizes.items(): + rows: List[Dict[str, Any]] = [] + while len(rows) < split_size: + difficulty = "easy" if (len(rows) % 2 == 0) else "medium" + if difficulty == "easy": + expression, answer = _make_easy_expression(rng) + else: + expression, answer = _make_medium_expression(rng) + if expression in existing_expressions: + continue + if abs(answer) > 999: + continue + existing_expressions.add(expression) + task_id = f"{split_name}-{len(rows) + 1:06d}" + rows.append(build_sample(task_id, expression, answer, difficulty)) + _write_jsonl(dataset_paths[split_name], rows) + + return dataset_paths + + +def parse_solution_response(response_text: str) -> Dict[str, Any]: + matches = _SOLUTION_RE.findall(response_text) + single_solution_tag = len(matches) == 1 + solution_text = matches[0].strip() if single_solution_tag else None + parseable_solution = bool(solution_text is not None and _INTEGER_RE.fullmatch(solution_text)) + parsed_answer = int(solution_text) if parseable_solution else None + return { + "single_solution_tag": single_solution_tag, + "multiple_solution_tags": len(matches) > 1, + "solution_text": solution_text, + "parseable_solution": parseable_solution, + "parsed_answer": parsed_answer, + } + + +def score_solution_output(response_text: str, gold_answer: str) -> Dict[str, Any]: + parsed = parse_solution_response(response_text) + exact_match = parsed["parseable_solution"] and str(parsed["parsed_answer"]) == str(gold_answer).strip() + tag_reward = 0.1 if parsed["single_solution_tag"] else 0.0 + correctness_reward = 1.0 if exact_match else 0.0 + reward = tag_reward + correctness_reward + return { + **parsed, + "exact_match": exact_match, + "tag_reward": tag_reward, + "correctness_reward": correctness_reward, + "reward": reward, + } + + +class ArithmeticSolutionReward: + def evaluate(self, payload: Mapping[str, Any]) -> Dict[str, Any]: + scored = score_solution_output( + str(payload.get("completion_text", "")), + str(payload.get("reward_context", "")), + ) + return { + "reward": scored["reward"], + "components": { + "solution_tag": scored["tag_reward"], + "correctness": scored["correctness_reward"], + }, + "diagnostics": { + "solution_text": scored["solution_text"], + "parseable_solution": scored["parseable_solution"], + "multiple_solution_tags": scored["multiple_solution_tags"], + "exact_match": scored["exact_match"], + }, + } + + +def _token_length(tokenizer: Any, text: str) -> int: + try: + return len(tokenizer.encode(text, add_special_tokens=False)) + except TypeError: + return len(tokenizer.encode(text)) + + +def load_model_bundle( + model_name: str, + *, + max_seq_length: int, + load_adapter_path: Optional[Path] = None, + for_training: bool = False, + lora_rank: int = DEFAULT_LORA_RANK, +) -> Tuple[Any, Any]: + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=model_name, + max_seq_length=max_seq_length, + load_in_4bit=True, + ) + tokenizer = get_chat_template(tokenizer, chat_template="qwen-3") + if for_training: + model = FastLanguageModel.get_peft_model( + model, + r=lora_rank, + max_seq_length=max_seq_length, + ) + if load_adapter_path is not None and load_adapter_path.exists(): + model.load_adapter(str(load_adapter_path)) + if not for_training: + FastLanguageModel.for_inference(model) + return model, tokenizer + + +def generate_completion( + model: Any, + tokenizer: Any, + messages: Sequence[Mapping[str, Any]], + *, + max_tokens: int, + temperature: float, +) -> str: + prompt = tokenizer.apply_chat_template( + list(messages), + add_generation_prompt=True, + tokenize=False, + ) + return model.generate( + prompt=prompt, + max_tokens=max_tokens, + sampler=make_sampler(temp=temperature), + verbose=False, + ) + + +def _load_split_records(output_dir: Path, split: str) -> List[Dict[str, Any]]: + path = _dataset_path(output_dir, split) + if not path.exists(): + raise FileNotFoundError(f"Missing dataset split: {path}") + return _read_jsonl(path) + + +def _metrics_from_rows( + rows: Sequence[Mapping[str, Any]], + *, + split: str, + model_name: str, + seed: int, +) -> Dict[str, Any]: + count = len(rows) + if count == 0: + return { + "split": split, + "model_name": model_name, + "seed": seed, + "num_samples": 0, + "exact_match": 0.0, + "solution_tag_rate": 0.0, + "parseable_solution_rate": 0.0, + "multiple_solution_tag_rate": 0.0, + "avg_completion_tokens": 0.0, + "avg_reward": 0.0, + } + return { + "split": split, + "model_name": model_name, + "seed": seed, + "num_samples": count, + "exact_match": sum(1.0 for row in rows if row["exact_match"]) / count, + "solution_tag_rate": sum(1.0 for row in rows if row["single_solution_tag"]) / count, + "parseable_solution_rate": sum(1.0 for row in rows if row["parseable_solution"]) / count, + "multiple_solution_tag_rate": sum(1.0 for row in rows if row["multiple_solution_tags"]) / count, + "avg_completion_tokens": sum(float(row["completion_tokens"]) for row in rows) / count, + "avg_reward": sum(float(row["reward"]) for row in rows) / count, + } + + +def _aggregate_metrics( + per_split: Mapping[str, Mapping[str, Any]], + *, + model_name: str, + seed: int, +) -> Dict[str, Any]: + all_rows = sum(int(metrics["num_samples"]) for metrics in per_split.values()) + if all_rows == 0: + return _metrics_from_rows([], split="aggregate", model_name=model_name, seed=seed) + weighted = {} + for field in ( + "exact_match", + "solution_tag_rate", + "parseable_solution_rate", + "multiple_solution_tag_rate", + "avg_completion_tokens", + "avg_reward", + ): + weighted[field] = sum( + float(metrics[field]) * int(metrics["num_samples"]) for metrics in per_split.values() + ) / all_rows + return { + "split": "aggregate", + "model_name": model_name, + "seed": seed, + "num_samples": all_rows, + **weighted, + } + + +def evaluate_records( + model: Any, + tokenizer: Any, + rows: Sequence[Mapping[str, Any]], + *, + split: str, + model_name: str, + seed: int, + max_tokens: int, + temperature: float, +) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + evaluated: List[Dict[str, Any]] = [] + for row in rows: + response_text = generate_completion( + model, + tokenizer, + row["messages"], + max_tokens=max_tokens, + temperature=temperature, + ) + scored = score_solution_output(response_text, str(row["answer"])) + evaluated.append( + { + "task_id": row["task_id"], + "split": split, + "expression": row["expression"], + "difficulty": row["difficulty"], + "answer": row["answer"], + "completion_text": response_text, + "completion_tokens": _token_length(tokenizer, response_text), + **scored, + } + ) + return evaluated, _metrics_from_rows(evaluated, split=split, model_name=model_name, seed=seed) + + +def _run_eval( + model: Any, + tokenizer: Any, + output_dir: Path, + *, + model_name: str, + seed: int, + output_path: Path, + metrics_path: Path, + max_tokens: int, + temperature: float, +) -> Dict[str, Any]: + all_outputs: List[Dict[str, Any]] = [] + per_split: Dict[str, Dict[str, Any]] = {} + for split in ("val", "test"): + rows = _load_split_records(output_dir, split) + split_outputs, split_metrics = evaluate_records( + model, + tokenizer, + rows, + split=split, + model_name=model_name, + seed=seed, + max_tokens=max_tokens, + temperature=temperature, + ) + all_outputs.extend(split_outputs) + per_split[split] = split_metrics + _write_jsonl(output_path, all_outputs) + payload = { + "model_name": model_name, + "seed": seed, + "splits": per_split, + "aggregate": _aggregate_metrics(per_split, model_name=model_name, seed=seed), + } + _write_json(metrics_path, payload) + return payload + + +def run_baseline( + output_dir: Path, + *, + model_name: str, + seed: int, + max_seq_length: int, + max_completion_length: int, +) -> Dict[str, Any]: + model, tokenizer = load_model_bundle( + model_name, + max_seq_length=max_seq_length, + ) + return _run_eval( + model, + tokenizer, + output_dir, + model_name=model_name, + seed=seed, + output_path=_baseline_outputs_path(output_dir), + metrics_path=_baseline_metrics_path(output_dir), + max_tokens=max_completion_length, + temperature=DEFAULT_BASELINE_TEMPERATURE, + ) + + +def run_training( + output_dir: Path, + *, + model_name: str, + seed: int, + max_seq_length: int, + max_completion_length: int, + max_steps: int, + learning_rate: float, + per_device_train_batch_size: int, + rollout_batch_size: int, + num_generations: int, + rl_temperature: float, + lora_rank: int, + logging_steps: int, + eval_steps: int, + save_steps: int, +) -> Dict[str, Any]: + train_rows = _load_split_records(output_dir, "train") + val_rows = _load_split_records(output_dir, "val") + model, tokenizer = load_model_bundle( + model_name, + max_seq_length=max_seq_length, + for_training=True, + lora_rank=lora_rank, + ) + reward = ArithmeticSolutionReward() + config = GRPOConfig( + loss_type="grpo", + learning_rate=learning_rate, + per_device_train_batch_size=per_device_train_batch_size, + rollout_batch_size=rollout_batch_size, + num_generations=num_generations, + temperature=rl_temperature, + max_steps=max_steps, + max_seq_length=max_seq_length, + max_completion_length=max_completion_length, + logging_steps=logging_steps, + eval_steps=eval_steps, + save_steps=save_steps, + reward_source="online", + reward_normalization="none", + mask_truncated_completions=True, + output_dir=str(_rl_output_dir(output_dir)), + seed=seed, + ) + trainer = GRPOTrainer( + model=model, + train_dataset=train_rows, + eval_dataset=val_rows, + tokenizer=tokenizer, + reward_fn=reward, + args=config, + ) + result = trainer.train() + _write_json(_training_summary_path(output_dir), result) + FastLanguageModel.for_inference(model) + post_metrics = _run_eval( + model, + tokenizer, + output_dir, + model_name=model_name, + seed=seed, + output_path=_post_outputs_path(output_dir), + metrics_path=_post_metrics_path(output_dir), + max_tokens=max_completion_length, + temperature=DEFAULT_BASELINE_TEMPERATURE, + ) + return {"training": result, "post_eval": post_metrics} + + +def _metric_delta(baseline: Mapping[str, Any], post: Mapping[str, Any]) -> Dict[str, float]: + return { + key: float(post[key]) - float(baseline[key]) + for key in ( + "exact_match", + "solution_tag_rate", + "parseable_solution_rate", + "multiple_solution_tag_rate", + "avg_completion_tokens", + "avg_reward", + ) + } + + +def _load_eval_rows(path: Path) -> Dict[Tuple[str, str], Dict[str, Any]]: + return { + (row["split"], row["task_id"]): row + for row in _read_jsonl(path) + } + + +def _pair_examples( + baseline_rows: Mapping[Tuple[str, str], Mapping[str, Any]], + post_rows: Mapping[Tuple[str, str], Mapping[str, Any]], +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]: + improvements: List[Dict[str, Any]] = [] + regressions: List[Dict[str, Any]] = [] + unchanged_failures: List[Dict[str, Any]] = [] + for key in sorted(set(baseline_rows) & set(post_rows)): + baseline = baseline_rows[key] + post = post_rows[key] + pair = { + "split": key[0], + "task_id": key[1], + "expression": baseline["expression"], + "answer": baseline["answer"], + "baseline_completion_text": baseline["completion_text"], + "post_completion_text": post["completion_text"], + "baseline_reward": baseline["reward"], + "post_reward": post["reward"], + "baseline_exact_match": baseline["exact_match"], + "post_exact_match": post["exact_match"], + "baseline_single_solution_tag": baseline["single_solution_tag"], + "post_single_solution_tag": post["single_solution_tag"], + } + if float(post["reward"]) > float(baseline["reward"]): + improvements.append(pair) + elif float(post["reward"]) < float(baseline["reward"]): + regressions.append(pair) + elif not baseline["exact_match"] and not post["exact_match"]: + unchanged_failures.append(pair) + return improvements, regressions, unchanged_failures + + +def _render_examples_md(title: str, rows: Sequence[Mapping[str, Any]]) -> List[str]: + lines = [f"### {title}", ""] + if not rows: + lines.append("_None_") + lines.append("") + return lines + for row in rows: + lines.extend( + [ + f"- `{row['split']}/{row['task_id']}` `{row['expression']}` -> `{row['answer']}`", + f" baseline: {row['baseline_completion_text']}", + f" post_rl: {row['post_completion_text']}", + "", + ] + ) + return lines + + +def run_compare(output_dir: Path) -> Dict[str, Any]: + baseline_metrics = _read_json(_baseline_metrics_path(output_dir)) + post_metrics = _read_json(_post_metrics_path(output_dir)) + comparison = { + "baseline": baseline_metrics, + "post_rl": post_metrics, + "delta": { + split: _metric_delta(baseline_metrics["splits"][split], post_metrics["splits"][split]) + for split in ("val", "test") + }, + "aggregate_delta": _metric_delta( + baseline_metrics["aggregate"], + post_metrics["aggregate"], + ), + } + baseline_rows = _load_eval_rows(_baseline_outputs_path(output_dir)) + post_rows = _load_eval_rows(_post_outputs_path(output_dir)) + improvements, regressions, unchanged_failures = _pair_examples(baseline_rows, post_rows) + comparison["sample_counts"] = { + "improvements": len(improvements), + "regressions": len(regressions), + "unchanged_failures": len(unchanged_failures), + } + _write_json(_comparison_json_path(output_dir), comparison) + + lines = [ + "# Qwen3 Arithmetic GRPO Comparison", + "", + "| Split | Baseline Exact | Post Exact | Delta Exact | Baseline Reward | Post Reward | Delta Reward | Baseline Tag | Post Tag | Delta Tag |", + "| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: |", + ] + for split in ("val", "test"): + baseline = baseline_metrics["splits"][split] + post = post_metrics["splits"][split] + delta = comparison["delta"][split] + lines.append( + "| {split} | {b_exact:.4f} | {p_exact:.4f} | {d_exact:+.4f} | " + "{b_reward:.4f} | {p_reward:.4f} | {d_reward:+.4f} | " + "{b_tag:.4f} | {p_tag:.4f} | {d_tag:+.4f} |".format( + split=split, + b_exact=baseline["exact_match"], + p_exact=post["exact_match"], + d_exact=delta["exact_match"], + b_reward=baseline["avg_reward"], + p_reward=post["avg_reward"], + d_reward=delta["avg_reward"], + b_tag=baseline["solution_tag_rate"], + p_tag=post["solution_tag_rate"], + d_tag=delta["solution_tag_rate"], + ) + ) + aggregate_delta = comparison["aggregate_delta"] + lines.extend( + [ + "", + "## Conclusion", + "", + ( + "GRPO appears to work on this benchmark." + if aggregate_delta["avg_reward"] > 0 or aggregate_delta["solution_tag_rate"] > 0 + else "GRPO did not show a positive held-out signal on this run." + ), + "", + ] + ) + lines.extend(_render_examples_md("Improvements", improvements[:10])) + lines.extend(_render_examples_md("Regressions", regressions[:5])) + lines.extend(_render_examples_md("Unchanged Failures", unchanged_failures[:5])) + _comparison_md_path(output_dir).write_text("\n".join(lines)) + return comparison + + +def _parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("mode", choices=["generate", "baseline", "train", "compare", "all"]) + parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR) + parser.add_argument("--model-name", default=DEFAULT_MODEL_NAME) + parser.add_argument("--train-size", type=int, default=DEFAULT_TRAIN_SIZE) + parser.add_argument("--val-size", type=int, default=DEFAULT_VAL_SIZE) + parser.add_argument("--test-size", type=int, default=DEFAULT_TEST_SIZE) + parser.add_argument("--seed", type=int, default=DEFAULT_SEED) + parser.add_argument("--force-generate", action="store_true") + parser.add_argument("--max-completion-length", type=int, default=DEFAULT_MAX_COMPLETION_LENGTH) + parser.add_argument("--max-seq-length", type=int, default=DEFAULT_MAX_SEQ_LENGTH) + parser.add_argument("--max-steps", type=int, default=500) + parser.add_argument("--learning-rate", type=float, default=1e-6) + parser.add_argument("--per-device-train-batch-size", type=int, default=2) + parser.add_argument("--rollout-batch-size", type=int, default=4) + parser.add_argument("--num-generations", type=int, default=4) + parser.add_argument("--rl-temperature", type=float, default=DEFAULT_RL_TEMPERATURE) + parser.add_argument("--lora-rank", type=int, default=DEFAULT_LORA_RANK) + parser.add_argument("--logging-steps", type=int, default=10) + parser.add_argument("--eval-steps", type=int, default=50) + parser.add_argument("--save-steps", type=int, default=100) + return parser.parse_args(argv) + + +def main(argv: Optional[Sequence[str]] = None) -> int: + args = _parse_args(argv) + generate_benchmark_splits( + args.output_dir, + train_size=args.train_size, + val_size=args.val_size, + test_size=args.test_size, + seed=args.seed, + force=args.force_generate, + ) + if args.mode == "generate": + return 0 + if args.mode == "baseline": + run_baseline( + args.output_dir, + model_name=args.model_name, + seed=args.seed, + max_seq_length=args.max_seq_length, + max_completion_length=args.max_completion_length, + ) + return 0 + if args.mode == "train": + run_training( + args.output_dir, + model_name=args.model_name, + seed=args.seed, + max_seq_length=args.max_seq_length, + max_completion_length=args.max_completion_length, + max_steps=args.max_steps, + learning_rate=args.learning_rate, + per_device_train_batch_size=args.per_device_train_batch_size, + rollout_batch_size=args.rollout_batch_size, + num_generations=args.num_generations, + rl_temperature=args.rl_temperature, + lora_rank=args.lora_rank, + logging_steps=args.logging_steps, + eval_steps=args.eval_steps, + save_steps=args.save_steps, + ) + return 0 + if args.mode == "compare": + run_compare(args.output_dir) + return 0 + run_baseline( + args.output_dir, + model_name=args.model_name, + seed=args.seed, + max_seq_length=args.max_seq_length, + max_completion_length=args.max_completion_length, + ) + run_training( + args.output_dir, + model_name=args.model_name, + seed=args.seed, + max_seq_length=args.max_seq_length, + max_completion_length=args.max_completion_length, + max_steps=args.max_steps, + learning_rate=args.learning_rate, + per_device_train_batch_size=args.per_device_train_batch_size, + rollout_batch_size=args.rollout_batch_size, + num_generations=args.num_generations, + rl_temperature=args.rl_temperature, + lora_rank=args.lora_rank, + logging_steps=args.logging_steps, + eval_steps=args.eval_steps, + save_steps=args.save_steps, + ) + run_compare(args.output_dir) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/mlx_tune/losses.py b/mlx_tune/losses.py index 5625cd0..219e194 100644 --- a/mlx_tune/losses.py +++ b/mlx_tune/losses.py @@ -1,7 +1,7 @@ """ Loss functions for MLX-Tune RL training. -Provides proper loss implementations for: +Provides native MLX losses and reference-logprob helpers for: - DPO (Direct Preference Optimization) - ORPO (Odds Ratio Preference Optimization) - GRPO (Group Relative Policy Optimization) @@ -9,10 +9,77 @@ - SimPO (Simple Preference Optimization) """ -from typing import Optional, Tuple, Callable, List, Any +from typing import Any, Callable, Dict, List, Optional, Tuple + import mlx.core as mx import mlx.nn as nn +from mlx_tune._rl_runtime import ( + PolicyEvalBatch, + build_token_mask, + collect_rollouts, + completion_token_mask, + compute_advantages, + evaluate_rewards, + kl_against_reference, + length_mask, + make_policy_eval_batch, + normalize_logprobs, + sample_completion, + score_policy, + score_policy_in_chunks, +) + + +def _policy_eval_from_padded( + input_ids: mx.array, + lengths: mx.array, + mode: str = "sequence", + prompt_lengths: Optional[mx.array] = None, + completion_lengths: Optional[mx.array] = None, + rollout_logprobs: Optional[mx.array] = None, + old_logprobs: Optional[mx.array] = None, + old_token_logprobs: Optional[mx.array] = None, + reference_logprobs: Optional[mx.array] = None, + value_predictions: Optional[mx.array] = None, + returns: Optional[mx.array] = None, + advantages: Optional[mx.array] = None, + labels: Optional[mx.array] = None, +) -> PolicyEvalBatch: + return PolicyEvalBatch( + input_ids=input_ids, + sequence_lengths=lengths, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + token_mask=build_token_mask( + input_ids=input_ids, + sequence_lengths=lengths, + mode=mode, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + ), + rollout_logprobs=rollout_logprobs, + old_logprobs=old_logprobs if old_logprobs is not None else rollout_logprobs, + old_token_logprobs=old_token_logprobs, + reference_logprobs=reference_logprobs, + value_predictions=value_predictions, + returns=returns, + advantages=advantages, + labels=labels, + ) + + +def _token_log_probs( + model: Any, + input_ids: mx.array, + temperature: float = 1.0, +) -> mx.array: + batch = _policy_eval_from_padded( + input_ids=input_ids, + lengths=mx.array([input_ids.shape[1]] * input_ids.shape[0]), + ) + return score_policy(model, batch, mode="sequence", temperature=temperature).token_logprobs + def compute_log_probs( model: Any, @@ -20,48 +87,12 @@ def compute_log_probs( attention_mask: Optional[mx.array] = None, ) -> mx.array: """ - Compute per-token log probabilities for a batch of sequences. - - Args: - model: The language model. - input_ids: Token IDs of shape [batch_size, seq_len]. - attention_mask: Optional mask of shape [batch_size, seq_len]. - - Returns: - Log probabilities of shape [batch_size] (sum over sequence). + Compute per-sequence log probabilities for a batch of sequences. """ - # Get inputs (all tokens except last) and targets (all tokens except first) - inputs = input_ids[:, :-1] - targets = input_ids[:, 1:] - - # Forward pass to get logits - logits = model(inputs) # [batch_size, seq_len-1, vocab_size] - - # Compute log softmax to get log probabilities - log_probs = nn.log_softmax(logits, axis=-1) # [batch_size, seq_len-1, vocab_size] - - # Gather log probs for the actual target tokens - # targets: [batch_size, seq_len-1] - # We need to get log_probs[b, t, targets[b, t]] for each position - batch_size, seq_len = targets.shape - - # Use advanced indexing to gather target log probs - target_log_probs = mx.take_along_axis( - log_probs, - targets[:, :, None], # [batch_size, seq_len-1, 1] - axis=-1 - ).squeeze(-1) # [batch_size, seq_len-1] - - # Apply attention mask if provided + token_log_probs = _token_log_probs(model, input_ids) if attention_mask is not None: - # Shift mask to match targets - mask = attention_mask[:, 1:] - target_log_probs = target_log_probs * mask - - # Sum log probs over sequence to get sequence log probability - sequence_log_probs = target_log_probs.sum(axis=-1) # [batch_size] - - return sequence_log_probs + token_log_probs = token_log_probs * attention_mask[:, 1:].astype(token_log_probs.dtype) + return token_log_probs.sum(axis=-1) def compute_log_probs_with_lengths( @@ -70,38 +101,70 @@ def compute_log_probs_with_lengths( lengths: mx.array, ) -> mx.array: """ - Compute per-token log probabilities with explicit length masking. + Compute per-sequence log probabilities with explicit length masking. + """ + batch = _policy_eval_from_padded(input_ids=input_ids, lengths=lengths, mode="sequence") + return score_policy(model, batch, mode="sequence").summed_logprobs - Args: - model: The language model. - input_ids: Token IDs of shape [batch_size, seq_len]. - lengths: Sequence lengths of shape [batch_size]. - Returns: - Log probabilities of shape [batch_size] (sum over valid tokens). +def compute_completion_log_probs( + model: Any, + input_ids: mx.array, + prompt_lengths: mx.array, + completion_lengths: mx.array, + temperature: float = 1.0, +) -> mx.array: """ - inputs = input_ids[:, :-1] - targets = input_ids[:, 1:] - - logits = model(inputs) - log_probs = nn.log_softmax(logits, axis=-1) + Compute log probabilities over completion tokens only. + """ + sequence_lengths = prompt_lengths + completion_lengths + batch = _policy_eval_from_padded( + input_ids=input_ids, + lengths=sequence_lengths, + mode="completion", + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + ) + return score_policy(model, batch, mode="completion", temperature=temperature).summed_logprobs + + +def _batched_sequence_log_probs( + model: Any, + input_ids: mx.array, + lengths: mx.array, + batch_size: int = 8, +) -> mx.array: + batch = _policy_eval_from_padded(input_ids=input_ids, lengths=lengths, mode="sequence") + return score_policy_in_chunks(model, batch, batch_size=batch_size, mode="sequence").summed_logprobs - target_log_probs = mx.take_along_axis( - log_probs, - targets[:, :, None], - axis=-1 - ).squeeze(-1) - # Create mask from lengths - seq_len = targets.shape[1] - positions = mx.arange(seq_len)[None, :] # [1, seq_len] - mask = positions < lengths[:, None] # [batch_size, seq_len] +def precompute_preference_reference_logprobs( + model: Any, + chosen_ids: mx.array, + rejected_ids: mx.array, + chosen_lengths: mx.array, + rejected_lengths: mx.array, + batch_size: int = 8, +) -> Tuple[mx.array, mx.array]: + """ + Precompute frozen-reference log probabilities for preference pairs. + """ + ref_chosen = _batched_sequence_log_probs(model, chosen_ids, chosen_lengths, batch_size) + ref_rejected = _batched_sequence_log_probs(model, rejected_ids, rejected_lengths, batch_size) + return mx.stop_gradient(ref_chosen), mx.stop_gradient(ref_rejected) - # Apply mask and sum - masked_log_probs = target_log_probs * mask.astype(target_log_probs.dtype) - sequence_log_probs = masked_log_probs.sum(axis=-1) - return sequence_log_probs +def precompute_kto_reference_logprobs( + model: Any, + input_ids: mx.array, + lengths: mx.array, + batch_size: int = 8, +) -> mx.array: + """ + Precompute frozen-reference log probabilities for KTO samples. + """ + ref = _batched_sequence_log_probs(model, input_ids, lengths, batch_size) + return mx.stop_gradient(ref) def dpo_loss( @@ -116,60 +179,51 @@ def dpo_loss( label_smoothing: float = 0.0, ) -> Tuple[mx.array, mx.array]: """ - Compute DPO (Direct Preference Optimization) loss. - - DPO Loss: -log(sigmoid(beta * (log_ratio_chosen - log_ratio_rejected))) - - Where: - log_ratio = log_pi(y|x) - log_ref(y|x) - - Args: - model: The policy model being trained. - chosen_ids: Token IDs for chosen responses [batch_size, seq_len]. - rejected_ids: Token IDs for rejected responses [batch_size, seq_len]. - chosen_lengths: Lengths of chosen sequences [batch_size]. - rejected_lengths: Lengths of rejected sequences [batch_size]. - beta: KL penalty coefficient (temperature). - reference_chosen_logprobs: Pre-computed reference log probs for chosen. - reference_rejected_logprobs: Pre-computed reference log probs for rejected. - label_smoothing: Label smoothing coefficient. - - Returns: - Tuple of (loss, num_tokens). + Compute DPO loss. """ - # Compute policy model log probabilities - log_pi_chosen = compute_log_probs_with_lengths(model, chosen_ids, chosen_lengths) - log_pi_rejected = compute_log_probs_with_lengths(model, rejected_ids, rejected_lengths) - - # Handle reference model log probabilities - if reference_chosen_logprobs is None or reference_rejected_logprobs is None: - # Use current model with stop_gradient as reference (memory efficient) - log_ref_chosen = mx.stop_gradient(log_pi_chosen) - log_ref_rejected = mx.stop_gradient(log_pi_rejected) - else: - log_ref_chosen = reference_chosen_logprobs - log_ref_rejected = reference_rejected_logprobs - - # Compute log ratios - log_ratio_chosen = log_pi_chosen - log_ref_chosen - log_ratio_rejected = log_pi_rejected - log_ref_rejected - - # DPO loss: -log(sigmoid(beta * (log_ratio_chosen - log_ratio_rejected))) - logits = beta * (log_ratio_chosen - log_ratio_rejected) - + chosen_batch = score_policy( + model, + _policy_eval_from_padded( + input_ids=chosen_ids, + lengths=chosen_lengths, + mode="sequence", + reference_logprobs=reference_chosen_logprobs, + ), + mode="sequence", + ) + rejected_batch = score_policy( + model, + _policy_eval_from_padded( + input_ids=rejected_ids, + lengths=rejected_lengths, + mode="sequence", + reference_logprobs=reference_rejected_logprobs, + ), + mode="sequence", + ) + + log_pi_chosen = chosen_batch.summed_logprobs + log_pi_rejected = rejected_batch.summed_logprobs + log_ref_chosen = ( + mx.stop_gradient(log_pi_chosen) + if chosen_batch.reference_logprobs is None + else chosen_batch.reference_logprobs + ) + log_ref_rejected = ( + mx.stop_gradient(log_pi_rejected) + if rejected_batch.reference_logprobs is None + else rejected_batch.reference_logprobs + ) + + logits = beta * ((log_pi_chosen - log_ref_chosen) - (log_pi_rejected - log_ref_rejected)) if label_smoothing > 0: - # Smooth the labels losses = ( -nn.log_sigmoid(logits) * (1 - label_smoothing) - nn.log_sigmoid(-logits) * label_smoothing ) else: losses = -nn.log_sigmoid(logits) - - loss = mx.mean(losses) - ntoks = chosen_lengths.sum() + rejected_lengths.sum() - - return loss, ntoks + return mx.mean(losses), chosen_lengths.sum() + rejected_lengths.sum() def orpo_loss( @@ -181,101 +235,51 @@ def orpo_loss( beta: float = 0.1, ) -> Tuple[mx.array, mx.array]: """ - Compute ORPO (Odds Ratio Preference Optimization) loss. - - ORPO combines SFT loss with odds ratio preference loss: - L = L_SFT + beta * L_OR - - Where: - L_SFT = -log P(chosen) - L_OR = -log(sigmoid(log(odds_ratio))) - odds_ratio = P(chosen) / P(rejected) - - Args: - model: The model being trained. - chosen_ids: Token IDs for chosen responses. - rejected_ids: Token IDs for rejected responses. - chosen_lengths: Lengths of chosen sequences. - rejected_lengths: Lengths of rejected sequences. - beta: Weight for odds ratio loss. - - Returns: - Tuple of (loss, num_tokens). + Compute ORPO loss. """ - # Compute log probabilities log_pi_chosen = compute_log_probs_with_lengths(model, chosen_ids, chosen_lengths) log_pi_rejected = compute_log_probs_with_lengths(model, rejected_ids, rejected_lengths) + avg_log_pi_chosen = normalize_logprobs(log_pi_chosen, chosen_lengths, mode="mean") - # SFT loss on chosen (negative log likelihood) - # Normalize by length for fair comparison - avg_log_pi_chosen = log_pi_chosen / chosen_lengths.astype(log_pi_chosen.dtype) - sft_loss = -mx.mean(avg_log_pi_chosen) - - # Odds ratio loss - # log(odds_ratio) = log(P_chosen) - log(P_rejected) - log_odds = log_pi_chosen - log_pi_rejected - or_loss = -mx.mean(nn.log_sigmoid(log_odds)) - - # Combined loss - loss = sft_loss + beta * or_loss - - ntoks = chosen_lengths.sum() + rejected_lengths.sum() - return loss, ntoks + sft_term = -mx.mean(avg_log_pi_chosen) + odds_term = -mx.mean(nn.log_sigmoid(log_pi_chosen - log_pi_rejected)) + loss = sft_term + beta * odds_term + return loss, chosen_lengths.sum() + rejected_lengths.sum() def kto_loss( model: Any, input_ids: mx.array, lengths: mx.array, - labels: mx.array, # 1 for positive, 0 for negative + labels: mx.array, beta: float = 0.1, reference_logprobs: Optional[mx.array] = None, ) -> Tuple[mx.array, mx.array]: """ - Compute KTO (Kahneman-Tversky Optimization) loss. - - KTO uses prospect theory with asymmetric treatment of gains and losses: - L = -E[w(y) * log(sigmoid(beta * log_ratio))] - - Where w(y) = lambda if y is positive, 1 if y is negative. - - Args: - model: The model being trained. - input_ids: Token IDs [batch_size, seq_len]. - lengths: Sequence lengths [batch_size]. - labels: Binary labels (1=positive, 0=negative) [batch_size]. - beta: Temperature coefficient. - reference_logprobs: Pre-computed reference log probs. - - Returns: - Tuple of (loss, num_tokens). + Compute KTO loss. """ - # Compute policy log probs - log_pi = compute_log_probs_with_lengths(model, input_ids, lengths) - - # Handle reference - if reference_logprobs is None: - log_ref = mx.stop_gradient(log_pi) - else: - log_ref = reference_logprobs - + batch = score_policy( + model, + _policy_eval_from_padded( + input_ids=input_ids, + lengths=lengths, + mode="sequence", + reference_logprobs=reference_logprobs, + labels=labels, + ), + mode="sequence", + ) + log_pi = batch.summed_logprobs + log_ref = mx.stop_gradient(log_pi) if batch.reference_logprobs is None else batch.reference_logprobs log_ratio = log_pi - log_ref - # KTO weights (lambda for positive, 1 for negative) - lambda_weight = 1.0 # Can be tuned - weights = mx.where(labels > 0.5, lambda_weight, 1.0) - - # Loss with asymmetric weights positive_mask = labels > 0.5 negative_mask = ~positive_mask - + weights = mx.where(positive_mask, 1.0, 1.0) positive_loss = -nn.log_sigmoid(beta * log_ratio) * positive_mask negative_loss = -nn.log_sigmoid(-beta * log_ratio) * negative_mask - loss = mx.mean(weights * (positive_loss + negative_loss)) - ntoks = lengths.sum() - - return loss, ntoks + return loss, lengths.sum() def simpo_loss( @@ -288,39 +292,138 @@ def simpo_loss( gamma: float = 0.5, ) -> Tuple[mx.array, mx.array]: """ - Compute SimPO (Simple Preference Optimization) loss. + Compute SimPO loss. + """ + log_pi_chosen = compute_log_probs_with_lengths(model, chosen_ids, chosen_lengths) + log_pi_rejected = compute_log_probs_with_lengths(model, rejected_ids, rejected_lengths) + r_chosen = normalize_logprobs(log_pi_chosen, chosen_lengths, mode="mean") + r_rejected = normalize_logprobs(log_pi_rejected, rejected_lengths, mode="mean") + + logits = beta * (r_chosen - r_rejected - gamma) + return -mx.mean(nn.log_sigmoid(logits)), chosen_lengths.sum() + rejected_lengths.sum() - SimPO simplifies DPO by removing the need for a reference model: - L = -log(sigmoid(beta * (r_chosen - r_rejected - gamma))) - Where r = log P(y|x) / |y| (length-normalized log prob). +def pairwise_reward_loss( + chosen_scores: mx.array, + rejected_scores: mx.array, + margin: float = 0.0, +) -> mx.array: + """ + Logistic pairwise preference loss over scalar reward scores. + """ + return -mx.mean(nn.log_sigmoid(chosen_scores - rejected_scores - margin)) + + +def reward_model_pairwise_loss( + reward_model: Any, + chosen_input_ids: mx.array, + rejected_input_ids: mx.array, + chosen_sequence_lengths: mx.array, + rejected_sequence_lengths: mx.array, + chosen_prompt_lengths: Optional[mx.array] = None, + rejected_prompt_lengths: Optional[mx.array] = None, + chosen_completion_lengths: Optional[mx.array] = None, + rejected_completion_lengths: Optional[mx.array] = None, + margin: float = 0.0, +) -> Tuple[mx.array, Dict[str, mx.array]]: + """ + Compute pairwise reward-model loss for chosen/rejected sequences. + """ + chosen_scores, rejected_scores = reward_model.score_pairs( + chosen_input_ids=chosen_input_ids, + rejected_input_ids=rejected_input_ids, + chosen_sequence_lengths=chosen_sequence_lengths, + rejected_sequence_lengths=rejected_sequence_lengths, + chosen_prompt_lengths=chosen_prompt_lengths, + rejected_prompt_lengths=rejected_prompt_lengths, + chosen_completion_lengths=chosen_completion_lengths, + rejected_completion_lengths=rejected_completion_lengths, + ) + loss = pairwise_reward_loss(chosen_scores, rejected_scores, margin=margin) + return loss, { + "chosen_scores": chosen_scores, + "rejected_scores": rejected_scores, + } + + +def value_regression_loss( + predictions: mx.array, + targets: mx.array, + loss_type: str = "mse", +) -> mx.array: + """ + Compute a pointwise scalar regression loss. + """ + if loss_type == "mse": + return mx.mean((predictions - targets) ** 2) + if loss_type == "mae": + return mx.mean(mx.abs(predictions - targets)) + raise ValueError(f"Unsupported scalar regression loss: {loss_type}") - Args: - model: The model being trained. - chosen_ids: Token IDs for chosen responses. - rejected_ids: Token IDs for rejected responses. - chosen_lengths: Lengths of chosen sequences. - rejected_lengths: Lengths of rejected sequences. - beta: Temperature coefficient. - gamma: Target reward margin. - Returns: - Tuple of (loss, num_tokens). +def value_model_regression_loss( + value_model: Any, + input_ids: mx.array, + sequence_lengths: mx.array, + targets: mx.array, + prompt_lengths: Optional[mx.array] = None, + completion_lengths: Optional[mx.array] = None, + loss_type: str = "mse", +) -> Tuple[mx.array, mx.array]: """ - # Compute log probabilities - log_pi_chosen = compute_log_probs_with_lengths(model, chosen_ids, chosen_lengths) - log_pi_rejected = compute_log_probs_with_lengths(model, rejected_ids, rejected_lengths) + Compute pointwise regression loss for a scalar value model. + """ + predictions = value_model.predict( + input_ids, + sequence_lengths=sequence_lengths, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + ) + return value_regression_loss(predictions, targets, loss_type=loss_type), predictions - # Length-normalize to get "reward" - r_chosen = log_pi_chosen / chosen_lengths.astype(log_pi_chosen.dtype) - r_rejected = log_pi_rejected / rejected_lengths.astype(log_pi_rejected.dtype) - # SimPO loss - logits = beta * (r_chosen - r_rejected - gamma) - loss = -mx.mean(nn.log_sigmoid(logits)) +def reward_model_regression_loss( + reward_model: Any, + input_ids: mx.array, + sequence_lengths: mx.array, + targets: mx.array, + prompt_lengths: Optional[mx.array] = None, + completion_lengths: Optional[mx.array] = None, + loss_type: str = "mse", +) -> Tuple[mx.array, mx.array]: + """ + Compute pointwise regression loss for a scalar reward model. + """ + predictions = reward_model.score( + input_ids, + sequence_lengths=sequence_lengths, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + ) + return value_regression_loss(predictions, targets, loss_type=loss_type), predictions + - ntoks = chosen_lengths.sum() + rejected_lengths.sum() - return loss, ntoks +def scalar_loss_metrics(loss: mx.array, predictions: mx.array, targets: mx.array) -> Dict[str, float]: + """ + Compute generic scalar regression metrics. + """ + mae = mx.mean(mx.abs(predictions - targets)) + mse = mx.mean((predictions - targets) ** 2) + return { + "loss": float(loss.item()), + "mae": float(mae.item()), + "mse": float(mse.item()), + } + + +def pairwise_ranking_accuracy( + chosen_scores: mx.array, + rejected_scores: mx.array, +) -> float: + """ + Compute pairwise ranking accuracy for scalar preference scores. + """ + return float(mx.mean((chosen_scores > rejected_scores).astype(mx.float32)).item()) def sft_loss( @@ -329,37 +432,71 @@ def sft_loss( lengths: mx.array, ) -> Tuple[mx.array, mx.array]: """ - Standard Supervised Fine-Tuning (cross-entropy) loss. - - Args: - model: The model being trained. - input_ids: Token IDs [batch_size, seq_len]. - lengths: Sequence lengths [batch_size]. - - Returns: - Tuple of (loss, num_tokens). + Standard supervised fine-tuning loss. """ inputs = input_ids[:, :-1] targets = input_ids[:, 1:] - logits = model(inputs) - # Create length mask - seq_len = targets.shape[1] - positions = mx.arange(seq_len)[None, :] - mask = positions < lengths[:, None] - - # Cross entropy loss - ce = nn.losses.cross_entropy(logits, targets, reduction='none') - masked_ce = ce * mask.astype(ce.dtype) - + mask = length_mask(lengths, targets.shape[1]).astype(logits.dtype) + ce = nn.losses.cross_entropy(logits, targets, reduction="none") + masked_ce = ce * mask ntoks = mask.sum() - loss = masked_ce.sum() / ntoks + return masked_ce.sum() / ntoks, ntoks - return loss, ntoks +def ppo_sequence_loss( + model: Any, + batch: PolicyEvalBatch, + beta: float = 0.0, + clip_epsilon: float = 0.2, + temperature: float = 1.0, + reference_model: Optional[Any] = None, +) -> Tuple[mx.array, Dict[str, mx.array]]: + """ + Compute a clipped PPO objective over full sampled completions. + """ + scored_batch = score_policy( + model, + batch, + mode="completion", + reference_model=reference_model if batch.reference_logprobs is None else None, + temperature=temperature, + ) + old_logprobs = batch.old_logprobs if batch.old_logprobs is not None else batch.rollout_logprobs + if old_logprobs is None: + raise ValueError("PPO loss requires stored old log probabilities.") + if batch.advantages is None: + raise ValueError("PPO loss requires advantages.") + + ratios = mx.exp(scored_batch.summed_logprobs - old_logprobs) + clipped_ratios = mx.clip(ratios, 1.0 - clip_epsilon, 1.0 + clip_epsilon) + unclipped_objective = ratios * batch.advantages + clipped_objective = clipped_ratios * batch.advantages + policy_objective = mx.minimum(unclipped_objective, clipped_objective) + + kl_penalty = mx.zeros_like(policy_objective) + if scored_batch.reference_logprobs is not None: + kl_scored_batch = scored_batch + if temperature != 1.0: + kl_scored_batch = score_policy( + model, + batch, + mode="completion", + reference_model=reference_model if batch.reference_logprobs is None else None, + temperature=1.0, + ) + kl_penalty = kl_against_reference( + kl_scored_batch.summed_logprobs, + kl_scored_batch.reference_logprobs, + ) + loss = -mx.mean(policy_objective - beta * kl_penalty) + return loss, { + "policy_logprobs": scored_batch.summed_logprobs, + "ratios": ratios, + "kl_penalty": kl_penalty, + } -# GRPO-specific functions def generate_with_log_probs( model: Any, @@ -369,57 +506,136 @@ def generate_with_log_probs( temperature: float = 0.7, ) -> Tuple[mx.array, mx.array]: """ - Generate a completion and return token IDs with their log probabilities. - - Args: - model: The language model. - tokenizer: The tokenizer. - prompt_ids: Prompt token IDs [seq_len]. - max_tokens: Maximum tokens to generate. - temperature: Sampling temperature. - - Returns: - Tuple of (generated_ids, log_probs) where: - generated_ids: [prompt_len + gen_len] - log_probs: [gen_len] log probability of each generated token + Generate a sampled completion and return sampled-token log probabilities. """ - generated_ids = list(prompt_ids.tolist()) if hasattr(prompt_ids, 'tolist') else list(prompt_ids) - log_probs = [] - - # Current sequence - x = mx.array([generated_ids]) - - for _ in range(max_tokens): - # Get logits for next token - logits = model(x)[:, -1, :] # [1, vocab_size] - - # Apply temperature - if temperature > 0: - logits = logits / temperature - probs = mx.softmax(logits, axis=-1) - # Sample from categorical distribution - next_token = mx.random.categorical(mx.log(probs + 1e-10)) - else: - # Greedy decoding - next_token = mx.argmax(logits, axis=-1) - - next_token_id = next_token.item() - - # Get log probability of sampled token - log_prob = nn.log_softmax(logits, axis=-1)[0, next_token_id] - log_probs.append(log_prob) - - # Append to sequence - generated_ids.append(next_token_id) - - # Check for EOS - if hasattr(tokenizer, 'eos_token_id') and next_token_id == tokenizer.eos_token_id: - break - - # Update input sequence - x = mx.array([generated_ids]) - - return mx.array(generated_ids), mx.stack(log_probs) + generated = sample_completion( + policy=model, + tokenizer=tokenizer, + prompt_ids=prompt_ids.tolist() if hasattr(prompt_ids, "tolist") else list(prompt_ids), + max_tokens=max_tokens, + temperature=temperature, + collect_sample_stats=False, + ) + token_logprobs = generated["sampled_logprobs"] + if token_logprobs: + return mx.array(generated["generated_ids"]), mx.array(token_logprobs, dtype=mx.float32) + return mx.array(generated["generated_ids"]), mx.zeros((0,), dtype=mx.float32) + + +def grpo_recompute_loss( + model: Any, + reference_model: Any, + input_ids: mx.array, + prompt_lengths: mx.array, + completion_lengths: mx.array, + rollout_logprobs: mx.array, + advantages: mx.array, + beta: float = 0.04, + clip_epsilon: float = 0.2, + epsilon_low: Optional[float] = None, + epsilon_high: Optional[float] = None, + temperature: float = 1.0, + loss_type: str = "grpo", + max_completion_length: Optional[int] = None, + old_token_logprobs: Optional[mx.array] = None, + reference_logprobs: Optional[mx.array] = None, +) -> Tuple[mx.array, mx.array]: + """ + Recompute GRPO-family losses on fixed sampled completions. + """ + epsilon_low = clip_epsilon if epsilon_low is None else epsilon_low + epsilon_high = clip_epsilon if epsilon_high is None else epsilon_high + batch = _policy_eval_from_padded( + input_ids=input_ids, + lengths=prompt_lengths + completion_lengths, + mode="completion", + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + rollout_logprobs=rollout_logprobs, + old_logprobs=rollout_logprobs, + old_token_logprobs=old_token_logprobs, + reference_logprobs=reference_logprobs, + advantages=advantages, + ) + scored_batch = score_policy( + model, + batch, + mode="completion", + reference_model=reference_model if reference_logprobs is None else None, + temperature=temperature, + ) + + sequence_logprobs = scored_batch.summed_logprobs + normalized_current = normalize_logprobs(sequence_logprobs, completion_lengths, mode="mean") + kl_scored_batch = scored_batch + if beta != 0.0 and temperature != 1.0: + kl_scored_batch = score_policy( + model, + batch, + mode="completion", + reference_model=reference_model if reference_logprobs is None else None, + temperature=1.0, + ) + if kl_scored_batch.reference_logprobs is not None: + normalized_reference = normalize_logprobs( + kl_scored_batch.reference_logprobs, + completion_lengths, + mode="mean", + ) + normalized_kl_current = normalize_logprobs( + kl_scored_batch.summed_logprobs, + completion_lengths, + mode="mean", + ) + sequence_kl_penalty = kl_against_reference(normalized_kl_current, normalized_reference) + else: + sequence_kl_penalty = mx.zeros_like(advantages) + + if loss_type == "gspo": + sequence_ratios = mx.exp(normalized_current - normalize_logprobs(rollout_logprobs, completion_lengths, mode="mean")) + clipped_sequence_ratios = mx.clip(sequence_ratios, 1.0 - epsilon_low, 1.0 + epsilon_high) + policy_objective = mx.minimum(sequence_ratios * advantages, clipped_sequence_ratios * advantages) + loss = -mx.mean(policy_objective - beta * sequence_kl_penalty) + return loss, completion_lengths.sum() + + if old_token_logprobs is None: + sequence_ratios = mx.exp(sequence_logprobs - rollout_logprobs) + clipped_sequence_ratios = mx.clip(sequence_ratios, 1.0 - epsilon_low, 1.0 + epsilon_high) + policy_objective = mx.minimum( + sequence_ratios * advantages, + clipped_sequence_ratios * advantages, + ) + loss = -mx.mean(policy_objective - beta * sequence_kl_penalty) + return loss, completion_lengths.sum() + + mask = completion_token_mask(input_ids, prompt_lengths, completion_lengths).astype(mx.float32) + current_token_logprobs = scored_batch.token_logprobs + token_ratios = mx.exp(current_token_logprobs - old_token_logprobs) + clipped_token_ratios = mx.clip(token_ratios, 1.0 - epsilon_low, 1.0 + epsilon_high) + per_token_objective = mx.minimum( + token_ratios * advantages[:, None], + clipped_token_ratios * advantages[:, None], + ) + masked_objective = per_token_objective * mask + + if loss_type == "grpo": + policy_objective = masked_objective.sum(axis=-1) / mx.maximum(mask.sum(axis=-1), 1.0) + loss = -mx.mean(policy_objective - beta * sequence_kl_penalty) + elif loss_type == "bnpo" or loss_type == "dapo": + normalizer = mx.maximum(mask.sum(), 1.0) + loss = -( + masked_objective.sum() / normalizer + - beta * mx.mean(sequence_kl_penalty) + ) + elif loss_type == "dr_grpo": + denominator = float(max_completion_length or int(mx.max(completion_lengths).item()) or 1) + loss = -( + masked_objective.sum() / max(completion_lengths.shape[0] * denominator, 1.0) + - beta * mx.mean(sequence_kl_penalty) + ) + else: + raise ValueError(f"Unsupported GRPO loss_type: {loss_type}") + return loss, completion_lengths.sum() def grpo_loss( @@ -434,65 +650,29 @@ def grpo_loss( beta: float = 0.04, ) -> Tuple[mx.array, int]: """ - Compute GRPO (Group Relative Policy Optimization) loss for a single prompt. - - GRPO: - 1. Generates multiple completions for each prompt - 2. Computes rewards for each completion - 3. Uses group statistics for advantage estimation - 4. Computes policy gradient loss - - Args: - model: The policy model. - tokenizer: The tokenizer. - prompt_ids: Prompt token IDs. - reward_fn: Function(completion, prompt) -> reward. - prompt_text: Original prompt text for reward computation. - num_generations: Number of completions to generate per prompt. - temperature: Sampling temperature. - max_tokens: Maximum tokens per completion. - beta: KL penalty coefficient. - - Returns: - Tuple of (loss, num_completions). - """ - completions = [] - all_log_probs = [] - - # Generate multiple completions - for _ in range(num_generations): - gen_ids, log_probs = generate_with_log_probs( - model, tokenizer, prompt_ids, - max_tokens=max_tokens, - temperature=temperature, - ) - - # Decode completion (skip prompt) - prompt_len = len(prompt_ids) - completion_ids = gen_ids[prompt_len:] - completion_text = tokenizer.decode(completion_ids.tolist()) - - completions.append(completion_text) - all_log_probs.append(log_probs.sum()) # Sum log probs - - # Compute rewards - rewards = [] - for completion in completions: - reward = reward_fn(completion, prompt_text) - rewards.append(reward) - - rewards = mx.array(rewards) - log_probs_tensor = mx.stack(all_log_probs) - - # Compute advantages using group statistics - mean_reward = mx.mean(rewards) - std_reward = mx.std(rewards) + 1e-8 - advantages = (rewards - mean_reward) / std_reward - - # Policy gradient loss: -E[advantage * log_prob] - # We want to increase prob of high-advantage completions - pg_loss = -mx.mean(advantages * log_probs_tensor) - + Legacy log-only GRPO loss retained for compatibility. + """ + del beta + rollout_batch = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=[ + { + "sample_index": 0, + "prompt": prompt_text, + "prompt_ids": prompt_ids.tolist() if hasattr(prompt_ids, "tolist") else list(prompt_ids), + "reward_context": prompt_text, + } + ], + sampling_config={ + "num_generations": num_generations, + "temperature": temperature, + "max_tokens": max_tokens, + }, + ) + reward_batch = evaluate_rewards(rollout_batch, reward_fn) + advantages = compute_advantages(reward_batch) + pg_loss = -mx.mean(advantages * rollout_batch.rollout_logprobs) return pg_loss, num_generations @@ -507,47 +687,33 @@ def grpo_batch_loss( beta: float = 0.04, ) -> Tuple[mx.array, int]: """ - Compute GRPO loss for a batch of prompts. - - Args: - model: The policy model. - tokenizer: The tokenizer. - prompts: List of prompt strings. - reward_fn: Reward function. - num_generations: Completions per prompt. - temperature: Sampling temperature. - max_tokens: Max tokens per completion. - beta: KL coefficient. - - Returns: - Tuple of (average_loss, total_completions). - """ - losses = [] - total_completions = 0 - - for prompt in prompts: - prompt_ids = mx.array(tokenizer.encode(prompt)) - - loss, n_comp = grpo_loss( - model=model, - tokenizer=tokenizer, - prompt_ids=prompt_ids, - reward_fn=reward_fn, - prompt_text=prompt, - num_generations=num_generations, - temperature=temperature, - max_tokens=max_tokens, - beta=beta, - ) - - losses.append(loss) - total_completions += n_comp - - avg_loss = mx.mean(mx.stack(losses)) - return avg_loss, total_completions - + Legacy batched GRPO loss retained for compatibility. + """ + del beta + prompt_samples = [ + { + "sample_index": index, + "prompt": prompt, + "prompt_ids": tokenizer.encode(prompt), + "reward_context": prompt, + } + for index, prompt in enumerate(prompts) + ] + rollout_batch = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=prompt_samples, + sampling_config={ + "num_generations": num_generations, + "temperature": temperature, + "max_tokens": max_tokens, + }, + ) + reward_batch = evaluate_rewards(rollout_batch, reward_fn) + advantages = compute_advantages(reward_batch) + pg_loss = -mx.mean(advantages * rollout_batch.rollout_logprobs) + return pg_loss, len(prompt_samples) * num_generations -# Utility function for batched DPO def compute_reference_logprobs( model: Any, @@ -557,22 +723,12 @@ def compute_reference_logprobs( rejected_lengths: mx.array, ) -> Tuple[mx.array, mx.array]: """ - Compute reference log probabilities (for frozen reference model). - - Call this once before training to get reference logprobs, - then pass them to dpo_loss to avoid recomputation. - - Args: - model: The reference model (should be frozen/not updated). - chosen_ids: Chosen sequence token IDs. - rejected_ids: Rejected sequence token IDs. - chosen_lengths: Chosen sequence lengths. - rejected_lengths: Rejected sequence lengths. - - Returns: - Tuple of (ref_chosen_logprobs, ref_rejected_logprobs). + Backwards-compatible alias for batched DPO reference precompute. """ - ref_chosen = compute_log_probs_with_lengths(model, chosen_ids, chosen_lengths) - ref_rejected = compute_log_probs_with_lengths(model, rejected_ids, rejected_lengths) - - return mx.stop_gradient(ref_chosen), mx.stop_gradient(ref_rejected) + return precompute_preference_reference_logprobs( + model, + chosen_ids, + rejected_ids, + chosen_lengths, + rejected_lengths, + ) diff --git a/mlx_tune/model.py b/mlx_tune/model.py index 80026ca..15f0812 100644 --- a/mlx_tune/model.py +++ b/mlx_tune/model.py @@ -5,9 +5,14 @@ using Apple's MLX framework under the hood. """ -from typing import Optional, Tuple, Union, List, Any, Dict +from dataclasses import dataclass +import json +from typing import Optional, Tuple, Union, List, Any, Dict, Mapping, Sequence from pathlib import Path +import copy import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten, tree_unflatten from mlx_lm import load as mlx_load import warnings @@ -128,7 +133,13 @@ def from_pretrained( try: # Load model using MLX (with config for saving later) - model, tokenizer, config = mlx_load(model_name, return_config=True, **mlx_kwargs) + try: + model, tokenizer, config = mlx_load(model_name, return_config=True, **mlx_kwargs) + except TypeError as exc: + if "return_config" not in str(exc): + raise + model, tokenizer = mlx_load(model_name, **mlx_kwargs) + config = None # Wrap model with our compatibility layer wrapped_model = MLXModelWrapper( @@ -489,6 +500,12 @@ def set_adapter_path(self, path: str) -> None: """ self._adapter_path = Path(path) + def has_adapters(self) -> bool: + """ + Return whether this wrapper currently tracks adapter state. + """ + return bool(self._lora_applied or self._adapter_path is not None) + def get_adapter_path(self) -> Optional[Path]: """ Get the current adapter path. @@ -498,6 +515,142 @@ def get_adapter_path(self) -> Optional[Path]: """ return self._adapter_path + def clone( + self, + freeze: bool = False, + snapshot_adapters: bool = True, + copy_adapter_path: bool = False, + ) -> "MLXModelWrapper": + """ + Deep-clone the wrapped model and optionally freeze the clone. + """ + clone = MLXModelWrapper( + model=copy.deepcopy(self.model), + tokenizer=self.tokenizer, + max_seq_length=self.max_seq_length, + model_name=self.model_name, + config=copy.deepcopy(self.config), + ) + clone.lora_config = copy.deepcopy(self.lora_config) + clone.lora_enabled = self.lora_enabled + clone._lora_applied = self._lora_applied + clone._adapter_path = ( + Path(self._adapter_path) if copy_adapter_path and self._adapter_path is not None else None + ) + clone.inference_mode = self.inference_mode + clone.use_cache = self.use_cache + + source_actual = self.model + clone_actual = clone.model + clone_actual.update(source_actual.parameters(), strict=False) + if snapshot_adapters and self.has_adapters(): + clone.load_adapter_state(self.snapshot_adapter_state(), strict=False) + mx.eval(clone_actual.parameters()) + + if freeze: + clone.freeze_parameters() + return clone + + def freeze_parameters(self) -> None: + """ + Freeze all parameters on the wrapped model. + """ + if hasattr(self.model, "freeze"): + self.model.freeze() + mx.eval(self.model.parameters()) + + def snapshot_adapter_state(self) -> Dict[str, mx.array]: + """ + Capture the current trainable adapter parameter state as a flat tree. + """ + if not self.has_adapters(): + return {} + return { + name: mx.array(value) + for name, value in tree_flatten(self.model.trainable_parameters()) + } + + def load_adapter_state( + self, + adapter_state: Mapping[str, mx.array], + strict: bool = False, + ) -> None: + """ + Restore adapter parameters from a flat tree. + """ + if not adapter_state: + return + self.model.update(tree_unflatten(list(adapter_state.items())), strict=strict) + mx.eval(self.model.parameters()) + + def build_adapter_config(self) -> Dict[str, Any]: + """ + Build the mlx_lm-compatible adapter configuration for the current LoRA setup. + """ + num_layers = None + if hasattr(self.model, "layers"): + num_layers = len(self.model.layers) + elif hasattr(self.model, "model") and hasattr(self.model.model, "layers"): + num_layers = len(self.model.model.layers) + + lora_config = self.lora_config.copy() if self.lora_config else {} + r = lora_config.get("r", 16) + alpha = lora_config.get("lora_alpha", 16) + adapter_config = { + "fine_tune_type": "lora", + "num_layers": num_layers, + "lora_parameters": { + "rank": r, + "scale": alpha / r, + "dropout": lora_config.get("lora_dropout", 0.0), + }, + } + + target_modules = lora_config.get("target_modules", []) + if target_modules: + short_to_full = { + "q_proj": "self_attn.q_proj", + "k_proj": "self_attn.k_proj", + "v_proj": "self_attn.v_proj", + "o_proj": "self_attn.o_proj", + "gate_proj": "mlp.gate_proj", + "up_proj": "mlp.up_proj", + "down_proj": "mlp.down_proj", + } + adapter_config["lora_parameters"]["keys"] = [ + short_to_full.get(module, module) for module in target_modules + ] + return adapter_config + + def save_adapter_snapshot(self, output_dir: str) -> bool: + """ + Persist the current adapter state in mlx_lm's adapter directory layout. + """ + if not self.has_adapters(): + return False + + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + mx.save_safetensors( + str(output_path / "adapters.safetensors"), + self.snapshot_adapter_state(), + ) + with open(output_path / "adapter_config.json", "w") as handle: + json.dump(self.build_adapter_config(), handle, indent=2) + self._adapter_path = output_path + return True + + def load_adapter_snapshot(self, adapter_path: str, strict: bool = False) -> None: + """ + Load adapter state from a role/checkpoint directory. + """ + adapter_dir = Path(adapter_path) + adapter_file = adapter_dir / "adapters.safetensors" + if not adapter_file.exists(): + raise FileNotFoundError(f"Missing adapters.safetensors under {adapter_dir}") + self.load_adapter_state(mx.load(str(adapter_file)), strict=strict) + self._adapter_path = adapter_dir + def enable_inference_mode(self, use_cache: bool = True): """ Enable inference mode optimizations. @@ -788,17 +941,556 @@ def save_pretrained_gguf( **kwargs ) - def __call__(self, *args, **kwargs): + def forward_with_cache(self, *args, cache=None, **kwargs): + """ + Forward pass with optional cache support for autoregressive decoding. + """ + if cache is None: + return self.model(*args, **kwargs) + return self.model(*args, cache=cache, **kwargs) + + def __call__(self, *args, cache=None, **kwargs): """ Forward pass through the model. Note: This is a simplified interface. For training, use MLX's training utilities directly. """ - return self.model(*args, **kwargs) + return self.forward_with_cache(*args, cache=cache, **kwargs) def __getattr__(self, name): """ Delegate attribute access to the underlying MLX model. """ return getattr(self.model, name) + + +class ReferencePolicy: + """ + Frozen reference-policy wrapper used by native RL trainers. + + A reference policy can either wrap an explicit reference model provided by + the caller or snapshot the current policy into a detached, frozen model + instance before RL optimization starts. + """ + + def __init__( + self, + model: Any, + source: str, + metadata: Optional[Dict[str, Any]] = None, + ): + self.model = model + self.source = source + self.metadata = metadata or {} + self._freeze() + + @classmethod + def from_model( + cls, + policy_model: Any, + ref_model: Optional[Any] = None, + ) -> "ReferencePolicy": + return build_reference_policy(policy_model, ref_model=ref_model, snapshot=True) + + @staticmethod + def _unwrap(model: Any) -> Any: + return model.model if hasattr(model, "model") else model + + @classmethod + def _snapshot_model(cls, model: Any) -> Any: + if isinstance(model, MLXModelWrapper): + snapshot = MLXModelWrapper( + model=copy.deepcopy(model.model), + tokenizer=model.tokenizer, + max_seq_length=model.max_seq_length, + model_name=model.model_name, + config=copy.deepcopy(model.config), + ) + snapshot.lora_config = copy.deepcopy(model.lora_config) + snapshot.lora_enabled = model.lora_enabled + snapshot._lora_applied = model._lora_applied + snapshot._adapter_path = model._adapter_path + snapshot.inference_mode = model.inference_mode + snapshot.use_cache = model.use_cache + else: + snapshot = copy.deepcopy(model) + + source_actual = cls._unwrap(model) + snapshot_actual = cls._unwrap(snapshot) + snapshot_actual.update(source_actual.parameters()) + mx.eval(snapshot_actual.parameters()) + return snapshot + + def _freeze(self) -> None: + actual_model = self._unwrap(self.model) + if hasattr(actual_model, "freeze"): + actual_model.freeze() + mx.eval(actual_model.parameters()) + + +def _actual_model(model: Any) -> Any: + return model.model if hasattr(model, "model") else model + + +def _clone_role_model(model: Any, freeze: bool = False) -> Any: + if isinstance(model, MLXModelWrapper): + return model.clone(freeze=freeze, snapshot_adapters=True, copy_adapter_path=False) + + snapshot = copy.deepcopy(model) + source_actual = _actual_model(model) + snapshot_actual = _actual_model(snapshot) + if hasattr(snapshot_actual, "update"): + snapshot_actual.update(source_actual.parameters(), strict=False) + mx.eval(snapshot_actual.parameters()) + if freeze and hasattr(snapshot_actual, "freeze"): + snapshot_actual.freeze() + mx.eval(snapshot_actual.parameters()) + if hasattr(snapshot, "_adapter_path"): + snapshot._adapter_path = None + return snapshot + + +def build_reference_policy( + policy_model: Any, + ref_model: Optional[Any] = None, + snapshot: bool = True, +) -> ReferencePolicy: + """ + Build an explicit frozen reference policy for RL training. + """ + source_model = ref_model if ref_model is not None else policy_model + if snapshot: + role_model = _clone_role_model(source_model, freeze=True) + source = "explicit_snapshot" if ref_model is not None else "policy_snapshot" + strategy = "clone_and_freeze" + else: + role_model = source_model + actual_model = _actual_model(role_model) + if hasattr(actual_model, "freeze"): + actual_model.freeze() + mx.eval(actual_model.parameters()) + source = "explicit_live" if ref_model is not None else "policy_live" + strategy = "freeze_in_place" + + metadata = { + "source": source, + "snapshot_strategy": strategy, + "model_name": getattr(source_model, "model_name", None), + "adapter_path": ( + str(source_model.get_adapter_path()) + if hasattr(source_model, "get_adapter_path") and source_model.get_adapter_path() is not None + else None + ), + } + return ReferencePolicy(role_model, source=source, metadata=metadata) + + +def _infer_hidden_size(module: Any) -> int: + if hasattr(module, "args"): + for attr in ("hidden_size", "dim"): + value = getattr(module.args, attr, None) + if value is not None: + return int(value) + for attr in ("hidden_size", "n_embd", "dim"): + value = getattr(module, attr, None) + if value is not None: + return int(value) + if hasattr(module, "embed_tokens") and hasattr(module.embed_tokens, "weight"): + return int(module.embed_tokens.weight.shape[-1]) + if hasattr(module, "embedding") and hasattr(module.embedding, "weight"): + return int(module.embedding.weight.shape[-1]) + raise ValueError("Could not infer backbone hidden size for scalar head construction.") + + +def _resolve_hidden_backbone(model: Any) -> Any: + actual_model = _actual_model(model) + hidden_backbone = getattr(actual_model, "model", None) + if callable(hidden_backbone): + return hidden_backbone + if callable(actual_model): + return actual_model + raise ValueError("Scalar-head roles require a callable causal-LM backbone.") + + +def _pad_sequence_batch( + sequences: Sequence[Sequence[int]], + pad_id: int, +) -> Tuple[mx.array, mx.array]: + max_length = max(len(sequence) for sequence in sequences) + padded = [list(sequence) + [pad_id] * (max_length - len(sequence)) for sequence in sequences] + lengths = [len(sequence) for sequence in sequences] + return mx.array(padded), mx.array(lengths) + + +def _flat_parameter_state(model: Any) -> Dict[str, mx.array]: + actual_model = _actual_model(model) + return {name: mx.array(value) for name, value in tree_flatten(actual_model.parameters())} + + +def _load_flat_parameter_state( + model: Any, + parameter_state: Mapping[str, mx.array], + strict: bool = False, +) -> None: + if not parameter_state: + return + actual_model = _actual_model(model) + actual_model.update(tree_unflatten(list(parameter_state.items())), strict=strict) + mx.eval(actual_model.parameters()) + + +@dataclass +class RLModelRoles: + policy_model: Any + reference_policy: ReferencePolicy + reward_model: Optional["RewardModel"] = None + value_model: Optional["ValueModel"] = None + + def as_dict(self) -> Dict[str, Any]: + return { + "policy": self.policy_model, + "reference": self.reference_policy, + "reward_model": self.reward_model, + "value_model": self.value_model, + } + + +class ScalarHeadModel: + role_name = "scalar_model" + + def __init__( + self, + base_model: Any, + pooling: str = "last_token", + target: str = "completion", + head: Optional[nn.Linear] = None, + head_config: Optional[Dict[str, Any]] = None, + ): + self.base_model = base_model + self.tokenizer = getattr(base_model, "tokenizer", None) + self.pooling = pooling + self.target = target + backbone_module = _resolve_hidden_backbone(base_model) + self._hidden_backbone = backbone_module + hidden_size = _infer_hidden_size(backbone_module) + self.head = head or nn.Linear(hidden_size, 1) + self.head_config = { + "role": self.role_name, + "pooling": pooling, + "target": target, + "hidden_size": hidden_size, + } + if head_config: + self.head_config.update(head_config) + self.pooling = self.head_config.get("pooling", self.pooling) + self.target = self.head_config.get("target", self.target) + + @classmethod + def from_pretrained(cls, base_model: Any, output_dir: str) -> "ScalarHeadModel": + output_path = Path(output_dir) + with open(output_path / "head_config.json") as handle: + head_config = json.load(handle) + instance = cls( + base_model=base_model, + pooling=head_config.get("pooling", "last_token"), + target=head_config.get("target", "completion"), + head_config=head_config, + ) + instance.load_pretrained(output_dir) + return instance + + def _normalize_batch_inputs( + self, + input_ids: Union[mx.array, Sequence[Sequence[int]]], + sequence_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + prompt_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + completion_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + ) -> Tuple[mx.array, mx.array, Optional[mx.array], Optional[mx.array]]: + if hasattr(input_ids, "shape"): + array_input_ids = input_ids + if sequence_lengths is None: + raise ValueError("sequence_lengths is required when input_ids is already padded.") + array_sequence_lengths = ( + sequence_lengths if hasattr(sequence_lengths, "shape") else mx.array(sequence_lengths) + ) + else: + sequences = [list(sequence) for sequence in input_ids] + pad_id = int(getattr(self.tokenizer, "pad_token_id", 0) or 0) + array_input_ids, array_sequence_lengths = _pad_sequence_batch(sequences, pad_id) + + prompt_lengths_array = None + completion_lengths_array = None + if prompt_lengths is not None: + prompt_lengths_array = prompt_lengths if hasattr(prompt_lengths, "shape") else mx.array(prompt_lengths) + if completion_lengths is not None: + completion_lengths_array = ( + completion_lengths + if hasattr(completion_lengths, "shape") + else mx.array(completion_lengths) + ) + return array_input_ids, array_sequence_lengths, prompt_lengths_array, completion_lengths_array + + def _hidden_states(self, input_ids: mx.array) -> mx.array: + hidden_states = self._hidden_backbone(input_ids) + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + if hasattr(hidden_states, "last_hidden_state"): + hidden_states = hidden_states.last_hidden_state + return hidden_states + + def _last_indices( + self, + sequence_lengths: mx.array, + prompt_lengths: Optional[mx.array], + completion_lengths: Optional[mx.array], + ) -> mx.array: + if self.target == "completion": + if prompt_lengths is None or completion_lengths is None: + raise ValueError("Completion scalar scoring requires prompt_lengths and completion_lengths.") + completion_last = prompt_lengths + completion_lengths - 1 + prompt_last = mx.maximum(prompt_lengths - 1, 0) + has_completion = completion_lengths > 0 + return mx.where(has_completion, completion_last, prompt_last) + return mx.maximum(sequence_lengths - 1, 0) + + def _token_mask( + self, + width: int, + sequence_lengths: mx.array, + prompt_lengths: Optional[mx.array], + completion_lengths: Optional[mx.array], + ) -> mx.array: + positions = mx.arange(width)[None, :] + if self.pooling == "mean_sequence": + return positions < sequence_lengths[:, None] + if self.pooling == "mean_completion": + if prompt_lengths is None or completion_lengths is None: + raise ValueError("Completion pooling requires prompt_lengths and completion_lengths.") + start = prompt_lengths[:, None] + end = (prompt_lengths + completion_lengths)[:, None] + mask = (positions >= start) & (positions < end) + has_completion = completion_lengths[:, None] > 0 + fallback = positions == mx.maximum(prompt_lengths - 1, 0)[:, None] + return mx.where(has_completion, mask, fallback) + raise ValueError(f"Unsupported scalar pooling mode: {self.pooling}") + + def _pool_hidden_states( + self, + hidden_states: mx.array, + sequence_lengths: mx.array, + prompt_lengths: Optional[mx.array], + completion_lengths: Optional[mx.array], + ) -> mx.array: + if self.pooling == "last_token": + indices = self._last_indices(sequence_lengths, prompt_lengths, completion_lengths) + batch_indices = mx.arange(hidden_states.shape[0]) + return hidden_states[batch_indices, indices] + + token_mask = self._token_mask( + hidden_states.shape[1], + sequence_lengths, + prompt_lengths, + completion_lengths, + ) + weights = token_mask.astype(hidden_states.dtype)[:, :, None] + totals = weights.sum(axis=1) + totals = mx.maximum(totals, 1.0) + return (hidden_states * weights).sum(axis=1) / totals + + def score( + self, + input_ids: Union[mx.array, Sequence[Sequence[int]]], + sequence_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + prompt_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + completion_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + ) -> mx.array: + array_input_ids, array_sequence_lengths, prompt_lengths_array, completion_lengths_array = ( + self._normalize_batch_inputs( + input_ids=input_ids, + sequence_lengths=sequence_lengths, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + ) + ) + hidden_states = self._hidden_states(array_input_ids) + pooled = self._pool_hidden_states( + hidden_states, + array_sequence_lengths, + prompt_lengths_array, + completion_lengths_array, + ) + return self.head(pooled).squeeze(-1) + + def save_pretrained(self, output_dir: str) -> None: + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + mx.save_safetensors( + str(output_path / "weights.safetensors"), + _flat_parameter_state(self.base_model), + ) + mx.save_safetensors( + str(output_path / "head.safetensors"), + dict(tree_flatten(self.head.parameters())), + ) + with open(output_path / "head_config.json", "w") as handle: + json.dump(self.head_config, handle, indent=2) + with open(output_path / "role.json", "w") as handle: + json.dump( + { + "role": self.role_name, + "pooling": self.pooling, + "target": self.target, + "backbone_weight_format": "weights.safetensors", + "backbone_has_adapters": bool( + isinstance(self.base_model, MLXModelWrapper) and self.base_model.has_adapters() + ), + }, + handle, + indent=2, + ) + if isinstance(self.base_model, MLXModelWrapper): + self.base_model.save_adapter_snapshot(str(output_path)) + + def load_pretrained(self, output_dir: str) -> None: + output_path = Path(output_dir) + weights_path = output_path / "weights.safetensors" + if weights_path.exists(): + _load_flat_parameter_state(self.base_model, mx.load(str(weights_path)), strict=False) + head_weights = mx.load(str(output_path / "head.safetensors")) + self.head.update(tree_unflatten(list(head_weights.items())), strict=False) + mx.eval(self.head.parameters()) + if isinstance(self.base_model, MLXModelWrapper): + adapter_file = output_path / "adapters.safetensors" + if adapter_file.exists(): + self.base_model.load_adapter_snapshot(str(output_path), strict=False) + config_path = output_path / "head_config.json" + if config_path.exists(): + with open(config_path) as handle: + self.head_config = json.load(handle) + self.pooling = self.head_config.get("pooling", self.pooling) + self.target = self.head_config.get("target", self.target) + + +class RewardModel(ScalarHeadModel): + role_name = "reward_model" + + def score_pairs( + self, + chosen_input_ids: Union[mx.array, Sequence[Sequence[int]]], + rejected_input_ids: Union[mx.array, Sequence[Sequence[int]]], + chosen_sequence_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + rejected_sequence_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + chosen_prompt_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + rejected_prompt_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + chosen_completion_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + rejected_completion_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + ) -> Tuple[mx.array, mx.array]: + return ( + self.score( + chosen_input_ids, + sequence_lengths=chosen_sequence_lengths, + prompt_lengths=chosen_prompt_lengths, + completion_lengths=chosen_completion_lengths, + ), + self.score( + rejected_input_ids, + sequence_lengths=rejected_sequence_lengths, + prompt_lengths=rejected_prompt_lengths, + completion_lengths=rejected_completion_lengths, + ), + ) + + def evaluate(self, payload: Dict[str, Any]) -> float: + sequence = [list(payload["prompt_ids"]) + list(payload["completion_ids"])] + scores = self.score( + sequence, + prompt_lengths=[len(payload["prompt_ids"])], + completion_lengths=[len(payload["completion_ids"])], + ) + return float(scores[0].item()) + + +class ValueModel(ScalarHeadModel): + role_name = "value_model" + + def predict( + self, + input_ids: Union[mx.array, Sequence[Sequence[int]]], + sequence_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + prompt_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + completion_lengths: Optional[Union[mx.array, Sequence[int]]] = None, + ) -> mx.array: + return self.score( + input_ids, + sequence_lengths=sequence_lengths, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + ) + + +def build_reward_model( + base_model: Any, + pooling: str = "last_token", + target: str = "completion", + snapshot: bool = True, + head_config: Optional[Dict[str, Any]] = None, +) -> RewardModel: + role_model = _clone_role_model(base_model, freeze=False) if snapshot else base_model + if snapshot and hasattr(role_model, "_adapter_path"): + role_model._adapter_path = None + return RewardModel(role_model, pooling=pooling, target=target, head_config=head_config) + + +def build_value_model( + base_model: Any, + pooling: str = "last_token", + target: str = "completion", + snapshot: bool = True, + head_config: Optional[Dict[str, Any]] = None, +) -> ValueModel: + role_model = _clone_role_model(base_model, freeze=False) if snapshot else base_model + if snapshot and hasattr(role_model, "_adapter_path"): + role_model._adapter_path = None + return ValueModel(role_model, pooling=pooling, target=target, head_config=head_config) + + +def create_rl_model_roles( + policy_model: Any, + ref_model: Optional[Any] = None, + reward_model: Optional[RewardModel] = None, + value_model: Optional[ValueModel] = None, + reward_base_model: Optional[Any] = None, + value_base_model: Optional[Any] = None, + reference_snapshot: bool = True, + reward_pooling: str = "last_token", + reward_target: str = "completion", + value_pooling: str = "last_token", + value_target: str = "completion", +) -> RLModelRoles: + resolved_reward_model = reward_model + if resolved_reward_model is None and reward_base_model is not None: + resolved_reward_model = build_reward_model( + reward_base_model, + pooling=reward_pooling, + target=reward_target, + ) + + resolved_value_model = value_model + if resolved_value_model is None and value_base_model is not None: + resolved_value_model = build_value_model( + value_base_model, + pooling=value_pooling, + target=value_target, + ) + + return RLModelRoles( + policy_model=policy_model, + reference_policy=build_reference_policy( + policy_model, + ref_model=ref_model, + snapshot=reference_snapshot, + ), + reward_model=resolved_reward_model, + value_model=resolved_value_model, + ) diff --git a/mlx_tune/rl_api.py b/mlx_tune/rl_api.py new file mode 100644 index 0000000..2f1a0ba --- /dev/null +++ b/mlx_tune/rl_api.py @@ -0,0 +1,878 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +import inspect +import json +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence +import warnings + +import mlx.core as mx +from mlx.utils import tree_unflatten + +from mlx_tune.chat_templates import apply_chat_template_to_sample +from mlx_tune.model import build_reference_policy, build_reward_model + + +MANIFEST_FILE = "manifest.json" +STATE_FILE = "trainer_state.safetensors" +METADATA_FILE = "trainer_state.json" +REFERENCE_FILE = "reference_model.safetensors" +REFERENCE_METADATA_FILE = "reference_metadata.json" +CHECKPOINT_FORMAT_NAME = "mlx_tune_rl_checkpoint" +CHECKPOINT_FORMAT_VERSION = 4 +SUPPORTED_RL_DATASET_MODES = ( + "prompt", + "preference", + "reward_scalar", + "reward_pairwise", + "rollout", + "chat", +) + + +@dataclass +class PreparedRLDataset: + samples: List[Dict[str, Any]] + mode: str + adapter_name: str + metadata: Dict[str, Any] = field(default_factory=dict) + + def __iter__(self) -> Iterator[Dict[str, Any]]: + return iter(self.samples) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> Dict[str, Any]: + return self.samples[index] + + +@dataclass +class RLRoleState: + role: str + weight_format: Any + parameter_state: Dict[str, mx.array] = field(default_factory=dict) + head_state: Dict[str, mx.array] = field(default_factory=dict) + adapter_state: Dict[str, mx.array] = field(default_factory=dict) + head_config: Dict[str, Any] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class RLCheckpointBundle: + manifest: Dict[str, Any] + algorithm: str + restored_roles: Dict[str, RLRoleState] + optimizer_state_trees: Dict[str, Dict[str, Any]] + scheduler_metadata: Dict[str, Dict[str, Any]] + trainer_state: Dict[str, Any] + rng_state: Dict[str, mx.array] + runtime_cache: Dict[str, mx.array] + metrics_history: List[Dict[str, Any]] + source_format: str + + +def _read_json(path: Path, default: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + if not path.exists(): + return {} if default is None else default + with open(path) as handle: + return json.load(handle) + + +def _load_jsonl(path: Path) -> List[Dict[str, Any]]: + if not path.exists(): + return [] + rows: List[Dict[str, Any]] = [] + with open(path) as handle: + for line in handle: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def _extract_prefixed_tree(prefix: str, flat_state: Dict[str, mx.array]) -> Dict[str, Any]: + prefix_with_dot = f"{prefix}." + items = [(key[len(prefix_with_dot):], value) for key, value in flat_state.items() if key.startswith(prefix_with_dot)] + return tree_unflatten(items) if items else {} + + +def _role_dir(checkpoint_dir: Path, role_name: str) -> Path: + return checkpoint_dir / role_name + + +def _legacy_checkpoint_exists(checkpoint_dir: Path) -> bool: + return ( + (checkpoint_dir / "adapters" / "adapters.safetensors").exists() + or (checkpoint_dir / STATE_FILE).exists() + or (checkpoint_dir / REFERENCE_FILE).exists() + ) + + +def _metrics_path(checkpoint_dir: Path) -> Path: + return checkpoint_dir / "metrics" / "history.jsonl" + + +def _runtime_cache_path(checkpoint_dir: Path) -> Path: + return checkpoint_dir / "runtime" / "cache.safetensors" + + +def _trainer_state_path(checkpoint_dir: Path) -> Path: + return checkpoint_dir / "trainer" / "state.json" + + +def _rng_path(checkpoint_dir: Path) -> Path: + return checkpoint_dir / "trainer" / "rng.safetensors" + + +def _scheduler_path(checkpoint_dir: Path, name: Optional[str] = None) -> Path: + if name is None: + return checkpoint_dir / "scheduler" / "state.json" + return checkpoint_dir / "schedulers" / name / "state.json" + + +def _optimizer_path(checkpoint_dir: Path, name: Optional[str] = None) -> Path: + if name is None: + return checkpoint_dir / "optimizer" / "state.safetensors" + return checkpoint_dir / "optimizers" / name / "state.safetensors" + + +def _load_role_state(checkpoint_dir: Path, role_name: str, weight_format: Any) -> RLRoleState: + role_dir = _role_dir(checkpoint_dir, role_name) + metadata = _read_json(role_dir / "metadata.json") + head_config = _read_json(role_dir / "head_config.json") + parameter_state: Dict[str, mx.array] = {} + head_state: Dict[str, mx.array] = {} + adapter_state: Dict[str, mx.array] = {} + + if isinstance(weight_format, str): + weight_path = role_dir / weight_format + if weight_path.exists(): + parameter_state = dict(mx.load(str(weight_path))) + elif isinstance(weight_format, Mapping): + backbone_path = role_dir / str(weight_format.get("backbone", "weights.safetensors")) + head_path = role_dir / str(weight_format.get("head", "head.safetensors")) + adapters_name = weight_format.get("adapters") + if backbone_path.exists(): + parameter_state = dict(mx.load(str(backbone_path))) + if head_path.exists(): + head_state = dict(mx.load(str(head_path))) + if adapters_name: + adapter_path = role_dir / str(adapters_name) + if adapter_path.exists(): + adapter_state = dict(mx.load(str(adapter_path))) + + return RLRoleState( + role=role_name, + weight_format=weight_format, + parameter_state=parameter_state, + head_state=head_state, + adapter_state=adapter_state, + head_config=head_config, + metadata=metadata, + ) + + +def _build_manifest_bundle(checkpoint_dir: Path) -> RLCheckpointBundle: + manifest = _read_json(checkpoint_dir / MANIFEST_FILE) + if not manifest: + raise FileNotFoundError(f"Checkpoint manifest not found under {checkpoint_dir}") + + roles = { + role_name: _load_role_state( + checkpoint_dir, + role_name, + manifest.get("role_weight_formats", {}).get(role_name, "weights.safetensors"), + ) + for role_name in manifest.get("roles_present", []) + } + + trainer_locations = manifest.get("trainer_state_locations", {}) + optimizer_state_trees: Dict[str, Dict[str, Any]] = {} + optimizer_locations = trainer_locations.get("optimizers") or {} + if optimizer_locations: + for name, relative_path in optimizer_locations.items(): + path = checkpoint_dir / relative_path + if path.exists(): + optimizer_state_trees[name] = _extract_prefixed_tree("optimizer", dict(mx.load(str(path)))) + else: + legacy_optimizer_path = _optimizer_path(checkpoint_dir) + if legacy_optimizer_path.exists(): + optimizer_state_trees["default"] = _extract_prefixed_tree( + "optimizer", + dict(mx.load(str(legacy_optimizer_path))), + ) + + scheduler_metadata: Dict[str, Dict[str, Any]] = {} + scheduler_locations = trainer_locations.get("schedulers") or {} + if scheduler_locations: + for name, relative_path in scheduler_locations.items(): + scheduler_metadata[name] = _read_json(checkpoint_dir / relative_path) + else: + default_scheduler = _read_json(_scheduler_path(checkpoint_dir)) + if default_scheduler: + scheduler_metadata["default"] = default_scheduler + + rng_state = dict(mx.load(str(_rng_path(checkpoint_dir)))) if _rng_path(checkpoint_dir).exists() else {} + runtime_cache = ( + dict(mx.load(str(_runtime_cache_path(checkpoint_dir)))) + if _runtime_cache_path(checkpoint_dir).exists() + else {} + ) + trainer_state = _read_json(_trainer_state_path(checkpoint_dir)) + metrics_history = _load_jsonl(_metrics_path(checkpoint_dir)) + + return RLCheckpointBundle( + manifest=manifest, + algorithm=manifest.get("algorithm", trainer_state.get("algorithm", "rl")), + restored_roles=roles, + optimizer_state_trees=optimizer_state_trees, + scheduler_metadata=scheduler_metadata, + trainer_state=trainer_state, + rng_state=rng_state, + runtime_cache=runtime_cache, + metrics_history=metrics_history, + source_format="manifest", + ) + + +def _build_legacy_manifest(metadata: Dict[str, Any], has_reference: bool) -> Dict[str, Any]: + roles_present = ["policy"] + role_weight_formats: Dict[str, Any] = {"policy": "adapters.safetensors"} + if has_reference: + roles_present.append("reference") + role_weight_formats["reference"] = "weights.safetensors" + return { + "format_name": CHECKPOINT_FORMAT_NAME, + "format_version": CHECKPOINT_FORMAT_VERSION, + "algorithm": metadata.get("algorithm", "rl"), + "roles_present": roles_present, + "role_weight_formats": role_weight_formats, + "trainer_state_locations": { + "optimizer": "trainer_state.safetensors", + "rng": "trainer_state.safetensors", + "runtime_cache": "trainer_state.safetensors", + "trainer": "trainer_state.json", + }, + "metrics_path": None, + } + + +def _build_legacy_bundle(checkpoint_dir: Path) -> RLCheckpointBundle: + state_path = checkpoint_dir / STATE_FILE + metadata_path = checkpoint_dir / METADATA_FILE + if not state_path.exists() or not metadata_path.exists(): + raise FileNotFoundError(f"Checkpoint state not found under {checkpoint_dir}") + + metadata = _read_json(metadata_path) + flat_state = dict(mx.load(str(state_path))) + policy_state: Dict[str, mx.array] = {} + adapter_path = checkpoint_dir / "adapters" / "adapters.safetensors" + if adapter_path.exists(): + policy_state = dict(mx.load(str(adapter_path))) + + roles = { + "policy": RLRoleState( + role="policy", + weight_format="adapters.safetensors", + parameter_state=policy_state, + ), + } + reference_path = checkpoint_dir / REFERENCE_FILE + if reference_path.exists(): + roles["reference"] = RLRoleState( + role="reference", + weight_format="weights.safetensors", + parameter_state=dict(mx.load(str(reference_path))), + metadata=_read_json(checkpoint_dir / REFERENCE_METADATA_FILE), + ) + + runtime_cache = { + key: value + for key, value in flat_state.items() + if not key.startswith("optimizer.") and not key.startswith("rng.") + } + optimizer_state = _extract_prefixed_tree("optimizer", flat_state) + rng_state = {key: value for key, value in flat_state.items() if key.startswith("rng.")} + + return RLCheckpointBundle( + manifest=_build_legacy_manifest(metadata, has_reference="reference" in roles), + algorithm=metadata.get("algorithm", "rl"), + restored_roles=roles, + optimizer_state_trees={"default": optimizer_state} if optimizer_state else {}, + scheduler_metadata={}, + trainer_state=metadata, + rng_state=rng_state, + runtime_cache=runtime_cache, + metrics_history=[], + source_format="legacy", + ) + + +def resume_from_checkpoint(checkpoint_dir: str | Path) -> RLCheckpointBundle: + checkpoint_path = Path(checkpoint_dir) + if (checkpoint_path / MANIFEST_FILE).exists(): + return _build_manifest_bundle(checkpoint_path) + if _legacy_checkpoint_exists(checkpoint_path): + return _build_legacy_bundle(checkpoint_path) + raise FileNotFoundError(f"Checkpoint state not found under {checkpoint_path}") + + +def _tokenizer_encode(tokenizer: Any, text: str, add_special_tokens: bool) -> List[int]: + try: + return list(tokenizer.encode(text, add_special_tokens=add_special_tokens)) + except TypeError: + return list(tokenizer.encode(text)) + + +def _is_message_sequence(value: Any) -> bool: + return isinstance(value, list) and all(isinstance(item, Mapping) and "role" in item for item in value) + + +def _render_messages( + messages: Sequence[Mapping[str, Any]], + tokenizer: Any = None, + chat_template: Optional[Any] = None, + add_generation_prompt: bool = False, +) -> str: + payload = {"messages": list(messages)} + if callable(chat_template): + return chat_template(payload, add_generation_prompt=add_generation_prompt) + if isinstance(chat_template, str): + if add_generation_prompt: + return chat_template.format(messages=list(messages), add_generation_prompt=True) + return chat_template.format(messages=list(messages)) + if tokenizer is not None: + return apply_chat_template_to_sample( + payload, + tokenizer, + add_generation_prompt=add_generation_prompt, + ) + return "\n".join(f"{message.get('role', 'user')}: {message.get('content', '')}" for message in messages) + + +def _last_assistant_index(messages: Sequence[Mapping[str, Any]]) -> Optional[int]: + for index in range(len(messages) - 1, -1, -1): + if messages[index].get("role") == "assistant": + return index + return None + + +def _extract_chat_prompt_response( + sample: Mapping[str, Any], + tokenizer: Any = None, + chat_template: Optional[Any] = None, +) -> tuple[str, str]: + messages = list(sample.get("messages") or []) + if not messages: + raise ValueError("Chat adaptation requires a messages field.") + assistant_index = _last_assistant_index(messages) + if assistant_index is None: + prompt_messages = messages + response = "" + else: + prompt_messages = messages[:assistant_index] + response = str(messages[assistant_index].get("content", "")) + prompt = _render_messages( + prompt_messages, + tokenizer=tokenizer, + chat_template=chat_template, + add_generation_prompt=True, + ) + return prompt, response + + +def _extract_preference_value( + value: Any, + prompt: str, + tokenizer: Any = None, + chat_template: Optional[Any] = None, +) -> tuple[str, str]: + if isinstance(value, str): + return prompt, value + if isinstance(value, Mapping) and _is_message_sequence(value.get("messages")): + prompt_text, response = _extract_chat_prompt_response( + value, + tokenizer=tokenizer, + chat_template=chat_template, + ) + return prompt_text or prompt, response + if _is_message_sequence(value): + prompt_text, response = _extract_chat_prompt_response( + {"messages": value}, + tokenizer=tokenizer, + chat_template=chat_template, + ) + return prompt_text or prompt, response + raise ValueError("Unsupported preference value; expected string or chat message sequence.") + + +def _normalize_prompt_sample( + sample: Mapping[str, Any], + tokenizer: Any = None, + chat_template: Optional[Any] = None, +) -> Dict[str, Any]: + if _is_message_sequence(sample.get("messages")): + prompt, response = _extract_chat_prompt_response( + sample, + tokenizer=tokenizer, + chat_template=chat_template, + ) + reward_context = sample.get("reward_context", sample.get("answer", sample.get("response", response or prompt))) + return { + "prompt": prompt, + "reward_context": reward_context, + "source_messages": list(sample.get("messages") or []), + } + + prompt = str(sample.get("prompt", sample.get("question", ""))) + if not prompt: + raise ValueError("Prompt samples require a prompt or question field.") + return { + "prompt": prompt, + "reward_context": sample.get( + "reward_context", + sample.get("answer", sample.get("response", prompt)), + ), + } + + +def _normalize_preference_sample( + sample: Mapping[str, Any], + tokenizer: Any = None, + chat_template: Optional[Any] = None, + target_mode: str = "preference", +) -> Dict[str, Any]: + prompt = str(sample.get("prompt", sample.get("question", ""))) + if _is_message_sequence(sample.get("messages")): + prompt = _render_messages( + sample.get("messages") or [], + tokenizer=tokenizer, + chat_template=chat_template, + add_generation_prompt=True, + ) + chosen_prompt, chosen = _extract_preference_value( + sample.get("chosen"), + prompt, + tokenizer=tokenizer, + chat_template=chat_template, + ) + rejected_prompt, rejected = _extract_preference_value( + sample.get("rejected"), + chosen_prompt, + tokenizer=tokenizer, + chat_template=chat_template, + ) + return { + "type": "pairwise" if target_mode == "reward_pairwise" else "preference", + "prompt": chosen_prompt or rejected_prompt, + "chosen": chosen, + "rejected": rejected, + } + + +def _normalize_reward_scalar_sample( + sample: Mapping[str, Any], + tokenizer: Any = None, + chat_template: Optional[Any] = None, +) -> Dict[str, Any]: + if _is_message_sequence(sample.get("messages")): + prompt, response = _extract_chat_prompt_response( + sample, + tokenizer=tokenizer, + chat_template=chat_template, + ) + return { + "type": "scalar", + "prompt": prompt, + "response": response, + "score": float(sample.get("score", sample.get("reward", 0.0))), + } + + if "text" in sample: + return { + "type": "scalar", + "prompt": "", + "response": str(sample.get("text", "")), + "score": float(sample.get("score", sample.get("reward", 0.0))), + } + + prompt = str(sample.get("prompt", sample.get("question", ""))) + response = str(sample.get("response", sample.get("completion", sample.get("assistant", "")))) + if not response: + raise ValueError("Scalar reward samples require a response, completion, assistant, or text field.") + return { + "type": "scalar", + "prompt": prompt, + "response": response, + "score": float(sample.get("score", sample.get("reward", 0.0))), + } + + +def _normalize_rollout_sample( + sample: Mapping[str, Any], + tokenizer: Any = None, + chat_template: Optional[Any] = None, +) -> Dict[str, Any]: + if _is_message_sequence(sample.get("messages")): + prompt, response = _extract_chat_prompt_response( + sample, + tokenizer=tokenizer, + chat_template=chat_template, + ) + completion = response or str(sample.get("completion", sample.get("response", ""))) + else: + prompt = str(sample.get("prompt", sample.get("question", ""))) + completion = str(sample.get("completion", sample.get("response", ""))) + if not prompt or not completion: + raise ValueError("Rollout samples require prompt/question and completion/response data.") + reward = sample.get("reward", sample.get("score")) + return { + "prompt": prompt, + "completion": completion, + "reward": None if reward is None else float(reward), + "reward_context": sample.get( + "reward_context", + sample.get("answer", sample.get("response", completion)), + ), + } + + +def _explicit_adapter_mode(mode: str) -> str: + if mode not in SUPPORTED_RL_DATASET_MODES: + raise ValueError( + f"Unsupported RL dataset mode '{mode}'. Supported modes: {SUPPORTED_RL_DATASET_MODES}" + ) + return mode + + +def _candidate_modes(sample: Mapping[str, Any]) -> List[str]: + if _is_message_sequence(sample.get("messages")): + messages = list(sample.get("messages") or []) + assistant_index = _last_assistant_index(messages) + if "chosen" in sample and "rejected" in sample: + if "chosen_score" in sample or "rejected_score" in sample: + return ["reward_pairwise"] + return ["preference", "reward_pairwise"] + if "score" in sample or "reward" in sample: + if assistant_index is not None: + return ["reward_scalar"] + if "completion" in sample or "response" in sample or "rewards" in sample: + return ["rollout"] + if assistant_index is None: + return ["prompt"] + return [] + + keys = set(sample.keys()) + if {"chosen", "rejected"} <= keys: + if "chosen_score" in keys or "rejected_score" in keys: + return ["reward_pairwise"] + return ["preference", "reward_pairwise"] + if "text" in keys and "score" in keys: + return ["reward_scalar"] + if ("prompt" in keys or "question" in keys) and ("completion" in keys) and ("reward" in keys or "score" in keys): + return ["rollout"] + if ("prompt" in keys or "question" in keys) and ("response" in keys) and ("score" in keys): + return ["reward_scalar"] + if ("prompt" in keys or "question" in keys) and not ( + {"chosen", "rejected", "response", "completion", "score", "reward"} & keys + ): + return ["prompt"] + return [] + + +def prepare_rl_dataset( + data: Iterable[Mapping[str, Any]] | PreparedRLDataset, + mode: Optional[str] = None, + tokenizer: Any = None, + chat_template: Optional[Any] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> PreparedRLDataset: + if isinstance(data, PreparedRLDataset): + if mode is None or mode == data.mode: + return data + data = data.samples + + records = list(data) + if not records: + resolved_mode = _explicit_adapter_mode(mode) if mode is not None else "prompt" + return PreparedRLDataset( + samples=[], + mode=resolved_mode if resolved_mode != "chat" else "prompt", + adapter_name=resolved_mode, + metadata=dict(metadata or {}), + ) + + requested_mode = _explicit_adapter_mode(mode) if mode is not None else None + if requested_mode == "chat": + candidate_sets = [_candidate_modes(record) for record in records] + unique_modes = sorted({candidate for candidates in candidate_sets for candidate in candidates}) + if len(unique_modes) != 1: + raise ValueError( + "Chat auto-adaptation is ambiguous. Choose one explicit RL mode from " + f"{SUPPORTED_RL_DATASET_MODES[:-1]}." + ) + requested_mode = unique_modes[0] + + if requested_mode is None: + candidate_sets = [_candidate_modes(record) for record in records] + unique_modes = {candidate for candidates in candidate_sets for candidate in candidates} + if not unique_modes: + raise ValueError( + "Could not auto-detect RL dataset mode. Supported modes: " + f"{SUPPORTED_RL_DATASET_MODES[:-1]}." + ) + if len(unique_modes) != 1: + raise ValueError( + "Ambiguous RL dataset schema. Choose an explicit mode from " + f"{sorted(unique_modes)}." + ) + requested_mode = next(iter(unique_modes)) + + normalized: List[Dict[str, Any]] = [] + adapter_name = requested_mode + for record in records: + if requested_mode == "prompt": + sample = _normalize_prompt_sample(record, tokenizer=tokenizer, chat_template=chat_template) + elif requested_mode == "preference": + sample = _normalize_preference_sample( + record, + tokenizer=tokenizer, + chat_template=chat_template, + target_mode="preference", + ) + elif requested_mode == "reward_scalar": + sample = _normalize_reward_scalar_sample( + record, + tokenizer=tokenizer, + chat_template=chat_template, + ) + elif requested_mode == "reward_pairwise": + sample = _normalize_preference_sample( + record, + tokenizer=tokenizer, + chat_template=chat_template, + target_mode="reward_pairwise", + ) + elif requested_mode == "rollout": + sample = _normalize_rollout_sample( + record, + tokenizer=tokenizer, + chat_template=chat_template, + ) + else: + raise ValueError(f"Unsupported RL dataset mode: {requested_mode}") + + if _is_message_sequence(record.get("messages")) or _is_message_sequence(record.get("chosen")): + adapter_name = f"chat_{requested_mode}" + normalized.append(sample) + + dataset_metadata = dict(metadata or {}) + dataset_metadata.update( + { + "requested_mode": mode, + "num_samples": len(normalized), + } + ) + return PreparedRLDataset( + samples=normalized, + mode=requested_mode, + adapter_name=adapter_name, + metadata=dataset_metadata, + ) + + +def prepare_reward_dataset(dataset: Iterable[Mapping[str, Any]]) -> List[Dict[str, Any]]: + warnings.warn( + "prepare_reward_dataset() is deprecated; use prepare_rl_dataset(..., mode='reward_scalar' or 'reward_pairwise') instead.", + DeprecationWarning, + stacklevel=2, + ) + records = list(dataset) + if not records: + return [] + first = records[0] + mode = "reward_pairwise" if {"chosen", "rejected"} <= set(first.keys()) else "reward_scalar" + return list(prepare_rl_dataset(records, mode=mode)) + + +def prepare_preference_dataset( + dataset: Iterable[Mapping[str, Any]], + tokenizer: Any = None, + format_type: str = "dpo", +) -> List[Dict[str, Any]]: + warnings.warn( + "prepare_preference_dataset() is deprecated; use prepare_rl_dataset(..., mode='preference' or 'prompt') instead.", + DeprecationWarning, + stacklevel=2, + ) + mode = "prompt" if format_type == "grpo" else "preference" + return list(prepare_rl_dataset(dataset, mode=mode, tokenizer=tokenizer)) + + +def _simple_reward_builder(reward_type: str) -> Any: + if reward_type == "simple": + return lambda response, ground_truth: 1.0 if str(ground_truth).lower() in str(response).lower() else 0.0 + if reward_type == "math": + def math_reward(response: str, ground_truth: str) -> float: + import re + + numbers = re.findall(r"-?\d+\.?\d*", response) + target = re.findall(r"-?\d+\.?\d*", ground_truth) + if numbers and target: + try: + return 1.0 if float(numbers[-1]) == float(target[-1]) else 0.0 + except Exception: + return 0.0 + return 0.0 + + return math_reward + if reward_type == "length": + def length_reward(response: str, _: str) -> float: + length = len(str(response).split()) + if length < 10: + return 0.2 + if length < 50: + return 0.5 + if length < 200: + return 1.0 + return 0.8 + + return length_reward + raise ValueError(f"Unknown reward type: {reward_type}") + + +class _WeightedRewardComposer: + def __init__(self, components: Sequence[Dict[str, Any]]): + self.components = [dict(component) for component in components] + self._running_stats: Dict[str, Dict[str, float]] = {} + + def _resolve_source(self, source: Any) -> Any: + if isinstance(source, str): + return _simple_reward_builder(source) + return source + + def _normalize(self, name: str, value: float, mode: Any) -> float: + if mode in (None, False, "none"): + return value + if callable(mode): + return float(mode(value)) + if mode not in (True, "zscore"): + raise ValueError(f"Unsupported reward normalization mode: {mode}") + stats = self._running_stats.setdefault(name, {"count": 0.0, "mean": 0.0, "m2": 0.0}) + stats["count"] += 1.0 + delta = value - stats["mean"] + stats["mean"] += delta / stats["count"] + delta2 = value - stats["mean"] + stats["m2"] += delta * delta2 + variance = stats["m2"] / max(stats["count"] - 1.0, 1.0) + std = variance ** 0.5 + if std < 1e-6: + return value - stats["mean"] + return (value - stats["mean"]) / std + + def _evaluate_source(self, source: Any, payload: Dict[str, Any]) -> tuple[float, Dict[str, float], Optional[Dict[str, Any]]]: + evaluator = self._resolve_source(source) + if hasattr(evaluator, "evaluate"): + result = evaluator.evaluate(payload) + elif callable(evaluator): + try: + signature = inspect.signature(evaluator) + except (TypeError, ValueError): + signature = None + + if signature is None: + result = evaluator(payload["completion_text"], payload["reward_context"]) + else: + required = [ + parameter + for parameter in signature.parameters.values() + if parameter.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + and parameter.default is inspect._empty + ] + if len(required) >= 2: + result = evaluator(payload["completion_text"], payload["reward_context"]) + else: + result = evaluator(payload) + else: + raise TypeError("Reward component source must be a string, callable, or evaluator object.") + + if isinstance(result, Mapping): + reward = float(result.get("reward", result.get("score", 0.0))) + return reward, dict(result.get("components") or {}), result.get("diagnostics") + return float(result), {}, None + + def evaluate(self, payload: Dict[str, Any]) -> Dict[str, Any]: + total = 0.0 + named_components: Dict[str, float] = {} + diagnostics: Dict[str, Any] = {} + for index, component in enumerate(self.components): + name = str(component.get("name") or f"reward_{index}") + weight = float(component.get("weight", 1.0)) + raw_reward, nested_components, component_diagnostics = self._evaluate_source(component.get("source"), payload) + reward = self._normalize(name, raw_reward, component.get("normalize")) + total += weight * reward + named_components[name] = reward + for nested_name, nested_value in nested_components.items(): + named_components[f"{name}.{nested_name}"] = float(nested_value) + if component_diagnostics is not None: + diagnostics[name] = component_diagnostics + result: Dict[str, Any] = { + "reward": total, + "components": named_components, + } + if diagnostics: + result["diagnostics"] = diagnostics + return result + + +def create_reward_function( + reward_type: Any = "simple", + *, + rewards: Optional[Sequence[Any]] = None, +) -> Any: + if rewards is not None: + components: List[Dict[str, Any]] = [] + for index, component in enumerate(rewards): + if isinstance(component, Mapping): + item = dict(component) + if "source" not in item: + raise ValueError("Reward components must include a source field.") + else: + item = {"source": component} + item.setdefault("name", f"reward_{index}") + item.setdefault("weight", 1.0) + components.append(item) + return _WeightedRewardComposer(components) + + if isinstance(reward_type, Mapping): + if "rewards" in reward_type: + return create_reward_function(rewards=reward_type["rewards"]) + return create_reward_function(rewards=[reward_type]) + + if isinstance(reward_type, (list, tuple)): + return create_reward_function(rewards=list(reward_type)) + + if callable(reward_type) or hasattr(reward_type, "evaluate"): + return reward_type + + return _simple_reward_builder(str(reward_type)) + + +__all__ = [ + "RLCheckpointBundle", + "RLRoleState", + "PreparedRLDataset", + "SUPPORTED_RL_DATASET_MODES", + "build_reference_policy", + "build_reward_model", + "create_reward_function", + "prepare_preference_dataset", + "prepare_reward_dataset", + "prepare_rl_dataset", + "resume_from_checkpoint", +] diff --git a/mlx_tune/rl_trainers.py b/mlx_tune/rl_trainers.py index d920c10..d9be9cf 100644 --- a/mlx_tune/rl_trainers.py +++ b/mlx_tune/rl_trainers.py @@ -1,131 +1,1161 @@ """ -Reinforcement Learning Trainers for MLX-Tune - -Provides Unsloth/TRL-compatible RL training interfaces: -- DPOTrainer: Direct Preference Optimization -- ORPOTrainer: Odds Ratio Preference Optimization -- GRPOTrainer: Group Relative Policy Optimization (DeepSeek R1 style) -- KTOTrainer: Kahneman-Tversky Optimization -- SimPOTrainer: Simple Preference Optimization - -These trainers use MLX under the hood for Apple Silicon optimization. -Now with PROPER loss implementations using native MLX training! +Reinforcement learning trainers for MLX-Tune. + +Provides TRL-style trainer interfaces for: +- DPO +- ORPO +- GRPO +- KTO +- SimPO """ -from typing import Optional, Dict, Any, Union, List, Callable from pathlib import Path +from typing import Optional, Dict, Any, List, Callable, Tuple +from contextlib import contextmanager +import hashlib import json import subprocess +import time import warnings import mlx.core as mx -# Try to import native training components try: - from mlx_lm.tuner.trainer import TrainingArgs import mlx.nn as nn import mlx.optimizers as optim + from mlx.utils import tree_flatten, tree_unflatten HAS_NATIVE_TRAINING = True except ImportError: HAS_NATIVE_TRAINING = False -# Import our loss functions from mlx_tune.losses import ( dpo_loss as compute_dpo_loss, - orpo_loss as compute_orpo_loss, + grpo_recompute_loss, kto_loss as compute_kto_loss, + orpo_loss as compute_orpo_loss, + ppo_sequence_loss, + reward_model_pairwise_loss, + reward_model_regression_loss, + scalar_loss_metrics, + pairwise_ranking_accuracy, simpo_loss as compute_simpo_loss, - grpo_batch_loss, - compute_reference_logprobs, - compute_log_probs_with_lengths, + value_model_regression_loss, +) +from mlx_tune._rl_runtime import ( + PolicyEvalBatch, + PreferenceBatch, + RolloutBatch, + assemble_minibatches, + cap_prompt_and_completion_lengths, + collect_rollouts, + compute_advantages, + compute_returns_and_advantages, + evaluate_rewards, + kl_against_reference, + make_policy_eval_batch, + make_preference_batch, + pad_sequences, + predict_rollout_values, + rank_grouped_rollouts, + score_rollout_references, + score_policy_in_chunks, + summarize_rollout_metrics, +) +from mlx_tune.model import ReferencePolicy +from mlx_tune.model import ( + RewardModel, + ValueModel, + build_reference_policy, + build_reward_model, + build_value_model, +) +from mlx_tune.rl_api import ( + RLCheckpointBundle, + create_reward_function as public_create_reward_function, + prepare_preference_dataset as public_prepare_preference_dataset, + prepare_reward_dataset as public_prepare_reward_dataset, + prepare_rl_dataset, + resume_from_checkpoint, ) -def _save_adapters_and_config(model, adapter_path: Path): - """ - Save adapter weights and config (required for GGUF export). +STATE_FILE = "trainer_state.safetensors" +METADATA_FILE = "trainer_state.json" +REFERENCE_FILE = "reference_model.safetensors" +REFERENCE_METADATA_FILE = "reference_metadata.json" +MANIFEST_FILE = "manifest.json" +CHECKPOINT_FORMAT_NAME = "mlx_tune_rl_checkpoint" +CHECKPOINT_FORMAT_VERSION = 4 +MLX_TUNE_VERSION = "0.4.0" +GRPO_LOSS_TYPES = {"grpo", "dr_grpo", "dapo", "bnpo", "gspo"} + + +def _actual_model(model: Any) -> Any: + return model.model if hasattr(model, "model") else model + + +def _pad_token_id(tokenizer: Any) -> int: + pad_id = getattr(tokenizer, "pad_token_id", None) + return 0 if pad_id is None else pad_id + + +def _encode_text(tokenizer: Any, text: str, add_special_tokens: bool = True) -> List[int]: + try: + return list(tokenizer.encode(text, add_special_tokens=add_special_tokens)) + except TypeError: + return list(tokenizer.encode(text)) + - This helper function is used by all RL trainers to ensure - adapter_config.json is created alongside adapters.safetensors. +def _save_adapters_and_config(model: Any, adapter_path: Path) -> bool: + """ + Save trainable parameters and adapter config in mlx_lm-compatible layout. """ try: - from mlx.utils import tree_flatten - actual_model = model.model if hasattr(model, 'model') else model - adapter_file = adapter_path / "adapters.safetensors" - adapter_path.mkdir(parents=True, exist_ok=True) + if hasattr(model, "save_adapter_snapshot"): + return bool(model.save_adapter_snapshot(str(adapter_path))) - # Save adapter weights (trainable parameters only) + actual_model = _actual_model(model) + adapter_path.mkdir(parents=True, exist_ok=True) + adapter_file = adapter_path / "adapters.safetensors" adapter_weights = dict(tree_flatten(actual_model.trainable_parameters())) mx.save_safetensors(str(adapter_file), adapter_weights) + return True + except Exception as exc: + print(f" Warning: could not save adapters: {exc}") + return False + + +def _save_full_model_state(model: Any, path: Path) -> None: + actual_model = _actual_model(model) + mx.save_safetensors(str(path), dict(tree_flatten(actual_model.parameters()))) + + +def _load_parameter_tree(model: Any, path: Path, strict: bool = False) -> None: + if not path.exists(): + return + actual_model = _actual_model(model) + weights = mx.load(str(path)) + actual_model.update(tree_unflatten(list(weights.items())), strict=strict) + mx.eval(actual_model.parameters()) + + +def _load_flat_parameter_tree(model: Any, flat_state: Dict[str, Any], strict: bool = False) -> None: + if not flat_state: + return + actual_model = _actual_model(model) + actual_model.update(tree_unflatten(list(flat_state.items())), strict=strict) + mx.eval(actual_model.parameters()) + + +def _flatten_prefixed_tree(prefix: str, tree: Dict[str, Any]) -> Dict[str, mx.array]: + return {f"{prefix}.{key}": value for key, value in tree_flatten(tree)} + + +def _extract_prefixed_tree(prefix: str, flat_state: Dict[str, mx.array]) -> Dict[str, Any]: + items = [] + prefix_with_dot = f"{prefix}." + for key, value in flat_state.items(): + if key.startswith(prefix_with_dot): + items.append((key[len(prefix_with_dot):], value)) + return tree_unflatten(items) if items else {} + + +def _rng_state_to_dict() -> Dict[str, mx.array]: + return {f"rng.{idx}": state for idx, state in enumerate(mx.random.state)} + + +def _restore_rng_state(flat_state: Dict[str, mx.array]) -> None: + rng_items = [] + idx = 0 + while f"rng.{idx}" in flat_state: + rng_items.append(mx.array(flat_state[f"rng.{idx}"])) + idx += 1 + if rng_items: + mx.random.state = rng_items + + +def _pad_sequences(sequences: List[List[int]], pad_id: int) -> Tuple[mx.array, mx.array]: + return pad_sequences(sequences, pad_id) + + +def _read_json(path: Path, default: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + if not path.exists(): + return {} if default is None else default + with open(path) as handle: + return json.load(handle) + + +def _write_json(path: Path, payload: Dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as handle: + json.dump(payload, handle, indent=2) + + +def _load_jsonl(path: Path) -> List[Dict[str, Any]]: + if not path.exists(): + return [] + history: List[Dict[str, Any]] = [] + with open(path) as handle: + for line in handle: + line = line.strip() + if line: + history.append(json.loads(line)) + return history + + +def _save_jsonl(path: Path, rows: List[Dict[str, Any]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as handle: + for row in rows: + handle.write(json.dumps(row) + "\n") + + +def _hash_payload(payload: Any) -> str: + return hashlib.sha256(json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")).hexdigest() + + +class _ScalarRoleTrainTarget(nn.Module): + def __init__(self, backbone: Any, head: Any): + super().__init__() + self.backbone = backbone + self.head = head + + +class _RLTrainerBase: + algorithm = "rl" + requires_reference_policy = False + + def _init_native_state(self) -> None: + self.global_step = 0 + self.dataset_cursor = 0 + self.reference_policy: Optional[ReferencePolicy] = None + self.cache_metadata: Dict[str, Any] = {} + self.runtime_cache_arrays: Dict[str, mx.array] = {} + self.optimizer = None + self.optimizers: Dict[str, Any] = {} + self.reward_model: Optional[RewardModel] = getattr(self, "reward_model", None) + self.value_model: Optional[ValueModel] = getattr(self, "value_model", None) + self.metrics_history: List[Dict[str, Any]] = [] + self.loaded_checkpoint_manifest: Optional[Dict[str, Any]] = None + self.seed = getattr(self.config, "seed", getattr(self, "seed", 0)) + self._seed_initialized = False + + def _apply_lora_if_needed(self) -> None: + if hasattr(self.model, "_apply_lora") and not getattr(self.model, "_lora_applied", False): + self.model._apply_lora() + + def _optimizer_for_training(self, learning_rate: Optional[float] = None): + lr_schedule = optim.cosine_decay(self.learning_rate if learning_rate is None else learning_rate, self.iters) + return optim.AdamW(learning_rate=lr_schedule) + + def _primary_role_name(self) -> str: + return "policy" + + def _primary_role_weight_format(self) -> Any: + return "adapters.safetensors" + + def _primary_optimizer_name(self) -> str: + return self._primary_role_name() + + def _trainer_cursor_state(self) -> Dict[str, int]: + return {"dataset": int(self.dataset_cursor)} - # Determine number of layers in the model - num_layers = None - if hasattr(actual_model, 'layers'): - num_layers = len(actual_model.layers) - elif hasattr(actual_model, 'model') and hasattr(actual_model.model, 'layers'): - num_layers = len(actual_model.model.layers) - - # Save adapter_config.json - lora_config = {} - if hasattr(model, 'lora_config') and model.lora_config: - lora_config = model.lora_config.copy() - - r = lora_config.get('r', 16) - alpha = lora_config.get('lora_alpha', 16) - - adapter_config = { - "fine_tune_type": "lora", - "num_layers": num_layers, - "lora_parameters": { - "rank": r, - "scale": alpha / r, - "dropout": lora_config.get('lora_dropout', 0.0), + def _restore_trainer_cursors(self, cursors: Dict[str, Any]) -> None: + self.dataset_cursor = int(cursors.get("dataset", cursors.get("dataset_cursor", self.dataset_cursor))) + + def _sampling_config_payload(self) -> Dict[str, Any]: + return {} + + def _sampling_config_fingerprint(self) -> Optional[str]: + payload = self._sampling_config_payload() + if not payload: + return None + return _hash_payload(payload) + + def _validate_resume_sampling_fingerprint(self, trainer_state: Dict[str, Any]) -> None: + current_fingerprint = self._sampling_config_fingerprint() + saved_fingerprint = ( + trainer_state.get("trainer_state", {}).get("sampling_config_fingerprint") + or trainer_state.get("sampling_config_fingerprint") + ) + if current_fingerprint and saved_fingerprint and current_fingerprint != saved_fingerprint: + raise ValueError( + "Checkpoint sampling config fingerprint does not match the current trainer configuration." + ) + + def _seed_training_run(self) -> None: + if self._seed_initialized: + return + mx.random.seed(int(self.seed)) + self._seed_initialized = True + + @contextmanager + def _preserve_rng_state(self): + saved_state = [mx.array(state) for state in mx.random.state] + try: + yield + finally: + mx.random.state = [mx.array(state) for state in saved_state] + + def _normalize_metric_value(self, value: Any) -> Any: + if value is None: + return None + if hasattr(value, "item"): + value = value.item() + if isinstance(value, bool): + return bool(value) + if isinstance(value, int): + return int(value) + if isinstance(value, float): + return float(value) + return value + + def _record_metrics( + self, + namespace: str, + metrics: Dict[str, Any], + step: Optional[int] = None, + ) -> Dict[str, Any]: + normalized = { + (key if "/" in key else f"{namespace}/{key}"): self._normalize_metric_value(value) + for key, value in metrics.items() + if value is not None + } + if not normalized: + return {} + row = {"step": self.global_step if step is None else int(step)} + row.update(normalized) + self.metrics_history.append(row) + if hasattr(self, "output_dir"): + _save_jsonl(self._metrics_path(), self.metrics_history) + return row + + def _record_metric(self, **metrics: Any) -> None: + self._record_metrics("train", metrics) + + def _format_metric_summary( + self, + row: Dict[str, Any], + namespace: str = "train", + keys: Optional[List[str]] = None, + ) -> str: + preferred = keys or [ + "policy_loss", + "loss", + "value_loss", + "reward_loss", + "reward_mean", + "logprob_delta_per_token_mean", + "logprob_delta_mean", + "completion_length_mean", + "completion_length_max", + "eos_rate", + "truncation_rate", + "kl_to_reference_mean", + "rollout_generate_wall", + "reward_eval_wall", + "reference_score_wall", + "returns_wall", + "policy_update_wall", + "policy_update_steps", + "preference_win_rate", + ] + parts = [f"step={row['step']}"] + for key in preferred: + full_key = f"{namespace}/{key}" + if full_key not in row: + continue + value = row[full_key] + if isinstance(value, bool): + parts.append(f"{key}={value}") + elif isinstance(value, (int, float)): + parts.append(f"{key}={value:.4f}") + else: + parts.append(f"{key}={value}") + return " | ".join(parts) + + def _gather_runtime_cache_arrays( + self, + extra_arrays: Optional[Dict[str, mx.array]] = None, + ) -> Dict[str, mx.array]: + merged: Dict[str, mx.array] = {} + merged.update(getattr(self, "runtime_cache_arrays", {})) + if hasattr(self, "_extra_state_arrays"): + merged.update(getattr(self, "_extra_state_arrays")()) + if extra_arrays: + merged.update(extra_arrays) + return merged + + def _save_primary_role(self, checkpoint_dir: Optional[Path] = None) -> None: + policy_dir = self._role_dir("policy", checkpoint_dir) + _save_adapters_and_config(self.model, policy_dir) + if hasattr(self.model, "set_adapter_path"): + self.model.set_adapter_path(str(policy_dir)) + _write_json( + policy_dir / "role.json", + { + "role": "policy", + "weight_format": "adapters.safetensors", }, + ) + + def _load_primary_role(self, checkpoint_dir: Path) -> None: + policy_adapter_file = self._role_dir("policy", checkpoint_dir) / "adapters.safetensors" + if policy_adapter_file.exists(): + _load_parameter_tree(self.model, policy_adapter_file, strict=False) + if hasattr(self.model, "set_adapter_path"): + self.model.set_adapter_path(str(self._role_dir("policy", checkpoint_dir))) + + def _next_samples(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + if not samples: + raise ValueError(f"{self.algorithm} training dataset is empty.") + + batch = [] + for _ in range(max(1, self.batch_size)): + batch.append(samples[self.dataset_cursor]) + self.dataset_cursor = (self.dataset_cursor + 1) % len(samples) + return batch + + def _next_rollout_samples(self, samples: List[Dict[str, Any]], count: int) -> List[Dict[str, Any]]: + if not samples: + raise ValueError(f"{self.algorithm} training dataset is empty.") + + batch = [] + for _ in range(max(1, count)): + batch.append(samples[self.dataset_cursor]) + self.dataset_cursor = (self.dataset_cursor + 1) % len(samples) + return batch + + def _observed_rollout_kl(self, rollout_batch: Optional[RolloutBatch]) -> Optional[float]: + if rollout_batch is None: + return None + if rollout_batch.reference_logprobs is None or rollout_batch.rollout_logprobs is None: + return None + kl_values = kl_against_reference( + rollout_batch.rollout_logprobs.astype(mx.float32), + rollout_batch.reference_logprobs.astype(mx.float32), + ) + return float(mx.mean(kl_values).item()) + + def _effective_kl_beta(self, rollout_batch: Optional[RolloutBatch] = None) -> float: + base_beta = float(getattr(self, "beta", 0.0)) + if getattr(self, "kl_penalty_mode", "kl") == "none": + return 0.0 + kl_target = getattr(self, "kl_target", None) + observed_kl = self._observed_rollout_kl(rollout_batch) + if kl_target is None or observed_kl is None: + return base_beta + target = max(float(kl_target), 1e-8) + scale = min(max(observed_kl / target, 0.0), 10.0) + return base_beta * scale + + def _checkpoint_dir(self, checkpoint_dir: Optional[Path] = None) -> Path: + return checkpoint_dir or self.output_dir + + def _manifest_path(self, checkpoint_dir: Optional[Path] = None) -> Path: + return self._checkpoint_dir(checkpoint_dir) / MANIFEST_FILE + + def _role_dir(self, role_name: str, checkpoint_dir: Optional[Path] = None) -> Path: + return self._checkpoint_dir(checkpoint_dir) / role_name + + def _optimizer_state_path( + self, + checkpoint_dir: Optional[Path] = None, + optimizer_name: Optional[str] = None, + ) -> Path: + if optimizer_name is None: + return self._checkpoint_dir(checkpoint_dir) / "optimizer" / "state.safetensors" + return self._checkpoint_dir(checkpoint_dir) / "optimizers" / optimizer_name / "state.safetensors" + + def _scheduler_state_path( + self, + checkpoint_dir: Optional[Path] = None, + optimizer_name: Optional[str] = None, + ) -> Path: + if optimizer_name is None: + return self._checkpoint_dir(checkpoint_dir) / "scheduler" / "state.json" + return self._checkpoint_dir(checkpoint_dir) / "schedulers" / optimizer_name / "state.json" + + def _trainer_state_path(self, checkpoint_dir: Optional[Path] = None) -> Path: + return self._checkpoint_dir(checkpoint_dir) / "trainer" / "state.json" + + def _trainer_rng_path(self, checkpoint_dir: Optional[Path] = None) -> Path: + return self._checkpoint_dir(checkpoint_dir) / "trainer" / "rng.safetensors" + + def _runtime_cache_path(self, checkpoint_dir: Optional[Path] = None) -> Path: + return self._checkpoint_dir(checkpoint_dir) / "runtime" / "cache.safetensors" + + def _metrics_path(self, checkpoint_dir: Optional[Path] = None) -> Path: + return self._checkpoint_dir(checkpoint_dir) / "metrics" / "history.jsonl" + + def _legacy_state_path(self, checkpoint_dir: Optional[Path] = None) -> Path: + return self._checkpoint_dir(checkpoint_dir) / STATE_FILE + + def _legacy_metadata_path(self, checkpoint_dir: Optional[Path] = None) -> Path: + return self._checkpoint_dir(checkpoint_dir) / METADATA_FILE + + def _legacy_reference_path(self, checkpoint_dir: Optional[Path] = None) -> Path: + return self._checkpoint_dir(checkpoint_dir) / REFERENCE_FILE + + def _legacy_reference_metadata_path(self, checkpoint_dir: Optional[Path] = None) -> Path: + return self._checkpoint_dir(checkpoint_dir) / REFERENCE_METADATA_FILE + + def _has_manifest_checkpoint(self, checkpoint_dir: Path) -> bool: + return self._manifest_path(checkpoint_dir).exists() + + def _has_legacy_checkpoint(self, checkpoint_dir: Path) -> bool: + return ( + (checkpoint_dir / "adapters" / "adapters.safetensors").exists() + or self._legacy_state_path(checkpoint_dir).exists() + or self._legacy_reference_path(checkpoint_dir).exists() + ) + + def _save_reference_policy(self, checkpoint_dir: Optional[Path] = None) -> None: + if self.reference_policy is None: + return + reference_dir = self._role_dir("reference", checkpoint_dir) + reference_dir.mkdir(parents=True, exist_ok=True) + _save_full_model_state(self.reference_policy.model, reference_dir / "weights.safetensors") + _write_json( + reference_dir / "metadata.json", + { + "source": self.reference_policy.source, + "metadata": self.reference_policy.metadata, + }, + ) + _write_json( + reference_dir / "role.json", + { + "role": "reference", + "weight_format": "weights.safetensors", + }, + ) + + def _save_optional_scalar_role( + self, + role_name: str, + role_model: Optional[Any], + checkpoint_dir: Optional[Path] = None, + ) -> None: + if role_model is None or role_name == self._primary_role_name(): + return + role_model.save_pretrained(str(self._role_dir(role_name, checkpoint_dir))) + + def _build_training_metadata(self) -> Dict[str, Any]: + config = self.config.to_dict() if hasattr(self.config, "to_dict") else dict(self.config) + trainer_state = { + "cursors": self._trainer_cursor_state(), + "sampling_config_fingerprint": self._sampling_config_fingerprint(), + "step_boundary": { + "completed_optimizer_step": int(self.global_step), + "checkpoint_authoritative": True, + }, + } + runtime_state = { + "cache_metadata": self.cache_metadata, + "runtime_cache_keys": sorted(self.runtime_cache_arrays.keys()), + "rng_state_path": "trainer/rng.safetensors", + } + return { + "algorithm": self.algorithm, + "config": config, + "global_step": self.global_step, + "dataset_cursor": self.dataset_cursor, + "cache_metadata": self.cache_metadata, + "seed": int(self.seed), + "sampling_config_fingerprint": trainer_state["sampling_config_fingerprint"], + "trainer_state": trainer_state, + "runtime_state": runtime_state, } - target_modules = lora_config.get('target_modules', []) - if target_modules: - short_to_full = { - 'q_proj': 'self_attn.q_proj', 'k_proj': 'self_attn.k_proj', - 'v_proj': 'self_attn.v_proj', 'o_proj': 'self_attn.o_proj', - 'gate_proj': 'mlp.gate_proj', 'up_proj': 'mlp.up_proj', - 'down_proj': 'mlp.down_proj', + def _build_scheduler_state(self, optimizer: Any, learning_rate: Optional[float] = None) -> Dict[str, Any]: + step_value = 0 + if optimizer is not None and getattr(optimizer, "state", None): + step = optimizer.state.get("step", 0) + step_value = int(step.item()) if hasattr(step, "item") else int(step) + return { + "name": "cosine_decay", + "initial_learning_rate": self.learning_rate if learning_rate is None else learning_rate, + "total_steps": self.iters, + "step": step_value, + } + + def _normalize_optimizers( + self, + optimizer: Any = None, + optimizers: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + if optimizers: + return optimizers + if optimizer is not None: + return {self._primary_optimizer_name(): optimizer} + return getattr(self, "optimizers", {}) + + def _optimizer_learning_rates(self) -> Dict[str, float]: + return {self._primary_optimizer_name(): self.learning_rate} + + def _build_manifest(self, optimizers: Dict[str, Any]) -> Dict[str, Any]: + reward_base = getattr(self.reward_model, "base_model", None) + value_base = getattr(self.value_model, "base_model", None) + primary_role = self._primary_role_name() + roles_present = [primary_role] + role_weight_formats = { + primary_role: self._primary_role_weight_format(), + } + reference_provenance = None + if self.reference_policy is not None: + roles_present.append("reference") + role_weight_formats["reference"] = "weights.safetensors" + reference_provenance = { + "source": self.reference_policy.source, + "metadata": self.reference_policy.metadata, + } + if self.reward_model is not None and primary_role != "reward_model": + roles_present.append("reward_model") + role_weight_formats["reward_model"] = { + "backbone": "weights.safetensors", + "head": "head.safetensors", + "adapters": "adapters.safetensors" + if hasattr(reward_base, "has_adapters") and reward_base.has_adapters() + else None, + } + if self.value_model is not None and primary_role != "value_model": + roles_present.append("value_model") + role_weight_formats["value_model"] = { + "backbone": "weights.safetensors", + "head": "head.safetensors", + "adapters": "adapters.safetensors" + if hasattr(value_base, "has_adapters") and value_base.has_adapters() + else None, } - adapter_config["lora_parameters"]["keys"] = [ - short_to_full.get(m, m) for m in target_modules - ] - config_path = adapter_path / "adapter_config.json" - with open(config_path, 'w') as f: - json.dump(adapter_config, f, indent=2) + trainer_state_locations = { + "optimizers": { + name: f"optimizers/{name}/state.safetensors" + for name in optimizers + }, + "schedulers": { + name: f"schedulers/{name}/state.json" + for name in optimizers + }, + "trainer": "trainer/state.json", + "rng": "trainer/rng.safetensors", + "runtime_cache": "runtime/cache.safetensors", + } + if len(optimizers) == 1: + trainer_state_locations["optimizer"] = "optimizer/state.safetensors" + trainer_state_locations["scheduler"] = "scheduler/state.json" - return True - except Exception as e: - print(f" ⚠ Could not save adapters: {e}") - return False + return { + "format_name": CHECKPOINT_FORMAT_NAME, + "format_version": CHECKPOINT_FORMAT_VERSION, + "algorithm": self.algorithm, + "roles_present": roles_present, + "mlx_tune_version": MLX_TUNE_VERSION, + "role_weight_formats": role_weight_formats, + "trainer_state_locations": trainer_state_locations, + "metrics_path": "metrics/history.jsonl", + "reference_provenance": reference_provenance, + } + + def save_state( + self, + optimizer: Any = None, + extra_arrays: Optional[Dict[str, mx.array]] = None, + optimizers: Optional[Dict[str, Any]] = None, + ) -> None: + checkpoint_dir = self._checkpoint_dir() + checkpoint_dir.mkdir(parents=True, exist_ok=True) + optimizer_map = self._normalize_optimizers(optimizer=optimizer, optimizers=optimizers) + if not optimizer_map: + raise ValueError("No optimizer state provided for checkpoint save.") + self.optimizers = optimizer_map + if len(optimizer_map) == 1: + self.optimizer = next(iter(optimizer_map.values())) + + self._save_primary_role(checkpoint_dir) + self._save_reference_policy(checkpoint_dir) + self._save_optional_scalar_role("reward_model", self.reward_model, checkpoint_dir) + self._save_optional_scalar_role("value_model", self.value_model, checkpoint_dir) + + learning_rates = self._optimizer_learning_rates() + for name, current_optimizer in optimizer_map.items(): + optimizer_path = self._optimizer_state_path(checkpoint_dir, name) + optimizer_path.parent.mkdir(parents=True, exist_ok=True) + mx.save_safetensors(str(optimizer_path), _flatten_prefixed_tree("optimizer", current_optimizer.state)) + if len(optimizer_map) == 1: + legacy_path = self._optimizer_state_path(checkpoint_dir) + legacy_path.parent.mkdir(parents=True, exist_ok=True) + mx.save_safetensors(str(legacy_path), _flatten_prefixed_tree("optimizer", current_optimizer.state)) + scheduler_state = self._build_scheduler_state( + current_optimizer, + learning_rate=learning_rates.get(name, self.learning_rate), + ) + _write_json(self._scheduler_state_path(checkpoint_dir, name), scheduler_state) + if len(optimizer_map) == 1: + _write_json(self._scheduler_state_path(checkpoint_dir), scheduler_state) + + rng_path = self._trainer_rng_path(checkpoint_dir) + rng_path.parent.mkdir(parents=True, exist_ok=True) + mx.save_safetensors(str(rng_path), _rng_state_to_dict()) + + runtime_arrays = self._gather_runtime_cache_arrays(extra_arrays) + self.runtime_cache_arrays = dict(runtime_arrays) + runtime_path = self._runtime_cache_path(checkpoint_dir) + if runtime_arrays: + runtime_path.parent.mkdir(parents=True, exist_ok=True) + mx.save_safetensors(str(runtime_path), runtime_arrays) + + _write_json(self._trainer_state_path(checkpoint_dir), self._build_training_metadata()) + _save_jsonl(self._metrics_path(checkpoint_dir), self.metrics_history) + _write_json(self._manifest_path(checkpoint_dir), self._build_manifest(optimizer_map)) + + def _ensure_reference_policy(self) -> None: + if not self.requires_reference_policy: + return + if self.reference_policy is None: + ref_model = getattr(self, "ref_model", None) + self.reference_policy = build_reference_policy(self.model, ref_model=ref_model, snapshot=True) + + def _load_reference_policy(self, checkpoint_dir: Path) -> None: + reference_path = self._role_dir("reference", checkpoint_dir) / "weights.safetensors" + if reference_path.exists(): + self.reference_policy = build_reference_policy(self.model, snapshot=True) + _load_parameter_tree(self.reference_policy.model, reference_path, strict=False) + _actual_model(self.reference_policy.model).freeze() + mx.eval(_actual_model(self.reference_policy.model).parameters()) + metadata = _read_json(self._role_dir("reference", checkpoint_dir) / "metadata.json") + if metadata: + self.reference_policy.source = metadata.get("source", self.reference_policy.source) + self.reference_policy.metadata = metadata.get("metadata", self.reference_policy.metadata) + else: + self._ensure_reference_policy() + + def _load_optional_scalar_role(self, checkpoint_dir: Path, role_name: str) -> Optional[Any]: + role_dir = self._role_dir(role_name, checkpoint_dir) + if not role_dir.exists() or role_name == self._primary_role_name(): + return None + with open(role_dir / "head_config.json") as handle: + config = json.load(handle) + if role_name == "reward_model": + if self.reward_model is None: + self.reward_model = build_reward_model( + self.model, + pooling=config.get("pooling", "last_token"), + target=config.get("target", "completion"), + ) + self.reward_model.load_pretrained(str(role_dir)) + return self.reward_model + if role_name == "value_model": + if self.value_model is None: + self.value_model = build_value_model( + self.model, + pooling=config.get("pooling", "last_token"), + target=config.get("target", "completion"), + ) + self.value_model.load_pretrained(str(role_dir)) + return self.value_model + return None + + def _apply_scalar_role_state(self, role_name: str, role_state: Any) -> Optional[Any]: + if role_name == "reward_model": + if self.reward_model is None: + self.reward_model = build_reward_model( + self.model, + pooling=role_state.head_config.get("pooling", "last_token"), + target=role_state.head_config.get("target", "completion"), + ) + role_model = self.reward_model + elif role_name == "value_model": + if self.value_model is None: + self.value_model = build_value_model( + self.model, + pooling=role_state.head_config.get("pooling", "last_token"), + target=role_state.head_config.get("target", "completion"), + ) + role_model = self.value_model + else: + return None + + _load_flat_parameter_tree(role_model.base_model, role_state.parameter_state, strict=False) + if role_state.adapter_state: + _load_flat_parameter_tree(role_model.base_model, role_state.adapter_state, strict=False) + if role_state.head_state: + role_model.head.update(tree_unflatten(list(role_state.head_state.items())), strict=False) + mx.eval(role_model.head.parameters()) + if role_state.head_config: + role_model.head_config.update(role_state.head_config) + role_model.pooling = role_model.head_config.get("pooling", role_model.pooling) + role_model.target = role_model.head_config.get("target", role_model.target) + return role_model + + def _apply_checkpoint_bundle( + self, + bundle: RLCheckpointBundle, + optimizer: Any = None, + optimizers: Optional[Dict[str, Any]] = None, + ) -> Dict[str, mx.array]: + self.loaded_checkpoint_manifest = bundle.manifest + optimizer_map = self._normalize_optimizers(optimizer=optimizer, optimizers=optimizers) + if optimizer_map: + for name, current_optimizer in optimizer_map.items(): + state_tree = bundle.optimizer_state_trees.get(name) + if state_tree is None and len(bundle.optimizer_state_trees) == 1: + state_tree = next(iter(bundle.optimizer_state_trees.values())) + if state_tree: + current_optimizer.state = state_tree + self.optimizers = optimizer_map + if len(optimizer_map) == 1: + self.optimizer = next(iter(optimizer_map.values())) + + if bundle.rng_state: + _restore_rng_state(bundle.rng_state) + self._seed_initialized = True + + trainer_state = bundle.trainer_state + self._validate_resume_sampling_fingerprint(trainer_state) + self.global_step = trainer_state.get("global_step", 0) + self.dataset_cursor = trainer_state.get("dataset_cursor", 0) + self.cache_metadata = trainer_state.get("cache_metadata", {}) + self.seed = int(trainer_state.get("seed", self.seed)) + self._restore_trainer_cursors( + trainer_state.get("trainer_state", {}).get("cursors", {}) + ) + self.metrics_history = list(bundle.metrics_history) + self.runtime_cache_arrays = dict(bundle.runtime_cache) + + for role_name, role_state in bundle.restored_roles.items(): + if role_name == "policy": + _load_flat_parameter_tree(self.model, role_state.parameter_state, strict=False) + elif role_name == "reference": + self.reference_policy = build_reference_policy(self.model, snapshot=True) + _load_flat_parameter_tree(self.reference_policy.model, role_state.parameter_state, strict=False) + reference_actual = _actual_model(self.reference_policy.model) + if hasattr(reference_actual, "freeze"): + reference_actual.freeze() + mx.eval(reference_actual.parameters()) + if role_state.metadata: + self.reference_policy.source = role_state.metadata.get("source", self.reference_policy.source) + self.reference_policy.metadata = role_state.metadata.get("metadata", self.reference_policy.metadata) + elif role_name in {"reward_model", "value_model"}: + self._apply_scalar_role_state(role_name, role_state) + + if self.requires_reference_policy and self.reference_policy is None: + self._ensure_reference_policy() + return bundle.runtime_cache + + def _restore_optimizer_states( + self, + checkpoint_dir: Path, + optimizer_map: Dict[str, Any], + manifest: Optional[Dict[str, Any]] = None, + ) -> None: + trainer_locations = {} if manifest is None else manifest.get("trainer_state_locations", {}) + optimizer_locations = trainer_locations.get("optimizers") or {} + for name, current_optimizer in optimizer_map.items(): + optimizer_path = self._optimizer_state_path(checkpoint_dir, name) + if name in optimizer_locations: + optimizer_path = checkpoint_dir / optimizer_locations[name] + elif not optimizer_path.exists() and len(optimizer_map) == 1: + optimizer_path = self._optimizer_state_path(checkpoint_dir) + if not optimizer_path.exists(): + continue + optimizer_state = _extract_prefixed_tree("optimizer", mx.load(str(optimizer_path))) + if optimizer_state: + current_optimizer.state = optimizer_state + + def _load_manifest_state( + self, + optimizer: Any = None, + checkpoint_dir: Optional[Path] = None, + optimizers: Optional[Dict[str, Any]] = None, + ) -> Dict[str, mx.array]: + checkpoint_dir = Path(checkpoint_dir) + manifest = _read_json(self._manifest_path(checkpoint_dir)) + if not manifest: + raise FileNotFoundError(f"Checkpoint manifest not found under {checkpoint_dir}") + self.loaded_checkpoint_manifest = manifest + + self._load_primary_role(checkpoint_dir) + + optimizer_map = self._normalize_optimizers(optimizer=optimizer, optimizers=optimizers) + self._restore_optimizer_states(checkpoint_dir, optimizer_map, manifest=manifest) + self.optimizers = optimizer_map + if len(optimizer_map) == 1: + self.optimizer = next(iter(optimizer_map.values())) + + rng_path = self._trainer_rng_path(checkpoint_dir) + if rng_path.exists(): + _restore_rng_state(mx.load(str(rng_path))) + self._seed_initialized = True + + metadata = _read_json(self._trainer_state_path(checkpoint_dir)) + self._validate_resume_sampling_fingerprint(metadata) + self.global_step = metadata.get("global_step", 0) + self.dataset_cursor = metadata.get("dataset_cursor", 0) + self.cache_metadata = metadata.get("cache_metadata", {}) + self.seed = int(metadata.get("seed", self.seed)) + self._restore_trainer_cursors(metadata.get("trainer_state", {}).get("cursors", {})) + self.metrics_history = _load_jsonl(self._metrics_path(checkpoint_dir)) + + self._load_reference_policy(checkpoint_dir) + self._load_optional_scalar_role(checkpoint_dir, "reward_model") + self._load_optional_scalar_role(checkpoint_dir, "value_model") + + runtime_path = self._runtime_cache_path(checkpoint_dir) + self.runtime_cache_arrays = dict(mx.load(str(runtime_path))) if runtime_path.exists() else {} + return self.runtime_cache_arrays + + def _load_legacy_state( + self, + optimizer: Any = None, + checkpoint_dir: Optional[Path] = None, + optimizers: Optional[Dict[str, Any]] = None, + ) -> Dict[str, mx.array]: + checkpoint_dir = Path(checkpoint_dir) + adapter_file = checkpoint_dir / "adapters" / "adapters.safetensors" + if adapter_file.exists(): + _load_parameter_tree(self.model, adapter_file, strict=False) + if hasattr(self.model, "set_adapter_path"): + self.model.set_adapter_path(str(checkpoint_dir / "adapters")) + + state_path = self._legacy_state_path(checkpoint_dir) + metadata_path = self._legacy_metadata_path(checkpoint_dir) + if not state_path.exists() or not metadata_path.exists(): + raise FileNotFoundError(f"Checkpoint state not found under {checkpoint_dir}") + + metadata = _read_json(metadata_path) + flat_state = mx.load(str(state_path)) + optimizer_map = self._normalize_optimizers(optimizer=optimizer, optimizers=optimizers) + if optimizer_map: + first_optimizer = next(iter(optimizer_map.values())) + optimizer_state = _extract_prefixed_tree("optimizer", flat_state) + if optimizer_state: + first_optimizer.state = optimizer_state + self.optimizers = optimizer_map + if len(optimizer_map) == 1: + self.optimizer = first_optimizer + _restore_rng_state(flat_state) + self._seed_initialized = True + + self.global_step = metadata.get("global_step", 0) + self.dataset_cursor = metadata.get("dataset_cursor", 0) + self.cache_metadata = metadata.get("cache_metadata", {}) + self.seed = int(metadata.get("seed", self.seed)) + self.metrics_history = [] + self.runtime_cache_arrays = { + key: value + for key, value in flat_state.items() + if not key.startswith("optimizer.") and not key.startswith("rng.") + } + reference_path = self._legacy_reference_path(checkpoint_dir) + if reference_path.exists(): + self.reference_policy = build_reference_policy(self.model, snapshot=True) + _load_parameter_tree(self.reference_policy.model, reference_path, strict=False) + reference_metadata = _read_json(self._legacy_reference_metadata_path(checkpoint_dir)) + if reference_metadata: + self.reference_policy.source = reference_metadata.get("source", self.reference_policy.source) + self.reference_policy.metadata = reference_metadata.get("metadata", self.reference_policy.metadata) + else: + self._ensure_reference_policy() + return flat_state -class DPOConfig: - """ - Configuration for Direct Preference Optimization training. + def load_state( + self, + optimizer: Any = None, + checkpoint_dir: Optional[Path] = None, + optimizers: Optional[Dict[str, Any]] = None, + ) -> Dict[str, mx.array]: + bundle = resume_from_checkpoint(Path(checkpoint_dir)) + return self._apply_checkpoint_bundle(bundle, optimizer=optimizer, optimizers=optimizers) - Compatible with TRL's DPOConfig. - Example: - >>> config = DPOConfig( - ... beta=0.1, - ... learning_rate=5e-7, - ... max_steps=100, - ... ) - """ +def prepare_reward_dataset(dataset: Any) -> List[Dict[str, Any]]: + return public_prepare_reward_dataset(dataset) + +def _tokenize_reward_scalar_sample( + tokenizer: Any, + sample: Dict[str, Any], + max_seq_length: int, +) -> Dict[str, Any]: + prompt = sample.get("prompt", "") + response = sample.get("response", "") + sequence_ids = tokenizer.encode(prompt + response)[:max_seq_length] + prompt_ids = tokenizer.encode(prompt) if prompt else [] + prompt_length = min(len(prompt_ids), len(sequence_ids)) + completion_length = max(len(sequence_ids) - prompt_length, 0) + return { + "ids": sequence_ids, + "length": len(sequence_ids), + "prompt_length": prompt_length, + "completion_length": completion_length, + "score": float(sample.get("score", 0.0)), + } + + +def _tokenize_reward_pairwise_sample( + tokenizer: Any, + sample: Dict[str, Any], + max_seq_length: int, +) -> Dict[str, Any]: + prompt = sample.get("prompt", "") + prompt_ids = tokenizer.encode(prompt) if prompt else [] + chosen_ids = tokenizer.encode(prompt + sample["chosen"])[:max_seq_length] + rejected_ids = tokenizer.encode(prompt + sample["rejected"])[:max_seq_length] + chosen_prompt_length = min(len(prompt_ids), len(chosen_ids)) + rejected_prompt_length = min(len(prompt_ids), len(rejected_ids)) + return { + "chosen_ids": chosen_ids, + "rejected_ids": rejected_ids, + "chosen_length": len(chosen_ids), + "rejected_length": len(rejected_ids), + "chosen_prompt_length": chosen_prompt_length, + "rejected_prompt_length": rejected_prompt_length, + "chosen_completion_length": max(len(chosen_ids) - chosen_prompt_length, 0), + "rejected_completion_length": max(len(rejected_ids) - rejected_prompt_length, 0), + } + + +def score_reward_model( + reward_model: RewardModel, + samples: Any, + batch_size: int = 8, + tokenizer: Optional[Any] = None, + max_seq_length: int = 2048, +) -> List[float]: + scoring_tokenizer = tokenizer or getattr(reward_model, "tokenizer", None) + if scoring_tokenizer is None: + raise ValueError("score_reward_model requires a tokenizer on the reward model or as an argument.") + + normalized = list(prepare_rl_dataset(samples, mode="reward_scalar", tokenizer=scoring_tokenizer)) + if any(sample.get("type") != "scalar" for sample in normalized): + raise ValueError("score_reward_model only supports scalar reward samples.") + + tokenized = [ + _tokenize_reward_scalar_sample(scoring_tokenizer, sample, max_seq_length) + for sample in normalized + ] + scores: List[float] = [] + pad_id = _pad_token_id(scoring_tokenizer) + for start in range(0, len(tokenized), max(1, batch_size)): + chunk = tokenized[start:start + max(1, batch_size)] + input_ids, lengths = _pad_sequences([sample["ids"] for sample in chunk], pad_id) + chunk_scores = reward_model.score( + input_ids, + sequence_lengths=lengths, + prompt_lengths=mx.array([sample["prompt_length"] for sample in chunk]), + completion_lengths=mx.array([sample["completion_length"] for sample in chunk]), + ) + scores.extend(float(value.item()) for value in chunk_scores) + return scores + + +class RLConfigBase: + _NON_SERIALIZED_FIELDS: set[str] = set() + + @staticmethod + def _pop_alias(kwargs: Dict[str, Any], *names: str, default: Any = None) -> Any: + found = [name for name in names if name in kwargs] + if not found: + return default + value = kwargs.pop(found[0]) + for name in found[1:]: + kwargs.pop(name) + return value + + @staticmethod + def _normalize_reward_sources( + reward_sources: Optional[Any], + reward_fn: Optional[Any], + reward_model: Optional[Any], + ) -> List[Any]: + resolved: List[Any] = [] + if reward_sources is not None: + if isinstance(reward_sources, (list, tuple)): + resolved.extend(list(reward_sources)) + else: + resolved.append(reward_sources) + if reward_model is not None: + resolved.append({"name": "reward_model", "source": reward_model}) + if reward_fn is not None: + resolved.append({"name": "reward_fn", "source": reward_fn}) + return resolved + + def _set_remaining(self, kwargs: Dict[str, Any]) -> None: + for key, value in kwargs.items(): + setattr(self, key, value) + self._validate() + + def _validate_choice(self, field: str, value: Any, choices: set[str]) -> None: + if value not in choices: + raise ValueError(f"Unsupported {field} '{value}'. Supported values: {sorted(choices)}") + + def _validate(self) -> None: + reward_source = getattr(self, "reward_source", None) + if reward_source is not None: + self._validate_choice("reward_source", reward_source, {"auto", "online", "offline", "hybrid"}) + advantage_estimator = getattr(self, "advantage_estimator", None) + if advantage_estimator is not None: + self._validate_choice("advantage_estimator", advantage_estimator, {"group_zscore", "rloo", "gae"}) + kl_penalty_mode = getattr(self, "kl_penalty_mode", None) + if kl_penalty_mode is not None: + self._validate_choice("kl_penalty_mode", kl_penalty_mode, {"kl", "none"}) + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + for key, value in self.__dict__.items(): + if key.startswith("_") or key in self._NON_SERIALIZED_FIELDS: + continue + if callable(value): + continue + if hasattr(value, "parameters") or hasattr(value, "score") or hasattr(value, "predict"): + continue + payload[key] = value + return payload + + +class RewardConfig(RLConfigBase): def __init__( self, - # DPO-specific - beta: float = 0.1, # KL penalty coefficient - loss_type: str = "sigmoid", # sigmoid, hinge, ipo, kto_pair + output_dir: str = "./reward_outputs", + learning_rate: float = 5e-6, + per_device_train_batch_size: int = 2, + num_train_epochs: int = 1, + max_steps: int = -1, + logging_steps: int = 10, + save_steps: int = 100, + max_seq_length: int = 2048, + pairwise_margin: float = 0.0, + regression_loss_type: str = "mse", + dataset_mode: Optional[str] = None, + chat_template: Optional[Any] = None, + auto_detect_dataset: bool = True, + **kwargs, + ): + self.output_dir = output_dir + self.learning_rate = learning_rate + self.per_device_train_batch_size = per_device_train_batch_size + self.num_train_epochs = num_train_epochs + self.max_steps = max_steps + self.logging_steps = logging_steps + self.save_steps = save_steps + self.max_seq_length = max_seq_length + self.pairwise_margin = pairwise_margin + self.regression_loss_type = regression_loss_type + self.dataset_mode = dataset_mode + self.chat_template = chat_template + self.auto_detect_dataset = auto_detect_dataset + self._set_remaining(kwargs) + + +class DPOConfig(RLConfigBase): + def __init__( + self, + beta: float = 0.1, + loss_type: str = "sigmoid", label_smoothing: float = 0.0, - # Training args output_dir: str = "./dpo_outputs", learning_rate: float = 5e-7, per_device_train_batch_size: int = 2, @@ -137,7 +1167,10 @@ def __init__( save_steps: int = 100, max_seq_length: int = 2048, max_prompt_length: int = 512, - **kwargs + dataset_mode: Optional[str] = None, + chat_template: Optional[Any] = None, + auto_detect_dataset: bool = True, + **kwargs, ): self.beta = beta self.loss_type = loss_type @@ -153,34 +1186,16 @@ def __init__( self.save_steps = save_steps self.max_seq_length = max_seq_length self.max_prompt_length = max_prompt_length + self.dataset_mode = dataset_mode + self.chat_template = chat_template + self.auto_detect_dataset = auto_detect_dataset + self._set_remaining(kwargs) - for key, value in kwargs.items(): - setattr(self, key, value) - - def to_dict(self) -> Dict[str, Any]: - return {k: v for k, v in self.__dict__.items() if not k.startswith('_')} - - -class ORPOConfig: - """ - Configuration for Odds Ratio Preference Optimization training. - - ORPO combines SFT and preference learning into a single step, - making it simpler and more efficient than traditional RLHF. - - Example: - >>> config = ORPOConfig( - ... beta=0.1, - ... learning_rate=8e-6, - ... max_steps=1000, - ... ) - """ +class ORPOConfig(RLConfigBase): def __init__( self, - # ORPO-specific - beta: float = 0.1, # Odds ratio coefficient - # Training args + beta: float = 0.1, output_dir: str = "./orpo_outputs", learning_rate: float = 8e-6, per_device_train_batch_size: int = 2, @@ -192,7 +1207,10 @@ def __init__( save_steps: int = 100, max_seq_length: int = 2048, max_prompt_length: int = 512, - **kwargs + dataset_mode: Optional[str] = None, + chat_template: Optional[Any] = None, + auto_detect_dataset: bool = True, + **kwargs, ): self.beta = beta self.output_dir = output_dir @@ -206,49 +1224,29 @@ def __init__( self.save_steps = save_steps self.max_seq_length = max_seq_length self.max_prompt_length = max_prompt_length + self.dataset_mode = dataset_mode + self.chat_template = chat_template + self.auto_detect_dataset = auto_detect_dataset + self._set_remaining(kwargs) - for key, value in kwargs.items(): - setattr(self, key, value) - - def to_dict(self) -> Dict[str, Any]: - return {k: v for k, v in self.__dict__.items() if not k.startswith('_')} - -class GRPOConfig: - """ - Configuration for Group Relative Policy Optimization training. - - GRPO is used by DeepSeek to train their R1 reasoning models. - It replaces the value model with group statistics and uses custom - reward functions. - - Supports loss types: - - 'grpo': Standard GRPO - - 'dr_grpo': Dr. GRPO (distilled) - - 'dapo': DAPO variant - - 'bnpo': BNPO variant - - Example: - >>> config = GRPOConfig( - ... loss_type='grpo', - ... num_generations=4, - ... learning_rate=1e-6, - ... ) - """ +class GRPOConfig(RLConfigBase): + _NON_SERIALIZED_FIELDS = {"reward_fn", "reward_model", "value_model"} def __init__( self, - # GRPO-specific - loss_type: str = "grpo", # grpo, dr_grpo, dapo, bnpo - beta: float = 0.04, # KL coefficient - num_generations: int = 4, # Number of generations per prompt + loss_type: str = "grpo", + advantage_mode: str = "group_zscore", + beta: float = 0.04, + num_generations: int = 4, temperature: float = 0.7, max_completion_length: int = 512, - # Reward function (custom callable) reward_fn: Optional[Callable] = None, - # Training args + reward_model: Optional[Any] = None, + value_model: Optional[Any] = None, output_dir: str = "./grpo_outputs", learning_rate: float = 1e-6, + seed: int = 0, per_device_train_batch_size: int = 1, gradient_accumulation_steps: int = 8, num_train_epochs: int = 1, @@ -257,16 +1255,64 @@ def __init__( logging_steps: int = 1, save_steps: int = 100, max_seq_length: int = 2048, - **kwargs + clip_epsilon: float = 0.2, + epsilon_low: Optional[float] = None, + epsilon_high: Optional[float] = None, + rollout_batch_size: Optional[int] = None, + scale_rewards: Optional[bool] = None, + reward_normalization: str = "none", + mask_truncated_completions: bool = False, + minibatch_reuse_steps: int = 1, + entropy_bonus: float = 0.0, + advantage_estimator: Optional[str] = None, + reward_source: str = "auto", + reward_sources: Optional[Any] = None, + kl_target: Optional[float] = None, + kl_penalty_mode: str = "kl", + eval_steps: Optional[int] = None, + eval_num_batches: Optional[int] = None, + eval_num_generations: Optional[int] = None, + generation_batch_size: Optional[int] = None, + score_chunk_size: Optional[int] = None, + precompute_reference_scores: bool = False, + dataset_mode: Optional[str] = None, + chat_template: Optional[Any] = None, + auto_detect_dataset: bool = True, + **kwargs, ): self.loss_type = loss_type - self.beta = beta - self.num_generations = num_generations + resolved_advantage = self._pop_alias( + kwargs, + "baseline_mode", + "advantage_estimator", + default=advantage_estimator or advantage_mode, + ) + self.advantage_estimator = resolved_advantage + self.advantage_mode = resolved_advantage + self.kl_beta = self._pop_alias(kwargs, "kl_beta", default=beta) + self.beta = self.kl_beta + self.kl_target = kl_target + self.kl_penalty_mode = kl_penalty_mode + self.num_generations = self._pop_alias(kwargs, "generations_per_prompt", default=num_generations) self.temperature = temperature self.max_completion_length = max_completion_length self.reward_fn = reward_fn + self.reward_model = reward_model + self.value_model = value_model + self.reward_sources = self._normalize_reward_sources( + self._pop_alias(kwargs, "reward_sources", default=reward_sources), + reward_fn, + reward_model, + ) + self.reward_source = reward_source + self.reward_normalization = reward_normalization + self.mask_truncated_completions = mask_truncated_completions + self.minibatch_reuse_steps = minibatch_reuse_steps + self.entropy_bonus = entropy_bonus + self.rollout_batch_size = rollout_batch_size self.output_dir = output_dir self.learning_rate = learning_rate + self.seed = seed self.per_device_train_batch_size = per_device_train_batch_size self.gradient_accumulation_steps = gradient_accumulation_steps self.num_train_epochs = num_train_epochs @@ -275,683 +1321,2813 @@ def __init__( self.logging_steps = logging_steps self.save_steps = save_steps self.max_seq_length = max_seq_length + self.clip_epsilon = clip_epsilon + self.epsilon_low = clip_epsilon if epsilon_low is None else epsilon_low + self.epsilon_high = clip_epsilon if epsilon_high is None else epsilon_high + self.scale_rewards = scale_rewards + self.eval_steps = eval_steps + self.eval_num_batches = eval_num_batches + self.eval_num_generations = eval_num_generations + self.generation_batch_size = generation_batch_size + self.score_chunk_size = score_chunk_size + self.precompute_reference_scores = precompute_reference_scores + self.dataset_mode = dataset_mode + self.chat_template = chat_template + self.auto_detect_dataset = auto_detect_dataset + if self.loss_type == "dapo": + self.mask_truncated_completions = True + self.epsilon_high = 0.28 if epsilon_high is None else epsilon_high + if self.scale_rewards is None: + self.scale_rewards = self.loss_type != "dr_grpo" + self._set_remaining(kwargs) + + +class PPOConfig(RLConfigBase): + _NON_SERIALIZED_FIELDS = {"reward_fn", "reward_model", "value_model"} - for key, value in kwargs.items(): - setattr(self, key, value) + def __init__( + self, + output_dir: str = "./ppo_outputs", + learning_rate: float = 1e-6, + seed: int = 0, + value_learning_rate: Optional[float] = None, + per_device_train_batch_size: int = 1, + num_train_epochs: int = 1, + max_steps: int = -1, + logging_steps: int = 1, + save_steps: int = 100, + max_seq_length: int = 2048, + max_completion_length: int = 256, + num_generations: int = 4, + ppo_epochs: int = 2, + temperature: float = 0.7, + clip_epsilon: float = 0.2, + beta: float = 0.0, + gamma: float = 1.0, + gae_lambda: float = 1.0, + reward_fn: Optional[Callable] = None, + reward_model: Optional[Any] = None, + value_model: Optional[Any] = None, + normalize_advantages: bool = True, + rollout_batch_size: Optional[int] = None, + reward_normalization: str = "none", + mask_truncated_completions: bool = False, + minibatch_reuse_steps: Optional[int] = None, + entropy_bonus: float = 0.0, + advantage_estimator: str = "gae", + reward_source: str = "auto", + reward_sources: Optional[Any] = None, + kl_target: Optional[float] = None, + kl_penalty_mode: str = "kl", + eval_steps: Optional[int] = None, + eval_num_batches: Optional[int] = None, + eval_num_generations: Optional[int] = None, + generation_batch_size: Optional[int] = None, + score_chunk_size: Optional[int] = None, + precompute_reference_scores: bool = False, + dataset_mode: Optional[str] = None, + chat_template: Optional[Any] = None, + auto_detect_dataset: bool = True, + **kwargs, + ): + reuse_steps = ppo_epochs if minibatch_reuse_steps is None else minibatch_reuse_steps + self.output_dir = output_dir + self.learning_rate = learning_rate + self.seed = seed + self.value_learning_rate = learning_rate if value_learning_rate is None else value_learning_rate + self.per_device_train_batch_size = per_device_train_batch_size + self.num_train_epochs = num_train_epochs + self.max_steps = max_steps + self.logging_steps = logging_steps + self.save_steps = save_steps + self.max_seq_length = max_seq_length + self.max_completion_length = max_completion_length + self.num_generations = self._pop_alias(kwargs, "generations_per_prompt", default=num_generations) + self.minibatch_reuse_steps = reuse_steps + self.ppo_epochs = reuse_steps + self.temperature = temperature + self.clip_epsilon = clip_epsilon + self.kl_beta = self._pop_alias(kwargs, "kl_beta", default=beta) + self.beta = self.kl_beta + self.kl_target = kl_target + self.kl_penalty_mode = kl_penalty_mode + self.gamma = gamma + self.gae_lambda = gae_lambda + self.reward_fn = reward_fn + self.reward_model = reward_model + self.value_model = value_model + self.reward_sources = self._normalize_reward_sources( + self._pop_alias(kwargs, "reward_sources", default=reward_sources), + reward_fn, + reward_model, + ) + self.reward_source = reward_source + self.reward_normalization = reward_normalization + self.mask_truncated_completions = mask_truncated_completions + self.entropy_bonus = entropy_bonus + self.advantage_estimator = advantage_estimator + self.rollout_batch_size = rollout_batch_size + self.normalize_advantages = normalize_advantages + self.eval_steps = eval_steps + self.eval_num_batches = eval_num_batches + self.eval_num_generations = eval_num_generations + self.generation_batch_size = generation_batch_size + self.score_chunk_size = score_chunk_size + self.precompute_reference_scores = precompute_reference_scores + self.dataset_mode = dataset_mode + self.chat_template = chat_template + self.auto_detect_dataset = auto_detect_dataset + self._set_remaining(kwargs) + + +class OnlineDPOConfig(RLConfigBase): + _NON_SERIALIZED_FIELDS = {"reward_fn", "reward_model"} - def to_dict(self) -> Dict[str, Any]: - return {k: v for k, v in self.__dict__.items() - if not k.startswith('_') and k != 'reward_fn'} + def __init__( + self, + beta: float = 0.1, + label_smoothing: float = 0.0, + output_dir: str = "./online_dpo_outputs", + learning_rate: float = 5e-7, + seed: int = 0, + per_device_train_batch_size: int = 1, + num_train_epochs: int = 1, + max_steps: int = -1, + logging_steps: int = 1, + save_steps: int = 100, + max_seq_length: int = 2048, + max_completion_length: int = 256, + num_generations: int = 4, + temperature: float = 0.7, + reward_fn: Optional[Callable] = None, + reward_model: Optional[Any] = None, + rollout_batch_size: Optional[int] = None, + reward_normalization: str = "none", + mask_truncated_completions: bool = False, + minibatch_reuse_steps: int = 1, + entropy_bonus: float = 0.0, + reward_source: str = "auto", + reward_sources: Optional[Any] = None, + kl_target: Optional[float] = None, + kl_penalty_mode: str = "kl", + eval_steps: Optional[int] = None, + eval_num_batches: Optional[int] = None, + eval_num_generations: Optional[int] = None, + generation_batch_size: Optional[int] = None, + score_chunk_size: Optional[int] = None, + precompute_reference_scores: bool = False, + dataset_mode: Optional[str] = None, + chat_template: Optional[Any] = None, + auto_detect_dataset: bool = True, + **kwargs, + ): + self.kl_beta = self._pop_alias(kwargs, "kl_beta", default=beta) + self.beta = self.kl_beta + self.kl_target = kl_target + self.kl_penalty_mode = kl_penalty_mode + self.label_smoothing = label_smoothing + self.output_dir = output_dir + self.learning_rate = learning_rate + self.seed = seed + self.per_device_train_batch_size = per_device_train_batch_size + self.num_train_epochs = num_train_epochs + self.max_steps = max_steps + self.logging_steps = logging_steps + self.save_steps = save_steps + self.max_seq_length = max_seq_length + self.max_completion_length = max_completion_length + self.num_generations = self._pop_alias(kwargs, "generations_per_prompt", default=num_generations) + self.temperature = temperature + self.reward_fn = reward_fn + self.reward_model = reward_model + self.reward_sources = self._normalize_reward_sources( + self._pop_alias(kwargs, "reward_sources", default=reward_sources), + reward_fn, + reward_model, + ) + self.reward_source = reward_source + self.reward_normalization = reward_normalization + self.mask_truncated_completions = mask_truncated_completions + self.minibatch_reuse_steps = minibatch_reuse_steps + self.entropy_bonus = entropy_bonus + self.rollout_batch_size = rollout_batch_size + self.eval_steps = eval_steps + self.eval_num_batches = eval_num_batches + self.eval_num_generations = eval_num_generations + self.generation_batch_size = generation_batch_size + self.score_chunk_size = score_chunk_size + self.precompute_reference_scores = precompute_reference_scores + self.dataset_mode = dataset_mode + self.chat_template = chat_template + self.auto_detect_dataset = auto_detect_dataset + self._set_remaining(kwargs) + + +class KTOConfig(RLConfigBase): + def __init__( + self, + beta: float = 0.1, + output_dir: str = "./kto_outputs", + learning_rate: float = 5e-7, + per_device_train_batch_size: int = 1, + num_train_epochs: int = 1, + max_steps: int = 100, + logging_steps: int = 10, + save_steps: int = 100, + max_seq_length: int = 2048, + dataset_mode: Optional[str] = None, + chat_template: Optional[Any] = None, + auto_detect_dataset: bool = True, + **kwargs, + ): + self.beta = beta + self.output_dir = output_dir + self.learning_rate = learning_rate + self.per_device_train_batch_size = per_device_train_batch_size + self.num_train_epochs = num_train_epochs + self.max_steps = max_steps + self.logging_steps = logging_steps + self.save_steps = save_steps + self.max_seq_length = max_seq_length + self.dataset_mode = dataset_mode + self.chat_template = chat_template + self.auto_detect_dataset = auto_detect_dataset + self._set_remaining(kwargs) -class DPOTrainer: - """ - Direct Preference Optimization Trainer. - - DPO trains models on preference data (chosen vs rejected responses) - without requiring a separate reward model. - - Compatible with TRL's DPOTrainer API. - Now with PROPER DPO loss implementation! - - Example: - >>> from mlx_tune import FastLanguageModel, DPOTrainer, DPOConfig - >>> - >>> model, tokenizer = FastLanguageModel.from_pretrained(...) - >>> model = FastLanguageModel.get_peft_model(model, r=16) - >>> - >>> # Preference dataset with chosen/rejected pairs - >>> dataset = [ - ... {"prompt": "...", "chosen": "...", "rejected": "..."}, - ... ] - >>> - >>> trainer = DPOTrainer( - ... model=model, - ... ref_model=None, # Uses stop_gradient by default - ... train_dataset=dataset, - ... tokenizer=tokenizer, - ... args=DPOConfig(beta=0.1), - ... ) - >>> trainer.train() - """ +class SimPOConfig(RLConfigBase): + def __init__( + self, + gamma: float = 0.5, + beta: float = 2.0, + output_dir: str = "./simpo_outputs", + learning_rate: float = 5e-7, + per_device_train_batch_size: int = 1, + num_train_epochs: int = 1, + max_steps: int = 100, + logging_steps: int = 10, + save_steps: int = 100, + max_seq_length: int = 2048, + dataset_mode: Optional[str] = None, + chat_template: Optional[Any] = None, + auto_detect_dataset: bool = True, + **kwargs, + ): + self.gamma = gamma + self.beta = beta + self.output_dir = output_dir + self.learning_rate = learning_rate + self.per_device_train_batch_size = per_device_train_batch_size + self.num_train_epochs = num_train_epochs + self.max_steps = max_steps + self.logging_steps = logging_steps + self.save_steps = save_steps + self.max_seq_length = max_seq_length + self.dataset_mode = dataset_mode + self.chat_template = chat_template + self.auto_detect_dataset = auto_detect_dataset + self._set_remaining(kwargs) + + +def _resolve_reward_evaluator( + reward_model: Optional[Any], + reward_fn: Optional[Any], + reward_sources: Optional[List[Any]] = None, +) -> Any: + if reward_sources: + return public_create_reward_function(rewards=reward_sources) + if reward_model is not None: + return reward_model + if reward_fn is not None: + return reward_fn + return None + + +def _normalize_reward_values( + rewards: mx.array, + prompt_group_indices: mx.array, + mode: str, +) -> mx.array: + if mode in {"none", "off", ""}: + return rewards + if mode not in {"center", "zscore"}: + raise ValueError(f"Unsupported reward normalization mode: {mode}") + reward_values = rewards.tolist() + groups = prompt_group_indices.tolist() + adjusted = [0.0] * len(reward_values) + grouped: Dict[int, List[int]] = {} + for index, group in enumerate(groups): + grouped.setdefault(int(group), []).append(index) + for positions in grouped.values(): + group_rewards = mx.array([reward_values[position] for position in positions], dtype=mx.float32) + mean_value = mx.mean(group_rewards) + centered = group_rewards - mean_value + if mode == "zscore": + std_value = mx.std(group_rewards) + centered = centered if float(std_value.item()) < 1e-6 else centered / std_value + for offset, position in enumerate(positions): + adjusted[position] = float(centered[offset].item()) + return mx.array(adjusted, dtype=mx.float32) + + +def _apply_truncation_mask_to_rollout(rollout_batch: RolloutBatch) -> RolloutBatch: + if rollout_batch.truncation_flags is None or not bool(mx.any(rollout_batch.truncation_flags).item()): + return rollout_batch + + keep_mask = (~rollout_batch.truncation_flags).astype(mx.float32) + zero_lengths = mx.where(rollout_batch.truncation_flags, mx.zeros_like(rollout_batch.completion_lengths), rollout_batch.completion_lengths) + rollout_batch.completion_lengths = zero_lengths + rollout_batch.policy_eval.completion_lengths = zero_lengths + rollout_batch.rollout_logprobs = rollout_batch.rollout_logprobs * keep_mask + rollout_batch.policy_eval.rollout_logprobs = rollout_batch.rollout_logprobs + rollout_batch.policy_eval.old_logprobs = rollout_batch.rollout_logprobs + if rollout_batch.old_logprobs is not None: + rollout_batch.old_logprobs = rollout_batch.old_logprobs * keep_mask + if rollout_batch.policy_eval.old_token_logprobs is not None: + rollout_batch.policy_eval.old_token_logprobs = rollout_batch.policy_eval.old_token_logprobs * keep_mask[:, None] + if rollout_batch.reference_logprobs is not None: + rollout_batch.reference_logprobs = rollout_batch.reference_logprobs * keep_mask + rollout_batch.policy_eval.reference_logprobs = rollout_batch.reference_logprobs + if rollout_batch.value_predictions is not None: + rollout_batch.value_predictions = rollout_batch.value_predictions * keep_mask + rollout_batch.policy_eval.value_predictions = rollout_batch.value_predictions + if rollout_batch.rewards is not None: + rollout_batch.rewards = rollout_batch.rewards * keep_mask + if rollout_batch.returns is not None: + rollout_batch.returns = rollout_batch.returns * keep_mask + rollout_batch.policy_eval.returns = rollout_batch.returns + if rollout_batch.advantages is not None: + rollout_batch.advantages = rollout_batch.advantages * keep_mask + rollout_batch.policy_eval.advantages = rollout_batch.advantages + return rollout_batch + + +def _prepare_on_policy_samples( + dataset: Any, + tokenizer: Any, + config: Any, +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Optional[str]]: + prompt_samples: List[Dict[str, Any]] = [] + rollout_samples: List[Dict[str, Any]] = [] + prepared = prepare_rl_dataset( + dataset, + mode=config.dataset_mode, + tokenizer=tokenizer, + chat_template=getattr(config, "chat_template", None), + ) + for sample_index, sample in enumerate(prepared): + if prepared.mode == "prompt": + prompt = sample.get("prompt", "") + if not prompt: + continue + prompt_samples.append( + { + "sample_index": sample_index, + "prompt": prompt, + "prompt_ids": _encode_text(tokenizer, prompt), + "reward_context": sample.get("reward_context", prompt), + } + ) + elif prepared.mode == "rollout": + rollout_samples.append( + { + "sample_index": sample_index, + "prompt": sample["prompt"], + "completion": sample["completion"], + "reward": sample.get("reward"), + "reward_context": sample.get("reward_context", sample["completion"]), + } + ) + return prompt_samples, rollout_samples, prepared.mode + + +def _next_cursor_batch( + samples: List[Dict[str, Any]], + count: int, + cursor: int, + algorithm: str, +) -> Tuple[List[Dict[str, Any]], int]: + if not samples: + raise ValueError(f"{algorithm} training dataset is empty.") + + batch: List[Dict[str, Any]] = [] + next_cursor = cursor + for _ in range(max(1, count)): + batch.append(samples[next_cursor]) + next_cursor = (next_cursor + 1) % len(samples) + return batch, next_cursor + + +def _rollout_score_batch_size(trainer: Any, num_generations: Optional[int] = None) -> int: + generations = trainer.num_generations if num_generations is None else num_generations + return max(1, trainer.rollout_batch_size * max(1, generations)) + + +def _fixed_rollout_cache_dataset_fingerprint(samples: List[Dict[str, Any]]) -> str: + return _hash_payload( + [ + { + "sample_index": sample.get("sample_index"), + "prompt": sample.get("prompt"), + "completion": sample.get("completion"), + "reward_context": sample.get("reward_context"), + } + for sample in samples + ] + ) + + +def _preference_cache_dataset_fingerprint(samples: List[Dict[str, Any]]) -> str: + return _hash_payload( + [ + { + "sample_index": sample.get("sample_index"), + "chosen_ids": sample.get("chosen_ids"), + "rejected_ids": sample.get("rejected_ids"), + } + for sample in samples + ] + ) + + +def _fixed_rollout_reference_cache_valid( + trainer: Any, + cache_key: str, + samples: List[Dict[str, Any]], +) -> bool: + cached_scores = trainer.runtime_cache_arrays.get(cache_key) + cached_indices = trainer.runtime_cache_arrays.get(f"{cache_key}.sample_indices") + if cached_scores is None or cached_indices is None or cached_scores.shape[0] != len(samples): + return False + cache_info = trainer.cache_metadata.get("reference_score_caches", {}).get(cache_key, {}) + return ( + cached_indices.tolist() == [sample["sample_index"] for sample in samples] + and cache_info.get("dataset_fingerprint") == _fixed_rollout_cache_dataset_fingerprint(samples) + ) + + +def _ensure_fixed_rollout_reference_cache( + trainer: Any, + cache_key: str, + samples: List[Dict[str, Any]], +) -> Optional[mx.array]: + if not getattr(trainer.config, "precompute_reference_scores", False): + return None + if trainer.reference_policy is None or not samples: + return None + if _fixed_rollout_reference_cache_valid(trainer, cache_key, samples): + return trainer.runtime_cache_arrays[cache_key] + + prompt_ids = [_encode_text(trainer.tokenizer, sample["prompt"]) for sample in samples] + completion_ids = [ + _encode_text(trainer.tokenizer, sample["completion"], add_special_tokens=False) + for sample in samples + ] + full_sequences = [prompt + completion for prompt, completion in zip(prompt_ids, completion_ids)] + prompt_lengths = [len(prompt) for prompt in prompt_ids] + completion_lengths = [len(completion) for completion in completion_ids] + eval_batch = make_policy_eval_batch( + full_sequences, + pad_id=_pad_token_id(trainer.tokenizer), + mode="completion", + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + sample_indices=mx.array([sample["sample_index"] for sample in samples]), + ) + reference_logprobs = score_policy_in_chunks( + _actual_model(trainer.reference_policy.model), + eval_batch, + batch_size=_rollout_score_batch_size(trainer), + token_budget=getattr(trainer.config, "score_chunk_size", None), + mode="completion", + ).summed_logprobs + reference_logprobs = mx.stop_gradient(reference_logprobs.astype(mx.float32)) + trainer.runtime_cache_arrays[cache_key] = reference_logprobs + trainer.runtime_cache_arrays[f"{cache_key}.sample_indices"] = mx.array( + [sample["sample_index"] for sample in samples], + dtype=mx.int32, + ) + trainer.cache_metadata.setdefault("reference_score_caches", {})[cache_key] = { + "num_samples": len(samples), + "sampling_config_fingerprint": trainer._sampling_config_fingerprint(), + "dataset_fingerprint": _fixed_rollout_cache_dataset_fingerprint(samples), + } + return reference_logprobs + + +def _preference_reference_cache_valid( + trainer: Any, + cache_key: str, + samples: List[Dict[str, Any]], +) -> bool: + chosen = trainer.runtime_cache_arrays.get(f"{cache_key}.chosen") + rejected = trainer.runtime_cache_arrays.get(f"{cache_key}.rejected") + sample_indices = trainer.runtime_cache_arrays.get(f"{cache_key}.sample_indices") + if chosen is None or rejected is None or sample_indices is None: + return False + if chosen.shape[0] != len(samples) or rejected.shape[0] != len(samples): + return False + cache_info = trainer.cache_metadata.get("reference_score_caches", {}).get(cache_key, {}) + return ( + sample_indices.tolist() == [sample["sample_index"] for sample in samples] + and cache_info.get("dataset_fingerprint") == _preference_cache_dataset_fingerprint(samples) + ) + + +def _ensure_preference_reference_cache( + trainer: Any, + cache_key: str, + samples: List[Dict[str, Any]], +) -> Tuple[Optional[mx.array], Optional[mx.array]]: + if not getattr(trainer.config, "precompute_reference_scores", False): + return None, None + if trainer.reference_policy is None or not samples: + return None, None + if _preference_reference_cache_valid(trainer, cache_key, samples): + return ( + trainer.runtime_cache_arrays[f"{cache_key}.chosen"], + trainer.runtime_cache_arrays[f"{cache_key}.rejected"], + ) + + preference_batch = make_preference_batch( + chosen_sequences=[sample["chosen_ids"] for sample in samples], + rejected_sequences=[sample["rejected_ids"] for sample in samples], + pad_id=_pad_token_id(trainer.tokenizer), + sample_indices=[sample["sample_index"] for sample in samples], + ) + reference_model = _actual_model(trainer.reference_policy.model) + chosen_reference = score_policy_in_chunks( + reference_model, + preference_batch.chosen, + batch_size=max(1, trainer.batch_size), + token_budget=getattr(trainer.config, "score_chunk_size", None), + mode="sequence", + ).summed_logprobs + rejected_reference = score_policy_in_chunks( + reference_model, + preference_batch.rejected, + batch_size=max(1, trainer.batch_size), + token_budget=getattr(trainer.config, "score_chunk_size", None), + mode="sequence", + ).summed_logprobs + trainer.runtime_cache_arrays[f"{cache_key}.chosen"] = mx.stop_gradient(chosen_reference.astype(mx.float32)) + trainer.runtime_cache_arrays[f"{cache_key}.rejected"] = mx.stop_gradient(rejected_reference.astype(mx.float32)) + trainer.runtime_cache_arrays[f"{cache_key}.sample_indices"] = mx.array( + [sample["sample_index"] for sample in samples], + dtype=mx.int32, + ) + trainer.cache_metadata.setdefault("reference_score_caches", {})[cache_key] = { + "num_samples": len(samples), + "sampling_config_fingerprint": trainer._sampling_config_fingerprint(), + "dataset_fingerprint": _preference_cache_dataset_fingerprint(samples), + } + return ( + trainer.runtime_cache_arrays[f"{cache_key}.chosen"], + trainer.runtime_cache_arrays[f"{cache_key}.rejected"], + ) + + +def _collect_fixed_rollout_batch( + trainer: Any, + samples: List[Dict[str, Any]], + cached_reference_logprobs: Optional[mx.array] = None, +) -> RolloutBatch: + prompt_ids = [] + completion_ids = [] + truncation_flags = [] + max_seq_length = getattr(trainer, "max_seq_length", getattr(trainer.config, "max_seq_length", None)) + max_completion_length = getattr( + trainer, + "max_completion_length", + getattr(trainer.config, "max_completion_length", None), + ) + for sample in samples: + raw_prompt_ids = _encode_text(trainer.tokenizer, sample["prompt"]) + raw_completion_ids = _encode_text(trainer.tokenizer, sample["completion"], add_special_tokens=False) + capped_prompt_ids, capped_completion_ids, truncated = cap_prompt_and_completion_lengths( + raw_prompt_ids, + raw_completion_ids, + max_seq_length=max_seq_length, + max_completion_length=max_completion_length, + ) + prompt_ids.append(capped_prompt_ids) + completion_ids.append(capped_completion_ids) + truncation_flags.append(bool(truncated)) + full_sequences = [prompt + completion for prompt, completion in zip(prompt_ids, completion_ids)] + prompt_lengths = [len(prompt) for prompt in prompt_ids] + completion_lengths = [len(completion) for completion in completion_ids] + prompt_group_map: Dict[str, int] = {} + grouped_indices = [] + for sample in samples: + grouped_indices.append(prompt_group_map.setdefault(sample["prompt"], len(prompt_group_map))) + eval_batch = make_policy_eval_batch( + full_sequences, + pad_id=_pad_token_id(trainer.tokenizer), + mode="completion", + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + sample_indices=mx.array([sample["sample_index"] for sample in samples]), + prompt_group_indices=mx.array(grouped_indices), + reference_logprobs=cached_reference_logprobs, + ) + scored = score_policy_in_chunks( + _actual_model(trainer.model), + eval_batch, + batch_size=_rollout_score_batch_size(trainer), + token_budget=getattr(trainer.config, "score_chunk_size", None), + mode="completion", + ) + reward_values = [float(sample.get("reward", 0.0) or 0.0) for sample in samples] + rollout_batch = RolloutBatch( + prompt_ids=prompt_ids, + prompt_lengths=mx.array(prompt_lengths), + completion_ids=completion_ids, + completion_lengths=mx.array(completion_lengths), + prompt_texts=[sample["prompt"] for sample in samples], + original_prompt_texts=[sample["prompt"] for sample in samples], + completion_texts=[sample["completion"] for sample in samples], + reward_contexts=[sample.get("reward_context") for sample in samples], + sampled_token_logprobs=mx.zeros( + (len(samples), max(completion_lengths) if completion_lengths else 0), + dtype=mx.float32, + ), + rollout_logprobs=scored.summed_logprobs, + eos_flags=mx.array([True] * len(samples)), + truncation_flags=mx.array(truncation_flags), + prompt_group_indices=mx.array(grouped_indices), + policy_eval=scored, + sample_indices=mx.array([sample["sample_index"] for sample in samples]), + old_logprobs=scored.summed_logprobs, + rewards=mx.array(reward_values, dtype=mx.float32), + reference_logprobs=cached_reference_logprobs, + ) + rollout_batch.policy_eval.old_logprobs = scored.summed_logprobs + rollout_batch.policy_eval.old_token_logprobs = scored.token_logprobs + if cached_reference_logprobs is not None: + rollout_batch.policy_eval.reference_logprobs = cached_reference_logprobs + return rollout_batch + + +class RewardTrainer(_RLTrainerBase): + algorithm = "reward" + + def __init__( + self, + model: Any, + train_dataset: Any, + tokenizer: Optional[Any] = None, + args: Optional[RewardConfig] = None, + reward_model: Optional[RewardModel] = None, + use_native: bool = True, + **kwargs, + ): + self.reward_model = model if isinstance(model, RewardModel) else reward_model or build_reward_model(model) + self.model = self.reward_model.base_model + self.train_dataset = train_dataset + self.tokenizer = tokenizer or getattr(self.reward_model, "tokenizer", None) or getattr(self.model, "tokenizer", None) + self.use_native = use_native and HAS_NATIVE_TRAINING + self.config = args or RewardConfig() + self.output_dir = Path(self.config.output_dir) + self.learning_rate = self.config.learning_rate + self.batch_size = self.config.per_device_train_batch_size + self.max_steps = self.config.max_steps + self.max_seq_length = self.config.max_seq_length + self.pairwise_margin = self.config.pairwise_margin + self.regression_loss_type = self.config.regression_loss_type + self.logging_steps = self.config.logging_steps + self.save_steps = self.config.save_steps + dataset_size = len(train_dataset) if hasattr(train_dataset, "__len__") else 100 + self.iters = self.max_steps if self.max_steps > 0 else max( + 1, (dataset_size // max(1, self.batch_size)) * self.config.num_train_epochs + ) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.train_samples: List[Dict[str, Any]] = [] + self.dataset_type: Optional[str] = None + self._train_target = _ScalarRoleTrainTarget( + _actual_model(self.reward_model.base_model), + self.reward_model.head, + ) + self._init_native_state() + + def _primary_role_name(self) -> str: + return "reward_model" + + def _primary_role_weight_format(self) -> Any: + reward_base = getattr(self.reward_model, "base_model", None) + return { + "backbone": "weights.safetensors", + "head": "head.safetensors", + "adapters": "adapters.safetensors" + if hasattr(reward_base, "has_adapters") and reward_base.has_adapters() + else None, + } + + def _save_primary_role(self, checkpoint_dir: Optional[Path] = None) -> None: + self.reward_model.save_pretrained(str(self._role_dir("reward_model", checkpoint_dir))) + + def _load_primary_role(self, checkpoint_dir: Path) -> None: + role_dir = self._role_dir("reward_model", checkpoint_dir) + if role_dir.exists(): + self.reward_model.load_pretrained(str(role_dir)) + + def _prepare_training_samples(self) -> None: + records = list(self.train_dataset) + dataset_mode = self.config.dataset_mode + if dataset_mode is None and records: + first_sample = records[0] + dataset_mode = "reward_pairwise" if {"chosen", "rejected"} <= set(first_sample.keys()) else "reward_scalar" + prepared = prepare_rl_dataset( + records, + mode=dataset_mode, + tokenizer=self.tokenizer, + chat_template=getattr(self.config, "chat_template", None), + ) + if not prepared: + raise ValueError("RewardTrainer requires pairwise or scalar reward samples.") + if prepared.mode not in {"reward_scalar", "reward_pairwise"}: + raise ValueError("RewardTrainer requires reward_scalar or reward_pairwise datasets.") + normalized = list(prepared) + sample_types = {sample["type"] for sample in normalized} + if len(sample_types) != 1: + raise ValueError("RewardTrainer currently requires all reward samples to use the same supervision type.") + self.dataset_type = next(iter(sample_types)) + self.train_samples = [] + for sample in normalized: + if self.dataset_type == "pairwise": + self.train_samples.append( + _tokenize_reward_pairwise_sample(self.tokenizer, sample, self.max_seq_length) + ) + else: + if "score" not in sample: + raise ValueError("RewardTrainer scalar samples require a score field.") + self.train_samples.append( + _tokenize_reward_scalar_sample(self.tokenizer, sample, self.max_seq_length) + ) + + def _build_pairwise_batch(self, samples: List[Dict[str, Any]]) -> Dict[str, mx.array]: + pad_id = _pad_token_id(self.tokenizer) + chosen_ids, chosen_lengths = _pad_sequences([sample["chosen_ids"] for sample in samples], pad_id) + rejected_ids, rejected_lengths = _pad_sequences([sample["rejected_ids"] for sample in samples], pad_id) + return { + "chosen_ids": chosen_ids, + "rejected_ids": rejected_ids, + "chosen_lengths": chosen_lengths, + "rejected_lengths": rejected_lengths, + "chosen_prompt_lengths": mx.array([sample["chosen_prompt_length"] for sample in samples]), + "rejected_prompt_lengths": mx.array([sample["rejected_prompt_length"] for sample in samples]), + "chosen_completion_lengths": mx.array([sample["chosen_completion_length"] for sample in samples]), + "rejected_completion_lengths": mx.array([sample["rejected_completion_length"] for sample in samples]), + } + + def _build_scalar_batch(self, samples: List[Dict[str, Any]]) -> Dict[str, mx.array]: + pad_id = _pad_token_id(self.tokenizer) + input_ids, lengths = _pad_sequences([sample["ids"] for sample in samples], pad_id) + return { + "input_ids": input_ids, + "lengths": lengths, + "prompt_lengths": mx.array([sample["prompt_length"] for sample in samples]), + "completion_lengths": mx.array([sample["completion_length"] for sample in samples]), + "targets": mx.array([sample["score"] for sample in samples], dtype=mx.float32), + } + + def train(self, resume_from_checkpoint: Optional[str] = None): + if not self.use_native: + raise ValueError("RewardTrainer requires native MLX training support.") + return self._train_native(resume_from_checkpoint=resume_from_checkpoint) + + def _train_native(self, resume_from_checkpoint: Optional[str] = None): + self._apply_lora_if_needed() + self._prepare_training_samples() + + optimizer = self._optimizer_for_training() + self.optimizer = optimizer + self.optimizers = {self._primary_optimizer_name(): optimizer} + + if resume_from_checkpoint is not None: + self.load_state(optimizer=optimizer, checkpoint_dir=Path(resume_from_checkpoint)) + + if self.dataset_type == "pairwise": + def loss_fn(_, batch): + loss, outputs = reward_model_pairwise_loss( + self.reward_model, + chosen_input_ids=batch["chosen_ids"], + rejected_input_ids=batch["rejected_ids"], + chosen_sequence_lengths=batch["chosen_lengths"], + rejected_sequence_lengths=batch["rejected_lengths"], + chosen_prompt_lengths=batch["chosen_prompt_lengths"], + rejected_prompt_lengths=batch["rejected_prompt_lengths"], + chosen_completion_lengths=batch["chosen_completion_lengths"], + rejected_completion_lengths=batch["rejected_completion_lengths"], + margin=self.pairwise_margin, + ) + return loss, outputs + else: + def loss_fn(_, batch): + loss, predictions = reward_model_regression_loss( + self.reward_model, + input_ids=batch["input_ids"], + sequence_lengths=batch["lengths"], + targets=batch["targets"], + prompt_lengths=batch["prompt_lengths"], + completion_lengths=batch["completion_lengths"], + loss_type=self.regression_loss_type, + ) + return loss, predictions + + value_and_grad = nn.value_and_grad(self._train_target, lambda modules, batch: loss_fn(modules, batch)[0]) + running_loss = 0.0 + last_loss = None + + while self.global_step < self.iters: + batch_samples = self._next_samples(self.train_samples) + batch = ( + self._build_pairwise_batch(batch_samples) + if self.dataset_type == "pairwise" + else self._build_scalar_batch(batch_samples) + ) + loss, grads = value_and_grad(self._train_target, batch) + optimizer.update(self._train_target, grads) + mx.eval(self._train_target.parameters(), optimizer.state) + + last_loss = float(loss.item()) + metric_payload: Dict[str, Any] = {"loss": last_loss} + if self.dataset_type == "pairwise": + _, outputs = loss_fn(self._train_target, batch) + metric_payload["ranking_accuracy"] = pairwise_ranking_accuracy( + outputs["chosen_scores"], + outputs["rejected_scores"], + ) + else: + _, predictions = loss_fn(self._train_target, batch) + metric_payload.update( + scalar_loss_metrics(loss, predictions, batch["targets"]) + ) + + running_loss += last_loss + self.global_step += 1 + self._record_metric(**metric_payload) + + if self.global_step % self.logging_steps == 0: + print( + f"Reward step {self.global_step}/{self.iters} | " + f"loss={running_loss / self.logging_steps:.4f}" + ) + running_loss = 0.0 + + if self.global_step % self.save_steps == 0: + self.save_state(optimizer=optimizer) + + self.save_state(optimizer=optimizer) + return { + "status": "success", + "global_step": self.global_step, + "final_loss": last_loss, + "reward_model_path": str(self._role_dir("reward_model")), + } + + +class DPOTrainer(_RLTrainerBase): + algorithm = "dpo" + requires_reference_policy = True def __init__( self, model: Any, train_dataset: Any, ref_model: Optional[Any] = None, + reward_model: Optional[Any] = None, + value_model: Optional[Any] = None, tokenizer: Optional[Any] = None, args: Optional[DPOConfig] = None, use_native: bool = True, - **kwargs + **kwargs, ): self.model = model self.ref_model = ref_model + self.reward_model = reward_model + self.value_model = value_model self.train_dataset = train_dataset - self.tokenizer = tokenizer or getattr(model, 'tokenizer', None) + self.tokenizer = tokenizer or getattr(model, "tokenizer", None) self.use_native = use_native and HAS_NATIVE_TRAINING - - # Extract config - if args is None: - args = DPOConfig() - - self.config = args - self.beta = args.beta - self.loss_type = args.loss_type - self.label_smoothing = args.label_smoothing - self.output_dir = Path(args.output_dir) - self.learning_rate = args.learning_rate - self.batch_size = args.per_device_train_batch_size - self.max_steps = args.max_steps - self.max_seq_length = args.max_seq_length - self.max_prompt_length = args.max_prompt_length - self.gradient_accumulation_steps = args.gradient_accumulation_steps - self.warmup_steps = args.warmup_steps - self.logging_steps = args.logging_steps - self.save_steps = args.save_steps - - # Calculate iters - if self.max_steps > 0: - self.iters = self.max_steps - else: - dataset_size = len(train_dataset) if hasattr(train_dataset, '__len__') else 100 - self.iters = max(1, (dataset_size // self.batch_size) * args.num_train_epochs) - + self.config = args or DPOConfig() + self.beta = self.config.beta + self.loss_type = self.config.loss_type + self.label_smoothing = self.config.label_smoothing + self.output_dir = Path(self.config.output_dir) + self.learning_rate = self.config.learning_rate + self.batch_size = self.config.per_device_train_batch_size + self.max_steps = self.config.max_steps + self.max_seq_length = self.config.max_seq_length + self.max_prompt_length = self.config.max_prompt_length + self.gradient_accumulation_steps = self.config.gradient_accumulation_steps + self.warmup_steps = self.config.warmup_steps + self.logging_steps = self.config.logging_steps + self.save_steps = self.config.save_steps + dataset_size = len(train_dataset) if hasattr(train_dataset, "__len__") else 100 + self.iters = self.max_steps if self.max_steps > 0 else max( + 1, (dataset_size // max(1, self.batch_size)) * self.config.num_train_epochs + ) self.output_dir.mkdir(parents=True, exist_ok=True) - self.adapter_path = self.output_dir / "adapters" + self.adapter_path = self.output_dir / "policy" self.adapter_path.mkdir(parents=True, exist_ok=True) + self.train_samples: List[Dict[str, Any]] = [] + self._init_native_state() - print(f"DPOTrainer initialized:") - print(f" Beta: {self.beta}") - print(f" Loss type: {self.loss_type}") - print(f" Learning rate: {self.learning_rate}") - print(f" Iterations: {self.iters}") - print(f" Native training: {self.use_native}") - print(f" Using proper DPO loss: {self.use_native}") - - def _tokenize_preference_pair(self, sample: Dict) -> Dict: - """Tokenize a preference pair (prompt + chosen, prompt + rejected).""" - prompt = sample.get('prompt', '') - chosen = sample.get('chosen', '') - rejected = sample.get('rejected', '') - - # Tokenize chosen and rejected with prompt - chosen_text = prompt + chosen - rejected_text = prompt + rejected - - chosen_ids = self.tokenizer.encode(chosen_text) - rejected_ids = self.tokenizer.encode(rejected_text) - - # Truncate if needed - if len(chosen_ids) > self.max_seq_length: - chosen_ids = chosen_ids[:self.max_seq_length] - if len(rejected_ids) > self.max_seq_length: - rejected_ids = rejected_ids[:self.max_seq_length] + def _tokenize_preference_pair(self, sample: Dict[str, Any], sample_index: int) -> Dict[str, Any]: + prompt = sample.get("prompt", "") + chosen = sample.get("chosen", "") + rejected = sample.get("rejected", "") + chosen_ids = self.tokenizer.encode(prompt + chosen)[: self.max_seq_length] + rejected_ids = self.tokenizer.encode(prompt + rejected)[: self.max_seq_length] return { - 'chosen_ids': chosen_ids, - 'rejected_ids': rejected_ids, - 'chosen_length': len(chosen_ids), - 'rejected_length': len(rejected_ids), + "sample_index": sample_index, + "chosen_ids": chosen_ids, + "rejected_ids": rejected_ids, + "chosen_length": len(chosen_ids), + "rejected_length": len(rejected_ids), } - def _prepare_dpo_batches(self): - """Prepare batched DPO data for training.""" - tokenized_data = [] - for sample in self.train_dataset: - if 'prompt' in sample and 'chosen' in sample and 'rejected' in sample: - tokenized_data.append(self._tokenize_preference_pair(sample)) + def _prepare_training_samples(self) -> None: + self.train_samples = [] + prepared = prepare_rl_dataset( + self.train_dataset, + mode=self.config.dataset_mode or "preference", + tokenizer=self.tokenizer, + chat_template=getattr(self.config, "chat_template", None), + ) + for sample_index, sample in enumerate(prepared): + self.train_samples.append(self._tokenize_preference_pair(sample, sample_index)) + if not self.train_samples: + raise ValueError("DPOTrainer requires prompt/chosen/rejected samples.") + + def _precompute_reference_cache(self) -> None: + self._ensure_reference_policy() + preference_batch = make_preference_batch( + chosen_sequences=[sample["chosen_ids"] for sample in self.train_samples], + rejected_sequences=[sample["rejected_ids"] for sample in self.train_samples], + pad_id=_pad_token_id(self.tokenizer), + sample_indices=[sample["sample_index"] for sample in self.train_samples], + ) + ref_chosen = score_policy_in_chunks( + _actual_model(self.reference_policy.model), + preference_batch.chosen, + batch_size=max(1, self.batch_size), + mode="sequence", + ).summed_logprobs + ref_rejected = score_policy_in_chunks( + _actual_model(self.reference_policy.model), + preference_batch.rejected, + batch_size=max(1, self.batch_size), + mode="sequence", + ).summed_logprobs + ref_chosen = mx.stop_gradient(ref_chosen) + ref_rejected = mx.stop_gradient(ref_rejected) + for idx, sample in enumerate(self.train_samples): + sample["reference_chosen_logprobs"] = ref_chosen[idx] + sample["reference_rejected_logprobs"] = ref_rejected[idx] + self.cache_metadata = { + "type": "inline_preference_reference_logprobs", + "num_samples": len(self.train_samples), + } - return tokenized_data + def _restore_reference_cache(self, flat_state: Dict[str, mx.array]) -> None: + if "dpo.reference_chosen_logprobs" not in flat_state: + self._precompute_reference_cache() + return + ref_chosen = flat_state["dpo.reference_chosen_logprobs"] + ref_rejected = flat_state["dpo.reference_rejected_logprobs"] + if ref_chosen.shape[0] != len(self.train_samples): + raise ValueError("Saved DPO cache does not match current dataset ordering.") + for idx, sample in enumerate(self.train_samples): + sample["reference_chosen_logprobs"] = ref_chosen[idx] + sample["reference_rejected_logprobs"] = ref_rejected[idx] + + def _build_batch(self, samples: List[Dict[str, Any]]) -> PreferenceBatch: + return make_preference_batch( + chosen_sequences=[sample["chosen_ids"] for sample in samples], + rejected_sequences=[sample["rejected_ids"] for sample in samples], + pad_id=_pad_token_id(self.tokenizer), + sample_indices=[sample["sample_index"] for sample in samples], + chosen_reference_logprobs=mx.array( + [sample["reference_chosen_logprobs"] for sample in samples] + ), + rejected_reference_logprobs=mx.array( + [sample["reference_rejected_logprobs"] for sample in samples] + ), + ) - def _pad_to_length(self, ids: List[int], length: int, pad_id: int = 0) -> List[int]: - """Pad sequence to target length.""" - if len(ids) >= length: - return ids[:length] - return ids + [pad_id] * (length - len(ids)) + def _extra_state_arrays(self) -> Dict[str, mx.array]: + return { + "dpo.reference_chosen_logprobs": mx.array( + [sample["reference_chosen_logprobs"] for sample in self.train_samples] + ), + "dpo.reference_rejected_logprobs": mx.array( + [sample["reference_rejected_logprobs"] for sample in self.train_samples] + ), + } - def train(self): - """ - Train the model using DPO with proper loss computation. + def train(self, resume_from_checkpoint: Optional[str] = None): + if self.use_native: + return self._train_native(resume_from_checkpoint=resume_from_checkpoint) + return self._train_subprocess() + + def _train_native(self, resume_from_checkpoint: Optional[str] = None): + self._apply_lora_if_needed() + self._prepare_training_samples() + + actual_model = _actual_model(self.model) + optimizer = self._optimizer_for_training() + self.optimizer = optimizer + + if resume_from_checkpoint is not None: + flat_state = self.load_state(optimizer, Path(resume_from_checkpoint)) + self._restore_reference_cache(flat_state) + else: + self._precompute_reference_cache() + + def loss_fn(model, batch): + loss, _ = compute_dpo_loss( + model=model, + chosen_ids=batch.chosen.input_ids, + rejected_ids=batch.rejected.input_ids, + chosen_lengths=batch.chosen.sequence_lengths, + rejected_lengths=batch.rejected.sequence_lengths, + beta=self.beta, + reference_chosen_logprobs=batch.chosen.reference_logprobs, + reference_rejected_logprobs=batch.rejected.reference_logprobs, + label_smoothing=self.label_smoothing, + ) + return loss - Uses native MLX training with real DPO loss when available, - falls back to SFT approximation otherwise. - """ - print("=" * 70) - print("Starting DPO Training") - print("=" * 70) + value_and_grad = nn.value_and_grad(actual_model, loss_fn) + running_loss = 0.0 + last_loss = None + while self.global_step < self.iters: + batch_samples = self._next_samples(self.train_samples) + batch = self._build_batch(batch_samples) + loss, grads = value_and_grad(actual_model, batch) + optimizer.update(actual_model, grads) + mx.eval(actual_model.parameters(), optimizer.state) + + last_loss = loss.item() + running_loss += last_loss + self.global_step += 1 + self._record_metric(loss=last_loss) + + if self.global_step % self.logging_steps == 0: + print( + f"DPO step {self.global_step}/{self.iters} | " + f"loss={running_loss / self.logging_steps:.4f}" + ) + running_loss = 0.0 + + if self.global_step % self.save_steps == 0: + self.save_state(optimizer, self._extra_state_arrays()) + + self.save_state(optimizer, self._extra_state_arrays()) + return { + "status": "success", + "adapter_path": str(self.adapter_path), + "global_step": self.global_step, + "final_loss": last_loss, + } + + def _train_subprocess(self): + warnings.warn( + "Native DPO training not available. Using SFT-on-chosen approximation.", + UserWarning, + ) + train_file = self.output_dir / "train.jsonl" + valid_file = self.output_dir / "valid.jsonl" + with open(train_file, "w") as handle: + for sample in self.train_dataset: + if "prompt" in sample and "chosen" in sample: + messages = [ + {"role": "user", "content": sample["prompt"]}, + {"role": "assistant", "content": sample["chosen"]}, + ] + handle.write(json.dumps({"messages": messages}) + "\n") + valid_file.write_text(train_file.read_text()) + cmd = [ + "mlx_lm.lora", + "--model", + getattr(self.model, "model_name", "model"), + "--train", + "--data", + str(self.output_dir), + "--iters", + str(self.iters), + "--learning-rate", + str(self.learning_rate), + "--batch-size", + str(self.batch_size), + "--adapter-path", + str(self.adapter_path), + ] + subprocess.run(cmd, check=True) + return {"status": "success", "adapter_path": str(self.adapter_path)} + + +class ORPOTrainer: + def __init__( + self, + model: Any, + train_dataset: Any, + tokenizer: Optional[Any] = None, + args: Optional[ORPOConfig] = None, + use_native: bool = True, + **kwargs, + ): + self.model = model + self.train_dataset = train_dataset + self.tokenizer = tokenizer or getattr(model, "tokenizer", None) + self.use_native = use_native and HAS_NATIVE_TRAINING + self.config = args or ORPOConfig() + self.beta = self.config.beta + self.output_dir = Path(self.config.output_dir) + self.learning_rate = self.config.learning_rate + self.batch_size = self.config.per_device_train_batch_size + self.max_steps = self.config.max_steps + self.max_seq_length = self.config.max_seq_length + self.logging_steps = self.config.logging_steps + self.save_steps = self.config.save_steps + dataset_size = len(train_dataset) if hasattr(train_dataset, "__len__") else 100 + self.iters = self.max_steps if self.max_steps > 0 else max( + 1, (dataset_size // max(1, self.batch_size)) * self.config.num_train_epochs + ) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.adapter_path = self.output_dir / "policy" + self.adapter_path.mkdir(parents=True, exist_ok=True) + + def _tokenize_preference_pair(self, sample: Dict[str, Any]) -> Dict[str, Any]: + prompt = sample.get("prompt", "") + chosen = sample.get("chosen", "") + rejected = sample.get("rejected", "") + chosen_ids = self.tokenizer.encode(prompt + chosen)[: self.max_seq_length] + rejected_ids = self.tokenizer.encode(prompt + rejected)[: self.max_seq_length] + return { + "chosen_ids": chosen_ids, + "rejected_ids": rejected_ids, + "chosen_length": len(chosen_ids), + "rejected_length": len(rejected_ids), + } + + def train(self): if self.use_native: return self._train_native() - else: - return self._train_subprocess() + return self._train_subprocess() def _train_native(self): - """Train using native MLX with proper DPO loss.""" - print("\n[Using Native DPO Training with Proper Loss]") - - # Apply LoRA if needed - if hasattr(self.model, '_apply_lora') and not getattr(self.model, '_lora_applied', False): - print("Applying LoRA adapters...") + if hasattr(self.model, "_apply_lora") and not getattr(self.model, "_lora_applied", False): self.model._apply_lora() - # Prepare data - print("Preparing preference data...") - tokenized_data = self._prepare_dpo_batches() - print(f"✓ Prepared {len(tokenized_data)} preference pairs") + prepared = prepare_rl_dataset( + self.train_dataset, + mode=self.config.dataset_mode or "preference", + tokenizer=self.tokenizer, + chat_template=getattr(self.config, "chat_template", None), + ) + tokenized_data = [self._tokenize_preference_pair(sample) for sample in prepared] + actual_model = _actual_model(self.model) + optimizer = optim.AdamW(learning_rate=optim.cosine_decay(self.learning_rate, self.iters)) - # Get actual model - actual_model = self.model.model if hasattr(self.model, 'model') else self.model + def loss_fn(model, batch): + chosen_ids, rejected_ids, chosen_lengths, rejected_lengths = batch + loss, _ = compute_orpo_loss( + model, + chosen_ids, + rejected_ids, + chosen_lengths, + rejected_lengths, + self.beta, + ) + return loss + + value_and_grad = nn.value_and_grad(actual_model, loss_fn) + last_loss = None + pad_id = _pad_token_id(self.tokenizer) + + for step in range(self.iters): + samples = tokenized_data[step % len(tokenized_data): step % len(tokenized_data) + self.batch_size] + if len(samples) < self.batch_size: + samples += tokenized_data[: self.batch_size - len(samples)] + chosen_ids, chosen_lengths = _pad_sequences([sample["chosen_ids"] for sample in samples], pad_id) + rejected_ids, rejected_lengths = _pad_sequences([sample["rejected_ids"] for sample in samples], pad_id) + loss, grads = value_and_grad(actual_model, (chosen_ids, rejected_ids, chosen_lengths, rejected_lengths)) + optimizer.update(actual_model, grads) + mx.eval(actual_model.parameters(), optimizer.state) + last_loss = loss.item() - # Create optimizer - lr_schedule = optim.cosine_decay(self.learning_rate, self.iters) - optimizer = optim.AdamW(learning_rate=lr_schedule) + _save_adapters_and_config(self.model, self.adapter_path) + return {"status": "success", "adapter_path": str(self.adapter_path), "final_loss": last_loss} - # Training loop - print(f"\nStarting training for {self.iters} iterations...") + def _train_subprocess(self): + warnings.warn("Using SFT approximation for ORPO.", UserWarning) + train_file = self.output_dir / "train.jsonl" + valid_file = self.output_dir / "valid.jsonl" + with open(train_file, "w") as handle: + for sample in self.train_dataset: + if "prompt" in sample and "chosen" in sample: + messages = [ + {"role": "user", "content": sample["prompt"]}, + {"role": "assistant", "content": sample["chosen"]}, + ] + handle.write(json.dumps({"messages": messages}) + "\n") + valid_file.write_text(train_file.read_text()) + cmd = [ + "mlx_lm.lora", + "--model", + getattr(self.model, "model_name", "model"), + "--train", + "--data", + str(self.output_dir), + "--iters", + str(self.iters), + "--learning-rate", + str(self.learning_rate), + "--batch-size", + str(self.batch_size), + "--adapter-path", + str(self.adapter_path), + ] + subprocess.run(cmd, check=True) + return {"status": "success", "adapter_path": str(self.adapter_path)} + + +class GRPOTrainer(_RLTrainerBase): + algorithm = "grpo" + requires_reference_policy = True + + def __init__( + self, + model: Any, + train_dataset: Any, + eval_dataset: Any = None, + eval_preference_dataset: Any = None, + tokenizer: Optional[Any] = None, + reward_fn: Optional[Callable] = None, + reward_model: Optional[Any] = None, + ref_model: Optional[Any] = None, + value_model: Optional[Any] = None, + args: Optional[GRPOConfig] = None, + use_native: bool = True, + **kwargs, + ): + self.model = model + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.eval_preference_dataset = eval_preference_dataset + self.tokenizer = tokenizer or getattr(model, "tokenizer", None) + self.ref_model = ref_model + self.reward_model = reward_model or getattr(args, "reward_model", None) + self.value_model = value_model or getattr(args, "value_model", None) + self.use_native = use_native and HAS_NATIVE_TRAINING + self.config = args or GRPOConfig() + self.loss_type = self.config.loss_type + self.phase1_loss_type = "phase1_shared_rollout_recompute" + self.resolved_loss_type = self._resolve_loss_type(self.loss_type) + self.advantage_mode = self.config.advantage_estimator + self.beta = self.config.kl_beta + self.num_generations = self.config.num_generations + self.max_completion_length = self.config.max_completion_length + self.reward_fn = reward_fn if reward_fn is not None else self.config.reward_fn + self.reward_sources = self.config.reward_sources + self.reward_source = self.config.reward_source + self.kl_target = self.config.kl_target + self.kl_penalty_mode = self.config.kl_penalty_mode + self.reward_normalization = self.config.reward_normalization + self.mask_truncated_completions = self.config.mask_truncated_completions + self.minibatch_reuse_steps = self.config.minibatch_reuse_steps + self.entropy_bonus = self.config.entropy_bonus + self.output_dir = Path(self.config.output_dir) + self.learning_rate = self.config.learning_rate + self.batch_size = self.config.per_device_train_batch_size + self.rollout_batch_size = self.config.rollout_batch_size or self.batch_size + self.max_steps = self.config.max_steps + self.temperature = self.config.temperature + self.clip_epsilon = self.config.clip_epsilon + self.epsilon_low = self.config.epsilon_low + self.epsilon_high = self.config.epsilon_high + self.scale_rewards = self.config.scale_rewards + self.eval_steps = self.config.eval_steps + self.eval_num_batches = self.config.eval_num_batches + self.eval_num_generations = self.config.eval_num_generations or self.num_generations + self.generation_batch_size = self.config.generation_batch_size + self.score_chunk_size = self.config.score_chunk_size + self.precompute_reference_scores = self.config.precompute_reference_scores + self.logging_steps = self.config.logging_steps + self.save_steps = self.config.save_steps + self.max_seq_length = self.config.max_seq_length + dataset_size = len(train_dataset) if hasattr(train_dataset, "__len__") else 100 + self.iters = self.max_steps if self.max_steps > 0 else max( + 1, (dataset_size // max(1, self.batch_size)) * self.config.num_train_epochs + ) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.adapter_path = self.output_dir / "policy" + self.adapter_path.mkdir(parents=True, exist_ok=True) + self.prompt_samples: List[Dict[str, Any]] = [] + self.rollout_samples: List[Dict[str, Any]] = [] + self.eval_prompt_samples: List[Dict[str, Any]] = [] + self.eval_rollout_samples: List[Dict[str, Any]] = [] + self.eval_preference_samples: List[Dict[str, Any]] = [] + self.prepared_dataset_mode: Optional[str] = None + self.prompt_dataset_cursor = 0 + self.rollout_dataset_cursor = 0 + self._last_rollout_batch: Optional[RolloutBatch] = None + self._last_rollout_phase_metrics: Dict[str, float] = {} + self._init_native_state() + + if self.reward_model is None and self.reward_fn is None: + self.reward_fn = lambda response, context: len(response.split()) / 100.0 + + def _resolve_loss_type(self, loss_type: str) -> str: + if loss_type not in GRPO_LOSS_TYPES: + raise ValueError( + f"Unsupported GRPO loss_type '{loss_type}'. " + f"Supported values: {sorted(GRPO_LOSS_TYPES)}" + ) + return loss_type + + def _trainer_cursor_state(self) -> Dict[str, int]: + return { + "dataset": int(self.dataset_cursor), + "prompt_dataset": int(self.prompt_dataset_cursor), + "offline_rollout_dataset": int(self.rollout_dataset_cursor), + } + + def _restore_trainer_cursors(self, cursors: Dict[str, Any]) -> None: + super()._restore_trainer_cursors(cursors) + self.prompt_dataset_cursor = int(cursors.get("prompt_dataset", self.prompt_dataset_cursor)) + self.rollout_dataset_cursor = int(cursors.get("offline_rollout_dataset", self.rollout_dataset_cursor)) + + def _sampling_config_payload(self) -> Dict[str, Any]: + return { + "algorithm": self.algorithm, + "loss_type": self.resolved_loss_type, + "beta": self.beta, + "clip_epsilon": self.clip_epsilon, + "epsilon_low": self.epsilon_low, + "epsilon_high": self.epsilon_high, + "rollout_batch_size": self.rollout_batch_size, + "minibatch_reuse_steps": self.minibatch_reuse_steps, + "advantage_mode": self.advantage_mode, + "scale_rewards": self.scale_rewards, + "kl_target": self.kl_target, + "kl_penalty_mode": self.kl_penalty_mode, + "reward_source": self.reward_source, + "reward_normalization": self.reward_normalization, + "mask_truncated_completions": self.mask_truncated_completions, + "entropy_bonus": self.entropy_bonus, + "temperature": self.temperature, + "num_generations": self.num_generations, + "max_completion_length": self.max_completion_length, + "max_seq_length": self.max_seq_length, + "generation_batch_size": self.generation_batch_size, + "score_chunk_size": self.score_chunk_size, + } + + def _prepare_prompt_samples(self) -> None: + self.prompt_samples, self.rollout_samples, self.prepared_dataset_mode = _prepare_on_policy_samples( + self.train_dataset, + self.tokenizer, + self.config, + ) + if not self.prompt_samples and not self.rollout_samples: + raise ValueError("GRPOTrainer requires prompt or rollout samples.") + + def _prepare_eval_datasets(self) -> None: + self.eval_prompt_samples = [] + self.eval_rollout_samples = [] + self.eval_preference_samples = [] + if self.eval_dataset is not None: + self.eval_prompt_samples, self.eval_rollout_samples, _ = _prepare_on_policy_samples( + self.eval_dataset, + self.tokenizer, + self.config, + ) + if self.eval_preference_dataset is not None: + prepared = prepare_rl_dataset( + self.eval_preference_dataset, + mode="preference", + tokenizer=self.tokenizer, + chat_template=getattr(self.config, "chat_template", None), + ) + for sample_index, sample in enumerate(prepared): + prompt = sample.get("prompt", "") + prompt_ids = _encode_text(self.tokenizer, prompt) + chosen_ids = _encode_text(self.tokenizer, sample["chosen"], add_special_tokens=False) + rejected_ids = _encode_text(self.tokenizer, sample["rejected"], add_special_tokens=False) + self.eval_preference_samples.append( + { + "sample_index": sample_index, + "prompt_ids": prompt_ids, + "chosen_ids": prompt_ids + chosen_ids, + "rejected_ids": prompt_ids + rejected_ids, + "prompt_length": len(prompt_ids), + "chosen_completion_length": len(chosen_ids), + "rejected_completion_length": len(rejected_ids), + } + ) + + def _resolve_reward_evaluator(self) -> Any: + evaluator = _resolve_reward_evaluator( + self.reward_model, + self.reward_fn, + self.reward_sources, + ) + if evaluator is not None: + return evaluator + return lambda response, context: len(response.split()) / 100.0 + + def _next_prompt_batch(self) -> List[Dict[str, Any]]: + batch, self.prompt_dataset_cursor = _next_cursor_batch( + self.prompt_samples, + self.rollout_batch_size, + self.prompt_dataset_cursor, + self.algorithm, + ) + self.dataset_cursor = self.prompt_dataset_cursor + return batch + + def _next_offline_rollout_batch(self) -> List[Dict[str, Any]]: + batch, self.rollout_dataset_cursor = _next_cursor_batch( + self.rollout_samples, + self.rollout_batch_size * self.num_generations, + self.rollout_dataset_cursor, + self.algorithm, + ) + self.dataset_cursor = self.rollout_dataset_cursor + return batch + + def _collect_fixed_rollout_batch( + self, + samples: List[Dict[str, Any]], + cache_key: Optional[str] = None, + ) -> RolloutBatch: + cached_reference_logprobs = None + if cache_key is not None: + cached_reference_logprobs = _ensure_fixed_rollout_reference_cache(self, cache_key, samples) + return _collect_fixed_rollout_batch( + self, + samples, + cached_reference_logprobs=cached_reference_logprobs, + ) + + def _collect_rollout_batch( + self, + prompt_samples: List[Dict[str, Any]], + num_generations: Optional[int] = None, + cache_key: Optional[str] = None, + ) -> RolloutBatch: + generations = self.num_generations if num_generations is None else num_generations + reward_evaluator = self._resolve_reward_evaluator() + rollout_generate_started_at = time.perf_counter() + reward_eval_wall = 0.0 + reference_score_wall = 0.0 + returns_wall = 0.0 + if prompt_samples and "completion" in prompt_samples[0]: + rollout_batch = self._collect_fixed_rollout_batch(prompt_samples, cache_key=cache_key) + rollout_generate_wall = time.perf_counter() - rollout_generate_started_at + else: + rollout_batch = collect_rollouts( + _actual_model(self.model), + self.tokenizer, + prompt_samples=prompt_samples, + sampling_config={ + "num_generations": generations, + "temperature": self.temperature, + "max_completion_length": self.max_completion_length, + "max_seq_length": self.max_seq_length, + "generation_batch_size": self.generation_batch_size, + }, + reward_evaluator=None, + collect_sample_stats=self.entropy_bonus != 0.0, + ) + rollout_generate_wall = time.perf_counter() - rollout_generate_started_at + if self.reward_source == "offline": + raise ValueError("reward_source='offline' requires rollout samples with completion/reward fields.") + reward_eval_started_at = time.perf_counter() + reward_batch = evaluate_rewards(rollout_batch, reward_evaluator) + reward_eval_wall += time.perf_counter() - reward_eval_started_at + rollout_batch.rewards = reward_batch.scalar_rewards + if prompt_samples and "completion" in prompt_samples[0]: + if self.reward_source == "online": + reward_eval_started_at = time.perf_counter() + reward_batch = evaluate_rewards(rollout_batch, reward_evaluator) + reward_eval_wall += time.perf_counter() - reward_eval_started_at + rollout_batch.rewards = reward_batch.scalar_rewards + elif self.reward_source == "hybrid": + reward_eval_started_at = time.perf_counter() + reward_batch = evaluate_rewards(rollout_batch, reward_evaluator) + reward_eval_wall += time.perf_counter() - reward_eval_started_at + rollout_batch.rewards = rollout_batch.rewards + reward_batch.scalar_rewards + if rollout_batch.rewards is None: + rollout_batch.rewards = mx.zeros((len(rollout_batch.prompt_ids),), dtype=mx.float32) + if self.entropy_bonus and rollout_batch.token_entropies is not None: + entropy_bonus = mx.mean(rollout_batch.token_entropies, axis=-1) * self.entropy_bonus + rollout_batch.rewards = rollout_batch.rewards + entropy_bonus.astype(mx.float32) + if self.reward_normalization != "none": + rollout_batch.rewards = _normalize_reward_values( + rollout_batch.rewards, + rollout_batch.prompt_group_indices, + self.reward_normalization, + ) + reference_score_started_at = time.perf_counter() + rollout_batch = score_rollout_references( + _actual_model(self.reference_policy.model) if self.reference_policy is not None else None, + rollout_batch, + batch_size=_rollout_score_batch_size(self, num_generations=generations), + token_budget=self.score_chunk_size, + ) + reference_score_wall = time.perf_counter() - reference_score_started_at + if self.mask_truncated_completions: + rollout_batch = _apply_truncation_mask_to_rollout(rollout_batch) + returns_mode = "rloo" if self.advantage_mode == "rloo" else "group_zscore" + if returns_mode == "group_zscore" and self.scale_rewards is False: + returns_mode = "group_center" + returns_started_at = time.perf_counter() + returns, advantages = compute_returns_and_advantages( + rewards=rollout_batch.rewards, + prompt_group_indices=rollout_batch.prompt_group_indices, + mode=returns_mode, + ) + returns_wall = time.perf_counter() - returns_started_at + rollout_batch.returns = returns + rollout_batch.advantages = advantages + rollout_batch.policy_eval.returns = returns + rollout_batch.policy_eval.advantages = advantages + self._last_rollout_phase_metrics = { + "rollout_generate_wall": float(rollout_generate_wall), + "reward_eval_wall": float(reward_eval_wall), + "reference_score_wall": float(reference_score_wall), + "returns_wall": float(returns_wall), + } + return rollout_batch + + def _evaluate_rollout_metrics(self) -> Dict[str, Any]: + prompt_limit = max(1, self.eval_num_batches or 1) * self.rollout_batch_size + rollout_limit = max(1, self.eval_num_batches or 1) * self.rollout_batch_size * self.num_generations + if self.eval_rollout_samples: + rollout_batch = self._collect_rollout_batch( + self.eval_rollout_samples[:rollout_limit], + cache_key=f"{self.algorithm}.eval_rollout_reference_logprobs", + ) + elif self.eval_prompt_samples: + rollout_batch = self._collect_rollout_batch( + self.eval_prompt_samples[:prompt_limit], + num_generations=self.eval_num_generations, + ) + else: + return {} + reference_model = _actual_model(self.reference_policy.model) + effective_beta = self._effective_kl_beta(rollout_batch) + loss, _ = grpo_recompute_loss( + model=_actual_model(self.model), + reference_model=reference_model, + input_ids=rollout_batch.policy_eval.input_ids, + prompt_lengths=rollout_batch.policy_eval.prompt_lengths, + completion_lengths=rollout_batch.policy_eval.completion_lengths, + rollout_logprobs=rollout_batch.policy_eval.rollout_logprobs, + old_token_logprobs=rollout_batch.policy_eval.old_token_logprobs, + reference_logprobs=rollout_batch.policy_eval.reference_logprobs, + advantages=rollout_batch.policy_eval.advantages, + beta=effective_beta, + clip_epsilon=self.clip_epsilon, + epsilon_low=self.epsilon_low, + epsilon_high=self.epsilon_high, + temperature=self.temperature, + loss_type=self.resolved_loss_type, + max_completion_length=self.max_completion_length, + ) + return summarize_rollout_metrics(rollout_batch, policy_loss=float(loss.item())) + + def _evaluate_preference_metrics(self) -> Dict[str, Any]: + if not self.eval_preference_samples: + return {} + limit = max(1, self.eval_num_batches or 1) * self.batch_size + samples = self.eval_preference_samples[:limit] + chosen_batch = make_policy_eval_batch( + [sample["chosen_ids"] for sample in samples], + pad_id=_pad_token_id(self.tokenizer), + mode="completion", + prompt_lengths=[sample["prompt_length"] for sample in samples], + completion_lengths=[sample["chosen_completion_length"] for sample in samples], + sample_indices=mx.array([sample["sample_index"] for sample in samples]), + ) + rejected_batch = make_policy_eval_batch( + [sample["rejected_ids"] for sample in samples], + pad_id=_pad_token_id(self.tokenizer), + mode="completion", + prompt_lengths=[sample["prompt_length"] for sample in samples], + completion_lengths=[sample["rejected_completion_length"] for sample in samples], + sample_indices=mx.array([sample["sample_index"] for sample in samples]), + ) + chosen_scores = score_policy_in_chunks( + _actual_model(self.model), + chosen_batch, + batch_size=max(1, self.batch_size), + token_budget=self.score_chunk_size, + mode="completion", + ).summed_logprobs + rejected_scores = score_policy_in_chunks( + _actual_model(self.model), + rejected_batch, + batch_size=max(1, self.batch_size), + token_budget=self.score_chunk_size, + mode="completion", + ).summed_logprobs + return { + "preference_win_rate": float(mx.mean((chosen_scores > rejected_scores).astype(mx.float32)).item()) + } + + def evaluate(self) -> Dict[str, Any]: + self._prepare_eval_datasets() + if self.reference_policy is None: + self._ensure_reference_policy() + with self._preserve_rng_state(): + mx.random.seed(int(self.seed) + 100000 + int(self.global_step)) + metrics: Dict[str, Any] = {} + metrics.update(self._evaluate_rollout_metrics()) + metrics.update(self._evaluate_preference_metrics()) + if not metrics: + return {} + return self._record_metrics("eval", metrics) + + def train(self, resume_from_checkpoint: Optional[str] = None): + if self.use_native: + return self._train_native(resume_from_checkpoint=resume_from_checkpoint) + return self._train_subprocess() + + def _train_native(self, resume_from_checkpoint: Optional[str] = None): + self._apply_lora_if_needed() + self._prepare_prompt_samples() + self._prepare_eval_datasets() + if resume_from_checkpoint is None: + self._seed_training_run() + + actual_model = _actual_model(self.model) + optimizer = self._optimizer_for_training() + self.optimizer = optimizer + + if resume_from_checkpoint is not None: + self.load_state(optimizer=optimizer, checkpoint_dir=Path(resume_from_checkpoint)) + else: + self._ensure_reference_policy() + + reference_model = _actual_model(self.reference_policy.model) + if self.rollout_samples: + _ensure_fixed_rollout_reference_cache( + self, + f"{self.algorithm}.train_rollout_reference_logprobs", + self.rollout_samples, + ) - # Define loss and grad function - def loss_fn(model, batch_data): - chosen_ids, rejected_ids, chosen_lengths, rejected_lengths = batch_data + effective_beta = self.beta - loss, ntoks = compute_dpo_loss( + def loss_fn(model, batch): + loss, _ = grpo_recompute_loss( model=model, - chosen_ids=chosen_ids, - rejected_ids=rejected_ids, - chosen_lengths=chosen_lengths, - rejected_lengths=rejected_lengths, - beta=self.beta, - label_smoothing=self.label_smoothing, + reference_model=reference_model, + input_ids=batch.input_ids, + prompt_lengths=batch.prompt_lengths, + completion_lengths=batch.completion_lengths, + rollout_logprobs=batch.rollout_logprobs, + old_token_logprobs=batch.old_token_logprobs, + reference_logprobs=batch.reference_logprobs, + advantages=batch.advantages, + beta=effective_beta, + clip_epsilon=self.clip_epsilon, + epsilon_low=self.epsilon_low, + epsilon_high=self.epsilon_high, + temperature=self.temperature, + loss_type=self.resolved_loss_type, + max_completion_length=self.max_completion_length, ) return loss - loss_and_grad = nn.value_and_grad(actual_model, loss_fn) - - total_loss = 0.0 - for step in range(self.iters): - # Get batch - batch_idx = step % len(tokenized_data) - sample = tokenized_data[batch_idx] - - # Pad sequences - max_len = max(sample['chosen_length'], sample['rejected_length']) - pad_id = self.tokenizer.pad_token_id or 0 - - chosen_padded = self._pad_to_length(sample['chosen_ids'], max_len, pad_id) - rejected_padded = self._pad_to_length(sample['rejected_ids'], max_len, pad_id) - - # Create batch tensors - chosen_ids = mx.array([chosen_padded]) - rejected_ids = mx.array([rejected_padded]) - chosen_lengths = mx.array([sample['chosen_length']]) - rejected_lengths = mx.array([sample['rejected_length']]) - - batch_data = (chosen_ids, rejected_ids, chosen_lengths, rejected_lengths) - - # Compute loss and gradients - loss, grads = loss_and_grad(actual_model, batch_data) - optimizer.update(actual_model, grads) - mx.eval(actual_model.parameters(), optimizer.state) - - total_loss += loss.item() - - # Logging - if (step + 1) % self.logging_steps == 0: - avg_loss = total_loss / self.logging_steps - print(f" Step {step + 1}/{self.iters} | Loss: {avg_loss:.4f}") - total_loss = 0.0 - - # Save checkpoint - if (step + 1) % self.save_steps == 0: - self._save_adapters(step + 1) + value_and_grad = nn.value_and_grad(actual_model, loss_fn) + running_loss = 0.0 + last_loss = None + + while self.global_step < self.iters: + if self.reward_source == "offline" and self.rollout_samples: + prompt_samples = self._next_offline_rollout_batch() + rollout_batch = self._collect_rollout_batch( + prompt_samples, + cache_key=f"{self.algorithm}.train_rollout_reference_logprobs", + ) + else: + prompt_samples = self._next_prompt_batch() + rollout_batch = self._collect_rollout_batch(prompt_samples) + self._last_rollout_batch = rollout_batch + effective_beta = self._effective_kl_beta(rollout_batch) + step_loss = 0.0 + step_updates = 0 + policy_update_started_at = time.perf_counter() + for _ in range(max(1, self.minibatch_reuse_steps)): + minibatches = assemble_minibatches( + rollout_batch.policy_eval, + minibatch_size=_rollout_score_batch_size(self), + shuffle=False, + mode="completion", + token_budget=self.score_chunk_size, + ) + for minibatch in minibatches: + loss, grads = value_and_grad(actual_model, minibatch) + optimizer.update(actual_model, grads) + mx.eval(actual_model.parameters(), optimizer.state) + step_loss += loss.item() + step_updates += 1 + policy_update_wall = time.perf_counter() - policy_update_started_at + + last_loss = step_loss / max(1, step_updates) + running_loss += last_loss + self.global_step += 1 + train_row = self._record_metrics( + "train", + { + **summarize_rollout_metrics(rollout_batch, policy_loss=last_loss), + **self._last_rollout_phase_metrics, + "policy_update_wall": float(policy_update_wall), + "policy_update_steps": float(step_updates), + }, + ) - # Final save - self._save_adapters(self.iters) + if self.global_step % self.logging_steps == 0: + print(f"GRPO step {self.global_step}/{self.iters} | {self._format_metric_summary(train_row)}") + running_loss = 0.0 - print("\n" + "=" * 70) - print("DPO Training Complete!") - print("=" * 70) - print(f" Adapters saved to: {self.adapter_path}") + if self.eval_steps and self.global_step % self.eval_steps == 0: + eval_row = self.evaluate() + if eval_row: + print(f"GRPO eval | {self._format_metric_summary(eval_row, namespace='eval')}") - return {"status": "success", "adapter_path": str(self.adapter_path)} + if self.global_step % self.save_steps == 0: + self.save_state(optimizer=optimizer) - def _save_adapters(self, step: int): - """Save adapter weights and config.""" - if _save_adapters_and_config(self.model, self.adapter_path): - print(f" ✓ Saved checkpoint at step {step}") + self.save_state(optimizer=optimizer) + return { + "status": "success", + "adapter_path": str(self.adapter_path), + "global_step": self.global_step, + "final_loss": last_loss, + } def _train_subprocess(self): - """Fallback: Train using subprocess (SFT approximation).""" warnings.warn( - "Native DPO training not available. Using SFT on chosen responses. " - "Install mlx-lm[train] for proper DPO loss.", - UserWarning + "Native GRPO training not available. Using SFT approximation.", + UserWarning, ) - - print("\n[Using Subprocess Training (SFT Approximation)]") - - # Prepare SFT data from chosen responses train_file = self.output_dir / "train.jsonl" valid_file = self.output_dir / "valid.jsonl" - - with open(train_file, 'w') as f: + with open(train_file, "w") as handle: for sample in self.train_dataset: - if 'prompt' in sample and 'chosen' in sample: - messages = [ - {"role": "user", "content": sample['prompt']}, - {"role": "assistant", "content": sample['chosen']} - ] - f.write(json.dumps({"messages": messages}) + '\n') - - import shutil - shutil.copy(train_file, valid_file) - - model_name = getattr(self.model, 'model_name', 'model') - + prompt = sample.get("prompt", sample.get("question", "")) + if not prompt: + continue + messages = [{"role": "user", "content": prompt}] + if "response" in sample or "answer" in sample: + response = sample.get("response", sample.get("answer", "")) + messages.append({"role": "assistant", "content": response}) + handle.write(json.dumps({"messages": messages}) + "\n") + valid_file.write_text(train_file.read_text()) cmd = [ "mlx_lm.lora", - "--model", model_name, + "--model", + getattr(self.model, "model_name", "model"), "--train", - "--data", str(self.output_dir), - "--iters", str(self.iters), - "--learning-rate", str(self.learning_rate), - "--batch-size", str(self.batch_size), - "--adapter-path", str(self.adapter_path), + "--data", + str(self.output_dir), + "--iters", + str(self.iters), + "--adapter-path", + str(self.adapter_path), ] - - print(f"Running: {' '.join(cmd)}") subprocess.run(cmd, check=True) - - print("DPO Training Complete (SFT approximation)!") return {"status": "success", "adapter_path": str(self.adapter_path)} -class ORPOTrainer: - """ - Odds Ratio Preference Optimization Trainer. - - ORPO combines SFT and preference alignment in a single training step, - making it simpler and more memory-efficient than DPO. - - Compatible with TRL's ORPOTrainer API. - Now with PROPER ORPO loss implementation! - - Example: - >>> trainer = ORPOTrainer( - ... model=model, - ... train_dataset=preference_dataset, - ... tokenizer=tokenizer, - ... args=ORPOConfig(beta=0.1), - ... ) - >>> trainer.train() - """ +class PPOTrainer(_RLTrainerBase): + algorithm = "ppo" + requires_reference_policy = True def __init__( self, model: Any, train_dataset: Any, + eval_dataset: Any = None, + eval_preference_dataset: Any = None, tokenizer: Optional[Any] = None, - args: Optional[ORPOConfig] = None, + reward_fn: Optional[Callable] = None, + reward_model: Optional[Any] = None, + ref_model: Optional[Any] = None, + value_model: Optional[Any] = None, + args: Optional[PPOConfig] = None, use_native: bool = True, - **kwargs + **kwargs, ): self.model = model self.train_dataset = train_dataset - self.tokenizer = tokenizer or getattr(model, 'tokenizer', None) + self.eval_dataset = eval_dataset + self.eval_preference_dataset = eval_preference_dataset + self.tokenizer = tokenizer or getattr(model, "tokenizer", None) + self.ref_model = ref_model + self.config = args or PPOConfig() + self.reward_model = reward_model or self.config.reward_model + self.reward_fn = reward_fn if reward_fn is not None else self.config.reward_fn + self.reward_sources = self.config.reward_sources + self.reward_source = self.config.reward_source + self.value_model = value_model or self.config.value_model or build_value_model(model) self.use_native = use_native and HAS_NATIVE_TRAINING - - if args is None: - args = ORPOConfig() - - self.config = args - self.beta = args.beta - self.output_dir = Path(args.output_dir) - self.learning_rate = args.learning_rate - self.batch_size = args.per_device_train_batch_size - self.max_steps = args.max_steps - self.max_seq_length = args.max_seq_length - self.logging_steps = args.logging_steps - self.save_steps = args.save_steps - - if self.max_steps > 0: - self.iters = self.max_steps - else: - dataset_size = len(train_dataset) if hasattr(train_dataset, '__len__') else 100 - self.iters = max(1, (dataset_size // self.batch_size) * args.num_train_epochs) - + self.output_dir = Path(self.config.output_dir) + self.learning_rate = self.config.learning_rate + self.value_learning_rate = self.config.value_learning_rate + self.batch_size = self.config.per_device_train_batch_size + self.rollout_batch_size = self.config.rollout_batch_size or self.batch_size + self.max_steps = self.config.max_steps + self.max_seq_length = self.config.max_seq_length + self.max_completion_length = self.config.max_completion_length + self.num_generations = self.config.num_generations + self.ppo_epochs = self.config.ppo_epochs + self.minibatch_reuse_steps = self.config.minibatch_reuse_steps + self.temperature = self.config.temperature + self.clip_epsilon = self.config.clip_epsilon + self.beta = self.config.kl_beta + self.reward_normalization = self.config.reward_normalization + self.mask_truncated_completions = self.config.mask_truncated_completions + self.entropy_bonus = self.config.entropy_bonus + self.gamma = self.config.gamma + self.gae_lambda = self.config.gae_lambda + self.normalize_advantages = self.config.normalize_advantages + self.advantage_estimator = self.config.advantage_estimator + self.kl_target = self.config.kl_target + self.kl_penalty_mode = self.config.kl_penalty_mode + self.eval_steps = self.config.eval_steps + self.eval_num_batches = self.config.eval_num_batches + self.eval_num_generations = self.config.eval_num_generations or self.num_generations + self.generation_batch_size = self.config.generation_batch_size + self.score_chunk_size = self.config.score_chunk_size + self.precompute_reference_scores = self.config.precompute_reference_scores + self.logging_steps = self.config.logging_steps + self.save_steps = self.config.save_steps + dataset_size = len(train_dataset) if hasattr(train_dataset, "__len__") else 100 + self.iters = self.max_steps if self.max_steps > 0 else max( + 1, (dataset_size // max(1, self.batch_size)) * self.config.num_train_epochs + ) self.output_dir.mkdir(parents=True, exist_ok=True) - self.adapter_path = self.output_dir / "adapters" + self.adapter_path = self.output_dir / "policy" self.adapter_path.mkdir(parents=True, exist_ok=True) + self.prompt_samples: List[Dict[str, Any]] = [] + self.rollout_samples: List[Dict[str, Any]] = [] + self.eval_prompt_samples: List[Dict[str, Any]] = [] + self.eval_rollout_samples: List[Dict[str, Any]] = [] + self.eval_preference_samples: List[Dict[str, Any]] = [] + self.prepared_dataset_mode: Optional[str] = None + self.prompt_dataset_cursor = 0 + self.rollout_dataset_cursor = 0 + self._last_rollout_batch: Optional[RolloutBatch] = None + self._value_train_target = _ScalarRoleTrainTarget( + _actual_model(self.value_model.base_model), + self.value_model.head, + ) + self._init_native_state() - print(f"ORPOTrainer initialized:") - print(f" Beta: {self.beta}") - print(f" Learning rate: {self.learning_rate}") - print(f" Iterations: {self.iters}") - print(f" Native training: {self.use_native}") - - def _tokenize_preference_pair(self, sample: Dict) -> Dict: - """Tokenize a preference pair.""" - prompt = sample.get('prompt', '') - chosen = sample.get('chosen', '') - rejected = sample.get('rejected', '') + def _optimizer_learning_rates(self) -> Dict[str, float]: + return {"policy": self.learning_rate, "value": self.value_learning_rate} - chosen_ids = self.tokenizer.encode(prompt + chosen) - rejected_ids = self.tokenizer.encode(prompt + rejected) + def _trainer_cursor_state(self) -> Dict[str, int]: + return { + "dataset": int(self.dataset_cursor), + "prompt_dataset": int(self.prompt_dataset_cursor), + "offline_rollout_dataset": int(self.rollout_dataset_cursor), + } - if len(chosen_ids) > self.max_seq_length: - chosen_ids = chosen_ids[:self.max_seq_length] - if len(rejected_ids) > self.max_seq_length: - rejected_ids = rejected_ids[:self.max_seq_length] + def _restore_trainer_cursors(self, cursors: Dict[str, Any]) -> None: + super()._restore_trainer_cursors(cursors) + self.prompt_dataset_cursor = int(cursors.get("prompt_dataset", self.prompt_dataset_cursor)) + self.rollout_dataset_cursor = int(cursors.get("offline_rollout_dataset", self.rollout_dataset_cursor)) + def _sampling_config_payload(self) -> Dict[str, Any]: return { - 'chosen_ids': chosen_ids, - 'rejected_ids': rejected_ids, - 'chosen_length': len(chosen_ids), - 'rejected_length': len(rejected_ids), + "algorithm": self.algorithm, + "beta": self.beta, + "clip_epsilon": self.clip_epsilon, + "rollout_batch_size": self.rollout_batch_size, + "minibatch_reuse_steps": self.minibatch_reuse_steps, + "gamma": self.gamma, + "gae_lambda": self.gae_lambda, + "normalize_advantages": self.normalize_advantages, + "value_learning_rate": self.value_learning_rate, + "kl_target": self.kl_target, + "kl_penalty_mode": self.kl_penalty_mode, + "reward_source": self.reward_source, + "reward_normalization": self.reward_normalization, + "mask_truncated_completions": self.mask_truncated_completions, + "entropy_bonus": self.entropy_bonus, + "advantage_estimator": self.advantage_estimator, + "temperature": self.temperature, + "num_generations": self.num_generations, + "max_completion_length": self.max_completion_length, + "max_seq_length": self.max_seq_length, + "generation_batch_size": self.generation_batch_size, + "score_chunk_size": self.score_chunk_size, } - def _pad_to_length(self, ids: List[int], length: int, pad_id: int = 0) -> List[int]: - if len(ids) >= length: - return ids[:length] - return ids + [pad_id] * (length - len(ids)) + def _prepare_prompt_samples(self) -> None: + self.prompt_samples, self.rollout_samples, self.prepared_dataset_mode = _prepare_on_policy_samples( + self.train_dataset, + self.tokenizer, + self.config, + ) + if not self.prompt_samples and not self.rollout_samples: + raise ValueError("PPOTrainer requires prompt or rollout samples.") + + def _prepare_eval_datasets(self) -> None: + self.eval_prompt_samples = [] + self.eval_rollout_samples = [] + self.eval_preference_samples = [] + if self.eval_dataset is not None: + self.eval_prompt_samples, self.eval_rollout_samples, _ = _prepare_on_policy_samples( + self.eval_dataset, + self.tokenizer, + self.config, + ) + if self.eval_preference_dataset is not None: + prepared = prepare_rl_dataset( + self.eval_preference_dataset, + mode="preference", + tokenizer=self.tokenizer, + chat_template=getattr(self.config, "chat_template", None), + ) + for sample_index, sample in enumerate(prepared): + prompt = sample.get("prompt", "") + prompt_ids = _encode_text(self.tokenizer, prompt) + chosen_ids = _encode_text(self.tokenizer, sample["chosen"], add_special_tokens=False) + rejected_ids = _encode_text(self.tokenizer, sample["rejected"], add_special_tokens=False) + self.eval_preference_samples.append( + { + "sample_index": sample_index, + "prompt_ids": prompt_ids, + "chosen_ids": prompt_ids + chosen_ids, + "rejected_ids": prompt_ids + rejected_ids, + "prompt_length": len(prompt_ids), + "chosen_completion_length": len(chosen_ids), + "rejected_completion_length": len(rejected_ids), + } + ) + + def _resolve_reward_evaluator(self) -> Any: + evaluator = _resolve_reward_evaluator( + self.reward_model, + self.reward_fn, + self.reward_sources, + ) + if evaluator is not None: + return evaluator + return lambda response, context: len(response.split()) / 100.0 + + def _next_prompt_batch(self) -> List[Dict[str, Any]]: + batch, self.prompt_dataset_cursor = _next_cursor_batch( + self.prompt_samples, + self.rollout_batch_size, + self.prompt_dataset_cursor, + self.algorithm, + ) + self.dataset_cursor = self.prompt_dataset_cursor + return batch + + def _next_offline_rollout_batch(self) -> List[Dict[str, Any]]: + batch, self.rollout_dataset_cursor = _next_cursor_batch( + self.rollout_samples, + self.rollout_batch_size * self.num_generations, + self.rollout_dataset_cursor, + self.algorithm, + ) + self.dataset_cursor = self.rollout_dataset_cursor + return batch - def train(self): - """Train using ORPO with proper loss.""" - print("=" * 70) - print("Starting ORPO Training") - print("=" * 70) + def _collect_fixed_rollout_batch( + self, + samples: List[Dict[str, Any]], + cache_key: Optional[str] = None, + ) -> RolloutBatch: + cached_reference_logprobs = None + if cache_key is not None: + cached_reference_logprobs = _ensure_fixed_rollout_reference_cache(self, cache_key, samples) + return _collect_fixed_rollout_batch( + self, + samples, + cached_reference_logprobs=cached_reference_logprobs, + ) - if self.use_native: - return self._train_native() + def _collect_rollout_batch( + self, + prompt_samples: List[Dict[str, Any]], + num_generations: Optional[int] = None, + cache_key: Optional[str] = None, + ) -> RolloutBatch: + generations = self.num_generations if num_generations is None else num_generations + reward_evaluator = self._resolve_reward_evaluator() + if prompt_samples and "completion" in prompt_samples[0]: + rollout_batch = self._collect_fixed_rollout_batch(prompt_samples, cache_key=cache_key) else: - return self._train_subprocess() - - def _train_native(self): - """Train with native ORPO loss.""" - print("\n[Using Native ORPO Training with Proper Loss]") + rollout_batch = collect_rollouts( + _actual_model(self.model), + self.tokenizer, + prompt_samples=prompt_samples, + sampling_config={ + "num_generations": generations, + "temperature": self.temperature, + "max_completion_length": self.max_completion_length, + "max_seq_length": self.max_seq_length, + "generation_batch_size": self.generation_batch_size, + }, + reward_evaluator=None, + collect_sample_stats=self.entropy_bonus != 0.0, + ) + if self.reward_source == "offline": + raise ValueError("reward_source='offline' requires rollout samples with completion/reward fields.") + reward_batch = evaluate_rewards(rollout_batch, reward_evaluator) + rollout_batch.rewards = reward_batch.scalar_rewards + if prompt_samples and "completion" in prompt_samples[0]: + if self.reward_source == "online": + reward_batch = evaluate_rewards(rollout_batch, reward_evaluator) + rollout_batch.rewards = reward_batch.scalar_rewards + elif self.reward_source == "hybrid": + reward_batch = evaluate_rewards(rollout_batch, reward_evaluator) + rollout_batch.rewards = rollout_batch.rewards + reward_batch.scalar_rewards + if rollout_batch.rewards is None: + rollout_batch.rewards = mx.zeros((len(rollout_batch.prompt_ids),), dtype=mx.float32) + if self.entropy_bonus and rollout_batch.token_entropies is not None: + entropy_bonus = mx.mean(rollout_batch.token_entropies, axis=-1) * self.entropy_bonus + rollout_batch.rewards = rollout_batch.rewards + entropy_bonus.astype(mx.float32) + if self.reward_normalization != "none": + rollout_batch.rewards = _normalize_reward_values( + rollout_batch.rewards, + rollout_batch.prompt_group_indices, + self.reward_normalization, + ) + rollout_batch = score_rollout_references( + _actual_model(self.reference_policy.model), + rollout_batch, + batch_size=_rollout_score_batch_size(self, num_generations=generations), + token_budget=self.score_chunk_size, + ) + if self.mask_truncated_completions: + rollout_batch = _apply_truncation_mask_to_rollout(rollout_batch) + rollout_batch = predict_rollout_values( + self.value_model, + rollout_batch, + batch_size=_rollout_score_batch_size(self, num_generations=generations), + token_budget=self.score_chunk_size, + ) + advantage_mode = self.advantage_estimator + rollout_values = rollout_batch.value_predictions if advantage_mode == "gae" else None + returns, advantages = compute_returns_and_advantages( + rewards=rollout_batch.rewards, + values=rollout_values, + prompt_group_indices=rollout_batch.prompt_group_indices, + mode=advantage_mode, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + normalize=self.normalize_advantages, + ) + rollout_batch.returns = returns + rollout_batch.advantages = advantages + rollout_batch.policy_eval.returns = returns + rollout_batch.policy_eval.advantages = advantages + rollout_batch.policy_eval.reference_logprobs = rollout_batch.reference_logprobs + rollout_batch.policy_eval.value_predictions = rollout_batch.value_predictions + return rollout_batch + + def _evaluate_rollout_metrics(self) -> Dict[str, Any]: + prompt_limit = max(1, self.eval_num_batches or 1) * self.rollout_batch_size + rollout_limit = max(1, self.eval_num_batches or 1) * self.rollout_batch_size * self.num_generations + if self.eval_rollout_samples: + rollout_batch = self._collect_rollout_batch( + self.eval_rollout_samples[:rollout_limit], + cache_key=f"{self.algorithm}.eval_rollout_reference_logprobs", + ) + elif self.eval_prompt_samples: + rollout_batch = self._collect_rollout_batch( + self.eval_prompt_samples[:prompt_limit], + num_generations=self.eval_num_generations, + ) + else: + return {} + effective_beta = self._effective_kl_beta(rollout_batch) + policy_loss, _ = ppo_sequence_loss( + model=_actual_model(self.model), + batch=rollout_batch.policy_eval, + beta=effective_beta, + clip_epsilon=self.clip_epsilon, + temperature=self.temperature, + ) + value_loss, _ = value_model_regression_loss( + self.value_model, + input_ids=rollout_batch.policy_eval.input_ids, + sequence_lengths=rollout_batch.policy_eval.sequence_lengths, + targets=rollout_batch.policy_eval.returns, + prompt_lengths=rollout_batch.policy_eval.prompt_lengths, + completion_lengths=rollout_batch.policy_eval.completion_lengths, + ) + return summarize_rollout_metrics( + rollout_batch, + policy_loss=float(policy_loss.item()), + value_loss=float(value_loss.item()), + ) - if hasattr(self.model, '_apply_lora') and not getattr(self.model, '_lora_applied', False): - self.model._apply_lora() + def _evaluate_preference_metrics(self) -> Dict[str, Any]: + if not self.eval_preference_samples: + return {} + limit = max(1, self.eval_num_batches or 1) * self.batch_size + samples = self.eval_preference_samples[:limit] + chosen_batch = make_policy_eval_batch( + [sample["chosen_ids"] for sample in samples], + pad_id=_pad_token_id(self.tokenizer), + mode="completion", + prompt_lengths=[sample["prompt_length"] for sample in samples], + completion_lengths=[sample["chosen_completion_length"] for sample in samples], + sample_indices=mx.array([sample["sample_index"] for sample in samples]), + ) + rejected_batch = make_policy_eval_batch( + [sample["rejected_ids"] for sample in samples], + pad_id=_pad_token_id(self.tokenizer), + mode="completion", + prompt_lengths=[sample["prompt_length"] for sample in samples], + completion_lengths=[sample["rejected_completion_length"] for sample in samples], + sample_indices=mx.array([sample["sample_index"] for sample in samples]), + ) + chosen_scores = score_policy_in_chunks( + _actual_model(self.model), + chosen_batch, + batch_size=max(1, self.batch_size), + token_budget=self.score_chunk_size, + mode="completion", + ).summed_logprobs + rejected_scores = score_policy_in_chunks( + _actual_model(self.model), + rejected_batch, + batch_size=max(1, self.batch_size), + token_budget=self.score_chunk_size, + mode="completion", + ).summed_logprobs + return { + "preference_win_rate": float(mx.mean((chosen_scores > rejected_scores).astype(mx.float32)).item()) + } - # Prepare data - tokenized_data = [] - for sample in self.train_dataset: - if 'prompt' in sample and 'chosen' in sample and 'rejected' in sample: - tokenized_data.append(self._tokenize_preference_pair(sample)) - print(f"✓ Prepared {len(tokenized_data)} preference pairs") + def evaluate(self) -> Dict[str, Any]: + self._prepare_eval_datasets() + if self.reference_policy is None: + self._ensure_reference_policy() + with self._preserve_rng_state(): + mx.random.seed(int(self.seed) + 100000 + int(self.global_step)) + metrics: Dict[str, Any] = {} + metrics.update(self._evaluate_rollout_metrics()) + metrics.update(self._evaluate_preference_metrics()) + if not metrics: + return {} + return self._record_metrics("eval", metrics) + + def train(self, resume_from_checkpoint: Optional[str] = None): + if not self.use_native: + raise ValueError("PPOTrainer requires native MLX training support.") + return self._train_native(resume_from_checkpoint=resume_from_checkpoint) + + def _train_native(self, resume_from_checkpoint: Optional[str] = None): + self._apply_lora_if_needed() + if hasattr(self.value_model.base_model, "_apply_lora") and not getattr(self.value_model.base_model, "_lora_applied", False): + self.value_model.base_model._apply_lora() + self._prepare_prompt_samples() + self._prepare_eval_datasets() + if resume_from_checkpoint is None: + self._seed_training_run() + + policy_model = _actual_model(self.model) + policy_optimizer = self._optimizer_for_training(self.learning_rate) + value_optimizer = self._optimizer_for_training(self.value_learning_rate) + self.optimizer = policy_optimizer + self.optimizers = {"policy": policy_optimizer, "value": value_optimizer} + + if resume_from_checkpoint is not None: + self.load_state( + checkpoint_dir=Path(resume_from_checkpoint), + optimizers=self.optimizers, + ) + else: + self._ensure_reference_policy() + if self.rollout_samples: + _ensure_fixed_rollout_reference_cache( + self, + f"{self.algorithm}.train_rollout_reference_logprobs", + self.rollout_samples, + ) - actual_model = self.model.model if hasattr(self.model, 'model') else self.model - lr_schedule = optim.cosine_decay(self.learning_rate, self.iters) - optimizer = optim.AdamW(learning_rate=lr_schedule) + effective_beta = self.beta - def loss_fn(model, batch_data): - chosen_ids, rejected_ids, chosen_lengths, rejected_lengths = batch_data - loss, _ = compute_orpo_loss( - model, chosen_ids, rejected_ids, chosen_lengths, rejected_lengths, self.beta + def policy_loss_fn(model, batch): + loss, _ = ppo_sequence_loss( + model=model, + batch=batch, + beta=effective_beta, + clip_epsilon=self.clip_epsilon, + temperature=self.temperature, ) return loss - loss_and_grad = nn.value_and_grad(actual_model, loss_fn) - - total_loss = 0.0 - for step in range(self.iters): - batch_idx = step % len(tokenized_data) - sample = tokenized_data[batch_idx] - - max_len = max(sample['chosen_length'], sample['rejected_length']) - pad_id = self.tokenizer.pad_token_id or 0 - - chosen_ids = mx.array([self._pad_to_length(sample['chosen_ids'], max_len, pad_id)]) - rejected_ids = mx.array([self._pad_to_length(sample['rejected_ids'], max_len, pad_id)]) - chosen_lengths = mx.array([sample['chosen_length']]) - rejected_lengths = mx.array([sample['rejected_length']]) - - loss, grads = loss_and_grad(actual_model, (chosen_ids, rejected_ids, chosen_lengths, rejected_lengths)) - optimizer.update(actual_model, grads) - mx.eval(actual_model.parameters(), optimizer.state) - - total_loss += loss.item() - - if (step + 1) % self.logging_steps == 0: - print(f" Step {step + 1}/{self.iters} | Loss: {total_loss / self.logging_steps:.4f}") - total_loss = 0.0 - - # Save adapters and config - _save_adapters_and_config(self.model, self.adapter_path) + def value_loss_fn(_, batch): + loss, _ = value_model_regression_loss( + self.value_model, + input_ids=batch.input_ids, + sequence_lengths=batch.sequence_lengths, + targets=batch.returns, + prompt_lengths=batch.prompt_lengths, + completion_lengths=batch.completion_lengths, + ) + return loss - print("\n" + "=" * 70) - print("ORPO Training Complete!") - print("=" * 70) - print(f" Adapters saved to: {self.adapter_path}") - return {"status": "success", "adapter_path": str(self.adapter_path)} + policy_value_and_grad = nn.value_and_grad(policy_model, policy_loss_fn) + value_value_and_grad = nn.value_and_grad(self._value_train_target, value_loss_fn) + running_loss = 0.0 + last_policy_loss = None + last_value_loss = None + + while self.global_step < self.iters: + if self.reward_source == "offline" and self.rollout_samples: + prompt_samples = self._next_offline_rollout_batch() + rollout_batch = self._collect_rollout_batch( + prompt_samples, + cache_key=f"{self.algorithm}.train_rollout_reference_logprobs", + ) + else: + prompt_samples = self._next_prompt_batch() + rollout_batch = self._collect_rollout_batch(prompt_samples) + self._last_rollout_batch = rollout_batch + effective_beta = self._effective_kl_beta(rollout_batch) + + total_policy_loss = 0.0 + total_value_loss = 0.0 + update_count = 0 + for _ in range(max(1, self.minibatch_reuse_steps)): + minibatches = assemble_minibatches( + rollout_batch.policy_eval, + minibatch_size=_rollout_score_batch_size(self), + shuffle=True, + mode="completion", + token_budget=self.score_chunk_size, + ) + for minibatch in minibatches: + policy_loss, policy_grads = policy_value_and_grad(policy_model, minibatch) + policy_optimizer.update(policy_model, policy_grads) + mx.eval(policy_model.parameters(), policy_optimizer.state) + + value_loss, value_grads = value_value_and_grad(self._value_train_target, minibatch) + value_optimizer.update(self._value_train_target, value_grads) + mx.eval(self._value_train_target.parameters(), value_optimizer.state) + + total_policy_loss += float(policy_loss.item()) + total_value_loss += float(value_loss.item()) + update_count += 1 + + last_policy_loss = total_policy_loss / max(1, update_count) + last_value_loss = total_value_loss / max(1, update_count) + running_loss += last_policy_loss + self.global_step += 1 + train_row = self._record_metrics( + "train", + summarize_rollout_metrics( + rollout_batch, + policy_loss=last_policy_loss, + value_loss=last_value_loss, + ), + ) - def _train_subprocess(self): - """Fallback subprocess training.""" - warnings.warn("Using SFT approximation for ORPO.", UserWarning) + if self.global_step % self.logging_steps == 0: + print(f"PPO step {self.global_step}/{self.iters} | {self._format_metric_summary(train_row)}") + running_loss = 0.0 - train_file = self.output_dir / "train.jsonl" - with open(train_file, 'w') as f: - for sample in self.train_dataset: - if 'prompt' in sample and 'chosen' in sample: - messages = [ - {"role": "user", "content": sample['prompt']}, - {"role": "assistant", "content": sample['chosen']} - ] - f.write(json.dumps({"messages": messages}) + '\n') + if self.eval_steps and self.global_step % self.eval_steps == 0: + eval_row = self.evaluate() + if eval_row: + print(f"PPO eval | {self._format_metric_summary(eval_row, namespace='eval')}") - import shutil - shutil.copy(train_file, self.output_dir / "valid.jsonl") + if self.global_step % self.save_steps == 0: + self.save_state(optimizers=self.optimizers) - cmd = [ - "mlx_lm.lora", "--model", getattr(self.model, 'model_name', 'model'), - "--train", "--data", str(self.output_dir), "--iters", str(self.iters), - "--learning-rate", str(self.learning_rate), "--batch-size", str(self.batch_size), - "--adapter-path", str(self.adapter_path), - ] - subprocess.run(cmd, check=True) - return {"status": "success"} + self.save_state(optimizers=self.optimizers) + return { + "status": "success", + "adapter_path": str(self.adapter_path), + "global_step": self.global_step, + "final_loss": last_policy_loss, + "final_value_loss": last_value_loss, + } -class GRPOTrainer: - """ - Group Relative Policy Optimization Trainer. - - GRPO is the technique used by DeepSeek to train reasoning models like R1. - It removes the need for a value model by using group statistics from - multiple generations and custom reward functions. - - Key features: - - No value model needed (uses group statistics) - - Custom reward functions (for math, code verification, etc.) - - Supports GRPO, Dr.GRPO, DAPO, BNPO variants - - NOW WITH FULL MULTI-GENERATION IMPLEMENTATION! - - Example: - >>> def math_reward(response, prompt): - ... # Custom reward for math problems - ... return 1.0 if "correct" in response.lower() else 0.0 - >>> - >>> trainer = GRPOTrainer( - ... model=model, - ... train_dataset=math_dataset, - ... tokenizer=tokenizer, - ... reward_fn=math_reward, - ... args=GRPOConfig( - ... loss_type='grpo', - ... num_generations=4, - ... ), - ... ) - >>> trainer.train() - """ +class OnlineDPOTrainer(_RLTrainerBase): + algorithm = "online_dpo" + requires_reference_policy = True def __init__( self, model: Any, train_dataset: Any, + eval_dataset: Any = None, + eval_preference_dataset: Any = None, tokenizer: Optional[Any] = None, reward_fn: Optional[Callable] = None, - args: Optional[GRPOConfig] = None, + reward_model: Optional[Any] = None, + ref_model: Optional[Any] = None, + args: Optional[OnlineDPOConfig] = None, use_native: bool = True, - **kwargs + **kwargs, ): self.model = model self.train_dataset = train_dataset - self.tokenizer = tokenizer or getattr(model, 'tokenizer', None) + self.eval_dataset = eval_dataset + self.eval_preference_dataset = eval_preference_dataset + self.tokenizer = tokenizer or getattr(model, "tokenizer", None) + self.ref_model = ref_model + self.config = args or OnlineDPOConfig() + self.reward_model = reward_model or self.config.reward_model + self.reward_fn = reward_fn if reward_fn is not None else self.config.reward_fn + self.reward_sources = self.config.reward_sources + self.reward_source = self.config.reward_source self.use_native = use_native and HAS_NATIVE_TRAINING - - if args is None: - args = GRPOConfig() - - self.config = args - self.loss_type = args.loss_type - self.beta = args.beta - self.num_generations = args.num_generations - self.max_completion_length = args.max_completion_length - self.reward_fn = reward_fn or args.reward_fn - self.output_dir = Path(args.output_dir) - self.learning_rate = args.learning_rate - self.batch_size = args.per_device_train_batch_size - self.max_steps = args.max_steps - self.temperature = args.temperature - self.logging_steps = args.logging_steps - self.save_steps = args.save_steps - - if self.max_steps > 0: - self.iters = self.max_steps - else: - dataset_size = len(train_dataset) if hasattr(train_dataset, '__len__') else 100 - self.iters = max(1, (dataset_size // self.batch_size) * args.num_train_epochs) - + self.beta = self.config.beta + self.kl_target = self.config.kl_target + self.kl_penalty_mode = self.config.kl_penalty_mode + self.label_smoothing = self.config.label_smoothing + self.output_dir = Path(self.config.output_dir) + self.learning_rate = self.config.learning_rate + self.batch_size = self.config.per_device_train_batch_size + self.rollout_batch_size = self.config.rollout_batch_size or self.batch_size + self.max_steps = self.config.max_steps + self.max_seq_length = self.config.max_seq_length + self.max_completion_length = self.config.max_completion_length + self.num_generations = self.config.num_generations + self.temperature = self.config.temperature + self.reward_normalization = self.config.reward_normalization + self.mask_truncated_completions = self.config.mask_truncated_completions + self.minibatch_reuse_steps = self.config.minibatch_reuse_steps + self.entropy_bonus = self.config.entropy_bonus + self.eval_steps = self.config.eval_steps + self.eval_num_batches = self.config.eval_num_batches + self.eval_num_generations = self.config.eval_num_generations or self.num_generations + self.generation_batch_size = self.config.generation_batch_size + self.score_chunk_size = self.config.score_chunk_size + self.precompute_reference_scores = self.config.precompute_reference_scores + self.logging_steps = self.config.logging_steps + self.save_steps = self.config.save_steps + dataset_size = len(train_dataset) if hasattr(train_dataset, "__len__") else 100 + self.iters = self.max_steps if self.max_steps > 0 else max( + 1, (dataset_size // max(1, self.batch_size)) * self.config.num_train_epochs + ) self.output_dir.mkdir(parents=True, exist_ok=True) - self.adapter_path = self.output_dir / "adapters" + self.adapter_path = self.output_dir / "policy" self.adapter_path.mkdir(parents=True, exist_ok=True) + self.prompt_samples: List[Dict[str, Any]] = [] + self.rollout_samples: List[Dict[str, Any]] = [] + self.eval_prompt_samples: List[Dict[str, Any]] = [] + self.eval_rollout_samples: List[Dict[str, Any]] = [] + self.eval_preference_samples: List[Dict[str, Any]] = [] + self.prepared_dataset_mode: Optional[str] = None + self.prompt_dataset_cursor = 0 + self.rollout_dataset_cursor = 0 + self._last_rollout_batch: Optional[RolloutBatch] = None + self._init_native_state() + if self.num_generations < 2: + raise ValueError("OnlineDPOTrainer requires num_generations >= 2.") + + def _prepare_prompt_samples(self) -> None: + self.prompt_samples, self.rollout_samples, self.prepared_dataset_mode = _prepare_on_policy_samples( + self.train_dataset, + self.tokenizer, + self.config, + ) + if not self.prompt_samples and not self.rollout_samples: + raise ValueError("OnlineDPOTrainer requires prompt or rollout samples.") - # Default reward function if none provided - if self.reward_fn is None: - self.reward_fn = lambda response, prompt: len(response.split()) / 100.0 - - print(f"GRPOTrainer initialized:") - print(f" Loss type: {self.loss_type}") - print(f" Beta: {self.beta}") - print(f" Num generations: {self.num_generations}") - print(f" Custom reward fn: {'Yes' if reward_fn else 'Default (length-based)'}") - print(f" Learning rate: {self.learning_rate}") - print(f" Iterations: {self.iters}") - print(f" Native GRPO: {self.use_native}") - - def train(self): - """ - Train using GRPO with multi-generation sampling. - """ - print("=" * 70) - print(f"Starting GRPO Training (loss_type={self.loss_type})") - print("=" * 70) - - if self.use_native: - return self._train_native() - else: - return self._train_subprocess() - - def _train_native(self): - """Train with native GRPO: multi-generation + reward + policy gradient.""" - print("\n[Using Native GRPO Training with Multi-Generation]") - - if hasattr(self.model, '_apply_lora') and not getattr(self.model, '_lora_applied', False): - self.model._apply_lora() - - # Prepare prompts - prompts = [] - for sample in self.train_dataset: - if 'prompt' in sample: - prompts.append(sample['prompt']) - elif 'question' in sample: - prompts.append(sample['question']) - print(f"✓ Prepared {len(prompts)} prompts") - - actual_model = self.model.model if hasattr(self.model, 'model') else self.model - lr_schedule = optim.cosine_decay(self.learning_rate, self.iters) - optimizer = optim.AdamW(learning_rate=lr_schedule) + def _trainer_cursor_state(self) -> Dict[str, int]: + return { + "dataset": int(self.dataset_cursor), + "prompt_dataset": int(self.prompt_dataset_cursor), + "offline_rollout_dataset": int(self.rollout_dataset_cursor), + } - print(f"\nStarting training for {self.iters} iterations...") - print(f" Generating {self.num_generations} completions per prompt") + def _restore_trainer_cursors(self, cursors: Dict[str, Any]) -> None: + super()._restore_trainer_cursors(cursors) + self.prompt_dataset_cursor = int(cursors.get("prompt_dataset", self.prompt_dataset_cursor)) + self.rollout_dataset_cursor = int(cursors.get("offline_rollout_dataset", self.rollout_dataset_cursor)) - total_loss = 0.0 - for step in range(self.iters): - # Get prompt for this step - prompt_idx = step % len(prompts) - prompt = prompts[prompt_idx] + def _sampling_config_payload(self) -> Dict[str, Any]: + return { + "algorithm": self.algorithm, + "beta": self.beta, + "label_smoothing": self.label_smoothing, + "kl_target": self.kl_target, + "kl_penalty_mode": self.kl_penalty_mode, + "rollout_batch_size": self.rollout_batch_size, + "minibatch_reuse_steps": self.minibatch_reuse_steps, + "reward_source": self.reward_source, + "reward_normalization": self.reward_normalization, + "mask_truncated_completions": self.mask_truncated_completions, + "entropy_bonus": self.entropy_bonus, + "temperature": self.temperature, + "num_generations": self.num_generations, + "max_completion_length": self.max_completion_length, + "max_seq_length": self.max_seq_length, + "generation_batch_size": self.generation_batch_size, + "score_chunk_size": self.score_chunk_size, + } - # Compute GRPO loss with multi-generation - loss, n_gen = grpo_batch_loss( - model=actual_model, + def _prepare_eval_datasets(self) -> None: + self.eval_prompt_samples = [] + self.eval_rollout_samples = [] + self.eval_preference_samples = [] + if self.eval_dataset is not None: + self.eval_prompt_samples, self.eval_rollout_samples, _ = _prepare_on_policy_samples( + self.eval_dataset, + self.tokenizer, + self.config, + ) + if self.eval_preference_dataset is not None: + prepared = prepare_rl_dataset( + self.eval_preference_dataset, + mode="preference", tokenizer=self.tokenizer, - prompts=[prompt], - reward_fn=self.reward_fn, - num_generations=self.num_generations, - temperature=self.temperature, - max_tokens=self.max_completion_length, - beta=self.beta, + chat_template=getattr(self.config, "chat_template", None), ) + for sample_index, sample in enumerate(prepared): + prompt = sample.get("prompt", "") + prompt_ids = _encode_text(self.tokenizer, prompt) + chosen_ids = prompt_ids + _encode_text(self.tokenizer, sample["chosen"], add_special_tokens=False) + rejected_ids = prompt_ids + _encode_text(self.tokenizer, sample["rejected"], add_special_tokens=False) + self.eval_preference_samples.append( + { + "sample_index": sample_index, + "chosen_ids": chosen_ids, + "rejected_ids": rejected_ids, + "chosen_length": len(chosen_ids), + "rejected_length": len(rejected_ids), + } + ) + + def _resolve_reward_evaluator(self) -> Any: + evaluator = _resolve_reward_evaluator( + self.reward_model, + self.reward_fn, + self.reward_sources, + ) + if evaluator is not None: + return evaluator + return lambda response, context: len(response.split()) / 100.0 + + def _next_prompt_batch(self) -> List[Dict[str, Any]]: + batch, self.prompt_dataset_cursor = _next_cursor_batch( + self.prompt_samples, + self.rollout_batch_size, + self.prompt_dataset_cursor, + self.algorithm, + ) + self.dataset_cursor = self.prompt_dataset_cursor + return batch + + def _next_offline_rollout_batch(self) -> List[Dict[str, Any]]: + batch, self.rollout_dataset_cursor = _next_cursor_batch( + self.rollout_samples, + self.rollout_batch_size * self.num_generations, + self.rollout_dataset_cursor, + self.algorithm, + ) + self.dataset_cursor = self.rollout_dataset_cursor + return batch + + def _collect_fixed_rollout_batch( + self, + samples: List[Dict[str, Any]], + cache_key: Optional[str] = None, + ) -> RolloutBatch: + cached_reference_logprobs = None + if cache_key is not None: + cached_reference_logprobs = _ensure_fixed_rollout_reference_cache(self, cache_key, samples) + return _collect_fixed_rollout_batch( + self, + samples, + cached_reference_logprobs=cached_reference_logprobs, + ) - # Manual backward pass since grpo_batch_loss generates internally - # For a proper implementation, we'd need to track gradients through generation - # This is a simplified version that uses the loss for logging - mx.eval(loss) - total_loss += loss.item() + def _build_online_preference_batch(self, rollout_batch: RolloutBatch) -> Optional[PreferenceBatch]: + rankings = rank_grouped_rollouts(rollout_batch) + chosen_sequences: List[List[int]] = [] + rejected_sequences: List[List[int]] = [] + sample_indices: List[int] = [] + for ranking in rankings: + if ranking["all_tied"]: + continue + best = ranking["best_position"] + worst = ranking["worst_position"] + chosen_sequences.append(rollout_batch.prompt_ids[best] + rollout_batch.completion_ids[best]) + rejected_sequences.append(rollout_batch.prompt_ids[worst] + rollout_batch.completion_ids[worst]) + sample_indices.append(int(rollout_batch.sample_indices[best].item())) + if not chosen_sequences: + return None + + preference_batch = make_preference_batch( + chosen_sequences=chosen_sequences, + rejected_sequences=rejected_sequences, + pad_id=_pad_token_id(self.tokenizer), + sample_indices=sample_indices, + ) + reference_model = _actual_model(self.reference_policy.model) + preference_batch.chosen_reference_logprobs = score_policy_in_chunks( + reference_model, + preference_batch.chosen, + batch_size=max(1, self.batch_size), + token_budget=self.score_chunk_size, + mode="sequence", + ).summed_logprobs + preference_batch.rejected_reference_logprobs = score_policy_in_chunks( + reference_model, + preference_batch.rejected, + batch_size=max(1, self.batch_size), + token_budget=self.score_chunk_size, + mode="sequence", + ).summed_logprobs + preference_batch.chosen.reference_logprobs = preference_batch.chosen_reference_logprobs + preference_batch.rejected.reference_logprobs = preference_batch.rejected_reference_logprobs + return preference_batch + + def _collect_rollout_batch( + self, + prompt_samples: List[Dict[str, Any]], + num_generations: Optional[int] = None, + cache_key: Optional[str] = None, + ) -> RolloutBatch: + generations = self.num_generations if num_generations is None else num_generations + if prompt_samples and "completion" in prompt_samples[0]: + rollout_batch = self._collect_fixed_rollout_batch(prompt_samples, cache_key=cache_key) + else: + rollout_batch = collect_rollouts( + _actual_model(self.model), + self.tokenizer, + prompt_samples=prompt_samples, + sampling_config={ + "num_generations": generations, + "temperature": self.temperature, + "max_completion_length": self.max_completion_length, + "max_seq_length": self.max_seq_length, + "generation_batch_size": self.generation_batch_size, + }, + reward_evaluator=None, + collect_sample_stats=self.entropy_bonus != 0.0, + ) + if self.reward_source == "offline": + raise ValueError("reward_source='offline' requires rollout samples with completion/reward fields.") + reward_batch = evaluate_rewards(rollout_batch, self._resolve_reward_evaluator()) + rollout_batch.rewards = reward_batch.scalar_rewards + if prompt_samples and "completion" in prompt_samples[0]: + if self.reward_source == "online": + reward_batch = evaluate_rewards(rollout_batch, self._resolve_reward_evaluator()) + rollout_batch.rewards = reward_batch.scalar_rewards + elif self.reward_source == "hybrid": + reward_batch = evaluate_rewards(rollout_batch, self._resolve_reward_evaluator()) + rollout_batch.rewards = rollout_batch.rewards + reward_batch.scalar_rewards + if self.entropy_bonus and rollout_batch.token_entropies is not None: + entropy_bonus = mx.mean(rollout_batch.token_entropies, axis=-1) * self.entropy_bonus + rollout_batch.rewards = rollout_batch.rewards + entropy_bonus.astype(mx.float32) + if self.reward_normalization != "none": + rollout_batch.rewards = _normalize_reward_values( + rollout_batch.rewards, + rollout_batch.prompt_group_indices, + self.reward_normalization, + ) + if self.kl_penalty_mode != "none" or self.kl_target is not None: + rollout_batch = score_rollout_references( + _actual_model(self.reference_policy.model), + rollout_batch, + batch_size=_rollout_score_batch_size(self, num_generations=generations), + token_budget=self.score_chunk_size, + ) + if self.mask_truncated_completions: + rollout_batch = _apply_truncation_mask_to_rollout(rollout_batch) + return rollout_batch + + def _evaluate_rollout_metrics(self) -> Dict[str, Any]: + prompt_limit = max(1, self.eval_num_batches or 1) * self.rollout_batch_size + rollout_limit = max(1, self.eval_num_batches or 1) * self.rollout_batch_size * self.num_generations + if self.eval_rollout_samples: + rollout_batch = self._collect_rollout_batch(self.eval_rollout_samples[:rollout_limit]) + elif self.eval_prompt_samples: + rollout_batch = self._collect_rollout_batch( + self.eval_prompt_samples[:prompt_limit], + num_generations=self.eval_num_generations, + ) + else: + return {} + preference_batch = self._build_online_preference_batch(rollout_batch) + metrics = summarize_rollout_metrics(rollout_batch) + if preference_batch is not None: + effective_beta = self._effective_kl_beta(rollout_batch) + loss, _ = compute_dpo_loss( + model=_actual_model(self.model), + chosen_ids=preference_batch.chosen.input_ids, + rejected_ids=preference_batch.rejected.input_ids, + chosen_lengths=preference_batch.chosen.sequence_lengths, + rejected_lengths=preference_batch.rejected.sequence_lengths, + beta=effective_beta, + reference_chosen_logprobs=preference_batch.chosen.reference_logprobs, + reference_rejected_logprobs=preference_batch.rejected.reference_logprobs, + label_smoothing=self.label_smoothing, + ) + metrics["policy_loss"] = float(loss.item()) + return metrics + + def _evaluate_preference_metrics(self) -> Dict[str, Any]: + if not self.eval_preference_samples: + return {} + limit = max(1, self.eval_num_batches or 1) * self.batch_size + samples = self.eval_preference_samples[:limit] + preference_batch = make_preference_batch( + chosen_sequences=[sample["chosen_ids"] for sample in samples], + rejected_sequences=[sample["rejected_ids"] for sample in samples], + pad_id=_pad_token_id(self.tokenizer), + sample_indices=[sample["sample_index"] for sample in samples], + ) + chosen_scores = score_policy_in_chunks( + _actual_model(self.model), + preference_batch.chosen, + batch_size=max(1, self.batch_size), + token_budget=self.score_chunk_size, + mode="sequence", + ).summed_logprobs + rejected_scores = score_policy_in_chunks( + _actual_model(self.model), + preference_batch.rejected, + batch_size=max(1, self.batch_size), + token_budget=self.score_chunk_size, + mode="sequence", + ).summed_logprobs + metrics = { + "preference_win_rate": float(mx.mean((chosen_scores > rejected_scores).astype(mx.float32)).item()) + } + chosen_reference, rejected_reference = _ensure_preference_reference_cache( + self, + f"{self.algorithm}.eval_preference_reference_logprobs", + samples, + ) + if chosen_reference is None or rejected_reference is None: + reference_model = _actual_model(self.reference_policy.model) + chosen_reference = score_policy_in_chunks( + reference_model, + preference_batch.chosen, + batch_size=max(1, self.batch_size), + token_budget=self.score_chunk_size, + mode="sequence", + ).summed_logprobs + rejected_reference = score_policy_in_chunks( + reference_model, + preference_batch.rejected, + batch_size=max(1, self.batch_size), + token_budget=self.score_chunk_size, + mode="sequence", + ).summed_logprobs + loss, _ = compute_dpo_loss( + model=_actual_model(self.model), + chosen_ids=preference_batch.chosen.input_ids, + rejected_ids=preference_batch.rejected.input_ids, + chosen_lengths=preference_batch.chosen.sequence_lengths, + rejected_lengths=preference_batch.rejected.sequence_lengths, + beta=self.beta, + reference_chosen_logprobs=chosen_reference, + reference_rejected_logprobs=rejected_reference, + label_smoothing=self.label_smoothing, + ) + metrics["policy_loss"] = float(loss.item()) + return metrics + + def evaluate(self) -> Dict[str, Any]: + self._prepare_eval_datasets() + if self.reference_policy is None: + self._ensure_reference_policy() + with self._preserve_rng_state(): + mx.random.seed(int(self.seed) + 100000 + int(self.global_step)) + metrics: Dict[str, Any] = {} + metrics.update(self._evaluate_rollout_metrics()) + metrics.update(self._evaluate_preference_metrics()) + if not metrics: + return {} + return self._record_metrics("eval", metrics) + + def train(self, resume_from_checkpoint: Optional[str] = None): + if not self.use_native: + raise ValueError("OnlineDPOTrainer requires native MLX training support.") + return self._train_native(resume_from_checkpoint=resume_from_checkpoint) + + def _train_native(self, resume_from_checkpoint: Optional[str] = None): + self._apply_lora_if_needed() + self._prepare_prompt_samples() + self._prepare_eval_datasets() + if resume_from_checkpoint is None: + self._seed_training_run() + + actual_model = _actual_model(self.model) + optimizer = self._optimizer_for_training() + self.optimizer = optimizer + self.optimizers = {"policy": optimizer} + + if resume_from_checkpoint is not None: + self.load_state(optimizer=optimizer, checkpoint_dir=Path(resume_from_checkpoint)) + else: + self._ensure_reference_policy() + if self.rollout_samples: + _ensure_fixed_rollout_reference_cache( + self, + f"{self.algorithm}.train_rollout_reference_logprobs", + self.rollout_samples, + ) - if (step + 1) % self.logging_steps == 0: - avg_loss = total_loss / self.logging_steps - print(f" Step {step + 1}/{self.iters} | Loss: {avg_loss:.4f}") - total_loss = 0.0 + effective_beta = self.beta - # Save adapters and config - _save_adapters_and_config(self.model, self.adapter_path) + def loss_fn(model, batch): + loss, _ = compute_dpo_loss( + model=model, + chosen_ids=batch.chosen.input_ids, + rejected_ids=batch.rejected.input_ids, + chosen_lengths=batch.chosen.sequence_lengths, + rejected_lengths=batch.rejected.sequence_lengths, + beta=effective_beta, + reference_chosen_logprobs=batch.chosen.reference_logprobs, + reference_rejected_logprobs=batch.rejected.reference_logprobs, + label_smoothing=self.label_smoothing, + ) + return loss - print("\n" + "=" * 70) - print("GRPO Training Complete!") - print("=" * 70) - print(f" Adapters saved to: {self.adapter_path}") - print(f"Note: Full GRPO with gradient flow through generation requires") - print(f" custom implementation. This version uses reward signals.") - return {"status": "success", "adapter_path": str(self.adapter_path)} + value_and_grad = nn.value_and_grad(actual_model, loss_fn) + running_loss = 0.0 + last_loss = None + + while self.global_step < self.iters: + if self.reward_source == "offline" and self.rollout_samples: + prompt_samples = self._next_offline_rollout_batch() + rollout_batch = self._collect_rollout_batch( + prompt_samples, + cache_key=f"{self.algorithm}.train_rollout_reference_logprobs", + ) + else: + prompt_samples = self._next_prompt_batch() + rollout_batch = self._collect_rollout_batch(prompt_samples) + self._last_rollout_batch = rollout_batch + effective_beta = self._effective_kl_beta(rollout_batch) + preference_batch = self._build_online_preference_batch(rollout_batch) + if preference_batch is None: + self.global_step += 1 + train_row = self._record_metrics( + "train", + {**summarize_rollout_metrics(rollout_batch, policy_loss=0.0), "skipped_pairs": True}, + ) + if self.global_step % self.logging_steps == 0: + print(f"Online DPO step {self.global_step}/{self.iters} | {self._format_metric_summary(train_row)}") + continue + + loss, grads = value_and_grad(actual_model, preference_batch) + optimizer.update(actual_model, grads) + mx.eval(actual_model.parameters(), optimizer.state) - def _train_subprocess(self): - """Fallback to SFT approximation.""" - warnings.warn( - "Native GRPO not available. Using SFT on provided responses.", - UserWarning - ) + last_loss = float(loss.item()) + running_loss += last_loss + self.global_step += 1 + train_row = self._record_metrics( + "train", + summarize_rollout_metrics(rollout_batch, policy_loss=last_loss), + ) - train_file = self.output_dir / "train.jsonl" - with open(train_file, 'w') as f: - for sample in self.train_dataset: - if 'prompt' in sample: - messages = [{"role": "user", "content": sample['prompt']}] - if 'response' in sample or 'answer' in sample: - response = sample.get('response', sample.get('answer', '')) - messages.append({"role": "assistant", "content": response}) - f.write(json.dumps({"messages": messages}) + '\n') + if self.global_step % self.logging_steps == 0: + print(f"Online DPO step {self.global_step}/{self.iters} | {self._format_metric_summary(train_row)}") + running_loss = 0.0 - import shutil - shutil.copy(train_file, self.output_dir / "valid.jsonl") + if self.eval_steps and self.global_step % self.eval_steps == 0: + eval_row = self.evaluate() + if eval_row: + print(f"Online DPO eval | {self._format_metric_summary(eval_row, namespace='eval')}") - cmd = [ - "mlx_lm.lora", "--model", getattr(self.model, 'model_name', 'model'), - "--train", "--data", str(self.output_dir), "--iters", str(self.iters), - "--adapter-path", str(self.adapter_path), - ] - subprocess.run(cmd, check=True) - return {"status": "success"} + if self.global_step % self.save_steps == 0: + self.save_state(optimizer=optimizer) + self.save_state(optimizer=optimizer) + return { + "status": "success", + "adapter_path": str(self.adapter_path), + "global_step": self.global_step, + "final_loss": last_loss, + } -class KTOTrainer: - """ - Kahneman-Tversky Optimization Trainer. - KTO uses prospect theory for preference optimization, - treating gains and losses asymmetrically. - Now with proper KTO loss implementation! - """ +class KTOTrainer(_RLTrainerBase): + algorithm = "kto" + requires_reference_policy = True def __init__( self, @@ -959,106 +4135,169 @@ def __init__( train_dataset: Any, tokenizer: Optional[Any] = None, beta: float = 0.1, + args: Optional[KTOConfig] = None, + ref_model: Optional[Any] = None, + reward_model: Optional[Any] = None, + value_model: Optional[Any] = None, use_native: bool = True, - **kwargs + **kwargs, ): self.model = model self.train_dataset = train_dataset - self.tokenizer = tokenizer or getattr(model, 'tokenizer', None) - self.beta = beta + self.tokenizer = tokenizer or getattr(model, "tokenizer", None) + self.config = args or KTOConfig(beta=beta, **kwargs) + self.beta = self.config.beta + self.ref_model = ref_model + self.reward_model = reward_model + self.value_model = value_model self.use_native = use_native and HAS_NATIVE_TRAINING - self.output_dir = Path(kwargs.get('output_dir', './kto_outputs')) - self.learning_rate = kwargs.get('learning_rate', 5e-7) - self.iters = kwargs.get('max_steps', 100) - self.max_seq_length = kwargs.get('max_seq_length', 2048) - self.logging_steps = kwargs.get('logging_steps', 10) - + self.output_dir = Path(self.config.output_dir) + self.learning_rate = self.config.learning_rate + self.iters = self.config.max_steps + self.max_seq_length = self.config.max_seq_length + self.batch_size = self.config.per_device_train_batch_size + self.logging_steps = self.config.logging_steps + self.save_steps = self.config.save_steps self.output_dir.mkdir(parents=True, exist_ok=True) self.adapter_path = self.output_dir / "adapters" self.adapter_path.mkdir(parents=True, exist_ok=True) + self.train_samples: List[Dict[str, Any]] = [] + self._init_native_state() + + def _prepare_training_samples(self) -> None: + self.train_samples = [] + for sample_index, sample in enumerate(self.train_dataset): + if "text" not in sample or "label" not in sample: + continue + ids = self.tokenizer.encode(sample["text"])[: self.max_seq_length] + self.train_samples.append( + { + "sample_index": sample_index, + "ids": ids, + "length": len(ids), + "label": float(sample["label"]), + } + ) + if not self.train_samples: + raise ValueError("KTOTrainer requires text/label samples.") + + def _precompute_reference_cache(self) -> None: + self._ensure_reference_policy() + eval_batch = make_policy_eval_batch( + [sample["ids"] for sample in self.train_samples], + pad_id=_pad_token_id(self.tokenizer), + mode="sequence", + labels=mx.array([sample["label"] for sample in self.train_samples]), + sample_indices=mx.array([sample["sample_index"] for sample in self.train_samples]), + ) + reference_logprobs = score_policy_in_chunks( + _actual_model(self.reference_policy.model), + eval_batch, + batch_size=max(1, self.batch_size), + mode="sequence", + ).summed_logprobs + reference_logprobs = mx.stop_gradient(reference_logprobs) + for idx, sample in enumerate(self.train_samples): + sample["reference_logprobs"] = reference_logprobs[idx] + self.cache_metadata = { + "type": "inline_kto_reference_logprobs", + "num_samples": len(self.train_samples), + } - print(f"KTOTrainer initialized (beta={self.beta}, native={self.use_native})") - - def train(self): - """Train using KTO with proper loss.""" - print("=" * 70) - print("Starting KTO Training") - print("=" * 70) - - if not self.use_native: - warnings.warn("KTO requires native training. Using SFT approximation.", UserWarning) - return {"status": "fallback"} + def _restore_reference_cache(self, flat_state: Dict[str, mx.array]) -> None: + if "kto.reference_logprobs" not in flat_state: + self._precompute_reference_cache() + return + reference_logprobs = flat_state["kto.reference_logprobs"] + if reference_logprobs.shape[0] != len(self.train_samples): + raise ValueError("Saved KTO cache does not match current dataset ordering.") + for idx, sample in enumerate(self.train_samples): + sample["reference_logprobs"] = reference_logprobs[idx] + + def _build_batch(self, samples: List[Dict[str, Any]]) -> PolicyEvalBatch: + return make_policy_eval_batch( + [sample["ids"] for sample in samples], + pad_id=_pad_token_id(self.tokenizer), + mode="sequence", + labels=mx.array([sample["label"] for sample in samples]), + reference_logprobs=mx.array([sample["reference_logprobs"] for sample in samples]), + sample_indices=mx.array([sample["sample_index"] for sample in samples]), + ) - print("\n[Using Native KTO Training with Proper Loss]") + def _extra_state_arrays(self) -> Dict[str, mx.array]: + return { + "kto.reference_logprobs": mx.array( + [sample["reference_logprobs"] for sample in self.train_samples] + ) + } - if hasattr(self.model, '_apply_lora') and not getattr(self.model, '_lora_applied', False): - self.model._apply_lora() + def train(self, resume_from_checkpoint: Optional[str] = None): + if self.use_native: + return self._train_native(resume_from_checkpoint=resume_from_checkpoint) + warnings.warn("KTO requires native training. Using SFT approximation.", UserWarning) + return {"status": "fallback"} - actual_model = self.model.model if hasattr(self.model, 'model') else self.model - lr_schedule = optim.cosine_decay(self.learning_rate, self.iters) - optimizer = optim.AdamW(learning_rate=lr_schedule) - - # Prepare data - KTO expects samples with 'text' and 'label' (1=positive, 0=negative) - tokenized_data = [] - for sample in self.train_dataset: - if 'text' in sample and 'label' in sample: - ids = self.tokenizer.encode(sample['text'])[:self.max_seq_length] - tokenized_data.append({ - 'ids': ids, - 'length': len(ids), - 'label': float(sample['label']), - }) - - print(f"✓ Prepared {len(tokenized_data)} samples") - - def loss_fn(model, batch_data): - input_ids, lengths, labels = batch_data - loss, _ = compute_kto_loss(model, input_ids, lengths, labels, self.beta) - return loss + def _train_native(self, resume_from_checkpoint: Optional[str] = None): + self._apply_lora_if_needed() + self._prepare_training_samples() - loss_and_grad = nn.value_and_grad(actual_model, loss_fn) + actual_model = _actual_model(self.model) + optimizer = self._optimizer_for_training() + self.optimizer = optimizer - total_loss = 0.0 - for step in range(self.iters): - sample = tokenized_data[step % len(tokenized_data)] - pad_id = self.tokenizer.pad_token_id or 0 + if resume_from_checkpoint is not None: + flat_state = self.load_state(optimizer, Path(resume_from_checkpoint)) + self._restore_reference_cache(flat_state) + else: + self._precompute_reference_cache() - max_len = sample['length'] - ids_padded = sample['ids'] + [pad_id] * (max_len - len(sample['ids'])) + def loss_fn(model, batch): + loss, _ = compute_kto_loss( + model=model, + input_ids=batch.input_ids, + lengths=batch.sequence_lengths, + labels=batch.labels, + beta=self.beta, + reference_logprobs=batch.reference_logprobs, + ) + return loss - input_ids = mx.array([ids_padded]) - lengths = mx.array([sample['length']]) - labels = mx.array([sample['label']]) + value_and_grad = nn.value_and_grad(actual_model, loss_fn) + running_loss = 0.0 + last_loss = None - loss, grads = loss_and_grad(actual_model, (input_ids, lengths, labels)) + while self.global_step < self.iters: + batch_samples = self._next_samples(self.train_samples) + batch = self._build_batch(batch_samples) + loss, grads = value_and_grad(actual_model, batch) optimizer.update(actual_model, grads) mx.eval(actual_model.parameters(), optimizer.state) - total_loss += loss.item() + last_loss = loss.item() + running_loss += last_loss + self.global_step += 1 + self._record_metric(loss=last_loss) - if (step + 1) % self.logging_steps == 0: - print(f" Step {step + 1}/{self.iters} | Loss: {total_loss / self.logging_steps:.4f}") - total_loss = 0.0 + if self.global_step % self.logging_steps == 0: + print( + f"KTO step {self.global_step}/{self.iters} | " + f"loss={running_loss / self.logging_steps:.4f}" + ) + running_loss = 0.0 - # Save adapters and config - _save_adapters_and_config(self.model, self.adapter_path) + if self.global_step % self.save_steps == 0: + self.save_state(optimizer, self._extra_state_arrays()) - print("\n" + "=" * 70) - print("KTO Training Complete!") - print("=" * 70) - print(f" Adapters saved to: {self.adapter_path}") - return {"status": "success", "adapter_path": str(self.adapter_path)} + self.save_state(optimizer, self._extra_state_arrays()) + return { + "status": "success", + "adapter_path": str(self.adapter_path), + "global_step": self.global_step, + "final_loss": last_loss, + } class SimPOTrainer: - """ - Simple Preference Optimization Trainer. - - SimPO simplifies DPO by removing the reference model requirement. - Uses length-normalized log probabilities as implicit rewards. - Now with proper SimPO loss implementation! - """ - def __init__( self, model: Any, @@ -1066,207 +4305,98 @@ def __init__( tokenizer: Optional[Any] = None, gamma: float = 0.5, beta: float = 2.0, + args: Optional[SimPOConfig] = None, use_native: bool = True, - **kwargs + **kwargs, ): self.model = model self.train_dataset = train_dataset - self.tokenizer = tokenizer or getattr(model, 'tokenizer', None) - self.gamma = gamma - self.beta = beta + self.tokenizer = tokenizer or getattr(model, "tokenizer", None) + self.config = args or SimPOConfig(gamma=gamma, beta=beta, **kwargs) + self.gamma = self.config.gamma + self.beta = self.config.beta self.use_native = use_native and HAS_NATIVE_TRAINING - self.output_dir = Path(kwargs.get('output_dir', './simpo_outputs')) - self.learning_rate = kwargs.get('learning_rate', 5e-7) - self.iters = kwargs.get('max_steps', 100) - self.max_seq_length = kwargs.get('max_seq_length', 2048) - self.logging_steps = kwargs.get('logging_steps', 10) - + self.output_dir = Path(self.config.output_dir) + self.learning_rate = self.config.learning_rate + self.batch_size = self.config.per_device_train_batch_size + self.iters = self.config.max_steps + self.max_seq_length = self.config.max_seq_length + self.logging_steps = self.config.logging_steps self.output_dir.mkdir(parents=True, exist_ok=True) self.adapter_path = self.output_dir / "adapters" self.adapter_path.mkdir(parents=True, exist_ok=True) - print(f"SimPOTrainer initialized (gamma={gamma}, beta={beta}, native={self.use_native})") - - def _tokenize_pair(self, sample): - prompt = sample.get('prompt', '') - chosen = sample.get('chosen', '') - rejected = sample.get('rejected', '') - - chosen_ids = self.tokenizer.encode(prompt + chosen)[:self.max_seq_length] - rejected_ids = self.tokenizer.encode(prompt + rejected)[:self.max_seq_length] - + def _tokenize_pair(self, sample: Dict[str, Any]) -> Dict[str, Any]: + prompt = sample.get("prompt", "") + chosen = sample.get("chosen", "") + rejected = sample.get("rejected", "") + chosen_ids = self.tokenizer.encode(prompt + chosen)[: self.max_seq_length] + rejected_ids = self.tokenizer.encode(prompt + rejected)[: self.max_seq_length] return { - 'chosen_ids': chosen_ids, - 'rejected_ids': rejected_ids, - 'chosen_length': len(chosen_ids), - 'rejected_length': len(rejected_ids), + "chosen_ids": chosen_ids, + "rejected_ids": rejected_ids, + "chosen_length": len(chosen_ids), + "rejected_length": len(rejected_ids), } - def _pad(self, ids, length, pad_id=0): - return ids + [pad_id] * (length - len(ids)) if len(ids) < length else ids[:length] - def train(self): - """Train using SimPO with proper loss.""" - print("=" * 70) - print("Starting SimPO Training") - print("=" * 70) - if not self.use_native: warnings.warn("SimPO requires native training. Using SFT approximation.", UserWarning) return {"status": "fallback"} - print("\n[Using Native SimPO Training with Proper Loss]") - - if hasattr(self.model, '_apply_lora') and not getattr(self.model, '_lora_applied', False): + if hasattr(self.model, "_apply_lora") and not getattr(self.model, "_lora_applied", False): self.model._apply_lora() - tokenized_data = [] - for sample in self.train_dataset: - if 'prompt' in sample and 'chosen' in sample and 'rejected' in sample: - tokenized_data.append(self._tokenize_pair(sample)) - print(f"✓ Prepared {len(tokenized_data)} preference pairs") - - actual_model = self.model.model if hasattr(self.model, 'model') else self.model - lr_schedule = optim.cosine_decay(self.learning_rate, self.iters) - optimizer = optim.AdamW(learning_rate=lr_schedule) + prepared = prepare_rl_dataset( + self.train_dataset, + mode=self.config.dataset_mode or "preference", + tokenizer=self.tokenizer, + chat_template=getattr(self.config, "chat_template", None), + ) + tokenized_data = [self._tokenize_pair(sample) for sample in prepared] + actual_model = _actual_model(self.model) + optimizer = optim.AdamW(learning_rate=optim.cosine_decay(self.learning_rate, self.iters)) + pad_id = _pad_token_id(self.tokenizer) - def loss_fn(model, batch_data): - chosen_ids, rejected_ids, chosen_lengths, rejected_lengths = batch_data + def loss_fn(model, batch): + chosen_ids, rejected_ids, chosen_lengths, rejected_lengths = batch loss, _ = compute_simpo_loss( - model, chosen_ids, rejected_ids, chosen_lengths, rejected_lengths, - self.beta, self.gamma + model, + chosen_ids, + rejected_ids, + chosen_lengths, + rejected_lengths, + self.beta, + self.gamma, ) return loss - loss_and_grad = nn.value_and_grad(actual_model, loss_fn) + value_and_grad = nn.value_and_grad(actual_model, loss_fn) + last_loss = None - total_loss = 0.0 for step in range(self.iters): - sample = tokenized_data[step % len(tokenized_data)] - max_len = max(sample['chosen_length'], sample['rejected_length']) - pad_id = self.tokenizer.pad_token_id or 0 - - chosen_ids = mx.array([self._pad(sample['chosen_ids'], max_len, pad_id)]) - rejected_ids = mx.array([self._pad(sample['rejected_ids'], max_len, pad_id)]) - chosen_lengths = mx.array([sample['chosen_length']]) - rejected_lengths = mx.array([sample['rejected_length']]) - - loss, grads = loss_and_grad(actual_model, (chosen_ids, rejected_ids, chosen_lengths, rejected_lengths)) + start = (step * max(1, self.batch_size)) % len(tokenized_data) + samples = tokenized_data[start:start + max(1, self.batch_size)] + if len(samples) < max(1, self.batch_size): + samples += tokenized_data[: max(1, self.batch_size) - len(samples)] + chosen_ids, chosen_lengths = _pad_sequences([sample["chosen_ids"] for sample in samples], pad_id) + rejected_ids, rejected_lengths = _pad_sequences([sample["rejected_ids"] for sample in samples], pad_id) + loss, grads = value_and_grad(actual_model, (chosen_ids, rejected_ids, chosen_lengths, rejected_lengths)) optimizer.update(actual_model, grads) mx.eval(actual_model.parameters(), optimizer.state) + last_loss = loss.item() - total_loss += loss.item() - - if (step + 1) % self.logging_steps == 0: - print(f" Step {step + 1}/{self.iters} | Loss: {total_loss / self.logging_steps:.4f}") - total_loss = 0.0 - - # Save adapters and config _save_adapters_and_config(self.model, self.adapter_path) - - print("\n" + "=" * 70) - print("SimPO Training Complete!") - print("=" * 70) - print(f" Adapters saved to: {self.adapter_path}") - return {"status": "success", "adapter_path": str(self.adapter_path)} + return {"status": "success", "adapter_path": str(self.adapter_path), "final_loss": last_loss} -# Utility functions for preference data - def prepare_preference_dataset( dataset: Any, tokenizer: Any, format_type: str = "dpo", -) -> List[Dict]: - """ - Prepare dataset for preference-based training (DPO, ORPO, etc.). +) -> List[Dict[str, Any]]: + return public_prepare_preference_dataset(dataset, tokenizer=tokenizer, format_type=format_type) - Args: - dataset: HuggingFace dataset with preference pairs - tokenizer: Tokenizer for formatting - format_type: 'dpo', 'orpo', or 'grpo' - - Returns: - Formatted dataset ready for training - - Example: - >>> from datasets import load_dataset - >>> dataset = load_dataset("Anthropic/hh-rlhf") - >>> formatted = prepare_preference_dataset(dataset, tokenizer, "dpo") - """ - - formatted_data = [] - - for sample in dataset: - if format_type in ["dpo", "orpo"]: - # Expect chosen/rejected format - if 'chosen' in sample and 'rejected' in sample: - formatted_data.append({ - "prompt": sample.get('prompt', ''), - "chosen": sample['chosen'], - "rejected": sample['rejected'], - }) - elif format_type == "grpo": - # Expect prompt + optional ground truth - formatted_data.append({ - "prompt": sample.get('prompt', sample.get('question', '')), - "answer": sample.get('answer', sample.get('response', '')), - }) - - return formatted_data - - -def create_reward_function(reward_type: str = "simple") -> Callable: - """ - Create a reward function for GRPO training. - - Args: - reward_type: Type of reward function - - 'simple': Binary correct/incorrect - - 'math': Extract and compare numerical answers - - 'code': Execute and verify code output - - 'length': Reward based on response length - - Returns: - Reward function callable - - Example: - >>> reward_fn = create_reward_function('math') - >>> trainer = GRPOTrainer(..., reward_fn=reward_fn) - """ - - if reward_type == "simple": - def simple_reward(response: str, ground_truth: str) -> float: - return 1.0 if ground_truth.lower() in response.lower() else 0.0 - return simple_reward - - elif reward_type == "math": - def math_reward(response: str, ground_truth: str) -> float: - import re - # Extract numbers from response - numbers = re.findall(r'-?\d+\.?\d*', response) - target = re.findall(r'-?\d+\.?\d*', ground_truth) - if numbers and target: - try: - return 1.0 if float(numbers[-1]) == float(target[-1]) else 0.0 - except: - return 0.0 - return 0.0 - return math_reward - - elif reward_type == "length": - def length_reward(response: str, _: str) -> float: - # Reward longer, more detailed responses (up to a point) - length = len(response.split()) - if length < 10: - return 0.2 - elif length < 50: - return 0.5 - elif length < 200: - return 1.0 - else: - return 0.8 # Penalize very long responses - return length_reward - else: - raise ValueError(f"Unknown reward type: {reward_type}") +def create_reward_function(reward_type: Any = "simple", *, rewards: Optional[List[Any]] = None) -> Callable: + return public_create_reward_function(reward_type, rewards=rewards) diff --git a/mlx_tune/trainer.py b/mlx_tune/trainer.py index ee346f3..7c5be87 100644 --- a/mlx_tune/trainer.py +++ b/mlx_tune/trainer.py @@ -261,11 +261,17 @@ def save_model_hf_format( # CRITICAL: Fuse LoRA adapters into base weights before saving # Without this, LoRA layers are saved as-is and won't load properly - fused_linears = [ - (n, m.fuse(dequantize=kwargs.get('dequantize', False))) - for n, m in actual_model.named_modules() - if hasattr(m, "fuse") - ] + fused_linears = [] + for n, m in actual_model.named_modules(): + if not hasattr(m, "fuse"): + continue + try: + fused = m.fuse(dequantize=kwargs.get('dequantize', False)) + except TypeError as exc: + if "dequantize" not in str(exc): + raise + fused = m.fuse() + fused_linears.append((n, fused)) if fused_linears: print(f" Fusing {len(fused_linears)} LoRA layers into base model...") diff --git a/mlx_tune/trl_compat.py b/mlx_tune/trl_compat.py new file mode 100644 index 0000000..dd20de8 --- /dev/null +++ b/mlx_tune/trl_compat.py @@ -0,0 +1,251 @@ +""" +Compatibility patching for TRL/Unsloth-style imports. + +This module exposes ``PatchFastRL`` so MLX-Tune can populate or mutate a +``trl`` module in-place with trainer/config shims backed by MLX-Tune's native +implementations. +""" + +from __future__ import annotations + +from dataclasses import asdict, is_dataclass +import inspect +import sys +from types import ModuleType +from typing import Any, Dict, Mapping, Tuple, Type + +from mlx_tune.rl_trainers import ( + DPOConfig as _DPOConfig, + DPOTrainer as _DPOTrainer, + GRPOConfig as _GRPOConfig, + GRPOTrainer as _GRPOTrainer, + KTOConfig as _KTOConfig, + KTOTrainer as _KTOTrainer, + OnlineDPOConfig as _OnlineDPOConfig, + OnlineDPOTrainer as _OnlineDPOTrainer, + ORPOConfig as _ORPOConfig, + ORPOTrainer as _ORPOTrainer, + PPOConfig as _PPOConfig, + PPOTrainer as _PPOTrainer, + RewardConfig as _RewardConfig, + RewardTrainer as _RewardTrainer, + SimPOConfig as _SimPOConfig, + SimPOTrainer as _SimPOTrainer, +) +from mlx_tune.sft_trainer import SFTConfig as _SFTConfig +from mlx_tune.sft_trainer import SFTTrainer as _SFTTrainer + + +def _normalize_alias_kwargs(kwargs: Mapping[str, Any]) -> Dict[str, Any]: + normalized = dict(kwargs) + + if "processing_class" in normalized: + normalized.setdefault("tokenizer", normalized["processing_class"]) + normalized.pop("processing_class", None) + + if "reward_funcs" in normalized: + normalized.setdefault("reward_sources", normalized["reward_funcs"]) + normalized.pop("reward_funcs", None) + + if "reward_func" in normalized: + normalized.setdefault("reward_fn", normalized["reward_func"]) + normalized.pop("reward_func", None) + + return normalized + + +def _extract_public_attrs(value: Any) -> Dict[str, Any]: + if value is None: + return {} + if isinstance(value, Mapping): + return dict(value) + if hasattr(value, "to_dict") and callable(value.to_dict): + return dict(value.to_dict()) + if is_dataclass(value) and not isinstance(value, type): + return dict(asdict(value)) + if hasattr(value, "__dict__"): + return { + key: item + for key, item in vars(value).items() + if not key.startswith("_") and not callable(item) + } + + public: Dict[str, Any] = {} + for key in dir(value): + if key.startswith("_"): + continue + try: + item = getattr(value, key) + except Exception: + continue + if callable(item): + continue + public[key] = item + return public + + +def _config_field_names(config_class: Type[Any]) -> set[str]: + parameters = inspect.signature(config_class.__init__).parameters + return { + name + for name, parameter in parameters.items() + if name != "self" and parameter.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} + } + + +def _trainer_param_names(trainer_class: Type[Any]) -> set[str]: + parameters = inspect.signature(trainer_class.__init__).parameters + return { + name + for name, parameter in parameters.items() + if name != "self" and parameter.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} + } + + +def _coerce_config( + config_value: Any, + config_class: Type[Any], + config_overrides: Mapping[str, Any] | None = None, +) -> Any: + if isinstance(config_value, config_class): + if not config_overrides: + return config_value + payload = _extract_public_attrs(config_value) + else: + payload = _extract_public_attrs(config_value) + + if config_overrides: + payload.update(config_overrides) + return config_class(**_normalize_alias_kwargs(payload)) + + +def _prepare_trainer_kwargs( + kwargs: Mapping[str, Any], + trainer_class: Type[Any], + config_class: Type[Any], +) -> Dict[str, Any]: + normalized = _normalize_alias_kwargs(kwargs) + trainer_params = _trainer_param_names(trainer_class) + config_fields = _config_field_names(config_class) + + config_overrides: Dict[str, Any] = {} + for key in list(normalized): + if key in {"args"} or key in trainer_params: + continue + if key in config_fields: + config_overrides[key] = normalized.pop(key) + + if "args" not in normalized: + if config_overrides: + normalized["args"] = config_class(**config_overrides) + return normalized + + args_value = normalized.get("args") + if args_value is None: + if config_overrides: + normalized["args"] = config_class(**config_overrides) + return normalized + + normalized["args"] = _coerce_config(args_value, config_class, config_overrides=config_overrides) + return normalized + + +def _build_compat_config_class(export_name: str, config_class: Type[Any]) -> Type[Any]: + class CompatConfig(config_class): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **_normalize_alias_kwargs(kwargs)) + + CompatConfig.__name__ = export_name + CompatConfig.__qualname__ = export_name + CompatConfig.__module__ = "trl" + CompatConfig.__doc__ = config_class.__doc__ + return CompatConfig + + +def _build_compat_trainer_class( + export_name: str, + trainer_class: Type[Any], + config_class: Type[Any], +) -> Type[Any]: + class CompatTrainer(trainer_class): + def __init__(self, *args: Any, **kwargs: Any) -> None: + prepared_kwargs = _prepare_trainer_kwargs(kwargs, trainer_class, config_class) + super().__init__(*args, **prepared_kwargs) + + CompatTrainer.__name__ = export_name + CompatTrainer.__qualname__ = export_name + CompatTrainer.__module__ = "trl" + CompatTrainer.__doc__ = trainer_class.__doc__ + return CompatTrainer + + +_COMPAT_PAIRS: Tuple[Tuple[str, Type[Any], Type[Any]], ...] = ( + ("SFT", _SFTTrainer, _SFTConfig), + ("Reward", _RewardTrainer, _RewardConfig), + ("DPO", _DPOTrainer, _DPOConfig), + ("ORPO", _ORPOTrainer, _ORPOConfig), + ("GRPO", _GRPOTrainer, _GRPOConfig), + ("PPO", _PPOTrainer, _PPOConfig), + ("OnlineDPO", _OnlineDPOTrainer, _OnlineDPOConfig), + ("KTO", _KTOTrainer, _KTOConfig), + ("SimPO", _SimPOTrainer, _SimPOConfig), +) + +_COMPAT_CLASSES: Dict[str, Type[Any]] = {} +for stem, trainer_class, config_class in _COMPAT_PAIRS: + compat_config = _build_compat_config_class(f"{stem}Config", config_class) + compat_trainer = _build_compat_trainer_class(f"{stem}Trainer", trainer_class, config_class) + _COMPAT_CLASSES[compat_config.__name__] = compat_config + _COMPAT_CLASSES[compat_trainer.__name__] = compat_trainer + +_COMPAT_EXPORT_NAMES = tuple(_COMPAT_CLASSES) + + +def _ensure_trl_modules() -> tuple[ModuleType, ModuleType]: + trl_module = sys.modules.get("trl") + if trl_module is None: + trl_module = ModuleType("trl") + trl_module.__package__ = "trl" + trl_module.__path__ = [] + trl_module.__version__ = "0.0.0-mlx-tune" + sys.modules["trl"] = trl_module + else: + if not hasattr(trl_module, "__package__"): + trl_module.__package__ = "trl" + if not hasattr(trl_module, "__path__"): + trl_module.__path__ = [] + if not hasattr(trl_module, "__version__"): + trl_module.__version__ = "0.0.0-mlx-tune" + + trainer_module = sys.modules.get("trl.trainer") + if trainer_module is None: + trainer_module = ModuleType("trl.trainer") + trainer_module.__package__ = "trl" + sys.modules["trl.trainer"] = trainer_module + else: + if not hasattr(trainer_module, "__package__"): + trainer_module.__package__ = "trl" + + trl_module.trainer = trainer_module + return trl_module, trainer_module + + +def PatchFastRL(algorithm: Any = None, FastLanguageModel: Any = None) -> None: + """ + Patch ``trl`` so top-level trainer/config imports resolve to MLX-Tune. + + ``algorithm`` and ``FastLanguageModel`` are accepted for source compatibility + with Unsloth's ``PatchFastRL`` signature. + """ + del algorithm, FastLanguageModel + + trl_module, trainer_module = _ensure_trl_modules() + for export_name, export_value in _COMPAT_CLASSES.items(): + setattr(trl_module, export_name, export_value) + setattr(trainer_module, export_name, export_value) + + trl_module.__all__ = list(_COMPAT_EXPORT_NAMES) + trainer_module.__all__ = list(_COMPAT_EXPORT_NAMES) + trl_module.__MLX_TUNE_PATCHED__ = True + trainer_module.__MLX_TUNE_PATCHED__ = True + diff --git a/tests/test_arithmetic_grpo_validation.py b/tests/test_arithmetic_grpo_validation.py new file mode 100644 index 0000000..70ebeb1 --- /dev/null +++ b/tests/test_arithmetic_grpo_validation.py @@ -0,0 +1,160 @@ +from pathlib import Path + +from mlx_tune.arithmetic_grpo_validation import ( + ArithmeticSolutionReward, + generate_benchmark_splits, + parse_solution_response, + run_baseline, + run_compare, + run_training, + score_solution_output, +) + + +class FakeTokenizer: + def apply_chat_template(self, messages, add_generation_prompt=True, tokenize=False): + prompt = "\n".join(message["content"] for message in messages) + return prompt + ("\nassistant:" if add_generation_prompt else "") + + def encode(self, text, add_special_tokens=False): + del add_special_tokens + return [ord(char) % 256 for char in text] + + +class FakeModel: + def __init__(self): + self.trained = False + self.loaded_adapter = None + + def load_adapter(self, path): + self.loaded_adapter = path + self.trained = True + + def generate(self, prompt, max_tokens, sampler=None, verbose=False): + del max_tokens, sampler, verbose + expression = prompt.split("Compute exactly:\n", 1)[1].split("\nassistant:", 1)[0].strip() + answer = str(int(eval(expression, {"__builtins__": {}}, {}))) + if self.trained: + return f"\nchecking\n\n{answer}" + return answer + + +class FakeTrainer: + def __init__(self, model, train_dataset, eval_dataset, tokenizer, reward_fn, args): + self.model = model + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.tokenizer = tokenizer + self.reward_fn = reward_fn + self.args = args + + def train(self): + self.model.trained = True + policy_dir = Path(self.args.output_dir) / "policy" + policy_dir.mkdir(parents=True, exist_ok=True) + (policy_dir / "adapters.safetensors").write_text("fake") + (policy_dir / "adapter_config.json").write_text("{}") + return {"status": "success", "global_step": self.args.max_steps} + + +def _fake_load_model_bundle(*args, **kwargs): + del args, kwargs + return FakeModel(), FakeTokenizer() + + +def test_generate_benchmark_splits_is_deterministic_and_disjoint(tmp_path): + first_dir = tmp_path / "first" + second_dir = tmp_path / "second" + generate_benchmark_splits(first_dir, train_size=12, val_size=4, test_size=4, seed=7) + generate_benchmark_splits(second_dir, train_size=12, val_size=4, test_size=4, seed=7) + + first_paths = {split: (first_dir / "datasets" / f"{split}.jsonl").read_text() for split in ("train", "val", "test")} + second_paths = {split: (second_dir / "datasets" / f"{split}.jsonl").read_text() for split in ("train", "val", "test")} + + assert first_paths == second_paths + + expressions = {} + for split in ("train", "val", "test"): + rows = [line for line in first_paths[split].strip().splitlines() if line] + expressions[split] = {__import__("json").loads(line)["expression"] for line in rows} + + assert expressions["train"].isdisjoint(expressions["val"]) + assert expressions["train"].isdisjoint(expressions["test"]) + assert expressions["val"].isdisjoint(expressions["test"]) + + +def test_parse_and_score_solution_output(): + parsed = parse_solution_response("x\n42\n") + assert parsed["single_solution_tag"] is True + assert parsed["parseable_solution"] is True + assert parsed["parsed_answer"] == 42 + + exact = score_solution_output("42", "42") + assert exact["exact_match"] is True + assert exact["reward"] == 1.1 + + wrong = score_solution_output("41", "42") + assert wrong["exact_match"] is False + assert wrong["reward"] == 0.1 + + missing = score_solution_output("42", "42") + assert missing["single_solution_tag"] is False + assert missing["reward"] == 0.0 + + multiple = score_solution_output("4142", "42") + assert multiple["multiple_solution_tags"] is True + assert multiple["reward"] == 0.0 + + +def test_reward_evaluator_returns_components(): + evaluator = ArithmeticSolutionReward() + result = evaluator.evaluate({"completion_text": "7", "reward_context": "7"}) + assert result["reward"] == 1.1 + assert result["components"]["solution_tag"] == 0.1 + assert result["components"]["correctness"] == 1.0 + + +def test_baseline_train_and_compare_smoke(tmp_path, monkeypatch): + monkeypatch.setattr("mlx_tune.arithmetic_grpo_validation.load_model_bundle", _fake_load_model_bundle) + monkeypatch.setattr("mlx_tune.arithmetic_grpo_validation.GRPOTrainer", FakeTrainer) + monkeypatch.setattr("mlx_tune.arithmetic_grpo_validation.FastLanguageModel.for_inference", lambda model: model) + + generate_benchmark_splits(tmp_path, train_size=8, val_size=3, test_size=3, seed=0) + + baseline = run_baseline( + tmp_path, + model_name="fake-qwen", + seed=0, + max_seq_length=64, + max_completion_length=32, + ) + assert baseline["aggregate"]["exact_match"] == 0.0 + assert (tmp_path / "baseline_outputs.jsonl").exists() + assert (tmp_path / "baseline_metrics.json").exists() + + trained = run_training( + tmp_path, + model_name="fake-qwen", + seed=0, + max_seq_length=64, + max_completion_length=32, + max_steps=3, + learning_rate=1e-6, + per_device_train_batch_size=2, + rollout_batch_size=2, + num_generations=2, + rl_temperature=0.9, + lora_rank=4, + logging_steps=1, + eval_steps=1, + save_steps=1, + ) + assert trained["training"]["status"] == "success" + assert trained["post_eval"]["aggregate"]["exact_match"] == 1.0 + assert (tmp_path / "post_rl_outputs.jsonl").exists() + assert (tmp_path / "post_rl_metrics.json").exists() + + comparison = run_compare(tmp_path) + assert comparison["aggregate_delta"]["exact_match"] > 0.0 + assert (tmp_path / "comparison.json").exists() + assert (tmp_path / "comparison.md").exists() diff --git a/tests/test_losses.py b/tests/test_losses.py index 10bb304..db18a20 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -1,220 +1,377 @@ """ -Unit tests for loss functions in mlx_tune.losses +Unit tests for mlx_tune.losses. """ -import pytest import mlx.core as mx import mlx.nn as nn -class TestComputeLogProbs: - """Test log probability computation.""" +class TinyModel(nn.Module): + def __init__(self, vocab_size: int = 32, hidden_size: int = 16): + super().__init__() + self.embedding = nn.Embedding(vocab_size, hidden_size) + self.output = nn.Linear(hidden_size, vocab_size) - def test_compute_log_probs_shape(self): - """Test output shape of compute_log_probs.""" - from mlx_tune.losses import compute_log_probs_with_lengths + def __call__(self, x): + return self.output(self.embedding(x)) - # Create a simple mock model - class MockModel(nn.Module): - def __init__(self): - super().__init__() - self.embedding = nn.Embedding(100, 64) - self.linear = nn.Linear(64, 100) - def __call__(self, x): - return self.linear(self.embedding(x)) +class TinyTokenizer: + eos_token_id = 1 - model = MockModel() - mx.eval(model.parameters()) - # Test input - batch_size = 2 - seq_len = 10 - input_ids = mx.random.randint(0, 100, (batch_size, seq_len)) - lengths = mx.array([8, 6]) +class ConstantRewardModel: + def score(self, input_ids, **kwargs): + del input_ids, kwargs + return mx.array([1.0, 2.0], dtype=mx.float32) - log_probs = compute_log_probs_with_lengths(model, input_ids, lengths) - assert log_probs.shape == (batch_size,), f"Expected shape {(batch_size,)}, got {log_probs.shape}" +def test_compute_log_probs_with_lengths_shape(): + from mlx_tune.losses import compute_log_probs_with_lengths - def test_compute_log_probs_values(self): - """Test that log probs are negative (as expected for probabilities).""" - from mlx_tune.losses import compute_log_probs_with_lengths + model = TinyModel() + mx.eval(model.parameters()) - class MockModel(nn.Module): - def __init__(self): - super().__init__() - self.embedding = nn.Embedding(100, 64) - self.linear = nn.Linear(64, 100) + input_ids = mx.array([[1, 2, 3, 4], [1, 5, 6, 0]]) + lengths = mx.array([3, 2]) - def __call__(self, x): - return self.linear(self.embedding(x)) + log_probs = compute_log_probs_with_lengths(model, input_ids, lengths) + assert log_probs.shape == (2,) - model = MockModel() - mx.eval(model.parameters()) - input_ids = mx.random.randint(0, 100, (2, 10)) - lengths = mx.array([8, 6]) +def test_precompute_preference_reference_logprobs_matches_direct(): + from mlx_tune.losses import ( + compute_reference_logprobs, + precompute_preference_reference_logprobs, + ) - log_probs = compute_log_probs_with_lengths(model, input_ids, lengths) - mx.eval(log_probs) + model = TinyModel() + mx.eval(model.parameters()) - # Log probabilities should be negative (or zero at maximum) - assert mx.all(log_probs <= 0), "Log probabilities should be non-positive" + chosen_ids = mx.array([[1, 2, 3, 4], [1, 4, 5, 6]]) + rejected_ids = mx.array([[1, 3, 2, 4], [1, 6, 5, 4]]) + chosen_lengths = mx.array([3, 3]) + rejected_lengths = mx.array([3, 3]) + direct = compute_reference_logprobs( + model, + chosen_ids, + rejected_ids, + chosen_lengths, + rejected_lengths, + ) + batched = precompute_preference_reference_logprobs( + model, + chosen_ids, + rejected_ids, + chosen_lengths, + rejected_lengths, + batch_size=1, + ) -class TestDPOLoss: - """Test DPO loss computation.""" + assert mx.allclose(direct[0], batched[0]) + assert mx.allclose(direct[1], batched[1]) - def test_dpo_loss_shape(self): - """Test DPO loss returns scalar.""" - from mlx_tune.losses import dpo_loss - class MockModel(nn.Module): - def __init__(self): - super().__init__() - self.embedding = nn.Embedding(100, 64) - self.linear = nn.Linear(64, 100) - - def __call__(self, x): - return self.linear(self.embedding(x)) - - model = MockModel() - mx.eval(model.parameters()) - - batch_size = 2 - seq_len = 10 - chosen_ids = mx.random.randint(0, 100, (batch_size, seq_len)) - rejected_ids = mx.random.randint(0, 100, (batch_size, seq_len)) - chosen_lengths = mx.array([8, 7]) - rejected_lengths = mx.array([9, 6]) - - loss, ntoks = dpo_loss( - model, chosen_ids, rejected_ids, - chosen_lengths, rejected_lengths, - beta=0.1 - ) - - assert loss.shape == (), f"Loss should be scalar, got shape {loss.shape}" - assert ntoks.shape == (), f"ntoks should be scalar, got shape {ntoks.shape}" - - def test_dpo_loss_beta_effect(self): - """Test that higher beta increases loss magnitude.""" - from mlx_tune.losses import dpo_loss - - class MockModel(nn.Module): - def __init__(self): - super().__init__() - self.embedding = nn.Embedding(100, 64) - self.linear = nn.Linear(64, 100) - - def __call__(self, x): - return self.linear(self.embedding(x)) - - model = MockModel() - mx.eval(model.parameters()) - - chosen_ids = mx.random.randint(0, 100, (2, 10)) - rejected_ids = mx.random.randint(0, 100, (2, 10)) - chosen_lengths = mx.array([8, 7]) - rejected_lengths = mx.array([9, 6]) - - loss_low_beta, _ = dpo_loss(model, chosen_ids, rejected_ids, - chosen_lengths, rejected_lengths, beta=0.01) - loss_high_beta, _ = dpo_loss(model, chosen_ids, rejected_ids, - chosen_lengths, rejected_lengths, beta=1.0) - - mx.eval(loss_low_beta, loss_high_beta) - - # Both losses should be finite - assert not mx.isnan(loss_low_beta), "Low beta loss should not be NaN" - assert not mx.isnan(loss_high_beta), "High beta loss should not be NaN" - - -class TestORPOLoss: - """Test ORPO loss computation.""" - - def test_orpo_loss_shape(self): - """Test ORPO loss returns scalar.""" - from mlx_tune.losses import orpo_loss - - class MockModel(nn.Module): - def __init__(self): - super().__init__() - self.embedding = nn.Embedding(100, 64) - self.linear = nn.Linear(64, 100) - - def __call__(self, x): - return self.linear(self.embedding(x)) - - model = MockModel() - mx.eval(model.parameters()) - - chosen_ids = mx.random.randint(0, 100, (2, 10)) - rejected_ids = mx.random.randint(0, 100, (2, 10)) - chosen_lengths = mx.array([8, 7]) - rejected_lengths = mx.array([9, 6]) - - loss, ntoks = orpo_loss(model, chosen_ids, rejected_ids, - chosen_lengths, rejected_lengths, beta=0.1) - - assert loss.shape == (), f"Loss should be scalar, got shape {loss.shape}" - - -class TestSimPOLoss: - """Test SimPO loss computation.""" - - def test_simpo_loss_shape(self): - """Test SimPO loss returns scalar.""" - from mlx_tune.losses import simpo_loss - - class MockModel(nn.Module): - def __init__(self): - super().__init__() - self.embedding = nn.Embedding(100, 64) - self.linear = nn.Linear(64, 100) - - def __call__(self, x): - return self.linear(self.embedding(x)) - - model = MockModel() - mx.eval(model.parameters()) - - chosen_ids = mx.random.randint(0, 100, (2, 10)) - rejected_ids = mx.random.randint(0, 100, (2, 10)) - chosen_lengths = mx.array([8, 7]) - rejected_lengths = mx.array([9, 6]) - - loss, ntoks = simpo_loss(model, chosen_ids, rejected_ids, - chosen_lengths, rejected_lengths, - beta=2.0, gamma=0.5) - - assert loss.shape == (), f"Loss should be scalar, got shape {loss.shape}" - - -class TestSFTLoss: - """Test SFT loss computation.""" - - def test_sft_loss_shape(self): - """Test SFT loss returns scalar.""" - from mlx_tune.losses import sft_loss - - class MockModel(nn.Module): - def __init__(self): - super().__init__() - self.embedding = nn.Embedding(100, 64) - self.linear = nn.Linear(64, 100) - - def __call__(self, x): - return self.linear(self.embedding(x)) - - model = MockModel() - mx.eval(model.parameters()) - - input_ids = mx.random.randint(0, 100, (2, 10)) - lengths = mx.array([8, 6]) - - loss, ntoks = sft_loss(model, input_ids, lengths) - - assert loss.shape == (), f"Loss should be scalar, got shape {loss.shape}" - assert loss.item() > 0, "Cross entropy loss should be positive" +def test_precompute_kto_reference_logprobs_matches_direct(): + from mlx_tune.losses import compute_log_probs_with_lengths, precompute_kto_reference_logprobs + + model = TinyModel() + mx.eval(model.parameters()) + + input_ids = mx.array([[1, 2, 3, 4], [1, 6, 7, 8]]) + lengths = mx.array([3, 3]) + + direct = compute_log_probs_with_lengths(model, input_ids, lengths) + cached = precompute_kto_reference_logprobs(model, input_ids, lengths, batch_size=1) + + assert mx.allclose(direct, cached) + + +def test_compute_completion_log_probs_masks_prompt_tokens(): + from mlx_tune.losses import compute_completion_log_probs + + model = TinyModel() + mx.eval(model.parameters()) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + prompt_lengths = mx.array([3]) + completion_lengths = mx.array([2]) + + completion_only = compute_completion_log_probs( + model, + input_ids, + prompt_lengths, + completion_lengths, + ) + + first_completion = compute_completion_log_probs( + model, + input_ids, + mx.array([4]), + mx.array([1]), + ) + + assert completion_only.shape == (1,) + assert not mx.allclose(completion_only, first_completion) + + +def test_grpo_recompute_loss_is_finite_and_trainable(): + from mlx_tune.losses import ( + compute_completion_log_probs, + grpo_recompute_loss, + ) + + policy = TinyModel() + reference = TinyModel() + mx.eval(policy.parameters(), reference.parameters()) + + input_ids = mx.array([[1, 2, 3, 4, 5], [1, 4, 3, 2, 1]]) + prompt_lengths = mx.array([3, 2]) + completion_lengths = mx.array([2, 3]) + rollout_logprobs = compute_completion_log_probs( + policy, + input_ids, + prompt_lengths, + completion_lengths, + ) + advantages = mx.array([1.0, -0.5]) + + loss, ntoks = grpo_recompute_loss( + model=policy, + reference_model=reference, + input_ids=input_ids, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + rollout_logprobs=rollout_logprobs, + advantages=advantages, + beta=0.04, + ) + + mx.eval(loss, ntoks) + assert loss.shape == () + assert ntoks.item() == 5 + + +def test_grpo_recompute_kl_penalty_is_temperature_invariant_when_advantages_are_zero(): + from mlx_tune.losses import grpo_recompute_loss + + policy = TinyModel() + reference = TinyModel() + mx.eval(policy.parameters(), reference.parameters()) + + input_ids = mx.array([[1, 2, 3, 4, 5], [1, 4, 3, 2, 1]]) + prompt_lengths = mx.array([3, 2]) + completion_lengths = mx.array([2, 3]) + rollout_logprobs = mx.array([0.0, 0.0], dtype=mx.float32) + advantages = mx.array([0.0, 0.0], dtype=mx.float32) + + loss_at_one, _ = grpo_recompute_loss( + model=policy, + reference_model=reference, + input_ids=input_ids, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + rollout_logprobs=rollout_logprobs, + advantages=advantages, + beta=0.04, + temperature=1.0, + ) + loss_at_point_seven, _ = grpo_recompute_loss( + model=policy, + reference_model=reference, + input_ids=input_ids, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + rollout_logprobs=rollout_logprobs, + advantages=advantages, + beta=0.04, + temperature=0.7, + ) + + assert mx.allclose(loss_at_one, loss_at_point_seven) + + +def test_ppo_kl_penalty_is_temperature_invariant_when_advantages_are_zero(): + from mlx_tune._rl_runtime import make_policy_eval_batch, score_policy + from mlx_tune.losses import ppo_sequence_loss + + policy = TinyModel() + mx.eval(policy.parameters()) + + input_ids = [[1, 2, 3, 4, 5], [1, 4, 3, 2, 1]] + prompt_lengths = [3, 2] + completion_lengths = [2, 3] + batch = make_policy_eval_batch( + input_ids, + pad_id=0, + mode="completion", + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + old_logprobs=mx.array([0.0, 0.0], dtype=mx.float32), + advantages=mx.array([0.0, 0.0], dtype=mx.float32), + ) + raw_scores = score_policy(policy, batch, mode="completion", temperature=1.0) + batch.reference_logprobs = raw_scores.summed_logprobs + + loss_at_one, metrics_at_one = ppo_sequence_loss( + model=policy, + batch=batch, + beta=0.04, + temperature=1.0, + ) + loss_at_point_seven, metrics_at_point_seven = ppo_sequence_loss( + model=policy, + batch=batch, + beta=0.04, + temperature=0.7, + ) + + assert mx.allclose(loss_at_one, loss_at_point_seven) + assert mx.allclose(metrics_at_one["kl_penalty"], mx.zeros_like(metrics_at_one["kl_penalty"])) + assert mx.allclose( + metrics_at_point_seven["kl_penalty"], + mx.zeros_like(metrics_at_point_seven["kl_penalty"]), + ) + + +def test_grpo_rollout_and_recompute_logprobs_match_with_temperature(): + from mlx_tune.losses import compute_completion_log_probs, generate_with_log_probs + + model = TinyModel() + tokenizer = TinyTokenizer() + mx.eval(model.parameters()) + mx.random.seed(21) + + prompt_ids = mx.array([2, 7, 8]) + generated_ids, rollout_token_logprobs = generate_with_log_probs( + model, + tokenizer, + prompt_ids, + max_tokens=3, + temperature=0.7, + ) + completion_ids = generated_ids[len(prompt_ids):].tolist() + input_ids = mx.array([prompt_ids.tolist() + completion_ids]) + + recomputed = compute_completion_log_probs( + model, + input_ids, + mx.array([len(prompt_ids)]), + mx.array([len(completion_ids)]), + temperature=0.7, + ) + + assert mx.allclose(recomputed, mx.array([rollout_token_logprobs.sum()])) + + +def test_reward_model_regression_loss_returns_expected_predictions(): + from mlx_tune.losses import reward_model_regression_loss + + loss, predictions = reward_model_regression_loss( + ConstantRewardModel(), + input_ids=mx.array([[1, 2], [3, 4]], dtype=mx.int32), + sequence_lengths=mx.array([2, 2], dtype=mx.int32), + targets=mx.array([1.0, 2.0], dtype=mx.float32), + ) + + assert float(loss.item()) == 0.0 + assert predictions.tolist() == [1.0, 2.0] + + +def test_grpo_loss_routes_produce_distinct_losses_from_same_rollout(): + from mlx_tune._rl_runtime import make_policy_eval_batch, score_policy + from mlx_tune.losses import grpo_recompute_loss + + policy = TinyModel() + reference = TinyModel() + mx.eval(policy.parameters(), reference.parameters()) + + input_ids = mx.array([[1, 2, 3, 4, 5], [1, 4, 3, 2, 1]], dtype=mx.int32) + prompt_lengths = mx.array([3, 2], dtype=mx.int32) + completion_lengths = mx.array([2, 3], dtype=mx.int32) + batch = make_policy_eval_batch( + input_ids.tolist(), + pad_id=0, + mode="completion", + prompt_lengths=prompt_lengths.tolist(), + completion_lengths=completion_lengths.tolist(), + ) + scored = score_policy(policy, batch, mode="completion") + advantages = mx.array([1.0, -0.5], dtype=mx.float32) + + losses = { + name: grpo_recompute_loss( + model=policy, + reference_model=reference, + input_ids=input_ids, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + rollout_logprobs=scored.summed_logprobs, + old_token_logprobs=scored.token_logprobs * batch.token_mask.astype(mx.float32), + advantages=advantages, + loss_type=name, + max_completion_length=4, + )[0] + for name in ["grpo", "dapo", "dr_grpo", "gspo"] + } + + unique_losses = {round(float(loss.item()), 6) for loss in losses.values()} + assert len(unique_losses) >= 3 + + +def test_grpo_recompute_loss_supports_asymmetric_clip_bounds(): + from mlx_tune._rl_runtime import make_policy_eval_batch, score_policy + from mlx_tune.losses import grpo_recompute_loss + + policy = TinyModel() + reference = TinyModel() + mx.eval(policy.parameters(), reference.parameters()) + + input_ids = mx.array([[1, 2, 3, 4, 5], [1, 4, 3, 2, 1]], dtype=mx.int32) + prompt_lengths = mx.array([3, 2], dtype=mx.int32) + completion_lengths = mx.array([2, 3], dtype=mx.int32) + batch = make_policy_eval_batch( + input_ids.tolist(), + pad_id=0, + mode="completion", + prompt_lengths=prompt_lengths.tolist(), + completion_lengths=completion_lengths.tolist(), + ) + scored = score_policy(policy, batch, mode="completion") + advantages = mx.array([1.0, 0.5], dtype=mx.float32) + token_mask = batch.token_mask.astype(mx.float32) + old_token_logprobs = (scored.token_logprobs - (0.24 * token_mask)) * token_mask + + symmetric_loss, _ = grpo_recompute_loss( + model=policy, + reference_model=reference, + input_ids=input_ids, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + rollout_logprobs=scored.summed_logprobs, + old_token_logprobs=old_token_logprobs, + advantages=advantages, + loss_type="dapo", + clip_epsilon=0.2, + epsilon_low=0.2, + epsilon_high=0.2, + max_completion_length=4, + ) + asymmetric_loss, _ = grpo_recompute_loss( + model=policy, + reference_model=reference, + input_ids=input_ids, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + rollout_logprobs=scored.summed_logprobs, + old_token_logprobs=old_token_logprobs, + advantages=advantages, + loss_type="dapo", + clip_epsilon=0.2, + epsilon_low=0.2, + epsilon_high=0.28, + max_completion_length=4, + ) + + assert not mx.allclose(symmetric_loss, asymmetric_loss) diff --git a/tests/test_model.py b/tests/test_model.py index 7beef89..cdd903e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -222,6 +222,34 @@ def test_enable_inference_mode_method(self, wrapped_model): assert model.inference_mode is True assert model.use_cache is True + def test_wrapper_forwards_cache_argument(self): + """Test that the wrapper exposes and forwards cache for RL decoding.""" + from mlx_tune.model import MLXModelWrapper + + class MockCacheModel: + def __init__(self): + self.calls = [] + + def __call__(self, x, cache=None): + self.calls.append((x, cache)) + return ("logits", {"seen": True}) if cache is not None else "logits" + + wrapped = MLXModelWrapper( + model=MockCacheModel(), + tokenizer=None, + max_seq_length=128, + model_name="mock", + ) + + assert wrapped("tokens") == "logits" + assert wrapped("tokens", cache={"kv": 1}) == ("logits", {"seen": True}) + assert wrapped.forward_with_cache("tokens", cache={"kv": 2}) == ("logits", {"seen": True}) + assert wrapped.model.calls == [ + ("tokens", None), + ("tokens", {"kv": 1}), + ("tokens", {"kv": 2}), + ] + class TestGGUFExportFix: """Test cases for GGUF export fix (GitHub issue #3). diff --git a/tests/test_rl_api.py b/tests/test_rl_api.py new file mode 100644 index 0000000..e96710f --- /dev/null +++ b/tests/test_rl_api.py @@ -0,0 +1,128 @@ +import mlx.core as mx +import mlx.nn as nn +import pytest + + +class SmallBackbone(nn.Module): + def __init__(self, vocab_size: int = 64, hidden_size: int = 32, num_layers: int = 2): + super().__init__() + self.embedding = nn.Embedding(vocab_size, hidden_size) + self.layers = [nn.Linear(hidden_size, hidden_size) for _ in range(num_layers)] + + def __call__(self, x): + h = self.embedding(x) + for layer in self.layers: + h = mx.maximum(layer(h), 0) + return h + + +class SmallLanguageModel(nn.Module): + def __init__(self, vocab_size: int = 64, hidden_size: int = 32, num_layers: int = 2): + super().__init__() + self.model = SmallBackbone(vocab_size=vocab_size, hidden_size=hidden_size, num_layers=num_layers) + self.output = nn.Linear(hidden_size, vocab_size) + + def __call__(self, x): + return self.output(self.model(x)) + + +class MockTokenizer: + def __init__(self, vocab_size: int = 64): + self.vocab_size = vocab_size + self.pad_token_id = 0 + self.eos_token_id = 1 + self.bos_token_id = 2 + + def encode(self, text: str, add_special_tokens: bool = True): + ids = [((ord(char) % (self.vocab_size - 3)) + 3) for char in text[:32]] + if add_special_tokens: + ids = [self.bos_token_id] + ids + [self.eos_token_id] + return ids + + def decode(self, ids, skip_special_tokens: bool = True): + if skip_special_tokens: + ids = [token for token in ids if token not in (self.pad_token_id, self.eos_token_id, self.bos_token_id)] + return "".join(chr(65 + (token % 26)) for token in ids) + + +class MockModelWrapper: + def __init__(self, model: SmallLanguageModel): + self.model = model + self._lora_applied = False + + def __call__(self, x): + return self.model(x) + + def _apply_lora(self): + self._lora_applied = True + return True + + +def make_model(seed: int) -> MockModelWrapper: + mx.random.seed(seed) + model = SmallLanguageModel() + mx.eval(model.parameters()) + return MockModelWrapper(model) + + +def test_prepare_rl_dataset_auto_detect_and_chat_adaptation(): + from mlx_tune import prepare_rl_dataset + + prompt_dataset = prepare_rl_dataset([{"prompt": "Solve 2 + 2", "answer": "4"}]) + assert prompt_dataset.mode == "prompt" + assert prompt_dataset.samples[0]["reward_context"] == "4" + + chat_dataset = prepare_rl_dataset( + [ + { + "messages": [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "score": 1.0, + } + ], + mode="reward_scalar", + tokenizer=MockTokenizer(), + ) + assert chat_dataset.mode == "reward_scalar" + assert chat_dataset.adapter_name == "chat_reward_scalar" + assert chat_dataset.samples[0]["response"] == "Hello" + + +def test_prepare_rl_dataset_raises_on_ambiguous_pairwise_schema(): + from mlx_tune import prepare_rl_dataset + + with pytest.raises(ValueError, match="Ambiguous RL dataset schema"): + prepare_rl_dataset([{"prompt": "Q", "chosen": "A", "rejected": "B"}]) + + +def test_resume_from_checkpoint_returns_bundle_with_manifest_fields(tmp_path): + from mlx_tune import RewardConfig, RewardTrainer, build_reward_model, resume_from_checkpoint + + tokenizer = MockTokenizer() + reward_model = build_reward_model(make_model(100)) + trainer = RewardTrainer( + model=reward_model, + train_dataset=[ + {"prompt": "Q:", "response": "good", "score": 1.0}, + {"prompt": "Q:", "response": "bad", "score": 0.0}, + ], + tokenizer=tokenizer, + args=RewardConfig( + learning_rate=1e-2, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(tmp_path / "reward"), + ), + ) + + trainer.train() + bundle = resume_from_checkpoint(tmp_path / "reward") + + assert bundle.algorithm == "reward" + assert "reward_model" in bundle.restored_roles + assert bundle.trainer_state["global_step"] == 1 + assert bundle.metrics_history + assert bundle.source_format == "manifest" diff --git a/tests/test_rl_model_roles.py b/tests/test_rl_model_roles.py new file mode 100644 index 0000000..3c091ba --- /dev/null +++ b/tests/test_rl_model_roles.py @@ -0,0 +1,336 @@ +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten + +from mlx_tune import ( + build_reference_policy, + build_reward_model, + build_value_model, + create_rl_model_roles, + pairwise_ranking_accuracy, + reward_model_pairwise_loss, + scalar_loss_metrics, + value_model_regression_loss, +) +from mlx_tune.model import MLXModelWrapper + + +class MockTokenizer: + pad_token_id = 0 + eos_token_id = 1 + bos_token_id = 2 + + def encode(self, text: str, add_special_tokens: bool = True): + ids = [((ord(char) % 20) + 3) for char in text[:16]] + if add_special_tokens: + return [self.bos_token_id] + ids + [self.eos_token_id] + return ids + + def decode(self, ids, skip_special_tokens: bool = True): + if skip_special_tokens: + ids = [token for token in ids if token not in (0, 1, 2)] + return "".join(chr(65 + (token % 26)) for token in ids) + + +class DeterministicBackbone(nn.Module): + def __init__(self): + super().__init__() + self.hidden_size = 2 + + def __call__(self, x): + values = x.astype(mx.float32) + return mx.stack([values, values * 10.0], axis=-1) + + +class DeterministicCausalLM(nn.Module): + def __init__(self, vocab_size: int = 64): + super().__init__() + self.model = DeterministicBackbone() + self.output = nn.Linear(2, vocab_size) + + def __call__(self, x): + return self.output(self.model(x)) + + +class WeightedBackbone(nn.Module): + def __init__(self): + super().__init__() + self.hidden_size = 2 + self.proj = nn.Linear(1, 2) + + def __call__(self, x): + return self.proj(x.astype(mx.float32)[..., None]) + + +class WeightedCausalLM(nn.Module): + def __init__(self, vocab_size: int = 64): + super().__init__() + self.model = WeightedBackbone() + self.output = nn.Linear(2, vocab_size) + + def __call__(self, x): + return self.output(self.model(x)) + + +def make_wrapper() -> MLXModelWrapper: + model = DeterministicCausalLM() + mx.eval(model.parameters()) + wrapper = MLXModelWrapper( + model=model, + tokenizer=MockTokenizer(), + max_seq_length=32, + model_name="deterministic-model", + ) + return wrapper + + +def make_weighted_wrapper(seed: int) -> MLXModelWrapper: + mx.random.seed(seed) + model = WeightedCausalLM() + mx.eval(model.parameters()) + return MLXModelWrapper( + model=model, + tokenizer=MockTokenizer(), + max_seq_length=32, + model_name=f"weighted-model-{seed}", + ) + + +def _set_scalar_head_to_first_feature(role_model) -> None: + role_model.head.update( + { + "weight": mx.array([[1.0, 0.0]], dtype=mx.float32), + "bias": mx.array([0.0], dtype=mx.float32), + }, + strict=False, + ) + mx.eval(role_model.head.parameters()) + + +def _parameter_snapshot(model) -> dict[str, mx.array]: + actual_model = model.model if hasattr(model, "model") else model + return {name: mx.array(value) for name, value in tree_flatten(actual_model.parameters())} + + +def _parameters_match(before: dict[str, mx.array], after_model) -> bool: + actual_model = after_model.model if hasattr(after_model, "model") else after_model + after = {name: value for name, value in tree_flatten(actual_model.parameters())} + return all(mx.allclose(before[name], after[name]) for name in before) + + +def test_reference_policy_clone_is_frozen_and_isolated(): + policy = make_wrapper() + policy.lora_enabled = True + policy._lora_applied = True + policy.set_adapter_path("/tmp/live-policy") + + roles = create_rl_model_roles(policy) + reference = roles.reference_policy + + assert reference.model is not policy + assert reference.model.get_adapter_path() is None + assert not tree_flatten(reference.model.model.trainable_parameters()) + + reference_before = _parameter_snapshot(reference.model) + policy.model.output.update( + { + "weight": mx.ones_like(policy.model.output.weight), + "bias": mx.zeros_like(policy.model.output.bias), + }, + strict=False, + ) + mx.eval(policy.model.parameters()) + + assert _parameters_match(reference_before, reference.model) + + +def test_scalar_roles_default_to_last_completion_token_and_support_mean_pooling(): + base_model = make_wrapper() + + reward_model = build_reward_model(base_model) + _set_scalar_head_to_first_feature(reward_model) + + input_ids = mx.array([[2, 5, 7, 9], [2, 4, 6, 8]], dtype=mx.int32) + sequence_lengths = mx.array([4, 4], dtype=mx.int32) + prompt_lengths = mx.array([2, 3], dtype=mx.int32) + completion_lengths = mx.array([2, 1], dtype=mx.int32) + + reward_scores = reward_model.score( + input_ids, + sequence_lengths=sequence_lengths, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + ) + assert mx.allclose(reward_scores, mx.array([9.0, 8.0], dtype=mx.float32)) + + mean_completion_model = build_value_model(base_model, pooling="mean_completion") + _set_scalar_head_to_first_feature(mean_completion_model) + mean_completion = mean_completion_model.predict( + input_ids, + sequence_lengths=sequence_lengths, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + ) + assert mx.allclose(mean_completion, mx.array([8.0, 8.0], dtype=mx.float32)) + + mean_sequence_model = build_value_model(base_model, pooling="mean_sequence", target="sequence") + _set_scalar_head_to_first_feature(mean_sequence_model) + mean_sequence = mean_sequence_model.predict( + input_ids, + sequence_lengths=sequence_lengths, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + ) + assert mx.allclose(mean_sequence, mx.array([5.75, 5.0], dtype=mx.float32)) + + +def test_scalar_role_save_load_round_trip_preserves_head_and_adapter_state(tmp_path): + base_model = make_wrapper() + base_model.lora_enabled = True + base_model._lora_applied = True + + reward_model = build_reward_model(base_model) + _set_scalar_head_to_first_feature(reward_model) + reward_model.base_model.model.output.update( + { + "weight": mx.ones_like(reward_model.base_model.model.output.weight) * 0.5, + "bias": mx.ones_like(reward_model.base_model.model.output.bias) * 0.25, + }, + strict=False, + ) + mx.eval(reward_model.base_model.model.parameters()) + + output_dir = Path(tmp_path) / "reward_role" + reward_model.save_pretrained(str(output_dir)) + + restored = build_reward_model(make_wrapper()) + restored.load_pretrained(str(output_dir)) + + score_inputs = mx.array([[2, 4, 6]], dtype=mx.int32) + sequence_lengths = mx.array([3], dtype=mx.int32) + prompt_lengths = mx.array([1], dtype=mx.int32) + completion_lengths = mx.array([2], dtype=mx.int32) + + assert (output_dir / "head.safetensors").exists() + assert (output_dir / "head_config.json").exists() + assert (output_dir / "weights.safetensors").exists() + assert (output_dir / "adapters.safetensors").exists() + assert (output_dir / "adapter_config.json").exists() + assert mx.allclose( + reward_model.score( + score_inputs, + sequence_lengths=sequence_lengths, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + ), + restored.score( + score_inputs, + sequence_lengths=sequence_lengths, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + ), + ) + + +def test_scalar_role_save_load_round_trip_preserves_independent_backbone_weights(tmp_path): + original_base = make_weighted_wrapper(101) + + reward_model = build_reward_model(original_base) + _set_scalar_head_to_first_feature(reward_model) + + output_dir = Path(tmp_path) / "independent_reward_role" + reward_model.save_pretrained(str(output_dir)) + + restored = build_reward_model(make_weighted_wrapper(202)) + restored.load_pretrained(str(output_dir)) + + sequence = mx.array([[2, 5, 9]], dtype=mx.int32) + sequence_lengths = mx.array([3], dtype=mx.int32) + prompt_lengths = mx.array([1], dtype=mx.int32) + completion_lengths = mx.array([2], dtype=mx.int32) + + assert mx.allclose( + reward_model.base_model.model.model.proj.weight, + restored.base_model.model.model.proj.weight, + ) + assert mx.allclose( + reward_model.score( + sequence, + sequence_lengths=sequence_lengths, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + ), + restored.score( + sequence, + sequence_lengths=sequence_lengths, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + ), + ) + + +def test_snapshot_false_scalar_builders_preserve_caller_adapter_path(): + base_model = make_wrapper() + base_model.set_adapter_path("/tmp/existing-adapters") + + reward_model = build_reward_model(base_model, snapshot=False) + value_model = build_value_model(base_model, snapshot=False) + + assert reward_model.base_model is base_model + assert value_model.base_model is base_model + assert str(base_model.get_adapter_path()) == "/tmp/existing-adapters" + + +def test_scalar_objective_helpers_return_stable_losses_and_metrics(): + reward_model = build_reward_model(make_wrapper()) + _set_scalar_head_to_first_feature(reward_model) + + chosen_input_ids = mx.array([[2, 4, 9], [2, 3, 8]], dtype=mx.int32) + rejected_input_ids = mx.array([[2, 4, 5], [2, 3, 6]], dtype=mx.int32) + chosen_lengths = mx.array([3, 3], dtype=mx.int32) + rejected_lengths = mx.array([3, 3], dtype=mx.int32) + prompt_lengths = mx.array([1, 1], dtype=mx.int32) + completion_lengths = mx.array([2, 2], dtype=mx.int32) + + reward_loss, reward_outputs = reward_model_pairwise_loss( + reward_model, + chosen_input_ids=chosen_input_ids, + rejected_input_ids=rejected_input_ids, + chosen_sequence_lengths=chosen_lengths, + rejected_sequence_lengths=rejected_lengths, + chosen_prompt_lengths=prompt_lengths, + rejected_prompt_lengths=prompt_lengths, + chosen_completion_lengths=completion_lengths, + rejected_completion_lengths=completion_lengths, + ) + assert float(reward_loss.item()) > 0.0 + assert pairwise_ranking_accuracy( + reward_outputs["chosen_scores"], + reward_outputs["rejected_scores"], + ) == 1.0 + + value_model = build_value_model(make_wrapper()) + _set_scalar_head_to_first_feature(value_model) + value_targets = mx.array([9.0, 8.0], dtype=mx.float32) + value_loss, predictions = value_model_regression_loss( + value_model, + input_ids=chosen_input_ids, + sequence_lengths=chosen_lengths, + targets=value_targets, + prompt_lengths=prompt_lengths, + completion_lengths=completion_lengths, + ) + metrics = scalar_loss_metrics(value_loss, predictions, value_targets) + assert metrics["loss"] == 0.0 + assert metrics["mae"] == 0.0 + assert metrics["mse"] == 0.0 + + +def test_build_reference_policy_public_builder_returns_compat_wrapper(): + policy = make_wrapper() + reference = build_reference_policy(policy) + + assert reference.source == "policy_snapshot" + assert reference.metadata["snapshot_strategy"] == "clone_and_freeze" diff --git a/tests/test_rl_runtime.py b/tests/test_rl_runtime.py new file mode 100644 index 0000000..abdccda --- /dev/null +++ b/tests/test_rl_runtime.py @@ -0,0 +1,917 @@ +import math + +import mlx.core as mx +import mlx.nn as nn + + +class TinyModel(nn.Module): + def __init__(self, vocab_size: int = 32, hidden_size: int = 16): + super().__init__() + self.embedding = nn.Embedding(vocab_size, hidden_size) + self.output = nn.Linear(hidden_size, vocab_size) + + def __call__(self, x): + return self.output(self.embedding(x)) + + +class ScriptedModel(nn.Module): + def __init__(self, next_tokens, vocab_size: int = 16): + super().__init__() + self.next_tokens = dict(next_tokens) + self.vocab_size = vocab_size + + def __call__(self, x): + batch, seq_len = x.shape + logits = mx.full((batch, seq_len, self.vocab_size), -100.0) + token_id = self.next_tokens.get(seq_len, self.next_tokens.get("default", 0)) + logits[:, -1, token_id] = 100.0 + return logits + + +class CacheOnlyScriptedModel(nn.Module): + def __init__(self, next_tokens, vocab_size: int = 16): + super().__init__() + self.next_tokens = dict(next_tokens) + self.vocab_size = vocab_size + self.calls = [] + + def make_cache(self): + return [{"steps": 0}] + + def __call__(self, x, cache=None): + if cache is None: + raise ValueError("cache is required") + batch, seq_len = x.shape + cache[0]["steps"] += 1 + self.calls.append((seq_len, cache[0]["steps"])) + logits = mx.full((batch, seq_len, self.vocab_size), -100.0) + token_id = self.next_tokens.get(cache[0]["steps"], self.next_tokens.get("default", 0)) + logits[:, -1, token_id] = 100.0 + return logits + + +class BatchSizedCacheScriptedModel(nn.Module): + def __init__(self, vocab_size: int = 16): + super().__init__() + self.vocab_size = vocab_size + self.calls = [] + + def make_cache(self): + return [{"steps": 0, "batch_size": None}] + + def __call__(self, x, cache=None): + if cache is None: + raise ValueError("cache is required") + batch, seq_len = x.shape + if cache[0]["batch_size"] is None: + cache[0]["batch_size"] = batch + elif cache[0]["batch_size"] != batch: + raise ValueError("cache batch size mismatch") + cache[0]["steps"] += 1 + self.calls.append((batch, seq_len, cache[0]["steps"])) + logits = mx.full((batch, seq_len, self.vocab_size), -100.0) + for row_index in range(batch): + if cache[0]["steps"] == 1: + token_id = 1 if row_index == 0 else 5 + elif cache[0]["steps"] == 2: + token_id = 1 + else: + token_id = 1 + logits[row_index, -1, token_id] = 100.0 + return logits + + +class BatchSensitiveCacheModel(nn.Module): + def __init__(self, vocab_size: int = 16): + super().__init__() + self.vocab_size = vocab_size + self.cache_calls = [] + + def make_cache(self): + return [{"steps": 0}] + + def _token_for_sequence(self, row): + last_token = int(row[-1]) if len(row) else 0 + if last_token == 5: + return 1 + return 5 + + def __call__(self, x, cache=None): + batch, seq_len = x.shape + logits = mx.full((batch, seq_len, self.vocab_size), -100.0) + if cache is None: + for row_index in range(batch): + row = x[row_index].tolist() + for position in range(seq_len): + logits[row_index, position, self._token_for_sequence(row[: position + 1])] = 100.0 + return logits + + cache[0]["steps"] += 1 + self.cache_calls.append((batch, seq_len, cache[0]["steps"])) + for row_index in range(batch): + token_id = self._token_for_sequence(x[row_index].tolist()) + if batch > 1 and seq_len == 1: + token_id = 7 + logits[row_index, -1, token_id] = 100.0 + return logits + + +class TinyTokenizer: + pad_token_id = 0 + eos_token_id = 1 + bos_token_id = 2 + + def encode(self, text: str, add_special_tokens: bool = True): + if text == "<|im_end|>": + return [9] + ids = [((ord(char) % 10) + 3) for char in text] + if add_special_tokens: + ids = [self.bos_token_id] + ids + [self.eos_token_id] + return ids + + def decode(self, ids, skip_special_tokens: bool = True): + if skip_special_tokens: + ids = [token for token in ids if token not in (self.pad_token_id, self.eos_token_id, self.bos_token_id)] + return "".join(chr(65 + (token % 26)) for token in ids) + + def convert_tokens_to_ids(self, token: str): + if token == "<|im_end|>": + return 9 + return None + + def get_vocab(self): + return {"<|im_end|>": 9} + + +def test_collect_rollouts_returns_metadata_and_truncates_prompt_left(): + from mlx_tune._rl_runtime import collect_rollouts + + tokenizer = TinyTokenizer() + model = ScriptedModel({2: 5, 3: 1, "default": 1}) + + rollout = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=[ + { + "sample_index": 7, + "prompt": "abcdef", + "prompt_ids": [3, 4, 5, 6, 7, 8], + "reward_context": "ctx", + } + ], + sampling_config={ + "num_generations": 2, + "temperature": 0.0, + "max_completion_length": 4, + "max_seq_length": 6, + }, + ) + + assert rollout.prompt_ids == [[4, 5, 6, 7, 8], [4, 5, 6, 7, 8]] + assert rollout.prompt_lengths.tolist() == [5, 5] + assert rollout.prompt_texts == [tokenizer.decode([4, 5, 6, 7, 8]), tokenizer.decode([4, 5, 6, 7, 8])] + assert rollout.original_prompt_texts == ["abcdef", "abcdef"] + assert rollout.completion_ids == [[1], [1]] + assert rollout.completion_lengths.tolist() == [1, 1] + assert rollout.sampled_token_logprobs.shape == (2, 1) + assert rollout.eos_flags.tolist() == [True, True] + assert rollout.truncation_flags.tolist() == [False, False] + assert rollout.prompt_group_indices.tolist() == [0, 0] + assert rollout.sample_indices.tolist() == [7, 7] + assert rollout.policy_eval.input_ids.shape == (2, 6) + assert mx.allclose( + rollout.rollout_logprobs, + rollout.sampled_token_logprobs.sum(axis=-1), + ) + + +def test_collect_rollouts_can_capture_sample_stats_and_truncation(): + from mlx_tune._rl_runtime import collect_rollouts + + tokenizer = TinyTokenizer() + model = ScriptedModel({2: 5, 3: 6, 4: 7, "default": 7}) + + rollout = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=[ + { + "sample_index": 1, + "prompt": "ab", + "prompt_ids": [3, 4], + "reward_context": "ctx", + } + ], + sampling_config={ + "num_generations": 1, + "temperature": 0.0, + "max_completion_length": 2, + "max_seq_length": 8, + }, + collect_sample_stats=True, + ) + + assert rollout.completion_ids == [[5, 6]] + assert rollout.eos_flags.tolist() == [False] + assert rollout.truncation_flags.tolist() == [True] + assert rollout.sampled_token_logits.shape == (1, 2) + assert rollout.token_entropies.shape == (1, 2) + + +def test_collect_rollouts_sampled_logprobs_match_rescored_completion_logprobs(): + from mlx_tune._rl_runtime import collect_rollouts, score_policy + + tokenizer = TinyTokenizer() + mx.random.seed(0) + model = TinyModel() + mx.eval(model.parameters()) + + rollout = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=[ + { + "sample_index": 0, + "prompt": "abc", + "prompt_ids": [tokenizer.bos_token_id, 3, 4, 5], + "reward_context": "ctx", + } + ], + sampling_config={ + "num_generations": 4, + "temperature": 1.0, + "max_completion_length": 5, + "max_seq_length": 16, + }, + ) + + rescored = score_policy(model, rollout.policy_eval, mode="completion", temperature=1.0) + masked_rescored = rescored.token_logprobs * rollout.policy_eval.token_mask.astype(mx.float32) + + assert mx.allclose(rollout.rollout_logprobs, rescored.summed_logprobs, atol=1e-5) + assert mx.allclose(rollout.policy_eval.old_token_logprobs, masked_rescored, atol=1e-5) + + +def test_collect_rollouts_respects_max_seq_length_during_generation(): + from mlx_tune._rl_runtime import collect_rollouts + + tokenizer = TinyTokenizer() + model = ScriptedModel({2: 5, 3: 6, 4: 7, 5: 8, 6: 9, "default": 9}) + + rollout = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=[ + { + "sample_index": 0, + "prompt": "abcdef", + "prompt_ids": [3, 4, 5, 6, 7, 8], + "reward_context": "ctx", + } + ], + sampling_config={ + "num_generations": 1, + "temperature": 0.0, + "max_completion_length": 5, + "max_seq_length": 3, + }, + ) + + assert rollout.prompt_lengths.tolist() == [2] + assert rollout.completion_lengths.tolist() == [1] + assert rollout.policy_eval.input_ids.shape == (1, 3) + assert rollout.truncation_flags.tolist() == [True] + + +def test_collect_rollouts_reduces_completion_budget_before_collapsing_prompt(): + from mlx_tune._rl_runtime import collect_rollouts + + tokenizer = TinyTokenizer() + model = ScriptedModel({6: 5, 7: 6, "default": 6}) + + rollout = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=[ + { + "sample_index": 0, + "prompt": "abcdef", + "prompt_ids": [3, 4, 5, 6, 7, 8], + "reward_context": "ctx", + } + ], + sampling_config={ + "num_generations": 1, + "temperature": 0.0, + "max_completion_length": 6, + "max_seq_length": 6, + }, + ) + + assert rollout.prompt_ids == [[4, 5, 6, 7, 8]] + assert rollout.prompt_lengths.tolist() == [5] + assert rollout.completion_ids == [[6]] + assert rollout.completion_lengths.tolist() == [1] + assert rollout.policy_eval.input_ids.shape == (1, 6) + assert rollout.truncation_flags.tolist() == [True] + + +def test_collect_rollouts_treats_unsloth_stop_token_as_terminal(): + from mlx_tune._rl_runtime import collect_rollouts, sample_completion + + tokenizer = TinyTokenizer() + tokenizer._unsloth_stop_token = "<|im_end|>" + model = ScriptedModel({2: 9, 3: 7, 4: 7, "default": 7}) + + sampled = sample_completion( + policy=model, + tokenizer=tokenizer, + prompt_ids=[3, 4], + max_tokens=4, + temperature=0.0, + ) + rollout = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=[ + { + "sample_index": 0, + "prompt": "ab", + "prompt_ids": [3, 4], + "reward_context": "ctx", + } + ], + sampling_config={ + "num_generations": 1, + "temperature": 0.0, + "max_completion_length": 4, + }, + ) + + assert sampled["completion_ids"] == [9] + assert sampled["eos_flag"] is True + assert sampled["truncation_flag"] is False + assert rollout.completion_ids == [[9]] + assert rollout.eos_flags.tolist() == [True] + assert rollout.truncation_flags.tolist() == [False] + + +def test_collect_rollouts_initializes_and_reuses_prompt_cache_for_logits_only_models(): + from mlx_tune._rl_runtime import collect_rollouts + + tokenizer = TinyTokenizer() + model = CacheOnlyScriptedModel({1: 5, 2: 1, "default": 1}) + + rollout = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=[ + { + "sample_index": 0, + "prompt": "ab", + "prompt_ids": [3, 4], + "reward_context": "ctx", + } + ], + sampling_config={ + "num_generations": 2, + "temperature": 0.0, + "max_completion_length": 4, + "generation_batch_size": 2, + }, + ) + + assert rollout.completion_ids == [[5, 1], [5, 1]] + assert rollout.eos_flags.tolist() == [True, True] + assert model.calls == [(2, 1), (2, 1), (1, 2), (1, 2)] + + +def test_collect_rollouts_rebuilds_prompt_cache_when_active_batch_shrinks(): + from mlx_tune._rl_runtime import collect_rollouts + + tokenizer = TinyTokenizer() + model = BatchSizedCacheScriptedModel() + + rollout = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=[ + { + "sample_index": 0, + "prompt": "ab", + "prompt_ids": [3, 4], + "reward_context": "ctx-a", + }, + { + "sample_index": 1, + "prompt": "ac", + "prompt_ids": [3, 5], + "reward_context": "ctx-b", + }, + ], + sampling_config={ + "num_generations": 1, + "temperature": 0.0, + "max_completion_length": 3, + "generation_batch_size": 2, + }, + ) + + assert rollout.completion_ids == [[1], [1]] + assert rollout.eos_flags.tolist() == [True, True] + assert rollout.truncation_flags.tolist() == [False, False] + assert model.calls == [(1, 2, 1), (1, 2, 1)] + + +def test_collect_rollouts_uses_per_row_cache_to_preserve_rescored_logprobs(): + from mlx_tune._rl_runtime import collect_rollouts, score_policy + + tokenizer = TinyTokenizer() + model = BatchSensitiveCacheModel() + + rollout = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=[ + { + "sample_index": 0, + "prompt": "ab", + "prompt_ids": [3, 4], + "reward_context": "ctx", + } + ], + sampling_config={ + "num_generations": 2, + "temperature": 0.0, + "max_completion_length": 3, + "generation_batch_size": 2, + }, + ) + + rescored = score_policy(model, rollout.policy_eval, mode="completion", temperature=1.0) + masked_rescored = rescored.token_logprobs * rollout.policy_eval.token_mask.astype(mx.float32) + + assert rollout.completion_ids == [[5, 1], [5, 1]] + assert all(batch == 1 for batch, _, _ in model.cache_calls) + assert mx.allclose(rollout.rollout_logprobs, rescored.summed_logprobs, atol=1e-5) + assert mx.allclose(rollout.policy_eval.old_token_logprobs, masked_rescored, atol=1e-5) + + +def test_score_policy_matches_public_logprob_helpers(): + from mlx_tune._rl_runtime import make_policy_eval_batch, score_policy + from mlx_tune.losses import compute_completion_log_probs, compute_log_probs_with_lengths + + model = TinyModel() + mx.eval(model.parameters()) + + sequence_batch = make_policy_eval_batch( + [[1, 2, 3, 4], [1, 5, 6]], + pad_id=0, + mode="sequence", + ) + sequence_scores = score_policy(model, sequence_batch, mode="sequence") + direct_sequence = compute_log_probs_with_lengths( + model, + sequence_batch.input_ids, + sequence_batch.sequence_lengths, + ) + + completion_batch = make_policy_eval_batch( + [[1, 2, 3, 4, 5], [1, 4, 3, 2]], + pad_id=0, + mode="completion", + prompt_lengths=[3, 2], + completion_lengths=[2, 2], + ) + completion_scores = score_policy(model, completion_batch, mode="completion") + direct_completion = compute_completion_log_probs( + model, + completion_batch.input_ids, + completion_batch.prompt_lengths, + completion_batch.completion_lengths, + ) + + assert mx.allclose(sequence_scores.summed_logprobs, direct_sequence) + assert mx.allclose(completion_scores.summed_logprobs, direct_completion) + + +def test_reference_precompute_helpers_match_runtime_scorer(): + from mlx_tune._rl_runtime import make_policy_eval_batch, score_policy_in_chunks + from mlx_tune.losses import ( + precompute_kto_reference_logprobs, + precompute_preference_reference_logprobs, + ) + + model = TinyModel() + mx.eval(model.parameters()) + + chosen = make_policy_eval_batch([[1, 2, 3, 4], [1, 4, 5]], pad_id=0, mode="sequence") + rejected = make_policy_eval_batch([[1, 3, 2, 4], [1, 6, 5]], pad_id=0, mode="sequence") + direct_chosen = score_policy_in_chunks(model, chosen, batch_size=1, mode="sequence").summed_logprobs + direct_rejected = score_policy_in_chunks(model, rejected, batch_size=1, mode="sequence").summed_logprobs + cached_chosen, cached_rejected = precompute_preference_reference_logprobs( + model, + chosen.input_ids, + rejected.input_ids, + chosen.sequence_lengths, + rejected.sequence_lengths, + batch_size=1, + ) + + kto_batch = make_policy_eval_batch([[1, 2, 3], [1, 4, 5, 6]], pad_id=0, mode="sequence") + direct_kto = score_policy_in_chunks(model, kto_batch, batch_size=1, mode="sequence").summed_logprobs + cached_kto = precompute_kto_reference_logprobs( + model, + kto_batch.input_ids, + kto_batch.sequence_lengths, + batch_size=1, + ) + + assert mx.allclose(direct_chosen, cached_chosen) + assert mx.allclose(direct_rejected, cached_rejected) + assert mx.allclose(direct_kto, cached_kto) + + +def test_reward_adapter_supports_legacy_and_structured_evaluators(): + from mlx_tune._rl_runtime import collect_rollouts, evaluate_rewards + + tokenizer = TinyTokenizer() + model = ScriptedModel({2: 5, 3: 1, "default": 1}) + rollout = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=[ + { + "sample_index": 0, + "prompt": "ab", + "prompt_ids": [3, 4], + "reward_context": "ctx", + } + ], + sampling_config={"num_generations": 1, "temperature": 0.0, "max_completion_length": 2}, + ) + + legacy = evaluate_rewards( + rollout, + lambda response, context: float(len(response) + len(context)), + ) + + seen_payloads = [] + + class StructuredEvaluator: + def evaluate(self, payload): + seen_payloads.append(payload) + return { + "reward": float(payload["completion_length"]), + "components": {"length": float(payload["completion_length"])}, + "diagnostics": {"used_context": payload["reward_context"]}, + } + + structured = evaluate_rewards(rollout, StructuredEvaluator()) + + assert legacy.scalar_rewards.tolist() == [float(len(rollout.completion_texts[0]) + 3)] + assert structured.scalar_rewards.tolist() == [2.0] + assert structured.named_reward_components == [{"length": 2.0}] + assert structured.diagnostics == [{"used_context": "ctx"}] + assert seen_payloads[0]["prompt_text"] == tokenizer.decode([3, 4]) + assert seen_payloads[0]["original_prompt_text"] == "ab" + + +def test_reward_payload_exposes_effective_and_original_prompt_when_truncated(): + from mlx_tune._rl_runtime import collect_rollouts, evaluate_rewards + + tokenizer = TinyTokenizer() + model = ScriptedModel({2: 5, 3: 1, "default": 1}) + captured = [] + + rollout = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=[ + { + "sample_index": 0, + "prompt": "abcdef", + "prompt_ids": [3, 4, 5, 6, 7, 8], + "reward_context": "ctx", + } + ], + sampling_config={ + "num_generations": 1, + "temperature": 0.0, + "max_completion_length": 2, + "max_seq_length": 4, + }, + ) + + evaluate_rewards( + rollout, + lambda payload: captured.append(payload) or float(payload["completion_length"]), + ) + + assert captured[0]["prompt_ids"] == [6, 7, 8] + assert captured[0]["prompt_text"] == tokenizer.decode([6, 7, 8]) + assert captured[0]["original_prompt_text"] == "abcdef" + + +def test_compute_advantages_uses_zero_variance_fallback_per_prompt(): + from mlx_tune._rl_runtime import RewardBatch, compute_advantages + + reward_batch = RewardBatch( + prompt_texts=["p0", "p0", "p1", "p1"], + completion_texts=["a", "b", "c", "d"], + reward_contexts=["c0", "c0", "c1", "c1"], + scalar_rewards=mx.array([2.0, 2.0, 1.0, 3.0], dtype=mx.float32), + prompt_group_indices=mx.array([0, 0, 1, 1]), + ) + + advantages = compute_advantages(reward_batch) + + assert mx.allclose(advantages, mx.array([0.0, 0.0, -1.0, 1.0], dtype=mx.float32)) + + +def test_assemble_minibatches_preserves_order_and_prompt_groups(): + from mlx_tune._rl_runtime import assemble_minibatches, make_policy_eval_batch + + batch = make_policy_eval_batch( + [[1, 2, 3], [1, 4, 5], [1, 6, 7]], + pad_id=0, + mode="sequence", + rollout_logprobs=mx.array([0.1, 0.2, 0.3], dtype=mx.float32), + advantages=mx.array([1.0, 2.0, 3.0], dtype=mx.float32), + prompt_group_indices=mx.array([9, 9, 10]), + sample_indices=mx.array([0, 1, 2]), + ) + + minibatches = list(assemble_minibatches(batch, minibatch_size=2, shuffle=False)) + + assert len(minibatches) == 2 + assert minibatches[0].input_ids.tolist() == [[1, 2, 3], [1, 4, 5]] + assert minibatches[0].prompt_group_indices.tolist() == [9, 9] + assert minibatches[1].input_ids.tolist() == [[1, 6, 7]] + assert minibatches[1].prompt_group_indices.tolist() == [10] + + +def test_post_rollout_helpers_attach_reference_values_and_rloo_advantages(): + from mlx_tune._rl_runtime import ( + collect_rollouts, + compute_returns_and_advantages, + predict_rollout_values, + rank_grouped_rollouts, + score_rollout_references, + ) + + class ConstantValueModel: + def predict(self, input_ids, **kwargs): + del kwargs + return mx.array([0.5] * input_ids.shape[0], dtype=mx.float32) + + tokenizer = TinyTokenizer() + policy = ScriptedModel({2: 5, 3: 1, "default": 1}) + reference = ScriptedModel({2: 5, 3: 1, "default": 1}) + + rollout = collect_rollouts( + policy=policy, + tokenizer=tokenizer, + prompt_samples=[ + {"sample_index": 0, "prompt": "ab", "prompt_ids": [3, 4], "reward_context": "x"}, + {"sample_index": 1, "prompt": "cd", "prompt_ids": [5, 6], "reward_context": "y"}, + ], + sampling_config={"num_generations": 2, "temperature": 0.0, "max_completion_length": 2}, + ) + rollout.rewards = mx.array([3.0, 1.0, 4.0, 2.0], dtype=mx.float32) + + rollout = score_rollout_references(reference, rollout, batch_size=1) + rollout = predict_rollout_values(ConstantValueModel(), rollout, batch_size=2) + returns, advantages = compute_returns_and_advantages( + rollout.rewards, + prompt_group_indices=rollout.prompt_group_indices, + mode="rloo", + ) + rollout.returns = returns + rollout.advantages = advantages + rankings = rank_grouped_rollouts(rollout) + + assert rollout.reference_logprobs.shape == (4,) + assert rollout.value_predictions.tolist() == [0.5, 0.5, 0.5, 0.5] + assert returns.tolist() == [3.0, 1.0, 4.0, 2.0] + assert advantages.tolist() == [2.0, -2.0, 2.0, -2.0] + assert rankings[0]["best_position"] == 0 + assert rankings[0]["worst_position"] == 1 + + +def test_length_normalization_and_kl_helpers_match_expected_math(): + from mlx_tune._rl_runtime import kl_against_reference, normalize_logprobs + + summed = mx.array([4.0, 6.0], dtype=mx.float32) + lengths = mx.array([2, 3]) + normalized = normalize_logprobs(summed, lengths, mode="mean") + kl = kl_against_reference( + mx.array([0.0, 1.0], dtype=mx.float32), + mx.array([0.0, 0.5], dtype=mx.float32), + ) + + assert mx.allclose(normalized, mx.array([2.0, 2.0], dtype=mx.float32)) + assert mx.allclose( + kl, + mx.array([0.0, float((mx.exp(mx.array(0.5)) - 0.5 - 1.0).item())], dtype=mx.float32), + ) + + +def test_collect_rollouts_batched_decode_matches_per_sequence_sampler(): + from mlx_tune._rl_runtime import collect_rollouts, sample_completion + + tokenizer = TinyTokenizer() + model = ScriptedModel({2: 5, 3: 6, 4: 1, "default": 1}) + prompt_samples = [ + {"sample_index": 0, "prompt": "ab", "prompt_ids": [3, 4], "reward_context": "x"}, + {"sample_index": 1, "prompt": "cd", "prompt_ids": [5, 6], "reward_context": "y"}, + ] + + expected = [] + for sample in prompt_samples: + for _ in range(2): + expected.append( + sample_completion( + policy=model, + tokenizer=tokenizer, + prompt_ids=sample["prompt_ids"], + max_tokens=3, + temperature=0.0, + collect_sample_stats=True, + ) + ) + + rollout = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=prompt_samples, + sampling_config={ + "num_generations": 2, + "temperature": 0.0, + "max_completion_length": 3, + "generation_batch_size": 2, + }, + collect_sample_stats=True, + ) + + assert rollout.completion_ids == [item["completion_ids"] for item in expected] + assert rollout.eos_flags.tolist() == [item["eos_flag"] for item in expected] + assert rollout.truncation_flags.tolist() == [item["truncation_flag"] for item in expected] + assert mx.allclose( + rollout.sampled_token_logprobs, + mx.array([item["sampled_logprobs"] for item in expected], dtype=mx.float32), + ) + assert mx.allclose( + rollout.sampled_token_logits, + mx.array([item["sampled_logits"] for item in expected], dtype=mx.float32), + ) + assert mx.allclose( + rollout.token_entropies, + mx.array([item["token_entropies"] for item in expected], dtype=mx.float32), + ) + + +def test_reward_adapter_prefers_evaluate_batch_when_available(): + from mlx_tune._rl_runtime import collect_rollouts, evaluate_rewards + + tokenizer = TinyTokenizer() + model = ScriptedModel({2: 5, 3: 1, "default": 1}) + rollout = collect_rollouts( + policy=model, + tokenizer=tokenizer, + prompt_samples=[ + {"sample_index": 0, "prompt": "ab", "prompt_ids": [3, 4], "reward_context": "ctx0"}, + {"sample_index": 1, "prompt": "cd", "prompt_ids": [5, 6], "reward_context": "ctx1"}, + ], + sampling_config={"num_generations": 1, "temperature": 0.0, "max_completion_length": 2}, + ) + + calls = [] + + class BatchEvaluator: + def evaluate_batch(self, payloads): + calls.append([payload["reward_context"] for payload in payloads]) + return [ + {"reward": float(payload["completion_length"]), "components": {"len": float(payload["completion_length"])}} + for payload in payloads + ] + + reward_batch = evaluate_rewards(rollout, BatchEvaluator()) + + assert calls == [["ctx0", "ctx1"]] + assert reward_batch.scalar_rewards.tolist() == [2.0, 2.0] + assert reward_batch.named_reward_components == [{"len": 2.0}, {"len": 2.0}] + + +def test_token_budget_chunking_matches_unchunked_policy_and_value_scoring(): + from mlx_tune._rl_runtime import ( + collect_rollouts, + make_policy_eval_batch, + predict_rollout_values, + score_policy_in_chunks, + ) + + class LengthValueModel: + def predict(self, input_ids, sequence_lengths, **kwargs): + del input_ids, kwargs + return sequence_lengths.astype(mx.float32) + + model = TinyModel() + mx.eval(model.parameters()) + batch = make_policy_eval_batch( + [[1, 2, 3, 4, 5], [1, 4, 5], [1, 6, 7, 8], [1, 9, 10, 11, 12, 13]], + pad_id=0, + mode="sequence", + ) + + full_scores = score_policy_in_chunks(model, batch, batch_size=8, mode="sequence") + chunked_scores = score_policy_in_chunks( + model, + batch, + batch_size=8, + token_budget=4, + mode="sequence", + ) + + tokenizer = TinyTokenizer() + rollout = collect_rollouts( + policy=ScriptedModel({2: 5, 3: 1, "default": 1}), + tokenizer=tokenizer, + prompt_samples=[ + {"sample_index": 0, "prompt": "ab", "prompt_ids": [3, 4], "reward_context": "x"}, + {"sample_index": 1, "prompt": "cde", "prompt_ids": [5, 6, 7], "reward_context": "y"}, + ], + sampling_config={"num_generations": 2, "temperature": 0.0, "max_completion_length": 2}, + ) + full_values = predict_rollout_values(LengthValueModel(), rollout, batch_size=8).value_predictions + chunked_values = predict_rollout_values( + LengthValueModel(), + rollout, + batch_size=8, + token_budget=4, + ).value_predictions + + assert mx.allclose(full_scores.summed_logprobs, chunked_scores.summed_logprobs) + assert mx.allclose(full_scores.token_logprobs, chunked_scores.token_logprobs) + assert mx.allclose(full_values, chunked_values) + + +def test_summarize_rollout_metrics_includes_completion_and_stop_stats(): + from mlx_tune._rl_runtime import collect_rollouts, summarize_rollout_metrics + + tokenizer = TinyTokenizer() + tokenizer._unsloth_stop_token = "<|im_end|>" + rollout = collect_rollouts( + policy=ScriptedModel({2: 9, 3: 7, 4: 7, "default": 7}), + tokenizer=tokenizer, + prompt_samples=[ + {"sample_index": 0, "prompt": "ab", "prompt_ids": [3, 4], "reward_context": "x"}, + {"sample_index": 1, "prompt": "cd", "prompt_ids": [5, 6], "reward_context": "y"}, + ], + sampling_config={"num_generations": 1, "temperature": 0.0, "max_completion_length": 4}, + ) + rollout.rewards = mx.array([1.0, 0.0], dtype=mx.float32) + + metrics = summarize_rollout_metrics(rollout, policy_loss=0.5) + + assert metrics["completion_length_mean"] == 1.0 + assert metrics["completion_length_max"] == 1.0 + assert metrics["eos_rate"] == 1.0 + assert metrics["truncation_rate"] == 0.0 + assert metrics["reward_mean"] == 0.5 + assert metrics["policy_loss"] == 0.5 + + +def test_summarize_rollout_metrics_normalizes_kl_by_completion_length(): + from mlx_tune._rl_runtime import RolloutBatch, summarize_rollout_metrics + + rollout = RolloutBatch( + prompt_ids=[[1, 2], [3, 4]], + completion_ids=[[5], [6]], + prompt_texts=["p1", "p2"], + original_prompt_texts=None, + completion_texts=["c1", "c2"], + reward_contexts=["x", "y"], + prompt_lengths=mx.array([2, 2], dtype=mx.int32), + completion_lengths=mx.array([10, 1000], dtype=mx.int32), + sampled_token_logprobs=mx.zeros((2, 1), dtype=mx.float32), + rollout_logprobs=mx.array([0.0, 0.0], dtype=mx.float32), + eos_flags=mx.array([True, True]), + truncation_flags=mx.array([False, False]), + prompt_group_indices=mx.array([0, 1], dtype=mx.int32), + policy_eval=None, + sample_indices=None, + sampled_token_logits=None, + token_entropies=None, + old_logprobs=None, + rewards=mx.array([1.0, 0.0], dtype=mx.float32), + reference_logprobs=mx.array([-10.0, -1000.0], dtype=mx.float32), + returns=None, + advantages=None, + ) + + metrics = summarize_rollout_metrics(rollout) + + assert abs(metrics["logprob_delta_per_token_mean"] - 1.0) < 1e-6 + assert abs(metrics["kl_to_reference_mean"] - (math.e - 2.0)) < 1e-5 diff --git a/tests/test_rl_trainers_integration.py b/tests/test_rl_trainers_integration.py index cfadc55..c3e18cd 100644 --- a/tests/test_rl_trainers_integration.py +++ b/tests/test_rl_trainers_integration.py @@ -1,153 +1,218 @@ """ -Integration tests for RL trainers (DPO, ORPO, GRPO, KTO, SimPO). - -These tests verify that the trainers actually run and produce valid results, -not just that they can be imported or configured. - -Tests marked with @pytest.mark.integration require more time/resources. +Integration tests for RL trainers. """ -import pytest +import json +from pathlib import Path + import mlx.core as mx import mlx.nn as nn -from typing import Dict, Any, List - - -# ============================================================================= -# TEST FIXTURES - Small models and datasets for fast testing -# ============================================================================= +import pytest +from mlx.utils import tree_flatten -class SmallLanguageModel(nn.Module): - """A tiny language model for testing - fast to train.""" - def __init__(self, vocab_size: int = 100, hidden_size: int = 64, num_layers: int = 2): +class SmallBackbone(nn.Module): + def __init__(self, vocab_size: int = 64, hidden_size: int = 32, num_layers: int = 2): super().__init__() - self.vocab_size = vocab_size - self.hidden_size = hidden_size self.embedding = nn.Embedding(vocab_size, hidden_size) self.layers = [nn.Linear(hidden_size, hidden_size) for _ in range(num_layers)] - self.output = nn.Linear(hidden_size, vocab_size) def __call__(self, x): h = self.embedding(x) for layer in self.layers: - h = mx.maximum(layer(h), 0) # ReLU - return self.output(h) + h = mx.maximum(layer(h), 0) + return h -class MockTokenizer: - """Mock tokenizer for testing.""" +class SmallLanguageModel(nn.Module): + def __init__(self, vocab_size: int = 64, hidden_size: int = 32, num_layers: int = 2): + super().__init__() + self.model = SmallBackbone(vocab_size=vocab_size, hidden_size=hidden_size, num_layers=num_layers) + self.output = nn.Linear(hidden_size, vocab_size) - def __init__(self, vocab_size: int = 100): + def __call__(self, x): + return self.output(self.model(x)) + + +class MockTokenizer: + def __init__(self, vocab_size: int = 64): self.vocab_size = vocab_size self.pad_token_id = 0 self.eos_token_id = 1 self.bos_token_id = 2 - self.pad_token = "" - self.eos_token = "" - self.name_or_path = "mock-tokenizer" - def encode(self, text: str, add_special_tokens: bool = True) -> List[int]: - """Simple encoding: hash characters to vocab indices.""" - ids = [hash(c) % (self.vocab_size - 3) + 3 for c in text[:50]] + def encode(self, text: str, add_special_tokens: bool = True): + ids = [((ord(char) % (self.vocab_size - 3)) + 3) for char in text[:32]] if add_special_tokens: ids = [self.bos_token_id] + ids + [self.eos_token_id] return ids - def decode(self, ids: List[int], skip_special_tokens: bool = True) -> str: - """Simple decoding.""" + def decode(self, ids, skip_special_tokens: bool = True): if skip_special_tokens: - ids = [i for i in ids if i not in (self.pad_token_id, self.eos_token_id, self.bos_token_id)] - return "".join(chr(65 + (i % 26)) for i in ids) - - def __call__(self, text, return_tensors=None, padding=True, truncation=True, max_length=512): - """Tokenize text.""" - if isinstance(text, str): - ids = self.encode(text) - else: - ids = [self.encode(t) for t in text] - - if return_tensors == "mlx": - return {"input_ids": mx.array(ids)} - return {"input_ids": ids} + ids = [token for token in ids if token not in (self.pad_token_id, self.eos_token_id, self.bos_token_id)] + return "".join(chr(65 + (token % 26)) for token in ids) class MockModelWrapper: - """Wrapper to match FastLanguageModel interface.""" - def __init__(self, model: SmallLanguageModel): self.model = model self._lora_applied = False - self._lora_config = None + self.lora_config = None + self._adapter_path = None def __call__(self, x): return self.model(x) def _apply_lora(self): - """Mock LoRA application.""" self._lora_applied = True - print(" [Mock] LoRA applied") - - def parameters(self): - return self.model.parameters() - - def freeze(self): - pass - - def unfreeze(self): - pass + return True + + def set_adapter_path(self, path: str): + self._adapter_path = path + + +def write_legacy_rl_checkpoint(trainer, checkpoint_dir: Path): + checkpoint_dir.mkdir(parents=True, exist_ok=True) + adapter_dir = checkpoint_dir / "adapters" + adapter_dir.mkdir(parents=True, exist_ok=True) + adapter_weights = dict(tree_flatten(trainer.model.model.trainable_parameters())) + mx.save_safetensors(str(adapter_dir / "adapters.safetensors"), adapter_weights) + + state_arrays = { + f"optimizer.{key}": value + for key, value in tree_flatten(trainer.optimizer.state) + } + state_arrays.update({f"rng.{idx}": state for idx, state in enumerate(mx.random.state)}) + if hasattr(trainer, "_extra_state_arrays"): + state_arrays.update(trainer._extra_state_arrays()) + mx.save_safetensors(str(checkpoint_dir / "trainer_state.safetensors"), state_arrays) + + metadata = { + "algorithm": trainer.algorithm, + "config": trainer.config.to_dict() if hasattr(trainer.config, "to_dict") else dict(trainer.config), + "global_step": trainer.global_step, + "dataset_cursor": trainer.dataset_cursor, + "cache_metadata": trainer.cache_metadata, + } + (checkpoint_dir / "trainer_state.json").write_text(json.dumps(metadata, indent=2)) + + if trainer.reference_policy is not None: + reference_weights = dict(tree_flatten(trainer.reference_policy.model.model.parameters())) + mx.save_safetensors(str(checkpoint_dir / "reference_model.safetensors"), reference_weights) + (checkpoint_dir / "reference_metadata.json").write_text( + json.dumps( + { + "source": trainer.reference_policy.source, + "metadata": trainer.reference_policy.metadata, + }, + indent=2, + ) + ) -@pytest.fixture -def small_model(): - """Create a small model for testing.""" - model = SmallLanguageModel(vocab_size=100, hidden_size=64, num_layers=2) +def make_model(seed: int) -> MockModelWrapper: + mx.random.seed(seed) + model = SmallLanguageModel() mx.eval(model.parameters()) return MockModelWrapper(model) +def parameter_snapshot(model_wrapper: MockModelWrapper): + return {name: mx.array(value) for name, value in tree_flatten(model_wrapper.model.parameters())} + + +def parameters_changed(before, after_wrapper: MockModelWrapper) -> bool: + after = {name: value for name, value in tree_flatten(after_wrapper.model.parameters())} + for name, before_value in before.items(): + delta = mx.max(mx.abs(after[name] - before_value)).item() + if delta > 1e-6: + return True + return False + + +def rollout_sequence_tensors(rollout_batch, index: int): + sequence = rollout_batch.prompt_ids[index] + rollout_batch.completion_ids[index] + return ( + mx.array([sequence]), + mx.array([len(rollout_batch.prompt_ids[index])]), + mx.array([int(rollout_batch.completion_lengths[index].item())]), + ) + + +def test_format_metric_summary_includes_rollout_length_and_stop_stats(): + from mlx_tune import GRPOTrainer + + trainer = GRPOTrainer.__new__(GRPOTrainer) + row = { + "step": 3, + "train/policy_loss": 0.25, + "train/reward_mean": 1.0, + "train/logprob_delta_per_token_mean": 0.0125, + "train/logprob_delta_mean": 0.75, + "train/completion_length_mean": 12.0, + "train/completion_length_max": 32.0, + "train/eos_rate": 0.75, + "train/truncation_rate": 0.25, + "train/kl_to_reference_mean": 0.1, + "train/rollout_generate_wall": 4.5, + "train/reward_eval_wall": 0.2, + "train/reference_score_wall": 11.0, + "train/returns_wall": 0.01, + "train/policy_update_wall": 24.0, + "train/policy_update_steps": 3.0, + } + + summary = trainer._format_metric_summary(row) + + assert summary == ( + "step=3 | policy_loss=0.2500 | reward_mean=1.0000 | " + "logprob_delta_per_token_mean=0.0125 | logprob_delta_mean=0.7500 | " + "completion_length_mean=12.0000 | completion_length_max=32.0000 | " + "eos_rate=0.7500 | truncation_rate=0.2500 | kl_to_reference_mean=0.1000 | " + "rollout_generate_wall=4.5000 | reward_eval_wall=0.2000 | " + "reference_score_wall=11.0000 | returns_wall=0.0100 | " + "policy_update_wall=24.0000 | policy_update_steps=3.0000" + ) + + @pytest.fixture -def mock_tokenizer(): - """Create a mock tokenizer.""" - return MockTokenizer(vocab_size=100) +def tokenizer(): + return MockTokenizer() @pytest.fixture def preference_dataset(): - """Sample preference dataset for DPO/ORPO/SimPO.""" return [ { "prompt": "What is machine learning?", - "chosen": "Machine learning is a branch of AI that enables systems to learn from data.", - "rejected": "idk its computers doing stuff" + "chosen": "Machine learning is learning from data.", + "rejected": "Computers do stuff.", }, { "prompt": "Explain Python.", - "chosen": "Python is a high-level programming language known for readability.", - "rejected": "python is a snake" + "chosen": "Python is a high-level programming language.", + "rejected": "Python is only a snake.", }, { "prompt": "What is deep learning?", - "chosen": "Deep learning uses neural networks with many layers to learn patterns.", - "rejected": "its like machine learning but deeper i guess" + "chosen": "Deep learning uses many neural-network layers.", + "rejected": "It is just regular learning.", }, ] @pytest.fixture def kto_dataset(): - """Sample dataset for KTO (binary feedback).""" return [ - {"text": "Machine learning is a branch of AI.", "label": 1}, # Good - {"text": "idk computers stuff", "label": 0}, # Bad - {"text": "Python is a programming language.", "label": 1}, # Good - {"text": "python snake", "label": 0}, # Bad + {"text": "Machine learning uses data.", "label": 1}, + {"text": "Computers maybe stuff.", "label": 0}, + {"text": "Python is a programming language.", "label": 1}, + {"text": "Snake only.", "label": 0}, ] @pytest.fixture def grpo_dataset(): - """Sample dataset for GRPO (reasoning with answers).""" return [ {"prompt": "What is 2 + 2?", "answer": "4"}, {"prompt": "What is 5 * 3?", "answer": "15"}, @@ -155,681 +220,1562 @@ def grpo_dataset(): ] -# ============================================================================= -# DPO TRAINER INTEGRATION TESTS -# ============================================================================= - @pytest.mark.integration -class TestDPOTrainerIntegration: - """Integration tests for DPOTrainer.""" - - def test_dpo_trainer_init(self, small_model, mock_tokenizer, preference_dataset): - """Test DPOTrainer can be initialized.""" - from mlx_tune import DPOTrainer, DPOConfig - - config = DPOConfig( - beta=0.1, - learning_rate=1e-4, - max_steps=2, - output_dir="./test_dpo_output", +class TestRewardAndPPOIntegration: + def test_reward_trainer_and_scoring_helper_save_manifest_checkpoint( + self, + tmp_path, + tokenizer, + ): + from mlx_tune import RewardConfig, RewardTrainer, build_reward_model, score_reward_model + + reward_model = build_reward_model(make_model(50)) + trainer = RewardTrainer( + model=reward_model, + train_dataset=[ + {"prompt": "Q:", "response": "good", "score": 1.0}, + {"prompt": "Q:", "response": "bad", "score": 0.0}, + ], + tokenizer=tokenizer, + args=RewardConfig( + learning_rate=1e-2, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(tmp_path / "reward"), + ), ) - trainer = DPOTrainer( - model=small_model, - train_dataset=preference_dataset, - tokenizer=mock_tokenizer, - args=config, + result = trainer.train() + scores = score_reward_model( + reward_model, + [{"prompt": "Q:", "response": "good"}], + batch_size=1, + tokenizer=tokenizer, ) - assert trainer is not None - assert trainer.beta == 0.1 - assert len(trainer.train_dataset) == 3 + assert result["status"] == "success" + assert len(scores) == 1 + assert (tmp_path / "reward" / "manifest.json").exists() + assert (tmp_path / "reward" / "reward_model" / "head.safetensors").exists() + + def test_ppo_trainer_persists_policy_and_value_optimizers( + self, + tmp_path, + tokenizer, + grpo_dataset, + ): + from mlx_tune import PPOConfig, PPOTrainer, build_value_model + + trainer = PPOTrainer( + model=make_model(51), + train_dataset=grpo_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + value_model=build_value_model(make_model(52)), + args=PPOConfig( + learning_rate=1e-2, + value_learning_rate=2e-2, + max_steps=1, + logging_steps=1, + save_steps=1, + num_generations=2, + max_completion_length=4, + output_dir=str(tmp_path / "ppo"), + ), + ) - def test_dpo_trainer_train_runs(self, small_model, mock_tokenizer, preference_dataset): - """Test DPOTrainer.train() executes without errors.""" - from mlx_tune import DPOTrainer, DPOConfig + result = trainer.train() - config = DPOConfig( - beta=0.1, - learning_rate=1e-4, - max_steps=2, # Very short for testing - output_dir="./test_dpo_output", + assert result["status"] == "success" + assert trainer._last_rollout_batch is not None + assert trainer._last_rollout_batch.value_predictions is not None + assert trainer._last_rollout_batch.returns is not None + assert (tmp_path / "ppo" / "optimizers" / "policy" / "state.safetensors").exists() + assert (tmp_path / "ppo" / "optimizers" / "value" / "state.safetensors").exists() + + def test_ppo_trainer_uses_requested_advantage_estimator( + self, + tmp_path, + tokenizer, + grpo_dataset, + ): + from mlx_tune import PPOConfig, PPOTrainer, build_value_model + from mlx_tune._rl_runtime import compute_returns_and_advantages + + trainer = PPOTrainer( + model=make_model(60), + ref_model=make_model(61), + train_dataset=grpo_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + value_model=build_value_model(make_model(62)), + args=PPOConfig( + advantage_estimator="rloo", + normalize_advantages=False, + temperature=0.0, + max_steps=1, + logging_steps=1, + save_steps=1, + num_generations=2, + max_completion_length=4, + output_dir=str(tmp_path / "ppo_rloo"), + ), ) - trainer = DPOTrainer( - model=small_model, - train_dataset=preference_dataset, - tokenizer=mock_tokenizer, - args=config, + trainer._ensure_reference_policy() + trainer._prepare_prompt_samples() + rollout_batch = trainer._collect_rollout_batch(trainer._next_prompt_batch()) + + _, expected_advantages = compute_returns_and_advantages( + rewards=rollout_batch.rewards, + prompt_group_indices=rollout_batch.prompt_group_indices, + mode="rloo", + gamma=trainer.gamma, + gae_lambda=trainer.gae_lambda, + normalize=False, ) - # This should run without raising exceptions - result = trainer.train() - - assert result is not None - # Verify training completed - check model's _lora_applied flag - assert small_model._lora_applied, "LoRA should have been applied during training" - - def test_dpo_loss_decreases(self, small_model, mock_tokenizer, preference_dataset): - """Test that DPO loss decreases or stays stable during training.""" - from mlx_tune import DPOTrainer, DPOConfig + assert mx.allclose(rollout_batch.advantages, expected_advantages) + + def test_on_policy_trainers_activate_kl_controls( + self, + tmp_path, + tokenizer, + grpo_dataset, + ): + from mlx_tune import ( + GRPOConfig, + GRPOTrainer, + OnlineDPOConfig, + OnlineDPOTrainer, + PPOConfig, + PPOTrainer, + build_value_model, + ) - config = DPOConfig( - beta=0.1, - learning_rate=1e-3, # Higher LR for visible change - max_steps=5, - output_dir="./test_dpo_output", + grpo = GRPOTrainer( + model=make_model(70), + ref_model=make_model(71), + train_dataset=grpo_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=GRPOConfig( + beta=0.3, + kl_target=1e-4, + max_steps=1, + logging_steps=1, + save_steps=1, + num_generations=2, + max_completion_length=3, + output_dir=str(tmp_path / "grpo_kl"), + ), + ) + grpo._ensure_reference_policy() + grpo._prepare_prompt_samples() + grpo_rollout = grpo._collect_rollout_batch(grpo._next_prompt_batch()) + assert grpo._effective_kl_beta(grpo_rollout) != grpo.beta + + ppo = PPOTrainer( + model=make_model(72), + ref_model=make_model(73), + train_dataset=grpo_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + value_model=build_value_model(make_model(74)), + args=PPOConfig( + beta=0.4, + kl_penalty_mode="none", + max_steps=1, + logging_steps=1, + save_steps=1, + num_generations=2, + max_completion_length=3, + output_dir=str(tmp_path / "ppo_kl_none"), + ), + ) + ppo._ensure_reference_policy() + ppo._prepare_prompt_samples() + ppo_rollout = ppo._collect_rollout_batch(ppo._next_prompt_batch()) + assert ppo._effective_kl_beta(ppo_rollout) == 0.0 + + online_dpo = OnlineDPOTrainer( + model=make_model(75), + ref_model=make_model(76), + train_dataset=grpo_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=OnlineDPOConfig( + beta=0.2, + kl_target=1e-4, + max_steps=1, + logging_steps=1, + save_steps=1, + num_generations=2, + max_completion_length=3, + output_dir=str(tmp_path / "online_dpo_kl"), + ), ) + online_dpo._ensure_reference_policy() + online_dpo._prepare_prompt_samples() + online_rollout = online_dpo._collect_rollout_batch(online_dpo._next_prompt_batch()) + assert online_dpo._effective_kl_beta(online_rollout) != online_dpo.beta + +@pytest.mark.integration +class TestDPOTrainerIntegration: + def test_dpo_frozen_reference_stays_unchanged_across_policy_updates( + self, + tmp_path, + tokenizer, + preference_dataset, + ): + from mlx_tune import DPOConfig, DPOTrainer, compute_reference_logprobs + + model = make_model(0) trainer = DPOTrainer( - model=small_model, + model=model, train_dataset=preference_dataset, - tokenizer=mock_tokenizer, - args=config, + tokenizer=tokenizer, + args=DPOConfig( + learning_rate=5e-2, + max_steps=3, + logging_steps=1, + save_steps=1, + output_dir=str(tmp_path), + ), ) result = trainer.train() + assert result["global_step"] == 3 + assert trainer.reference_policy is not None + + pad_id = tokenizer.pad_token_id + sample = trainer.train_samples[0] + chosen = mx.array([sample["chosen_ids"]]) + rejected = mx.array([sample["rejected_ids"]]) + chosen_lengths = mx.array([sample["chosen_length"]]) + rejected_lengths = mx.array([sample["rejected_length"]]) + + ref_chosen, ref_rejected = compute_reference_logprobs( + trainer.reference_policy.model.model, + chosen, + rejected, + chosen_lengths, + rejected_lengths, + ) - # Check no NaN in result - if isinstance(result, dict) and 'final_loss' in result: - assert not mx.isnan(mx.array(result['final_loss'])), "Final loss should not be NaN" - - def test_dpo_different_loss_types(self, small_model, mock_tokenizer, preference_dataset): - """Test DPO with different loss types.""" - from mlx_tune import DPOTrainer, DPOConfig - - loss_types = ["sigmoid", "hinge", "ipo"] - - for loss_type in loss_types: - # Create fresh model for each test - model = MockModelWrapper(SmallLanguageModel()) - mx.eval(model.model.parameters()) - - config = DPOConfig( - beta=0.1, - loss_type=loss_type, - learning_rate=1e-4, - max_steps=2, - output_dir="./test_dpo_output", - ) - - trainer = DPOTrainer( - model=model, - train_dataset=preference_dataset, - tokenizer=mock_tokenizer, - args=config, - ) - - # Should not raise - result = trainer.train() - assert result is not None, f"DPO with loss_type={loss_type} failed" - - -# ============================================================================= -# ORPO TRAINER INTEGRATION TESTS -# ============================================================================= - -@pytest.mark.integration -class TestORPOTrainerIntegration: - """Integration tests for ORPOTrainer.""" - - def test_orpo_trainer_init(self, small_model, mock_tokenizer, preference_dataset): - """Test ORPOTrainer can be initialized.""" - from mlx_tune import ORPOTrainer, ORPOConfig + assert mx.allclose(ref_chosen, mx.array([sample["reference_chosen_logprobs"]])) + assert mx.allclose(ref_rejected, mx.array([sample["reference_rejected_logprobs"]])) + assert model._lora_applied + + def test_dpo_loss_changes_when_reference_cache_changes( + self, + tokenizer, + preference_dataset, + ): + from mlx_tune import dpo_loss, precompute_preference_reference_logprobs + + policy = make_model(1) + reference_a = make_model(2) + reference_b = make_model(3) + + prompt = preference_dataset[0]["prompt"] + chosen = tokenizer.encode(prompt + preference_dataset[0]["chosen"]) + rejected = tokenizer.encode(prompt + preference_dataset[0]["rejected"]) + + chosen_ids = mx.array([chosen]) + rejected_ids = mx.array([rejected]) + chosen_lengths = mx.array([len(chosen)]) + rejected_lengths = mx.array([len(rejected)]) + + ref_a = precompute_preference_reference_logprobs( + reference_a.model, + chosen_ids, + rejected_ids, + chosen_lengths, + rejected_lengths, + ) + ref_b = precompute_preference_reference_logprobs( + reference_b.model, + chosen_ids, + rejected_ids, + chosen_lengths, + rejected_lengths, + ) - config = ORPOConfig( + loss_a, _ = dpo_loss( + policy.model, + chosen_ids, + rejected_ids, + chosen_lengths, + rejected_lengths, beta=0.1, - learning_rate=1e-4, - max_steps=2, - output_dir="./test_orpo_output", + reference_chosen_logprobs=ref_a[0], + reference_rejected_logprobs=ref_a[1], ) - - trainer = ORPOTrainer( - model=small_model, - train_dataset=preference_dataset, - tokenizer=mock_tokenizer, - args=config, + loss_b, _ = dpo_loss( + policy.model, + chosen_ids, + rejected_ids, + chosen_lengths, + rejected_lengths, + beta=0.1, + reference_chosen_logprobs=ref_b[0], + reference_rejected_logprobs=ref_b[1], ) - assert trainer is not None - assert trainer.beta == 0.1 + assert abs(loss_a.item() - loss_b.item()) > 1e-6 - def test_orpo_trainer_train_runs(self, small_model, mock_tokenizer, preference_dataset): - """Test ORPOTrainer.train() executes without errors.""" - from mlx_tune import ORPOTrainer, ORPOConfig + def test_dpo_resume_restores_state_and_cache( + self, + tmp_path, + tokenizer, + preference_dataset, + ): + from mlx_tune import DPOConfig, DPOTrainer - config = ORPOConfig( - beta=0.1, - learning_rate=1e-4, - max_steps=2, - output_dir="./test_orpo_output", + output_dir = tmp_path / "dpo_resume" + trainer = DPOTrainer( + model=make_model(4), + train_dataset=preference_dataset, + tokenizer=tokenizer, + args=DPOConfig( + learning_rate=1e-2, + max_steps=2, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), ) + trainer.train() - trainer = ORPOTrainer( - model=small_model, + resumed = DPOTrainer( + model=make_model(5), train_dataset=preference_dataset, - tokenizer=mock_tokenizer, - args=config, + tokenizer=tokenizer, + args=DPOConfig( + learning_rate=1e-2, + max_steps=4, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + result = resumed.train(resume_from_checkpoint=str(output_dir)) + + assert result["global_step"] == 4 + assert resumed.cache_metadata == trainer.cache_metadata + assert resumed.optimizer is not None + assert resumed.optimizer.state["step"].item() == 4 + assert mx.allclose( + mx.array([sample["reference_chosen_logprobs"] for sample in resumed.train_samples]), + mx.array([sample["reference_chosen_logprobs"] for sample in trainer.train_samples]), ) - result = trainer.train() - assert result is not None - - def test_orpo_combines_sft_and_preference(self, small_model, mock_tokenizer, preference_dataset): - """Test that ORPO combines SFT and preference learning.""" - from mlx_tune import ORPOTrainer, ORPOConfig - config = ORPOConfig( - beta=0.1, - learning_rate=1e-3, - max_steps=3, - output_dir="./test_orpo_output", - ) +@pytest.mark.integration +class TestGRPOTrainerIntegration: + def test_grpo_training_changes_parameters_and_increases_rewarded_logprob( + self, + tmp_path, + tokenizer, + grpo_dataset, + ): + from mlx_tune import GRPOConfig, GRPOTrainer, compute_completion_log_probs + + mx.random.seed(7) + model = make_model(6) + before = parameter_snapshot(model) - trainer = ORPOTrainer( - model=small_model, - train_dataset=preference_dataset, - tokenizer=mock_tokenizer, - args=config, + trainer = GRPOTrainer( + model=model, + train_dataset=grpo_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=GRPOConfig( + learning_rate=5e-2, + beta=0.01, + num_generations=3, + max_completion_length=4, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(tmp_path), + ), ) result = trainer.train() + rollout = trainer._last_rollout_batch - # ORPO should complete successfully - assert result is not None + assert result["global_step"] == 1 + assert rollout is not None + assert parameters_changed(before, model) + best_index = int(mx.argmax(rollout.rewards).item()) + input_ids, prompt_lengths, completion_lengths = rollout_sequence_tensors(rollout, best_index) + updated_logprob = compute_completion_log_probs( + model.model, + input_ids, + prompt_lengths, + completion_lengths, + )[0].item() + rollout_logprob = rollout.rollout_logprobs[best_index].item() -# ============================================================================= -# GRPO TRAINER INTEGRATION TESTS (Most important for reasoning) -# ============================================================================= + assert updated_logprob > rollout_logprob -@pytest.mark.integration -class TestGRPOTrainerIntegration: - """Integration tests for GRPOTrainer - DeepSeek R1 style reasoning.""" + def test_grpo_prefers_answer_context_over_prompt( + self, + tmp_path, + tokenizer, + ): + from mlx_tune import GRPOConfig, GRPOTrainer - def test_grpo_trainer_init(self, small_model, mock_tokenizer, grpo_dataset): - """Test GRPOTrainer can be initialized.""" - from mlx_tune import GRPOTrainer, GRPOConfig, create_reward_function + seen_contexts = [] - reward_fn = create_reward_function("simple") - - config = GRPOConfig( - beta=0.04, - num_generations=2, # Small for testing - learning_rate=1e-5, - max_steps=2, - output_dir="./test_grpo_output", - ) + def reward_fn(response: str, context: str) -> float: + seen_contexts.append(context) + return float(len(response)) trainer = GRPOTrainer( - model=small_model, - train_dataset=grpo_dataset, - tokenizer=mock_tokenizer, + model=make_model(7), + train_dataset=[{"prompt": "Solve 2 + 2", "answer": "4"}], + tokenizer=tokenizer, reward_fn=reward_fn, - args=config, + args=GRPOConfig( + learning_rate=1e-2, + num_generations=2, + max_completion_length=3, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(tmp_path), + ), ) + trainer.train() - assert trainer is not None - assert trainer.num_generations == 2 - assert trainer.reward_fn is not None + assert seen_contexts + assert all(context == "4" for context in seen_contexts) - def test_grpo_trainer_train_runs(self, small_model, mock_tokenizer, grpo_dataset): - """Test GRPOTrainer.train() executes without errors.""" - from mlx_tune import GRPOTrainer, GRPOConfig, create_reward_function - - reward_fn = create_reward_function("simple") + def test_grpo_phase1_accepts_documented_loss_type_aliases( + self, + tmp_path, + tokenizer, + grpo_dataset, + ): + from mlx_tune import GRPOConfig, GRPOTrainer + for loss_type in ["grpo", "dr_grpo", "dapo", "bnpo"]: + trainer = GRPOTrainer( + model=make_model(20), + train_dataset=grpo_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=GRPOConfig( + loss_type=loss_type, + learning_rate=1e-2, + beta=0.01, + num_generations=2, + max_completion_length=3, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(tmp_path / loss_type), + ), + ) + result = trainer.train() + assert result["status"] == "success" + assert trainer.phase1_loss_type == "phase1_shared_rollout_recompute" + + def test_grpo_resume_restores_rng_and_optimizer_state( + self, + tmp_path, + tokenizer, + grpo_dataset, + ): + from mlx_tune import GRPOConfig, GRPOTrainer + + output_dir = tmp_path / "grpo_resume" config = GRPOConfig( - beta=0.04, + learning_rate=1e-2, + beta=0.01, num_generations=2, - learning_rate=1e-5, - max_steps=2, - temperature=0.7, - output_dir="./test_grpo_output", + max_completion_length=4, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), ) + mx.random.seed(11) trainer = GRPOTrainer( - model=small_model, + model=make_model(8), train_dataset=grpo_dataset, - tokenizer=mock_tokenizer, - reward_fn=reward_fn, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), args=config, ) + trainer.train() - result = trainer.train() - assert result is not None - - def test_grpo_multi_generation(self, small_model, mock_tokenizer, grpo_dataset): - """Test that GRPO generates multiple completions per prompt.""" - from mlx_tune import GRPOTrainer, GRPOConfig, create_reward_function - - reward_fn = create_reward_function("simple") - num_gens = 3 - - config = GRPOConfig( - beta=0.04, - num_generations=num_gens, - learning_rate=1e-5, - max_steps=1, # Just one step to verify multi-gen - output_dir="./test_grpo_output", + def load_and_rollout(seed: int): + restored = GRPOTrainer( + model=make_model(seed), + train_dataset=grpo_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=GRPOConfig( + learning_rate=1e-2, + beta=0.01, + num_generations=2, + max_completion_length=4, + max_steps=2, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + restored._apply_lora_if_needed() + restored._prepare_prompt_samples() + optimizer = restored._optimizer_for_training() + restored.optimizer = optimizer + restored.load_state(optimizer, Path(output_dir)) + rollout = restored._collect_rollout_batch(restored._next_samples(restored.prompt_samples)) + return restored, rollout + + restored_a, rollout_a = load_and_rollout(9) + restored_b, rollout_b = load_and_rollout(9) + + assert restored_a.optimizer.state["step"].item() == 1 + assert rollout_a.completion_ids == rollout_b.completion_ids + assert mx.allclose(rollout_a.rollout_logprobs, rollout_b.rollout_logprobs) + + resumed = GRPOTrainer( + model=make_model(12), + train_dataset=grpo_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=GRPOConfig( + learning_rate=1e-2, + beta=0.01, + num_generations=2, + max_completion_length=4, + max_steps=2, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + result = resumed.train(resume_from_checkpoint=str(output_dir)) + assert result["global_step"] == 2 + + def test_grpo_prefers_learned_reward_model_over_reward_fn( + self, + tmp_path, + tokenizer, + grpo_dataset, + ): + from mlx_tune import GRPOConfig, GRPOTrainer, build_reward_model + + reward_model = build_reward_model(make_model(30)) + reward_model.head.update( + { + "weight": mx.zeros_like(reward_model.head.weight), + "bias": mx.array([1.0], dtype=mx.float32), + }, + strict=False, ) + mx.eval(reward_model.head.parameters()) trainer = GRPOTrainer( - model=small_model, + model=make_model(31), train_dataset=grpo_dataset, - tokenizer=mock_tokenizer, - reward_fn=reward_fn, - args=config, + tokenizer=tokenizer, + reward_model=reward_model, + reward_fn=lambda response, context: (_ for _ in ()).throw(RuntimeError("reward_fn should not run")), + args=GRPOConfig( + learning_rate=1e-2, + beta=0.01, + num_generations=2, + max_completion_length=4, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(tmp_path), + ), ) - # The trainer should be configured for multi-generation - assert trainer.num_generations == num_gens - - # Training should work result = trainer.train() - assert result is not None - - def test_grpo_with_math_reward(self, small_model, mock_tokenizer, grpo_dataset): - """Test GRPO with math reward function.""" - from mlx_tune import GRPOTrainer, GRPOConfig, create_reward_function - - math_reward = create_reward_function("math") - - config = GRPOConfig( - beta=0.04, - num_generations=2, - learning_rate=1e-5, - max_steps=2, - output_dir="./test_grpo_output", + assert result["status"] == "success" + assert trainer._last_rollout_batch is not None + assert mx.allclose( + trainer._last_rollout_batch.rewards, + mx.ones_like(trainer._last_rollout_batch.rewards), ) + def test_grpo_offline_reward_source_uses_dataset_rewards_without_reward_evaluator( + self, + tmp_path, + tokenizer, + ): + from mlx_tune import GRPOConfig, GRPOTrainer + trainer = GRPOTrainer( - model=small_model, - train_dataset=grpo_dataset, - tokenizer=mock_tokenizer, - reward_fn=math_reward, - args=config, + model=make_model(63), + train_dataset=[ + {"prompt": "Solve 2 + 2", "completion": "4", "reward": 1.0}, + {"prompt": "Solve 2 + 2", "completion": "5", "reward": -1.0}, + ], + tokenizer=tokenizer, + reward_fn=lambda *_: (_ for _ in ()).throw(RuntimeError("reward_fn should not run")), + args=GRPOConfig( + learning_rate=1e-2, + beta=0.01, + num_generations=2, + reward_source="offline", + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(tmp_path / "grpo_offline"), + ), ) result = trainer.train() - assert result is not None - def test_grpo_different_loss_types(self, small_model, mock_tokenizer, grpo_dataset): - """Test GRPO with different loss types (grpo, dr_grpo, dapo, bnpo).""" - from mlx_tune import GRPOTrainer, GRPOConfig, create_reward_function + assert result["status"] == "success" + assert trainer._last_rollout_batch is not None + assert trainer._last_rollout_batch.rewards.tolist() == [1.0, -1.0] - reward_fn = create_reward_function("simple") - loss_types = ["grpo", "dr_grpo", "dapo", "bnpo"] + def test_grpo_single_reward_source_component_preserves_weight_semantics( + self, + tmp_path, + tokenizer, + ): + from mlx_tune import GRPOConfig, GRPOTrainer - for loss_type in loss_types: - model = MockModelWrapper(SmallLanguageModel()) - mx.eval(model.model.parameters()) - - config = GRPOConfig( - loss_type=loss_type, - beta=0.04, + trainer = GRPOTrainer( + model=make_model(64), + train_dataset=[{"prompt": "Solve 2 + 2", "answer": "4"}], + tokenizer=tokenizer, + args=GRPOConfig( + learning_rate=1e-2, + beta=0.01, num_generations=2, - learning_rate=1e-5, + max_completion_length=3, max_steps=1, - output_dir="./test_grpo_output", - ) + logging_steps=1, + save_steps=1, + reward_sources=[ + {"name": "zeroed", "source": "length", "weight": 0.0}, + ], + output_dir=str(tmp_path / "grpo_weighted_single_reward"), + ), + ) - trainer = GRPOTrainer( - model=model, - train_dataset=grpo_dataset, - tokenizer=mock_tokenizer, - reward_fn=reward_fn, - args=config, - ) + result = trainer.train() - result = trainer.train() - assert result is not None, f"GRPO with loss_type={loss_type} failed" + assert result["status"] == "success" + assert trainer._last_rollout_batch is not None + assert mx.allclose( + trainer._last_rollout_batch.rewards, + mx.zeros_like(trainer._last_rollout_batch.rewards), + ) - def test_grpo_custom_reward_function(self, small_model, mock_tokenizer, grpo_dataset): - """Test GRPO with a custom reward function.""" - from mlx_tune import GRPOTrainer, GRPOConfig + def test_grpo_prompt_rollouts_respect_rollout_batch_size( + self, + tmp_path, + tokenizer, + ): + from mlx_tune import GRPOConfig, GRPOTrainer - # Custom reward: reward longer responses - def length_reward(response: str, answer: str = None) -> float: - return len(response) / 100.0 # Normalize + trainer = GRPOTrainer( + model=make_model(65), + train_dataset=[ + {"prompt": "What is 1 + 1?", "answer": "2"}, + {"prompt": "What is 2 + 2?", "answer": "4"}, + {"prompt": "What is 3 + 3?", "answer": "6"}, + ], + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=GRPOConfig( + learning_rate=1e-2, + beta=0.01, + per_device_train_batch_size=1, + rollout_batch_size=3, + num_generations=2, + max_completion_length=3, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(tmp_path / "grpo_rollout_batch_size"), + ), + ) - config = GRPOConfig( - beta=0.04, - num_generations=2, - learning_rate=1e-5, - max_steps=2, - output_dir="./test_grpo_output", + result = trainer.train() + + assert result["status"] == "success" + assert trainer._last_rollout_batch is not None + assert len(trainer._last_rollout_batch.prompt_ids) == 6 + assert sorted(set(trainer._last_rollout_batch.prompt_group_indices.tolist())) == [0, 1, 2] + + def test_grpo_manifest_checkpoint_persists_roles_and_metrics( + self, + tmp_path, + tokenizer, + grpo_dataset, + ): + from mlx_tune import GRPOConfig, GRPOTrainer, build_reward_model, build_value_model + + output_dir = tmp_path / "grpo_manifest" + reward_model = build_reward_model(make_model(33)) + value_model = build_value_model(make_model(34)) + score_sequence = mx.array([[2, 4, 7]], dtype=mx.int32) + score_sequence_lengths = mx.array([3], dtype=mx.int32) + score_prompt_lengths = mx.array([1], dtype=mx.int32) + score_completion_lengths = mx.array([2], dtype=mx.int32) + saved_reward_score = reward_model.score( + score_sequence, + sequence_lengths=score_sequence_lengths, + prompt_lengths=score_prompt_lengths, + completion_lengths=score_completion_lengths, + ) + saved_value_score = value_model.predict( + score_sequence, + sequence_lengths=score_sequence_lengths, + prompt_lengths=score_prompt_lengths, + completion_lengths=score_completion_lengths, ) trainer = GRPOTrainer( - model=small_model, + model=make_model(32), train_dataset=grpo_dataset, - tokenizer=mock_tokenizer, - reward_fn=length_reward, - args=config, + tokenizer=tokenizer, + reward_model=reward_model, + value_model=value_model, + args=GRPOConfig( + learning_rate=1e-2, + beta=0.01, + num_generations=2, + max_completion_length=4, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), ) + trainer.train() + + assert (output_dir / "manifest.json").exists() + assert (output_dir / "policy" / "role.json").exists() + assert (output_dir / "reference" / "weights.safetensors").exists() + assert (output_dir / "reward_model" / "weights.safetensors").exists() + assert (output_dir / "reward_model" / "head.safetensors").exists() + assert (output_dir / "value_model" / "weights.safetensors").exists() + assert (output_dir / "value_model" / "head.safetensors").exists() + assert (output_dir / "optimizer" / "state.safetensors").exists() + assert (output_dir / "scheduler" / "state.json").exists() + assert (output_dir / "trainer" / "state.json").exists() + assert (output_dir / "trainer" / "rng.safetensors").exists() + assert (output_dir / "metrics" / "history.jsonl").exists() + assert trainer.metrics_history + + resumed = GRPOTrainer( + model=make_model(35), + train_dataset=grpo_dataset, + tokenizer=tokenizer, + args=GRPOConfig( + learning_rate=1e-2, + beta=0.01, + num_generations=2, + max_completion_length=4, + max_steps=2, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + result = resumed.train(resume_from_checkpoint=str(output_dir)) + + assert result["global_step"] == 2 + assert resumed.reward_model is not None + assert resumed.value_model is not None + assert mx.allclose( + resumed.reward_model.score( + score_sequence, + sequence_lengths=score_sequence_lengths, + prompt_lengths=score_prompt_lengths, + completion_lengths=score_completion_lengths, + ), + saved_reward_score, + ) + assert mx.allclose( + resumed.value_model.predict( + score_sequence, + sequence_lengths=score_sequence_lengths, + prompt_lengths=score_prompt_lengths, + completion_lengths=score_completion_lengths, + ), + saved_value_score, + ) + assert len(resumed.metrics_history) >= len(trainer.metrics_history) + assert resumed.loaded_checkpoint_manifest is not None - result = trainer.train() - assert result is not None - - -# ============================================================================= -# KTO TRAINER INTEGRATION TESTS -# ============================================================================= @pytest.mark.integration class TestKTOTrainerIntegration: - """Integration tests for KTOTrainer.""" - - def test_kto_trainer_init(self, small_model, mock_tokenizer, kto_dataset): - """Test KTOTrainer can be initialized.""" - from mlx_tune import KTOTrainer + def test_kto_uses_cached_reference_logprobs_instead_of_live_policy_outputs( + self, + tmp_path, + tokenizer, + kto_dataset, + ): + from mlx_tune import KTOTrainer, compute_log_probs_with_lengths, kto_loss trainer = KTOTrainer( - model=small_model, + model=make_model(13), train_dataset=kto_dataset, - tokenizer=mock_tokenizer, - learning_rate=1e-4, + tokenizer=tokenizer, + learning_rate=2e-2, max_steps=2, + logging_steps=1, + save_steps=1, + output_dir=str(tmp_path), ) - - assert trainer is not None - - def test_kto_trainer_train_runs(self, small_model, mock_tokenizer, kto_dataset): - """Test KTOTrainer.train() executes without errors.""" + trainer.train() + + sample = trainer.train_samples[0] + batch = trainer._build_batch([sample]) + cached_loss, _ = kto_loss( + trainer.model.model, + batch.input_ids, + batch.sequence_lengths, + batch.labels, + beta=trainer.beta, + reference_logprobs=batch.reference_logprobs, + ) + live_policy_reference = compute_log_probs_with_lengths( + trainer.model.model, + batch.input_ids, + batch.sequence_lengths, + ) + live_loss, _ = kto_loss( + trainer.model.model, + batch.input_ids, + batch.sequence_lengths, + batch.labels, + beta=trainer.beta, + reference_logprobs=live_policy_reference, + ) + assert abs(cached_loss.item() - live_loss.item()) > 1e-6 + + def test_legacy_rl_checkpoint_loads_and_resaves_manifest_layout( + self, + tmp_path, + tokenizer, + kto_dataset, + ): from mlx_tune import KTOTrainer - trainer = KTOTrainer( - model=small_model, + source_trainer = KTOTrainer( + model=make_model(40), train_dataset=kto_dataset, - tokenizer=mock_tokenizer, - learning_rate=1e-4, - max_steps=2, + tokenizer=tokenizer, + learning_rate=2e-2, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(tmp_path / "source"), ) + source_trainer.train() - result = trainer.train() - assert result is not None - - def test_kto_binary_feedback(self, small_model, mock_tokenizer, kto_dataset): - """Test KTO processes binary feedback correctly.""" - from mlx_tune import KTOTrainer - - # Ensure dataset has both positive and negative examples - assert any(d['label'] == 1 for d in kto_dataset), "Need positive examples" - assert any(d['label'] == 0 for d in kto_dataset), "Need negative examples" + legacy_dir = tmp_path / "legacy_kto" + write_legacy_rl_checkpoint(source_trainer, legacy_dir) - trainer = KTOTrainer( - model=small_model, + resumed = KTOTrainer( + model=make_model(41), train_dataset=kto_dataset, - tokenizer=mock_tokenizer, - learning_rate=1e-4, - max_steps=3, + tokenizer=tokenizer, + learning_rate=2e-2, + max_steps=2, + logging_steps=1, + save_steps=1, + output_dir=str(legacy_dir), ) + result = resumed.train(resume_from_checkpoint=str(legacy_dir)) - result = trainer.train() - assert result is not None + assert result["global_step"] == 2 + assert resumed.cache_metadata == source_trainer.cache_metadata + assert (legacy_dir / "manifest.json").exists() + assert (legacy_dir / "reference" / "weights.safetensors").exists() + assert (legacy_dir / "runtime" / "cache.safetensors").exists() -# ============================================================================= -# SIMPO TRAINER INTEGRATION TESTS -# ============================================================================= - @pytest.mark.integration -class TestSimPOTrainerIntegration: - """Integration tests for SimPOTrainer.""" - - def test_simpo_trainer_init(self, small_model, mock_tokenizer, preference_dataset): - """Test SimPOTrainer can be initialized.""" - from mlx_tune import SimPOTrainer - - trainer = SimPOTrainer( - model=small_model, +class TestOtherRLTrainers: + def test_orpo_and_simpo_still_train( + self, + tmp_path, + tokenizer, + preference_dataset, + ): + from mlx_tune import ORPOConfig, ORPOTrainer, SimPOTrainer + + orpo = ORPOTrainer( + model=make_model(14), train_dataset=preference_dataset, - tokenizer=mock_tokenizer, - learning_rate=1e-4, - max_steps=2, + tokenizer=tokenizer, + args=ORPOConfig( + learning_rate=1e-2, + max_steps=1, + output_dir=str(tmp_path / "orpo"), + ), ) - - assert trainer is not None - - def test_simpo_trainer_train_runs(self, small_model, mock_tokenizer, preference_dataset): - """Test SimPOTrainer.train() executes without errors.""" - from mlx_tune import SimPOTrainer - - trainer = SimPOTrainer( - model=small_model, + simpo = SimPOTrainer( + model=make_model(15), train_dataset=preference_dataset, - tokenizer=mock_tokenizer, - learning_rate=1e-4, - max_steps=2, + tokenizer=tokenizer, + learning_rate=1e-2, + max_steps=1, + output_dir=str(tmp_path / "simpo"), ) - result = trainer.train() - assert result is not None - - def test_simpo_no_reference_model(self, small_model, mock_tokenizer, preference_dataset): - """Test SimPO works without reference model.""" - from mlx_tune import SimPOTrainer + assert orpo.train()["status"] == "success" + assert simpo.train()["status"] == "success" + + def test_simpo_uses_configured_per_device_train_batch_size( + self, + tmp_path, + tokenizer, + preference_dataset, + monkeypatch, + ): + import mlx_tune.rl_trainers as rl_trainers_module + from mlx_tune import SimPOConfig, SimPOTrainer + + observed_batch_sizes = [] + original = rl_trainers_module.compute_simpo_loss + + def wrapped(model, chosen_ids, rejected_ids, chosen_lengths, rejected_lengths, beta, gamma): + observed_batch_sizes.append(int(chosen_ids.shape[0])) + return original( + model, + chosen_ids, + rejected_ids, + chosen_lengths, + rejected_lengths, + beta, + gamma, + ) - # SimPO is special: it doesn't require a reference model + monkeypatch.setattr(rl_trainers_module, "compute_simpo_loss", wrapped) trainer = SimPOTrainer( - model=small_model, + model=make_model(16), train_dataset=preference_dataset, - tokenizer=mock_tokenizer, - learning_rate=1e-4, - max_steps=3, + tokenizer=tokenizer, + args=SimPOConfig( + learning_rate=1e-2, + per_device_train_batch_size=2, + max_steps=1, + output_dir=str(tmp_path / "simpo_batch"), + ), ) result = trainer.train() - assert result is not None + assert result["status"] == "success" + assert observed_batch_sizes == [2] + + def test_online_dpo_uses_reward_model_preference_ordering( + self, + tmp_path, + tokenizer, + ): + from mlx_tune import OnlineDPOConfig, OnlineDPOTrainer, build_reward_model + + reward_model = build_reward_model(make_model(60)) + reward_model.head.update( + { + "weight": mx.zeros_like(reward_model.head.weight), + "bias": mx.array([1.0], dtype=mx.float32), + }, + strict=False, + ) + mx.eval(reward_model.head.parameters()) + + trainer = OnlineDPOTrainer( + model=make_model(61), + train_dataset=[{"prompt": "Solve 1 + 1", "reward_context": "2"}], + tokenizer=tokenizer, + reward_model=reward_model, + reward_fn=lambda response, context: (_ for _ in ()).throw(RuntimeError("reward_fn should not run")), + args=OnlineDPOConfig( + learning_rate=1e-2, + max_steps=1, + logging_steps=1, + save_steps=1, + num_generations=2, + max_completion_length=3, + output_dir=str(tmp_path / "online_dpo"), + ), + ) -# ============================================================================= -# GRADIENT FLOW TESTS -# ============================================================================= - -@pytest.mark.integration -class TestGradientFlow: - """Test that gradients flow correctly in RL trainers.""" - - def _check_grads_nonzero(self, grads): - """Recursively check if any gradients are non-zero.""" - if isinstance(grads, dict): - for value in grads.values(): - if self._check_grads_nonzero(value): - return True - return False - elif isinstance(grads, mx.array): - return float(mx.sum(mx.abs(grads)).item()) > 0 - return False - - def test_dpo_gradient_not_zero(self): - """Test DPO computes non-zero gradients.""" - from mlx_tune.losses import dpo_loss - - model = SmallLanguageModel(vocab_size=50, hidden_size=32) - mx.eval(model.parameters()) - - # Create sample data - chosen = mx.array([[1, 2, 3, 4, 5]]) - rejected = mx.array([[1, 2, 6, 7, 8]]) - chosen_len = mx.array([5]) - rejected_len = mx.array([5]) - - # Compute loss and gradients - def loss_fn(model): - loss, _ = dpo_loss(model, chosen, rejected, chosen_len, rejected_len, beta=0.1) - return loss - - loss, grads = nn.value_and_grad(model, loss_fn)(model) - mx.eval(loss, grads) - - # Check gradients are not all zero - has_nonzero_grad = self._check_grads_nonzero(grads) + result = trainer.train() - assert has_nonzero_grad, "DPO gradients should not all be zero" - assert not mx.isnan(loss), "DPO loss should not be NaN" + assert result["status"] == "success" + assert trainer._last_rollout_batch is not None + assert mx.allclose( + trainer._last_rollout_batch.rewards, + mx.ones_like(trainer._last_rollout_batch.rewards), + ) - def test_orpo_gradient_not_zero(self): - """Test ORPO computes non-zero gradients.""" - from mlx_tune.losses import orpo_loss + def test_online_dpo_prefers_answer_context_over_prompt( + self, + tmp_path, + tokenizer, + ): + from mlx_tune import OnlineDPOConfig, OnlineDPOTrainer - model = SmallLanguageModel(vocab_size=50, hidden_size=32) - mx.eval(model.parameters()) + seen_contexts = [] - chosen = mx.array([[1, 2, 3, 4, 5]]) - rejected = mx.array([[1, 2, 6, 7, 8]]) - chosen_len = mx.array([5]) - rejected_len = mx.array([5]) + def reward_fn(response: str, context: str) -> float: + seen_contexts.append(context) + return float(len(response)) - def loss_fn(model): - loss, _ = orpo_loss(model, chosen, rejected, chosen_len, rejected_len, beta=0.1) - return loss + trainer = OnlineDPOTrainer( + model=make_model(62), + train_dataset=[{"prompt": "Solve 2 + 2", "answer": "4"}], + tokenizer=tokenizer, + reward_fn=reward_fn, + args=OnlineDPOConfig( + learning_rate=1e-2, + max_steps=1, + logging_steps=1, + save_steps=1, + num_generations=2, + max_completion_length=3, + output_dir=str(tmp_path / "online_dpo_answer_context"), + ), + ) - loss, grads = nn.value_and_grad(model, loss_fn)(model) - mx.eval(loss, grads) + result = trainer.train() - has_nonzero_grad = self._check_grads_nonzero(grads) + assert result["status"] == "success" + assert seen_contexts + assert all(context == "4" for context in seen_contexts) + + def test_online_dpo_prompt_rollouts_respect_rollout_batch_size( + self, + tmp_path, + tokenizer, + ): + from mlx_tune import OnlineDPOConfig, OnlineDPOTrainer + + trainer = OnlineDPOTrainer( + model=make_model(66), + train_dataset=[ + {"prompt": "Solve 1 + 1", "answer": "2"}, + {"prompt": "Solve 2 + 2", "answer": "4"}, + {"prompt": "Solve 3 + 3", "answer": "6"}, + ], + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=OnlineDPOConfig( + learning_rate=1e-2, + per_device_train_batch_size=1, + rollout_batch_size=3, + max_steps=1, + logging_steps=1, + save_steps=1, + num_generations=2, + max_completion_length=3, + output_dir=str(tmp_path / "online_dpo_rollout_batch_size"), + ), + ) - assert has_nonzero_grad, "ORPO gradients should not all be zero" - assert not mx.isnan(loss), "ORPO loss should not be NaN" + result = trainer.train() + assert result["status"] == "success" + assert trainer._last_rollout_batch is not None + assert len(trainer._last_rollout_batch.prompt_ids) == 6 + assert sorted(set(trainer._last_rollout_batch.prompt_group_indices.tolist())) == [0, 1, 2] -# ============================================================================= -# LOSS STABILITY TESTS -# ============================================================================= @pytest.mark.integration -class TestLossStability: - """Test that losses remain stable (no NaN, Inf) during training.""" - - def test_dpo_loss_stability(self): - """Test DPO loss doesn't produce NaN or Inf.""" - from mlx_tune.losses import dpo_loss - - model = SmallLanguageModel(vocab_size=50, hidden_size=32) - mx.eval(model.parameters()) - - # Run multiple forward passes - for i in range(10): - chosen = mx.random.randint(0, 50, (2, 20)) - rejected = mx.random.randint(0, 50, (2, 20)) - chosen_len = mx.array([15, 18]) - rejected_len = mx.array([17, 16]) - - loss, ntoks = dpo_loss(model, chosen, rejected, chosen_len, rejected_len, beta=0.1) - mx.eval(loss) - - assert not mx.isnan(loss), f"DPO loss became NaN at iteration {i}" - assert not mx.isinf(loss), f"DPO loss became Inf at iteration {i}" - - def test_orpo_loss_stability(self): - """Test ORPO loss doesn't produce NaN or Inf.""" - from mlx_tune.losses import orpo_loss - - model = SmallLanguageModel(vocab_size=50, hidden_size=32) - mx.eval(model.parameters()) - - for i in range(10): - chosen = mx.random.randint(0, 50, (2, 20)) - rejected = mx.random.randint(0, 50, (2, 20)) - chosen_len = mx.array([15, 18]) - rejected_len = mx.array([17, 16]) - - loss, _ = orpo_loss(model, chosen, rejected, chosen_len, rejected_len, beta=0.1) - mx.eval(loss) - - assert not mx.isnan(loss), f"ORPO loss became NaN at iteration {i}" - assert not mx.isinf(loss), f"ORPO loss became Inf at iteration {i}" - - def test_simpo_loss_stability(self): - """Test SimPO loss doesn't produce NaN or Inf.""" - from mlx_tune.losses import simpo_loss - - model = SmallLanguageModel(vocab_size=50, hidden_size=32) - mx.eval(model.parameters()) - - for i in range(10): - chosen = mx.random.randint(0, 50, (2, 20)) - rejected = mx.random.randint(0, 50, (2, 20)) - chosen_len = mx.array([15, 18]) - rejected_len = mx.array([17, 16]) - - loss, _ = simpo_loss(model, chosen, rejected, chosen_len, rejected_len, - beta=0.1, gamma=0.5) - mx.eval(loss) - - assert not mx.isnan(loss), f"SimPO loss became NaN at iteration {i}" - assert not mx.isinf(loss), f"SimPO loss became Inf at iteration {i}" +def test_grpo_evaluate_records_namespaced_metrics_and_preference_accuracy(tmp_path, tokenizer): + from mlx_tune import GRPOConfig, GRPOTrainer + + output_dir = tmp_path / "grpo_eval" + trainer = GRPOTrainer( + model=make_model(80), + train_dataset=[{"prompt": "Solve 2 + 2", "answer": "4"}], + eval_dataset=[{"prompt": "Solve 3 + 3", "answer": "6"}], + eval_preference_dataset=[ + {"prompt": "Q", "chosen": "AAAA", "rejected": "B"}, + {"prompt": "Q", "chosen": "CCCC", "rejected": "D"}, + ], + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=GRPOConfig( + seed=123, + learning_rate=1e-2, + beta=0.01, + num_generations=2, + max_completion_length=3, + max_steps=1, + eval_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + + trainer.train() + eval_metrics = trainer.evaluate() + history_rows = [json.loads(line) for line in (output_dir / "metrics" / "history.jsonl").read_text().splitlines()] + + assert "eval/reward_mean" in eval_metrics + assert "eval/preference_win_rate" in eval_metrics + assert any("train/policy_loss" in row for row in history_rows) + assert any("eval/reward_mean" in row for row in history_rows) + assert any("eval/preference_win_rate" in row for row in history_rows) -# ============================================================================= -# REWARD FUNCTION TESTS -# ============================================================================= +@pytest.mark.integration +@pytest.mark.parametrize("trainer_kind", ["grpo", "ppo", "online_dpo"]) +def test_on_policy_trainers_with_same_seed_are_reproducible(tmp_path, tokenizer, grpo_dataset, trainer_kind): + from mlx_tune import ( + GRPOConfig, + GRPOTrainer, + OnlineDPOConfig, + OnlineDPOTrainer, + PPOConfig, + PPOTrainer, + build_value_model, + ) + + def build_trainer(output_dir): + if trainer_kind == "grpo": + return GRPOTrainer( + model=make_model(81), + train_dataset=grpo_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=GRPOConfig( + seed=777, + learning_rate=1e-2, + beta=0.01, + num_generations=2, + max_completion_length=3, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + if trainer_kind == "ppo": + return PPOTrainer( + model=make_model(81), + train_dataset=grpo_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + value_model=build_value_model(make_model(82)), + args=PPOConfig( + seed=777, + learning_rate=1e-2, + value_learning_rate=1e-2, + num_generations=2, + max_completion_length=3, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + return OnlineDPOTrainer( + model=make_model(81), + train_dataset=grpo_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=OnlineDPOConfig( + seed=777, + learning_rate=1e-2, + num_generations=2, + max_completion_length=3, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) -class TestRewardFunctions: - """Test reward functions used in GRPO.""" + trainer_a = build_trainer(tmp_path / f"{trainer_kind}_a") + trainer_b = build_trainer(tmp_path / f"{trainer_kind}_b") + trainer_a.train() + trainer_b.train() + + def _strip_timing(metrics_history): + cleaned = [] + for row in metrics_history: + cleaned.append( + { + key: value + for key, value in row.items() + if not key.endswith("_wall") + } + ) + return cleaned - def test_simple_reward_function(self): - """Test simple reward function.""" - from mlx_tune import create_reward_function + assert _strip_timing(trainer_a.metrics_history) == _strip_timing(trainer_b.metrics_history) + assert trainer_a._last_rollout_batch is not None + assert trainer_b._last_rollout_batch is not None + assert trainer_a._last_rollout_batch.completion_ids == trainer_b._last_rollout_batch.completion_ids - reward_fn = create_reward_function("simple") - # Simple reward expects (response, ground_truth) - # Returns 1.0 if ground_truth is in response, else 0.0 - score_match = reward_fn("The answer is 42", "42") - score_no_match = reward_fn("The answer is something", "42") +@pytest.mark.integration +def test_grpo_checkpoint_records_step_boundary_and_independent_cursors(tmp_path, tokenizer): + from mlx_tune import GRPOConfig, GRPOTrainer, resume_from_checkpoint + + output_dir = tmp_path / "grpo_cursor_checkpoint" + trainer = GRPOTrainer( + model=make_model(83), + train_dataset=[ + {"prompt": "Q1", "completion": "A1", "reward": 1.0}, + {"prompt": "Q2", "completion": "A2", "reward": 0.5}, + {"prompt": "Q3", "completion": "A3", "reward": 0.2}, + {"prompt": "Q4", "completion": "A4", "reward": -0.1}, + ], + tokenizer=tokenizer, + args=GRPOConfig( + seed=99, + learning_rate=1e-2, + reward_source="offline", + num_generations=2, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) - assert score_match == 1.0, "Should return 1.0 when ground_truth is in response" - assert score_no_match == 0.0, "Should return 0.0 when ground_truth is not in response" + trainer.train() + bundle = resume_from_checkpoint(output_dir) + state = bundle.trainer_state - def test_math_reward_function(self): - """Test math reward function.""" - from mlx_tune import create_reward_function + assert bundle.manifest["format_version"] == 4 + assert state["seed"] == 99 + assert state["trainer_state"]["step_boundary"]["completed_optimizer_step"] == 1 + assert state["trainer_state"]["step_boundary"]["checkpoint_authoritative"] is True + assert state["trainer_state"]["cursors"]["prompt_dataset"] == 0 + assert state["trainer_state"]["cursors"]["offline_rollout_dataset"] == 2 - reward_fn = create_reward_function("math") - # Math reward expects (response, ground_truth) and compares extracted numbers - correct = reward_fn("The answer is 42", "42") - incorrect = reward_fn("The answer is 99", "42") +@pytest.mark.integration +def test_grpo_reuses_precomputed_rollout_reference_cache_after_resume(tmp_path, tokenizer, monkeypatch): + import mlx_tune.rl_trainers as rl_trainers_module + from mlx_tune import GRPOConfig, GRPOTrainer + + output_dir = tmp_path / "grpo_cache_resume" + train_dataset = [ + {"prompt": "Q1", "completion": "A1", "reward": 1.0}, + {"prompt": "Q2", "completion": "A2", "reward": 0.5}, + {"prompt": "Q3", "completion": "A3", "reward": -0.2}, + ] + trainer = GRPOTrainer( + model=make_model(84), + train_dataset=train_dataset, + tokenizer=tokenizer, + args=GRPOConfig( + seed=11, + learning_rate=1e-2, + reward_source="offline", + num_generations=2, + precompute_reference_scores=True, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + trainer.train() + + resumed = GRPOTrainer( + model=make_model(85), + train_dataset=train_dataset, + tokenizer=tokenizer, + args=GRPOConfig( + seed=11, + learning_rate=1e-2, + reward_source="offline", + num_generations=2, + precompute_reference_scores=True, + max_steps=2, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + resumed._apply_lora_if_needed() + resumed._prepare_prompt_samples() + optimizer = resumed._optimizer_for_training() + resumed.optimizer = optimizer + resumed.load_state(optimizer=optimizer, checkpoint_dir=output_dir) + resumed._ensure_reference_policy() + + original = rl_trainers_module.score_policy_in_chunks + reference_model = resumed.reference_policy.model.model + calls = {"reference": 0} + + def wrapped(model, *args, **kwargs): + if model is reference_model: + calls["reference"] += 1 + return original(model, *args, **kwargs) + + monkeypatch.setattr(rl_trainers_module, "score_policy_in_chunks", wrapped) + batch = resumed._collect_fixed_rollout_batch( + resumed.rollout_samples[:2], + cache_key="grpo.train_rollout_reference_logprobs", + ) + + assert calls["reference"] == 0 + assert batch.reference_logprobs is not None - assert correct == 1.0, "Correct math answer should get 1.0" - assert incorrect == 0.0, "Incorrect math answer should get 0.0" - def test_length_reward_function(self): - """Test length-based reward function.""" - from mlx_tune import create_reward_function +@pytest.mark.integration +def test_grpo_fixed_rollouts_respect_length_caps(tmp_path, tokenizer): + from mlx_tune import GRPOConfig, GRPOTrainer + + trainer = GRPOTrainer( + model=make_model(87), + train_dataset=[ + {"prompt": "abcdefghij", "completion": "klmnopqrst", "reward": 1.0}, + ], + tokenizer=tokenizer, + args=GRPOConfig( + seed=13, + learning_rate=1e-2, + reward_source="offline", + max_seq_length=4, + max_completion_length=2, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(tmp_path / "grpo_fixed_rollout_caps"), + ), + ) + + trainer._apply_lora_if_needed() + trainer._prepare_prompt_samples() + batch = trainer._collect_fixed_rollout_batch(trainer.rollout_samples) + + assert batch.prompt_lengths.tolist() == [2] + assert batch.completion_lengths.tolist() == [2] + assert batch.policy_eval.input_ids.shape == (1, 4) + assert batch.truncation_flags.tolist() == [True] - reward_fn = create_reward_function("length") - # Length reward expects (response, _) where _ is ignored - short = reward_fn("Hi there", "") - medium = reward_fn("This is a longer response with about fifteen words in it here now", "") - long = reward_fn(" ".join(["word"] * 100), "") +@pytest.mark.integration +def test_grpo_invalidates_precomputed_rollout_reference_cache_when_dataset_changes( + tmp_path, + tokenizer, + monkeypatch, +): + import mlx_tune.rl_trainers as rl_trainers_module + from mlx_tune import GRPOConfig, GRPOTrainer + + output_dir = tmp_path / "grpo_cache_dataset_drift" + original_dataset = [ + {"prompt": "Q1", "completion": "A1", "reward": 1.0}, + {"prompt": "Q2", "completion": "A2", "reward": 0.5}, + ] + changed_dataset = [ + {"prompt": "!!!!!!!!", "completion": "zzzzzzzz", "reward": 1.0}, + {"prompt": "????????", "completion": "yyyyyyyy", "reward": 0.5}, + ] + trainer = GRPOTrainer( + model=make_model(88), + train_dataset=original_dataset, + tokenizer=tokenizer, + args=GRPOConfig( + seed=19, + learning_rate=1e-2, + reward_source="offline", + num_generations=2, + precompute_reference_scores=True, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + trainer.train() + original_scores = mx.array(trainer.runtime_cache_arrays["grpo.train_rollout_reference_logprobs"]) + + resumed = GRPOTrainer( + model=make_model(89), + train_dataset=changed_dataset, + tokenizer=tokenizer, + args=GRPOConfig( + seed=19, + learning_rate=1e-2, + reward_source="offline", + num_generations=2, + precompute_reference_scores=True, + max_steps=2, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + resumed._apply_lora_if_needed() + resumed._prepare_prompt_samples() + optimizer = resumed._optimizer_for_training() + resumed.optimizer = optimizer + resumed.load_state(optimizer=optimizer, checkpoint_dir=output_dir) + resumed._ensure_reference_policy() + + original = rl_trainers_module.score_policy_in_chunks + reference_model = resumed.reference_policy.model.model + calls = {"reference": 0} + + def wrapped(model, *args, **kwargs): + if model is reference_model: + calls["reference"] += 1 + return original(model, *args, **kwargs) + + monkeypatch.setattr(rl_trainers_module, "score_policy_in_chunks", wrapped) + batch = resumed._collect_fixed_rollout_batch( + resumed.rollout_samples, + cache_key="grpo.train_rollout_reference_logprobs", + ) + + assert calls["reference"] > 0 + assert not mx.allclose(batch.reference_logprobs, original_scores) - # Short (<10 words) = 0.2, Medium (10-50) = 0.5, Long (50-200) = 1.0 - assert short == 0.2, f"Short response should be 0.2, got {short}" - assert medium == 0.5, f"Medium response should be 0.5, got {medium}" - assert long == 1.0, f"Long response should be 1.0, got {long}" - def test_custom_reward_function(self): - """Test using a custom reward function.""" - def my_reward(response: str, ground_truth: str = "") -> float: - # Reward responses that contain "please" or "thank" - score = 0.0 - if "please" in response.lower(): - score += 0.5 - if "thank" in response.lower(): - score += 0.5 - return score +@pytest.mark.integration +def test_grpo_resume_rejects_objective_and_update_shape_drift(tmp_path, tokenizer, grpo_dataset): + from mlx_tune import GRPOConfig, GRPOTrainer + + output_dir = tmp_path / "grpo_resume_config_drift" + trainer = GRPOTrainer( + model=make_model(90), + train_dataset=grpo_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=GRPOConfig( + seed=31, + learning_rate=1e-2, + beta=0.01, + clip_epsilon=0.2, + rollout_batch_size=1, + minibatch_reuse_steps=1, + num_generations=2, + max_completion_length=4, + max_steps=1, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + trainer.train() + + resumed = GRPOTrainer( + model=make_model(91), + train_dataset=grpo_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=GRPOConfig( + seed=31, + learning_rate=1e-2, + beta=0.99, + clip_epsilon=0.33, + rollout_batch_size=5, + minibatch_reuse_steps=7, + num_generations=2, + max_completion_length=4, + max_steps=2, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + resumed._apply_lora_if_needed() + resumed._prepare_prompt_samples() + optimizer = resumed._optimizer_for_training() + resumed.optimizer = optimizer - assert my_reward("Please help me") == 0.5 - assert my_reward("Thank you") == 0.5 - assert my_reward("Please help, thank you!") == 1.0 - assert my_reward("Hello world") == 0.0 + with pytest.raises(ValueError, match="fingerprint"): + resumed.load_state(optimizer=optimizer, checkpoint_dir=output_dir) -# Run tests -if __name__ == "__main__": - pytest.main([__file__, "-v", "-m", "integration"]) +@pytest.mark.integration +def test_online_dpo_eval_preference_reference_cache_reused_after_resume(tmp_path, tokenizer, monkeypatch): + import mlx_tune.rl_trainers as rl_trainers_module + from mlx_tune import OnlineDPOConfig, OnlineDPOTrainer + + output_dir = tmp_path / "online_dpo_eval_cache" + eval_preference_dataset = [ + {"prompt": "Q", "chosen": "AA", "rejected": "B"}, + {"prompt": "Q", "chosen": "CC", "rejected": "D"}, + ] + trainer = OnlineDPOTrainer( + model=make_model(86), + train_dataset=[{"prompt": "Solve 2 + 2", "answer": "4"}], + eval_preference_dataset=eval_preference_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=OnlineDPOConfig( + seed=22, + learning_rate=1e-2, + num_generations=2, + max_completion_length=3, + max_steps=1, + eval_steps=1, + precompute_reference_scores=True, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + trainer.train() + trainer.evaluate() + trainer.save_state(optimizer=trainer.optimizer) + + resumed = OnlineDPOTrainer( + model=make_model(87), + train_dataset=[{"prompt": "Solve 2 + 2", "answer": "4"}], + eval_preference_dataset=eval_preference_dataset, + tokenizer=tokenizer, + reward_fn=lambda response, context: float(len(response)), + args=OnlineDPOConfig( + seed=22, + learning_rate=1e-2, + num_generations=2, + max_completion_length=3, + max_steps=2, + eval_steps=1, + precompute_reference_scores=True, + logging_steps=1, + save_steps=1, + output_dir=str(output_dir), + ), + ) + resumed._apply_lora_if_needed() + resumed._prepare_prompt_samples() + optimizer = resumed._optimizer_for_training() + resumed.optimizer = optimizer + resumed.load_state(optimizer=optimizer, checkpoint_dir=output_dir) + + original = rl_trainers_module.score_policy_in_chunks + reference_model = resumed.reference_policy.model.model + calls = {"reference": 0} + + def wrapped(model, *args, **kwargs): + if model is reference_model: + calls["reference"] += 1 + return original(model, *args, **kwargs) + + monkeypatch.setattr(rl_trainers_module, "score_policy_in_chunks", wrapped) + eval_metrics = resumed.evaluate() + + assert calls["reference"] == 0 + assert "eval/preference_win_rate" in eval_metrics diff --git a/tests/test_trainers.py b/tests/test_trainers.py index 61367ae..012c1e3 100644 --- a/tests/test_trainers.py +++ b/tests/test_trainers.py @@ -74,6 +74,17 @@ def test_dpoconfig_custom_beta(self): assert config.beta == 0.5 +class TestRewardConfig: + def test_rewardconfig_defaults(self): + from mlx_tune import RewardConfig + + config = RewardConfig() + + assert config.learning_rate == 5e-6 + assert config.regression_loss_type == "mse" + assert config.dataset_mode is None + + class TestGRPOConfig: """Test GRPOConfig class.""" @@ -84,9 +95,13 @@ def test_grpoconfig_defaults(self): config = GRPOConfig() assert config.loss_type == "grpo" + assert config.advantage_mode == "group_zscore" + assert config.advantage_estimator == "group_zscore" assert config.num_generations == 4 assert config.temperature == 0.7 assert config.beta == 0.04 + assert config.kl_beta == 0.04 + assert config.reward_source == "auto" def test_grpoconfig_with_reward_fn(self): """Test GRPOConfig with custom reward function.""" @@ -100,6 +115,62 @@ def custom_reward(response, prompt): assert config.reward_fn is not None assert config.num_generations == 8 + def test_grpoconfig_aliases_and_to_dict(self): + from mlx_tune import GRPOConfig + + config = GRPOConfig( + generations_per_prompt=6, + advantage_estimator="rloo", + reward_fn=lambda *_: 1.0, + ) + + assert config.num_generations == 6 + assert config.advantage_mode == "rloo" + assert config.to_dict()["num_generations"] == 6 + assert "reward_fn" not in config.to_dict() + + def test_grpoconfig_variant_defaults_match_loss_family(self): + from mlx_tune import GRPOConfig + + dapo = GRPOConfig(loss_type="dapo") + dr_grpo = GRPOConfig(loss_type="dr_grpo") + + assert dapo.mask_truncated_completions is True + assert dapo.epsilon_high == 0.28 + assert dr_grpo.scale_rewards is False + assert dr_grpo.epsilon_low == dr_grpo.clip_epsilon + assert dr_grpo.epsilon_high == dr_grpo.clip_epsilon + + +class TestPPOAndOnlineDPOConfig: + def test_ppoconfig_defaults(self): + from mlx_tune import PPOConfig + + config = PPOConfig() + + assert config.ppo_epochs == 2 + assert config.minibatch_reuse_steps == 2 + assert config.value_learning_rate == config.learning_rate + assert config.advantage_estimator == "gae" + assert config.kl_penalty_mode == "kl" + + def test_online_dpoconfig_defaults(self): + from mlx_tune import OnlineDPOConfig + + config = OnlineDPOConfig() + + assert config.num_generations == 4 + assert config.beta == 0.1 + + def test_new_offline_config_exports(self): + from mlx_tune import KTOConfig, SimPOConfig + + kto = KTOConfig() + simpo = SimPOConfig() + + assert kto.beta == 0.1 + assert simpo.gamma == 0.5 + class TestTrainerInitialization: """Test trainer initialization (without actual model loading).""" @@ -109,21 +180,30 @@ def test_imports_work(self): from mlx_tune import ( SFTTrainer, SFTConfig, + RewardTrainer, + RewardConfig, DPOTrainer, DPOConfig, ORPOTrainer, ORPOConfig, GRPOTrainer, GRPOConfig, + PPOTrainer, + PPOConfig, + OnlineDPOTrainer, + OnlineDPOConfig, KTOTrainer, SimPOTrainer, ) # Just verify imports work assert SFTTrainer is not None + assert RewardTrainer is not None assert DPOTrainer is not None assert ORPOTrainer is not None assert GRPOTrainer is not None + assert PPOTrainer is not None + assert OnlineDPOTrainer is not None assert KTOTrainer is not None assert SimPOTrainer is not None @@ -165,6 +245,11 @@ def test_prepare_preference_dataset_import(self): from mlx_tune import prepare_preference_dataset assert prepare_preference_dataset is not None + def test_prepare_rl_dataset_import(self): + from mlx_tune import prepare_rl_dataset + + assert prepare_rl_dataset is not None + def test_create_reward_function_simple(self): """Test create_reward_function with simple type.""" from mlx_tune import create_reward_function @@ -205,6 +290,21 @@ def test_create_reward_function_length(self): medium_result = reward_fn(" ".join(["word"] * 30), "") assert medium_result == 0.5 + def test_create_reward_function_composition(self): + from mlx_tune import create_reward_function + + reward_fn = create_reward_function( + rewards=[ + {"name": "simple", "source": "simple", "weight": 0.25}, + {"name": "length", "source": "length", "weight": 0.75}, + ] + ) + + result = reward_fn.evaluate({"completion_text": "The answer is 42", "reward_context": "42"}) + + assert result["reward"] > 0.0 + assert set(result["components"]) >= {"simple", "length"} + class TestExportFunctions: """Test export utility functions.""" diff --git a/tests/test_trl_compat.py b/tests/test_trl_compat.py new file mode 100644 index 0000000..d9f6395 --- /dev/null +++ b/tests/test_trl_compat.py @@ -0,0 +1,204 @@ +import importlib +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import mlx.core as mx +import mlx.nn as nn +import pytest + + +class TinyModel(nn.Module): + def __init__(self, vocab_size: int = 32, hidden_size: int = 16): + super().__init__() + self.embedding = nn.Embedding(vocab_size, hidden_size) + self.output = nn.Linear(hidden_size, vocab_size) + + def __call__(self, x): + return self.output(self.embedding(x)) + + +class TinyTokenizer: + pad_token_id = 0 + eos_token_id = 1 + bos_token_id = 2 + + def encode(self, text: str, add_special_tokens: bool = True): + token_ids = [((ord(char) % 10) + 3) for char in text] + if add_special_tokens: + token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id] + return token_ids + + def decode(self, ids, skip_special_tokens: bool = True): + if skip_special_tokens: + ids = [token for token in ids if token not in (self.pad_token_id, self.eos_token_id, self.bos_token_id)] + return "".join(chr(65 + (token % 26)) for token in ids) + + +@pytest.fixture(autouse=True) +def isolated_trl_modules(): + saved = { + name: module + for name, module in sys.modules.items() + if name == "trl" or name.startswith("trl.") + } + for name in list(saved): + sys.modules.pop(name, None) + yield + for name in list(sys.modules): + if name == "trl" or name.startswith("trl."): + sys.modules.pop(name, None) + sys.modules.update(saved) + + +def test_patch_fast_rl_creates_fallback_trl_module_when_missing(): + from mlx_tune import GRPOConfig as MLXGRPOConfig + from mlx_tune import GRPOTrainer as MLXGRPOTrainer + from mlx_tune import PatchFastRL + from mlx_tune import SFTConfig as MLXSFTConfig + from mlx_tune import SFTTrainer as MLXSFTTrainer + + PatchFastRL() + + trl = importlib.import_module("trl") + from trl import GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer + + assert trl.__MLX_TUNE_PATCHED__ is True + assert trl.trainer.__MLX_TUNE_PATCHED__ is True + assert issubclass(GRPOConfig, MLXGRPOConfig) + assert issubclass(GRPOTrainer, MLXGRPOTrainer) + assert issubclass(SFTConfig, MLXSFTConfig) + assert issubclass(SFTTrainer, MLXSFTTrainer) + assert trl.trainer.GRPOTrainer is GRPOTrainer + assert trl.trainer.SFTConfig is SFTConfig + + +def test_patch_fast_rl_mutates_existing_module_in_place_and_is_idempotent(): + from mlx_tune import PatchFastRL + + trl_module = ModuleType("trl") + trl_module.__package__ = "trl" + trl_module.__path__ = [] + trl_module.keep_me = "present" + trainer_module = ModuleType("trl.trainer") + trainer_module.keep_me_too = "present" + trl_module.trainer = trainer_module + sys.modules["trl"] = trl_module + sys.modules["trl.trainer"] = trainer_module + + PatchFastRL() + first_grpo_trainer = trl_module.GRPOTrainer + + PatchFastRL() + + assert sys.modules["trl"] is trl_module + assert sys.modules["trl.trainer"] is trainer_module + assert trl_module.trainer is trainer_module + assert trl_module.keep_me == "present" + assert trainer_module.keep_me_too == "present" + assert trl_module.GRPOTrainer is first_grpo_trainer + assert trainer_module.GRPOTrainer is first_grpo_trainer + + +def test_grpo_config_normalizes_reward_aliases(): + from mlx_tune import PatchFastRL + + PatchFastRL() + + from trl import GRPOConfig + + reward_fn = lambda response, context: float(len(response) + len(context)) + reward_funcs = [reward_fn, reward_fn] + + config = GRPOConfig( + reward_funcs=reward_funcs, + reward_func=reward_fn, + generations_per_prompt=3, + baseline_mode="rloo", + ) + + assert config.reward_sources[:2] == reward_funcs + assert config.reward_sources[2]["source"] is reward_fn + assert config.reward_fn is reward_fn + assert config.num_generations == 3 + assert config.advantage_estimator == "rloo" + + +def test_grpo_trainer_accepts_processing_class_and_foreign_args(tmp_path): + from mlx_tune import PatchFastRL + + PatchFastRL() + + from trl import GRPOTrainer + + reward_fn = lambda response, context: float(len(response) + len(context)) + tokenizer = TinyTokenizer() + model = TinyModel() + mx.eval(model.parameters()) + output_dir = tmp_path / "grpo" + + trainer = GRPOTrainer( + model=model, + processing_class=tokenizer, + train_dataset=[{"prompt": "hi"}], + eval_dataset=[{"prompt": "bye"}], + args=SimpleNamespace( + output_dir=str(output_dir), + learning_rate=1e-5, + per_device_train_batch_size=1, + num_train_epochs=1, + max_steps=1, + logging_steps=1, + save_steps=1, + max_seq_length=8, + max_completion_length=2, + generations_per_prompt=2, + reward_funcs=[reward_fn], + ), + ) + + assert trainer.tokenizer is tokenizer + assert trainer.eval_dataset == [{"prompt": "bye"}] + assert trainer.config.num_generations == 2 + assert trainer.reward_sources == [reward_fn] + assert trainer.output_dir == output_dir + + +def test_dpo_trainer_runs_via_patched_trl_import(tmp_path): + from mlx_tune import PatchFastRL + + PatchFastRL() + + from trl import DPOTrainer + + tokenizer = TinyTokenizer() + model = TinyModel() + mx.eval(model.parameters()) + + trainer = DPOTrainer( + model=model, + processing_class=tokenizer, + train_dataset=[ + {"prompt": "a", "chosen": "b", "rejected": "c"}, + {"prompt": "d", "chosen": "e", "rejected": "f"}, + ], + args=SimpleNamespace( + output_dir=str(tmp_path / "dpo"), + learning_rate=1e-4, + per_device_train_batch_size=1, + gradient_accumulation_steps=1, + num_train_epochs=1, + max_steps=1, + warmup_steps=0, + logging_steps=1, + save_steps=1, + max_seq_length=12, + max_prompt_length=6, + ), + ) + + result = trainer.train() + + assert result["status"] == "success" + assert result["global_step"] == 1 + assert Path(result["adapter_path"]).exists()