diff --git a/environments/README.md b/environments/README.md index 503d65b37..5e7a5338e 100644 --- a/environments/README.md +++ b/environments/README.md @@ -12,6 +12,7 @@ This folder contains installable example environments that showcase common usage ### SingleTurnEnv (prompt → single response) - **gsm8k**: Classic QA with exact-match reward and optional response-format reward. - **reverse_text**: XML formatting with non-binary LCS reward + format reward. +- **swe_swiss_sft**: SWE-Swiss supervised code-repair transcripts scored against the target assistant response. - **continuation_quality**: Completion-style generation (`message_type="completion"`) judged for prose quality with `JudgeRubric`. - **mmmu**: Multimodal inputs (image + text) packed in chat content; single-turn boxed-answer check. diff --git a/environments/swe_swiss_sft/README.md b/environments/swe_swiss_sft/README.md new file mode 100644 index 000000000..63d3ae2a9 --- /dev/null +++ b/environments/swe_swiss_sft/README.md @@ -0,0 +1,28 @@ +# SWE-Swiss SFT + +SWE-Swiss SFT is a train/eval environment over the public SWE-Swiss supervised +fine-tuning datasets. It converts each row's chat transcript into a single-turn +code-repair prompt with the first assistant response as the target answer. + +By default the environment uses `SWE-Swiss/SWESwiss-SFT-Repair-4K`: + +```bash +prime eval run swe-swiss-sft \ + -a '{"num_eval_examples": 20}' +``` + +Useful arguments: + +| Arg | Default | Description | +| --- | --- | --- | +| `dataset_name` | `SWE-Swiss/SWESwiss-SFT-Repair-4K` | Hugging Face dataset to use for training | +| `eval_dataset_name` | same as `dataset_name` | Optional separate eval dataset | +| `split` | `train` | Training split | +| `eval_split` | same as `split` | Evaluation split | +| `num_train_examples` | `-1` | Limit train examples | +| `num_eval_examples` | `-1` | Limit eval examples | +| `shuffle_seed` | `777` | Deterministic shuffle seed, or `null` to preserve source order | + +The reward is normalized text similarity against the reference assistant +response. This keeps the environment lightweight and usable for the existing +SWE-Swiss SFT data without requiring a sandboxed repository checkout. diff --git a/environments/swe_swiss_sft/pyproject.toml b/environments/swe_swiss_sft/pyproject.toml new file mode 100644 index 000000000..397df32b8 --- /dev/null +++ b/environments/swe_swiss_sft/pyproject.toml @@ -0,0 +1,24 @@ +[project] +name = "swe-swiss-sft" +version = "0.1.0" +description = "SWE-Swiss SFT train/eval environment." +tags = ["swe", "code", "sft", "train", "eval", "single-turn"] +requires-python = ">=3.10" +dependencies = [ + "verifiers>=0.1.14", + "datasets", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["swe_swiss_sft.py", "README.md", "pyproject.toml"] + +[project.entry-points."verifiers.environments"] +swe-swiss-sft = "swe_swiss_sft:load_environment" + +[tool.verifiers.eval] +num_examples = 20 +rollouts_per_example = 1 diff --git a/environments/swe_swiss_sft/swe_swiss_sft.py b/environments/swe_swiss_sft/swe_swiss_sft.py new file mode 100644 index 000000000..c1f7c3532 --- /dev/null +++ b/environments/swe_swiss_sft/swe_swiss_sft.py @@ -0,0 +1,129 @@ +import json +from difflib import SequenceMatcher +from typing import Any + +from datasets import Dataset, load_dataset + +import verifiers as vf + +DEFAULT_DATASET_NAME = "SWE-Swiss/SWESwiss-SFT-Repair-4K" +DEFAULT_SPLIT = "train" + +SYSTEM_PROMPT = ( + "You are a code repair agent. Respond with the requested SEARCH/REPLACE " + "edits or solution text only." +) + + +def parse_messages(raw_messages: Any) -> list[dict[str, str]]: + if isinstance(raw_messages, str): + raw_messages = json.loads(raw_messages) + if not isinstance(raw_messages, list): + raise ValueError("SWE-Swiss rows must contain a messages list or JSON string") + + messages: list[dict[str, str]] = [] + for message in raw_messages: + if not isinstance(message, dict): + continue + role = str(message.get("role") or "").strip() + content = str(message.get("content") or "") + if role and content: + messages.append({"role": role, "content": content}) + return messages + + +def row_to_example(row: dict[str, Any]) -> dict[str, Any]: + messages = parse_messages(row["messages"]) + assistant_messages = [ + message for message in messages if message.get("role") == "assistant" + ] + if not assistant_messages: + raise ValueError("SWE-Swiss row does not contain an assistant target") + + prompt_messages = [] + for message in messages: + if message.get("role") == "assistant": + break + prompt_messages.append(message) + + if not prompt_messages: + raise ValueError("SWE-Swiss row does not contain prompt messages") + + answer = assistant_messages[0]["content"] + return { + "prompt": prompt_messages, + "question": prompt_messages[-1]["content"], + "answer": answer, + "info": { + "benchmark": "swe-swiss", + "target_role": "assistant", + "source_index": row.get("__index_level_0__"), + }, + } + + +def load_swe_swiss_dataset( + dataset_name: str = DEFAULT_DATASET_NAME, + split: str = DEFAULT_SPLIT, + num_examples: int = -1, + shuffle_seed: int | None = None, +) -> Dataset: + ds = load_dataset(dataset_name, split=split) + if shuffle_seed is not None: + ds = ds.shuffle(seed=shuffle_seed) + if num_examples != -1: + ds = ds.select(range(min(num_examples, len(ds)))) + return ds.map(row_to_example, remove_columns=ds.column_names) + + +def _completion_text(completion: Any) -> str: + if isinstance(completion, str): + return completion + if isinstance(completion, list): + return "\n".join( + str(message.get("content", "")) + for message in completion + if isinstance(message, dict) and message.get("role") == "assistant" + ) + return "" + + +def normalized_similarity_reward(completion, answer, **kwargs) -> float: + response = " ".join(_completion_text(completion).split()) + target = " ".join(str(answer).split()) + if not response or not target: + return 0.0 + return SequenceMatcher(None, response, target).ratio() + + +def load_environment( + dataset_name: str = DEFAULT_DATASET_NAME, + eval_dataset_name: str | None = None, + split: str = DEFAULT_SPLIT, + eval_split: str | None = None, + num_train_examples: int = -1, + num_eval_examples: int = -1, + shuffle_seed: int | None = 777, + system_prompt: str = SYSTEM_PROMPT, +) -> vf.Environment: + eval_dataset_name = eval_dataset_name or dataset_name + eval_split = eval_split or split + + rubric = vf.Rubric(funcs=[normalized_similarity_reward], weights=[1.0]) + return vf.SingleTurnEnv( + dataset=lambda: load_swe_swiss_dataset( + dataset_name=dataset_name, + split=split, + num_examples=num_train_examples, + shuffle_seed=shuffle_seed, + ), + eval_dataset=lambda: load_swe_swiss_dataset( + dataset_name=eval_dataset_name, + split=eval_split, + num_examples=num_eval_examples, + shuffle_seed=shuffle_seed, + ), + system_prompt=system_prompt, + parser=vf.Parser(), + rubric=rubric, + ) diff --git a/tests/test_swe_swiss_sft_environment.py b/tests/test_swe_swiss_sft_environment.py new file mode 100644 index 000000000..712d286a6 --- /dev/null +++ b/tests/test_swe_swiss_sft_environment.py @@ -0,0 +1,132 @@ +import importlib.util +import json +from pathlib import Path + +from datasets import Dataset + + +ENV_PATH = ( + Path(__file__).resolve().parents[1] + / "environments" + / "swe_swiss_sft" + / "swe_swiss_sft.py" +) + + +def load_module(): + spec = importlib.util.spec_from_file_location("swe_swiss_sft_env", ENV_PATH) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def sample_messages(): + return [ + {"role": "system", "content": "You are a repair agent."}, + {"role": "user", "content": "Fix the failing test."}, + { + "role": "assistant", + "content": "### app.py\n<<<<<<< SEARCH\nbad\n=======\ngood\n>>>>>>> REPLACE", + }, + ] + + +def sample_dataset(messages_as_json=False) -> Dataset: + messages = sample_messages() + return Dataset.from_list( + [ + { + "messages": json.dumps(messages) if messages_as_json else messages, + "__index_level_0__": 7, + } + ] + ) + + +def test_parse_messages_accepts_list_and_json_string(): + swe_swiss = load_module() + + assert swe_swiss.parse_messages(sample_messages()) == sample_messages() + assert swe_swiss.parse_messages(json.dumps(sample_messages())) == sample_messages() + + +def test_parse_messages_skips_null_content(): + swe_swiss = load_module() + messages = [ + {"role": "system", "content": None}, + {"role": "user", "content": "Fix this."}, + {"role": "assistant", "content": "Done."}, + ] + + parsed = swe_swiss.parse_messages(messages) + + assert parsed == [ + {"role": "user", "content": "Fix this."}, + {"role": "assistant", "content": "Done."}, + ] + assert all(message["content"] != "None" for message in parsed) + + +def test_row_to_example_uses_prompt_before_first_assistant(): + swe_swiss = load_module() + row = sample_dataset()[0] + + example = swe_swiss.row_to_example(row) + + assert example["prompt"] == sample_messages()[:2] + assert example["question"] == "Fix the failing test." + assert ">>>>>>> REPLACE" in example["answer"] + assert example["info"]["benchmark"] == "swe-swiss" + assert example["info"]["source_index"] == 7 + + +def test_load_dataset_uses_requested_hf_dataset_and_limits(monkeypatch): + swe_swiss = load_module() + calls = {} + + def fake_load_dataset(name, split): + calls["name"] = name + calls["split"] = split + return sample_dataset(messages_as_json=True) + + monkeypatch.setattr(swe_swiss, "load_dataset", fake_load_dataset) + + dataset = swe_swiss.load_swe_swiss_dataset( + dataset_name="SWE-Swiss/SWESwiss-SFT-Merged-10K", + split="train", + num_examples=1, + shuffle_seed=None, + ) + + assert calls == {"name": "SWE-Swiss/SWESwiss-SFT-Merged-10K", "split": "train"} + assert len(dataset) == 1 + assert {"prompt", "question", "answer", "info"}.issubset(dataset.column_names) + + +def test_similarity_reward_scores_exact_higher_than_wrong_answer(): + swe_swiss = load_module() + answer = "replace bad with good" + + exact = swe_swiss.normalized_similarity_reward( + [{"role": "assistant", "content": "replace bad with good"}], + answer, + ) + wrong = swe_swiss.normalized_similarity_reward( + [{"role": "assistant", "content": "unrelated text"}], + answer, + ) + + assert exact == 1.0 + assert wrong < exact + + +def test_environment_loads_without_building_dataset(): + swe_swiss = load_module() + + env = swe_swiss.load_environment(num_train_examples=1, num_eval_examples=1) + + assert env.system_prompt == swe_swiss.SYSTEM_PROMPT + assert callable(env.dataset_source) + assert callable(env.eval_dataset_source)