-
Notifications
You must be signed in to change notification settings - Fork 546
Add GSM-Infinite training environment #1336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
0f1f9f2
Add GSM-Infinite training environment
poofeth a9d549a
Register GSM-Infinite environment
poofeth 203cfb1
Fix GSM-Infinite answer normalization
poofeth 4f732d8
Use final GSM-Infinite answer marker
poofeth e1e9a26
Fix GSM-Infinite symbolic dataset names
poofeth 76ab25b
Handle missing GSM-Infinite answer lists
poofeth File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| # GSM-Infinite | ||
|
|
||
| GSM-Infinite is a train/eval environment for Infini-AI-Lab's scalable | ||
| long-context math reasoning benchmark. It loads the public Hugging Face datasets | ||
| from the GSM-Infinite collection and normalizes rows into Verifiers | ||
| `question`/`answer`/`info` examples. | ||
|
|
||
| By default the environment uses the medium zero-context dataset and the | ||
| `ops_2` split: | ||
|
|
||
| ```bash | ||
| prime eval run gsm-infinite \ | ||
| -a '{"subset": "medium", "context_length": "0", "operations": 2, "num_eval_examples": 20}' | ||
| ``` | ||
|
|
||
| Useful arguments: | ||
|
|
||
| | Arg | Default | Description | | ||
| | --- | --- | --- | | ||
| | `subset` | `medium` | One of `symbolic`, `medium`, or `hard` | | ||
| | `context_length` | `0` | Dataset suffix length, e.g. `0`, `8k`, `16k`, `32k` | | ||
| | `operations` | `2` | Operation split loaded as `ops_<operations>` | | ||
| | `num_train_examples` | `-1` | Limit train examples | | ||
| | `num_eval_examples` | `-1` | Limit eval examples | | ||
|
|
||
| The reward compares the model's final `Answer: ...` value to the normalized | ||
| dataset answer. The environment does not require provider APIs during dataset | ||
| loading or scoring. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| import re | ||
| from typing import Any | ||
|
|
||
| from datasets import Dataset, load_dataset | ||
|
|
||
| import verifiers as vf | ||
|
|
||
| SYSTEM_PROMPT = ( | ||
| "Solve the GSM-Infinite problem step by step. End with the final result in " | ||
| "the format `Answer: ...`." | ||
| ) | ||
|
|
||
| DATASET_TEMPLATE = "InfiniAILab/gsm_infinite_{subset}_{context_length}" | ||
| SYMBOLIC_DATASET_TEMPLATE = "InfiniAILab/gsm_infinite_symbolic_{context_length}" | ||
|
|
||
|
|
||
| def dataset_name(subset: str = "medium", context_length: str = "0") -> str: | ||
| if subset not in {"symbolic", "medium", "hard"}: | ||
| raise ValueError("subset must be one of: symbolic, medium, hard") | ||
| if subset == "symbolic": | ||
| return SYMBOLIC_DATASET_TEMPLATE.format(context_length=context_length) | ||
| return DATASET_TEMPLATE.format(subset=subset, context_length=context_length) | ||
|
|
||
|
|
||
| def split_name(operations: int = 2) -> str: | ||
| return f"ops_{operations}" | ||
|
|
||
|
|
||
| def extract_answer(text: Any) -> str: | ||
| if isinstance(text, list): | ||
| return ", ".join(str(item) for item in text) | ||
| text = str(text) | ||
| matches = list(re.finditer(r"\banswer\s*:\s*([^\n]+)", text, flags=re.IGNORECASE)) | ||
| if matches: | ||
| match = matches[-1] | ||
| answer = match.group(1).strip() | ||
| leading_number = re.match(r"-?\d+(?:\.\d+)?", answer) | ||
| if leading_number is not None: | ||
| suffix = answer[leading_number.end() :].strip() | ||
| if suffix == "" or suffix.startswith("."): | ||
| return leading_number.group(0) | ||
| return answer.rstrip(".") | ||
| numbers = re.findall(r"-?\d+(?:\.\d+)?", text) | ||
| return numbers[-1] if numbers else text.strip() | ||
|
|
||
|
|
||
| def normalize_answer(text: Any) -> str: | ||
| if isinstance(text, list): | ||
| text = ", ".join(str(item) for item in text) | ||
| return re.sub(r"\s+", " ", str(text)).strip().lower() | ||
|
|
||
|
|
||
| def row_to_example(row: dict[str, Any]) -> dict[str, Any]: | ||
| messages = row.get("messages") or [] | ||
| question = "" | ||
| if messages: | ||
| question = "\n\n".join( | ||
| str(message.get("content", "")) | ||
| for message in messages | ||
| if isinstance(message, dict) and message.get("role") == "user" | ||
| ) | ||
| if not question: | ||
| question = ( | ||
| f"Problem: {row.get('problem', '')}\nQuestion: {row.get('question', '')}" | ||
| ) | ||
|
|
||
| answer_source = row.get("answer_list") or row.get("solution", "") | ||
| return { | ||
| "question": question, | ||
| "answer": extract_answer(answer_source), | ||
| "info": { | ||
| "benchmark": "gsm-infinite", | ||
| "subset": row.get("d"), | ||
| "operations": row.get("op"), | ||
| "length": row.get("length"), | ||
| "source_id": row.get("id"), | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| def load_gsm_infinite_dataset( | ||
| subset: str = "medium", | ||
| context_length: str = "0", | ||
| operations: int = 2, | ||
| num_examples: int = -1, | ||
| ) -> Dataset: | ||
| ds = load_dataset( | ||
| dataset_name(subset=subset, context_length=context_length), | ||
| split=split_name(operations), | ||
| ) | ||
| ds = ds.map(row_to_example, remove_columns=ds.column_names) | ||
| if num_examples != -1: | ||
| ds = ds.select(range(min(num_examples, len(ds)))) | ||
| return ds | ||
|
|
||
|
|
||
| def _completion_text(completion: Any) -> str: | ||
| if isinstance(completion, str): | ||
| return completion | ||
| if isinstance(completion, list): | ||
| return "\n".join( | ||
| str(message.get("content", "")) | ||
| for message in completion | ||
| if isinstance(message, dict) and message.get("role") == "assistant" | ||
| ) | ||
| return "" | ||
|
|
||
|
|
||
| def exact_answer_reward(completion, answer, **kwargs) -> float: | ||
| return float( | ||
| normalize_answer(extract_answer(_completion_text(completion))) | ||
| == normalize_answer(answer) | ||
| ) | ||
|
|
||
|
|
||
| def load_environment( | ||
| subset: str = "medium", | ||
| eval_subset: str | None = None, | ||
| context_length: str = "0", | ||
| eval_context_length: str | None = None, | ||
| operations: int = 2, | ||
| eval_operations: int | None = None, | ||
| num_train_examples: int = -1, | ||
| num_eval_examples: int = -1, | ||
| system_prompt: str = SYSTEM_PROMPT, | ||
| ) -> vf.Environment: | ||
| eval_subset = eval_subset if eval_subset is not None else subset | ||
| eval_context_length = ( | ||
| eval_context_length if eval_context_length is not None else context_length | ||
| ) | ||
| eval_operations = eval_operations if eval_operations is not None else operations | ||
|
|
||
| rubric = vf.Rubric(funcs=[exact_answer_reward], weights=[1.0]) | ||
| return vf.SingleTurnEnv( | ||
| dataset=lambda: load_gsm_infinite_dataset( | ||
| subset=subset, | ||
| context_length=context_length, | ||
| operations=operations, | ||
| num_examples=num_train_examples, | ||
| ), | ||
| eval_dataset=lambda: load_gsm_infinite_dataset( | ||
| subset=eval_subset, | ||
| context_length=eval_context_length, | ||
| operations=eval_operations, | ||
| num_examples=num_eval_examples, | ||
| ), | ||
| system_prompt=system_prompt, | ||
| parser=vf.Parser(), | ||
| rubric=rubric, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| [project] | ||
| name = "gsm-infinite" | ||
| version = "0.1.0" | ||
| description = "GSM-Infinite long-context math reasoning environment." | ||
| tags = ["math", "gsm-infinite", "long-context", "single-turn", "train", "eval"] | ||
| requires-python = ">=3.10" | ||
| dependencies = [ | ||
| "verifiers>=0.1.14", | ||
| "datasets", | ||
| ] | ||
|
|
||
| [build-system] | ||
| requires = ["hatchling"] | ||
| build-backend = "hatchling.build" | ||
|
|
||
| [tool.hatch.build] | ||
| include = ["gsm_infinite.py", "README.md", "pyproject.toml"] | ||
|
|
||
| [project.entry-points."verifiers.environments"] | ||
| gsm-infinite = "gsm_infinite:load_environment" | ||
|
|
||
| [tool.verifiers.eval] | ||
| num_examples = 20 | ||
| rollouts_per_example = 3 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| import importlib.util | ||
| from pathlib import Path | ||
|
|
||
| from datasets import Dataset | ||
|
|
||
|
|
||
| ENV_PATH = ( | ||
| Path(__file__).resolve().parents[1] | ||
| / "environments" | ||
| / "gsm_infinite" | ||
| / "gsm_infinite.py" | ||
| ) | ||
|
|
||
|
|
||
| def load_module(): | ||
| spec = importlib.util.spec_from_file_location("gsm_infinite_env", ENV_PATH) | ||
| assert spec is not None | ||
| assert spec.loader is not None | ||
| module = importlib.util.module_from_spec(spec) | ||
| spec.loader.exec_module(module) | ||
| return module | ||
|
|
||
|
|
||
| def sample_dataset() -> Dataset: | ||
| return Dataset.from_list( | ||
| [ | ||
| { | ||
| "problem": "There are 2 foxes.", | ||
| "question": "How many foxes are there?", | ||
| "solution": "Define foxes as F. F = 2. Answer: 2.", | ||
| "op": 2, | ||
| "id": 4, | ||
| "length": "zero_context", | ||
| "d": 2, | ||
| "messages": [ | ||
| {"role": "system", "content": "You are helpful"}, | ||
| {"role": "user", "content": "Problem: There are 2 foxes."}, | ||
| {"role": "user", "content": "Question: How many foxes? Solution:"}, | ||
| ], | ||
| } | ||
| ] | ||
| ) | ||
|
|
||
|
|
||
| def test_dataset_name_and_split_validation(): | ||
| gsm_infinite = load_module() | ||
|
|
||
| assert ( | ||
| gsm_infinite.dataset_name("medium", "8k") | ||
| == "InfiniAILab/gsm_infinite_medium_8k" | ||
| ) | ||
| assert ( | ||
| gsm_infinite.dataset_name("symbolic", "0") | ||
| == "InfiniAILab/gsm_infinite_symbolic_0" | ||
| ) | ||
| assert ( | ||
| gsm_infinite.dataset_name("symbolic", "8k") | ||
| == "InfiniAILab/gsm_infinite_symbolic_8k" | ||
| ) | ||
| assert gsm_infinite.split_name(12) == "ops_12" | ||
|
|
||
|
|
||
| def test_row_to_example_maps_messages_and_answer(): | ||
| gsm_infinite = load_module() | ||
| row = sample_dataset()[0] | ||
|
|
||
| example = gsm_infinite.row_to_example(row) | ||
|
|
||
| assert "Problem: There are 2 foxes." in example["question"] | ||
| assert "How many foxes?" in example["question"] | ||
| assert example["answer"] == "2" | ||
| assert example["info"]["benchmark"] == "gsm-infinite" | ||
|
|
||
|
|
||
| def test_row_to_example_falls_back_when_answer_list_is_none(): | ||
| gsm_infinite = load_module() | ||
| row = dict(sample_dataset()[0]) | ||
| row["answer_list"] = None | ||
| row["solution"] = "Solve it. Answer: 7." | ||
|
|
||
| example = gsm_infinite.row_to_example(row) | ||
|
|
||
| assert example["answer"] == "7" | ||
|
|
||
|
|
||
| def test_load_dataset_uses_hf_name_split_and_limits(monkeypatch): | ||
| gsm_infinite = load_module() | ||
| calls = {} | ||
|
|
||
| def fake_load_dataset(name, split): | ||
| calls["name"] = name | ||
| calls["split"] = split | ||
| return sample_dataset() | ||
|
|
||
| monkeypatch.setattr(gsm_infinite, "load_dataset", fake_load_dataset) | ||
|
|
||
| dataset = gsm_infinite.load_gsm_infinite_dataset( | ||
| subset="medium", | ||
| context_length="0", | ||
| operations=2, | ||
| num_examples=1, | ||
| ) | ||
|
|
||
| assert calls == {"name": "InfiniAILab/gsm_infinite_medium_0", "split": "ops_2"} | ||
| assert len(dataset) == 1 | ||
| assert {"question", "answer", "info"}.issubset(dataset.column_names) | ||
|
|
||
|
|
||
| def test_reward_accepts_answer_prefix_and_exact_value(): | ||
| gsm_infinite = load_module() | ||
| completion = [{"role": "assistant", "content": "Reasoning...\nAnswer: 2."}] | ||
|
|
||
| assert gsm_infinite.exact_answer_reward(completion, "2") == 1.0 | ||
| assert gsm_infinite.exact_answer_reward(completion, "3") == 0.0 | ||
|
|
||
|
|
||
| def test_reward_preserves_multi_value_list_answers(): | ||
| gsm_infinite = load_module() | ||
| completion = [{"role": "assistant", "content": "Reasoning...\nAnswer: 4, 9."}] | ||
|
|
||
| assert gsm_infinite.extract_answer(["4", "9"]) == "4, 9" | ||
| assert gsm_infinite.exact_answer_reward(completion, "4, 9") == 1.0 | ||
| assert gsm_infinite.exact_answer_reward(completion, "9") == 0.0 | ||
|
|
||
|
|
||
| def test_extract_answer_uses_answer_prefix_without_greedy_fallback(): | ||
| gsm_infinite = load_module() | ||
|
|
||
| assert gsm_infinite.extract_answer("work\nAnswer: 2. Therefore done") == "2" | ||
| assert gsm_infinite.extract_answer("work\nANSWER: V670487.") == "V670487" | ||
| assert ( | ||
| gsm_infinite.extract_answer("I think answer: 5 is wrong.\nFinal Answer: 10.") | ||
| == "10" | ||
| ) | ||
|
|
||
|
|
||
| def test_environment_loads_without_building_dataset(): | ||
| gsm_infinite = load_module() | ||
|
|
||
| env = gsm_infinite.load_environment(num_train_examples=1, num_eval_examples=1) | ||
|
|
||
| assert env.system_prompt == gsm_infinite.SYSTEM_PROMPT | ||
| assert callable(env.dataset_source) | ||
| assert callable(env.eval_dataset_source) | ||
|
|
||
|
|
||
| def test_environment_preserves_explicit_zero_eval_operations(monkeypatch): | ||
| gsm_infinite = load_module() | ||
| calls = [] | ||
|
|
||
| def fake_load_dataset(name, split): | ||
| calls.append(split) | ||
| return sample_dataset() | ||
|
|
||
| monkeypatch.setattr(gsm_infinite, "load_dataset", fake_load_dataset) | ||
| env = gsm_infinite.load_environment( | ||
| operations=2, | ||
| eval_operations=0, | ||
| num_train_examples=1, | ||
| num_eval_examples=1, | ||
| ) | ||
|
|
||
| env.eval_dataset_source() | ||
|
|
||
| assert calls == ["ops_0"] |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.