Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ This folder contains installable example environments that showcase common usage

### SingleTurnEnv (prompt → single response)
- **gsm8k**: Classic QA with exact-match reward and optional response-format reward.
- **gsm_infinite**: GSM-Infinite long-context arithmetic datasets with exact final-answer scoring.
- **reverse_text**: XML formatting with non-binary LCS reward + format reward.
- **continuation_quality**: Completion-style generation (`message_type="completion"`) judged for prose quality with `JudgeRubric`.
- **mmmu**: Multimodal inputs (image + text) packed in chat content; single-turn boxed-answer check.
Expand Down
28 changes: 28 additions & 0 deletions environments/gsm_infinite/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# GSM-Infinite

GSM-Infinite is a train/eval environment for Infini-AI-Lab's scalable
long-context math reasoning benchmark. It loads the public Hugging Face datasets
from the GSM-Infinite collection and normalizes rows into Verifiers
`question`/`answer`/`info` examples.

By default the environment uses the medium zero-context dataset and the
`ops_2` split:

```bash
prime eval run gsm-infinite \
-a '{"subset": "medium", "context_length": "0", "operations": 2, "num_eval_examples": 20}'
```

Useful arguments:

| Arg | Default | Description |
| --- | --- | --- |
| `subset` | `medium` | One of `symbolic`, `medium`, or `hard` |
| `context_length` | `0` | Dataset suffix length, e.g. `0`, `8k`, `16k`, `32k` |
| `operations` | `2` | Operation split loaded as `ops_<operations>` |
| `num_train_examples` | `-1` | Limit train examples |
| `num_eval_examples` | `-1` | Limit eval examples |

The reward compares the model's final `Answer: ...` value to the normalized
dataset answer. The environment does not require provider APIs during dataset
loading or scoring.
150 changes: 150 additions & 0 deletions environments/gsm_infinite/gsm_infinite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import re
from typing import Any

from datasets import Dataset, load_dataset

import verifiers as vf

SYSTEM_PROMPT = (
"Solve the GSM-Infinite problem step by step. End with the final result in "
"the format `Answer: ...`."
)

DATASET_TEMPLATE = "InfiniAILab/gsm_infinite_{subset}_{context_length}"
SYMBOLIC_DATASET_TEMPLATE = "InfiniAILab/gsm_infinite_symbolic_{context_length}"


def dataset_name(subset: str = "medium", context_length: str = "0") -> str:
if subset not in {"symbolic", "medium", "hard"}:
raise ValueError("subset must be one of: symbolic, medium, hard")
if subset == "symbolic":
return SYMBOLIC_DATASET_TEMPLATE.format(context_length=context_length)
return DATASET_TEMPLATE.format(subset=subset, context_length=context_length)


def split_name(operations: int = 2) -> str:
return f"ops_{operations}"


def extract_answer(text: Any) -> str:
if isinstance(text, list):
return ", ".join(str(item) for item in text)
text = str(text)
matches = list(re.finditer(r"\banswer\s*:\s*([^\n]+)", text, flags=re.IGNORECASE))
if matches:
match = matches[-1]
answer = match.group(1).strip()
leading_number = re.match(r"-?\d+(?:\.\d+)?", answer)
if leading_number is not None:
suffix = answer[leading_number.end() :].strip()
if suffix == "" or suffix.startswith("."):
return leading_number.group(0)
return answer.rstrip(".")
numbers = re.findall(r"-?\d+(?:\.\d+)?", text)
return numbers[-1] if numbers else text.strip()


def normalize_answer(text: Any) -> str:
if isinstance(text, list):
text = ", ".join(str(item) for item in text)
return re.sub(r"\s+", " ", str(text)).strip().lower()


def row_to_example(row: dict[str, Any]) -> dict[str, Any]:
messages = row.get("messages") or []
question = ""
if messages:
question = "\n\n".join(
str(message.get("content", ""))
for message in messages
if isinstance(message, dict) and message.get("role") == "user"
)
if not question:
question = (
f"Problem: {row.get('problem', '')}\nQuestion: {row.get('question', '')}"
)

