diff --git a/environments/README.md b/environments/README.md index 503d65b37..1cf7ce320 100644 --- a/environments/README.md +++ b/environments/README.md @@ -11,6 +11,7 @@ This folder contains installable example environments that showcase common usage ### SingleTurnEnv (prompt → single response) - **gsm8k**: Classic QA with exact-match reward and optional response-format reward. +- **gsm_infinite**: GSM-Infinite long-context arithmetic datasets with exact final-answer scoring. - **reverse_text**: XML formatting with non-binary LCS reward + format reward. - **continuation_quality**: Completion-style generation (`message_type="completion"`) judged for prose quality with `JudgeRubric`. - **mmmu**: Multimodal inputs (image + text) packed in chat content; single-turn boxed-answer check. diff --git a/environments/gsm_infinite/README.md b/environments/gsm_infinite/README.md new file mode 100644 index 000000000..3fb9fc318 --- /dev/null +++ b/environments/gsm_infinite/README.md @@ -0,0 +1,28 @@ +# GSM-Infinite + +GSM-Infinite is a train/eval environment for Infini-AI-Lab's scalable +long-context math reasoning benchmark. It loads the public Hugging Face datasets +from the GSM-Infinite collection and normalizes rows into Verifiers +`question`/`answer`/`info` examples. + +By default the environment uses the medium zero-context dataset and the +`ops_2` split: + +```bash +prime eval run gsm-infinite \ + -a '{"subset": "medium", "context_length": "0", "operations": 2, "num_eval_examples": 20}' +``` + +Useful arguments: + +| Arg | Default | Description | +| --- | --- | --- | +| `subset` | `medium` | One of `symbolic`, `medium`, or `hard` | +| `context_length` | `0` | Dataset suffix length, e.g. `0`, `8k`, `16k`, `32k` | +| `operations` | `2` | Operation split loaded as `ops_` | +| `num_train_examples` | `-1` | Limit train examples | +| `num_eval_examples` | `-1` | Limit eval examples | + +The reward compares the model's final `Answer: ...` value to the normalized +dataset answer. The environment does not require provider APIs during dataset +loading or scoring. diff --git a/environments/gsm_infinite/gsm_infinite.py b/environments/gsm_infinite/gsm_infinite.py new file mode 100644 index 000000000..b98d0bc64 --- /dev/null +++ b/environments/gsm_infinite/gsm_infinite.py @@ -0,0 +1,150 @@ +import re +from typing import Any + +from datasets import Dataset, load_dataset + +import verifiers as vf + +SYSTEM_PROMPT = ( + "Solve the GSM-Infinite problem step by step. End with the final result in " + "the format `Answer: ...`." +) + +DATASET_TEMPLATE = "InfiniAILab/gsm_infinite_{subset}_{context_length}" +SYMBOLIC_DATASET_TEMPLATE = "InfiniAILab/gsm_infinite_symbolic_{context_length}" + + +def dataset_name(subset: str = "medium", context_length: str = "0") -> str: + if subset not in {"symbolic", "medium", "hard"}: + raise ValueError("subset must be one of: symbolic, medium, hard") + if subset == "symbolic": + return SYMBOLIC_DATASET_TEMPLATE.format(context_length=context_length) + return DATASET_TEMPLATE.format(subset=subset, context_length=context_length) + + +def split_name(operations: int = 2) -> str: + return f"ops_{operations}" + + +def extract_answer(text: Any) -> str: + if isinstance(text, list): + return ", ".join(str(item) for item in text) + text = str(text) + matches = list(re.finditer(r"\banswer\s*:\s*([^\n]+)", text, flags=re.IGNORECASE)) + if matches: + match = matches[-1] + answer = match.group(1).strip() + leading_number = re.match(r"-?\d+(?:\.\d+)?", answer) + if leading_number is not None: + suffix = answer[leading_number.end() :].strip() + if suffix == "" or suffix.startswith("."): + return leading_number.group(0) + return answer.rstrip(".") + numbers = re.findall(r"-?\d+(?:\.\d+)?", text) + return numbers[-1] if numbers else text.strip() + + +def normalize_answer(text: Any) -> str: + if isinstance(text, list): + text = ", ".join(str(item) for item in text) + return re.sub(r"\s+", " ", str(text)).strip().lower() + + +def row_to_example(row: dict[str, Any]) -> dict[str, Any]: + messages = row.get("messages") or [] + question = "" + if messages: + question = "\n\n".join( + str(message.get("content", "")) + for message in messages + if isinstance(message, dict) and message.get("role") == "user" + ) + if not question: + question = ( + f"Problem: {row.get('problem', '')}\nQuestion: {row.get('question', '')}" + ) + + answer_source = row.get("answer_list") or row.get("solution", "") + return { + "question": question, + "answer": extract_answer(answer_source), + "info": { + "benchmark": "gsm-infinite", + "subset": row.get("d"), + "operations": row.get("op"), + "length": row.get("length"), + "source_id": row.get("id"), + }, + } + + +def load_gsm_infinite_dataset( + subset: str = "medium", + context_length: str = "0", + operations: int = 2, + num_examples: int = -1, +) -> Dataset: + ds = load_dataset( + dataset_name(subset=subset, context_length=context_length), + split=split_name(operations), + ) + ds = ds.map(row_to_example, remove_columns=ds.column_names) + if num_examples != -1: + ds = ds.select(range(min(num_examples, len(ds)))) + return ds + + +def _completion_text(completion: Any) -> str: + if isinstance(completion, str): + return completion + if isinstance(completion, list): + return "\n".join( + str(message.get("content", "")) + for message in completion + if isinstance(message, dict) and message.get("role") == "assistant" + ) + return "" + + +def exact_answer_reward(completion, answer, **kwargs) -> float: + return float( + normalize_answer(extract_answer(_completion_text(completion))) + == normalize_answer(answer) + ) + + +def load_environment( + subset: str = "medium", + eval_subset: str | None = None, + context_length: str = "0", + eval_context_length: str | None = None, + operations: int = 2, + eval_operations: int | None = None, + num_train_examples: int = -1, + num_eval_examples: int = -1, + system_prompt: str = SYSTEM_PROMPT, +) -> vf.Environment: + eval_subset = eval_subset if eval_subset is not None else subset + eval_context_length = ( + eval_context_length if eval_context_length is not None else context_length + ) + eval_operations = eval_operations if eval_operations is not None else operations + + rubric = vf.Rubric(funcs=[exact_answer_reward], weights=[1.0]) + return vf.SingleTurnEnv( + dataset=lambda: load_gsm_infinite_dataset( + subset=subset, + context_length=context_length, + operations=operations, + num_examples=num_train_examples, + ), + eval_dataset=lambda: load_gsm_infinite_dataset( + subset=eval_subset, + context_length=eval_context_length, + operations=eval_operations, + num_examples=num_eval_examples, + ), + system_prompt=system_prompt, + parser=vf.Parser(), + rubric=rubric, + ) diff --git a/environments/gsm_infinite/pyproject.toml b/environments/gsm_infinite/pyproject.toml new file mode 100644 index 000000000..7bbd04992 --- /dev/null +++ b/environments/gsm_infinite/pyproject.toml @@ -0,0 +1,24 @@ +[project] +name = "gsm-infinite" +version = "0.1.0" +description = "GSM-Infinite long-context math reasoning environment." +tags = ["math", "gsm-infinite", "long-context", "single-turn", "train", "eval"] +requires-python = ">=3.10" +dependencies = [ + "verifiers>=0.1.14", + "datasets", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["gsm_infinite.py", "README.md", "pyproject.toml"] + +[project.entry-points."verifiers.environments"] +gsm-infinite = "gsm_infinite:load_environment" + +[tool.verifiers.eval] +num_examples = 20 +rollouts_per_example = 3 diff --git a/tests/test_gsm_infinite_environment.py b/tests/test_gsm_infinite_environment.py new file mode 100644 index 000000000..0ce8c8c88 --- /dev/null +++ b/tests/test_gsm_infinite_environment.py @@ -0,0 +1,165 @@ +import importlib.util +from pathlib import Path + +from datasets import Dataset + + +ENV_PATH = ( + Path(__file__).resolve().parents[1] + / "environments" + / "gsm_infinite" + / "gsm_infinite.py" +) + + +def load_module(): + spec = importlib.util.spec_from_file_location("gsm_infinite_env", ENV_PATH) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def sample_dataset() -> Dataset: + return Dataset.from_list( + [ + { + "problem": "There are 2 foxes.", + "question": "How many foxes are there?", + "solution": "Define foxes as F. F = 2. Answer: 2.", + "op": 2, + "id": 4, + "length": "zero_context", + "d": 2, + "messages": [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Problem: There are 2 foxes."}, + {"role": "user", "content": "Question: How many foxes? Solution:"}, + ], + } + ] + ) + + +def test_dataset_name_and_split_validation(): + gsm_infinite = load_module() + + assert ( + gsm_infinite.dataset_name("medium", "8k") + == "InfiniAILab/gsm_infinite_medium_8k" + ) + assert ( + gsm_infinite.dataset_name("symbolic", "0") + == "InfiniAILab/gsm_infinite_symbolic_0" + ) + assert ( + gsm_infinite.dataset_name("symbolic", "8k") + == "InfiniAILab/gsm_infinite_symbolic_8k" + ) + assert gsm_infinite.split_name(12) == "ops_12" + + +def test_row_to_example_maps_messages_and_answer(): + gsm_infinite = load_module() + row = sample_dataset()[0] + + example = gsm_infinite.row_to_example(row) + + assert "Problem: There are 2 foxes." in example["question"] + assert "How many foxes?" in example["question"] + assert example["answer"] == "2" + assert example["info"]["benchmark"] == "gsm-infinite" + + +def test_row_to_example_falls_back_when_answer_list_is_none(): + gsm_infinite = load_module() + row = dict(sample_dataset()[0]) + row["answer_list"] = None + row["solution"] = "Solve it. Answer: 7." + + example = gsm_infinite.row_to_example(row) + + assert example["answer"] == "7" + + +def test_load_dataset_uses_hf_name_split_and_limits(monkeypatch): + gsm_infinite = load_module() + calls = {} + + def fake_load_dataset(name, split): + calls["name"] = name + calls["split"] = split + return sample_dataset() + + monkeypatch.setattr(gsm_infinite, "load_dataset", fake_load_dataset) + + dataset = gsm_infinite.load_gsm_infinite_dataset( + subset="medium", + context_length="0", + operations=2, + num_examples=1, + ) + + assert calls == {"name": "InfiniAILab/gsm_infinite_medium_0", "split": "ops_2"} + assert len(dataset) == 1 + assert {"question", "answer", "info"}.issubset(dataset.column_names) + + +def test_reward_accepts_answer_prefix_and_exact_value(): + gsm_infinite = load_module() + completion = [{"role": "assistant", "content": "Reasoning...\nAnswer: 2."}] + + assert gsm_infinite.exact_answer_reward(completion, "2") == 1.0 + assert gsm_infinite.exact_answer_reward(completion, "3") == 0.0 + + +def test_reward_preserves_multi_value_list_answers(): + gsm_infinite = load_module() + completion = [{"role": "assistant", "content": "Reasoning...\nAnswer: 4, 9."}] + + assert gsm_infinite.extract_answer(["4", "9"]) == "4, 9" + assert gsm_infinite.exact_answer_reward(completion, "4, 9") == 1.0 + assert gsm_infinite.exact_answer_reward(completion, "9") == 0.0 + + +def test_extract_answer_uses_answer_prefix_without_greedy_fallback(): + gsm_infinite = load_module() + + assert gsm_infinite.extract_answer("work\nAnswer: 2. Therefore done") == "2" + assert gsm_infinite.extract_answer("work\nANSWER: V670487.") == "V670487" + assert ( + gsm_infinite.extract_answer("I think answer: 5 is wrong.\nFinal Answer: 10.") + == "10" + ) + + +def test_environment_loads_without_building_dataset(): + gsm_infinite = load_module() + + env = gsm_infinite.load_environment(num_train_examples=1, num_eval_examples=1) + + assert env.system_prompt == gsm_infinite.SYSTEM_PROMPT + assert callable(env.dataset_source) + assert callable(env.eval_dataset_source) + + +def test_environment_preserves_explicit_zero_eval_operations(monkeypatch): + gsm_infinite = load_module() + calls = [] + + def fake_load_dataset(name, split): + calls.append(split) + return sample_dataset() + + monkeypatch.setattr(gsm_infinite, "load_dataset", fake_load_dataset) + env = gsm_infinite.load_environment( + operations=2, + eval_operations=0, + num_train_examples=1, + num_eval_examples=1, + ) + + env.eval_dataset_source() + + assert calls == ["ops_0"]