Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ This folder contains installable example environments that showcase common usage
### SingleTurnEnv (prompt → single response)
- **gsm8k**: Classic QA with exact-match reward and optional response-format reward.
- **reverse_text**: XML formatting with non-binary LCS reward + format reward.
- **swe_swiss_sft**: SWE-Swiss supervised code-repair transcripts scored against the target assistant response.
- **continuation_quality**: Completion-style generation (`message_type="completion"`) judged for prose quality with `JudgeRubric`.
- **mmmu**: Multimodal inputs (image + text) packed in chat content; single-turn boxed-answer check.

Expand Down
28 changes: 28 additions & 0 deletions environments/swe_swiss_sft/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SWE-Swiss SFT

SWE-Swiss SFT is a train/eval environment over the public SWE-Swiss supervised
fine-tuning datasets. It converts each row's chat transcript into a single-turn
code-repair prompt with the first assistant response as the target answer.

By default the environment uses `SWE-Swiss/SWESwiss-SFT-Repair-4K`:

```bash
prime eval run swe-swiss-sft \
-a '{"num_eval_examples": 20}'
```

Useful arguments:

| Arg | Default | Description |
| --- | --- | --- |
| `dataset_name` | `SWE-Swiss/SWESwiss-SFT-Repair-4K` | Hugging Face dataset to use for training |
| `eval_dataset_name` | same as `dataset_name` | Optional separate eval dataset |
| `split` | `train` | Training split |
| `eval_split` | same as `split` | Evaluation split |
| `num_train_examples` | `-1` | Limit train examples |
| `num_eval_examples` | `-1` | Limit eval examples |
| `shuffle_seed` | `777` | Deterministic shuffle seed, or `null` to preserve source order |

The reward is normalized text similarity against the reference assistant
response. This keeps the environment lightweight and usable for the existing
SWE-Swiss SFT data without requiring a sandboxed repository checkout.
24 changes: 24 additions & 0 deletions environments/swe_swiss_sft/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[project]
name = "swe-swiss-sft"
version = "0.1.0"
description = "SWE-Swiss SFT train/eval environment."
tags = ["swe", "code", "sft", "train", "eval", "single-turn"]
requires-python = ">=3.10"
dependencies = [
"verifiers>=0.1.14",
"datasets",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build]
include = ["swe_swiss_sft.py", "README.md", "pyproject.toml"]

[project.entry-points."verifiers.environments"]
swe-swiss-sft = "swe_swiss_sft:load_environment"

[tool.verifiers.eval]
num_examples = 20
rollouts_per_example = 1
129 changes: 129 additions & 0 deletions environments/swe_swiss_sft/swe_swiss_sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import json
from difflib import SequenceMatcher
from typing import Any

from datasets import Dataset, load_dataset

import verifiers as vf

DEFAULT_DATASET_NAME = "SWE-Swiss/SWESwiss-SFT-Repair-4K"
DEFAULT_SPLIT = "train"

SYSTEM_PROMPT = (
"You are a code repair agent. Respond with the requested SEARCH/REPLACE "
"edits or solution text only."
)


def parse_messages(raw_messages: Any) -> list[dict[str, str]]:
if isinstance(raw_messages, str):
raw_messages = json.loads(raw_messages)
if not isinstance(raw_messages, list):
raise ValueError("SWE-Swiss rows must contain a messages list or JSON string")

messages: list[dict[str, str]] = []
for message in raw_messages:
if not isinstance(message, dict):
continue
role = str(message.get("role") or "").strip()
content = str(message.get("content") or "")
if role and content:
messages.append({"role": role, "content": content})
return messages


def row_to_example(row: dict[str, Any]) -> dict[str, Any]:
messages = parse_messages(row["messages"])
assistant_messages = [
message for message in messages if message.get("role") == "assistant"
]
if not assistant_messages:
raise ValueError("SWE-Swiss row does not contain an assistant target")

prompt_messages = []
for message in messages:
if message.get("role") == "assistant":
break
prompt_messages.append(message)

if not prompt_messages:
raise ValueError("SWE-Swiss row does not contain prompt messages")

answer = assistant_messages[0]["content"]
return {
"prompt": prompt_messages,
"question": prompt_messages[-1]["content"],
"answer": answer,
"info": {
"benchmark": "swe-swiss",
"target_role": "assistant",
"source_index": row.get("__index_level_0__"),
},
}


def load_swe_swiss_dataset(
dataset_name: str = DEFAULT_DATASET_NAME,
split: str = DEFAULT_SPLIT,
num_examples: int = -1,
shuffle_seed: int | None = None,
) -> Dataset:
ds = load_dataset(dataset_name, split=split)
if shuffle_seed is not None:
ds = ds.shuffle(seed=shuffle_seed)
if num_examples != -1:
ds = ds.select(range(min(num_examples, len(ds))))
return ds.map(row_to_example, remove_columns=ds.column_names)


def _completion_text(completion: Any) -> str:
if isinstance(completion, str):
return completion
if isinstance(completion, list):
return "\n".join(
str(message.get("content", ""))
for message in completion
if isinstance(message, dict) and message.get("role") == "assistant"
)
return ""


