Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions openverifiablellm/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Configuration loader for benchmark contamination detection.

"""
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional

import yaml

logger = logging.getLogger(__name__)

DEFAULT_BENCHMARKS: List[str] = ["gsm8k", "cais/mmlu"]

BENCHMARKS_YAML = "benchmarks.yaml"


@dataclass
class BenchmarkConfig:
"""Holds all settings needed by the contamination-detection pipeline."""

benchmarks: List[str] = field(default_factory=lambda: list(DEFAULT_BENCHMARKS))
n: int = 13
bloom_capacity: int = 10_000_000
bloom_error_rate: float = 0.001
filter_path: Path = Path("data/filter.bin")
trust_remote_code: bool = False


def _load_from_yaml(yaml_path: Path) -> Optional[List[str]]:
""" Attempt to read benchmark identifiers from a YAML file. """
if not yaml_path.is_file():
return None

try:
with yaml_path.open("r", encoding="utf-8") as fh:
data = yaml.safe_load(fh)
if isinstance(data, dict) and "benchmarks" in data:
raw_benchmarks = data["benchmarks"]
if isinstance(raw_benchmarks, list):
benchmarks = [b.strip() for b in raw_benchmarks if isinstance(b, str) and b.strip()]
if benchmarks:
return benchmarks
logger.warning(
"benchmarks.yaml found but does not contain a valid "
"non-empty 'benchmarks' list; falling back to defaults."
)
except (OSError, yaml.YAMLError) as exc:
logger.warning("Failed to read/parse %s: %s", yaml_path, exc)
return None

def load_config(cli_benchmarks: Optional[str] = None) -> BenchmarkConfig:
"""Build a :class:`BenchmarkConfig` by resolving benchmark sources."""
config = BenchmarkConfig()

if cli_benchmarks:
benchmarks = [b.strip() for b in cli_benchmarks.split(",") if b.strip()]
if benchmarks:
config.benchmarks = benchmarks
logger.info("Benchmarks from CLI: %s", config.benchmarks)
return config
else:
logger.warning("CLI provided empty benchmarks; falling back to next priority.")

yaml_benchmarks = _load_from_yaml(Path(BENCHMARKS_YAML))
if yaml_benchmarks:
config.benchmarks = yaml_benchmarks
logger.info("Benchmarks from %s: %s", BENCHMARKS_YAML, config.benchmarks)
return config

logger.info("Using default benchmarks: %s", config.benchmarks)
return config
281 changes: 281 additions & 0 deletions openverifiablellm/contamination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
"""
Contamination detection via n-gram matching and Bloom filters.

Provides utilities to:
- Load benchmark datasets (Hugging Face or local files)
- Generate n-grams from text
- Build / load a Bloom filter populated with benchmark n-grams
- Check whether a text chunk is contaminated
"""

import csv
import hashlib
import json
import logging
import re
from pathlib import Path
from typing import Dict, List, Optional, Set

from rbloom import Bloom

from openverifiablellm.config import BenchmarkConfig

logger = logging.getLogger(__name__)

# Text fields we look for when extracting questions from benchmarks
_TEXT_FIELDS = ("question", "prompt", "problem", "input", "text")


def _bloom_hash(item) -> int:
"""Deterministic hash function for rbloom."""
raw = item.encode("utf-8") if isinstance(item, str) else item
digest = hashlib.sha256(raw).digest()[:8]
return int.from_bytes(digest, "little")

# Benchmark loading
def fetch_benchmarks(config: BenchmarkConfig) -> List[str]:
texts: List[str] = []

for benchmark in config.benchmarks:
path = Path(benchmark)
if path.suffix in (".jsonl", ".csv"):
if not path.is_file():
raise FileNotFoundError(f"Local benchmark file not found: {benchmark}")
texts.extend(_load_local(path))
else:
texts.extend(_load_hf(benchmark, config.trust_remote_code))

logger.info(
"Loaded %d benchmark text(s) from %d source(s).",
len(texts),
len(config.benchmarks),
)
return texts


class DatasetLoadError(Exception):
"""Raised when a benchmark dataset cannot be loaded from its source."""
pass


def _load_hf(
dataset_id: str,
trust_remote_code: bool = False,
split_map: Optional[Dict[str, List[str]]] = None
) -> List[str]:
"""Load benchmark texts from a Hugging Face dataset."""
try:
from datasets import load_dataset
except ImportError as exc:
raise ImportError(
"The 'datasets' package is required to fetch HF benchmarks. "
"Install it with: pip install datasets"
) from exc

if split_map is None:
split_map = {"gsm8k": ["test"],"cais/mmlu": ["test", "validation"]}

texts: List[str] = []
splits_to_load = split_map.get(dataset_id)

try:
if splits_to_load:
splits = [
load_dataset(dataset_id, split=s, trust_remote_code=trust_remote_code)
for s in splits_to_load
]
else:
ds = load_dataset(dataset_id, trust_remote_code=trust_remote_code)
splits = ds.values() if hasattr(ds, "values") else [ds]
except Exception as exc:
logger.error("Could not load HF dataset '%s': %s", dataset_id, exc)
raise DatasetLoadError(f"Failed to load HF dataset '{dataset_id}'") from exc

for split in splits:
for row in split:
for field in _TEXT_FIELDS:
if field in row and row[field]:
texts.append(str(row[field]))
break

return texts


def _load_local(path: Path) -> List[str]:
"""Load benchmark texts from a local JSONL or CSV file."""
texts: List[str] = []

if path.suffix == ".jsonl":
with path.open("r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if not line:
continue
row = json.loads(line)
for field in _TEXT_FIELDS:
if field in row and row[field]:
texts.append(str(row[field]))
break
elif path.suffix == ".csv":
with path.open("r", encoding="utf-8", newline="") as fh:
reader = csv.DictReader(fh)
for row in reader:
for field in _TEXT_FIELDS:
if field in row and row[field]:
texts.append(str(row[field]))
break
else:
logger.warning("Unsupported local file format: %s", path)

return texts


# N-gram utilities

def get_ngrams(text: str, n: int = 13) -> List[str]:
"""
Generate whitespace-tokenised n-grams from text.