answer_source = row.get("answer_list") or row.get("solution", "")
return {
"question": question,
"answer": extract_answer(answer_source),
"info": {
"benchmark": "gsm-infinite",
"subset": row.get("d"),
"operations": row.get("op"),
"length": row.get("length"),
"source_id": row.get("id"),
},
}


def load_gsm_infinite_dataset(
subset: str = "medium",
context_length: str = "0",
operations: int = 2,
num_examples: int = -1,
) -> Dataset:
ds = load_dataset(
dataset_name(subset=subset, context_length=context_length),
split=split_name(operations),
)
ds = ds.map(row_to_example, remove_columns=ds.column_names)
if num_examples != -1:
ds = ds.select(range(min(num_examples, len(ds))))
return ds


def _completion_text(completion: Any) -> str:
if isinstance(completion, str):
return completion
if isinstance(completion, list):
return "\n".join(
str(message.get("content", ""))
for message in completion
if isinstance(message, dict) and message.get("role") == "assistant"
)
return ""


def exact_answer_reward(completion, answer, **kwargs) -> float:
return float(
normalize_answer(extract_answer(_completion_text(completion)))
== normalize_answer(answer)
)
Comment thread
cursor[bot] marked this conversation as resolved.


def load_environment(
subset: str = "medium",
eval_subset: str | None = None,
context_length: str = "0",
eval_context_length: str | None = None,
operations: int = 2,
eval_operations: int | None = None,
num_train_examples: int = -1,
num_eval_examples: int = -1,
system_prompt: str = SYSTEM_PROMPT,
) -> vf.Environment:
eval_subset = eval_subset if eval_subset is not None else subset
eval_context_length = (
eval_context_length if eval_context_length is not None else context_length
)
eval_operations = eval_operations if eval_operations is not None else operations

rubric = vf.Rubric(funcs=[exact_answer_reward], weights=[1.0])
return vf.SingleTurnEnv(
dataset=lambda: load_gsm_infinite_dataset(
subset=subset,
context_length=context_length,
operations=operations,
num_examples=num_train_examples,
),
eval_dataset=lambda: load_gsm_infinite_dataset(
subset=eval_subset,
context_length=eval_context_length,
operations=eval_operations,
num_examples=num_eval_examples,
),
system_prompt=system_prompt,
parser=vf.Parser(),
rubric=rubric,
)
24 changes: 24 additions & 0 deletions environments/gsm_infinite/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[project]
name = "gsm-infinite"
version = "0.1.0"
description = "GSM-Infinite long-context math reasoning environment."
tags = ["math", "gsm-infinite", "long-context", "single-turn", "train", "eval"]
requires-python = ">=3.10"
dependencies = [
"verifiers>=0.1.14",
"datasets",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build]
include = ["gsm_infinite.py", "README.md", "pyproject.toml"]

[project.entry-points."verifiers.environments"]
gsm-infinite = "gsm_infinite:load_environment"

[tool.verifiers.eval]
num_examples = 20
rollouts_per_example = 3
165 changes: 165 additions & 0 deletions tests/test_gsm_infinite_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import importlib.util
from pathlib import Path

from datasets import Dataset


ENV_PATH = (
Path(__file__).resolve().parents[1]
/ "environments"
/ "gsm_infinite"
/ "gsm_infinite.py"
)


def load_module():
spec = importlib.util.spec_from_file_location("gsm_infinite_env", ENV_PATH)
assert spec is not None
assert spec.loader is not None
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


def sample_dataset() -> Dataset:
return Dataset.from_list(
[
{
"problem": "There are 2 foxes.",
"question": "How many foxes are there?",
"solution": "Define foxes as F. F = 2. Answer: 2.",
"op": 2,
"id": 4,
"length": "zero_context",
"d": 2,
"messages": [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Problem: There are 2 foxes."},
{"role": "user", "content": "Question: How many foxes? Solution:"},
],
}
]
)