def normalized_similarity_reward(completion, answer, **kwargs) -> float:
response = " ".join(_completion_text(completion).split())
target = " ".join(str(answer).split())
if not response or not target:
return 0.0
return SequenceMatcher(None, response, target).ratio()


def load_environment(
dataset_name: str = DEFAULT_DATASET_NAME,
eval_dataset_name: str | None = None,
split: str = DEFAULT_SPLIT,
eval_split: str | None = None,
num_train_examples: int = -1,
num_eval_examples: int = -1,
shuffle_seed: int | None = 777,
system_prompt: str = SYSTEM_PROMPT,
) -> vf.Environment:
eval_dataset_name = eval_dataset_name or dataset_name
eval_split = eval_split or split

rubric = vf.Rubric(funcs=[normalized_similarity_reward], weights=[1.0])
return vf.SingleTurnEnv(
dataset=lambda: load_swe_swiss_dataset(
dataset_name=dataset_name,
split=split,
num_examples=num_train_examples,
shuffle_seed=shuffle_seed,
),
eval_dataset=lambda: load_swe_swiss_dataset(
dataset_name=eval_dataset_name,
split=eval_split,
num_examples=num_eval_examples,
shuffle_seed=shuffle_seed,
),
system_prompt=system_prompt,
parser=vf.Parser(),
rubric=rubric,
)
132 changes: 132 additions & 0 deletions tests/test_swe_swiss_sft_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import importlib.util
import json
from pathlib import Path

from datasets import Dataset


ENV_PATH = (
Path(__file__).resolve().parents[1]
/ "environments"
/ "swe_swiss_sft"
/ "swe_swiss_sft.py"
)


def load_module():
spec = importlib.util.spec_from_file_location("swe_swiss_sft_env", ENV_PATH)
assert spec is not None
assert spec.loader is not None
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


def sample_messages():
return [
{"role": "system", "content": "You are a repair agent."},
{"role": "user", "content": "Fix the failing test."},
{
"role": "assistant",
"content": "### app.py\n<<<<<<< SEARCH\nbad\n=======\ngood\n>>>>>>> REPLACE",
},
]


def sample_dataset(messages_as_json=False) -> Dataset:
messages = sample_messages()
return Dataset.from_list(
[
{
"messages": json.dumps(messages) if messages_as_json else messages,
"__index_level_0__": 7,
}
]
)


def test_parse_messages_accepts_list_and_json_string():
swe_swiss = load_module()

assert swe_swiss.parse_messages(sample_messages()) == sample_messages()
assert swe_swiss.parse_messages(json.dumps(sample_messages())) == sample_messages()


def test_parse_messages_skips_null_content():
swe_swiss = load_module()
messages = [
{"role": "system", "content": None},
{"role": "user", "content": "Fix this."},
{"role": "assistant", "content": "Done."},
]

parsed = swe_swiss.parse_messages(messages)

assert parsed == [
{"role": "user", "content": "Fix this."},
{"role": "assistant", "content": "Done."},
]
assert all(message["content"] != "None" for message in parsed)


def test_row_to_example_uses_prompt_before_first_assistant():
swe_swiss = load_module()
row = sample_dataset()[0]

example = swe_swiss.row_to_example(row)

assert example["prompt"] == sample_messages()[:2]
assert example["question"] == "Fix the failing test."
assert ">>>>>>> REPLACE" in example["answer"]
assert example["info"]["benchmark"] == "swe-swiss"
assert example["info"]["source_index"] == 7


def test_load_dataset_uses_requested_hf_dataset_and_limits(monkeypatch):
swe_swiss = load_module()
calls = {}

def fake_load_dataset(name, split):
calls["name"] = name
calls["split"] = split
return sample_dataset(messages_as_json=True)

monkeypatch.setattr(swe_swiss, "load_dataset", fake_load_dataset)

dataset = swe_swiss.load_swe_swiss_dataset(
dataset_name="SWE-Swiss/SWESwiss-SFT-Merged-10K",
split="train",
num_examples=1,
shuffle_seed=None,
)

assert calls == {"name": "SWE-Swiss/SWESwiss-SFT-Merged-10K", "split": "train"}
assert len(dataset) == 1
assert {"prompt", "question", "answer", "info"}.issubset(dataset.column_names)


def test_similarity_reward_scores_exact_higher_than_wrong_answer():
swe_swiss = load_module()
answer = "replace bad with good"

exact = swe_swiss.normalized_similarity_reward(
[{"role": "assistant", "content": "replace bad with good"}],
answer,
)
wrong = swe_swiss.normalized_similarity_reward(
[{"role": "assistant", "content": "unrelated text"}],
answer,
)

assert exact == 1.0
assert wrong < exact


def test_environment_loads_without_building_dataset():
swe_swiss = load_module()

env = swe_swiss.load_environment(num_train_examples=1, num_eval_examples=1)

assert env.system_prompt == swe_swiss.SYSTEM_PROMPT
assert callable(env.dataset_source)
assert callable(env.eval_dataset_source)