diff --git a/openverifiablellm/eval/__init__.py b/openverifiablellm/eval/__init__.py new file mode 100644 index 0000000..30266cd --- /dev/null +++ b/openverifiablellm/eval/__init__.py @@ -0,0 +1,5 @@ +from .perplexity import PerplexityEvaluator + +__all__ = [ + "PerplexityEvaluator", +] diff --git a/openverifiablellm/eval/base.py b/openverifiablellm/eval/base.py new file mode 100644 index 0000000..dc59a6b --- /dev/null +++ b/openverifiablellm/eval/base.py @@ -0,0 +1,44 @@ +from abc import ABC, abstractmethod +from typing import List + +try: + from typing import Protocol, runtime_checkable +except ImportError: # Python < 3.8 + from typing_extensions import Protocol, runtime_checkable + + +@runtime_checkable +class Model(Protocol): + """Structural type for a language model callable.""" + + def __call__(self, input_ids: List[int]) -> List[List[float]]: ... + + +@runtime_checkable +class Tokenizer(Protocol): + """Structural type for a tokenizer.""" + + def encode(self, text: str) -> List[int]: ... + + +class BaseEvaluator(ABC): + """Abstract base class for all dataset evaluators.""" + + @abstractmethod + def evaluate(self, model: Model, tokenizer: Tokenizer) -> dict: + """ + Evaluate a language model using the given tokenizer. + + Parameters + ---------- + model : callable + Callable accepting a sequence of token IDs and returning a + 2-D sequence of logits with shape ``(len(input_ids), vocab_size)``. + tokenizer : object + Object with an ``encode(text: str) -> list[int]`` method. + + Returns + ------- + dict + Benchmark-specific evaluation results. + """ diff --git a/openverifiablellm/eval/factual/__init__.py b/openverifiablellm/eval/factual/__init__.py new file mode 100644 index 0000000..4ebbffe --- /dev/null +++ b/openverifiablellm/eval/factual/__init__.py @@ -0,0 +1,5 @@ +from .factual_consistency import WikipediaFactualEvaluator + +__all__ = [ + "WikipediaFactualEvaluator", +] diff --git a/openverifiablellm/eval/factual/factual_consistency.py b/openverifiablellm/eval/factual/factual_consistency.py new file mode 100644 index 0000000..2235547 --- /dev/null +++ b/openverifiablellm/eval/factual/factual_consistency.py @@ -0,0 +1,246 @@ +""" +openverifiablellm/eval/factual/factual_consistency.py + +Wikipedia-based factual consistency evaluator. +""" + +import math +import random +import re +from pathlib import Path +from typing import List, Optional, Union + +from ..base import BaseEvaluator +from ..perplexity import PerplexityEvaluator + +_ENTITY_RE = re.compile(r"\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\b") + + +class WikipediaFactualEvaluator(BaseEvaluator): + """ + Evaluates factual consistency of a language model using Wikipedia passages. + + For each sentence extracted from a processed Wikipedia text file + (``wiki_clean.txt``), a counterfactual variant is generated by substituting + a named entity found in the sentence with a different named entity drawn + from the same passage. The model's perplexity is then compared on the + original (factual) vs the substituted (counterfactual) sentence. A + well-trained model should assign lower perplexity to factual sentences. + + The ``factual_score`` is the mean per-pair difference + ``(counterfactual_ppl - factual_ppl)``: positive values indicate the model + correctly prefers factual sentences, negative values indicate the model + prefers counterfactual sentences. + + Named entities are identified with the simple capitalized-word-sequence + regex ``r"\\b([A-Z][a-z]+(?:\\s+[A-Z][a-z]+)*)\\b"``. Evaluation is + fully deterministic: ``random.seed(42)`` is applied inside + :meth:`evaluate` before any entity selection. + + Parameters + ---------- + wiki_text_path : str or Path + Path to the processed ``wiki_clean.txt`` file produced by + :func:`openverifiablellm.utils.extract_text_from_xml`. + n_samples : int or None + Maximum number of sentence pairs to evaluate. ``None`` evaluates + all available pairs. Default ``None``. + """ + + def __init__( + self, + wiki_text_path: Union[str, Path], + n_samples: Optional[int] = None, + ): + self.wiki_text_path = Path(wiki_text_path) + self.n_samples = n_samples + + # ------------------------------------------------------------------ + # Static helpers + # ------------------------------------------------------------------ + + @staticmethod + def _substitute_entity(sentence: str, candidate_entities: List[str]) -> Optional[str]: + """ + Replace the first named entity in *sentence* with a random different + entity drawn from *candidate_entities*. + + Named entities are matched by + ``r"\\b([A-Z][a-z]+(?:\\s+[A-Z][a-z]+)*)\\b"``. + + Parameters + ---------- + sentence : str + Input sentence. + candidate_entities : list[str] + Pool of named entities to draw substitutes from (typically all + entities extracted from the enclosing passage). + + Returns + ------- + str or None + The modified sentence with the first entity replaced, or ``None`` + if no named entity was found in *sentence* or no differing + substitute is available in *candidate_entities*. + """ + matches = _ENTITY_RE.findall(sentence) + if not matches: + return None + + found_entity = matches[0] + alternatives = [e for e in candidate_entities if e != found_entity] + if not alternatives: + return None + + substitute = random.choice(alternatives) + pattern = r"\b" + re.escape(found_entity) + r"\b" + return re.sub(pattern, substitute, sentence, count=1) + + @staticmethod + def _extract_passages( + wiki_text_path: Union[str, Path], + n_samples: Optional[int], + ) -> List[dict]: + """ + Build factual/counterfactual sentence pairs from *wiki_text_path*. + + The file is read line by line; consecutive non-empty lines are grouped + into passages (blank lines act as separators). For each passage the + lines are joined into a single string, split on ``". "``, and each + resulting sentence is tested for entity substitution via + :meth:`_substitute_entity`. A pair is emitted for every sentence that + yields a valid counterfactual. Collection halts early once *n_samples* + pairs have been gathered (if *n_samples* is not ``None``). + + Parameters + ---------- + wiki_text_path : str or Path + Path to the processed ``wiki_clean.txt`` file. + n_samples : int or None + Maximum number of pairs to return. + + Returns + ------- + list[dict] + Each element is ``{"original": str, "counterfactual": str}``. + """ + pairs: List[dict] = [] + current_lines: List[str] = [] + + def _process_passage(lines: List[str]) -> None: + passage_text = " ".join(lines) + all_entities = _ENTITY_RE.findall(passage_text) + if not all_entities: + return + sentences = passage_text.split(". ") + for sentence in sentences: + if n_samples is not None and len(pairs) >= n_samples: + return + sentence = sentence.strip() + if not sentence: + continue + counterfactual = WikipediaFactualEvaluator._substitute_entity( + sentence, all_entities + ) + if counterfactual is not None and counterfactual != sentence: + pairs.append({"original": sentence, "counterfactual": counterfactual}) + + with open(wiki_text_path, encoding="utf-8") as fh: + for raw_line in fh: + line = raw_line.rstrip("\n") + if line.strip(): + current_lines.append(line.strip()) + else: + if current_lines: + _process_passage(current_lines) + current_lines = [] + if n_samples is not None and len(pairs) >= n_samples: + return pairs + + # Handle final passage if file has no trailing blank line + if current_lines: + _process_passage(current_lines) + + return pairs + + # ------------------------------------------------------------------ + # BaseEvaluator interface + # ------------------------------------------------------------------ + + def evaluate(self, model, tokenizer) -> dict: + """ + Compute factual consistency scores for *model*. + + Extracts sentence pairs from the configured Wikipedia text file, then + computes perplexity for each original and counterfactual sentence using + the same teacher-forced method as + :class:`~openverifiablellm.eval.perplexity.PerplexityEvaluator`. + + ``random.seed(42)`` is applied before any entity selection to ensure + fully reproducible results. + + Parameters + ---------- + model : callable + ``model(input_ids) -> 2-D sequence`` of shape + ``(len(input_ids), vocab_size)``, as described in + :meth:`~openverifiablellm.eval.perplexity.PerplexityEvaluator.compute_sentence_perplexity`. + tokenizer : object + Object with ``encode(text: str) -> list[int]``. + + Returns + ------- + dict + A dictionary with the following keys: + + * **factual_perplexity** (*float*) — mean perplexity on original + sentences. + * **counterfactual_perplexity** (*float*) — mean perplexity on + counterfactual sentences. + * **factual_score** (*float*) — mean per-pair difference + ``(counterfactual_ppl - factual_ppl)``; positive means the model + correctly prefers factual sentences (good), negative means the + model prefers counterfactual sentences (bad). + """ + random.seed(42) + pairs = self._extract_passages(self.wiki_text_path, self.n_samples) + + if not pairs: + return { + "factual_perplexity": float("inf"), + "counterfactual_perplexity": float("inf"), + "factual_score": float("inf"), + } + + factual_ppls: List[float] = [] + counterfactual_ppls: List[float] = [] + score_diffs: List[float] = [] + + for pair in pairs: + factual_tokens = tokenizer.encode(pair["original"]) + cf_tokens = tokenizer.encode(pair["counterfactual"]) + + factual_ppl = PerplexityEvaluator.compute_sentence_perplexity( + model, factual_tokens + ) + cf_ppl = PerplexityEvaluator.compute_sentence_perplexity(model, cf_tokens) + + if not math.isfinite(factual_ppl) or not math.isfinite(cf_ppl): + continue + + factual_ppls.append(factual_ppl) + counterfactual_ppls.append(cf_ppl) + score_diffs.append(cf_ppl - factual_ppl) + + n = len(factual_ppls) + if n == 0: + return { + "factual_perplexity": float("nan"), + "counterfactual_perplexity": float("nan"), + "factual_score": float("nan"), + } + return { + "factual_perplexity": sum(factual_ppls) / n, + "counterfactual_perplexity": sum(counterfactual_ppls) / n, + "factual_score": sum(score_diffs) / n, + } diff --git a/openverifiablellm/eval/perplexity.py b/openverifiablellm/eval/perplexity.py new file mode 100644 index 0000000..fa1f0be --- /dev/null +++ b/openverifiablellm/eval/perplexity.py @@ -0,0 +1,231 @@ +""" +openverifiablellm/eval/perplexity.py + +Perplexity evaluator for language models. +""" + +import math +from typing import List, Optional + +from .base import BaseEvaluator + + +class PerplexityEvaluator(BaseEvaluator): + """ + Evaluates language-model perplexity on a HuggingFace benchmark dataset. + + Perplexity is computed with a teacher-forced sliding-window approach: + for each token position *i* the model receives tokens ``[0 .. i-1]`` + and the negative log-probability of token ``[i]`` is accumulated. + The final perplexity is ``exp(mean_NLL)``. + + Parameters + ---------- + benchmark : str + HuggingFace dataset identifier. Default ``"wikitext"``. + n_samples : int or None + Maximum number of non-empty samples to evaluate. ``None`` means + evaluate the whole dataset. Default ``50``. + stride : int + Window stride used when the sequence exceeds the model's context + window. Default ``512``. + """ + + def __init__( + self, + benchmark: str = "wikitext", + n_samples: Optional[int] = 50, + stride: int = 512, + split: Optional[str] = None, + ): + self.benchmark = benchmark + self.n_samples = n_samples + self.stride = stride + self.split = split + + # ------------------------------------------------------------------ + # Mock helpers + # ------------------------------------------------------------------ + + @staticmethod + def uniform_model(vocab_size: int = 1000): + """ + Return a mock model that produces uniform (all-zero) logits. + + Useful for unit testing: because all logits are equal, the + log-softmax is ``-log(vocab_size)`` at every position, giving a + predictable perplexity of exactly ``vocab_size``. + + Parameters + ---------- + vocab_size : int + Vocabulary size of the mock model. Default ``1000``. + + Returns + ------- + callable + ``model(input_ids) -> list[list[float]]`` of shape + ``(len(input_ids), vocab_size)``. + """ + + def _model(input_ids): + return [[0.0] * vocab_size for _ in input_ids] + + return _model + + # ------------------------------------------------------------------ + # Core computation + # ------------------------------------------------------------------ + + @staticmethod + def compute_sentence_perplexity(model, token_ids: List[int]) -> float: + """ + Compute the perplexity of *token_ids* under *model*. + + Parameters + ---------- + model : callable + ``model(input_ids) -> 2-D sequence`` of shape + ``(len(input_ids), vocab_size)``. + token_ids : list[int] + Tokenised sentence. + + Returns + ------- + float + Perplexity (≥ 1). Returns ``float("inf")`` for sequences + shorter than 2 tokens. + """ + if len(token_ids) < 2: + return float("inf") + + inputs = token_ids[:-1] + targets = token_ids[1:] + + logits_batch = model(inputs) # shape: (n-1, vocab_size) + + if len(logits_batch) != len(targets): + raise ValueError( + f"Model returned {len(logits_batch)} logit vectors but expected " + f"{len(targets)} (one per target token)." + ) + + nll_sum = 0.0 + for logits, target in zip(logits_batch, targets): + # numerically-stable log-softmax + max_l = max(logits) + exp_shifted = [math.exp(v - max_l) for v in logits] + log_sum = math.log(sum(exp_shifted)) + log_prob_target = (logits[target] - max_l) - log_sum + nll_sum -= log_prob_target + + return math.exp(nll_sum / len(targets)) + + @staticmethod + def compute_sequence_perplexity(model, token_ids: List[int], stride: int = 512) -> float: + """ + Compute perplexity over a (possibly long) sequence using non-overlapping + stride-sized windows. + + The sequence is partitioned into windows of *stride* tokens. Each + window contributes its token predictions to a pooled NLL. The final + perplexity is ``exp(total_NLL / total_scored_tokens)``. + + For sequences shorter than *stride* + 1 tokens the result is + identical to :meth:`compute_sentence_perplexity`. + + Parameters + ---------- + model : callable + ``model(input_ids) -> 2-D sequence`` of shape + ``(len(input_ids), vocab_size)``. + token_ids : list[int] + Tokenised sequence. + stride : int + Number of tokens scored per window. Default ``512``. + + Returns + ------- + float + Perplexity (≥ 1). Returns ``float("inf")`` for sequences + shorter than 2 tokens. + """ + if len(token_ids) < 2: + return float("inf") + + nll_sum = 0.0 + n_scored = 0 + n = len(token_ids) + + for start in range(0, n - 1, stride): + end = min(start + stride + 1, n) + window = token_ids[start:end] + if len(window) < 2: + break + inputs = window[:-1] + targets = window[1:] + logits_batch = model(inputs) + if len(logits_batch) != len(targets): + raise ValueError( + f"Model returned {len(logits_batch)} logit vectors but expected " + f"{len(targets)} (one per target token)." + ) + for logits, target in zip(logits_batch, targets): + max_l = max(logits) + exp_shifted = [math.exp(v - max_l) for v in logits] + log_sum = math.log(sum(exp_shifted)) + nll_sum -= (logits[target] - max_l) - log_sum + n_scored += 1 + + return math.exp(nll_sum / n_scored) if n_scored > 0 else float("inf") + + # ------------------------------------------------------------------ + # BaseEvaluator interface + # ------------------------------------------------------------------ + + def evaluate(self, model, tokenizer) -> dict: + """ + Compute mean perplexity on *self.benchmark*. + + Parameters + ---------- + model : callable + Callable as described in :meth:`compute_sentence_perplexity`. + tokenizer : object + Object with ``encode(text: str) -> list[int]``. + + Returns + ------- + dict + ``{"perplexity": float}`` — mean perplexity across evaluated + sentences. + """ + import datasets as hf_datasets # deferred; runtime dep + + if self.split is not None: + ds = hf_datasets.load_dataset(self.benchmark, split=self.split, streaming=True) + else: + _splits_to_try = ("test", "validation", "train") + for _s in _splits_to_try: + try: + ds = hf_datasets.load_dataset(self.benchmark, split=_s, streaming=True) + break + except Exception: + continue + else: + raise ValueError( + f"Dataset {self.benchmark!r} has none of the expected splits: " + f"{_splits_to_try}. Pass split= explicitly." + ) + scores = [] + for row in ds: + text = row.get("text", "") + if not text.strip(): + continue + if self.n_samples is not None and len(scores) >= self.n_samples: + break + token_ids = tokenizer.encode(text) + scores.append(self.compute_sequence_perplexity(model, token_ids, self.stride)) + + mean_ppl = float(sum(scores) / len(scores)) if scores else float("inf") + return {"perplexity": mean_ppl} diff --git a/pyproject.toml b/pyproject.toml index 96523a0..55ba437 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,14 +12,25 @@ authors = [ requires-python = ">=3.9" dependencies = [ + "datasets", "defusedxml", "sentencepiece", "tokenizers==0.15.2" ] +# Intentionally duplicated from [dependency-groups] below. +# pip uses this section; uv/PEP 735 uses [dependency-groups]. Keep both in sync. +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "ruff>=0.15.4", +] + [tool.setuptools.packages.find] include = ["openverifiablellm*"] +# Intentionally duplicated from [project.optional-dependencies] above. +# uv/PEP 735 uses this section; pip uses [project.optional-dependencies]. Keep both in sync. [dependency-groups] dev = [ "pytest>=7.0", diff --git a/tests/test_factual_eval.py b/tests/test_factual_eval.py new file mode 100644 index 0000000..94d45c1 --- /dev/null +++ b/tests/test_factual_eval.py @@ -0,0 +1,167 @@ +""" +tests/test_factual_eval.py + +Tests for WikipediaFactualEvaluator. + +Run with: + pytest tests/test_factual_eval.py -v +""" + +import math +import random + +import pytest + +from openverifiablellm.eval.factual import WikipediaFactualEvaluator +from openverifiablellm.eval.perplexity import PerplexityEvaluator + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +_WIKI_TEXT = ( + "Albert Einstein was born in Germany.\n" + "He developed the Theory of Relativity.\n" + "Marie Curie was born in Poland.\n" + "\n" + "Isaac Newton discovered gravity in England.\n" + "Newton worked at Cambridge University.\n" +) + + +class _MockTokenizer: + """Tokenizer that maps each character to its ASCII code modulo 100.""" + + def encode(self, text: str) -> list: + return [ord(c) % 100 for c in text.replace(" ", "_")] + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def wiki_file(tmp_path): + p = tmp_path / "wiki_clean.txt" + p.write_text(_WIKI_TEXT, encoding="utf-8") + return p + + +@pytest.fixture() +def mock_model(): + return PerplexityEvaluator.uniform_model(vocab_size=100) + + +@pytest.fixture() +def mock_tokenizer(): + return _MockTokenizer() + + +@pytest.fixture() +def evaluator(wiki_file): + return WikipediaFactualEvaluator(wiki_text_path=wiki_file) + + +# --------------------------------------------------------------------------- +# _extract_passages +# --------------------------------------------------------------------------- + + +def test_extract_passages_returns_pairs(wiki_file): + random.seed(42) + pairs = WikipediaFactualEvaluator._extract_passages(wiki_file, n_samples=None) + assert len(pairs) > 0 + for pair in pairs: + assert "original" in pair + assert "counterfactual" in pair + + +def test_counterfactual_differs_from_original(wiki_file): + random.seed(42) + pairs = WikipediaFactualEvaluator._extract_passages(wiki_file, n_samples=None) + assert len(pairs) > 0 + for pair in pairs: + assert pair["original"] != pair["counterfactual"] + + +def test_n_samples_limits_pairs(wiki_file): + random.seed(42) + pairs = WikipediaFactualEvaluator._extract_passages(wiki_file, n_samples=2) + assert len(pairs) <= 2 + + +# --------------------------------------------------------------------------- +# _substitute_entity +# --------------------------------------------------------------------------- + + +def test_substitute_entity_replaces_entity(): + random.seed(42) + sentence = "Albert Einstein was born in Germany" + candidates = ["Albert Einstein", "Germany", "Marie Curie", "Poland"] + result = WikipediaFactualEvaluator._substitute_entity(sentence, candidates) + assert result is not None + assert result != sentence + assert "Albert Einstein" not in result + + +def test_substitute_entity_returns_none_when_no_entity(): + # All-lowercase sentence has no capitalized sequences + result = WikipediaFactualEvaluator._substitute_entity( + "the cat sat on the mat", ["Germany", "Poland"] + ) + assert result is None + + +# --------------------------------------------------------------------------- +# evaluate() +# --------------------------------------------------------------------------- + + +def test_evaluate_returns_correct_keys(evaluator, mock_model, mock_tokenizer): + result = evaluator.evaluate(mock_model, mock_tokenizer) + assert set(result.keys()) == { + "factual_perplexity", + "counterfactual_perplexity", + "factual_score", + } + + +def test_evaluate_scores_are_finite(evaluator, mock_model, mock_tokenizer): + result = evaluator.evaluate(mock_model, mock_tokenizer) + assert math.isfinite(result["factual_perplexity"]) + assert math.isfinite(result["counterfactual_perplexity"]) + assert math.isfinite(result["factual_score"]) + + +def test_factual_score_is_difference(evaluator, mock_model, mock_tokenizer): + """factual_score must equal mean(cf_ppl - factual_ppl) per pair, + which by linearity of expectation equals counterfactual_perplexity + minus factual_perplexity.""" + result = evaluator.evaluate(mock_model, mock_tokenizer) + expected = result["counterfactual_perplexity"] - result["factual_perplexity"] + assert abs(result["factual_score"] - expected) < 1e-9 + + +def test_n_samples_limits_pairs_via_evaluate(wiki_file, mock_model, mock_tokenizer): + """n_samples=2 must cause evaluate() to process at most 2 pairs.""" + ev = WikipediaFactualEvaluator(wiki_text_path=wiki_file, n_samples=2) + # Verify by checking _extract_passages directly with n_samples=2 + random.seed(42) + pairs = WikipediaFactualEvaluator._extract_passages(wiki_file, n_samples=2) + assert len(pairs) <= 2 + # evaluate() must still complete without error + result = ev.evaluate(mock_model, mock_tokenizer) + assert set(result.keys()) == { + "factual_perplexity", + "counterfactual_perplexity", + "factual_score", + } + + +def test_determinism(evaluator, mock_model, mock_tokenizer): + """Calling evaluate() twice on the same input must return identical results.""" + result1 = evaluator.evaluate(mock_model, mock_tokenizer) + result2 = evaluator.evaluate(mock_model, mock_tokenizer) + assert result1 == result2