def test_dataset_name_and_split_validation():
gsm_infinite = load_module()

assert (
gsm_infinite.dataset_name("medium", "8k")
== "InfiniAILab/gsm_infinite_medium_8k"
)
assert (
gsm_infinite.dataset_name("symbolic", "0")
== "InfiniAILab/gsm_infinite_symbolic_0"
)
assert (
gsm_infinite.dataset_name("symbolic", "8k")
== "InfiniAILab/gsm_infinite_symbolic_8k"
)
assert gsm_infinite.split_name(12) == "ops_12"


def test_row_to_example_maps_messages_and_answer():
gsm_infinite = load_module()
row = sample_dataset()[0]

example = gsm_infinite.row_to_example(row)

assert "Problem: There are 2 foxes." in example["question"]
assert "How many foxes?" in example["question"]
assert example["answer"] == "2"
assert example["info"]["benchmark"] == "gsm-infinite"


def test_row_to_example_falls_back_when_answer_list_is_none():
gsm_infinite = load_module()
row = dict(sample_dataset()[0])
row["answer_list"] = None
row["solution"] = "Solve it. Answer: 7."

example = gsm_infinite.row_to_example(row)

assert example["answer"] == "7"


def test_load_dataset_uses_hf_name_split_and_limits(monkeypatch):
gsm_infinite = load_module()
calls = {}

def fake_load_dataset(name, split):
calls["name"] = name
calls["split"] = split
return sample_dataset()

monkeypatch.setattr(gsm_infinite, "load_dataset", fake_load_dataset)

dataset = gsm_infinite.load_gsm_infinite_dataset(
subset="medium",
context_length="0",
operations=2,
num_examples=1,
)

assert calls == {"name": "InfiniAILab/gsm_infinite_medium_0", "split": "ops_2"}
assert len(dataset) == 1
assert {"question", "answer", "info"}.issubset(dataset.column_names)


def test_reward_accepts_answer_prefix_and_exact_value():
gsm_infinite = load_module()
completion = [{"role": "assistant", "content": "Reasoning...\nAnswer: 2."}]

assert gsm_infinite.exact_answer_reward(completion, "2") == 1.0
assert gsm_infinite.exact_answer_reward(completion, "3") == 0.0


def test_reward_preserves_multi_value_list_answers():
gsm_infinite = load_module()
completion = [{"role": "assistant", "content": "Reasoning...\nAnswer: 4, 9."}]

assert gsm_infinite.extract_answer(["4", "9"]) == "4, 9"
assert gsm_infinite.exact_answer_reward(completion, "4, 9") == 1.0
assert gsm_infinite.exact_answer_reward(completion, "9") == 0.0


def test_extract_answer_uses_answer_prefix_without_greedy_fallback():
gsm_infinite = load_module()

assert gsm_infinite.extract_answer("work\nAnswer: 2. Therefore done") == "2"
assert gsm_infinite.extract_answer("work\nANSWER: V670487.") == "V670487"
assert (
gsm_infinite.extract_answer("I think answer: 5 is wrong.\nFinal Answer: 10.")
== "10"
)


def test_environment_loads_without_building_dataset():
gsm_infinite = load_module()

env = gsm_infinite.load_environment(num_train_examples=1, num_eval_examples=1)

assert env.system_prompt == gsm_infinite.SYSTEM_PROMPT
assert callable(env.dataset_source)
assert callable(env.eval_dataset_source)


def test_environment_preserves_explicit_zero_eval_operations(monkeypatch):
gsm_infinite = load_module()
calls = []

def fake_load_dataset(name, split):
calls.append(split)
return sample_dataset()

monkeypatch.setattr(gsm_infinite, "load_dataset", fake_load_dataset)
env = gsm_infinite.load_environment(
operations=2,
eval_operations=0,
num_train_examples=1,
num_eval_examples=1,
)

env.eval_dataset_source()

assert calls == ["ops_0"]