The text is normalised (lowercased, punctuation stripped) before
tokenisation. If the text contains fewer than *n* tokens the
result is an empty list.

"""
if n <= 0:
raise ValueError("n must be a positive integer")

normalised = _normalise(text)
tokens = normalised.split()

if len(tokens) < n:
return []

return [" ".join(tokens[i : i + n]) for i in range(len(tokens) - n + 1)]


def _normalise(text: str) -> str:
text = text.lower()
text = re.sub(r"[^\w\s]", "", text)
text = re.sub(r"\s+", " ", text)
return text.strip()

# Bloom filter construction

def build_bloom_filter(
benchmark_texts: List[str],
config: BenchmarkConfig,
) -> Bloom:
"""
Build (or load) a Bloom filter populated with benchmark n-grams.

If ``config.filter_path`` already exists on disk the serialised
filter is loaded directly. Otherwise a new filter is created,
populated, and written to ``config.filter_path``.

"""
filter_path = Path(config.filter_path)
meta_path = filter_path.with_suffix(filter_path.suffix + ".meta")

hasher = hashlib.sha256()
for text in benchmark_texts:
encoded = text.encode("utf-8")
hasher.update(len(encoded).to_bytes(8, "little"))
hasher.update(encoded)
inputs_hash = hasher.hexdigest()

current_meta = {
"benchmarks": sorted(config.benchmarks),
"n": config.n,
"bloom_capacity": config.bloom_capacity,
"bloom_error_rate": config.bloom_error_rate,
"inputs_hash": inputs_hash,
"version": 1,
}

# --- load from cache if available ---
if filter_path.is_file() and meta_path.is_file():
try:
with meta_path.open("r", encoding="utf-8") as f:
cached_meta = json.load(f)

if cached_meta == current_meta:
logger.info("Loading existing Bloom filter from %s", filter_path)
raw = filter_path.read_bytes()
bloom = Bloom.load_bytes(raw, _bloom_hash)
return bloom
else:
logger.info("Bloom filter metadata mismatch; ignoring stale cache.")
except Exception as exc:
logger.warning("Failed to read Bloom filter metadata: %s", exc)

logger.info(
"Building Bloom filter (capacity=%s, error_rate=%s, n=%s) …",
config.bloom_capacity,
config.bloom_error_rate,
config.n,
)

bloom = Bloom(config.bloom_capacity, config.bloom_error_rate, hash_func=_bloom_hash)

ngram_count = 0
for text in benchmark_texts:
for ngram in get_ngrams(text, n=config.n):
bloom.add(ngram)
ngram_count += 1

if ngram_count == 0:
raise ValueError(
"No benchmark n-grams were generated; check the selected sources and n."
)

logger.info("Inserted %d n-grams into the Bloom filter.", ngram_count)

filter_path.parent.mkdir(parents=True, exist_ok=True)
filter_path.write_bytes(bloom.save_bytes())

with meta_path.open("w", encoding="utf-8") as f:
json.dump(current_meta, f, indent=2)

logger.info("Bloom filter and metadata saved to %s", filter_path)

return bloom


# Contamination checking

def check_contamination(
text: str,
bloom_filter: Bloom,
benchmark_texts: Optional[List[str]] = None,
n: int = 13,
precomputed_normalised_benchmarks: Optional[Set[str]] = None,
) -> bool:
"""
Determine whether text is contaminated by benchmark data.

A two-stage process is used:

1. Bloom filter check - fast probabilistic membership test on
the n-grams of text.
2. Exact match verification - if the Bloom filter signals a
hit, a full substring search against benchmark_texts is
performed to eliminate false positives.

"""
if benchmark_texts is None and precomputed_normalised_benchmarks is None:
raise ValueError("Must provide either benchmark_texts or precomputed_normalised_benchmarks")

ngrams = get_ngrams(text, n=n)
if not ngrams:
return False

for ngram in ngrams:
if ngram in bloom_filter:
if precomputed_normalised_benchmarks is None:
precomputed_normalised_benchmarks = {_normalise(bt) for bt in (benchmark_texts or [])}

for nb in precomputed_normalised_benchmarks:
if f" {ngram} " in f" {nb} ":
return True

return False
Loading
Loading