Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 204 additions & 12 deletions dspy/evaluate/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,102 @@


def EM(prediction, answers_list): # noqa: N802
assert isinstance(answers_list, list)
"""Compute the Exact Match (EM) metric between a prediction and reference answers.

Returns True if any reference exactly matches the prediction after normalization;
otherwise False. Normalization applies Unicode NFD, lowercasing, punctuation
removal, English article removal ("a", "an", "the"), and whitespace collapse.

Args:
prediction (str): Predicted answer string.
answers_list (list[str]): List of reference answers.

Returns:
bool: Whether any reference exactly equals the prediction after normalization.

Example:
```python
EM("The Eiffel Tower", ["Eiffel Tower", "Louvre"]) # True

EM("paris", ["Paris"]) # True
EM("paris", ["Paris, France"]) # False
```
"""
if not isinstance(answers_list, list):
raise ValueError(f"`answers_list` must be a list, got {type(answers_list)}")

return max(em_score(prediction, ans) for ans in answers_list)


def F1(prediction, answers_list): # noqa: N802
assert isinstance(answers_list, list)
"""Compute the maximum token-level F1 score against reference answers.

Strings are normalized (same as in `EM`) and whitespace-tokenized. The function
returns the maximum F1 over all provided references.

Args:
prediction (str): Predicted answer string.
answers_list (list[str]): List of reference answers.

Returns:
float: Highest F1 score in [0.0, 1.0].

Example:
```python
round(F1("Eiffel Tower is in Paris", ["Paris"]), 2) # 0.33
```
"""
if not isinstance(answers_list, list):
raise ValueError(f"`answers_list` must be a list, got {type(answers_list)}")

return max(f1_score(prediction, ans) for ans in answers_list)


def HotPotF1(prediction, answers_list): # noqa: N802
assert isinstance(answers_list, list)
"""Compute the maximum HotPotQA-style F1 score against reference answers.

Like `F1`, but if either normalized side is one of {"yes", "no", "noanswer"}
and they differ, the score is 0. Otherwise, standard token-level F1 is used.

Args:
prediction (str): Predicted answer.
answers_list (list[str]): List of reference answers.

Returns:
float: Highest HotPotQA-style F1 in [0.0, 1.0].

Example:
```python
HotPotF1("yes", ["no"]) # 0.0
```
"""
if not isinstance(answers_list, list):
raise ValueError(f"`answers_list` must be a list, got {type(answers_list)}")

return max(hotpot_f1_score(prediction, ans) for ans in answers_list)


def normalize_text(s):
"""Normalize text for string and token comparisons.

Steps:
1) Unicode NFD normalization
2) lowercasing
3) punctuation removal
4) English article removal ("a", "an", "the")
5) whitespace collapse

Args:
s (str): Input string.

Returns:
str: Normalized string.

Example:
```python
normalize_text("The, Eiffel Tower!") # "eiffel tower"
```
"""
s = unicodedata.normalize("NFD", s)

def remove_articles(text):
Expand All @@ -46,15 +124,42 @@ def lower(text):


def em_score(prediction, ground_truth):
return normalize_text(prediction) == normalize_text(ground_truth)
"""Compute boolean exact match after normalization.

Args:
prediction (str): Predicted answer.
ground_truth (str): Reference answer.

# See: https://github.com/hotpotqa/hotpot/blob/master/hotpot_evaluate_v1.py
# See: https://rajpurkar.github.io/SQuAD-explorer/ under Evaluation Script
# See: QReCC's
Returns:
bool: True if normalized strings are identical; otherwise False.

Example:
```python
em_score("Paris", "paris") # True
```
"""
return normalize_text(prediction) == normalize_text(ground_truth)


def f1_score(prediction, ground_truth):
"""Compute token-level F1 between prediction and reference (after normalization).

Strings are normalized (see `normalize_text`) and split by whitespace. F1 is
computed from token precision and recall. If there is no token overlap, returns 0.
If both sides are empty, a diagnostic message is printed; score remains 0.

Args:
prediction (str): Predicted answer.
ground_truth (str): Reference answer.

Returns:
float: F1 score in [0.0, 1.0].

Example:
```python
round(f1_score("the Eiffel Tower", "Eiffel Tower"), 2) # 1.0
```
"""
prediction_tokens = normalize_text(prediction).split()
ground_truth_tokens = normalize_text(ground_truth).split()

Expand All @@ -76,6 +181,23 @@ def f1_score(prediction, ground_truth):


