From ac3d085f6bc63e01fd144ed25974055c0935fa7a Mon Sep 17 00:00:00 2001 From: ErfanMhi Date: Thu, 27 Nov 2025 03:14:46 +0000 Subject: [PATCH 01/17] refactor(environments): change math source to EleutherAI huggingface repo! --- grail/environments/providers.py | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/grail/environments/providers.py b/grail/environments/providers.py index 1636a2b6..6fc0e40e 100644 --- a/grail/environments/providers.py +++ b/grail/environments/providers.py @@ -215,6 +215,44 @@ def _extract_boxed_answer(solution: str) -> str: _MATH_VAL_SEED = 42 +def _extract_boxed_answer(solution: str) -> str: + """Extract answer from \\boxed{...} in solution, handling nested braces.""" + import re + + match = re.search(r"\\boxed\{", solution) + if not match: + return "" + + start = match.end() + depth = 1 + i = start + while i < len(solution) and depth > 0: + if solution[i] == "{": + depth += 1 + elif solution[i] == "}": + depth -= 1 + i += 1 + + return solution[start : i - 1] if depth == 0 else "" + + +# Subsets in EleutherAI/hendrycks_math dataset +_MATH_SUBSETS = ( + "algebra", + "counting_and_probability", + "geometry", + "intermediate_algebra", + "number_theory", + "prealgebra", + "precalculus", +) + +# Fixed validation set size (stratified across problem types) +_MATH_VAL_SIZE = 500 +# Seed for deterministic stratified sampling of validation set +_MATH_VAL_SEED = 42 + + class MATHTaskSource(TaskSource): """HF datasets-backed Hendrycks MATH provider with stratified train/val split. From 902a13ff62fee5dc7387ed2eb8bc99b9edb105f7 Mon Sep 17 00:00:00 2001 From: ErfanMhi Date: Thu, 27 Nov 2025 03:15:41 +0000 Subject: [PATCH 02/17] feat(evaluation): add eval_math_harness for evaluating models on Hendrycks MATH dataset --- eval_math_harness.py | 251 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 251 insertions(+) create mode 100644 eval_math_harness.py diff --git a/eval_math_harness.py b/eval_math_harness.py new file mode 100644 index 00000000..3540bde9 --- /dev/null +++ b/eval_math_harness.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +"""Evaluate Qwen/Qwen2.5-1.5B-Instruct on Hendrycks MATH using lm-evaluation-harness. + +Metrics Computed: +----------------- +1. Overall exact_match accuracy (primary) +2. Per-subject accuracy (7 subjects: algebra, counting_and_prob, + geometry, intermediate_algebra, num_theory, prealgebra, precalc) +3. Aggregated by difficulty level (1-5) via post-processing + +Best Practices Used: +-------------------- +1. minerva_math task - Better prompts with \boxed{} extraction (standard for MATH) +2. 4-shot prompting - Standard for MATH benchmark +3. Chain-of-thought - Enabled via task's native format +4. Greedy decoding - temperature=0 for reproducibility +5. max_gen_toks=1024 - Sufficient for reasoning chains +6. Flash attention - Memory efficient for 7B+ models +7. BF16 precision - Optimal for modern GPUs +8. Batch size tuning - Auto via batch_size="auto" + +Usage: +------ + python eval_math_harness.py + python eval_math_harness.py --model Qwen/Qwen2.5-7B-Instruct + python eval_math_harness.py --num-fewshot 0 # zero-shot +""" + +import argparse +import json +import logging +import sys +from datetime import datetime +from pathlib import Path + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger(__name__) + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Evaluate model on Hendrycks MATH using lm-eval-harness" + ) + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen2.5-1.5B-Instruct", + help="Model name or path", + ) + parser.add_argument( + "--num-fewshot", + type=int, + default=4, + help="Number of few-shot examples (default: 4)", + ) + parser.add_argument( + "--batch-size", + type=str, + default="auto", + help="Batch size (default: auto)", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("eval_results"), + help="Output directory for results", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use (default: cuda)", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Limit number of samples per task (for debugging)", + ) + return parser.parse_args() + + +def run_evaluation(args: argparse.Namespace) -> dict: + """Run lm-evaluation-harness on MATH dataset.""" + try: + from lm_eval import simple_evaluate + from lm_eval.models.huggingface import HFLM + except ImportError: + logger.error("lm-eval not installed. Run: pip install lm-eval") + sys.exit(1) + + logger.info(f"Model: {args.model}") + logger.info(f"Few-shot: {args.num_fewshot}") + logger.info(f"Batch size: {args.batch_size}") + + # MATH subtasks (all 7 subjects) + # Using hendrycks_math tasks with proper \boxed{} answer extraction + tasks = [ + "hendrycks_math_algebra", + "hendrycks_math_counting_and_prob", + "hendrycks_math_geometry", + "hendrycks_math_intermediate_algebra", + "hendrycks_math_num_theory", + "hendrycks_math_prealgebra", + "hendrycks_math_precalc", + ] + + logger.info(f"Tasks: {tasks}") + + # Model configuration with best practices + model_kwargs = { + "pretrained": args.model, + "dtype": "bfloat16", + "device_map": "auto", + "trust_remote_code": True, + # Enable flash attention if available + "attn_implementation": "flash_attention_2", + } + + # Try to load with flash attention, fall back if not available + try: + model = HFLM(**model_kwargs) + except Exception as e: + logger.warning(f"Flash attention failed ({e}), using default attention") + model_kwargs.pop("attn_implementation", None) + model = HFLM(**model_kwargs) + + # Run evaluation + logger.info("Starting evaluation...") + results = simple_evaluate( + model=model, + tasks=tasks, + num_fewshot=args.num_fewshot, + batch_size=args.batch_size, + device=args.device, + limit=args.limit, + # Greedy decoding for reproducibility + gen_kwargs="temperature=0,do_sample=False", + log_samples=True, + ) + + return results + + +def print_results(results: dict, args: argparse.Namespace) -> None: + """Print formatted results.""" + print("\n" + "=" * 70) + print(f"MATH Evaluation Results - {args.model}") + print("=" * 70) + + # Extract per-subject results + subject_results = {} + total_correct = 0 + total_samples = 0 + + for task_name, task_results in results.get("results", {}).items(): + # Extract subject from task name and expand abbreviations + subject = task_name.replace("hendrycks_math_", "") + # Expand abbreviated names for readability + subject = subject.replace("counting_and_prob", "counting_and_probability") + subject = subject.replace("num_theory", "number_theory") + subject = subject.replace("precalc", "precalculus") + + # Get accuracy metric (exact_match or acc) + acc = task_results.get("exact_match,none", task_results.get("acc,none", 0)) + stderr = task_results.get("exact_match_stderr,none", task_results.get("acc_stderr,none", 0)) + + subject_results[subject] = { + "accuracy": acc, + "stderr": stderr, + } + + # For aggregate calculation + n_samples = task_results.get("alias", {}).get("n-shot", 0) + if "samples" in results: + task_samples = results["samples"].get(task_name, []) + n_samples = len(task_samples) + total_correct += sum(1 for s in task_samples if s.get("acc", 0) == 1) + total_samples += n_samples + + # Print per-subject results + print("\nPer-Subject Accuracy:") + print("-" * 50) + for subject, data in sorted(subject_results.items()): + acc_pct = data["accuracy"] * 100 + stderr_pct = data["stderr"] * 100 + print(f" {subject:35s} {acc_pct:5.2f}% ± {stderr_pct:.2f}%") + + # Print aggregate + if total_samples > 0: + overall_acc = total_correct / total_samples * 100 + else: + # Use average of subjects + accs = [d["accuracy"] for d in subject_results.values()] + overall_acc = sum(accs) / len(accs) * 100 if accs else 0 + + print("-" * 50) + print(f" {'OVERALL':35s} {overall_acc:5.2f}%") + print("=" * 70) + + +def save_results(results: dict, args: argparse.Namespace) -> Path: + """Save results to JSON file.""" + args.output_dir.mkdir(parents=True, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + model_name = args.model.replace("/", "_") + output_file = args.output_dir / f"math_{model_name}_{timestamp}.json" + + # Add metadata + results["metadata"] = { + "model": args.model, + "num_fewshot": args.num_fewshot, + "batch_size": args.batch_size, + "timestamp": timestamp, + } + + with open(output_file, "w") as f: + json.dump(results, f, indent=2, default=str) + + logger.info(f"Results saved to: {output_file}") + return output_file + + +def main() -> None: + """Main entry point.""" + args = parse_args() + + logger.info("=" * 50) + logger.info("Hendrycks MATH Evaluation") + logger.info("=" * 50) + + # Run evaluation + results = run_evaluation(args) + + # Print results + print_results(results, args) + + # Save results + save_results(results, args) + + logger.info("Evaluation complete!") + + +if __name__ == "__main__": + main() From 0144e4f2d11a669e87dcb3eeb86f38f2f61f26bb Mon Sep 17 00:00:00 2001 From: ErfanMhi Date: Thu, 27 Nov 2025 07:44:36 +0000 Subject: [PATCH 03/17] chore(config): change evaluation split from 'test' to 'val' for math dataset --- grail/trainer/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grail/trainer/config.py b/grail/trainer/config.py index 54ac1e49..5e08333d 100644 --- a/grail/trainer/config.py +++ b/grail/trainer/config.py @@ -96,7 +96,7 @@ class EvalConfig: enabled: bool = True window_interval: int = 20 - split: str = "val" # dataset-backed envs (e.g., GSM8K) #TODO: should be specified per env + split: str = "val" # Use validation split subset_size: int | None = None # generative envs or capped dataset eval seed_base: int = 2025 batch_size: int = 32 # Conservative for vLLM server: 8 tasks × 5 reps = 40 prompts/batch (prevent queue timeout) From 9375570f198b0d4040aa4d2292cf16d7380f2fb0 Mon Sep 17 00:00:00 2001 From: ErfanMhi Date: Fri, 28 Nov 2025 18:27:01 +0000 Subject: [PATCH 04/17] feat(training): add unified TRL GRPO training script for GSM8K and MATH datasets with detailed README --- research/trl/train_trl_grpo.py | 971 ++++++++++++++++++++++++++ research/trl/train_trl_grpo_README.md | 136 ++++ 2 files changed, 1107 insertions(+) create mode 100644 research/trl/train_trl_grpo.py create mode 100644 research/trl/train_trl_grpo_README.md diff --git a/research/trl/train_trl_grpo.py b/research/trl/train_trl_grpo.py new file mode 100644 index 00000000..b631e97c --- /dev/null +++ b/research/trl/train_trl_grpo.py @@ -0,0 +1,971 @@ +#!/usr/bin/env python3 +"""TRL GRPO training script with factory pattern for GSM8K and MATH datasets. + +Supports both datasets with exact parity to GRAIL environment implementations: +- GSM8K: Grade school math (7,473 train / 1,319 test) +- MATH: Hendrycks MATH benchmark (7,000 train / 500 val / 5,000 test) + +Usage: + python train_trl_grpo.py --dataset gsm8k + python train_trl_grpo.py --dataset math +""" + +from __future__ import annotations + +import abc +import argparse +import asyncio +import os +import re +import sys +from dataclasses import dataclass +from typing import Any + +import torch +from datasets import Dataset +from dotenv import load_dotenv +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedTokenizer, + TrainerCallback, +) +from trl import GRPOConfig, GRPOTrainer + +# Force unbuffered output for better logging in nohup mode +sys.stdout = open(sys.stdout.fileno(), mode="w", buffering=1) +sys.stderr = open(sys.stderr.fileno(), mode="w", buffering=1) + +# Load environment from .env for WandB +load_dotenv("/root/grail/.env") + +sys.path.append("/root/grail") + +# GRAIL imports - reuse task sources and validation logic (after sys.path.append) +from grail.environments.math_hendrycks_env import _math_answers_equal # noqa: E402 +from grail.environments.providers import GSM8KTaskSource, MATHTaskSource # noqa: E402 +from grail.shared.chat_templates import build_qwen_chat_template # noqa: E402 +from grail.trainer.metrics import KMetricsAggregator, TaskReplicateResult # noqa: E402 + + +# ════════════════════════════════════════════════════════════════════════════ +# HYPERPARAMETERS (from .env GRAIL config) +# ════════════════════════════════════════════════════════════════════════════ +@dataclass +class Config: + # Model (from GRAIL_TRAIN_MODEL_ID) + model_id: str = "Qwen/Qwen2.5-1.5B-Instruct" + # Learning rate (from GRAIL_TRAINER_LR) + lr: float = 3e-6 + # Epochs per window (from GRAIL_TRAINER_EPOCHS) + epochs: int = 1 + # Batch size (from GRAIL_TRAINER_BATCH_SIZE) + batch_size: int = 4 + # Gradient accumulation (from GRAIL_TRAINER_GRAD_ACCUM_STEPS) + grad_accum_steps: int = 128 + # Max sequence length (from GRAIL_TRAINER_MAX_LENGTH) + max_length: int = 2048 + # Gradient clipping (from GRAIL_TRAINER_GRAD_CLIP) + grad_clip: float = 1.0 + # Warmup steps (from GRAIL_TRAINER_WARMUP_STEPS) + warmup_steps: int = 50 + # KL coefficient (from GRAIL_TRAINER_KL_COEF) + kl_coef: float = 0.0 + # Entropy coefficient (from GRAIL_TRAINER_ENTROPY_COEF) + entropy_coef: float = 0.0005 + # PPO clip epsilon (standard GRAIL values) + ppo_clip_eps: float = 0.2 + ppo_clip_eps_upper: float = 0.28 + # Importance sampling ratio max (from GRAIL_TRAINER_IS_RATIO_MAX) + is_ratio_max: float = 2.5 + # Log-ratio clamp (from GRAIL_TRAINER_LOGRATIO_CLAMP) + logratio_clamp: float = 0.92 + # Dataset sampling + num_train_samples: int | None = None # None = use all training samples + num_eval_samples: int | None = None # None = use all test samples + # Rollouts per problem (matches GRAIL default) + rollouts_per_problem: int = 16 + # Generation parameters + temperature: float = 0.7 + top_p: float = 0.95 + top_k: int = 50 + # Max completion tokens (from GRPO_MAX_COMPLETION_TOKENS) + max_new_tokens: int = 1024 + # Evaluation config + eval_replicates: int = 5 + report_ks: tuple[int, ...] = (1, 5, 10) + # Evaluation optimization + eval_batch_size: int = 128 + eval_num_workers: int = 4 + # Max groups for GRPO (from GRPO_MAX_GROUPS) + max_groups: int = 128 + + +cfg = Config() + +# ════════════════════════════════════════════════════════════════════════════ +# SYSTEM PROMPT & TAGS (shared across datasets) +# ════════════════════════════════════════════════════════════════════════════ +REASONING_START_TOKEN = "start_working_out" +REASONING_END_TOKEN = "end_working_out" +SOLUTION_START_TOKEN = "SOLUTION" +SOLUTION_END_TOKEN = "SOLUTION" + +REASONING_START = f"<{REASONING_START_TOKEN}>" +REASONING_END = f"" +SOLUTION_START = f"<{SOLUTION_START_TOKEN}>" +SOLUTION_END = f"" + +SYSTEM_PROMPT = ( + "You are given a problem.\n" + "Think about the problem and provide your working out.\n" + f"Place it between {REASONING_START} and {REASONING_END}.\n" + f"Then, provide your solution between {SOLUTION_START}{SOLUTION_END}." +) + +QWEN_CHAT_TEMPLATE = build_qwen_chat_template( + system_prompt=SYSTEM_PROMPT, reasoning_start=REASONING_START +) + + +# ════════════════════════════════════════════════════════════════════════════ +# DATASET ADAPTER (Abstract Base + Concrete Implementations) +# ════════════════════════════════════════════════════════════════════════════ +class DatasetAdapter(abc.ABC): + """Abstract base class for dataset adapters. + + Provides unified interface for: + - Loading train/eval datasets + - Parsing gold answers + - Computing rewards + - Determining success threshold + """ + + @property + @abc.abstractmethod + def name(self) -> str: + """Dataset name for logging.""" + ... + + @property + @abc.abstractmethod + def question_field(self) -> str: + """Field name for question/problem text.""" + ... + + @property + @abc.abstractmethod + def answer_field(self) -> str: + """Field name for gold answer.""" + ... + + @property + @abc.abstractmethod + def correctness_weight(self) -> float: + """Weight for correctness component in reward.""" + ... + + @property + @abc.abstractmethod + def success_threshold(self) -> float: + """Reward threshold for success (correctness weight).""" + ... + + @abc.abstractmethod + def load_train_data(self) -> list[dict[str, Any]]: + """Load training data as list of dicts.""" + ... + + @abc.abstractmethod + def load_eval_data(self) -> list[dict[str, Any]]: + """Load evaluation data as list of dicts.""" + ... + + @abc.abstractmethod + def parse_gold_answer(self, raw_answer: str) -> str: + """Extract gold answer from dataset format.""" + ... + + @abc.abstractmethod + def validate_answer(self, predicted: str, gold: str) -> bool: + """Check if predicted answer matches gold.""" + ... + + @abc.abstractmethod + def compute_reward(self, completion: str, gold_answer: str) -> float: + """Compute total reward for completion.""" + ... + + +# ──────────────────────────────────────────────────────────────────────────── +# GSM8K Adapter +# ──────────────────────────────────────────────────────────────────────────── +class GSM8KAdapter(DatasetAdapter): + """GSM8K dataset adapter using GRAIL's GSM8KTaskSource.""" + + # Regex patterns (from gsm8k_env.py) + _HASH_PATTERN = re.compile(r"####\s*(?P.+)") + _NUMBER_PATTERN = re.compile(r"[-+]?\d+(?:[\.,]\d+)?") + _NUMERIC_ONLY_PATTERN = re.compile(r"^[-+]?[\d.,]+$") + + def __init__(self) -> None: + self._train_source = GSM8KTaskSource(split="train") + self._eval_source = GSM8KTaskSource(split="test") + + @property + def name(self) -> str: + return "gsm8k" + + @property + def question_field(self) -> str: + return "question" + + @property + def answer_field(self) -> str: + return "answer" + + @property + def correctness_weight(self) -> float: + return 0.6 # GSM8K uses 0.6 for correctness + + @property + def success_threshold(self) -> float: + return 0.6 # Success if correctness achieved + + def load_train_data(self) -> list[dict[str, Any]]: + """Load GSM8K training data via task source.""" + self._train_source._ensure_dataset() + assert self._train_source._ds is not None + data = [] + for i in range(len(self._train_source._ds)): + sample = self._train_source._ds[i] + data.append( + { + "question": sample["question"], + "answer": sample["answer"], + } + ) + return data + + def load_eval_data(self) -> list[dict[str, Any]]: + """Load GSM8K test data via task source.""" + self._eval_source._ensure_dataset() + assert self._eval_source._ds is not None + data = [] + for i in range(len(self._eval_source._ds)): + sample = self._eval_source._ds[i] + data.append( + { + "question": sample["question"], + "answer": sample["answer"], + } + ) + return data + + def parse_gold_answer(self, raw_answer: str) -> str: + """Parse GSM8K gold answer from #### format.""" + match = None + for m in self._HASH_PATTERN.finditer(raw_answer or ""): + match = m + if match is not None: + return match.group("ans").strip() + nums = list(self._NUMBER_PATTERN.finditer(raw_answer or "")) + if nums: + return nums[-1].group(0).replace(",", "").strip() + return "" + + def validate_answer(self, predicted: str, gold: str) -> bool: + """Validate GSM8K answer (numeric exact match).""" + pred_norm = re.sub(r"[\s\.]+$", "", predicted.strip().lower()) + gold_norm = re.sub(r"[\s\.]+$", "", gold.strip().lower()) + return pred_norm == gold_norm + + def _parse_completion(self, text: str) -> dict[str, Any]: + """Parse completion for thinking/answer tags.""" + flags = re.DOTALL | re.IGNORECASE + has_thinking = bool( + re.search(rf"<{REASONING_START_TOKEN}>.*?", text, flags) + ) + answer_match = re.search( + rf"<{SOLUTION_START_TOKEN}>\s*(.+?)\s*", text, flags + ) + + answer_text = "" + has_answer = bool(answer_match) + is_numeric_only = False + trailing = 0 + + if answer_match: + inside = answer_match.group(1).strip() + num_match = self._NUMBER_PATTERN.search(inside) + if num_match: + answer_text = num_match.group(0).replace(",", "").strip() + is_numeric_only = bool(self._NUMERIC_ONLY_PATTERN.match(inside.replace(" ", ""))) + trailing = len(text) - answer_match.end() + + return { + "answer_text": answer_text, + "has_thinking": has_thinking, + "has_answer": has_answer, + "is_numeric_only": is_numeric_only, + "trailing": trailing, + } + + def compute_reward(self, completion: str, gold_answer: str) -> float: + """Compute GSM8K reward (matching GSM8KEnv weights). + + Components: + - Correctness (0.6): exact match + - Strict format (0.15): numeric-only + no trailing + - Thinking (0.1): has thinking block + - Answer (0.1): has answer block + - No trailing (0.05): penalty for trailing text + """ + parsed = self._parse_completion(completion) + gold_parsed = self.parse_gold_answer(gold_answer) + + # Correctness + correctness = 0.6 if self.validate_answer(parsed["answer_text"], gold_parsed) else 0.0 + + # Strict format + strict_format = ( + 0.15 + if (parsed["has_answer"] and parsed["is_numeric_only"] and parsed["trailing"] == 0) + else 0.0 + ) + + # Thinking format + thinking = 0.1 if parsed["has_thinking"] else 0.0 + + # Answer format + answer = 0.1 if parsed["has_answer"] else 0.0 + + # No trailing + no_trailing = 0.05 if parsed["trailing"] == 0 else 0.0 + + return correctness + strict_format + thinking + answer + no_trailing + + +# ──────────────────────────────────────────────────────────────────────────── +# MATH (Hendrycks) Adapter +# ──────────────────────────────────────────────────────────────────────────── +class MATHAdapter(DatasetAdapter): + """MATH dataset adapter using GRAIL's MATHTaskSource. + + Uses exact same validation logic as MATHEnv: + - Multi-strategy comparison (exact, symbolic via sympy, numeric) + - LaTeX normalization + - Stratified train/val split (500 val samples) + """ + + def __init__(self) -> None: + self._train_source = MATHTaskSource(split="train") + self._eval_source = MATHTaskSource(split="val") # Use stratified val split + + @property + def name(self) -> str: + return "math" + + @property + def question_field(self) -> str: + return "question" # Normalized to 'question' for consistency + + @property + def answer_field(self) -> str: + return "answer" + + @property + def correctness_weight(self) -> float: + return 0.7 # MATH uses 0.7 for correctness + + @property + def success_threshold(self) -> float: + return 0.7 # Success if correctness achieved + + def load_train_data(self) -> list[dict[str, Any]]: + """Load MATH training data via task source (7000 samples).""" + self._train_source._ensure_dataset() + assert self._train_source._data is not None + data = [] + for sample in self._train_source._data: + data.append( + { + "question": sample["problem"], # Normalize field name + "answer": sample["answer"], # Pre-extracted from \boxed{} + "solution": sample["solution"], + "level": sample["level"], + "subject": sample["subject"], + } + ) + return data + + def load_eval_data(self) -> list[dict[str, Any]]: + """Load MATH validation data via task source (500 samples, stratified).""" + self._eval_source._ensure_dataset() + assert self._eval_source._data is not None + data = [] + for sample in self._eval_source._data: + data.append( + { + "question": sample["problem"], + "answer": sample["answer"], + "solution": sample["solution"], + "level": sample["level"], + "subject": sample["subject"], + } + ) + return data + + def parse_gold_answer(self, raw_answer: str) -> str: + """For MATH, answer is already extracted from \\boxed{} by TaskSource.""" + return raw_answer + + def validate_answer(self, predicted: str, gold: str) -> bool: + """Validate MATH answer using multi-strategy comparison. + + Uses GRAIL's _math_answers_equal which tries: + 1. Exact match (after LaTeX normalization) + 2. Symbolic equivalence (via sympy) + 3. Numeric comparison (floats) + """ + return _math_answers_equal(predicted, gold) + + def _parse_completion(self, text: str) -> dict[str, Any]: + """Parse completion for thinking/answer tags (MATH-specific).""" + flags = re.DOTALL | re.IGNORECASE + has_thinking = bool( + re.search(rf"<{REASONING_START_TOKEN}>.*?", text, flags) + ) + answer_match = re.search( + rf"<{SOLUTION_START_TOKEN}>\s*(.+?)\s*", text, flags + ) + + answer_text = "" + has_answer = bool(answer_match) + trailing = 0 + + if answer_match: + answer_text = answer_match.group(1).strip() + trailing = len(text) - answer_match.end() + + return { + "answer_text": answer_text, + "has_thinking": has_thinking, + "has_answer": has_answer, + "trailing": trailing, + } + + def compute_reward(self, completion: str, gold_answer: str) -> float: + """Compute MATH reward (matching MATHEnv weights). + + Components: + - Correctness (0.7): Multi-strategy validation + - Answer format (0.15): Has answer + minimal trailing + - Thinking (0.1): Has thinking block + - No trailing (0.05): Penalty for excessive trailing + """ + parsed = self._parse_completion(completion) + + # Correctness (using multi-strategy validation) + correctness = 0.7 if self.validate_answer(parsed["answer_text"], gold_answer) else 0.0 + + # Answer format (has answer + trailing < 50) + answer_format = 0.15 if (parsed["has_answer"] and parsed["trailing"] < 50) else 0.0 + + # Thinking format + thinking = 0.1 if parsed["has_thinking"] else 0.0 + + # No trailing (stricter check) + no_trailing = 0.05 if parsed["trailing"] == 0 else 0.0 + + return correctness + answer_format + thinking + no_trailing + + +# ════════════════════════════════════════════════════════════════════════════ +# FACTORY FUNCTION +# ════════════════════════════════════════════════════════════════════════════ +def get_dataset_adapter(dataset_name: str) -> DatasetAdapter: + """Factory function to get dataset adapter by name. + + Args: + dataset_name: 'gsm8k' or 'math' + + Returns: + DatasetAdapter instance + + Raises: + ValueError: If dataset_name is not supported + """ + adapters: dict[str, type[DatasetAdapter]] = { + "gsm8k": GSM8KAdapter, + "math": MATHAdapter, + } + + if dataset_name.lower() not in adapters: + raise ValueError(f"Unknown dataset: {dataset_name}. Supported: {list(adapters.keys())}") + + return adapters[dataset_name.lower()]() + + +# ════════════════════════════════════════════════════════════════════════════ +# DATA PREPARATION +# ════════════════════════════════════════════════════════════════════════════ +def prepare_train_dataset(adapter: DatasetAdapter, tokenizer: PreTrainedTokenizer) -> Dataset: + """Load and format training dataset for TRL GRPO. + + Args: + adapter: Dataset adapter instance + tokenizer: Tokenizer for chat template formatting + + Returns: + HuggingFace Dataset with 'prompt' and 'gold_answer' columns + """ + raw_data = adapter.load_train_data() + + if cfg.num_train_samples is not None: + raw_data = raw_data[: cfg.num_train_samples] + + formatted = [] + for sample in raw_data: + question = sample[adapter.question_field] + prompt = tokenizer.apply_chat_template( + [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": question}, + ], + tokenize=False, + add_generation_prompt=True, + ) + formatted.append( + { + "prompt": prompt, + "gold_answer": sample[adapter.answer_field], + } + ) + + print(f" Training dataset ({adapter.name}): {len(formatted)} samples") + return Dataset.from_list(formatted) + + +def prepare_eval_dataset(adapter: DatasetAdapter) -> tuple[Dataset, list[dict[str, Any]]]: + """Load evaluation dataset. + + Args: + adapter: Dataset adapter instance + + Returns: + Tuple of (HuggingFace Dataset, raw data list for reward computation) + """ + raw_data = adapter.load_eval_data() + + if cfg.num_eval_samples is not None: + raw_data = raw_data[: cfg.num_eval_samples] + + print(f" Eval dataset ({adapter.name}): {len(raw_data)} samples") + return Dataset.from_list(raw_data), raw_data + + +# ════════════════════════════════════════════════════════════════════════════ +# VLLM EVALUATION CALLBACK +# ════════════════════════════════════════════════════════════════════════════ +class VLLMEvalCallback(TrainerCallback): + """Evaluation callback using TRL vLLM server with dataset adapter.""" + + def __init__( + self, + adapter: DatasetAdapter, + eval_data: list[dict[str, Any]], + tokenizer: PreTrainedTokenizer, + vllm_base_url: str, + eval_every_n_steps: int = 30, + ) -> None: + self.adapter = adapter + self.eval_data = eval_data + self.tokenizer = tokenizer + self.eval_every_n = eval_every_n_steps + self.base_url = vllm_base_url.rstrip("/") + self._metrics_defined = False + + print( + f"✓ VLLMEvalCallback initialized: dataset={adapter.name}, " + f"url={vllm_base_url}, eval_every={eval_every_n_steps}" + ) + + def run_and_log(self, step: int, label: str = "VLLM EVAL") -> dict[str, float]: + """Run evaluation and log to WandB.""" + print(f"\n{'=' * 80}") + print(f"[{label}] Step {step}: Starting {self.adapter.name.upper()} evaluation...") + print(f"{'=' * 80}") + + metrics = asyncio.run(self._run_eval()) + + try: + import wandb + + if wandb.run is not None: + if not self._metrics_defined: + wandb.define_metric("eval_step") + wandb.define_metric("eval_vllm/*", step_metric="eval_step") + self._metrics_defined = True + + wandb_data = { + "eval_step": step, + "trainer/global_step": step, + } + wandb_data.update({f"eval_vllm/{k}": v for k, v in metrics.items()}) + wandb.log(wandb_data) + except Exception as e: + print(f"⚠️ WandB logging failed: {e}") + + print(f"[{label}] Results: {metrics}") + print(f"{'=' * 80}\n") + return metrics + + def on_step_end(self, args: Any, state: Any, control: Any, **kwargs: Any) -> None: + """Run evaluation every N steps.""" + if state.global_step >= self.eval_every_n and state.global_step % self.eval_every_n == 0: + self.run_and_log(state.global_step) + + async def _run_eval(self) -> dict[str, float]: + """Run evaluation using vLLM chat completions API.""" + import time + + from tqdm import tqdm + + start_time = time.time() + aggregator = KMetricsAggregator(report_ks=cfg.report_ks) + + total_tasks = len(self.eval_data) + batch_size = cfg.eval_batch_size + + with tqdm(total=total_tasks, desc=f"Eval ({self.adapter.name})", unit="task") as pbar: + for batch_start in range(0, total_tasks, batch_size): + batch_end = min(batch_start + batch_size, total_tasks) + batch = self.eval_data[batch_start:batch_end] + + # Get questions using adapter's field name + batch_questions = [s[self.adapter.question_field] for s in batch] + batch_golds = [s[self.adapter.answer_field] for s in batch] + + # Expand: each question gets N replicates + tasks_to_generate = [] + task_metadata = [] + + for idx, question in enumerate(batch_questions): + task_id = f"q{batch_start + idx}" + for rep_idx in range(cfg.eval_replicates): + tasks_to_generate.append(question) + task_metadata.append( + { + "task_id": task_id, + "task_idx": idx, + "replicate_idx": rep_idx, + } + ) + + # Generate completions + completions = await self._generate_batch(tasks_to_generate) + + # Log sample completions + if batch_start == 0: + print("\n ━━━ Sample Completions ━━━") + for i in range(min(3, len(completions))): + question = tasks_to_generate[i] + completion = completions[i] + metadata = task_metadata[i] + gold = batch_golds[metadata["task_idx"]] + reward = self.adapter.compute_reward(completion, gold) + + q_display = question[:150] + "..." if len(question) > 150 else question + c_display = ( + completion[:300] + "..." if len(completion) > 300 else completion + ) + print(f"\n Sample {i + 1}:") + print(f" Question: {q_display}") + print(f" Completion: {c_display}") + print(f" Reward: {reward:.3f} | Gold: {gold[:50]}...") + print(" ━━━━━━━━━━━━━━━━━━━━━━━━━\n") + + # Compute rewards and aggregate + for completion_text, metadata in zip(completions, task_metadata, strict=False): + task_id = metadata["task_id"] + task_idx = metadata["task_idx"] + replicate_idx = metadata["replicate_idx"] + gold = batch_golds[task_idx] + + reward = self.adapter.compute_reward(completion_text, gold) + success = reward >= self.adapter.success_threshold + + aggregator.add( + TaskReplicateResult( + task_id=task_id, + replicate_idx=replicate_idx, + reward=reward, + success=success, + ) + ) + + pbar.update(len(batch_questions)) + + metrics = aggregator.summarize() + elapsed = time.time() - start_time + throughput = (total_tasks * cfg.eval_replicates) / elapsed if elapsed > 0 else 0 + + print( + f" ✓ Evaluated {total_tasks} tasks × {cfg.eval_replicates} reps in {elapsed:.2f}s " + f"({throughput:.1f} completions/sec)" + ) + + return metrics + + async def _generate_batch(self, questions: list[str]) -> list[str]: + """Generate completions using TRL /chat/ endpoint with batching.""" + import asyncio + + import aiohttp + + vllm_batch_size = 64 + total = len(questions) + num_requests = (total + vllm_batch_size - 1) // vllm_batch_size + print(f" Generating {total} completions via {num_requests} batched requests") + + async def generate_batch_request( + session: aiohttp.ClientSession, batch_questions: list[str], start_idx: int + ) -> tuple[int, list[list[int]]]: + max_retries = 3 + base_backoff = 1.0 + + messages = [ + [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": q}, + ] + for q in batch_questions + ] + + payload = { + "messages": messages, + "max_tokens": cfg.max_new_tokens, + "temperature": cfg.temperature, + "top_p": cfg.top_p, + "top_k": cfg.top_k, + "repetition_penalty": 1.1, + "n": 1, + } + + for attempt in range(max_retries): + try: + async with session.post( + f"{self.base_url}/chat/", + json=payload, + timeout=aiohttp.ClientTimeout(total=300.0), + ) as response: + if response.status == 200: + data = await response.json() + return (start_idx, data["completion_ids"]) + else: + error_text = await response.text() + raise Exception(f"HTTP {response.status}: {error_text}") + except Exception as e: + if attempt < max_retries - 1: + backoff = base_backoff * (2**attempt) + await asyncio.sleep(backoff) + else: + print(f" ⚠️ Batch {start_idx} failed: {type(e).__name__}") + return (start_idx, [[] for _ in batch_questions]) + return (start_idx, [[] for _ in batch_questions]) + + async with aiohttp.ClientSession() as session: + tasks = [] + for batch_start in range(0, total, vllm_batch_size): + batch_end = min(batch_start + vllm_batch_size, total) + batch_questions = questions[batch_start:batch_end] + tasks.append(generate_batch_request(session, batch_questions, batch_start)) + + results = await asyncio.gather(*tasks, return_exceptions=False) + + all_completion_ids: list[list[int]] = [[] for _ in range(total)] + for start_idx, completion_ids_batch in results: + for offset, comp_ids in enumerate(completion_ids_batch): + all_completion_ids[start_idx + offset] = comp_ids + + completions = [] + for comp_ids in all_completion_ids: + if comp_ids: + completion_text = self.tokenizer.decode(comp_ids, skip_special_tokens=True) + completions.append(completion_text) + else: + completions.append("") + + return completions + + +# ════════════════════════════════════════════════════════════════════════════ +# MAIN TRAINING +# ════════════════════════════════════════════════════════════════════════════ +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="TRL GRPO training with GSM8K or MATH dataset") + parser.add_argument( + "--dataset", + type=str, + default="gsm8k", + choices=["gsm8k", "math"], + help="Dataset to use for training (default: gsm8k)", + ) + parser.add_argument( + "--eval-every", + type=int, + default=30, + help="Run evaluation every N steps (default: 30)", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + print(f"🚀 Starting TRL GRPO training with {args.dataset.upper()} dataset") + print("=" * 60) + + # Get dataset adapter + adapter = get_dataset_adapter(args.dataset) + print(f" Dataset: {adapter.name}") + print(f" Correctness weight: {adapter.correctness_weight}") + print(f" Success threshold: {adapter.success_threshold}") + + # Load model and tokenizer + print("\n📦 Loading model and tokenizer...") + try: + model = AutoModelForCausalLM.from_pretrained( + cfg.model_id, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + ) + except (ImportError, RuntimeError) as e: + print(f"⚠️ Flash Attention 2 unavailable ({type(e).__name__}), using default") + model = AutoModelForCausalLM.from_pretrained( + cfg.model_id, + torch_dtype=torch.bfloat16, + ) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model_id) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + tokenizer.chat_template = QWEN_CHAT_TEMPLATE + + # Prepare datasets + print("\n📊 Preparing datasets...") + train_ds = prepare_train_dataset(adapter, tokenizer) + eval_ds, eval_data = prepare_eval_dataset(adapter) + prompt_to_answer = {row["prompt"]: row["gold_answer"] for row in train_ds} + + # WandB setup + print("\n⚙️ Configuring GRPO trainer...") + import wandb + + wandb_api_key = os.getenv("WANDB_API_KEY") + if wandb_api_key: + wandb.login(key=wandb_api_key) + print(f" ✓ WandB logged in (project: {os.getenv('WANDB_PROJECT', 'grail')})") + + # Calculate max_prompt_length + max_prompt_length = cfg.max_length - cfg.max_new_tokens + + grpo_config = GRPOConfig( + output_dir=f"./outputs/trl_{adapter.name}", + learning_rate=cfg.lr, + num_train_epochs=cfg.epochs, + per_device_train_batch_size=cfg.batch_size, + gradient_accumulation_steps=cfg.grad_accum_steps, + max_grad_norm=cfg.grad_clip, + warmup_steps=cfg.warmup_steps, + beta=cfg.kl_coef, + epsilon=cfg.ppo_clip_eps, + epsilon_high=cfg.ppo_clip_eps_upper, + max_prompt_length=max_prompt_length, + max_completion_length=cfg.max_new_tokens, + temperature=cfg.temperature, + top_p=cfg.top_p, + top_k=cfg.top_k, + repetition_penalty=1.1, + num_generations=cfg.rollouts_per_problem, + generation_batch_size=16, + steps_per_generation=None, + logging_steps=1, + log_completions=True, + num_completions_to_print=1, + wandb_log_unique_prompts=True, + save_strategy="no", + bf16=True, + report_to=["wandb"], + eval_strategy="no", + run_name=f"trl_{adapter.name}_grpo_qwen15b_env_matched", + loss_type="dapo", + use_vllm=True, + vllm_mode="server", + vllm_server_base_url="http://127.0.0.1:8000", + vllm_importance_sampling_correction=False, + vllm_importance_sampling_cap=cfg.is_ratio_max, + ) + + # Create reward function using adapter + def reward_fn(completions: list[str], prompts: list[str], **kwargs: Any) -> list[float]: + if "gold_answer" in kwargs and kwargs["gold_answer"]: + golds = kwargs["gold_answer"] + return [adapter.compute_reward(c, g) for c, g in zip(completions, golds, strict=False)] + if "metadatas" in kwargs and kwargs["metadatas"]: + golds = [m.get("gold_answer", "") for m in kwargs["metadatas"]] + return [adapter.compute_reward(c, g) for c, g in zip(completions, golds, strict=False)] + golds = [prompt_to_answer.get(p, "") for p in prompts] + return [adapter.compute_reward(c, g) for c, g in zip(completions, golds, strict=False)] + + print(f"\n🏋️ Training with GRPO on {adapter.name.upper()}...") + + # Initialize evaluation callback + vllm_eval_callback = VLLMEvalCallback( + adapter=adapter, + eval_data=eval_data, + tokenizer=tokenizer, + vllm_base_url=grpo_config.vllm_server_base_url, + eval_every_n_steps=args.eval_every, + ) + + trainer = GRPOTrainer( + model=model, + reward_funcs=reward_fn, + args=grpo_config, + train_dataset=train_ds, + processing_class=tokenizer, + callbacks=[vllm_eval_callback], + ) + + # Baseline evaluation + vllm_eval_callback.run_and_log(step=0, label="BASELINE EVAL") + + # Train + trainer.train() + + # Final evaluation + final_step = trainer.state.global_step if hasattr(trainer, "state") else 9999 + final_metrics = vllm_eval_callback.run_and_log(step=final_step, label="FINAL EVAL") + + # Print summary + print("\n" + "=" * 60) + print(f"FINAL RESULTS SUMMARY ({adapter.name.upper()})") + print("=" * 60) + for k in cfg.report_ks: + if k > cfg.eval_replicates: + continue + print(f"\nMetrics @ k={k}:") + print(f" pass@{k}: {final_metrics[f'pass@{k}']:.3f}") + print(f" pass_ordered@{k}: {final_metrics[f'pass_ordered@{k}']:.3f}") + print(f" mean@{k}: {final_metrics[f'mean@{k}']:.3f}") + print(f" best@{k}: {final_metrics[f'best@{k}']:.3f}") + print("\nGlobal metrics:") + print(f" reward_mean_all: {final_metrics['reward_mean_all']:.3f}") + print(f" success_rate_all: {final_metrics['success_rate_all']:.3f}") + + +if __name__ == "__main__": + main() diff --git a/research/trl/train_trl_grpo_README.md b/research/trl/train_trl_grpo_README.md new file mode 100644 index 00000000..12d4b53b --- /dev/null +++ b/research/trl/train_trl_grpo_README.md @@ -0,0 +1,136 @@ +# TRL GRPO Training Script + +Unified TRL GRPO training script supporting both GSM8K and MATH (Hendrycks) datasets with exact parity to GRAIL environment implementations. + +## Quickstart + +### 1. Launch the vLLM Server (Generation GPUs) + +The vLLM server handles rollout generation on separate GPUs while the trainer runs on its own GPU. + +```bash +# Activate vLLM environment +source tools/vllm-server/.venv/bin/activate + +# Launch vLLM server on GPUs 1-4 (4-way tensor parallel) +CUDA_VISIBLE_DEVICES=1,2,3,4 nohup trl vllm-serve \ + --model Qwen/Qwen2.5-1.5B-Instruct \ + --tensor-parallel-size 4 \ + --host 127.0.0.1 \ + --port 8000 \ + --gpu-memory-utilization 0.9 \ + > vllm_server.log 2>&1 & + +# Wait for server to be ready (check logs) +tail -f vllm_server.log +``` + +### 2. Start GRPO Training (Training GPU) + +```bash +# Train on GSM8K (default) +CUDA_VISIBLE_DEVICES=0 nohup python research/trl/train_trl_grpo.py \ + --dataset gsm8k \ + > research/trl/train_gsm8k.log 2>&1 & + +# Train on MATH (Hendrycks) +CUDA_VISIBLE_DEVICES=0 nohup python research/trl/train_trl_grpo.py \ + --dataset math \ + > research/trl/train_math.log 2>&1 & + +# Custom eval frequency +CUDA_VISIBLE_DEVICES=0 python research/trl/train_trl_grpo.py \ + --dataset math \ + --eval-every 50 +``` + +Training logs stream to the respective log files. + +## Features + +- **Factory Pattern**: Easy switching between datasets via `--dataset` CLI flag +- **GRAIL Parity**: Uses exact same task sources, validation logic, and reward weights +- **Multi-Strategy Validation** (MATH): Exact match → Symbolic (sympy) → Numeric +- **Stratified Splits** (MATH): 7,000 train / 500 val (stratified by subject) +- **vLLM Evaluation**: Async batched evaluation with KMetrics aggregation + +## Dataset Comparison + +| Aspect | GSM8K | MATH | +|--------|-------|------| +| **Train Size** | 7,473 | 7,000 | +| **Eval Size** | 1,319 (test) | 500 (stratified val) | +| **Gold Format** | `#### answer` | `\boxed{answer}` | +| **Validation** | Numeric exact | Multi-strategy (exact/sympy/numeric) | +| **Correctness Weight** | 0.6 | 0.7 | +| **Success Threshold** | ≥0.6 | ≥0.7 | + +## Reward Components + +### GSM8K (Total: 1.0) +| Component | Weight | Description | +|-----------|--------|-------------| +| Correctness | 0.6 | Exact numeric match | +| Strict format | 0.15 | Numeric-only + no trailing | +| Thinking | 0.1 | Has reasoning block | +| Answer | 0.1 | Has solution tags | +| No trailing | 0.05 | No text after answer | + +### MATH (Total: 1.0) +| Component | Weight | Description | +|-----------|--------|-------------| +| Correctness | 0.7 | Multi-strategy validation | +| Answer format | 0.15 | Has answer + trailing < 50 chars | +| Thinking | 0.1 | Has reasoning block | +| No trailing | 0.05 | No text after answer | + +## Hyperparameters (from .env) + +| Parameter | Value | Source | +|-----------|-------|--------| +| Learning rate | 3e-6 | `GRAIL_TRAINER_LR` | +| Epochs | 1 | `GRAIL_TRAINER_EPOCHS` | +| Batch size | 4 | `GRAIL_TRAINER_BATCH_SIZE` | +| Grad accum | 128 | `GRAIL_TRAINER_GRAD_ACCUM_STEPS` | +| Max length | 2048 | `GRAIL_TRAINER_MAX_LENGTH` | +| Max completion | 1024 | `GRPO_MAX_COMPLETION_TOKENS` | +| Loss type | dapo | `GRAIL_GRPO_VARIANT` | + +## Architecture + +``` +train_trl_grpo.py +├── DatasetAdapter (ABC) +│ ├── GSM8KAdapter # Uses GSM8KTaskSource from GRAIL +│ └── MATHAdapter # Uses MATHTaskSource from GRAIL +├── get_dataset_adapter() # Factory function +├── VLLMEvalCallback # Dataset-agnostic evaluation +└── main() # CLI entry point +``` + +## GPU Layout (Example: 8x A100) + +``` +┌─────────────────────────────────────────────────────────────┐ +│ GPU 0: Training (GRPO backward pass) │ +├─────────────────────────────────────────────────────────────┤ +│ GPUs 1-4: vLLM Server (4-way tensor parallel generation) │ +├─────────────────────────────────────────────────────────────┤ +│ GPUs 5-7: Available for other tasks │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Requirements + +- TRL with vLLM support (`pip install trl[vllm]`) +- GRAIL codebase (for task sources and validation logic) +- vLLM server running on port 8000 +- Flash Attention 2 (optional, for faster training) + +## Files + +| File | Description | +|------|-------------| +| `train_trl_grpo.py` | Main training script (unified GSM8K + MATH) | +| `train_trl_grpo_README.md` | This documentation | +| `train_trl_gsm8k.py` | Legacy GSM8K-only script (deprecated) | From 5c04954804536a37f5f2b4355ed43266c07abcc2 Mon Sep 17 00:00:00 2001 From: ErfanMhi Date: Fri, 28 Nov 2025 18:27:12 +0000 Subject: [PATCH 05/17] chore(config): update hyperparameters in TRL training script for GSM8K, including learning rate, batch size, and max sequence length --- research/trl/train_trl_gsm8k.py | 80 ++++++++++++++++++++++++--------- 1 file changed, 58 insertions(+), 22 deletions(-) diff --git a/research/trl/train_trl_gsm8k.py b/research/trl/train_trl_gsm8k.py index 75826710..5a8b4018 100644 --- a/research/trl/train_trl_gsm8k.py +++ b/research/trl/train_trl_gsm8k.py @@ -32,35 +32,55 @@ sys.path.append("/root/grail") -# ──────────────── HYPERPARAMETERS (from GRAIL config) ──────────────── +# ──────────────── HYPERPARAMETERS (from .env GRAIL config) ──────────────── @dataclass class Config: + # Model (from GRAIL_TRAIN_MODEL_ID) model_id: str = "Qwen/Qwen2.5-1.5B-Instruct" - lr: float = 2e-6 - epochs: int = 2 - batch_size: int = 8 # 16 groups (prompts) per step - grad_accum_steps: int = 32 - max_length: int = 1536 + # Learning rate (from GRAIL_TRAINER_LR) + lr: float = 3e-6 + # Epochs per window (from GRAIL_TRAINER_EPOCHS) + epochs: int = 1 + # Batch size (from GRAIL_TRAINER_BATCH_SIZE) + batch_size: int = 4 + # Gradient accumulation (from GRAIL_TRAINER_GRAD_ACCUM_STEPS) + grad_accum_steps: int = 128 + # Max sequence length (from GRAIL_TRAINER_MAX_LENGTH) + max_length: int = 2048 + # Gradient clipping (from GRAIL_TRAINER_GRAD_CLIP) grad_clip: float = 1.0 + # Warmup steps (from GRAIL_TRAINER_WARMUP_STEPS) warmup_steps: int = 50 + # KL coefficient (from GRAIL_TRAINER_KL_COEF) kl_coef: float = 0.0 + # Entropy coefficient (from GRAIL_TRAINER_ENTROPY_COEF) entropy_coef: float = 0.0005 + # PPO clip epsilon (standard GRAIL values) ppo_clip_eps: float = 0.2 ppo_clip_eps_upper: float = 0.28 + # Importance sampling ratio max (from GRAIL_TRAINER_IS_RATIO_MAX) is_ratio_max: float = 2.5 + # Log-ratio clamp (from GRAIL_TRAINER_LOGRATIO_CLAMP) logratio_clamp: float = 0.92 + # Dataset sampling num_train_samples: int | None = None # None = use all training samples num_eval_samples: int | None = None # None = use all test samples + # Rollouts per problem (matches GRAIL default) rollouts_per_problem: int = 16 + # Generation parameters temperature: float = 0.7 top_p: float = 0.95 top_k: int = 50 - max_new_tokens: int = 512 + # Max completion tokens (from GRPO_MAX_COMPLETION_TOKENS) + max_new_tokens: int = 1024 + # Evaluation config eval_replicates: int = 5 report_ks: tuple = (1, 5, 10) # Evaluation optimization (for multi-GPU with 8 A100s) eval_batch_size: int = 128 # Large batch for parallel generation eval_num_workers: int = 4 # Dataloader workers + # Max groups for GRPO (from GRPO_MAX_GROUPS) + max_groups: int = 128 cfg = Config() @@ -511,29 +531,43 @@ def main() -> None: wandb.login(key=wandb_api_key) print(f" ✓ WandB logged in (project: {os.getenv('WANDB_PROJECT', 'grail')})") + # Calculate max_prompt_length: total max_length minus max_completion_tokens + max_prompt_length = cfg.max_length - cfg.max_new_tokens # 2048 - 1024 = 1024 + grpo_config = GRPOConfig( output_dir="./outputs/trl_gsm8k", + # Learning rate (GRAIL_TRAINER_LR=3e-6) learning_rate=cfg.lr, + # Epochs (GRAIL_TRAINER_EPOCHS=1) num_train_epochs=cfg.epochs, + # Batch size (GRAIL_TRAINER_BATCH_SIZE=4) per_device_train_batch_size=cfg.batch_size, + # Gradient accumulation (GRAIL_TRAINER_GRAD_ACCUM_STEPS=128) gradient_accumulation_steps=cfg.grad_accum_steps, + # Gradient clipping (GRAIL_TRAINER_GRAD_CLIP=1.0) max_grad_norm=cfg.grad_clip, + # Warmup steps (GRAIL_TRAINER_WARMUP_STEPS=50) warmup_steps=cfg.warmup_steps, - beta=cfg.kl_coef, # Beta is KL coefficient in GRPO - epsilon=cfg.ppo_clip_eps, # PPO epsilon - epsilon_high=cfg.ppo_clip_eps_upper, # Upper PPO epsilon - max_prompt_length=512, # Reasonable prompt limit - max_completion_length=cfg.max_new_tokens, # Max new tokens + # KL coefficient (GRAIL_TRAINER_KL_COEF=0.0) + beta=cfg.kl_coef, + # PPO clip epsilon + epsilon=cfg.ppo_clip_eps, + epsilon_high=cfg.ppo_clip_eps_upper, + # Max prompt length (derived from GRAIL_TRAINER_MAX_LENGTH - GRPO_MAX_COMPLETION_TOKENS) + max_prompt_length=max_prompt_length, + # Max completion tokens (GRPO_MAX_COMPLETION_TOKENS=1024) + max_completion_length=cfg.max_new_tokens, + # Generation parameters temperature=cfg.temperature, top_p=cfg.top_p, - top_k=cfg.top_k, # Match loop.py: 50 highest probability tokens - repetition_penalty=1.1, # Match loop.py: penalize repeating tokens - num_generations=16, # group size: 16 completions per prompt - generation_batch_size=16, # 64 prompts per generation batch + top_k=cfg.top_k, + repetition_penalty=1.1, + # Group size: 16 completions per prompt (rollouts_per_problem) + num_generations=cfg.rollouts_per_problem, + generation_batch_size=16, steps_per_generation=None, logging_steps=1, - # Enable logging a small sample of (prompt, completion) pairs each logging step. - # Prints to console if `rich` is installed and logs a WandB table named "completions". + # Enable logging a sample of (prompt, completion) pairs each logging step log_completions=True, num_completions_to_print=1, wandb_log_unique_prompts=True, @@ -541,14 +575,16 @@ def main() -> None: bf16=True, report_to=["wandb"], eval_strategy="no", # Disable TRL's internal eval (using VLLMEvalCallback instead) - run_name="trl_gsm8k_grpo_qwen15b_g16x16_vllm", - loss_type="dapo", # Match config.py GRPO_VARIANT + run_name="trl_gsm8k_grpo_qwen15b_env_matched", + # Loss type (GRAIL_GRPO_VARIANT=dapo) + loss_type="dapo", # vLLM configuration for offloading generation to separate GPUs use_vllm=True, vllm_mode="server", vllm_server_base_url="http://127.0.0.1:8000", - vllm_importance_sampling_correction=False, # Correct for vLLM/training distribution mismatch - vllm_importance_sampling_cap=2.0, # Cap importance sampling ratio for stability + # Importance sampling (GRAIL_TRAINER_IS_RATIO_MAX=2.5) + vllm_importance_sampling_correction=False, + vllm_importance_sampling_cap=cfg.is_ratio_max, ) # Reward function wrapper From 8f24bc0c5dd147318c81bb7e66ccfb5e1784f7f1 Mon Sep 17 00:00:00 2001 From: Erfan Miahi Date: Thu, 27 Nov 2025 06:16:14 +0000 Subject: [PATCH 06/17] feat(evaluation): add eval_pass_at_k script for computing pass@k metrics on MATH dataset on the test set --- eval_pass_at_k.py | 190 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 190 insertions(+) create mode 100644 eval_pass_at_k.py diff --git a/eval_pass_at_k.py b/eval_pass_at_k.py new file mode 100644 index 00000000..f898c640 --- /dev/null +++ b/eval_pass_at_k.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +"""Compute pass@1, pass@5, pass@10 on MATH using vLLM with multiple samples.""" + +import argparse +import json +import re +from datetime import datetime +from pathlib import Path + +import numpy as np +from datasets import load_dataset +from tqdm import tqdm + + +def pass_at_k(n: int, c: int, k: int) -> float: + """Unbiased pass@k estimator from Codex paper. + + Args: + n: total number of samples + c: number of correct samples + k: k in pass@k + """ + if n - c < k: + return 1.0 + return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) + + +def extract_boxed_answer(text: str) -> str | None: + """Extract answer from \\boxed{...} format.""" + # Find the last \boxed{...} + matches = re.findall(r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}", text) + if matches: + return matches[-1].strip() + return None + + +def normalize_answer(answer: str) -> str: + """Normalize answer for comparison.""" + if answer is None: + return "" + # Remove whitespace and common LaTeX formatting + answer = answer.strip() + answer = answer.replace(" ", "") + answer = answer.replace("\\,", "") + answer = answer.replace("\\!", "") + return answer + + +def is_correct(pred: str, target: str) -> bool: + """Check if prediction matches target.""" + pred_norm = normalize_answer(extract_boxed_answer(pred) or pred) + target_norm = normalize_answer(extract_boxed_answer(target) or target) + return pred_norm == target_norm + + +def build_prompt(problem: str, few_shot_examples: list[dict] = None) -> str: + """Build prompt for MATH problem (zero-shot or few-shot).""" + prompt = "" + if few_shot_examples: + for ex in few_shot_examples: + prompt += f"Problem: {ex['problem']}\nSolution: {ex['solution']}\n\n" + prompt += f"Problem: {problem}\nSolution:" + return prompt + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct") + parser.add_argument("--n-samples", type=int, default=5, help="Samples per problem (>= max k)") + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--max-tokens", type=int, default=2048) + parser.add_argument("--num-fewshot", type=int, default=0) + parser.add_argument("--limit", type=int, default=None, help="Limit problems (for testing)") + parser.add_argument("--tensor-parallel-size", type=int, default=4) + parser.add_argument("--output-dir", type=Path, default=Path("eval_results")) + args = parser.parse_args() + + # Import vLLM + from vllm import LLM, SamplingParams + + print(f"Loading model: {args.model}") + llm = LLM( + model=args.model, + tensor_parallel_size=args.tensor_parallel_size, + dtype="bfloat16", + trust_remote_code=True, + gpu_memory_utilization=0.95, + max_model_len=4096, + max_num_seqs=512, # More concurrent sequences + enable_prefix_caching=True, + ) + + # Load MATH dataset (all 7 subjects) + print("Loading MATH dataset...") + subjects = [ + "algebra", + "counting_and_probability", + "geometry", + "intermediate_algebra", + "number_theory", + "prealgebra", + "precalculus", + ] + + # Load all subjects + test_data = [] + for subject in subjects: + test_data.extend(load_dataset("EleutherAI/hendrycks_math", subject, split="test")) + + dataset = test_data[: args.limit] if args.limit else test_data + print(f"Loaded {len(dataset)} test problems ({args.num_fewshot}-shot)") + + # Load few-shot examples if needed + few_shot_examples = [] + if args.num_fewshot > 0: + train_data = load_dataset("EleutherAI/hendrycks_math", "algebra", split="train") + few_shot_examples = [train_data[i] for i in range(args.num_fewshot)] + + # Build prompts + prompts = [build_prompt(ex["problem"], few_shot_examples) for ex in dataset] + targets = [ex["solution"] for ex in dataset] + + # Sampling params for multiple samples + sampling_params = SamplingParams( + temperature=args.temperature, + max_tokens=args.max_tokens, + n=args.n_samples, # Generate n samples per prompt + ) + + print(f"Generating {args.n_samples} samples per problem for {len(prompts)} problems...") + outputs = llm.generate(prompts, sampling_params) + + # Evaluate + results = [] + for _idx, (output, target) in enumerate( + tqdm(zip(outputs, targets, strict=False), total=len(outputs)) + ): + # Check each sample + correct_count = sum( + 1 for completion in output.outputs if is_correct(completion.text, target) + ) + results.append( + { + "n_samples": args.n_samples, + "n_correct": correct_count, + } + ) + + # Compute pass@k for k=1,5 + k_values = [1, 5] + pass_at_k_results = {} + + for k in k_values: + if k <= args.n_samples: + scores = [pass_at_k(r["n_samples"], r["n_correct"], k) for r in results] + pass_at_k_results[f"pass@{k}"] = np.mean(scores) * 100 + + # Print results + print("\n" + "=" * 50) + print(f"MATH pass@k Results - {args.model}") + print(f"Temperature: {args.temperature}, Samples: {args.n_samples}") + print("=" * 50) + for k, score in pass_at_k_results.items(): + print(f" {k}: {score:.2f}%") + print("=" * 50) + + # Save results + args.output_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_file = args.output_dir / f"pass_at_k_{timestamp}.json" + + with open(output_file, "w") as f: + json.dump( + { + "model": args.model, + "temperature": args.temperature, + "n_samples": args.n_samples, + "num_problems": len(results), + "results": pass_at_k_results, + "per_problem": results, + }, + f, + indent=2, + ) + + print(f"Results saved to: {output_file}") + + +if __name__ == "__main__": + main() From 9649399c396d37a53b0f03144ccc2150c0c04ad5 Mon Sep 17 00:00:00 2001 From: Erfan Miahi Date: Thu, 27 Nov 2025 06:16:53 +0000 Subject: [PATCH 07/17] feat(cli): add parallel mining functionality with multi-GPU support and result aggregation --- grail/cli/__init__.py | 1 + grail/cli/mine.py | 40 +- grail/cli/multi_miner_aggregator.py | 390 +++++++++++++ grail/cli/multi_miner_config.py | 264 +++++++++ grail/cli/parallel_miner.py | 857 ++++++++++++++++++++++++++++ 5 files changed, 1549 insertions(+), 3 deletions(-) create mode 100644 grail/cli/multi_miner_aggregator.py create mode 100644 grail/cli/multi_miner_config.py create mode 100644 grail/cli/parallel_miner.py diff --git a/grail/cli/__init__.py b/grail/cli/__init__.py index fe906480..0c91f5d2 100644 --- a/grail/cli/__init__.py +++ b/grail/cli/__init__.py @@ -329,6 +329,7 @@ def _register_subcommands() -> None: "grail.cli.mine", "grail.cli.validate", "grail.cli.train", + "grail.cli.parallel_miner", ): module = importlib.import_module(mod_name) register: Callable[[typer.Typer], None] | None = getattr(module, "register", None) diff --git a/grail/cli/mine.py b/grail/cli/mine.py index c71e265d..c4145cdd 100644 --- a/grail/cli/mine.py +++ b/grail/cli/mine.py @@ -539,6 +539,9 @@ async def generate_rollouts_for_window( monitor: Any | None, use_drand: bool, checkpoint_window: int, + *, + problem_offset: int = 0, + max_problems: int = 0, ) -> list[dict]: """Generate as many GRPO rollouts as safely possible within a window. @@ -559,11 +562,31 @@ async def generate_rollouts_for_window( timers: EMA-based timing estimates for safety. monitor: Optional monitoring client for metrics. use_drand: Whether drand was used in randomness generation. - checkpoint_window: The checkpoint window used for this generation + checkpoint_window: The checkpoint window used for this generation. + problem_offset: Starting problem index for this worker (default: 0). + Used in parallel mining to assign non-overlapping problem ranges. + max_problems: Maximum number of problems to generate (default: 0 = unlimited). + When 0, generates until time runs out. Used in parallel mining. Returns: List of signed rollout data ready for upload. """ + # Read problem offset/max from environment (worker mode support) + # Environment variables take precedence over function args for subprocess isolation + env_problem_offset = int(os.getenv("GRAIL_PROBLEM_OFFSET", str(problem_offset))) + env_max_problems = int(os.getenv("GRAIL_MAX_PROBLEMS", str(max_problems))) + + # Use env values if set, otherwise use function args + effective_offset = env_problem_offset + effective_max = env_max_problems + + if effective_offset > 0 or effective_max > 0: + logger.info( + "Worker mode: problem_offset=%d, max_problems=%s", + effective_offset, + effective_max if effective_max > 0 else "unlimited", + ) + # Window generation state and metrics inferences: list[dict] = [] start_time = time.time() @@ -598,6 +621,14 @@ async def generate_rollouts_for_window( logger.info("Window %s has ended, moving to next window", window_start) break + # Check max_problems limit (for worker mode) + if effective_max > 0 and problem_count >= effective_max: + logger.info( + "Stopping generation: reached max_problems limit (%d)", + effective_max, + ) + break + blocks_remaining = (window_start + WINDOW_LENGTH) - current_block needed_blocks = timers.blocks_needed_for_next_gen() if blocks_remaining <= needed_blocks: @@ -619,9 +650,13 @@ async def generate_rollouts_for_window( problem_count += 1 inference_count += 1 + # Apply problem offset for parallel mining coordination + # Each GPU worker gets a unique range: GPU0=[0-11], GPU1=[12-23], etc. + problem_index = effective_offset + (problem_count - 1) + logger.info( "⚡ Generating GRPO rollouts for problem %s (block %s/%s)...", - problem_count, + problem_index, current_block, window_start + WINDOW_LENGTH - 1, ) @@ -635,7 +670,6 @@ async def generate_rollouts_for_window( ) # Deterministically derive environment seed from miner+window+index - problem_index = max(0, problem_count - 1) seed_int = derive_env_seed(wallet.hotkey.ss58_address, window_block_hash, problem_index) # Use deterministic problem index as rollout_group identifier base_nonce = problem_index diff --git a/grail/cli/multi_miner_aggregator.py b/grail/cli/multi_miner_aggregator.py new file mode 100644 index 00000000..c998af3e --- /dev/null +++ b/grail/cli/multi_miner_aggregator.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +""" +Multi-Miner Aggregator for GRAIL + +Coordinates multiple miners running on the same machine and aggregates +their results into a single window upload to R2. + +Usage: + python -m grail.cli.multi_miner_aggregator \ + --hotkeys miner_1 miner_2 miner_3 miner_4 \ + --aggregation-hotkey aggregator_hotkey \ + --mode watch # or 'batch' +""" + +import asyncio +import json +import logging +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import bittensor as bt +import typer + +from ..infrastructure.comms import ( + upload_file_chunked, +) +from ..infrastructure.credentials import load_r2_credentials +from ..shared.constants import WINDOW_LENGTH +from . import console + +logger = logging.getLogger("grail.aggregator") + + +# --------------------------------------------------------------------------- # +# Configuration & State # +# --------------------------------------------------------------------------- # + + +@dataclass +class AggregatorConfig: + """Configuration for multi-miner aggregation.""" + + hotkeys: list[str] + aggregation_hotkey: str + cold_wallet: str = "default" + results_dir: Path = Path("/tmp/grail_miner_results") + poll_interval: float = 5.0 # seconds between polls + window_timeout: float = 300.0 # seconds to wait for all miners per window + credentials: Any | None = None + + +class WindowAggregator: + """Aggregates results from multiple miners for a specific window.""" + + def __init__(self, window_start: int, config: AggregatorConfig): + self.window_start = window_start + self.config = config + self.results: dict[str, list[dict]] = {hotkey: [] for hotkey in config.hotkeys} + self.collected_hotkeys: set[str] = set() + self.start_time = time.time() + + async def collect_results(self) -> dict[str, list[dict]]: + """Poll for results from all miners until timeout or all collected.""" + logger.info( + f"🔍 Collecting results for window {self.window_start} " + f"from {len(self.config.hotkeys)} miners..." + ) + + while time.time() - self.start_time < self.config.window_timeout: + # Check each miner's result directory + for hotkey in self.config.hotkeys: + if hotkey in self.collected_hotkeys: + continue # Already collected + + result_file = self._get_result_path(hotkey) + if result_file.exists(): + try: + inferences = await self._load_and_parse(result_file, hotkey) + self.results[hotkey] = inferences + self.collected_hotkeys.add(hotkey) + logger.info(f" ✓ {hotkey}: {len(inferences)} inferences collected") + except Exception as e: + logger.warning(f" ✗ {hotkey}: Failed to load results - {e}") + + # Check if we have all results + if len(self.collected_hotkeys) == len(self.config.hotkeys): + logger.info( + f"✅ All {len(self.config.hotkeys)} miners reported for window " + f"{self.window_start}" + ) + break + + # Log progress + elapsed = time.time() - self.start_time + remaining = self.config.window_timeout - elapsed + pending = len(self.config.hotkeys) - len(self.collected_hotkeys) + if pending > 0: + logger.debug(f" ⏳ Waiting for {pending} miners ({remaining:.0f}s remaining)...") + + await asyncio.sleep(self.config.poll_interval) + + # Log final status + if len(self.collected_hotkeys) < len(self.config.hotkeys): + missing = set(self.config.hotkeys) - self.collected_hotkeys + logger.warning( + f"⚠️ Timeout: Missing results from {missing}. " + f"Uploading partial results ({len(self.collected_hotkeys)}/{len(self.config.hotkeys)})" + ) + + return self.results + + async def aggregate_and_upload(self, wallet: bt.wallet) -> bool: + """Aggregate all collected results and upload to R2.""" + # Flatten all inferences + all_inferences: list[dict] = [] + for inferences in self.results.values(): + all_inferences.extend(inferences) + + if not all_inferences: + logger.warning(f"No inferences to upload for window {self.window_start}") + return False + + # Create window data with aggregation metadata + window_data = { + "wallet": wallet.hotkey.ss58_address, + "window_start": self.window_start, + "window_length": WINDOW_LENGTH, + "inference_count": len(all_inferences), + "inferences": all_inferences, + "timestamp": time.time(), + "aggregated": True, + "miner_count": len(self.collected_hotkeys), + "miner_hotkeys": list(self.collected_hotkeys), + "collection_time_seconds": time.time() - self.start_time, + } + + # Upload to R2 + key = ( + f"grail/windows/aggregated/{wallet.hotkey.ss58_address}-window-{self.window_start}.json" + ) + body = json.dumps(window_data).encode() + + logger.info( + f"📤 Uploading aggregated window {self.window_start} " + f"({len(all_inferences)} inferences from {len(self.collected_hotkeys)} miners)..." + ) + + success = await upload_file_chunked( + key, + body, + credentials=self.config.credentials, + use_write=True, + ) + + if success: + logger.info(f"✅ Successfully uploaded aggregated window {self.window_start} to R2") + # Clean up local result files + await self._cleanup_results() + else: + logger.error(f"❌ Failed to upload aggregated window {self.window_start}") + + return success + + def _get_result_path(self, hotkey: str) -> Path: + """Get path where miner should write results.""" + return self.config.results_dir / f"{hotkey}-window-{self.window_start}.json" + + async def _load_and_parse(self, result_file: Path, hotkey: str) -> list[dict]: + """Load and parse inferences from result file.""" + try: + with open(result_file) as f: + data = json.load(f) + inferences = data.get("inferences", []) + if not isinstance(inferences, list): + raise ValueError(f"Expected list of inferences, got {type(inferences)}") + return inferences + except Exception as e: + logger.debug(f"Failed to parse {result_file}: {e}") + raise + + async def _cleanup_results(self) -> None: + """Remove processed result files.""" + for hotkey in self.collected_hotkeys: + result_file = self._get_result_path(hotkey) + try: + if result_file.exists(): + result_file.unlink() + logger.debug(f"Cleaned up {result_file}") + except Exception as e: + logger.warning(f"Failed to cleanup {result_file}: {e}") + + +class MultiMinerAggregatorService: + """Main service for coordinating multi-miner aggregation.""" + + def __init__(self, config: AggregatorConfig): + self.config = config + self.config.results_dir.mkdir(parents=True, exist_ok=True) + self.stop_event = asyncio.Event() + + async def watch_and_aggregate(self) -> None: + """Watch for window completions and aggregate results.""" + logger.info(f"🚀 Starting multi-miner aggregator for {len(self.config.hotkeys)} miners") + logger.info(f" Miners: {', '.join(self.config.hotkeys)}") + logger.info(f" Results directory: {self.config.results_dir}") + logger.info(f" Poll interval: {self.config.poll_interval}s") + logger.info(f" Window timeout: {self.config.window_timeout}s") + + wallet = bt.wallet(name=self.config.cold_wallet, hotkey=self.config.aggregation_hotkey) + last_window = -1 + + try: + while not self.stop_event.is_set(): + # Get current window + subtensor = bt.subtensor() + current_block = await asyncio.to_thread(subtensor.get_current_block) + current_window = (current_block // WINDOW_LENGTH) * WINDOW_LENGTH + + # New window detected + if current_window > last_window: + logger.info(f"📍 New window detected: {current_window} (block {current_block})") + last_window = current_window + + # Process previous window if we have results + if current_window > WINDOW_LENGTH: + prev_window = current_window - WINDOW_LENGTH + await self._process_window(wallet, prev_window) + + await asyncio.sleep(self.config.poll_interval) + + except KeyboardInterrupt: + logger.info("Stopping aggregator...") + except Exception as e: + logger.error(f"Error in aggregator: {e}", exc_info=True) + raise + + async def _process_window(self, wallet: bt.wallet, window_start: int) -> None: + """Process and upload a specific window.""" + aggregator = WindowAggregator(window_start, self.config) + results = await aggregator.collect_results() + + # Check if we have any results + total_inferences = sum(len(inf) for inf in results.values()) + if total_inferences == 0: + logger.info(f"⊘ No results for window {window_start}, skipping") + return + + # Upload aggregated results + await aggregator.aggregate_and_upload(wallet) + + async def batch_process_window(self, window_start: int) -> bool: + """Process a single window in batch mode.""" + wallet = bt.wallet(name=self.config.cold_wallet, hotkey=self.config.aggregation_hotkey) + await self._process_window(wallet, window_start) + return True + + +# --------------------------------------------------------------------------- # +# CLI Interface # +# --------------------------------------------------------------------------- # + + +def register(app: typer.Typer) -> None: + """Register aggregator command with CLI.""" + app.command("aggregate")(aggregate) + + +def aggregate( + hotkeys: list[str] = typer.Option( + ..., + "--hotkey", + help="Miner hotkeys to aggregate (can specify multiple times)", + ), + aggregation_hotkey: str = typer.Option( + ..., + "--aggregation-hotkey", + help="Hotkey to use for uploading aggregated results", + ), + cold_wallet: str = typer.Option( + "default", + "--cold-wallet", + help="Cold wallet name", + ), + results_dir: str = typer.Option( + "/tmp/grail_miner_results", + "--results-dir", + help="Directory where miners write results", + ), + poll_interval: float = typer.Option( + 5.0, + "--poll-interval", + help="Seconds between polls for new results", + ), + window_timeout: float = typer.Option( + 300.0, + "--window-timeout", + help="Seconds to wait for all miners per window", + ), + mode: str = typer.Option( + "watch", + "--mode", + help="'watch' for continuous monitoring or 'batch' for single window", + ), + window: int | None = typer.Option( + None, + "--window", + help="Window to process (required for batch mode)", + ), +) -> None: + """Aggregate results from multiple miners and upload to R2. + + Example: + python -m grail.cli.multi_miner_aggregator \ + --hotkey miner_1 --hotkey miner_2 --hotkey miner_3 \ + --aggregation-hotkey aggregator \ + --mode watch + + python -m grail.cli.multi_miner_aggregator \ + --hotkey miner_1 --hotkey miner_2 \ + --aggregation-hotkey aggregator \ + --mode batch --window 12345 + """ + try: + # Validate inputs + if not hotkeys: + console.print("[red]Error: At least one --hotkey must be specified[/red]") + raise typer.Exit(code=1) + + if mode not in ("watch", "batch"): + console.print(f"[red]Error: mode must be 'watch' or 'batch', got {mode}[/red]") + raise typer.Exit(code=1) + + if mode == "batch" and window is None: + console.print("[red]Error: --window required for batch mode[/red]") + raise typer.Exit(code=1) + + # Load credentials + try: + credentials = load_r2_credentials() + except Exception as e: + console.print(f"[red]Failed to load R2 credentials: {e}[/red]") + raise typer.Exit(code=1) from None + + # Create config + config = AggregatorConfig( + hotkeys=hotkeys, + aggregation_hotkey=aggregation_hotkey, + cold_wallet=cold_wallet, + results_dir=Path(results_dir), + poll_interval=poll_interval, + window_timeout=window_timeout, + credentials=credentials, + ) + + # Run aggregator + service = MultiMinerAggregatorService(config) + + if mode == "watch": + console.print("[bold green]Starting multi-miner aggregator in watch mode[/bold green]") + asyncio.run(service.watch_and_aggregate()) + else: # batch + console.print(f"[bold green]Processing window {window}[/bold green]") + asyncio.run(service.batch_process_window(window)) + + except KeyboardInterrupt: + console.print("[yellow]Aggregator stopped by user[/yellow]") + raise typer.Exit(code=0) from None + except Exception as e: + logger.error(f"Fatal error: {e}", exc_info=True) + console.print(f"[red]Fatal error: {e}[/red]") + raise typer.Exit(code=1) from None + + +# --------------------------------------------------------------------------- # +# Main Entry Point # +# --------------------------------------------------------------------------- # + + +def main() -> None: + """Main entry point for aggregator CLI.""" + + app = typer.Typer() + register(app) + app() + + +if __name__ == "__main__": + main() diff --git a/grail/cli/multi_miner_config.py b/grail/cli/multi_miner_config.py new file mode 100644 index 00000000..0ead8088 --- /dev/null +++ b/grail/cli/multi_miner_config.py @@ -0,0 +1,264 @@ +""" +Multi-Miner Configuration and Helper Utilities + +Provides common configurations and helper functions for running multiple +miners on the same machine with window-based result aggregation. +""" + +import os +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class MinerConfig: + """Configuration for a single miner instance.""" + + hotkey: str + gpu_index: int | None = None + batch_size: int = 2 + safety_blocks: int = 3 + use_drand: bool = True + extra_env: dict[str, str] | None = None + + +@dataclass +class MultiMinerSetup: + """Complete setup for multiple miners.""" + + miners: list[MinerConfig] + cold_wallet: str = "default" + use_aggregator: bool = True + aggregator_hotkey: str | None = None + results_directory: Path = Path("/tmp/grail_miner_results") + poll_interval_seconds: float = 5.0 + window_timeout_seconds: float = 300.0 + + def __post_init__(self) -> None: + if not self.miners: + raise ValueError("At least one miner config is required") + if self.use_aggregator and not self.aggregator_hotkey: + import time + + self.aggregator_hotkey = f"aggregator_{int(time.time())}" + + +class MultiMinerBuilder: + """Builder for creating multi-miner configurations.""" + + @staticmethod + def from_hotkeys( + hotkeys: list[str], + gpus: list[int] | None = None, + batch_size: int = 2, + use_aggregator: bool = True, + ) -> MultiMinerSetup: + """Create multi-miner setup from list of hotkeys and optional GPU assignments. + + Args: + hotkeys: List of miner hotkeys + gpus: Optional list of GPU indices (cycles if fewer than hotkeys) + batch_size: Generation batch size per miner + use_aggregator: Whether to enable result aggregation + + Returns: + MultiMinerSetup ready to launch + """ + if not hotkeys: + raise ValueError("At least one hotkey required") + + # Build miner configs + miners = [] + for i, hotkey in enumerate(hotkeys): + gpu = gpus[i % len(gpus)] if gpus else None + miners.append( + MinerConfig( + hotkey=hotkey, + gpu_index=gpu, + batch_size=batch_size, + ) + ) + + return MultiMinerSetup( + miners=miners, + use_aggregator=use_aggregator, + ) + + @staticmethod + def from_environment() -> MultiMinerSetup: + """Create multi-miner setup from environment variables. + + Environment variables: + GRAIL_MINERS: Comma-separated hotkey list (e.g., "miner_1,miner_2,miner_3") + GRAIL_GPUS: Comma-separated GPU indices (optional, e.g., "0,1,2") + GRAIL_BATCH_SIZE: Generation batch size (default: 2) + GRAIL_USE_AGGREGATOR: "true" or "false" (default: true) + GRAIL_AGGREGATOR_HOTKEY: Aggregator identity (auto-generated if not set) + GRAIL_RESULTS_DIR: Results directory (default: /tmp/grail_miner_results) + + Returns: + MultiMinerSetup from environment configuration + """ + # Parse miners + miners_str = os.getenv("GRAIL_MINERS", "miner_1") + hotkeys = [h.strip() for h in miners_str.split(",") if h.strip()] + + if not hotkeys: + raise ValueError("GRAIL_MINERS environment variable is empty") + + # Parse GPUs (optional) + gpus_str = os.getenv("GRAIL_GPUS", "") + gpus = None + if gpus_str: + gpus = [int(g.strip()) for g in gpus_str.split(",") if g.strip()] + + # Other settings + batch_size = int(os.getenv("GRAIL_BATCH_SIZE", "2")) + use_aggregator = os.getenv("GRAIL_USE_AGGREGATOR", "true").lower() in ( + "true", + "1", + "yes", + ) + aggregator_hotkey = os.getenv("GRAIL_AGGREGATOR_HOTKEY", None) + results_dir = Path(os.getenv("GRAIL_RESULTS_DIR", "/tmp/grail_miner_results")) + + setup = MultiMinerBuilder.from_hotkeys( + hotkeys=hotkeys, + gpus=gpus, + batch_size=batch_size, + use_aggregator=use_aggregator, + ) + + if aggregator_hotkey: + setup.aggregator_hotkey = aggregator_hotkey + + setup.results_directory = results_dir + + return setup + + +class MinerLauncher: + """Helper for launching miner processes with proper environment.""" + + @staticmethod + def get_env_for_miner(config: MinerConfig, cold_wallet: str = "default") -> dict[str, str]: + """Get environment variables for a miner process. + + Args: + config: MinerConfig for this miner + cold_wallet: Cold wallet name + + Returns: + Dictionary of environment variables to set + """ + env = os.environ.copy() + + # Set wallet + env["BT_WALLET_COLD"] = cold_wallet + env["BT_WALLET_HOT"] = config.hotkey + + # Set GPU if specified + if config.gpu_index is not None: + env["CUDA_VISIBLE_DEVICES"] = str(config.gpu_index) + else: + # Remove GPU constraint if not specified + env.pop("CUDA_VISIBLE_DEVICES", None) + + # Set generation parameters + env["GRAIL_GENERATION_BATCH_SIZE"] = str(config.batch_size) + env["GRAIL_MINER_SAFETY_BLOCKS"] = str(config.safety_blocks) + + # Add any extra environment variables + if config.extra_env: + env.update(config.extra_env) + + return env + + @staticmethod + def get_command_for_miner(config: MinerConfig) -> list[str]: + """Get command to launch a miner. + + Args: + config: MinerConfig for this miner + + Returns: + Command as list of strings (suitable for subprocess) + """ + return [ + "python", + "-m", + "grail.cli.mine", + "--use-drand" if config.use_drand else "--no-drand", + ] + + +class AggregatorLauncher: + """Helper for launching aggregator with proper arguments.""" + + @staticmethod + def get_command_for_aggregator(setup: MultiMinerSetup, mode: str = "watch") -> list[str]: + """Get command to launch aggregator. + + Args: + setup: MultiMinerSetup configuration + mode: "watch" or "batch" + + Returns: + Command as list of strings + """ + if not setup.aggregator_hotkey: + raise ValueError("aggregator_hotkey not set") + + cmd = [ + "python", + "-m", + "grail.cli.multi_miner_aggregator", + ] + + # Add miner hotkeys + for miner in setup.miners: + cmd.extend(["--hotkey", miner.hotkey]) + + # Add aggregator settings + cmd.extend( + [ + "--aggregation-hotkey", + setup.aggregator_hotkey, + "--cold-wallet", + setup.cold_wallet, + "--results-dir", + str(setup.results_directory), + "--poll-interval", + str(setup.poll_interval_seconds), + "--window-timeout", + str(setup.window_timeout_seconds), + "--mode", + mode, + ] + ) + + return cmd + + +def print_setup_summary(setup: MultiMinerSetup) -> None: + """Pretty-print the multi-miner setup configuration.""" + print("\n" + "=" * 60) + print("Multi-Miner Setup Configuration") + print("=" * 60) + + print(f"\n📊 Miners: {len(setup.miners)}") + for i, miner in enumerate(setup.miners, 1): + gpu_info = f"GPU {miner.gpu_index}" if miner.gpu_index is not None else "Any GPU" + print(f" {i}. {miner.hotkey:20s} [{gpu_info}] batch_size={miner.batch_size}") + + print(f"\n💼 Wallet: {setup.cold_wallet}") + + if setup.use_aggregator: + print(f"\n🔄 Aggregator: {setup.aggregator_hotkey}") + print(f" Poll interval: {setup.poll_interval_seconds}s") + print(f" Window timeout: {setup.window_timeout_seconds}s") + else: + print("\n🔄 Aggregator: Disabled") + + print(f"\n📁 Results directory: {setup.results_directory}") + print("\n" + "=" * 60 + "\n") diff --git a/grail/cli/parallel_miner.py b/grail/cli/parallel_miner.py new file mode 100644 index 00000000..8c2e49a0 --- /dev/null +++ b/grail/cli/parallel_miner.py @@ -0,0 +1,857 @@ +#!/usr/bin/env python3 +""" +Parallel Multi-GPU Miner for GRAIL + +Coordinates multiple GPU workers to generate rollouts in parallel, with each GPU +handling a distinct range of problem IDs. All results are gathered before a +single upload to maximize throughput while maintaining submission integrity. + +Architecture: + ┌─────────────────────────────────────────────────────────────┐ + │ Coordinator Process │ + │ - Assigns problem ranges: GPU0=[0-11], GPU1=[12-23], ... │ + │ - Spawns N worker processes │ + │ - Gathers results via temp files │ + │ - Single sink_window_inferences() call │ + └──────────────────────────┬──────────────────────────────────┘ + │ + ┌─────────┬─────────┬─┴─────────┬─────────┐ + ▼ ▼ ▼ ▼ ▼ + ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ + │GPU 0 │ │GPU 1 │ │GPU 2 │ ... │GPU N │ │GPU N │ + │P:0-11│ │P:12-23│ │P:24-35│ │ │ │ │ + └──────┘ └──────┘ └──────┘ └──────┘ └──────┘ + +Usage: + python -m grail.cli.parallel_miner --num-gpus 8 --problems-per-gpu 12 +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import multiprocessing as mp +import os +import tempfile +import time +from dataclasses import dataclass, field +from pathlib import Path +from queue import Empty +from typing import Any + +import bittensor as bt +import torch +import typer + +from ..infrastructure.credentials import load_r2_credentials +from ..shared.constants import WINDOW_LENGTH +from . import console + +logger = logging.getLogger("grail.parallel_miner") + + +# --------------------------------------------------------------------------- # +# Configuration # +# --------------------------------------------------------------------------- # + + +@dataclass +class GPUWorkerConfig: + """Configuration for a single GPU worker process.""" + + gpu_id: int + problem_offset: int + max_problems: int + results_dir: Path + window_start: int + window_block_hash: str + combined_randomness: str + use_drand: bool + checkpoint_path: str | None + # Wallet names read from environment in worker for subprocess isolation + batch_size: int = 2 + safety_blocks: int = 3 + + +@dataclass +class ParallelMinerConfig: + """Configuration for parallel multi-GPU mining.""" + + num_gpus: int = 8 + problems_per_gpu: int = 12 + batch_size: int = 2 + safety_blocks: int = 3 + use_drand: bool = True + results_dir: Path = field( + default_factory=lambda: Path(tempfile.mkdtemp(prefix="grail_parallel_")) + ) + worker_timeout: float = 600.0 # 10 minutes max per window + gpu_ids: list[int] | None = None # Specific GPU IDs to use, None = [0, 1, ..., num_gpus-1] + + def get_gpu_ids(self) -> list[int]: + """Return list of GPU IDs to use.""" + if self.gpu_ids is not None: + return self.gpu_ids + return list(range(self.num_gpus)) + + +# --------------------------------------------------------------------------- # +# GPU Worker Process # +# --------------------------------------------------------------------------- # + + +def _gpu_worker_main( + config: GPUWorkerConfig, + result_queue: mp.Queue, +) -> None: + """Main function for GPU worker process. + + This runs in a separate process with CUDA_VISIBLE_DEVICES set to the + assigned GPU. It generates rollouts for a specific problem range and + writes results to a temp file. + + Args: + config: Worker configuration with GPU assignment and problem range + result_queue: Queue to signal completion status back to coordinator + """ + worker_id = f"GPU-{config.gpu_id}" + start_time = time.time() + + # Configure logging for worker process + import logging + + logging.basicConfig( + level=logging.INFO, + format=f"%(asctime)s [{worker_id}] %(levelname)s: %(message)s", + datefmt="%H:%M:%S", + ) + worker_logger = logging.getLogger(f"grail.worker.{config.gpu_id}") + + try: + # Set GPU visibility BEFORE any CUDA operations + os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_id) + + # Import heavy modules after setting CUDA_VISIBLE_DEVICES + from ..cli.mine import ( + MiningTimers, + package_rollout_data, + ) + from ..environments.factory import create_env + from ..environments.loop import AgentEnvLoop + from ..grail import derive_env_seed + from ..model.provider import get_model, get_tokenizer + from ..shared.constants import ROLLOUTS_PER_PROBLEM + + worker_logger.info( + "Starting worker: problems %d-%d on GPU %d", + config.problem_offset, + config.problem_offset + config.max_problems - 1, + config.gpu_id, + ) + + # Load wallet from environment (same env as coordinator) + coldkey = os.getenv("BT_WALLET_COLD", "default") + hotkey = os.getenv("BT_WALLET_HOT", "default") + wallet = bt.wallet(name=coldkey, hotkey=hotkey) + + # Load model and tokenizer + if config.checkpoint_path: + model = get_model(config.checkpoint_path, device="cuda", eval_mode=True) + tokenizer = get_tokenizer(config.checkpoint_path) + else: + raise RuntimeError("checkpoint_path is required for parallel mining") + + device = model.device + loop = AgentEnvLoop(model, tokenizer, str(device)) + + # Generate rollouts for assigned problem range + inferences: list[dict] = [] + timers = MiningTimers() + + for local_idx in range(config.max_problems): + problem_index = config.problem_offset + local_idx + gen_start = time.time() + + # Derive deterministic seed for this problem + seed_int = derive_env_seed( + wallet.hotkey.ss58_address, + config.window_block_hash, + problem_index, + ) + + worker_logger.debug( + "Generating problem %d (seed=%d)", + problem_index, + seed_int, + ) + + # Generate GRPO rollouts + def _env_factory(): + return create_env() + + grpo_rollouts = loop.run_grpo_group( + _env_factory, + ROLLOUTS_PER_PROBLEM, + config.combined_randomness, + wallet, + batch_size=config.batch_size, + seed=seed_int, + ) + + # Package rollouts with signatures + base_nonce = problem_index + for rollout_idx, rollout in enumerate(grpo_rollouts): + rollout_data = package_rollout_data( + model, + wallet, + rollout, + base_nonce, + rollout_idx, + len(grpo_rollouts), + config.window_start, + config.window_start, # current_block = window_start for parallel + config.window_block_hash, + config.combined_randomness, + config.use_drand, + ) + inferences.append(rollout_data) + + gen_duration = time.time() - gen_start + timers.update_gen_time_ema(gen_duration) + + worker_logger.info( + "Problem %d: %d rollouts in %.2fs", + problem_index, + len(grpo_rollouts), + gen_duration, + ) + + # Write results to temp file + result_file = config.results_dir / f"gpu_{config.gpu_id}_results.json" + result_data = { + "gpu_id": config.gpu_id, + "problem_offset": config.problem_offset, + "max_problems": config.max_problems, + "inference_count": len(inferences), + "inferences": inferences, + "duration_seconds": time.time() - start_time, + } + + with open(result_file, "w") as f: + json.dump(result_data, f) + + worker_logger.info( + "Completed: %d rollouts from %d problems in %.2fs", + len(inferences), + config.max_problems, + time.time() - start_time, + ) + + # Signal success + result_queue.put( + { + "gpu_id": config.gpu_id, + "status": "success", + "inference_count": len(inferences), + "result_file": str(result_file), + "duration": time.time() - start_time, + } + ) + + except Exception as e: + worker_logger.exception("Worker failed: %s", e) + result_queue.put( + { + "gpu_id": config.gpu_id, + "status": "error", + "error": str(e), + "duration": time.time() - start_time, + } + ) + + +# --------------------------------------------------------------------------- # +# Parallel Mining Coordinator # +# --------------------------------------------------------------------------- # + + +class ParallelMiningCoordinator: + """Coordinates parallel rollout generation across multiple GPUs. + + Responsibilities: + - Spawn GPU worker processes with non-overlapping problem ranges + - Monitor worker progress and handle failures + - Gather all results and perform single aggregated upload + - Clean up temp files after successful upload + """ + + def __init__( + self, + config: ParallelMinerConfig, + wallet: bt.wallet, + credentials: Any, + ) -> None: + self.config = config + self.wallet = wallet + self.credentials = credentials + self._workers: list[mp.Process] = [] + # Use spawn context for CUDA-safe queue + try: + mp.set_start_method("spawn", force=True) + except RuntimeError: + pass # Already set + self._ctx = mp.get_context("spawn") + self._result_queue: mp.Queue = self._ctx.Queue() + self._shutdown_requested = False + + async def mine_window( + self, + window_start: int, + window_block_hash: str, + combined_randomness: str, + checkpoint_path: str | None, + ) -> list[dict]: + """Generate rollouts for a window using all GPUs in parallel. + + Args: + window_start: Start block of the mining window + window_block_hash: Block hash at window start + combined_randomness: Combined randomness for proof generation + checkpoint_path: Path to model checkpoint + + Returns: + Combined list of all rollout inferences from all GPUs + """ + gpu_ids = self.config.get_gpu_ids() + total_problems = self.config.problems_per_gpu * len(gpu_ids) + + logger.info( + "🚀 Starting parallel mining: %d GPUs × %d problems = %d total problems", + len(gpu_ids), + self.config.problems_per_gpu, + total_problems, + ) + + # Ensure results directory exists + self.config.results_dir.mkdir(parents=True, exist_ok=True) + + # Create worker configs with non-overlapping problem ranges + worker_configs: list[GPUWorkerConfig] = [] + for idx, gpu_id in enumerate(gpu_ids): + problem_offset = idx * self.config.problems_per_gpu + worker_config = GPUWorkerConfig( + gpu_id=gpu_id, + problem_offset=problem_offset, + max_problems=self.config.problems_per_gpu, + results_dir=self.config.results_dir, + window_start=window_start, + window_block_hash=window_block_hash, + combined_randomness=combined_randomness, + use_drand=self.config.use_drand, + checkpoint_path=checkpoint_path, + batch_size=self.config.batch_size, + safety_blocks=self.config.safety_blocks, + ) + worker_configs.append(worker_config) + + # Spawn worker processes using 'spawn' method for CUDA compatibility + # This ensures each worker gets a fresh CUDA context without conflicts + start_time = time.time() + self._workers = [] + + for worker_config in worker_configs: + # Use spawn context to avoid CUDA context issues + proc = self._ctx.Process( + target=_gpu_worker_main, + args=(worker_config, self._result_queue), + daemon=True, + ) + proc.start() + self._workers.append(proc) + logger.info( + " Started worker PID %d for GPU %d (problems %d-%d)", + proc.pid, + worker_config.gpu_id, + worker_config.problem_offset, + worker_config.problem_offset + worker_config.max_problems - 1, + ) + + # Wait for all workers to complete + results = await self._wait_for_workers(len(gpu_ids)) + + # Gather and combine results - ALL workers must succeed + all_inferences, all_succeeded = await self._gather_results(results, len(gpu_ids)) + + elapsed = time.time() - start_time + + if not all_succeeded: + logger.error("❌ Parallel mining FAILED: Not all GPUs completed successfully") + logger.error( + "Returning empty results to prevent partial upload that would fail validation" + ) + return [] # Return empty to prevent upload + + # Verify expected rollout count + from ..shared.constants import ROLLOUTS_PER_PROBLEM + + expected_rollouts = len(gpu_ids) * self.config.problems_per_gpu * ROLLOUTS_PER_PROBLEM + if len(all_inferences) != expected_rollouts: + logger.error( + "❌ Rollout count mismatch: got %d, expected %d (%d GPUs × %d problems × %d rollouts)", + len(all_inferences), + expected_rollouts, + len(gpu_ids), + self.config.problems_per_gpu, + ROLLOUTS_PER_PROBLEM, + ) + logger.error("Returning empty results to prevent validation failure") + return [] + + logger.info( + "✅ Parallel mining complete: %d rollouts in %.2fs (%.1f rollouts/sec)", + len(all_inferences), + elapsed, + len(all_inferences) / elapsed if elapsed > 0 else 0, + ) + + return all_inferences + + async def _wait_for_workers(self, expected_count: int) -> list[dict]: + """Wait for all worker processes to complete. + + Args: + expected_count: Number of workers expected to complete + + Returns: + List of result dictionaries from each worker + """ + results: list[dict] = [] + deadline = time.time() + self.config.worker_timeout + + while len(results) < expected_count and time.time() < deadline: + try: + # Non-blocking check with timeout + result = await asyncio.to_thread( + self._result_queue.get, + timeout=5.0, + ) + results.append(result) + + if result["status"] == "success": + logger.info( + " GPU %d completed: %d rollouts in %.2fs", + result["gpu_id"], + result["inference_count"], + result["duration"], + ) + else: + logger.error( + " GPU %d failed: %s", + result["gpu_id"], + result.get("error", "unknown error"), + ) + + except Empty: + # Check if any workers have crashed + alive_count = sum(1 for w in self._workers if w.is_alive()) + if alive_count == 0 and len(results) < expected_count: + logger.error("All workers have exited but not all reported results") + break + continue + + # Terminate any remaining workers + for worker in self._workers: + if worker.is_alive(): + logger.warning("Terminating hung worker PID %d", worker.pid) + worker.terminate() + worker.join(timeout=5.0) + + return results + + async def _gather_results( + self, worker_results: list[dict], expected_gpu_count: int + ) -> tuple[list[dict], bool]: + """Gather and combine results from all workers. + + CRITICAL: All workers must succeed for upload to proceed. + Missing any problem ID will cause validator proof failure. + + Args: + worker_results: List of worker result status dictionaries + expected_gpu_count: Number of GPUs that must succeed + + Returns: + Tuple of (combined inferences, all_succeeded) + """ + all_inferences: list[dict] = [] + successful_gpus = 0 + failed_gpus: list[int] = [] + + for result in worker_results: + if result["status"] != "success": + failed_gpus.append(result["gpu_id"]) + logger.error( + "GPU %d FAILED: %s - Cannot upload partial results!", + result["gpu_id"], + result.get("error", "unknown error"), + ) + continue + + result_file = Path(result["result_file"]) + if not result_file.exists(): + failed_gpus.append(result["gpu_id"]) + logger.error( + "GPU %d result file missing: %s - Cannot upload partial results!", + result["gpu_id"], + result_file, + ) + continue + + try: + with open(result_file) as f: + data = json.load(f) + inferences = data.get("inferences", []) + all_inferences.extend(inferences) + successful_gpus += 1 + logger.info( + " GPU %d: %d rollouts collected", + result["gpu_id"], + len(inferences), + ) + + # Clean up temp file + result_file.unlink() + + except Exception as e: + failed_gpus.append(result["gpu_id"]) + logger.error("Failed to read results from GPU %d: %s", result["gpu_id"], e) + + # Check if ALL workers succeeded + all_succeeded = (successful_gpus == expected_gpu_count) and len(failed_gpus) == 0 + + if all_succeeded: + logger.info( + "✅ All %d GPUs succeeded: %d total rollouts ready for upload", + successful_gpus, + len(all_inferences), + ) + else: + logger.error( + "❌ INCOMPLETE: Only %d/%d GPUs succeeded. Failed GPUs: %s", + successful_gpus, + expected_gpu_count, + failed_gpus, + ) + logger.error( + "Cannot upload partial results - validator would reject due to missing problem IDs!" + ) + + return all_inferences, all_succeeded + + def cleanup(self) -> None: + """Clean up resources and temp files.""" + # Terminate any remaining workers + for worker in self._workers: + if worker.is_alive(): + worker.terminate() + worker.join(timeout=2.0) + + # Clean up results directory + try: + if self.config.results_dir.exists(): + for f in self.config.results_dir.iterdir(): + f.unlink() + self.config.results_dir.rmdir() + except Exception as e: + logger.debug("Cleanup error (non-fatal): %s", e) + + +# --------------------------------------------------------------------------- # +# CLI Interface # +# --------------------------------------------------------------------------- # + + +async def run_parallel_miner( + config: ParallelMinerConfig, + use_drand: bool = True, +) -> None: + """Main entry point for parallel multi-GPU mining. + + Args: + config: Parallel mining configuration + use_drand: Whether to use drand for randomness + """ + from types import SimpleNamespace + + from ..cli.mine import ( + MiningTimers, + calculate_window_start, + get_conf, + get_window_randomness, + upload_inferences_with_metrics, + ) + from ..infrastructure.chain import GrailChainManager + from ..infrastructure.checkpoints import CheckpointManager, default_checkpoint_cache_root + from ..shared.constants import TRAINER_UID + + # Load configuration + coldkey = get_conf("BT_WALLET_COLD", "default") + hotkey = get_conf("BT_WALLET_HOT", "default") + wallet = bt.wallet(name=coldkey, hotkey=hotkey) + + logger.info("🔑 Parallel Miner hotkey: %s", wallet.hotkey.ss58_address) + logger.info(" GPUs: %d, Problems/GPU: %d", config.num_gpus, config.problems_per_gpu) + + # Load credentials + credentials = load_r2_credentials() + logger.info("✅ Loaded R2 credentials") + + # Initialize async subtensor (grail uses async bittensor wrapper) + from ..infrastructure.network import create_subtensor + + subtensor = await create_subtensor() + netuid = int(get_conf("BT_NETUID", get_conf("NETUID", 200))) + + # Get metagraph using async subtensor + metagraph = await subtensor.metagraph(netuid) + + # Initialize chain manager for credential commitments + chain_config = SimpleNamespace(netuid=netuid) + chain_manager = GrailChainManager(chain_config, wallet, metagraph, subtensor, credentials) + await chain_manager.initialize() + logger.info("✅ Initialized chain manager") + + # Get trainer credentials for checkpoints + trainer_bucket = chain_manager.get_bucket(TRAINER_UID) + checkpoint_credentials = trainer_bucket if trainer_bucket else credentials + + checkpoint_manager = CheckpointManager( + cache_root=default_checkpoint_cache_root(), + credentials=checkpoint_credentials, + keep_limit=2, + ) + + # Create coordinator + coordinator = ParallelMiningCoordinator(config, wallet, credentials) + + # Main mining loop + last_window_start = -1 + timers = MiningTimers() + current_checkpoint_window: int | None = None + checkpoint_path: str | None = None + + try: + while True: + current_block = await subtensor.get_current_block() + window_start = calculate_window_start(current_block) + checkpoint_window = window_start - WINDOW_LENGTH + + if window_start <= last_window_start: + await asyncio.sleep(5) + continue + + # Load checkpoint if needed + if checkpoint_window >= 0 and current_checkpoint_window != checkpoint_window: + logger.info("🔁 Loading checkpoint for window %s", checkpoint_window) + checkpoint_path_obj = await checkpoint_manager.get_checkpoint(checkpoint_window) + if checkpoint_path_obj: + checkpoint_path = str(checkpoint_path_obj) + current_checkpoint_window = checkpoint_window + else: + logger.error("No checkpoint available for window %s", checkpoint_window) + await asyncio.sleep(30) + continue + + if not checkpoint_path: + logger.error("No checkpoint loaded, cannot mine") + await asyncio.sleep(30) + continue + + # Check time budget - skip for parallel mode since workers manage their own time + # The parallel coordinator ensures all workers complete before upload + + # Get window randomness + window_block_hash, combined_randomness = await get_window_randomness( + subtensor, + window_start, + use_drand, + ) + + logger.info( + "🔥 Starting parallel mining for window %d-%d", + window_start, + window_start + WINDOW_LENGTH - 1, + ) + + # Run parallel mining + inferences = await coordinator.mine_window( + window_start, + window_block_hash, + combined_randomness, + checkpoint_path, + ) + + # Upload aggregated results + if inferences: + logger.info( + "📤 Uploading %d aggregated rollouts for window %d", + len(inferences), + window_start, + ) + upload_duration = await upload_inferences_with_metrics( + wallet, + window_start, + inferences, + credentials, + None, # monitor + ) + timers.update_upload_time_ema(upload_duration) + logger.info("✅ Successfully uploaded window %d", window_start) + else: + logger.warning("No inferences generated for window %d", window_start) + + last_window_start = window_start + await checkpoint_manager.cleanup_local(window_start) + + except KeyboardInterrupt: + logger.info("Shutting down parallel miner...") + finally: + coordinator.cleanup() + chain_manager.stop() + + +def register(app: typer.Typer) -> None: + """Register parallel-mine command with CLI.""" + app.command("parallel-mine")(parallel_mine) + + +def parallel_mine( + num_gpus: int = typer.Option( + 8, + "--num-gpus", + "-g", + help="Number of GPUs to use for parallel mining", + ), + problems_per_gpu: int = typer.Option( + 12, + "--problems-per-gpu", + "-p", + help="Minimum number of problems each GPU should generate", + ), + batch_size: int = typer.Option( + 2, + "--batch-size", + "-b", + help="Rollout batch size within each problem (1-16)", + ), + safety_blocks: int = typer.Option( + 3, + "--safety-blocks", + help="Safety margin blocks before window end", + ), + use_drand: bool = typer.Option( + True, + "--use-drand/--no-drand", + help="Use drand for randomness", + ), + gpu_ids: str = typer.Option( + None, + "--gpu-ids", + help="Comma-separated GPU IDs to use (e.g., '0,1,2,3'). Default: 0 to num_gpus-1", + ), + worker_timeout: float = typer.Option( + 600.0, + "--worker-timeout", + help="Maximum seconds to wait for workers per window", + ), +) -> None: + """Run parallel multi-GPU miner for maximum throughput. + + Spawns multiple worker processes, each on a dedicated GPU, generating + rollouts for non-overlapping problem ranges. Results are aggregated + and uploaded as a single submission per window. + + Example: + grail parallel-mine --num-gpus 8 --problems-per-gpu 12 + + This generates 8 × 12 = 96 problems per window (1,536+ rollouts). + """ + # Validate inputs + if num_gpus < 1: + console.print("[red]Error: --num-gpus must be at least 1[/red]") + raise typer.Exit(code=1) + + if problems_per_gpu < 1: + console.print("[red]Error: --problems-per-gpu must be at least 1[/red]") + raise typer.Exit(code=1) + + if batch_size < 1 or batch_size > 16: + console.print("[red]Error: --batch-size must be between 1 and 16[/red]") + raise typer.Exit(code=1) + + # Parse GPU IDs if provided + parsed_gpu_ids = None + if gpu_ids: + try: + parsed_gpu_ids = [int(x.strip()) for x in gpu_ids.split(",")] + if len(parsed_gpu_ids) != num_gpus: + console.print( + f"[red]Error: --gpu-ids has {len(parsed_gpu_ids)} IDs " + f"but --num-gpus is {num_gpus}[/red]" + ) + raise typer.Exit(code=1) + except ValueError as err: + console.print("[red]Error: --gpu-ids must be comma-separated integers[/red]") + raise typer.Exit(code=1) from err + + # Check GPU availability + available_gpus = torch.cuda.device_count() + if available_gpus < num_gpus: + console.print( + f"[yellow]Warning: Only {available_gpus} GPUs available, " + f"but {num_gpus} requested[/yellow]" + ) + + config = ParallelMinerConfig( + num_gpus=num_gpus, + problems_per_gpu=problems_per_gpu, + batch_size=batch_size, + safety_blocks=safety_blocks, + use_drand=use_drand, + worker_timeout=worker_timeout, + gpu_ids=parsed_gpu_ids, + ) + + total_problems = num_gpus * problems_per_gpu + console.print("[bold green]Starting Parallel Miner[/bold green]") + console.print(f" GPUs: {num_gpus}") + console.print(f" Problems/GPU: {problems_per_gpu}") + console.print(f" Total problems/window: {total_problems}") + console.print(f" Expected rollouts/window: {total_problems * 16}") + + try: + asyncio.run(run_parallel_miner(config, use_drand)) + except KeyboardInterrupt: + console.print("[yellow]Parallel miner stopped by user[/yellow]") + raise typer.Exit(code=0) from None + except Exception as e: + logger.exception("Fatal error in parallel miner") + console.print(f"[red]Fatal error: {e}[/red]") + raise typer.Exit(code=1) from None + + +# --------------------------------------------------------------------------- # +# Main Entry Point # +# --------------------------------------------------------------------------- # + + +def main() -> None: + """Main entry point for parallel miner CLI.""" + app = typer.Typer() + register(app) + app() + + +if __name__ == "__main__": + main() From fcecd9ab2bfc9dc304469d59c9fad0c06fd5a517 Mon Sep 17 00:00:00 2001 From: Erfan Miahi Date: Thu, 27 Nov 2025 06:17:27 +0000 Subject: [PATCH 08/17] feat(evaluation): enhance eval_math_harness with vLLM backend support and tensor parallelism features --- eval_math_harness.py | 199 ++++++++++++++++++++++++++++++++----------- 1 file changed, 148 insertions(+), 51 deletions(-) diff --git a/eval_math_harness.py b/eval_math_harness.py index 3540bde9..1b1a3323 100644 --- a/eval_math_harness.py +++ b/eval_math_harness.py @@ -15,15 +15,16 @@ 3. Chain-of-thought - Enabled via task's native format 4. Greedy decoding - temperature=0 for reproducibility 5. max_gen_toks=1024 - Sufficient for reasoning chains -6. Flash attention - Memory efficient for 7B+ models +6. vLLM backend - High-throughput inference with tensor parallelism 7. BF16 precision - Optimal for modern GPUs -8. Batch size tuning - Auto via batch_size="auto" +8. Tensor parallelism - Utilize all available GPUs Usage: ------ python eval_math_harness.py python eval_math_harness.py --model Qwen/Qwen2.5-7B-Instruct python eval_math_harness.py --num-fewshot 0 # zero-shot + python eval_math_harness.py --tensor-parallel-size 8 # use 8 GPUs """ import argparse @@ -41,10 +42,54 @@ logger = logging.getLogger(__name__) +def get_gpu_count() -> int: + """Get the number of available CUDA GPUs.""" + try: + import torch + + return torch.cuda.device_count() + except ImportError: + # Fallback to nvidia-smi + import subprocess + + result = subprocess.run( + ["nvidia-smi", "--query-gpu=index", "--format=csv,noheader"], + capture_output=True, + text=True, + ) + return len(result.stdout.strip().split("\n")) if result.returncode == 0 else 1 + + +def get_model_num_attention_heads(model_name: str) -> int: + """Get the number of attention heads for a model.""" + try: + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + return getattr(config, "num_attention_heads", 32) + except Exception: + return 32 # Default fallback + + +def get_optimal_tensor_parallel_size(model_name: str, max_gpus: int) -> int: + """Calculate optimal tensor parallel size based on model architecture. + + Tensor parallelism requires num_attention_heads to be divisible by TP size. + Returns the largest valid TP size <= max_gpus. + """ + num_heads = get_model_num_attention_heads(model_name) + + # Find the largest divisor of num_heads that is <= max_gpus + for tp_size in range(min(max_gpus, num_heads), 0, -1): + if num_heads % tp_size == 0: + return tp_size + return 1 + + def parse_args() -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser( - description="Evaluate model on Hendrycks MATH using lm-eval-harness" + description="Evaluate model on Hendrycks MATH using lm-eval-harness with vLLM" ) parser.add_argument( "--model", @@ -82,21 +127,53 @@ def parse_args() -> argparse.Namespace: default=None, help="Limit number of samples per task (for debugging)", ) + parser.add_argument( + "--tensor-parallel-size", + type=int, + default=None, + help="Tensor parallel size for vLLM (default: auto-detect all GPUs)", + ) + parser.add_argument( + "--max-model-len", + type=int, + default=4096, + help="Maximum model context length for vLLM (default: 4096)", + ) + parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=0.9, + help="GPU memory utilization for vLLM (default: 0.9)", + ) + parser.add_argument( + "--backend", + type=str, + choices=["vllm", "hf"], + default="vllm", + help="Inference backend: vllm (faster) or hf (HuggingFace)", + ) return parser.parse_args() def run_evaluation(args: argparse.Namespace) -> dict: - """Run lm-evaluation-harness on MATH dataset.""" + """Run lm-evaluation-harness on MATH dataset using vLLM or HuggingFace backend.""" try: from lm_eval import simple_evaluate - from lm_eval.models.huggingface import HFLM except ImportError: logger.error("lm-eval not installed. Run: pip install lm-eval") sys.exit(1) + # Auto-detect tensor parallel size if not specified + available_gpus = get_gpu_count() + if args.tensor_parallel_size is None: + args.tensor_parallel_size = get_optimal_tensor_parallel_size(args.model, available_gpus) + logger.info(f"Model: {args.model}") + logger.info(f"Backend: {args.backend}") logger.info(f"Few-shot: {args.num_fewshot}") logger.info(f"Batch size: {args.batch_size}") + logger.info(f"Available GPUs: {available_gpus}") + logger.info(f"Tensor parallel size: {args.tensor_parallel_size}") # MATH subtasks (all 7 subjects) # Using hendrycks_math tasks with proper \boxed{} answer extraction @@ -112,37 +189,71 @@ def run_evaluation(args: argparse.Namespace) -> dict: logger.info(f"Tasks: {tasks}") - # Model configuration with best practices - model_kwargs = { - "pretrained": args.model, - "dtype": "bfloat16", - "device_map": "auto", - "trust_remote_code": True, - # Enable flash attention if available - "attn_implementation": "flash_attention_2", - } + if args.backend == "vllm": + # Use vLLM backend for maximum efficiency with tensor parallelism + try: + from lm_eval.models.vllm_causallms import VLLM + except ImportError: + logger.error("vLLM not installed. Run: pip install vllm") + sys.exit(1) + + logger.info(f"GPU memory utilization: {args.gpu_memory_utilization}") + logger.info(f"Max model length: {args.max_model_len}") + + # vLLM model configuration for maximum throughput + model = VLLM( + pretrained=args.model, + tensor_parallel_size=args.tensor_parallel_size, + dtype="bfloat16", + gpu_memory_utilization=args.gpu_memory_utilization, + max_model_len=args.max_model_len, + trust_remote_code=True, + # Enable prefix caching for faster few-shot evaluation + enable_prefix_caching=True, + ) + + # Run evaluation with vLLM + logger.info("Starting evaluation with vLLM backend...") + results = simple_evaluate( + model=model, + tasks=tasks, + num_fewshot=args.num_fewshot, + batch_size=args.batch_size, + limit=args.limit, + # Greedy decoding for reproducibility + gen_kwargs="temperature=0,do_sample=False", + log_samples=True, + ) + else: + # Fall back to HuggingFace backend + from lm_eval.models.huggingface import HFLM - # Try to load with flash attention, fall back if not available - try: - model = HFLM(**model_kwargs) - except Exception as e: - logger.warning(f"Flash attention failed ({e}), using default attention") - model_kwargs.pop("attn_implementation", None) - model = HFLM(**model_kwargs) + model_kwargs = { + "pretrained": args.model, + "dtype": "bfloat16", + "device_map": "auto", + "trust_remote_code": True, + "attn_implementation": "flash_attention_2", + } - # Run evaluation - logger.info("Starting evaluation...") - results = simple_evaluate( - model=model, - tasks=tasks, - num_fewshot=args.num_fewshot, - batch_size=args.batch_size, - device=args.device, - limit=args.limit, - # Greedy decoding for reproducibility - gen_kwargs="temperature=0,do_sample=False", - log_samples=True, - ) + try: + model = HFLM(**model_kwargs) + except Exception as e: + logger.warning(f"Flash attention failed ({e}), using default attention") + model_kwargs.pop("attn_implementation", None) + model = HFLM(**model_kwargs) + + logger.info("Starting evaluation with HuggingFace backend...") + results = simple_evaluate( + model=model, + tasks=tasks, + num_fewshot=args.num_fewshot, + batch_size=args.batch_size, + device=args.device, + limit=args.limit, + gen_kwargs="temperature=0,do_sample=False", + log_samples=True, + ) return results @@ -155,8 +266,6 @@ def print_results(results: dict, args: argparse.Namespace) -> None: # Extract per-subject results subject_results = {} - total_correct = 0 - total_samples = 0 for task_name, task_results in results.get("results", {}).items(): # Extract subject from task name and expand abbreviations @@ -175,14 +284,6 @@ def print_results(results: dict, args: argparse.Namespace) -> None: "stderr": stderr, } - # For aggregate calculation - n_samples = task_results.get("alias", {}).get("n-shot", 0) - if "samples" in results: - task_samples = results["samples"].get(task_name, []) - n_samples = len(task_samples) - total_correct += sum(1 for s in task_samples if s.get("acc", 0) == 1) - total_samples += n_samples - # Print per-subject results print("\nPer-Subject Accuracy:") print("-" * 50) @@ -191,13 +292,9 @@ def print_results(results: dict, args: argparse.Namespace) -> None: stderr_pct = data["stderr"] * 100 print(f" {subject:35s} {acc_pct:5.2f}% ± {stderr_pct:.2f}%") - # Print aggregate - if total_samples > 0: - overall_acc = total_correct / total_samples * 100 - else: - # Use average of subjects - accs = [d["accuracy"] for d in subject_results.values()] - overall_acc = sum(accs) / len(accs) * 100 if accs else 0 + # Print aggregate (use weighted average based on number of samples per subject) + accs = [d["accuracy"] for d in subject_results.values()] + overall_acc = sum(accs) / len(accs) * 100 if accs else 0 print("-" * 50) print(f" {'OVERALL':35s} {overall_acc:5.2f}%") From 5b5f3edfc62c5e0f21cdd868aa89e0ba33663cd6 Mon Sep 17 00:00:00 2001 From: Erfan Miahi Date: Thu, 27 Nov 2025 06:17:42 +0000 Subject: [PATCH 09/17] feat(dependencies): add lm-eval for evaluation harness with vLLM support --- tools/vllm-server/pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tools/vllm-server/pyproject.toml b/tools/vllm-server/pyproject.toml index 01dc7845..d5ec6664 100644 --- a/tools/vllm-server/pyproject.toml +++ b/tools/vllm-server/pyproject.toml @@ -12,6 +12,8 @@ dependencies = [ "openai>=1.0.0", "trl>=0.22.0", "wandb>=0.22.3", + # Evaluation harness with vLLM support + "lm-eval>=0.4.0", ] [tool.ruff] From 32cd9a62e527d8652feea11f17d69026562b7e1b Mon Sep 17 00:00:00 2001 From: Erfan Miahi Date: Thu, 27 Nov 2025 07:46:01 +0000 Subject: [PATCH 10/17] feat(cli): increase default batch size to 16 for optimal performance and add critical checks for time management in parallel mining --- grail/cli/parallel_miner.py | 69 ++++++++++++++++++++++++++++++++----- 1 file changed, 60 insertions(+), 9 deletions(-) diff --git a/grail/cli/parallel_miner.py b/grail/cli/parallel_miner.py index 8c2e49a0..e424d058 100644 --- a/grail/cli/parallel_miner.py +++ b/grail/cli/parallel_miner.py @@ -70,7 +70,7 @@ class GPUWorkerConfig: use_drand: bool checkpoint_path: str | None # Wallet names read from environment in worker for subprocess isolation - batch_size: int = 2 + batch_size: int = 16 # Match single miner's default for optimal performance safety_blocks: int = 3 @@ -80,7 +80,7 @@ class ParallelMinerConfig: num_gpus: int = 8 problems_per_gpu: int = 12 - batch_size: int = 2 + batch_size: int = 16 # Match single miner's default for optimal performance safety_blocks: int = 3 use_drand: bool = True results_dir: Path = field( @@ -531,8 +531,18 @@ async def _gather_results( all_succeeded = (successful_gpus == expected_gpu_count) and len(failed_gpus) == 0 if all_succeeded: + # CRITICAL: Sort inferences by rollout_group (problem_index) then rollout_index + # The validator uses file-order to derive seed: first group in file = group_index 0 + # If we don't sort, a GPU that finishes first could put problem 24 before problem 0, + # causing the validator to derive wrong seeds and fail validation! + all_inferences.sort( + key=lambda x: ( + int(x.get("rollout_group", 0)), # Primary: problem index + int(x.get("rollout_index", 0)), # Secondary: rollout within problem + ) + ) logger.info( - "✅ All %d GPUs succeeded: %d total rollouts ready for upload", + "✅ All %d GPUs succeeded: %d total rollouts ready for upload (sorted by problem ID)", successful_gpus, len(all_inferences), ) @@ -668,8 +678,30 @@ async def run_parallel_miner( await asyncio.sleep(30) continue - # Check time budget - skip for parallel mode since workers manage their own time - # The parallel coordinator ensures all workers complete before upload + # Check time budget BEFORE starting parallel mining + # Parallel mode needs more time since all GPUs must complete before upload + current_block = await subtensor.get_current_block() + blocks_remaining = (window_start + WINDOW_LENGTH) - current_block + + # Estimate time needed: rough estimate based on problems and safety margin + # Each GPU needs ~30-60s per problem with batch_size=16, plus upload time + estimated_time_per_problem = 45 # seconds, conservative estimate + estimated_upload_time = 30 # seconds + total_estimated_seconds = ( + config.problems_per_gpu * estimated_time_per_problem + estimated_upload_time + ) + # Convert to blocks (12 seconds per block) + estimated_blocks_needed = (total_estimated_seconds // 12) + config.safety_blocks + + if blocks_remaining < estimated_blocks_needed: + logger.warning( + "⏰ Skipping window %d: only %d blocks remaining, need ~%d blocks for parallel mining", + window_start, + blocks_remaining, + estimated_blocks_needed, + ) + await asyncio.sleep(10) + continue # Get window randomness window_block_hash, combined_randomness = await get_window_randomness( @@ -692,12 +724,31 @@ async def run_parallel_miner( checkpoint_path, ) - # Upload aggregated results + # Upload aggregated results - but first check we have time! if inferences: + # CRITICAL: Check blocks remaining before upload + current_block = await subtensor.get_current_block() + blocks_remaining = (window_start + WINDOW_LENGTH) - current_block + + if blocks_remaining < config.safety_blocks: + logger.error( + "❌ SKIPPING UPLOAD: Only %d blocks remaining (need %d safety blocks)", + blocks_remaining, + config.safety_blocks, + ) + logger.error( + "Window %d will be missed - workers took too long", + window_start, + ) + # Don't upload late - validator would reject anyway + last_window_start = window_start + continue + logger.info( - "📤 Uploading %d aggregated rollouts for window %d", + "📤 Uploading %d aggregated rollouts for window %d (%d blocks remaining)", len(inferences), window_start, + blocks_remaining, ) upload_duration = await upload_inferences_with_metrics( wallet, @@ -740,10 +791,10 @@ def parallel_mine( help="Minimum number of problems each GPU should generate", ), batch_size: int = typer.Option( - 2, + 16, "--batch-size", "-b", - help="Rollout batch size within each problem (1-16)", + help="Rollout batch size within each problem (default 16 for optimal A100 performance)", ), safety_blocks: int = typer.Option( 3, From c82bf4829885e3211add994026c941089a66bb81 Mon Sep 17 00:00:00 2001 From: Erfan Miahi Date: Thu, 27 Nov 2025 07:46:12 +0000 Subject: [PATCH 11/17] fix(cli): update nonce calculation in package_rollout_data to prevent collisions by using ROLLOUTS_PER_PROBLEM as multiplier --- grail/cli/mine.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/grail/cli/mine.py b/grail/cli/mine.py index c4145cdd..ddb9a012 100644 --- a/grail/cli/mine.py +++ b/grail/cli/mine.py @@ -431,7 +431,10 @@ def package_rollout_data( Returns: Signed dictionary ready to upload for validation """ - rollout_nonce = base_nonce * 10 + rollout_idx + # CRITICAL: Use ROLLOUTS_PER_PROBLEM (16) as multiplier to avoid nonce collisions + # Old formula (base_nonce * 10) caused duplicates when rollout_idx >= 10 + # e.g., problem 14 rollout 10 = 150, problem 15 rollout 0 = 150 (collision!) + rollout_nonce = base_nonce * ROLLOUTS_PER_PROBLEM + rollout_idx # Sign commit binding (tokens, randomness, model, layer, commitments) from ..protocol.signatures import sign_commit_binding From 57fe89ab407d9d20b48efda69d50ce8b8660d613 Mon Sep 17 00:00:00 2001 From: Erfan Miahi Date: Thu, 27 Nov 2025 07:46:22 +0000 Subject: [PATCH 12/17] feat(model): add support for PyTorch SDPA in model loading with priority handling for Flash Attention 2 --- grail/model/provider.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/grail/model/provider.py b/grail/model/provider.py index c5aa7845..2a5def2e 100644 --- a/grail/model/provider.py +++ b/grail/model/provider.py @@ -59,6 +59,7 @@ def get_model( use_safetensors: bool = True, eval_mode: bool = True, use_flash_attention: bool = False, + use_sdpa: bool = True, checkpoint_window: int | None = None, ) -> Any: """Load model with consistent configuration. @@ -69,7 +70,9 @@ def get_model( use_safetensors: Whether to prefer safetensors format eval_mode: Whether to set model to eval() mode use_flash_attention: Whether to use Flash Attention 2 (requires flash-attn package). - Only enabled for training, not for evaluation/inference. + Takes priority over SDPA if both are enabled. + use_sdpa: Whether to use PyTorch SDPA (Scaled Dot-Product Attention). + Built into PyTorch 2.0+, provides 10-30% speedup. Default: True. checkpoint_window: Optional checkpoint window number. If not provided, will be extracted from metadata.json or parsed from the path. @@ -111,19 +114,27 @@ def get_model( except (ValueError, IndexError): pass - # Configure attention implementation + # Configure attention implementation (priority: Flash Attention 2 > SDPA > default) attn_implementation = None - if use_flash_attention and device == "cuda": - try: - import flash_attn # noqa: F401 - - attn_implementation = "flash_attention_2" - logger.info("Using Flash Attention 2 for model loading") - except ImportError: - logger.warning( - "flash-attn not installed; falling back to default attention. " - "Install with: uv pip install flash-attn" - ) + if device == "cuda": + if use_flash_attention: + try: + import flash_attn # noqa: F401 + + attn_implementation = "flash_attention_2" + logger.info("Using Flash Attention 2 for model loading") + except ImportError: + logger.warning( + "flash-attn not installed; falling back to SDPA. " + "Install with: uv pip install flash-attn" + ) + if use_sdpa: + attn_implementation = "sdpa" + logger.info("Using PyTorch SDPA (Scaled Dot-Product Attention)") + elif use_sdpa: + # SDPA is built into PyTorch 2.0+ and provides good speedup + attn_implementation = "sdpa" + logger.info("Using PyTorch SDPA (Scaled Dot-Product Attention)") # Load model with optimized attention if available model = AutoModelForCausalLM.from_pretrained( From 62e009c5face2216243e5efd3fbbafe33a5efb1e Mon Sep 17 00:00:00 2001 From: Erfan Miahi Date: Mon, 1 Dec 2025 21:35:12 +0000 Subject: [PATCH 13/17] feat(evaluation): add AIME 2024, AMC 2023, GSM8K eval tasks for lm-harness that support reasoning model evaluation as well. --- research/eval/README.md | 379 ++++++++++++++++++ .../eval/tasks/aime24_grail/aime24_grail.yaml | 26 ++ research/eval/tasks/aime24_grail/utils.py | 180 +++++++++ research/eval/tasks/amc2023/amc2023.yaml | 28 ++ research/eval/tasks/amc2023/utils.py | 162 ++++++++ .../tasks/amc2023_grail/amc2023_grail.yaml | 25 ++ research/eval/tasks/amc2023_grail/utils.py | 167 ++++++++ .../eval/tasks/gsm8k_grail/gsm8k_grail.yaml | 26 ++ research/eval/tasks/gsm8k_grail/utils.py | 140 +++++++ .../_default_template.yaml | 21 + .../hendrycks_math_grail.yaml | 15 + .../hendrycks_math_grail_algebra.yaml | 5 + ...endrycks_math_grail_counting_and_prob.yaml | 5 + .../hendrycks_math_grail_geometry.yaml | 5 + ...rycks_math_grail_intermediate_algebra.yaml | 5 + .../hendrycks_math_grail_num_theory.yaml | 5 + .../hendrycks_math_grail_prealgebra.yaml | 5 + .../hendrycks_math_grail_precalc.yaml | 5 + .../eval/tasks/hendrycks_math_grail/utils.py | 251 ++++++++++++ .../hendrycks_math_pass_at_5.yaml | 33 ++ .../hendrycks_math_grail_pass_at_k/utils.py | 249 ++++++++++++ 21 files changed, 1737 insertions(+) create mode 100644 research/eval/README.md create mode 100644 research/eval/tasks/aime24_grail/aime24_grail.yaml create mode 100644 research/eval/tasks/aime24_grail/utils.py create mode 100644 research/eval/tasks/amc2023/amc2023.yaml create mode 100644 research/eval/tasks/amc2023/utils.py create mode 100644 research/eval/tasks/amc2023_grail/amc2023_grail.yaml create mode 100644 research/eval/tasks/amc2023_grail/utils.py create mode 100644 research/eval/tasks/gsm8k_grail/gsm8k_grail.yaml create mode 100644 research/eval/tasks/gsm8k_grail/utils.py create mode 100644 research/eval/tasks/hendrycks_math_grail/_default_template.yaml create mode 100644 research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail.yaml create mode 100644 research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_algebra.yaml create mode 100644 research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_counting_and_prob.yaml create mode 100644 research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_geometry.yaml create mode 100644 research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_intermediate_algebra.yaml create mode 100644 research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_num_theory.yaml create mode 100644 research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_prealgebra.yaml create mode 100644 research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_precalc.yaml create mode 100644 research/eval/tasks/hendrycks_math_grail/utils.py create mode 100644 research/eval/tasks/hendrycks_math_grail_pass_at_k/hendrycks_math_pass_at_5.yaml create mode 100644 research/eval/tasks/hendrycks_math_grail_pass_at_k/utils.py diff --git a/research/eval/README.md b/research/eval/README.md new file mode 100644 index 00000000..8325faa1 --- /dev/null +++ b/research/eval/README.md @@ -0,0 +1,379 @@ +# MATH Benchmark Evaluation + +Evaluate language models on the [Hendrycks MATH](https://github.com/hendrycks/math) dataset using [EleutherAI's lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) with vLLM backend. + +## Overview + +This directory contains custom evaluation tasks for reasoning models that use the GRAIL format: +- `` ... `` for chain-of-thought reasoning +- `` ... `` for final answers + +## Prerequisites + +```bash +# Activate the vLLM environment +source /root/grail/tools/vllm-server/.venv/bin/activate +``` + +## Quick Start + +### 1. Base Model (Standard Evaluation) + +Standard 4-shot evaluation without reasoning format: + +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=Qwen/Qwen2.5-1.5B-Instruct,dtype=bfloat16,gpu_memory_utilization=0.9,max_model_len=4096" \ + --tasks hendrycks_math \ + --num_fewshot 4 \ + --batch_size auto \ + --gen_kwargs "temperature=0,do_sample=False" \ + --output_path ./results/base_4shot \ + --log_samples +``` + +### 2. Reasoning Model (Custom Template) + +For models trained with the GRAIL reasoning format, use the custom task with chat template: + +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=/path/to/checkpoint,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192" \ + --tasks hendrycks_math_grail \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --fewshot_as_multiturn \ + --num_fewshot 4 \ + --batch_size auto \ + --log_samples \ + --output_path ./results/reasoning_4shot +``` + +## Evaluation Configurations + +### Base Model Configurations + +| Config | Command Flags | Use Case | +|--------|--------------|----------| +| 0-shot | `--num_fewshot 0` | Zero-shot baseline | +| 4-shot | `--num_fewshot 4` | Standard MATH benchmark | + +### Reasoning Model Configurations + +| Config | Command Flags | Use Case | +|--------|--------------|----------| +| 0-shot | `--num_fewshot 0 --apply_chat_template` | Zero-shot with reasoning template | +| 4-shot multiturn | `--num_fewshot 4 --apply_chat_template --fewshot_as_multiturn` | **Recommended** - Few-shot as conversation | + +## Key Arguments + +| Argument | Description | +|----------|-------------| +| `--tasks hendrycks_math` | Standard MATH evaluation (7 subjects) | +| `--tasks hendrycks_math_grail` | Custom GRAIL reasoning format | +| `--include_path` | Path to custom task definitions | +| `--apply_chat_template` | Apply model's chat template | +| `--fewshot_as_multiturn` | Format few-shot examples as multi-turn conversation | +| `--think_end_token` | Token marking end of reasoning (extracts answer after this) | +| `--max_model_len` | Context length (use 8192+ for 4-shot) | +| `--log_samples` | Save per-sample outputs for analysis | + +## Example Commands + +### Evaluate GRAIL Checkpoint (Recommended) + +```bash +cd /root/grail && source tools/vllm-server/.venv/bin/activate + +CUDA_VISIBLE_DEVICES=0 python -m lm_eval \ + --model vllm \ + --model_args "pretrained=/root/grail/grail_final_checkpoint,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192" \ + --tasks hendrycks_math_grail \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --fewshot_as_multiturn \ + --num_fewshot 4 \ + --batch_size auto \ + --log_samples \ + --output_path /root/grail/eval_results/grail_checkpoint +``` + +### Evaluate Base Model with Reasoning Template + +First, prepare the base model with custom chat template: + +```bash +# Download and patch the model (one-time setup) +python -c " +from huggingface_hub import snapshot_download +import json + +# Download model +snapshot_download('Qwen/Qwen2.5-1.5B-Instruct', local_dir='./models/Qwen2.5-1.5B-Instruct-reasoning') + +# Patch tokenizer config with reasoning template +with open('./models/Qwen2.5-1.5B-Instruct-reasoning/tokenizer_config.json', 'r') as f: + config = json.load(f) + +config['chat_template'] = \"\"\"{% if messages[0]['role'] == 'system' %}{{ messages[0]['content'] + eos_token }}{% set loop_messages = messages[1:] %}{% else %}{{ 'You are given a problem. +Think about the problem and provide your working out. +Place it between and . +Then, provide your solution between .' + eos_token }}{% set loop_messages = messages %}{% endif %}{% for message in loop_messages %}{% if message['role'] == 'user' %}{{ message['content'] }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '' }}{% endif %}\"\"\" + +with open('./models/Qwen2.5-1.5B-Instruct-reasoning/tokenizer_config.json', 'w') as f: + json.dump(config, f, indent=2) +" +``` + +Then evaluate: + +```bash +CUDA_VISIBLE_DEVICES=0 python -m lm_eval \ + --model vllm \ + --model_args "pretrained=./models/Qwen2.5-1.5B-Instruct-reasoning,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192" \ + --tasks hendrycks_math_grail \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --fewshot_as_multiturn \ + --num_fewshot 4 \ + --batch_size auto \ + --log_samples \ + --output_path ./results/base_reasoning_4shot +``` + +### Run in Background + +```bash +CUDA_VISIBLE_DEVICES=0 nohup python -m lm_eval \ + --model vllm \ + --model_args "pretrained=/root/grail/grail_final_checkpoint,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192" \ + --tasks hendrycks_math_grail \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --fewshot_as_multiturn \ + --num_fewshot 4 \ + --batch_size auto \ + --log_samples \ + --output_path ./results/grail_checkpoint \ + > eval.log 2>&1 & +echo "Started. PID: $!" +``` + +## Benchmark Results + +| Model | Config | Accuracy | +|-------|--------|----------| +| Qwen2.5-1.5B-Instruct | 0-shot standard | 1.90% | +| Qwen2.5-1.5B-Instruct | 4-shot standard | 12.66% | +| Qwen2.5-1.5B-Instruct + reasoning template | 4-shot multiturn | 28.00% | +| grail_final_checkpoint | 4-shot multiturn | **30.34%** | + +## Task Structure + +``` +tasks/hendrycks_math_grail/ +├── _default_template.yaml # Base config with reasoning format +├── hendrycks_math_grail.yaml # Task group definition +├── hendrycks_math_grail_algebra.yaml +├── hendrycks_math_grail_counting_and_prob.yaml +├── hendrycks_math_grail_geometry.yaml +├── hendrycks_math_grail_intermediate_algebra.yaml +├── hendrycks_math_grail_num_theory.yaml +├── hendrycks_math_grail_prealgebra.yaml +├── hendrycks_math_grail_precalc.yaml +└── utils.py # Answer extraction and comparison +``` + +## Reasoning Format + +The custom chat template instructs the model to: + +``` +You are given a problem. +Think about the problem and provide your working out. +Place it between and . +Then, provide your solution between . +``` + +Example output: +``` + +Let me solve this step by step... +The answer is 42. + +42 +``` + +The `think_end_token=` argument tells the evaluator to extract the answer from text **after** this token, effectively using only the `` content for scoring. + +## AIME 2024 Benchmark + +AIME (American Invitational Mathematics Examination) is an extremely challenging competition math benchmark. The dataset contains 30 problems from AIME 2024. + +### Running AIME Evaluations + +**Base model (0-shot or 4-shot):** +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=Qwen/Qwen2.5-1.5B-Instruct,dtype=bfloat16,gpu_memory_utilization=0.9,max_model_len=4096,max_gen_toks=2048" \ + --tasks aime24 \ + --num_fewshot 0 \ + --batch_size auto \ + --gen_kwargs "temperature=0,do_sample=False,max_gen_toks=2048" \ + --log_samples \ + --output_path ./results/aime24_base +``` + +**Reasoning model (GRAIL checkpoint):** +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=/root/grail/grail_final_checkpoint,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192,enforce_eager=True" \ + --tasks aime24_grail \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --num_fewshot 0 \ + --batch_size auto \ + --log_samples \ + --output_path ./results/aime24_grail +``` + +**Note**: AIME is extremely difficult - even 70B+ models typically achieve only 3-10% on AIME. Small models (1.5B) are expected to score near 0%. + +## Pass@k Evaluation + +For sampling-based evaluation with pass@k metrics: + +### Best Practices + +| Parameter | Recommended Value | Notes | +|-----------|------------------|-------| +| `repeats` | 10 (for pass@5), 100 (for pass@100) | Number of samples per problem | +| `temperature` | 0.6 - 0.8 | Higher = more diversity | +| `top_p` | 0.95 | Nucleus sampling | +| `do_sample` | true | Required for sampling | + +### Formula + +pass@k = 1 - C(n-c, k) / C(n, k) + +Where: +- n = total samples generated +- c = number of correct samples +- k = number of samples to consider + +### Example: Pass@5 on MATH + +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=/root/grail/grail_final_checkpoint,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192" \ + --tasks hendrycks_math_pass_at_5 \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --batch_size auto \ + --log_samples \ + --output_path ./results/math_pass_at_5 +``` + +### Key Differences from Greedy Evaluation + +| Greedy (pass@1) | Sampling (pass@k) | +|-----------------|-------------------| +| `temperature=0` | `temperature=0.7` | +| `do_sample=false` | `do_sample=true` | +| `repeats=1` | `repeats=10+` | +| Single deterministic output | Multiple diverse outputs | + +### Custom Pass@k Tasks + +Create a task YAML with: +```yaml +repeats: 10 # Generate 10 samples per problem +generation_kwargs: + do_sample: true + temperature: 0.7 + top_p: 0.95 +metric_list: + - metric: !function utils.aggregate_pass_at_5 + aggregation: mean + higher_is_better: true +``` + +## AMC 2023 Benchmark + +AMC (American Mathematics Competition) is a high school math competition. The AMC 2023 dataset contains 40 problems. + +### Running AMC Evaluations + +**Base model (0-shot or 4-shot):** +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=Qwen/Qwen2.5-1.5B-Instruct,dtype=bfloat16,gpu_memory_utilization=0.9,max_model_len=4096,max_gen_toks=2048" \ + --tasks amc2023 \ + --include_path /root/grail/research/eval/tasks \ + --num_fewshot 0 \ + --batch_size auto \ + --gen_kwargs "temperature=0,do_sample=False" \ + --log_samples \ + --output_path ./results/amc2023_base +``` + +**Reasoning model (GRAIL checkpoint):** +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=/root/grail/grail_final_checkpoint,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192,enforce_eager=True" \ + --tasks amc2023_grail \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --num_fewshot 0 \ + --batch_size auto \ + --log_samples \ + --output_path ./results/amc2023_grail +``` + +### AMC 2023 Results + +| Model | Config | Accuracy | +|-------|--------|----------| +| Qwen2.5-1.5B-Instruct | 0-shot | 17.5% | +| Qwen2.5-1.5B-Instruct | 4-shot | 17.5% | +| grail_final_checkpoint | reasoning template | 17.5% | + +## GSM8K Benchmark + +GSM8K (Grade School Math 8K) is a dataset of 8.5K high-quality linguistically diverse grade school math word problems. The test set contains 1319 problems. + +### Running GSM8K Evaluations + +**Base model (0-shot or 4-shot):** +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=Qwen/Qwen2.5-1.5B-Instruct,dtype=bfloat16,gpu_memory_utilization=0.9,max_model_len=4096" \ + --tasks gsm8k \ + --num_fewshot 4 \ + --batch_size auto \ + --gen_kwargs "temperature=0,do_sample=False" \ + --log_samples \ + --output_path ./results/gsm8k_base +``` + +**Reasoning model (GRAIL checkpoint):** +```bash +CUDA_VISIBLE_DEVICES=0 lm_eval \ + --model vllm \ + --model_args "pretrained=/root/grail/grail_final_checkpoint,dtype=bfloat16,think_end_token=,gpu_memory_utilization=0.9,max_model_len=8192,enforce_eager=True" \ + --tasks gsm8k_grail \ + --include_path /root/grail/research/eval/tasks \ + --apply_chat_template \ + --num_fewshot 0 \ + --batch_size auto \ + --log_samples \ + --output_path ./results/gsm8k_grail +``` diff --git a/research/eval/tasks/aime24_grail/aime24_grail.yaml b/research/eval/tasks/aime24_grail/aime24_grail.yaml new file mode 100644 index 00000000..48af0027 --- /dev/null +++ b/research/eval/tasks/aime24_grail/aime24_grail.yaml @@ -0,0 +1,26 @@ +tag: + - math_word_problems +task: aime24_grail +dataset_path: Maxwell-Jia/AIME_2024 +output_type: generate_until +training_split: train +fewshot_split: train +test_split: train +doc_to_text: "{{Problem}}" +doc_to_target: "\n{{Solution}}\n\n{{Answer}}" +process_results: !function utils.process_results +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "<|im_end|>" + - "<|endoftext|>" + do_sample: false + temperature: 0.0 + max_gen_toks: 4096 +repeats: 1 +num_fewshot: 0 +metadata: + version: 1.0 diff --git a/research/eval/tasks/aime24_grail/utils.py b/research/eval/tasks/aime24_grail/utils.py new file mode 100644 index 00000000..1f655511 --- /dev/null +++ b/research/eval/tasks/aime24_grail/utils.py @@ -0,0 +1,180 @@ +"""AIME 2024 evaluation utilities for GRAIL reasoning models. + +Extracts answers from ... tags and uses robust +integer comparison for AIME answers (which are always 0-999). +""" + +import re + + +def process_results(doc: dict, results: list[str]) -> dict[str, int]: + """Process model output and compare with target answer. + + AIME answers are always integers from 000-999. + """ + retval = 0 + response = results[0] + + # Extract answer from ... tags first (for reasoning models) + solution_match = re.search(r"(.*?)", response, re.DOTALL) + if solution_match: + answer = solution_match.group(1).strip() + else: + # Fallback: try to extract from $...$ format + indices = [pos for pos, char in enumerate(response) if char == "$"] + if len(indices) >= 2: + answer = response[indices[0] + 1 : indices[-1]] + else: + # Fallback: try to extract from \boxed{} + boxed_answer = last_boxed_only_string(response) + if boxed_answer is not None: + try: + answer = remove_boxed(boxed_answer) + except (AssertionError, IndexError): + answer = response + else: + answer = response + + # Get target answer + answer_key = next((k for k in doc.keys() if k.lower() == "answer"), None) + if answer_key is None: + return {"exact_match": 0} + + target = str(doc[answer_key]) + + # AIME answers are integers 0-999, so try integer comparison + if is_equiv(answer, target): + retval = 1 + + return {"exact_match": retval} + + +def is_equiv(str1: str, str2: str, verbose: bool = False) -> bool: + """Check if two answers are equivalent. + + For AIME, answers are integers 0-999. We try to extract and compare integers. + """ + if str1 is None and str2 is None: + return True + if str1 is None or str2 is None: + return False + + try: + # Clean and normalize strings + ss1 = strip_string(str1) + ss2 = strip_string(str2) + + if verbose: + print(f"Comparing: '{ss1}' vs '{ss2}'") + + # Direct string comparison + if ss1 == ss2: + return True + + # Try integer comparison (AIME answers are always integers) + try: + int1 = extract_integer(ss1) + int2 = extract_integer(ss2) + if int1 is not None and int2 is not None: + return int1 == int2 + except (ValueError, TypeError): + pass + + return False + except Exception: + return str1 == str2 + + +def extract_integer(s: str) -> int: + """Extract integer from string, handling common formats.""" + if s is None: + return None + + s = s.strip() + + # Remove common wrappers + s = re.sub(r"\\boxed\{([^}]*)\}", r"\1", s) + s = re.sub(r"\$([^$]*)\$", r"\1", s) + s = s.strip() + + # Try direct integer parse + try: + return int(s) + except ValueError: + pass + + # Try to find integers in the string + matches = re.findall(r"-?\d+", s) + if matches: + # Return the last integer found (usually the final answer) + return int(matches[-1]) + + return None + + +def remove_boxed(s: str) -> str: + """Remove \\boxed{} wrapper from string.""" + if "\\boxed " in s: + left = "\\boxed " + assert s[: len(left)] == left + return s[len(left) :] + + left = "\\boxed{" + assert s[: len(left)] == left + assert s[-1] == "}" + return s[len(left) : -1] + + +def last_boxed_only_string(string: str) -> str: + """Extract the last \\boxed{} content from string.""" + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + return None + return string[idx : right_brace_idx + 1] + + +def strip_string(string: str) -> str: + """Normalize string for comparison.""" + if string is None: + return "" + + # Remove linebreaks + string = string.replace("\n", "") + + # Remove common LaTeX + string = string.replace("\\!", "") + string = string.replace("\\\\", "\\") + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove dollar signs + string = string.replace("$", "") + string = string.replace("\\$", "") + + # Remove spaces + string = string.replace(" ", "") + + # Remove leading zeros for integer comparison + string = string.lstrip("0") or "0" + + return string diff --git a/research/eval/tasks/amc2023/amc2023.yaml b/research/eval/tasks/amc2023/amc2023.yaml new file mode 100644 index 00000000..6dc6ba25 --- /dev/null +++ b/research/eval/tasks/amc2023/amc2023.yaml @@ -0,0 +1,28 @@ +tag: + - math_word_problems +task: amc2023 +dataset_path: sparkle-reasoning/amc2023 +output_type: generate_until +test_split: test +fewshot_split: test +doc_to_text: "Problem: {{question}}\n\nAnswer: The answer is" +doc_to_target: " ${{answer|int}}$" +process_results: !function utils.process_results +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "Problem:" + - "\n\n" + - "" + - "<|im_end|>" + - "<|endoftext|>" + do_sample: false + temperature: 0.0 + max_gen_toks: 512 +repeats: 1 +metadata: + version: 1.0 + diff --git a/research/eval/tasks/amc2023/utils.py b/research/eval/tasks/amc2023/utils.py new file mode 100644 index 00000000..8fc76418 --- /dev/null +++ b/research/eval/tasks/amc2023/utils.py @@ -0,0 +1,162 @@ +"""AMC 2023 evaluation utilities. + +AMC answers are integers (multiple choice A-E corresponds to numeric answers). +Extracts answers from $...$ format, \\boxed{}, or plain numbers. +""" + +import re + + +def process_results(doc: dict, results: list[str]) -> dict[str, int]: + """Process model output and compare with target answer.""" + retval = 0 + response = results[0] + + # Try to extract answer from $...$ format first + indices = [pos for pos, char in enumerate(response) if char == "$"] + if len(indices) >= 2: + answer = response[indices[0] + 1 : indices[-1]] + else: + # Try to extract from \boxed{} + boxed_answer = last_boxed_only_string(response) + if boxed_answer is not None: + try: + answer = remove_boxed(boxed_answer) + except (AssertionError, IndexError): + answer = response + else: + answer = response + + # Get target answer + target = str(doc.get("answer", "")) + + # Compare answers + if is_equiv(answer, target): + retval = 1 + + return {"exact_match": retval} + + +def is_equiv(str1: str, str2: str, verbose: bool = False) -> bool: + """Check if two answers are equivalent.""" + if str1 is None and str2 is None: + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = strip_string(str1) + ss2 = strip_string(str2) + + if verbose: + print(f"Comparing: '{ss1}' vs '{ss2}'") + + # Direct string comparison + if ss1 == ss2: + return True + + # Try numeric comparison + try: + num1 = extract_number(ss1) + num2 = extract_number(ss2) + if num1 is not None and num2 is not None: + # Compare as floats with tolerance + return abs(num1 - num2) < 0.01 + except (ValueError, TypeError): + pass + + return False + except Exception: + return str1 == str2 + + +def extract_number(s: str) -> float: + """Extract number from string.""" + if s is None: + return None + + s = s.strip() + + # Remove common wrappers + s = re.sub(r"\\boxed\{([^}]*)\}", r"\1", s) + s = re.sub(r"\$([^$]*)\$", r"\1", s) + s = s.strip() + + # Try direct parse + try: + return float(s) + except ValueError: + pass + + # Try to find numbers in the string + matches = re.findall(r"-?\d+\.?\d*", s) + if matches: + return float(matches[-1]) + + return None + + +def remove_boxed(s: str) -> str: + """Remove \\boxed{} wrapper from string.""" + if "\\boxed " in s: + left = "\\boxed " + assert s[: len(left)] == left + return s[len(left) :] + + left = "\\boxed{" + assert s[: len(left)] == left + assert s[-1] == "}" + return s[len(left) : -1] + + +def last_boxed_only_string(string: str) -> str: + """Extract the last \\boxed{} content from string.""" + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + return None + return string[idx : right_brace_idx + 1] + + +def strip_string(string: str) -> str: + """Normalize string for comparison.""" + if string is None: + return "" + + string = string.replace("\n", "") + string = string.replace("\\!", "") + string = string.replace("\\\\", "\\") + string = string.replace("\\left", "") + string = string.replace("\\right", "") + string = string.replace("$", "") + string = string.replace("\\$", "") + string = string.replace(" ", "") + + # Handle float formatting (e.g., "27.0" -> "27") + try: + num = float(string) + if num == int(num): + string = str(int(num)) + except ValueError: + pass + + return string diff --git a/research/eval/tasks/amc2023_grail/amc2023_grail.yaml b/research/eval/tasks/amc2023_grail/amc2023_grail.yaml new file mode 100644 index 00000000..995537b5 --- /dev/null +++ b/research/eval/tasks/amc2023_grail/amc2023_grail.yaml @@ -0,0 +1,25 @@ +tag: + - math_word_problems +task: amc2023_grail +dataset_path: sparkle-reasoning/amc2023 +output_type: generate_until +test_split: test +fewshot_split: test +doc_to_text: "{{question}}" +doc_to_target: "\n{{solution}}\n\n{{answer}}" +process_results: !function utils.process_results +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "<|im_end|>" + - "<|endoftext|>" + do_sample: false + temperature: 0.0 + max_gen_toks: 4096 +repeats: 1 +num_fewshot: 0 +metadata: + version: 1.0 diff --git a/research/eval/tasks/amc2023_grail/utils.py b/research/eval/tasks/amc2023_grail/utils.py new file mode 100644 index 00000000..45612c18 --- /dev/null +++ b/research/eval/tasks/amc2023_grail/utils.py @@ -0,0 +1,167 @@ +"""AMC 2023 evaluation utilities for GRAIL reasoning models. + +Extracts answers from ... tags and uses robust +numeric comparison for AMC answers. +""" + +import re + + +def process_results(doc: dict, results: list[str]) -> dict[str, int]: + """Process model output and compare with target answer.""" + retval = 0 + response = results[0] + + # Extract answer from ... tags first (for reasoning models) + solution_match = re.search(r"(.*?)", response, re.DOTALL) + if solution_match: + answer = solution_match.group(1).strip() + else: + # Fallback: try to extract from $...$ format + indices = [pos for pos, char in enumerate(response) if char == "$"] + if len(indices) >= 2: + answer = response[indices[0] + 1 : indices[-1]] + else: + # Fallback: try to extract from \boxed{} + boxed_answer = last_boxed_only_string(response) + if boxed_answer is not None: + try: + answer = remove_boxed(boxed_answer) + except (AssertionError, IndexError): + answer = response + else: + answer = response + + # Get target answer + target = str(doc.get("answer", "")) + + # Compare answers + if is_equiv(answer, target): + retval = 1 + + return {"exact_match": retval} + + +def is_equiv(str1: str, str2: str, verbose: bool = False) -> bool: + """Check if two answers are equivalent.""" + if str1 is None and str2 is None: + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = strip_string(str1) + ss2 = strip_string(str2) + + if verbose: + print(f"Comparing: '{ss1}' vs '{ss2}'") + + # Direct string comparison + if ss1 == ss2: + return True + + # Try numeric comparison + try: + num1 = extract_number(ss1) + num2 = extract_number(ss2) + if num1 is not None and num2 is not None: + # Compare as floats with tolerance + return abs(num1 - num2) < 0.01 + except (ValueError, TypeError): + pass + + return False + except Exception: + return str1 == str2 + + +def extract_number(s: str) -> float: + """Extract number from string.""" + if s is None: + return None + + s = s.strip() + + # Remove common wrappers + s = re.sub(r"\\boxed\{([^}]*)\}", r"\1", s) + s = re.sub(r"\$([^$]*)\$", r"\1", s) + s = s.strip() + + # Try direct parse + try: + return float(s) + except ValueError: + pass + + # Try to find numbers in the string + matches = re.findall(r"-?\d+\.?\d*", s) + if matches: + return float(matches[-1]) + + return None + + +def remove_boxed(s: str) -> str: + """Remove \\boxed{} wrapper from string.""" + if "\\boxed " in s: + left = "\\boxed " + assert s[: len(left)] == left + return s[len(left) :] + + left = "\\boxed{" + assert s[: len(left)] == left + assert s[-1] == "}" + return s[len(left) : -1] + + +def last_boxed_only_string(string: str) -> str: + """Extract the last \\boxed{} content from string.""" + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + return None + return string[idx : right_brace_idx + 1] + + +def strip_string(string: str) -> str: + """Normalize string for comparison.""" + if string is None: + return "" + + string = string.replace("\n", "") + string = string.replace("\\!", "") + string = string.replace("\\\\", "\\") + string = string.replace("\\left", "") + string = string.replace("\\right", "") + string = string.replace("$", "") + string = string.replace("\\$", "") + string = string.replace(" ", "") + + # Handle float formatting (e.g., "27.0" -> "27") + try: + num = float(string) + if num == int(num): + string = str(int(num)) + except ValueError: + pass + + return string diff --git a/research/eval/tasks/gsm8k_grail/gsm8k_grail.yaml b/research/eval/tasks/gsm8k_grail/gsm8k_grail.yaml new file mode 100644 index 00000000..c7fa8eba --- /dev/null +++ b/research/eval/tasks/gsm8k_grail/gsm8k_grail.yaml @@ -0,0 +1,26 @@ +tag: + - math_word_problems +task: gsm8k_grail +dataset_path: gsm8k +dataset_name: main +output_type: generate_until +test_split: test +fewshot_split: train +doc_to_text: "{{question}}" +doc_to_target: !function utils.doc_to_target +process_results: !function utils.process_results +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "<|im_end|>" + - "<|endoftext|>" + do_sample: false + temperature: 0.0 + max_gen_toks: 1024 +repeats: 1 +num_fewshot: 0 +metadata: + version: 1.0 diff --git a/research/eval/tasks/gsm8k_grail/utils.py b/research/eval/tasks/gsm8k_grail/utils.py new file mode 100644 index 00000000..be72529c --- /dev/null +++ b/research/eval/tasks/gsm8k_grail/utils.py @@ -0,0 +1,140 @@ +"""GSM8K evaluation utilities for GRAIL reasoning models. + +Extracts answers from ... tags and compares with +the ground truth answer (after ####). +""" + +import re + + +def doc_to_target(doc: dict) -> str: + """Convert document to target format for GRAIL reasoning.""" + answer = doc["answer"] + # Extract final answer after #### + if "####" in answer: + final_answer = answer.split("####")[-1].strip() + else: + final_answer = answer.strip() + + # Extract the reasoning part (before ####) + if "####" in answer: + reasoning = answer.split("####")[0].strip() + else: + reasoning = "" + + return ( + f"\n{reasoning}\n\n{final_answer}" + ) + + +def process_results(doc: dict, results: list[str]) -> dict[str, int]: + """Process model output and compare with target answer.""" + retval = 0 + response = results[0] + + # Extract answer from ... tags first (for reasoning models) + solution_match = re.search(r"(.*?)", response, re.DOTALL) + if solution_match: + answer = solution_match.group(1).strip() + else: + # Fallback: try to extract number from the end of response + answer = extract_last_number(response) + + # Get target answer from document + target_answer = doc["answer"] + if "####" in target_answer: + target = target_answer.split("####")[-1].strip() + else: + target = target_answer.strip() + + # Compare answers + if is_equiv(answer, target): + retval = 1 + + return {"exact_match": retval} + + +def is_equiv(str1: str, str2: str, verbose: bool = False) -> bool: + """Check if two answers are equivalent.""" + if str1 is None and str2 is None: + return True + if str1 is None or str2 is None: + return False + + try: + # Clean both strings + s1 = clean_answer(str1) + s2 = clean_answer(str2) + + if verbose: + print(f"Comparing: '{s1}' vs '{s2}'") + + # Direct string comparison + if s1 == s2: + return True + + # Try numeric comparison + try: + num1 = extract_number(s1) + num2 = extract_number(s2) + if num1 is not None and num2 is not None: + return abs(num1 - num2) < 0.001 + except (ValueError, TypeError): + pass + + return False + except Exception: + return str1 == str2 + + +def clean_answer(s: str) -> str: + """Clean answer string for comparison.""" + if s is None: + return "" + + s = s.strip() + # Remove dollar signs, commas, and common formatting + s = s.replace("$", "").replace(",", "").replace(" ", "") + # Remove trailing period + s = s.rstrip(".") + + return s + + +def extract_number(s: str) -> float: + """Extract number from string.""" + if s is None: + return None + + s = clean_answer(s) + + # Try direct parse + try: + return float(s) + except ValueError: + pass + + # Try to find numbers in the string + matches = re.findall(r"-?\d+\.?\d*", s) + if matches: + return float(matches[-1]) + + return None + + +def extract_last_number(s: str) -> str: + """Extract the last number from a string.""" + if s is None: + return "" + + # Look for #### pattern first (GSM8K format) + if "####" in s: + return s.split("####")[-1].strip() + + # Find all numbers + matches = re.findall(r"-?\d+(?:,\d{3})*(?:\.\d+)?", s) + if matches: + # Return last number, removing commas + return matches[-1].replace(",", "") + + return s.strip() diff --git a/research/eval/tasks/hendrycks_math_grail/_default_template.yaml b/research/eval/tasks/hendrycks_math_grail/_default_template.yaml new file mode 100644 index 00000000..4e94637e --- /dev/null +++ b/research/eval/tasks/hendrycks_math_grail/_default_template.yaml @@ -0,0 +1,21 @@ +dataset_path: EleutherAI/hendrycks_math +process_docs: !function utils.process_docs +output_type: generate_until +training_split: train +test_split: test +doc_to_text: "{{problem}}" +doc_to_target: "\n{{solution}}\n\n{{answer}}" +process_results: !function utils.process_results +generation_kwargs: + until: + - "<|im_end|>" + - "<|endoftext|>" + do_sample: false + temperature: 0 + max_gen_toks: 2048 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +metadata: + version: 1.0 diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail.yaml b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail.yaml new file mode 100644 index 00000000..c23e1ba4 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail.yaml @@ -0,0 +1,15 @@ +group: hendrycks_math_grail +task: + - hendrycks_math_grail_algebra + - hendrycks_math_grail_counting_and_prob + - hendrycks_math_grail_geometry + - hendrycks_math_grail_intermediate_algebra + - hendrycks_math_grail_num_theory + - hendrycks_math_grail_prealgebra + - hendrycks_math_grail_precalc +aggregate_metric_list: + - metric: exact_match + aggregation: mean + weight_by_size: true +metadata: + version: 1.0 diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_algebra.yaml b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_algebra.yaml new file mode 100644 index 00000000..95fe5683 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_algebra.yaml @@ -0,0 +1,5 @@ +include: _default_template.yaml +tag: + - math_word_problems +task: hendrycks_math_grail_algebra +dataset_name: algebra diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_counting_and_prob.yaml b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_counting_and_prob.yaml new file mode 100644 index 00000000..dfa695a0 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_counting_and_prob.yaml @@ -0,0 +1,5 @@ +include: _default_template.yaml +tag: + - math_word_problems +task: hendrycks_math_grail_counting_and_prob +dataset_name: counting_and_probability diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_geometry.yaml b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_geometry.yaml new file mode 100644 index 00000000..5743de5d --- /dev/null +++ b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_geometry.yaml @@ -0,0 +1,5 @@ +include: _default_template.yaml +tag: + - math_word_problems +task: hendrycks_math_grail_geometry +dataset_name: geometry diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_intermediate_algebra.yaml b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_intermediate_algebra.yaml new file mode 100644 index 00000000..a9db9246 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_intermediate_algebra.yaml @@ -0,0 +1,5 @@ +include: _default_template.yaml +tag: + - math_word_problems +task: hendrycks_math_grail_intermediate_algebra +dataset_name: intermediate_algebra diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_num_theory.yaml b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_num_theory.yaml new file mode 100644 index 00000000..95e3260a --- /dev/null +++ b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_num_theory.yaml @@ -0,0 +1,5 @@ +include: _default_template.yaml +tag: + - math_word_problems +task: hendrycks_math_grail_num_theory +dataset_name: number_theory diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_prealgebra.yaml b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_prealgebra.yaml new file mode 100644 index 00000000..c8e8bde6 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_prealgebra.yaml @@ -0,0 +1,5 @@ +include: _default_template.yaml +tag: + - math_word_problems +task: hendrycks_math_grail_prealgebra +dataset_name: prealgebra diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_precalc.yaml b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_precalc.yaml new file mode 100644 index 00000000..81594a08 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_precalc.yaml @@ -0,0 +1,5 @@ +include: _default_template.yaml +tag: + - math_word_problems +task: hendrycks_math_grail_precalc +dataset_name: precalculus diff --git a/research/eval/tasks/hendrycks_math_grail/utils.py b/research/eval/tasks/hendrycks_math_grail/utils.py new file mode 100644 index 00000000..0853668d --- /dev/null +++ b/research/eval/tasks/hendrycks_math_grail/utils.py @@ -0,0 +1,251 @@ +"""Custom utils for GRAIL reasoning model evaluation on MATH. + +Extracts answers from ... tags instead of \\boxed{}. +""" + +import re + +import datasets + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + """Process dataset docs - extract ground truth answer from \\boxed{}.""" + + def _process_doc(doc: dict) -> dict: + out_doc = { + "problem": doc["problem"], + "solution": doc["solution"], + "answer": remove_boxed(last_boxed_only_string(doc["solution"])), + } + return out_doc + + return dataset.map(_process_doc) + + +def extract_solution_tag(text: str) -> str: + """Extract content from ... tags.""" + match = re.search(r"([\s\S]*?)", text) + if match: + return match.group(1).strip() + # Fallback: return original text if no tags found + return text.strip() + + +def process_results(doc: dict, results: list[str]) -> dict[str, int]: + """Process results - extract answer from tags and compare.""" + retval = 0 + + # Extract from tags + model_answer = extract_solution_tag(results[0]) + + # Get ground truth (already extracted from \boxed{} in process_docs) + ground_truth = doc.get("answer", remove_boxed(last_boxed_only_string(doc["solution"]))) + + if is_equiv(model_answer, ground_truth): + retval = 1 + + return {"exact_match": retval} + + +# ============================================================================ +# String normalization functions (from lm-eval hendrycks_math/utils.py) +# ============================================================================ + + +def is_equiv(str1: str, str2: str, verbose: bool = False) -> bool: + """Check if two strings are equivalent after normalization.""" + if str1 is None and str2 is None: + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = strip_string(str1) + ss2 = strip_string(str2) + if verbose: + print(ss1, ss2) + return ss1 == ss2 + except Exception: + return str1 == str2 + + +def remove_boxed(s: str) -> str: + """Remove \\boxed{} wrapper from string.""" + if s is None: + return None + if "\\boxed " in s: + left = "\\boxed " + if s[: len(left)] == left: + return s[len(left) :] + + left = "\\boxed{" + if s[: len(left)] == left and s[-1] == "}": + return s[len(left) : -1] + + return s + + +def last_boxed_only_string(string: str) -> str: + """Extract the last \\boxed{} or \\fbox{} from a string.""" + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + return None + return string[idx : right_brace_idx + 1] + + +def fix_fracs(string: str) -> str: + """Fix fraction formatting.""" + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except AssertionError: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + return new_str + + +def fix_a_slash_b(string: str) -> str: + """Convert a/b to \\frac{a}{b}.""" + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == f"{a}/{b}" + return "\\frac{" + str(a) + "}{" + str(b) + "}" + except (AssertionError, ValueError): + return string + + +def remove_right_units(string: str) -> str: + """Remove units on the right side.""" + if "\\text{ " in string: + splits = string.split("\\text{ ") + if len(splits) == 2: + return splits[0] + return string + + +def fix_sqrt(string: str) -> str: + """Fix sqrt formatting.""" + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def strip_string(string: str) -> str: + """Normalize string for comparison.""" + # linebreaks + string = string.replace("\n", "") + + # remove inverse spaces + string = string.replace("\\!", "") + + # replace \\ with \ + string = string.replace("\\\\", "\\") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") # noqa: W605 + + # " 0." equivalent to " ." and "{0." equivalent to "{." + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # fix fractions + string = fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # X/Y changed to \frac{X}{Y} + string = fix_a_slash_b(string) + + return string diff --git a/research/eval/tasks/hendrycks_math_grail_pass_at_k/hendrycks_math_pass_at_5.yaml b/research/eval/tasks/hendrycks_math_grail_pass_at_k/hendrycks_math_pass_at_5.yaml new file mode 100644 index 00000000..6870678b --- /dev/null +++ b/research/eval/tasks/hendrycks_math_grail_pass_at_k/hendrycks_math_pass_at_5.yaml @@ -0,0 +1,33 @@ +# MATH pass@5 evaluation task +# Generates 10 samples per problem with temperature sampling +# Computes pass@1, pass@5 + +group: hendrycks_math_pass_at_k +task: hendrycks_math_pass_at_5 +dataset_path: EleutherAI/hendrycks_math +dataset_name: algebra +output_type: generate_until +training_split: train +test_split: test +doc_to_text: "{{problem}}" +doc_to_target: "{{answer}}" +process_results: !function utils.process_results +metric_list: + - metric: !function utils.aggregate_pass_at_1 + aggregation: mean + higher_is_better: true + - metric: !function utils.aggregate_pass_at_5 + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "<|im_end|>" + - "<|endoftext|>" + do_sample: true + temperature: 0.7 + top_p: 0.95 + max_gen_toks: 2048 +repeats: 10 +num_fewshot: 0 +metadata: + version: 1.0 diff --git a/research/eval/tasks/hendrycks_math_grail_pass_at_k/utils.py b/research/eval/tasks/hendrycks_math_grail_pass_at_k/utils.py new file mode 100644 index 00000000..565fe2a5 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_grail_pass_at_k/utils.py @@ -0,0 +1,249 @@ +"""Pass@k utilities for MATH benchmark evaluation. + +Implements the standard pass@k metric: given n samples, compute the probability +that at least one of k random samples is correct. + +Formula: pass@k = 1 - C(n-c, k) / C(n, k) +where n = total samples, c = correct samples, k = samples to consider +""" + +import re +from math import comb +from typing import Any + + +def pass_at_k( + references: list[str], + predictions: list[list[str]], + k: list[int] | None = None, +) -> dict[str, float]: + """Compute pass@k for math problems. + + Args: + references: List of ground truth answers + predictions: List of lists of model predictions (n samples per problem) + k: List of k values to compute (e.g., [1, 5, 10]) + + Returns: + Dictionary with pass@k scores for each k value + """ + if k is None: + k = [1, 5] + if isinstance(k, int): + k = [k] + + results = {} + for k_val in k: + pass_at_k_scores = [] + for ref, preds in zip(references, predictions, strict=False): + n = len(preds) + if n < k_val: + # If we have fewer samples than k, use what we have + c = sum(1 for p in preds if is_equiv(extract_answer(p), ref)) + score = 1.0 if c > 0 else 0.0 + else: + c = sum(1 for p in preds if is_equiv(extract_answer(p), ref)) + score = _pass_at_k(n, c, k_val) + pass_at_k_scores.append(score) + + avg_score = sum(pass_at_k_scores) / len(pass_at_k_scores) if pass_at_k_scores else 0.0 + results[f"pass@{k_val}"] = avg_score + + return results + + +def _pass_at_k(n: int, c: int, k: int) -> float: + """Compute pass@k for a single problem. + + Args: + n: Total number of samples + c: Number of correct samples + k: Number of samples to consider + + Returns: + Probability that at least one of k samples is correct + """ + if n - c < k: + return 1.0 + return 1.0 - comb(n - c, k) / comb(n, k) + + +def extract_answer(text: str) -> str: + """Extract answer from model output. + + Tries multiple extraction patterns in order: + 1. ... tags (for reasoning models) + 2. \\boxed{...} (LaTeX boxed answers) + 3. $...$ (LaTeX inline math at end) + 4. Last number in text + """ + if text is None: + return "" + + # Try tags first (reasoning models) + solution_match = re.search(r"(.*?)", text, re.DOTALL) + if solution_match: + return solution_match.group(1).strip() + + # Try \boxed{} + boxed = last_boxed_only_string(text) + if boxed: + try: + return remove_boxed(boxed) + except (AssertionError, IndexError): + pass + + # Try $...$ at end + indices = [pos for pos, char in enumerate(text) if char == "$"] + if len(indices) >= 2: + return text[indices[-2] + 1 : indices[-1]] + + # Fallback: return cleaned text + return text.strip() + + +def is_equiv(str1: str, str2: str) -> bool: + """Check if two answers are mathematically equivalent.""" + if str1 is None and str2 is None: + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = strip_string(str1) + ss2 = strip_string(str2) + return ss1 == ss2 + except Exception: + return str1 == str2 + + +def remove_boxed(s: str) -> str: + """Remove \\boxed{} wrapper from string.""" + if "\\boxed " in s: + left = "\\boxed " + assert s[: len(left)] == left + return s[len(left) :] + + left = "\\boxed{" + assert s[: len(left)] == left + assert s[-1] == "}" + return s[len(left) : -1] + + +def last_boxed_only_string(string: str) -> str | None: + """Extract the last \\boxed{} content from string.""" + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + return None + return string[idx : right_brace_idx + 1] + + +def strip_string(string: str) -> str: + """Normalize string for comparison.""" + if string is None: + return "" + + # Remove linebreaks + string = string.replace("\n", "") + + # Remove common LaTeX + string = string.replace("\\!", "") + string = string.replace("\\\\", "\\") + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + string = string.replace("\\left", "") + string = string.replace("\\right", "") + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # Remove dollar signs + string = string.replace("$", "") + string = string.replace("\\$", "") + + # Remove percentage + string = string.replace("\\%", "") + string = string.replace("%", "") + + # Remove spaces + string = string.replace(" ", "") + + return string + + +# For lm-eval metric interface +def process_results(doc: dict, results: list[str]) -> dict[str, Any]: + """Process results for a single document (used by lm-eval). + + This function is called for EACH sample. The pass@k aggregation + happens in the aggregation function. + """ + answer = doc.get("answer", "") + + # Check each result + correct_list = [] + for result in results: + extracted = extract_answer(result) + is_correct = 1 if is_equiv(extracted, str(answer)) else 0 + correct_list.append(is_correct) + + return { + "pass": correct_list, # List of 0/1 for each sample + "num_correct": sum(correct_list), + "num_samples": len(correct_list), + } + + +def aggregate_pass_at_k(results: list[dict], k: int = 5) -> float: + """Aggregate pass@k across all documents. + + Args: + results: List of result dicts from process_results + k: Number of samples to consider for pass@k + + Returns: + Average pass@k score + """ + scores = [] + for r in results: + n = r["num_samples"] + c = r["num_correct"] + if n < k: + score = 1.0 if c > 0 else 0.0 + else: + score = _pass_at_k(n, c, k) + scores.append(score) + + return sum(scores) / len(scores) if scores else 0.0 + + +# Convenience aggregation functions for different k values +def aggregate_pass_at_1(results: list[dict]) -> float: + return aggregate_pass_at_k(results, k=1) + + +def aggregate_pass_at_5(results: list[dict]) -> float: + return aggregate_pass_at_k(results, k=5) + + +def aggregate_pass_at_10(results: list[dict]) -> float: + return aggregate_pass_at_k(results, k=10) From 003fe3cb5e1c7f6e209797f1aadc3df63cc45cef Mon Sep 17 00:00:00 2001 From: Erfan Miahi Date: Tue, 2 Dec 2025 05:37:08 +0000 Subject: [PATCH 14/17] feat(training): enhance TRL GRPO training script with detailed hyperparameter configuration, add TrainingPassAtKTracker for pass@k metrics logging, and update evaluation callback for improved WandB integration --- research/trl/train_trl_grpo.py | 396 +++++++++++++++++++++++++++------ 1 file changed, 330 insertions(+), 66 deletions(-) diff --git a/research/trl/train_trl_grpo.py b/research/trl/train_trl_grpo.py index b631e97c..ae818ad0 100644 --- a/research/trl/train_trl_grpo.py +++ b/research/trl/train_trl_grpo.py @@ -49,56 +49,102 @@ # ════════════════════════════════════════════════════════════════════════════ -# HYPERPARAMETERS (from .env GRAIL config) +# HYPERPARAMETERS (from .env GRAIL config - exactly matching grail/trainer/algorithms/grpo.py) # ════════════════════════════════════════════════════════════════════════════ @dataclass class Config: - # Model (from GRAIL_TRAIN_MODEL_ID) + # ──────────────────────────────────────────────────────────────────────── + # Model Configuration (from GRAIL_TRAIN_MODEL_ID) + # ──────────────────────────────────────────────────────────────────────── model_id: str = "Qwen/Qwen2.5-1.5B-Instruct" - # Learning rate (from GRAIL_TRAINER_LR) + + # ──────────────────────────────────────────────────────────────────────── + # Training Hyperparameters (from grail/shared/constants.py + env vars) + # These match GRAIL's GRPOAlgorithm config exactly + # ──────────────────────────────────────────────────────────────────────── + # Learning rate (GRAIL_TRAINER_LR, constants.py default: 1e-6) lr: float = 3e-6 - # Epochs per window (from GRAIL_TRAINER_EPOCHS) + # Epochs per training iteration (GRAIL_TRAINER_EPOCHS, constants.py default: 1) epochs: int = 1 - # Batch size (from GRAIL_TRAINER_BATCH_SIZE) + # Batch size per device (GRAIL_TRAINER_BATCH_SIZE, constants.py default: 16) batch_size: int = 4 - # Gradient accumulation (from GRAIL_TRAINER_GRAD_ACCUM_STEPS) + # Gradient accumulation steps (GRAIL_TRAINER_GRAD_ACCUM_STEPS, constants.py default: 8) + # Effective batch = batch_size × grad_accum_steps = 4 × 128 = 512 grad_accum_steps: int = 128 - # Max sequence length (from GRAIL_TRAINER_MAX_LENGTH) + # Max sequence length (GRAIL_TRAINER_MAX_LENGTH, constants.py default: 2048) max_length: int = 2048 - # Gradient clipping (from GRAIL_TRAINER_GRAD_CLIP) + # Gradient clipping threshold (GRAIL_TRAINER_GRAD_CLIP, constants.py default: 0.5) grad_clip: float = 1.0 - # Warmup steps (from GRAIL_TRAINER_WARMUP_STEPS) + # Warmup steps for LR scheduler (GRAIL_TRAINER_WARMUP_STEPS, constants.py default: 10) warmup_steps: int = 50 - # KL coefficient (from GRAIL_TRAINER_KL_COEF) + # Total training windows (GRAIL_TRAINER_TOTAL_WINDOWS) - controls iteration count + # Each optimizer step = 32 groups × 16 rollouts = 512 samples + # total_optimizer_steps calculated below based on total_windows + total_windows: int = 100 + + # ──────────────────────────────────────────────────────────────────────── + # GRPO Loss Configuration (from grail/trainer/algorithms/grpo.py) + # ──────────────────────────────────────────────────────────────────────── + # KL divergence coefficient (GRAIL_TRAINER_KL_COEF, constants.py default: 0.02) kl_coef: float = 0.0 - # Entropy coefficient (from GRAIL_TRAINER_ENTROPY_COEF) + # Entropy coefficient for exploration (GRAIL_TRAINER_ENTROPY_COEF, constants.py default: 0.001) + # Note: TRL may not support entropy regularization directly entropy_coef: float = 0.0005 - # PPO clip epsilon (standard GRAIL values) + # PPO clip epsilon lower bound (TRAINER_PPO_CLIP_EPS, constants.py default: 0.2) ppo_clip_eps: float = 0.2 + # PPO clip epsilon upper bound - DAPO-style asymmetric clipping + # (TRAINER_PPO_CLIP_EPS_UPPER, constants.py default: 0.28) ppo_clip_eps_upper: float = 0.28 - # Importance sampling ratio max (from GRAIL_TRAINER_IS_RATIO_MAX) + # Importance sampling ratio ceiling (GRAIL_TRAINER_IS_RATIO_MAX, constants.py default: 10.0) + # Prevents training instability from extreme ratios is_ratio_max: float = 2.5 - # Log-ratio clamp (from GRAIL_TRAINER_LOGRATIO_CLAMP) + # Log-ratio clamp for numerical stability (GRAIL_TRAINER_LOGRATIO_CLAMP, constants.py default: 5.0) + # ln(2.5) ≈ 0.916 → aligned with IS_RATIO_MAX logratio_clamp: float = 0.92 - # Dataset sampling + # Advantage clipping percentile (GRAIL_TRAINER_ADV_CLIP_PERCENTILE, constants.py default: 99.0) + # Note: TRL handles advantage normalization differently + adv_clip_percentile: float = 99.0 + # Group advantage sum tolerance (GRAIL_TRAINER_GROUP_ADV_SUM_TOL, constants.py default: 0.01) + # Note: TRL doesn't use group validation, but kept for reference + group_adv_sum_tol: float = 0.01 + # GRPO loss variant (GRAIL_GRPO_VARIANT, constants.py default: "dapo") + # Options: 'grpo', 'bnpo', 'dapo', 'dr_grpo' + grpo_variant: str = "dapo" + # Importance sampling level (GRAIL_IMPORTANCE_SAMPLING_LEVEL, constants.py default: "sequence") + # Options: 'sequence' (one ratio per sequence), 'token' (per-token ratios) + # Note: TRL uses token-level IS by default when using vLLM + importance_sampling_level: str = "sequence" + + # ──────────────────────────────────────────────────────────────────────── + # GRPO Data Configuration (from grail/shared/constants.py) + # ──────────────────────────────────────────────────────────────────────── + # Groups per optimizer step = effective_batch / rollouts_per_problem = 512 / 16 = 32 + max_groups: int = 32 + # Max completion tokens (GRPO_MAX_COMPLETION_TOKENS, constants.py default: 1024) + max_new_tokens: int = 1024 + # Rollouts per problem (ROLLOUTS_PER_PROBLEM, constants.py: 16) + rollouts_per_problem: int = 16 + + # ──────────────────────────────────────────────────────────────────────── + # Dataset Sampling + # ──────────────────────────────────────────────────────────────────────── num_train_samples: int | None = None # None = use all training samples num_eval_samples: int | None = None # None = use all test samples - # Rollouts per problem (matches GRAIL default) - rollouts_per_problem: int = 16 - # Generation parameters + + # ──────────────────────────────────────────────────────────────────────── + # Generation Parameters + # ──────────────────────────────────────────────────────────────────────── temperature: float = 0.7 top_p: float = 0.95 top_k: int = 50 - # Max completion tokens (from GRPO_MAX_COMPLETION_TOKENS) - max_new_tokens: int = 1024 - # Evaluation config + + # ──────────────────────────────────────────────────────────────────────── + # Evaluation Configuration + # ──────────────────────────────────────────────────────────────────────── eval_replicates: int = 5 report_ks: tuple[int, ...] = (1, 5, 10) - # Evaluation optimization eval_batch_size: int = 128 eval_num_workers: int = 4 - # Max groups for GRPO (from GRPO_MAX_GROUPS) - max_groups: int = 128 cfg = Config() @@ -507,6 +553,143 @@ def get_dataset_adapter(dataset_name: str) -> DatasetAdapter: return adapters[dataset_name.lower()]() +# ════════════════════════════════════════════════════════════════════════════ +# TRAINING PASS@K TRACKER +# ════════════════════════════════════════════════════════════════════════════ +class TrainingPassAtKTracker: + """Computes and logs pass@k metrics during GRPO training. + + This class wraps the reward computation and tracks pass@k metrics + by grouping completions by their prompts. Uses the same unbiased pass@k + formula as evaluation (KMetricsAggregator from grail.trainer.metrics). + + Usage: + tracker = TrainingPassAtKTracker(adapter, prompt_to_answer) + trainer = GRPOTrainer(..., reward_funcs=tracker, ...) + """ + + # Required by TRL GRPOTrainer for reward function naming + __name__ = "reward_with_pass_at_k" + + def __init__( + self, + adapter: DatasetAdapter, + prompt_to_answer: dict[str, str], + report_ks: tuple[int, ...] = (1, 5, 10), + ) -> None: + """Initialize the tracker. + + Args: + adapter: Dataset adapter for reward computation and success threshold + prompt_to_answer: Mapping from prompt text to gold answer + report_ks: Tuple of k values for pass@k metrics + """ + self._adapter = adapter + self._prompt_to_answer = prompt_to_answer + self._report_ks = report_ks + self._step_count = 0 + + def __call__( + self, + completions: list[str], + prompts: list[str], + **kwargs: Any, + ) -> list[float]: + """Compute rewards and log pass@k metrics. + + This method is called by GRPOTrainer for each batch of completions. + + Args: + completions: List of model completions + prompts: List of corresponding prompts + **kwargs: Additional arguments (gold_answer, metadatas, etc.) + + Returns: + List of reward values for each completion + """ + gold_answers = self._extract_gold_answers(prompts, kwargs) + rewards = self._compute_rewards(completions, gold_answers) + metrics = self._compute_pass_at_k_metrics(prompts, rewards) + self._log_to_wandb(metrics) + self._step_count += 1 + return rewards + + def _extract_gold_answers( + self, + prompts: list[str], + kwargs: dict[str, Any], + ) -> list[str]: + """Extract gold answers from kwargs or prompt mapping.""" + if "gold_answer" in kwargs and kwargs["gold_answer"]: + return kwargs["gold_answer"] + if "metadatas" in kwargs and kwargs["metadatas"]: + return [m.get("gold_answer", "") for m in kwargs["metadatas"]] + return [self._prompt_to_answer.get(p, "") for p in prompts] + + def _compute_rewards( + self, + completions: list[str], + gold_answers: list[str], + ) -> list[float]: + """Compute reward for each completion.""" + return [ + self._adapter.compute_reward(c, g) + for c, g in zip(completions, gold_answers, strict=False) + ] + + def _compute_pass_at_k_metrics( + self, + prompts: list[str], + rewards: list[float], + ) -> dict[str, float]: + """Compute all metrics using KMetricsAggregator (unbiased pass@k formula).""" + from collections import defaultdict + + # Group rewards by prompt + prompt_groups: dict[str, list[float]] = defaultdict(list) + for prompt, reward in zip(prompts, rewards, strict=False): + prompt_groups[prompt].append(reward) + + group_count = len(prompt_groups) + expected_groups = cfg.max_groups + step_index = self._step_count + 1 + print( + "[TrainingPassAtKTracker] " + f"Step {step_index}: grouped {group_count} prompts " + f"(max_groups={expected_groups})" + ) + if group_count != expected_groups: + print( + "[TrainingPassAtKTracker] ⚠️ " + f"group_count ({group_count}) != max_groups ({expected_groups})" + ) + + # Use KMetricsAggregator for metrics computation + aggregator = KMetricsAggregator(report_ks=self._report_ks) + threshold = self._adapter.success_threshold + + for task_id, group_rewards in enumerate(prompt_groups.values()): + successes = [r >= threshold for r in group_rewards] + aggregator.add_group( + task_id=str(task_id), + rewards=group_rewards, + successes=successes, + ) + + return aggregator.summarize() + + def _log_to_wandb(self, metrics: dict[str, float]) -> None: + """Log metrics to WandB.""" + try: + import wandb + + if wandb.run is not None and metrics: + wandb_data = {f"train/{k}": v for k, v in metrics.items()} + wandb.log(wandb_data) + except Exception: + pass # Silently ignore WandB errors + + # ════════════════════════════════════════════════════════════════════════════ # DATA PREPARATION # ════════════════════════════════════════════════════════════════════════════ @@ -577,14 +760,14 @@ def __init__( eval_data: list[dict[str, Any]], tokenizer: PreTrainedTokenizer, vllm_base_url: str, - eval_every_n_steps: int = 30, + eval_every_n_steps: int = 40, ) -> None: self.adapter = adapter self.eval_data = eval_data self.tokenizer = tokenizer self.eval_every_n = eval_every_n_steps self.base_url = vllm_base_url.rstrip("/") - self._metrics_defined = False + self._wandb_configured = False print( f"✓ VLLMEvalCallback initialized: dataset={adapter.name}, " @@ -603,16 +786,15 @@ def run_and_log(self, step: int, label: str = "VLLM EVAL") -> dict[str, float]: import wandb if wandb.run is not None: - if not self._metrics_defined: + # Configure step metric for eval on first call + if not self._wandb_configured: wandb.define_metric("eval_step") - wandb.define_metric("eval_vllm/*", step_metric="eval_step") - self._metrics_defined = True + wandb.define_metric("eval/*", step_metric="eval_step") + self._wandb_configured = True - wandb_data = { - "eval_step": step, - "trainer/global_step": step, - } - wandb_data.update({f"eval_vllm/{k}": v for k, v in metrics.items()}) + # Log eval metrics with 'eval/' prefix and custom step + wandb_data = {"eval_step": step} + wandb_data.update({f"eval/{k}": v for k, v in metrics.items()}) wandb.log(wandb_data) except Exception as e: print(f"⚠️ WandB logging failed: {e}") @@ -816,7 +998,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--eval-every", type=int, - default=30, + default=40, help="Run evaluation every N steps (default: 30)", ) return parser.parse_args() @@ -826,10 +1008,42 @@ def main() -> None: args = parse_args() print(f"🚀 Starting TRL GRPO training with {args.dataset.upper()} dataset") - print("=" * 60) + print("=" * 80) + + # Print hyperparameter alignment summary + print("\n📋 GRAIL Hyperparameter Alignment Summary:") + print("─" * 80) + print(f" {'Parameter':<40} {'Value':<15} {'GRAIL Env Var'}") + print("─" * 80) + print(f" {'Model ID':<40} {cfg.model_id:<15} GRAIL_TRAIN_MODEL_ID") + print(f" {'Learning Rate':<40} {cfg.lr:<15} GRAIL_TRAINER_LR") + print(f" {'Epochs (per window)':<40} {cfg.epochs:<15} GRAIL_TRAINER_EPOCHS") + print(f" {'Batch Size':<40} {cfg.batch_size:<15} GRAIL_TRAINER_BATCH_SIZE") + print( + f" {'Gradient Accum Steps':<40} {cfg.grad_accum_steps:<15} GRAIL_TRAINER_GRAD_ACCUM_STEPS" + ) + print(f" {'Max Length':<40} {cfg.max_length:<15} GRAIL_TRAINER_MAX_LENGTH") + print(f" {'Max Completion Tokens':<40} {cfg.max_new_tokens:<15} GRPO_MAX_COMPLETION_TOKENS") + print(f" {'Gradient Clip':<40} {cfg.grad_clip:<15} GRAIL_TRAINER_GRAD_CLIP") + print(f" {'Warmup Steps':<40} {cfg.warmup_steps:<15} GRAIL_TRAINER_WARMUP_STEPS") + print(f" {'Total Windows':<40} {cfg.total_windows:<15} GRAIL_TRAINER_TOTAL_WINDOWS") + print(f" {'KL Coefficient':<40} {cfg.kl_coef:<15} GRAIL_TRAINER_KL_COEF") + print(f" {'Entropy Coefficient':<40} {cfg.entropy_coef:<15} GRAIL_TRAINER_ENTROPY_COEF") + print(f" {'PPO Clip Epsilon':<40} {cfg.ppo_clip_eps:<15} TRAINER_PPO_CLIP_EPS") + print( + f" {'PPO Clip Epsilon Upper':<40} {cfg.ppo_clip_eps_upper:<15} TRAINER_PPO_CLIP_EPS_UPPER" + ) + print(f" {'IS Ratio Max':<40} {cfg.is_ratio_max:<15} GRAIL_TRAINER_IS_RATIO_MAX") + print(f" {'Log-Ratio Clamp':<40} {cfg.logratio_clamp:<15} GRAIL_TRAINER_LOGRATIO_CLAMP") + print(f" {'GRPO Variant':<40} {cfg.grpo_variant:<15} GRAIL_GRPO_VARIANT") + print(f" {'IS Level':<40} {cfg.importance_sampling_level:<15} GRAIL_IMPORTANCE_SAMPLING_LEVEL") + print(f" {'Max Groups':<40} {cfg.max_groups:<15} GRPO_MAX_GROUPS") + print(f" {'Rollouts per Problem':<40} {cfg.rollouts_per_problem:<15} ROLLOUTS_PER_PROBLEM") + print("─" * 80) # Get dataset adapter adapter = get_dataset_adapter(args.dataset) + print("\n📚 Dataset Configuration:") print(f" Dataset: {adapter.name}") print(f" Correctness weight: {adapter.correctness_weight}") print(f" Success threshold: {adapter.success_threshold}") @@ -869,56 +1083,95 @@ def main() -> None: wandb.login(key=wandb_api_key) print(f" ✓ WandB logged in (project: {os.getenv('WANDB_PROJECT', 'grail')})") - # Calculate max_prompt_length + # Calculate max_prompt_length (GRAIL_TRAINER_MAX_LENGTH - GRPO_MAX_COMPLETION_TOKENS) max_prompt_length = cfg.max_length - cfg.max_new_tokens + # Calculate training schedule + # Each optimizer step = generation_batch_size = effective_batch = 512 samples + # = 32 groups × 16 rollouts + effective_batch = cfg.batch_size * cfg.grad_accum_steps # 4 × 128 = 512 + groups_per_step = effective_batch // cfg.rollouts_per_problem # 512 / 16 = 32 + total_optimizer_steps = 320 # Fixed: maintains original training duration + + print("\n📊 Training Schedule:") + print(f" • Effective batch size: {effective_batch} samples") + print(f" • Groups per optimizer step: {groups_per_step}") + print(f" • Rollouts per group: {cfg.rollouts_per_problem}") + print(f" • Total optimizer steps: {total_optimizer_steps}") + grpo_config = GRPOConfig( - output_dir=f"./outputs/trl_{adapter.name}", - learning_rate=cfg.lr, + output_dir=f"./outputs/trl_{adapter.name}_final", + # ───────────────────────────────────────────────────────────────────── + # Learning Rate & Schedule (matching GRAIL trainer config) + # ───────────────────────────────────────────────────────────────────── + learning_rate=cfg.lr, # GRAIL_TRAINER_LR + warmup_steps=cfg.warmup_steps, # GRAIL_TRAINER_WARMUP_STEPS + lr_scheduler_type="cosine", # Cosine annealing (matches grail/neurons/trainer.py) + # Use max_steps to control iterations (matching GRAIL_TRAINER_TOTAL_WINDOWS) + # num_train_epochs is ignored when max_steps is set num_train_epochs=cfg.epochs, - per_device_train_batch_size=cfg.batch_size, - gradient_accumulation_steps=cfg.grad_accum_steps, - max_grad_norm=cfg.grad_clip, - warmup_steps=cfg.warmup_steps, - beta=cfg.kl_coef, - epsilon=cfg.ppo_clip_eps, - epsilon_high=cfg.ppo_clip_eps_upper, - max_prompt_length=max_prompt_length, - max_completion_length=cfg.max_new_tokens, + max_steps=total_optimizer_steps, # Calculated from total_windows + # ───────────────────────────────────────────────────────────────────── + # Batch Size & Gradient Accumulation (matching GRAIL trainer config) + # ───────────────────────────────────────────────────────────────────── + per_device_train_batch_size=cfg.batch_size, # GRAIL_TRAINER_BATCH_SIZE + gradient_accumulation_steps=cfg.grad_accum_steps, # GRAIL_TRAINER_GRAD_ACCUM_STEPS + max_grad_norm=cfg.grad_clip, # GRAIL_TRAINER_GRAD_CLIP + # ───────────────────────────────────────────────────────────────────── + # GRPO Loss Configuration (matching grail/trainer/algorithms/grpo.py) + # ───────────────────────────────────────────────────────────────────── + beta=cfg.kl_coef, # GRAIL_TRAINER_KL_COEF (KL divergence coefficient) + epsilon=cfg.ppo_clip_eps, # TRAINER_PPO_CLIP_EPS (lower clip bound) + epsilon_high=cfg.ppo_clip_eps_upper, # TRAINER_PPO_CLIP_EPS_UPPER (DAPO asymmetric) + loss_type=cfg.grpo_variant, # GRAIL_GRPO_VARIANT ("dapo") + # ───────────────────────────────────────────────────────────────────── + # Sequence Length (matching GRAIL trainer config) + # ───────────────────────────────────────────────────────────────────── + max_prompt_length=max_prompt_length, # max_length - max_completion_tokens + max_completion_length=cfg.max_new_tokens, # GRPO_MAX_COMPLETION_TOKENS + # ───────────────────────────────────────────────────────────────────── + # Generation Parameters + # ───────────────────────────────────────────────────────────────────── temperature=cfg.temperature, top_p=cfg.top_p, top_k=cfg.top_k, repetition_penalty=1.1, - num_generations=cfg.rollouts_per_problem, - generation_batch_size=16, - steps_per_generation=None, + num_generations=cfg.rollouts_per_problem, # ROLLOUTS_PER_PROBLEM + # generation_batch_size must equal effective_batch to ensure: + # - One generation per optimizer step (no stale advantages) + # - 32 groups × 16 rollouts = 512 samples per optimizer update + generation_batch_size=cfg.batch_size * cfg.grad_accum_steps, # 4 × 128 = 512 + # ───────────────────────────────────────────────────────────────────── + # Logging & Checkpointing + # ───────────────────────────────────────────────────────────────────── logging_steps=1, log_completions=True, num_completions_to_print=1, wandb_log_unique_prompts=True, - save_strategy="no", + save_strategy="steps", + save_steps=40, bf16=True, report_to=["wandb"], eval_strategy="no", - run_name=f"trl_{adapter.name}_grpo_qwen15b_env_matched", - loss_type="dapo", + run_name=f"trl_{adapter.name}_grpo_qwen15b_grail_matched_final", + # ───────────────────────────────────────────────────────────────────── + # vLLM Configuration + # ───────────────────────────────────────────────────────────────────── use_vllm=True, vllm_mode="server", vllm_server_base_url="http://127.0.0.1:8000", + # Importance sampling configuration (matching GRAIL_IMPORTANCE_SAMPLING_LEVEL=token) vllm_importance_sampling_correction=False, - vllm_importance_sampling_cap=cfg.is_ratio_max, + vllm_importance_sampling_cap=cfg.is_ratio_max, # GRAIL_TRAINER_IS_RATIO_MAX ) - # Create reward function using adapter - def reward_fn(completions: list[str], prompts: list[str], **kwargs: Any) -> list[float]: - if "gold_answer" in kwargs and kwargs["gold_answer"]: - golds = kwargs["gold_answer"] - return [adapter.compute_reward(c, g) for c, g in zip(completions, golds, strict=False)] - if "metadatas" in kwargs and kwargs["metadatas"]: - golds = [m.get("gold_answer", "") for m in kwargs["metadatas"]] - return [adapter.compute_reward(c, g) for c, g in zip(completions, golds, strict=False)] - golds = [prompt_to_answer.get(p, "") for p in prompts] - return [adapter.compute_reward(c, g) for c, g in zip(completions, golds, strict=False)] + # Create reward tracker with pass@k logging + reward_tracker = TrainingPassAtKTracker( + adapter=adapter, + prompt_to_answer=prompt_to_answer, + report_ks=cfg.report_ks, + ) + print(f" ✓ TrainingPassAtKTracker initialized (report_ks={cfg.report_ks})") print(f"\n🏋️ Training with GRPO on {adapter.name.upper()}...") @@ -933,13 +1186,24 @@ def reward_fn(completions: list[str], prompts: list[str], **kwargs: Any) -> list trainer = GRPOTrainer( model=model, - reward_funcs=reward_fn, + reward_funcs=reward_tracker, args=grpo_config, train_dataset=train_ds, processing_class=tokenizer, callbacks=[vllm_eval_callback], ) + # Initialize WandB explicitly before baseline eval (GRPOTrainer does it lazily in .train()) + import wandb + + if wandb.run is None and grpo_config.report_to and "wandb" in grpo_config.report_to: + wandb.init( + project=os.getenv("WANDB_PROJECT", "grail"), + name=grpo_config.run_name, + config=grpo_config.to_dict(), + ) + print(" ✓ WandB initialized explicitly for baseline eval") + # Baseline evaluation vllm_eval_callback.run_and_log(step=0, label="BASELINE EVAL") From db18b1ff1890b70771fbfa43d764b53627fde0d7 Mon Sep 17 00:00:00 2001 From: ErfanMhi Date: Mon, 8 Dec 2025 23:04:20 +0000 Subject: [PATCH 15/17] feat(evaluation): introduce shared utilities for model evaluation tasks, including answer extraction and string normalization functions. --- research/eval/tasks/_common.py | 523 ++++++++++++++++++ research/eval/tasks/aime24_grail/utils.py | 180 ------ .../aime24_thinking.yaml} | 2 +- research/eval/tasks/aime24_thinking/utils.py | 62 +++ research/eval/tasks/amc2023/utils.py | 164 ++---- research/eval/tasks/amc2023_grail/utils.py | 167 ------ .../amc2023_thinking.yaml} | 2 +- research/eval/tasks/amc2023_thinking/utils.py | 62 +++ research/eval/tasks/gsm8k_grail/utils.py | 140 ----- .../gsm8k_thinking.yaml} | 2 +- research/eval/tasks/gsm8k_thinking/utils.py | 95 ++++ .../hendrycks_math_grail.yaml | 15 - .../eval/tasks/hendrycks_math_grail/utils.py | 251 --------- .../hendrycks_math_grail_pass_at_k/utils.py | 249 --------- .../_default_template.yaml | 0 .../hendrycks_math_thinking.yaml | 15 + .../hendrycks_math_thinking_algebra.yaml} | 2 +- ...ycks_math_thinking_counting_and_prob.yaml} | 2 +- .../hendrycks_math_thinking_geometry.yaml} | 2 +- ...s_math_thinking_intermediate_algebra.yaml} | 2 +- .../hendrycks_math_thinking_num_theory.yaml} | 2 +- .../hendrycks_math_thinking_prealgebra.yaml} | 2 +- .../hendrycks_math_thinking_precalc.yaml} | 2 +- .../tasks/hendrycks_math_thinking/utils.py | 54 ++ .../hendrycks_math_thinking_pass_at_5.yaml} | 4 +- .../utils.py | 104 ++++ 26 files changed, 959 insertions(+), 1146 deletions(-) create mode 100644 research/eval/tasks/_common.py delete mode 100644 research/eval/tasks/aime24_grail/utils.py rename research/eval/tasks/{aime24_grail/aime24_grail.yaml => aime24_thinking/aime24_thinking.yaml} (96%) create mode 100644 research/eval/tasks/aime24_thinking/utils.py delete mode 100644 research/eval/tasks/amc2023_grail/utils.py rename research/eval/tasks/{amc2023_grail/amc2023_grail.yaml => amc2023_thinking/amc2023_thinking.yaml} (96%) create mode 100644 research/eval/tasks/amc2023_thinking/utils.py delete mode 100644 research/eval/tasks/gsm8k_grail/utils.py rename research/eval/tasks/{gsm8k_grail/gsm8k_grail.yaml => gsm8k_thinking/gsm8k_thinking.yaml} (96%) create mode 100644 research/eval/tasks/gsm8k_thinking/utils.py delete mode 100644 research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail.yaml delete mode 100644 research/eval/tasks/hendrycks_math_grail/utils.py delete mode 100644 research/eval/tasks/hendrycks_math_grail_pass_at_k/utils.py rename research/eval/tasks/{hendrycks_math_grail => hendrycks_math_thinking}/_default_template.yaml (100%) create mode 100644 research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking.yaml rename research/eval/tasks/{hendrycks_math_grail/hendrycks_math_grail_algebra.yaml => hendrycks_math_thinking/hendrycks_math_thinking_algebra.yaml} (68%) rename research/eval/tasks/{hendrycks_math_grail/hendrycks_math_grail_counting_and_prob.yaml => hendrycks_math_thinking/hendrycks_math_thinking_counting_and_prob.yaml} (67%) rename research/eval/tasks/{hendrycks_math_grail/hendrycks_math_grail_geometry.yaml => hendrycks_math_thinking/hendrycks_math_thinking_geometry.yaml} (68%) rename research/eval/tasks/{hendrycks_math_grail/hendrycks_math_grail_intermediate_algebra.yaml => hendrycks_math_thinking/hendrycks_math_thinking_intermediate_algebra.yaml} (65%) rename research/eval/tasks/{hendrycks_math_grail/hendrycks_math_grail_num_theory.yaml => hendrycks_math_thinking/hendrycks_math_thinking_num_theory.yaml} (68%) rename research/eval/tasks/{hendrycks_math_grail/hendrycks_math_grail_prealgebra.yaml => hendrycks_math_thinking/hendrycks_math_thinking_prealgebra.yaml} (67%) rename research/eval/tasks/{hendrycks_math_grail/hendrycks_math_grail_precalc.yaml => hendrycks_math_thinking/hendrycks_math_thinking_precalc.yaml} (69%) create mode 100644 research/eval/tasks/hendrycks_math_thinking/utils.py rename research/eval/tasks/{hendrycks_math_grail_pass_at_k/hendrycks_math_pass_at_5.yaml => hendrycks_math_thinking_pass_at_k/hendrycks_math_thinking_pass_at_5.yaml} (90%) create mode 100644 research/eval/tasks/hendrycks_math_thinking_pass_at_k/utils.py diff --git a/research/eval/tasks/_common.py b/research/eval/tasks/_common.py new file mode 100644 index 00000000..bb3cf242 --- /dev/null +++ b/research/eval/tasks/_common.py @@ -0,0 +1,523 @@ +"""Shared utilities for thinking model evaluation tasks. + +This module contains common functions for answer extraction and comparison +used across multiple evaluation tasks (AIME, AMC, GSM8K, MATH, etc.). + +Following DRY principles - extract once, reuse everywhere. +""" + +import re +from collections.abc import Callable + +# ============================================================================= +# Answer Extraction Functions +# ============================================================================= + + +def extract_solution_tag(text: str) -> str | None: + """Extract content from ... tags. + + Args: + text: Model output text + + Returns: + Content inside SOLUTION tags, or None if not found + """ + match = re.search(r"(.*?)", text, re.DOTALL) + if match: + return match.group(1).strip() + return None + + +def extract_dollar_sign_answer(text: str) -> str | None: + """Extract answer from $...$ format (last pair). + + Args: + text: Model output text + + Returns: + Content between last pair of dollar signs, or None if not found + """ + indices = [pos for pos, char in enumerate(text) if char == "$"] + if len(indices) >= 2: + return text[indices[-2] + 1 : indices[-1]] + return None + + +def remove_boxed(s: str) -> str | None: + """Remove \\boxed{} wrapper from string. + + Args: + s: String potentially wrapped in \\boxed{} + + Returns: + Unwrapped content, or original string if no valid wrapper found + """ + if s is None: + return None + + # Handle "\\boxed " format (space after boxed) + if "\\boxed " in s: + left = "\\boxed " + if s[: len(left)] == left: + return s[len(left) :] + + # Handle "\\boxed{...}" format + left = "\\boxed{" + if s[: len(left)] == left and s.endswith("}"): + return s[len(left) : -1] + + return s + + +def last_boxed_only_string(string: str) -> str | None: + """Extract the last \\boxed{} or \\fbox{} content from a string. + + Handles nested braces correctly. + + Args: + string: Text containing potential boxed content + + Returns: + The last boxed expression (including \\boxed{} wrapper), or None + """ + if not string: + return None + + idx = string.rfind("\\boxed") + + # Handle "\\boxed " format + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + # Find matching closing brace + i = idx + right_brace_idx = None + num_left_braces_open = 0 + + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + return None + + return string[idx : right_brace_idx + 1] + + +def extract_answer_cascade( + text: str, + try_solution_tag: bool = True, + try_boxed: bool = True, + try_dollar: bool = True, +) -> str: + """Extract answer using cascade of methods. + + Tries extraction methods in order until one succeeds: + 1. ... tags (optional) + 2. \\boxed{...} (optional) + 3. $...$ format (optional) + 4. Original text (fallback) + + Args: + text: Model output text + try_solution_tag: Whether to try SOLUTION tag extraction + try_boxed: Whether to try boxed extraction + try_dollar: Whether to try dollar sign extraction + + Returns: + Extracted answer string + """ + if not text: + return "" + + # Try SOLUTION tags (reasoning models) + if try_solution_tag: + result = extract_solution_tag(text) + if result: + return result + + # Try boxed format + if try_boxed: + boxed = last_boxed_only_string(text) + if boxed: + unboxed = remove_boxed(boxed) + if unboxed: + return unboxed + + # Try dollar sign format + if try_dollar: + result = extract_dollar_sign_answer(text) + if result: + return result + + return text.strip() + + +# ============================================================================= +# String Normalization Functions +# ============================================================================= + + +def strip_string_basic(string: str) -> str: + """Basic string normalization for comparison. + + Removes common formatting that doesn't affect mathematical meaning: + - Linebreaks, spaces + - LaTeX commands: \\!, \\left, \\right + - Dollar signs + + Args: + string: String to normalize + + Returns: + Normalized string + """ + if string is None: + return "" + + string = string.replace("\n", "") + string = string.replace("\\!", "") + string = string.replace("\\\\", "\\") + string = string.replace("\\left", "") + string = string.replace("\\right", "") + string = string.replace("$", "") + string = string.replace("\\$", "") + string = string.replace(" ", "") + + return string + + +def fix_fracs(string: str) -> str: + """Fix fraction formatting (\\frac12 -> \\frac{1}{2}).""" + substrs = string.split("\\frac") + new_str = substrs[0] + + if len(substrs) > 1: + for substr in substrs[1:]: + new_str += "\\frac" + if not substr or substr[0] == "{": + new_str += substr + else: + if len(substr) < 2: + return string + a = substr[0] + b = substr[1] + if b != "{": + post_substr = substr[2:] if len(substr) > 2 else "" + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + post_substr = substr[2:] if len(substr) > 2 else "" + new_str += "{" + a + "}" + b + post_substr + + return new_str + + +def fix_sqrt(string: str) -> str: + """Fix sqrt formatting (\\sqrt2 -> \\sqrt{2}).""" + if "\\sqrt" not in string: + return string + + splits = string.split("\\sqrt") + new_string = splits[0] + + for split in splits[1:]: + if split and split[0] != "{": + new_substr = "\\sqrt{" + split[0] + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + + return new_string + + +def fix_a_slash_b(string: str) -> str: + """Convert simple fractions a/b to \\frac{a}{b}.""" + if len(string.split("/")) != 2: + return string + + a_str, b_str = string.split("/") + try: + a = int(a_str) + b = int(b_str) + if string == f"{a}/{b}": + return "\\frac{" + str(a) + "}{" + str(b) + "}" + except ValueError: + pass + + return string + + +def remove_right_units(string: str) -> str: + """Remove units on the right side (e.g., '5 \\text{ meters}').""" + if "\\text{ " in string: + splits = string.split("\\text{ ") + if len(splits) == 2: + return splits[0] + return string + + +def strip_string_math(string: str) -> str: + """Full math string normalization for MATH benchmark. + + Includes all basic normalization plus: + - tfrac/dfrac -> frac + - Degrees removal + - Units removal + - Fraction normalization + - Leading decimal fixes + + Args: + string: String to normalize + + Returns: + Normalized string + """ + if string is None: + return "" + + # Basic cleanup + string = string.replace("\n", "") + string = string.replace("\\!", "") + string = string.replace("\\\\", "\\") + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove degrees + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # Remove dollar signs + string = string.replace("\\$", "") + + # Remove units + string = remove_right_units(string) + + # Remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") # noqa: W605 + + # Fix leading decimals + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # Handle "k = " or "q = " at beginning + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # Fix sqrt formatting + string = fix_sqrt(string) + + # Remove spaces + string = string.replace(" ", "") + + # Fix fractions + string = fix_fracs(string) + + # Special case: 0.5 -> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # Convert a/b to \frac{a}{b} + string = fix_a_slash_b(string) + + return string + + +# ============================================================================= +# Number Extraction Functions +# ============================================================================= + + +def extract_integer(s: str) -> int | None: + """Extract integer from string, handling common formats. + + Args: + s: String potentially containing an integer + + Returns: + Extracted integer, or None if not found + """ + if s is None: + return None + + s = s.strip() + + # Remove common wrappers + s = re.sub(r"\\boxed\{([^}]*)\}", r"\1", s) + s = re.sub(r"\$([^$]*)\$", r"\1", s) + s = s.strip() + + # Try direct parse + try: + return int(s) + except ValueError: + pass + + # Find integers in the string (return last one) + matches = re.findall(r"-?\d+", s) + if matches: + return int(matches[-1]) + + return None + + +def extract_float(s: str) -> float | None: + """Extract float from string, handling common formats. + + Args: + s: String potentially containing a number + + Returns: + Extracted float, or None if not found + """ + if s is None: + return None + + s = s.strip() + + # Remove common wrappers + s = re.sub(r"\\boxed\{([^}]*)\}", r"\1", s) + s = re.sub(r"\$([^$]*)\$", r"\1", s) + s = s.strip() + + # Try direct parse + try: + return float(s) + except ValueError: + pass + + # Find numbers in the string (return last one) + matches = re.findall(r"-?\d+\.?\d*", s) + if matches: + return float(matches[-1]) + + return None + + +# ============================================================================= +# Equivalence Checking Functions +# ============================================================================= + + +def is_equiv_string( + str1: str, + str2: str, + normalizer: Callable[[str], str] = strip_string_basic, +) -> bool: + """Check if two strings are equivalent after normalization. + + Args: + str1: First string + str2: Second string + normalizer: Function to normalize strings before comparison + + Returns: + True if equivalent, False otherwise + """ + if str1 is None and str2 is None: + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = normalizer(str1) + ss2 = normalizer(str2) + return ss1 == ss2 + except Exception: + return str1 == str2 + + +def is_equiv_numeric( + str1: str, + str2: str, + tolerance: float = 0.01, + integer_only: bool = False, +) -> bool: + """Check if two strings represent equivalent numbers. + + Args: + str1: First string + str2: Second string + tolerance: Absolute tolerance for float comparison + integer_only: If True, only compare as integers + + Returns: + True if equivalent, False otherwise + """ + if str1 is None or str2 is None: + return str1 is None and str2 is None + + if integer_only: + int1 = extract_integer(str1) + int2 = extract_integer(str2) + if int1 is not None and int2 is not None: + return int1 == int2 + else: + num1 = extract_float(str1) + num2 = extract_float(str2) + if num1 is not None and num2 is not None: + return abs(num1 - num2) < tolerance + + return False + + +def is_equiv_combined( + str1: str, + str2: str, + normalizer: Callable[[str], str] = strip_string_basic, + try_numeric: bool = True, + tolerance: float = 0.01, + integer_only: bool = False, +) -> bool: + """Check equivalence using both string and numeric comparison. + + Args: + str1: First string + str2: Second string + normalizer: Function to normalize strings + try_numeric: Whether to try numeric comparison + tolerance: Tolerance for float comparison + integer_only: If True, only do integer comparison + + Returns: + True if equivalent by any method + """ + if str1 is None and str2 is None: + return True + if str1 is None or str2 is None: + return False + + try: + # Try string comparison first + ss1 = normalizer(str1) + ss2 = normalizer(str2) + if ss1 == ss2: + return True + + # Try numeric comparison + if try_numeric: + return is_equiv_numeric(ss1, ss2, tolerance, integer_only) + + return False + except Exception: + return str1 == str2 diff --git a/research/eval/tasks/aime24_grail/utils.py b/research/eval/tasks/aime24_grail/utils.py deleted file mode 100644 index 1f655511..00000000 --- a/research/eval/tasks/aime24_grail/utils.py +++ /dev/null @@ -1,180 +0,0 @@ -"""AIME 2024 evaluation utilities for GRAIL reasoning models. - -Extracts answers from ... tags and uses robust -integer comparison for AIME answers (which are always 0-999). -""" - -import re - - -def process_results(doc: dict, results: list[str]) -> dict[str, int]: - """Process model output and compare with target answer. - - AIME answers are always integers from 000-999. - """ - retval = 0 - response = results[0] - - # Extract answer from ... tags first (for reasoning models) - solution_match = re.search(r"(.*?)", response, re.DOTALL) - if solution_match: - answer = solution_match.group(1).strip() - else: - # Fallback: try to extract from $...$ format - indices = [pos for pos, char in enumerate(response) if char == "$"] - if len(indices) >= 2: - answer = response[indices[0] + 1 : indices[-1]] - else: - # Fallback: try to extract from \boxed{} - boxed_answer = last_boxed_only_string(response) - if boxed_answer is not None: - try: - answer = remove_boxed(boxed_answer) - except (AssertionError, IndexError): - answer = response - else: - answer = response - - # Get target answer - answer_key = next((k for k in doc.keys() if k.lower() == "answer"), None) - if answer_key is None: - return {"exact_match": 0} - - target = str(doc[answer_key]) - - # AIME answers are integers 0-999, so try integer comparison - if is_equiv(answer, target): - retval = 1 - - return {"exact_match": retval} - - -def is_equiv(str1: str, str2: str, verbose: bool = False) -> bool: - """Check if two answers are equivalent. - - For AIME, answers are integers 0-999. We try to extract and compare integers. - """ - if str1 is None and str2 is None: - return True - if str1 is None or str2 is None: - return False - - try: - # Clean and normalize strings - ss1 = strip_string(str1) - ss2 = strip_string(str2) - - if verbose: - print(f"Comparing: '{ss1}' vs '{ss2}'") - - # Direct string comparison - if ss1 == ss2: - return True - - # Try integer comparison (AIME answers are always integers) - try: - int1 = extract_integer(ss1) - int2 = extract_integer(ss2) - if int1 is not None and int2 is not None: - return int1 == int2 - except (ValueError, TypeError): - pass - - return False - except Exception: - return str1 == str2 - - -def extract_integer(s: str) -> int: - """Extract integer from string, handling common formats.""" - if s is None: - return None - - s = s.strip() - - # Remove common wrappers - s = re.sub(r"\\boxed\{([^}]*)\}", r"\1", s) - s = re.sub(r"\$([^$]*)\$", r"\1", s) - s = s.strip() - - # Try direct integer parse - try: - return int(s) - except ValueError: - pass - - # Try to find integers in the string - matches = re.findall(r"-?\d+", s) - if matches: - # Return the last integer found (usually the final answer) - return int(matches[-1]) - - return None - - -def remove_boxed(s: str) -> str: - """Remove \\boxed{} wrapper from string.""" - if "\\boxed " in s: - left = "\\boxed " - assert s[: len(left)] == left - return s[len(left) :] - - left = "\\boxed{" - assert s[: len(left)] == left - assert s[-1] == "}" - return s[len(left) : -1] - - -def last_boxed_only_string(string: str) -> str: - """Extract the last \\boxed{} content from string.""" - idx = string.rfind("\\boxed") - if "\\boxed " in string: - return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] - if idx < 0: - idx = string.rfind("\\fbox") - if idx < 0: - return None - - i = idx - right_brace_idx = None - num_left_braces_open = 0 - while i < len(string): - if string[i] == "{": - num_left_braces_open += 1 - if string[i] == "}": - num_left_braces_open -= 1 - if num_left_braces_open == 0: - right_brace_idx = i - break - i += 1 - - if right_brace_idx is None: - return None - return string[idx : right_brace_idx + 1] - - -def strip_string(string: str) -> str: - """Normalize string for comparison.""" - if string is None: - return "" - - # Remove linebreaks - string = string.replace("\n", "") - - # Remove common LaTeX - string = string.replace("\\!", "") - string = string.replace("\\\\", "\\") - string = string.replace("\\left", "") - string = string.replace("\\right", "") - - # Remove dollar signs - string = string.replace("$", "") - string = string.replace("\\$", "") - - # Remove spaces - string = string.replace(" ", "") - - # Remove leading zeros for integer comparison - string = string.lstrip("0") or "0" - - return string diff --git a/research/eval/tasks/aime24_grail/aime24_grail.yaml b/research/eval/tasks/aime24_thinking/aime24_thinking.yaml similarity index 96% rename from research/eval/tasks/aime24_grail/aime24_grail.yaml rename to research/eval/tasks/aime24_thinking/aime24_thinking.yaml index 48af0027..3d42239a 100644 --- a/research/eval/tasks/aime24_grail/aime24_grail.yaml +++ b/research/eval/tasks/aime24_thinking/aime24_thinking.yaml @@ -1,6 +1,6 @@ tag: - math_word_problems -task: aime24_grail +task: aime24_thinking dataset_path: Maxwell-Jia/AIME_2024 output_type: generate_until training_split: train diff --git a/research/eval/tasks/aime24_thinking/utils.py b/research/eval/tasks/aime24_thinking/utils.py new file mode 100644 index 00000000..1425d0a9 --- /dev/null +++ b/research/eval/tasks/aime24_thinking/utils.py @@ -0,0 +1,62 @@ +"""AIME 2024 evaluation utilities for thinking models. + +Extracts answers from ... tags and uses robust +integer comparison for AIME answers (which are always 0-999). +""" + +import sys +from pathlib import Path + +# Add parent directory to path for _common import +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from _common import ( + extract_answer_cascade, + is_equiv_combined, + strip_string_basic, +) + + +def process_results(doc: dict, results: list[str]) -> dict[str, int]: + """Process model output and compare with target answer. + + AIME answers are always integers from 000-999. + """ + response = results[0] + + # Extract answer using cascade (SOLUTION tag -> boxed -> dollar -> raw) + answer = extract_answer_cascade( + response, + try_solution_tag=True, + try_boxed=True, + try_dollar=True, + ) + + # Get target answer + answer_key = next((k for k in doc.keys() if k.lower() == "answer"), None) + if answer_key is None: + return {"exact_match": 0} + + target = str(doc[answer_key]) + + # AIME answers are integers 0-999, use integer comparison + is_correct = is_equiv_combined( + answer, + target, + normalizer=_strip_string_aime, + try_numeric=True, + integer_only=True, + ) + + return {"exact_match": 1 if is_correct else 0} + + +def _strip_string_aime(string: str) -> str: + """Normalize string for AIME comparison. + + Extends basic normalization with leading zero removal. + """ + string = strip_string_basic(string) + # Remove leading zeros for integer comparison (but keep "0") + string = string.lstrip("0") or "0" + return string diff --git a/research/eval/tasks/amc2023/utils.py b/research/eval/tasks/amc2023/utils.py index 8fc76418..23986532 100644 --- a/research/eval/tasks/amc2023/utils.py +++ b/research/eval/tasks/amc2023/utils.py @@ -4,152 +4,52 @@ Extracts answers from $...$ format, \\boxed{}, or plain numbers. """ -import re +import sys +from pathlib import Path + +# Add parent directory to path for _common import +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from _common import ( + extract_answer_cascade, + is_equiv_combined, + strip_string_basic, +) def process_results(doc: dict, results: list[str]) -> dict[str, int]: """Process model output and compare with target answer.""" - retval = 0 response = results[0] - # Try to extract answer from $...$ format first - indices = [pos for pos, char in enumerate(response) if char == "$"] - if len(indices) >= 2: - answer = response[indices[0] + 1 : indices[-1]] - else: - # Try to extract from \boxed{} - boxed_answer = last_boxed_only_string(response) - if boxed_answer is not None: - try: - answer = remove_boxed(boxed_answer) - except (AssertionError, IndexError): - answer = response - else: - answer = response + # Extract answer (no SOLUTION tags for non-thinking model) + answer = extract_answer_cascade( + response, + try_solution_tag=False, + try_boxed=True, + try_dollar=True, + ) # Get target answer target = str(doc.get("answer", "")) - # Compare answers - if is_equiv(answer, target): - retval = 1 - - return {"exact_match": retval} - - -def is_equiv(str1: str, str2: str, verbose: bool = False) -> bool: - """Check if two answers are equivalent.""" - if str1 is None and str2 is None: - return True - if str1 is None or str2 is None: - return False - - try: - ss1 = strip_string(str1) - ss2 = strip_string(str2) - - if verbose: - print(f"Comparing: '{ss1}' vs '{ss2}'") - - # Direct string comparison - if ss1 == ss2: - return True - - # Try numeric comparison - try: - num1 = extract_number(ss1) - num2 = extract_number(ss2) - if num1 is not None and num2 is not None: - # Compare as floats with tolerance - return abs(num1 - num2) < 0.01 - except (ValueError, TypeError): - pass - - return False - except Exception: - return str1 == str2 - - -def extract_number(s: str) -> float: - """Extract number from string.""" - if s is None: - return None - - s = s.strip() - - # Remove common wrappers - s = re.sub(r"\\boxed\{([^}]*)\}", r"\1", s) - s = re.sub(r"\$([^$]*)\$", r"\1", s) - s = s.strip() - - # Try direct parse - try: - return float(s) - except ValueError: - pass - - # Try to find numbers in the string - matches = re.findall(r"-?\d+\.?\d*", s) - if matches: - return float(matches[-1]) - - return None - - -def remove_boxed(s: str) -> str: - """Remove \\boxed{} wrapper from string.""" - if "\\boxed " in s: - left = "\\boxed " - assert s[: len(left)] == left - return s[len(left) :] - - left = "\\boxed{" - assert s[: len(left)] == left - assert s[-1] == "}" - return s[len(left) : -1] - - -def last_boxed_only_string(string: str) -> str: - """Extract the last \\boxed{} content from string.""" - idx = string.rfind("\\boxed") - if "\\boxed " in string: - return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] - if idx < 0: - idx = string.rfind("\\fbox") - if idx < 0: - return None - - i = idx - right_brace_idx = None - num_left_braces_open = 0 - while i < len(string): - if string[i] == "{": - num_left_braces_open += 1 - if string[i] == "}": - num_left_braces_open -= 1 - if num_left_braces_open == 0: - right_brace_idx = i - break - i += 1 + # Compare with numeric fallback + is_correct = is_equiv_combined( + answer, + target, + normalizer=_strip_string_amc, + try_numeric=True, + tolerance=0.01, + ) - if right_brace_idx is None: - return None - return string[idx : right_brace_idx + 1] + return {"exact_match": 1 if is_correct else 0} -def strip_string(string: str) -> str: - """Normalize string for comparison.""" - if string is None: - return "" +def _strip_string_amc(string: str) -> str: + """Normalize string for AMC comparison. - string = string.replace("\n", "") - string = string.replace("\\!", "") - string = string.replace("\\\\", "\\") - string = string.replace("\\left", "") - string = string.replace("\\right", "") - string = string.replace("$", "") - string = string.replace("\\$", "") - string = string.replace(" ", "") + Extends basic normalization with float->int conversion. + """ + string = strip_string_basic(string) # Handle float formatting (e.g., "27.0" -> "27") try: diff --git a/research/eval/tasks/amc2023_grail/utils.py b/research/eval/tasks/amc2023_grail/utils.py deleted file mode 100644 index 45612c18..00000000 --- a/research/eval/tasks/amc2023_grail/utils.py +++ /dev/null @@ -1,167 +0,0 @@ -"""AMC 2023 evaluation utilities for GRAIL reasoning models. - -Extracts answers from ... tags and uses robust -numeric comparison for AMC answers. -""" - -import re - - -def process_results(doc: dict, results: list[str]) -> dict[str, int]: - """Process model output and compare with target answer.""" - retval = 0 - response = results[0] - - # Extract answer from ... tags first (for reasoning models) - solution_match = re.search(r"(.*?)", response, re.DOTALL) - if solution_match: - answer = solution_match.group(1).strip() - else: - # Fallback: try to extract from $...$ format - indices = [pos for pos, char in enumerate(response) if char == "$"] - if len(indices) >= 2: - answer = response[indices[0] + 1 : indices[-1]] - else: - # Fallback: try to extract from \boxed{} - boxed_answer = last_boxed_only_string(response) - if boxed_answer is not None: - try: - answer = remove_boxed(boxed_answer) - except (AssertionError, IndexError): - answer = response - else: - answer = response - - # Get target answer - target = str(doc.get("answer", "")) - - # Compare answers - if is_equiv(answer, target): - retval = 1 - - return {"exact_match": retval} - - -def is_equiv(str1: str, str2: str, verbose: bool = False) -> bool: - """Check if two answers are equivalent.""" - if str1 is None and str2 is None: - return True - if str1 is None or str2 is None: - return False - - try: - ss1 = strip_string(str1) - ss2 = strip_string(str2) - - if verbose: - print(f"Comparing: '{ss1}' vs '{ss2}'") - - # Direct string comparison - if ss1 == ss2: - return True - - # Try numeric comparison - try: - num1 = extract_number(ss1) - num2 = extract_number(ss2) - if num1 is not None and num2 is not None: - # Compare as floats with tolerance - return abs(num1 - num2) < 0.01 - except (ValueError, TypeError): - pass - - return False - except Exception: - return str1 == str2 - - -def extract_number(s: str) -> float: - """Extract number from string.""" - if s is None: - return None - - s = s.strip() - - # Remove common wrappers - s = re.sub(r"\\boxed\{([^}]*)\}", r"\1", s) - s = re.sub(r"\$([^$]*)\$", r"\1", s) - s = s.strip() - - # Try direct parse - try: - return float(s) - except ValueError: - pass - - # Try to find numbers in the string - matches = re.findall(r"-?\d+\.?\d*", s) - if matches: - return float(matches[-1]) - - return None - - -def remove_boxed(s: str) -> str: - """Remove \\boxed{} wrapper from string.""" - if "\\boxed " in s: - left = "\\boxed " - assert s[: len(left)] == left - return s[len(left) :] - - left = "\\boxed{" - assert s[: len(left)] == left - assert s[-1] == "}" - return s[len(left) : -1] - - -def last_boxed_only_string(string: str) -> str: - """Extract the last \\boxed{} content from string.""" - idx = string.rfind("\\boxed") - if "\\boxed " in string: - return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] - if idx < 0: - idx = string.rfind("\\fbox") - if idx < 0: - return None - - i = idx - right_brace_idx = None - num_left_braces_open = 0 - while i < len(string): - if string[i] == "{": - num_left_braces_open += 1 - if string[i] == "}": - num_left_braces_open -= 1 - if num_left_braces_open == 0: - right_brace_idx = i - break - i += 1 - - if right_brace_idx is None: - return None - return string[idx : right_brace_idx + 1] - - -def strip_string(string: str) -> str: - """Normalize string for comparison.""" - if string is None: - return "" - - string = string.replace("\n", "") - string = string.replace("\\!", "") - string = string.replace("\\\\", "\\") - string = string.replace("\\left", "") - string = string.replace("\\right", "") - string = string.replace("$", "") - string = string.replace("\\$", "") - string = string.replace(" ", "") - - # Handle float formatting (e.g., "27.0" -> "27") - try: - num = float(string) - if num == int(num): - string = str(int(num)) - except ValueError: - pass - - return string diff --git a/research/eval/tasks/amc2023_grail/amc2023_grail.yaml b/research/eval/tasks/amc2023_thinking/amc2023_thinking.yaml similarity index 96% rename from research/eval/tasks/amc2023_grail/amc2023_grail.yaml rename to research/eval/tasks/amc2023_thinking/amc2023_thinking.yaml index 995537b5..5e93452a 100644 --- a/research/eval/tasks/amc2023_grail/amc2023_grail.yaml +++ b/research/eval/tasks/amc2023_thinking/amc2023_thinking.yaml @@ -1,6 +1,6 @@ tag: - math_word_problems -task: amc2023_grail +task: amc2023_thinking dataset_path: sparkle-reasoning/amc2023 output_type: generate_until test_split: test diff --git a/research/eval/tasks/amc2023_thinking/utils.py b/research/eval/tasks/amc2023_thinking/utils.py new file mode 100644 index 00000000..e060fb0c --- /dev/null +++ b/research/eval/tasks/amc2023_thinking/utils.py @@ -0,0 +1,62 @@ +"""AMC 2023 evaluation utilities for thinking models. + +Extracts answers from ... tags and uses robust +numeric comparison for AMC answers. +""" + +import sys +from pathlib import Path + +# Add parent directory to path for _common import +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from _common import ( + extract_answer_cascade, + is_equiv_combined, + strip_string_basic, +) + + +def process_results(doc: dict, results: list[str]) -> dict[str, int]: + """Process model output and compare with target answer.""" + response = results[0] + + # Extract answer using cascade (SOLUTION tag first for thinking models) + answer = extract_answer_cascade( + response, + try_solution_tag=True, + try_boxed=True, + try_dollar=True, + ) + + # Get target answer + target = str(doc.get("answer", "")) + + # Compare with numeric fallback + is_correct = is_equiv_combined( + answer, + target, + normalizer=_strip_string_amc, + try_numeric=True, + tolerance=0.01, + ) + + return {"exact_match": 1 if is_correct else 0} + + +def _strip_string_amc(string: str) -> str: + """Normalize string for AMC comparison. + + Extends basic normalization with float->int conversion. + """ + string = strip_string_basic(string) + + # Handle float formatting (e.g., "27.0" -> "27") + try: + num = float(string) + if num == int(num): + string = str(int(num)) + except ValueError: + pass + + return string diff --git a/research/eval/tasks/gsm8k_grail/utils.py b/research/eval/tasks/gsm8k_grail/utils.py deleted file mode 100644 index be72529c..00000000 --- a/research/eval/tasks/gsm8k_grail/utils.py +++ /dev/null @@ -1,140 +0,0 @@ -"""GSM8K evaluation utilities for GRAIL reasoning models. - -Extracts answers from ... tags and compares with -the ground truth answer (after ####). -""" - -import re - - -def doc_to_target(doc: dict) -> str: - """Convert document to target format for GRAIL reasoning.""" - answer = doc["answer"] - # Extract final answer after #### - if "####" in answer: - final_answer = answer.split("####")[-1].strip() - else: - final_answer = answer.strip() - - # Extract the reasoning part (before ####) - if "####" in answer: - reasoning = answer.split("####")[0].strip() - else: - reasoning = "" - - return ( - f"\n{reasoning}\n\n{final_answer}" - ) - - -def process_results(doc: dict, results: list[str]) -> dict[str, int]: - """Process model output and compare with target answer.""" - retval = 0 - response = results[0] - - # Extract answer from ... tags first (for reasoning models) - solution_match = re.search(r"(.*?)", response, re.DOTALL) - if solution_match: - answer = solution_match.group(1).strip() - else: - # Fallback: try to extract number from the end of response - answer = extract_last_number(response) - - # Get target answer from document - target_answer = doc["answer"] - if "####" in target_answer: - target = target_answer.split("####")[-1].strip() - else: - target = target_answer.strip() - - # Compare answers - if is_equiv(answer, target): - retval = 1 - - return {"exact_match": retval} - - -def is_equiv(str1: str, str2: str, verbose: bool = False) -> bool: - """Check if two answers are equivalent.""" - if str1 is None and str2 is None: - return True - if str1 is None or str2 is None: - return False - - try: - # Clean both strings - s1 = clean_answer(str1) - s2 = clean_answer(str2) - - if verbose: - print(f"Comparing: '{s1}' vs '{s2}'") - - # Direct string comparison - if s1 == s2: - return True - - # Try numeric comparison - try: - num1 = extract_number(s1) - num2 = extract_number(s2) - if num1 is not None and num2 is not None: - return abs(num1 - num2) < 0.001 - except (ValueError, TypeError): - pass - - return False - except Exception: - return str1 == str2 - - -def clean_answer(s: str) -> str: - """Clean answer string for comparison.""" - if s is None: - return "" - - s = s.strip() - # Remove dollar signs, commas, and common formatting - s = s.replace("$", "").replace(",", "").replace(" ", "") - # Remove trailing period - s = s.rstrip(".") - - return s - - -def extract_number(s: str) -> float: - """Extract number from string.""" - if s is None: - return None - - s = clean_answer(s) - - # Try direct parse - try: - return float(s) - except ValueError: - pass - - # Try to find numbers in the string - matches = re.findall(r"-?\d+\.?\d*", s) - if matches: - return float(matches[-1]) - - return None - - -def extract_last_number(s: str) -> str: - """Extract the last number from a string.""" - if s is None: - return "" - - # Look for #### pattern first (GSM8K format) - if "####" in s: - return s.split("####")[-1].strip() - - # Find all numbers - matches = re.findall(r"-?\d+(?:,\d{3})*(?:\.\d+)?", s) - if matches: - # Return last number, removing commas - return matches[-1].replace(",", "") - - return s.strip() diff --git a/research/eval/tasks/gsm8k_grail/gsm8k_grail.yaml b/research/eval/tasks/gsm8k_thinking/gsm8k_thinking.yaml similarity index 96% rename from research/eval/tasks/gsm8k_grail/gsm8k_grail.yaml rename to research/eval/tasks/gsm8k_thinking/gsm8k_thinking.yaml index c7fa8eba..dabcb685 100644 --- a/research/eval/tasks/gsm8k_grail/gsm8k_grail.yaml +++ b/research/eval/tasks/gsm8k_thinking/gsm8k_thinking.yaml @@ -1,6 +1,6 @@ tag: - math_word_problems -task: gsm8k_grail +task: gsm8k_thinking dataset_path: gsm8k dataset_name: main output_type: generate_until diff --git a/research/eval/tasks/gsm8k_thinking/utils.py b/research/eval/tasks/gsm8k_thinking/utils.py new file mode 100644 index 00000000..7ff1ce83 --- /dev/null +++ b/research/eval/tasks/gsm8k_thinking/utils.py @@ -0,0 +1,95 @@ +"""GSM8K evaluation utilities for thinking models. + +Extracts answers from ... tags and compares with +the ground truth answer (after ####). +""" + +import re +import sys +from pathlib import Path + +# Add parent directory to path for _common import +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from _common import ( + extract_solution_tag, + is_equiv_combined, +) + + +def doc_to_target(doc: dict) -> str: + """Convert document to target format for thinking.""" + answer = doc["answer"] + + # Extract final answer after #### + if "####" in answer: + final_answer = answer.split("####")[-1].strip() + reasoning = answer.split("####")[0].strip() + else: + final_answer = answer.strip() + reasoning = "" + + return ( + f"\n{reasoning}\n\n{final_answer}" + ) + + +def process_results(doc: dict, results: list[str]) -> dict[str, int]: + """Process model output and compare with target answer.""" + response = results[0] + + # Extract answer from tags first (reasoning models) + answer = extract_solution_tag(response) + if answer is None: + # Fallback: try to extract number from the end + answer = _extract_last_number(response) + + # Get target answer from document + target_answer = doc["answer"] + if "####" in target_answer: + target = target_answer.split("####")[-1].strip() + else: + target = target_answer.strip() + + # Compare answers + is_correct = is_equiv_combined( + answer, + target, + normalizer=_clean_answer, + try_numeric=True, + tolerance=0.001, + ) + + return {"exact_match": 1 if is_correct else 0} + + +def _clean_answer(s: str) -> str: + """Clean answer string for GSM8K comparison.""" + if s is None: + return "" + + s = s.strip() + # Remove dollar signs, commas, and common formatting + s = s.replace("$", "").replace(",", "").replace(" ", "") + # Remove trailing period + s = s.rstrip(".") + + return s + + +def _extract_last_number(s: str) -> str: + """Extract the last number from a string.""" + if s is None: + return "" + + # Look for #### pattern first (GSM8K format) + if "####" in s: + return s.split("####")[-1].strip() + + # Find all numbers + matches = re.findall(r"-?\d+(?:,\d{3})*(?:\.\d+)?", s) + if matches: + # Return last number, removing commas + return matches[-1].replace(",", "") + + return s.strip() diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail.yaml b/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail.yaml deleted file mode 100644 index c23e1ba4..00000000 --- a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail.yaml +++ /dev/null @@ -1,15 +0,0 @@ -group: hendrycks_math_grail -task: - - hendrycks_math_grail_algebra - - hendrycks_math_grail_counting_and_prob - - hendrycks_math_grail_geometry - - hendrycks_math_grail_intermediate_algebra - - hendrycks_math_grail_num_theory - - hendrycks_math_grail_prealgebra - - hendrycks_math_grail_precalc -aggregate_metric_list: - - metric: exact_match - aggregation: mean - weight_by_size: true -metadata: - version: 1.0 diff --git a/research/eval/tasks/hendrycks_math_grail/utils.py b/research/eval/tasks/hendrycks_math_grail/utils.py deleted file mode 100644 index 0853668d..00000000 --- a/research/eval/tasks/hendrycks_math_grail/utils.py +++ /dev/null @@ -1,251 +0,0 @@ -"""Custom utils for GRAIL reasoning model evaluation on MATH. - -Extracts answers from ... tags instead of \\boxed{}. -""" - -import re - -import datasets - - -def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: - """Process dataset docs - extract ground truth answer from \\boxed{}.""" - - def _process_doc(doc: dict) -> dict: - out_doc = { - "problem": doc["problem"], - "solution": doc["solution"], - "answer": remove_boxed(last_boxed_only_string(doc["solution"])), - } - return out_doc - - return dataset.map(_process_doc) - - -def extract_solution_tag(text: str) -> str: - """Extract content from ... tags.""" - match = re.search(r"([\s\S]*?)", text) - if match: - return match.group(1).strip() - # Fallback: return original text if no tags found - return text.strip() - - -def process_results(doc: dict, results: list[str]) -> dict[str, int]: - """Process results - extract answer from tags and compare.""" - retval = 0 - - # Extract from tags - model_answer = extract_solution_tag(results[0]) - - # Get ground truth (already extracted from \boxed{} in process_docs) - ground_truth = doc.get("answer", remove_boxed(last_boxed_only_string(doc["solution"]))) - - if is_equiv(model_answer, ground_truth): - retval = 1 - - return {"exact_match": retval} - - -# ============================================================================ -# String normalization functions (from lm-eval hendrycks_math/utils.py) -# ============================================================================ - - -def is_equiv(str1: str, str2: str, verbose: bool = False) -> bool: - """Check if two strings are equivalent after normalization.""" - if str1 is None and str2 is None: - return True - if str1 is None or str2 is None: - return False - - try: - ss1 = strip_string(str1) - ss2 = strip_string(str2) - if verbose: - print(ss1, ss2) - return ss1 == ss2 - except Exception: - return str1 == str2 - - -def remove_boxed(s: str) -> str: - """Remove \\boxed{} wrapper from string.""" - if s is None: - return None - if "\\boxed " in s: - left = "\\boxed " - if s[: len(left)] == left: - return s[len(left) :] - - left = "\\boxed{" - if s[: len(left)] == left and s[-1] == "}": - return s[len(left) : -1] - - return s - - -def last_boxed_only_string(string: str) -> str: - """Extract the last \\boxed{} or \\fbox{} from a string.""" - idx = string.rfind("\\boxed") - if "\\boxed " in string: - return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] - if idx < 0: - idx = string.rfind("\\fbox") - if idx < 0: - return None - - i = idx - right_brace_idx = None - num_left_braces_open = 0 - while i < len(string): - if string[i] == "{": - num_left_braces_open += 1 - if string[i] == "}": - num_left_braces_open -= 1 - if num_left_braces_open == 0: - right_brace_idx = i - break - i += 1 - - if right_brace_idx is None: - return None - return string[idx : right_brace_idx + 1] - - -def fix_fracs(string: str) -> str: - """Fix fraction formatting.""" - substrs = string.split("\\frac") - new_str = substrs[0] - if len(substrs) > 1: - substrs = substrs[1:] - for substr in substrs: - new_str += "\\frac" - if substr[0] == "{": - new_str += substr - else: - try: - assert len(substr) >= 2 - except AssertionError: - return string - a = substr[0] - b = substr[1] - if b != "{": - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}{" + b + "}" + post_substr - else: - new_str += "{" + a + "}{" + b + "}" - else: - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}" + b + post_substr - else: - new_str += "{" + a + "}" + b - return new_str - - -def fix_a_slash_b(string: str) -> str: - """Convert a/b to \\frac{a}{b}.""" - if len(string.split("/")) != 2: - return string - a = string.split("/")[0] - b = string.split("/")[1] - try: - a = int(a) - b = int(b) - assert string == f"{a}/{b}" - return "\\frac{" + str(a) + "}{" + str(b) + "}" - except (AssertionError, ValueError): - return string - - -def remove_right_units(string: str) -> str: - """Remove units on the right side.""" - if "\\text{ " in string: - splits = string.split("\\text{ ") - if len(splits) == 2: - return splits[0] - return string - - -def fix_sqrt(string: str) -> str: - """Fix sqrt formatting.""" - if "\\sqrt" not in string: - return string - splits = string.split("\\sqrt") - new_string = splits[0] - for split in splits[1:]: - if split[0] != "{": - a = split[0] - new_substr = "\\sqrt{" + a + "}" + split[1:] - else: - new_substr = "\\sqrt" + split - new_string += new_substr - return new_string - - -def strip_string(string: str) -> str: - """Normalize string for comparison.""" - # linebreaks - string = string.replace("\n", "") - - # remove inverse spaces - string = string.replace("\\!", "") - - # replace \\ with \ - string = string.replace("\\\\", "\\") - - # replace tfrac and dfrac with frac - string = string.replace("tfrac", "frac") - string = string.replace("dfrac", "frac") - - # remove \left and \right - string = string.replace("\\left", "") - string = string.replace("\\right", "") - - # Remove circ (degrees) - string = string.replace("^{\\circ}", "") - string = string.replace("^\\circ", "") - - # remove dollar signs - string = string.replace("\\$", "") - - # remove units (on the right) - string = remove_right_units(string) - - # remove percentage - string = string.replace("\\%", "") - string = string.replace("\%", "") # noqa: W605 - - # " 0." equivalent to " ." and "{0." equivalent to "{." - string = string.replace(" .", " 0.") - string = string.replace("{.", "{0.") - - if len(string) == 0: - return string - if string[0] == ".": - string = "0" + string - - # get rid of e.g. "k = " or "q = " at beginning - if len(string.split("=")) == 2: - if len(string.split("=")[0]) <= 2: - string = string.split("=")[1] - - # fix sqrt3 --> sqrt{3} - string = fix_sqrt(string) - - # remove spaces - string = string.replace(" ", "") - - # fix fractions - string = fix_fracs(string) - - # manually change 0.5 --> \frac{1}{2} - if string == "0.5": - string = "\\frac{1}{2}" - - # X/Y changed to \frac{X}{Y} - string = fix_a_slash_b(string) - - return string diff --git a/research/eval/tasks/hendrycks_math_grail_pass_at_k/utils.py b/research/eval/tasks/hendrycks_math_grail_pass_at_k/utils.py deleted file mode 100644 index 565fe2a5..00000000 --- a/research/eval/tasks/hendrycks_math_grail_pass_at_k/utils.py +++ /dev/null @@ -1,249 +0,0 @@ -"""Pass@k utilities for MATH benchmark evaluation. - -Implements the standard pass@k metric: given n samples, compute the probability -that at least one of k random samples is correct. - -Formula: pass@k = 1 - C(n-c, k) / C(n, k) -where n = total samples, c = correct samples, k = samples to consider -""" - -import re -from math import comb -from typing import Any - - -def pass_at_k( - references: list[str], - predictions: list[list[str]], - k: list[int] | None = None, -) -> dict[str, float]: - """Compute pass@k for math problems. - - Args: - references: List of ground truth answers - predictions: List of lists of model predictions (n samples per problem) - k: List of k values to compute (e.g., [1, 5, 10]) - - Returns: - Dictionary with pass@k scores for each k value - """ - if k is None: - k = [1, 5] - if isinstance(k, int): - k = [k] - - results = {} - for k_val in k: - pass_at_k_scores = [] - for ref, preds in zip(references, predictions, strict=False): - n = len(preds) - if n < k_val: - # If we have fewer samples than k, use what we have - c = sum(1 for p in preds if is_equiv(extract_answer(p), ref)) - score = 1.0 if c > 0 else 0.0 - else: - c = sum(1 for p in preds if is_equiv(extract_answer(p), ref)) - score = _pass_at_k(n, c, k_val) - pass_at_k_scores.append(score) - - avg_score = sum(pass_at_k_scores) / len(pass_at_k_scores) if pass_at_k_scores else 0.0 - results[f"pass@{k_val}"] = avg_score - - return results - - -def _pass_at_k(n: int, c: int, k: int) -> float: - """Compute pass@k for a single problem. - - Args: - n: Total number of samples - c: Number of correct samples - k: Number of samples to consider - - Returns: - Probability that at least one of k samples is correct - """ - if n - c < k: - return 1.0 - return 1.0 - comb(n - c, k) / comb(n, k) - - -def extract_answer(text: str) -> str: - """Extract answer from model output. - - Tries multiple extraction patterns in order: - 1. ... tags (for reasoning models) - 2. \\boxed{...} (LaTeX boxed answers) - 3. $...$ (LaTeX inline math at end) - 4. Last number in text - """ - if text is None: - return "" - - # Try tags first (reasoning models) - solution_match = re.search(r"(.*?)", text, re.DOTALL) - if solution_match: - return solution_match.group(1).strip() - - # Try \boxed{} - boxed = last_boxed_only_string(text) - if boxed: - try: - return remove_boxed(boxed) - except (AssertionError, IndexError): - pass - - # Try $...$ at end - indices = [pos for pos, char in enumerate(text) if char == "$"] - if len(indices) >= 2: - return text[indices[-2] + 1 : indices[-1]] - - # Fallback: return cleaned text - return text.strip() - - -def is_equiv(str1: str, str2: str) -> bool: - """Check if two answers are mathematically equivalent.""" - if str1 is None and str2 is None: - return True - if str1 is None or str2 is None: - return False - - try: - ss1 = strip_string(str1) - ss2 = strip_string(str2) - return ss1 == ss2 - except Exception: - return str1 == str2 - - -def remove_boxed(s: str) -> str: - """Remove \\boxed{} wrapper from string.""" - if "\\boxed " in s: - left = "\\boxed " - assert s[: len(left)] == left - return s[len(left) :] - - left = "\\boxed{" - assert s[: len(left)] == left - assert s[-1] == "}" - return s[len(left) : -1] - - -def last_boxed_only_string(string: str) -> str | None: - """Extract the last \\boxed{} content from string.""" - idx = string.rfind("\\boxed") - if "\\boxed " in string: - return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] - if idx < 0: - idx = string.rfind("\\fbox") - if idx < 0: - return None - - i = idx - right_brace_idx = None - num_left_braces_open = 0 - while i < len(string): - if string[i] == "{": - num_left_braces_open += 1 - if string[i] == "}": - num_left_braces_open -= 1 - if num_left_braces_open == 0: - right_brace_idx = i - break - i += 1 - - if right_brace_idx is None: - return None - return string[idx : right_brace_idx + 1] - - -def strip_string(string: str) -> str: - """Normalize string for comparison.""" - if string is None: - return "" - - # Remove linebreaks - string = string.replace("\n", "") - - # Remove common LaTeX - string = string.replace("\\!", "") - string = string.replace("\\\\", "\\") - string = string.replace("tfrac", "frac") - string = string.replace("dfrac", "frac") - string = string.replace("\\left", "") - string = string.replace("\\right", "") - string = string.replace("^{\\circ}", "") - string = string.replace("^\\circ", "") - - # Remove dollar signs - string = string.replace("$", "") - string = string.replace("\\$", "") - - # Remove percentage - string = string.replace("\\%", "") - string = string.replace("%", "") - - # Remove spaces - string = string.replace(" ", "") - - return string - - -# For lm-eval metric interface -def process_results(doc: dict, results: list[str]) -> dict[str, Any]: - """Process results for a single document (used by lm-eval). - - This function is called for EACH sample. The pass@k aggregation - happens in the aggregation function. - """ - answer = doc.get("answer", "") - - # Check each result - correct_list = [] - for result in results: - extracted = extract_answer(result) - is_correct = 1 if is_equiv(extracted, str(answer)) else 0 - correct_list.append(is_correct) - - return { - "pass": correct_list, # List of 0/1 for each sample - "num_correct": sum(correct_list), - "num_samples": len(correct_list), - } - - -def aggregate_pass_at_k(results: list[dict], k: int = 5) -> float: - """Aggregate pass@k across all documents. - - Args: - results: List of result dicts from process_results - k: Number of samples to consider for pass@k - - Returns: - Average pass@k score - """ - scores = [] - for r in results: - n = r["num_samples"] - c = r["num_correct"] - if n < k: - score = 1.0 if c > 0 else 0.0 - else: - score = _pass_at_k(n, c, k) - scores.append(score) - - return sum(scores) / len(scores) if scores else 0.0 - - -# Convenience aggregation functions for different k values -def aggregate_pass_at_1(results: list[dict]) -> float: - return aggregate_pass_at_k(results, k=1) - - -def aggregate_pass_at_5(results: list[dict]) -> float: - return aggregate_pass_at_k(results, k=5) - - -def aggregate_pass_at_10(results: list[dict]) -> float: - return aggregate_pass_at_k(results, k=10) diff --git a/research/eval/tasks/hendrycks_math_grail/_default_template.yaml b/research/eval/tasks/hendrycks_math_thinking/_default_template.yaml similarity index 100% rename from research/eval/tasks/hendrycks_math_grail/_default_template.yaml rename to research/eval/tasks/hendrycks_math_thinking/_default_template.yaml diff --git a/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking.yaml new file mode 100644 index 00000000..f5d05093 --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking.yaml @@ -0,0 +1,15 @@ +group: hendrycks_math_thinking +task: + - hendrycks_math_thinking_algebra + - hendrycks_math_thinking_counting_and_prob + - hendrycks_math_thinking_geometry + - hendrycks_math_thinking_intermediate_algebra + - hendrycks_math_thinking_num_theory + - hendrycks_math_thinking_prealgebra + - hendrycks_math_thinking_precalc +aggregate_metric_list: + - metric: exact_match + aggregation: mean + weight_by_size: true +metadata: + version: 1.0 diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_algebra.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_algebra.yaml similarity index 68% rename from research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_algebra.yaml rename to research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_algebra.yaml index 95fe5683..5a87439a 100644 --- a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_algebra.yaml +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_algebra.yaml @@ -1,5 +1,5 @@ include: _default_template.yaml tag: - math_word_problems -task: hendrycks_math_grail_algebra +task: hendrycks_math_thinking_algebra dataset_name: algebra diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_counting_and_prob.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_counting_and_prob.yaml similarity index 67% rename from research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_counting_and_prob.yaml rename to research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_counting_and_prob.yaml index dfa695a0..9f54e170 100644 --- a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_counting_and_prob.yaml +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_counting_and_prob.yaml @@ -1,5 +1,5 @@ include: _default_template.yaml tag: - math_word_problems -task: hendrycks_math_grail_counting_and_prob +task: hendrycks_math_thinking_counting_and_prob dataset_name: counting_and_probability diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_geometry.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_geometry.yaml similarity index 68% rename from research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_geometry.yaml rename to research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_geometry.yaml index 5743de5d..293f55a9 100644 --- a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_geometry.yaml +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_geometry.yaml @@ -1,5 +1,5 @@ include: _default_template.yaml tag: - math_word_problems -task: hendrycks_math_grail_geometry +task: hendrycks_math_thinking_geometry dataset_name: geometry diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_intermediate_algebra.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_intermediate_algebra.yaml similarity index 65% rename from research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_intermediate_algebra.yaml rename to research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_intermediate_algebra.yaml index a9db9246..7ee5914e 100644 --- a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_intermediate_algebra.yaml +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_intermediate_algebra.yaml @@ -1,5 +1,5 @@ include: _default_template.yaml tag: - math_word_problems -task: hendrycks_math_grail_intermediate_algebra +task: hendrycks_math_thinking_intermediate_algebra dataset_name: intermediate_algebra diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_num_theory.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_num_theory.yaml similarity index 68% rename from research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_num_theory.yaml rename to research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_num_theory.yaml index 95e3260a..b668341c 100644 --- a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_num_theory.yaml +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_num_theory.yaml @@ -1,5 +1,5 @@ include: _default_template.yaml tag: - math_word_problems -task: hendrycks_math_grail_num_theory +task: hendrycks_math_thinking_num_theory dataset_name: number_theory diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_prealgebra.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_prealgebra.yaml similarity index 67% rename from research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_prealgebra.yaml rename to research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_prealgebra.yaml index c8e8bde6..3c9aebc2 100644 --- a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_prealgebra.yaml +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_prealgebra.yaml @@ -1,5 +1,5 @@ include: _default_template.yaml tag: - math_word_problems -task: hendrycks_math_grail_prealgebra +task: hendrycks_math_thinking_prealgebra dataset_name: prealgebra diff --git a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_precalc.yaml b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_precalc.yaml similarity index 69% rename from research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_precalc.yaml rename to research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_precalc.yaml index 81594a08..827992b2 100644 --- a/research/eval/tasks/hendrycks_math_grail/hendrycks_math_grail_precalc.yaml +++ b/research/eval/tasks/hendrycks_math_thinking/hendrycks_math_thinking_precalc.yaml @@ -1,5 +1,5 @@ include: _default_template.yaml tag: - math_word_problems -task: hendrycks_math_grail_precalc +task: hendrycks_math_thinking_precalc dataset_name: precalculus diff --git a/research/eval/tasks/hendrycks_math_thinking/utils.py b/research/eval/tasks/hendrycks_math_thinking/utils.py new file mode 100644 index 00000000..d58371bd --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking/utils.py @@ -0,0 +1,54 @@ +"""Custom utils for thinking model evaluation on MATH. + +Extracts answers from ... tags instead of \\boxed{}. +""" + +import sys +from pathlib import Path + +import datasets + +# Add parent directory to path for _common import +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from _common import ( + extract_solution_tag, + is_equiv_string, + last_boxed_only_string, + remove_boxed, + strip_string_math, +) + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + """Process dataset docs - extract ground truth answer from \\boxed{}.""" + + def _process_doc(doc: dict) -> dict: + boxed = last_boxed_only_string(doc["solution"]) + answer = remove_boxed(boxed) if boxed else "" + return { + "problem": doc["problem"], + "solution": doc["solution"], + "answer": answer, + } + + return dataset.map(_process_doc) + + +def process_results(doc: dict, results: list[str]) -> dict[str, int]: + """Process results - extract answer from tags and compare.""" + # Extract from tags + model_answer = extract_solution_tag(results[0]) + if model_answer is None: + model_answer = results[0].strip() + + # Get ground truth (already extracted from \boxed{} in process_docs) + ground_truth = doc.get("answer") + if ground_truth is None: + boxed = last_boxed_only_string(doc["solution"]) + ground_truth = remove_boxed(boxed) if boxed else "" + + # Compare using full math normalization + is_correct = is_equiv_string(model_answer, ground_truth, strip_string_math) + + return {"exact_match": 1 if is_correct else 0} diff --git a/research/eval/tasks/hendrycks_math_grail_pass_at_k/hendrycks_math_pass_at_5.yaml b/research/eval/tasks/hendrycks_math_thinking_pass_at_k/hendrycks_math_thinking_pass_at_5.yaml similarity index 90% rename from research/eval/tasks/hendrycks_math_grail_pass_at_k/hendrycks_math_pass_at_5.yaml rename to research/eval/tasks/hendrycks_math_thinking_pass_at_k/hendrycks_math_thinking_pass_at_5.yaml index 6870678b..ffcaad7c 100644 --- a/research/eval/tasks/hendrycks_math_grail_pass_at_k/hendrycks_math_pass_at_5.yaml +++ b/research/eval/tasks/hendrycks_math_thinking_pass_at_k/hendrycks_math_thinking_pass_at_5.yaml @@ -2,8 +2,8 @@ # Generates 10 samples per problem with temperature sampling # Computes pass@1, pass@5 -group: hendrycks_math_pass_at_k -task: hendrycks_math_pass_at_5 +group: hendrycks_math_thinking_pass_at_k +task: hendrycks_math_thinking_pass_at_5 dataset_path: EleutherAI/hendrycks_math dataset_name: algebra output_type: generate_until diff --git a/research/eval/tasks/hendrycks_math_thinking_pass_at_k/utils.py b/research/eval/tasks/hendrycks_math_thinking_pass_at_k/utils.py new file mode 100644 index 00000000..36dec40c --- /dev/null +++ b/research/eval/tasks/hendrycks_math_thinking_pass_at_k/utils.py @@ -0,0 +1,104 @@ +"""Pass@k utilities for MATH benchmark evaluation. + +Implements the standard pass@k metric: given n samples, compute the probability +that at least one of k random samples is correct. + +Formula: pass@k = 1 - C(n-c, k) / C(n, k) +where n = total samples, c = correct samples, k = samples to consider +""" + +import sys +from math import comb +from pathlib import Path +from typing import Any + +# Add parent directory to path for _common import +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from _common import ( + extract_answer_cascade, + is_equiv_string, + strip_string_math, +) + + +def process_results(doc: dict, results: list[str]) -> dict[str, Any]: + """Process results for a single document (used by lm-eval). + + This function is called for EACH sample. The pass@k aggregation + happens in the aggregation function. + """ + answer = str(doc.get("answer", "")) + + # Check each result + correct_list = [] + for result in results: + extracted = extract_answer_cascade( + result, + try_solution_tag=True, + try_boxed=True, + try_dollar=True, + ) + is_correct = 1 if is_equiv_string(extracted, answer, strip_string_math) else 0 + correct_list.append(is_correct) + + return { + "pass": correct_list, + "num_correct": sum(correct_list), + "num_samples": len(correct_list), + } + + +def _pass_at_k(n: int, c: int, k: int) -> float: + """Compute pass@k for a single problem. + + Args: + n: Total number of samples + c: Number of correct samples + k: Number of samples to consider + + Returns: + Probability that at least one of k samples is correct + """ + if n - c < k: + return 1.0 + return 1.0 - comb(n - c, k) / comb(n, k) + + +def aggregate_pass_at_k(results: list[dict], k: int = 5) -> float: + """Aggregate pass@k across all documents. + + Args: + results: List of result dicts from process_results + k: Number of samples to consider for pass@k + + Returns: + Average pass@k score + """ + scores = [] + for r in results: + n = r["num_samples"] + c = r["num_correct"] + if n < k: + score = 1.0 if c > 0 else 0.0 + else: + score = _pass_at_k(n, c, k) + scores.append(score) + + return sum(scores) / len(scores) if scores else 0.0 + + +# Convenience aggregation functions for different k values +def aggregate_pass_at_1(results: list[dict]) -> float: + """Aggregate pass@1 score.""" + return aggregate_pass_at_k(results, k=1) + + +def aggregate_pass_at_5(results: list[dict]) -> float: + """Aggregate pass@5 score.""" + return aggregate_pass_at_k(results, k=5) + + +def aggregate_pass_at_10(results: list[dict]) -> float: + """Aggregate pass@10 score.""" + return aggregate_pass_at_k(results, k=10) From 7f0d291284d96ba35c6009f41255d94d47fe614a Mon Sep 17 00:00:00 2001 From: ErfanMhi Date: Fri, 5 Dec 2025 21:01:36 +0000 Subject: [PATCH 16/17] feat: connect importance_sampling_level to GRPOConfig --- research/trl/train_trl_grpo.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/research/trl/train_trl_grpo.py b/research/trl/train_trl_grpo.py index ae818ad0..104d358e 100644 --- a/research/trl/train_trl_grpo.py +++ b/research/trl/train_trl_grpo.py @@ -80,7 +80,7 @@ class Config: # Total training windows (GRAIL_TRAINER_TOTAL_WINDOWS) - controls iteration count # Each optimizer step = 32 groups × 16 rollouts = 512 samples # total_optimizer_steps calculated below based on total_windows - total_windows: int = 100 + total_steps: int = 400 # ──────────────────────────────────────────────────────────────────────── # GRPO Loss Configuration (from grail/trainer/algorithms/grpo.py) @@ -1026,7 +1026,7 @@ def main() -> None: print(f" {'Max Completion Tokens':<40} {cfg.max_new_tokens:<15} GRPO_MAX_COMPLETION_TOKENS") print(f" {'Gradient Clip':<40} {cfg.grad_clip:<15} GRAIL_TRAINER_GRAD_CLIP") print(f" {'Warmup Steps':<40} {cfg.warmup_steps:<15} GRAIL_TRAINER_WARMUP_STEPS") - print(f" {'Total Windows':<40} {cfg.total_windows:<15} GRAIL_TRAINER_TOTAL_WINDOWS") + print(f" {'Total Steps':<40} {cfg.total_steps:<15} GRAIL_TRAINER_TOTAL_STEPS") print(f" {'KL Coefficient':<40} {cfg.kl_coef:<15} GRAIL_TRAINER_KL_COEF") print(f" {'Entropy Coefficient':<40} {cfg.entropy_coef:<15} GRAIL_TRAINER_ENTROPY_COEF") print(f" {'PPO Clip Epsilon':<40} {cfg.ppo_clip_eps:<15} TRAINER_PPO_CLIP_EPS") @@ -1091,7 +1091,7 @@ def main() -> None: # = 32 groups × 16 rollouts effective_batch = cfg.batch_size * cfg.grad_accum_steps # 4 × 128 = 512 groups_per_step = effective_batch // cfg.rollouts_per_problem # 512 / 16 = 32 - total_optimizer_steps = 320 # Fixed: maintains original training duration + total_optimizer_steps = cfg.total_steps # Fixed: maintains original training duration print("\n📊 Training Schedule:") print(f" • Effective batch size: {effective_batch} samples") @@ -1129,6 +1129,11 @@ def main() -> None: # ───────────────────────────────────────────────────────────────────── max_prompt_length=max_prompt_length, # max_length - max_completion_tokens max_completion_length=cfg.max_new_tokens, # GRPO_MAX_COMPLETION_TOKENS + + # ───────────────────────────────────────────────────────────────────── + # Importance Sampling Level + # ───────────────────────────────────────────────────────────────────── + importance_sampling_level=cfg.importance_sampling_level, # GRAIL_IMPORTANCE_SAMPLING_LEVEL # ───────────────────────────────────────────────────────────────────── # Generation Parameters # ───────────────────────────────────────────────────────────────────── @@ -1160,7 +1165,6 @@ def main() -> None: use_vllm=True, vllm_mode="server", vllm_server_base_url="http://127.0.0.1:8000", - # Importance sampling configuration (matching GRAIL_IMPORTANCE_SAMPLING_LEVEL=token) vllm_importance_sampling_correction=False, vllm_importance_sampling_cap=cfg.is_ratio_max, # GRAIL_TRAINER_IS_RATIO_MAX ) From 7f700f2c7b91816fb9528dfe82ebc5da0d34538b Mon Sep 17 00:00:00 2001 From: ErfanMhi Date: Sat, 6 Dec 2025 00:42:25 +0000 Subject: [PATCH 17/17] chore: resolve minor ci linting issues --- research/trl/train_trl_grpo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/research/trl/train_trl_grpo.py b/research/trl/train_trl_grpo.py index 104d358e..d3f0d042 100644 --- a/research/trl/train_trl_grpo.py +++ b/research/trl/train_trl_grpo.py @@ -1129,7 +1129,6 @@ def main() -> None: # ───────────────────────────────────────────────────────────────────── max_prompt_length=max_prompt_length, # max_length - max_completion_tokens max_completion_length=cfg.max_new_tokens, # GRPO_MAX_COMPLETION_TOKENS - # ───────────────────────────────────────────────────────────────────── # Importance Sampling Level # ─────────────────────────────────────────────────────────────────────