def hotpot_f1_score(prediction, ground_truth):
"""Compute HotPotQA-style token F1 with special labels.

If either normalized string is in {"yes", "no", "noanswer"} and they differ,
the score is 0. Otherwise compute standard token F1 after normalization.

Args:
prediction (str): Predicted answer.
ground_truth (str): Reference answer.

Returns:
float: HotPotQA-style F1 score in [0.0, 1.0].

Example:
```python
hotpot_f1_score("no", "yes") # 0.0
```
"""
normalized_prediction = normalize_text(prediction)
normalized_ground_truth = normalize_text(ground_truth)

Expand All @@ -97,6 +219,24 @@ def hotpot_f1_score(prediction, ground_truth):


def precision_score(prediction, ground_truth):
"""Compute token-level precision of prediction against reference (after normalization).

Precision is (# overlapping tokens) / (# tokens in prediction). If there is no
token overlap, returns 0. If both sides are empty, a diagnostic message is printed;
precision remains 0.

Args:
prediction (str): Predicted answer.
ground_truth (str): Reference answer.

Returns:
float: Precision in [0.0, 1.0].

Example:
```python
precision_score("eiffel tower in paris", "eiffel tower") # 0.67
```
"""
prediction_tokens = normalize_text(prediction).split()
ground_truth_tokens = normalize_text(ground_truth).split()

Expand All @@ -105,22 +245,23 @@ def precision_score(prediction, ground_truth):

if len(prediction_tokens) == len(ground_truth_tokens) == 0:
# Unlike most tasks, QReCC and SQuAD-2.0 assign 1.0 in this edge case. We don't for uniformity.
print_message("\n#> Precision Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n")
print_message(
"\n#> Precision Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n"
)

if num_same == 0:
return 0

precision = 1.0 * num_same / len(prediction_tokens)

return precision


def _passage_match(passages: list[str], answers: list[str]) -> bool:
"""Returns True if any of the passages contains the answer."""
"""Return True if any passage contains any answer (normalized & DPR-normalized)."""
from dspy.dsp.utils import DPR_normalize, has_answer

def passage_has_answers(passage: str, answers: list[str]) -> bool:
"""Returns True if the passage contains the answer."""
"""Return True if the passage contains any of the answers."""
return has_answer(
tokenized_answers=[DPR_normalize(normalize_text(ans)) for ans in answers],
text=normalize_text(passage),
Expand All @@ -130,15 +271,44 @@ def passage_has_answers(passage: str, answers: list[str]) -> bool:


def _answer_match(prediction, answers, frac=1.0):
"""Returns True if the prediction matches any of the answers."""
"""Return True if prediction matches any answer.

When `frac >= 1.0`, require exact match (EM). Otherwise, return whether the
maximum token-level F1 across answers is at least `frac`.
"""
if frac >= 1.0:
return EM(prediction, answers)

return F1(prediction, answers) >= frac


def answer_exact_match(example, pred, trace=None, frac=1.0):
"""Evaluate exact match or F1-thresholded match for an example/prediction pair.

If `example.answer` is a string, compare `pred.answer` against it. If it's a list,
compare against any of the references. When `frac >= 1.0` (default), use EM;
otherwise require that the maximum F1 across references is at least `frac`.

Args:
example: `dspy.Example` object with field `answer` (str or list[str]).
pred: `dspy.Prediction` object with field `answer` (str).
trace: Unused; reserved for compatibility.
frac (float, optional): Threshold in [0.0, 1.0]. `1.0` means EM.

Returns:
bool: True if the match condition holds; otherwise False.

Example:
```python
import dspy

example = dspy.Example(answer=["Eiffel Tower", "Louvre"])
pred = dspy.Prediction(answer="The Eiffel Tower")

answer_exact_match(example, pred, frac=1.0) # equivalent to EM, True
answer_exact_match(example, pred, frac=0.5) # True
```
"""
if isinstance(example.answer, str):
return _answer_match(pred.answer, [example.answer], frac=frac)
elif isinstance(example.answer, list):
Expand All @@ -148,6 +318,28 @@ def answer_exact_match(example, pred, trace=None, frac=1.0):


def answer_passage_match(example, pred, trace=None):
"""Return True if any passage in `pred.context` contains the answer(s).

Strings are normalized (and passages also use DPR normalization internally).

Args:
example: `dspy.Example` object with field `answer` (str or list[str]).
pred: `dspy.Prediction` object with field `context` (list[str]) containing passages.
trace: Unused; reserved for compatibility.

Returns:
bool: True if any passage contains any reference answer; otherwise False.

Example:
```python
import dspy

example = dspy.Example(answer="Eiffel Tower")
pred = dspy.Prediction(context=["The Eiffel Tower is in Paris.", "..."])

answer_passage_match(example, pred) # True
```
"""
if isinstance(example.answer, str):
return _passage_match(pred.context, [example.answer])
elif isinstance(example.answer, list):
Expand Down