Skip to content

Commit 3076f50

Browse files
authored
fix: non llm based metrics (#1268)
1) rename metrics 2) delay import of optional dependencies
1 parent c615a9f commit 3076f50

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

src/ragas/metrics/bleu_score.py renamed to src/ragas/metrics/_bleu_score.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
from dataclasses import dataclass, field
33

44
from langchain_core.callbacks import Callbacks
5-
from nltk.tokenize import word_tokenize
6-
from nltk.translate.bleu_score import corpus_bleu
75

86
from ragas.dataset_schema import SingleTurnSample
97
from ragas.metrics._faithfulness import HasSegmentMethod
@@ -21,7 +19,16 @@ class BleuScore(SingleTurnMetric):
2119
sentence_segmenter: t.Optional[HasSegmentMethod] = None
2220

2321
def __post_init__(self):
22+
try:
23+
from nltk.tokenize import word_tokenize
24+
from nltk.translate.bleu_score import corpus_bleu
25+
except ImportError:
26+
raise ImportError(
27+
"nltk is required for bleu score. Please install it using `pip install nltk`"
28+
)
2429
self.segmenter = get_segmenter()
30+
self.word_tokenizer = word_tokenize
31+
self.corpus_bleu = corpus_bleu
2532

2633
def init(self, run_config: RunConfig):
2734
pass
@@ -32,9 +39,11 @@ async def _single_turn_ascore(
3239
reference_sentences = self.segmenter.segment(sample.reference)
3340
response_sentences = self.segmenter.segment(sample.response)
3441

35-
reference = [[word_tokenize(reference)] for reference in reference_sentences]
36-
response = [word_tokenize(response) for response in response_sentences]
37-
score = corpus_bleu(reference, response, weights=self.weights)
42+
reference = [
43+
[self.word_tokenizer(reference)] for reference in reference_sentences
44+
]
45+
response = [self.word_tokenizer(response) for response in response_sentences]
46+
score = self.corpus_bleu(reference, response, weights=self.weights)
3847
assert isinstance(score, float), "Expecting a float"
3948
return score
4049

src/ragas/metrics/_string.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from enum import Enum
44

55
from langchain_core.callbacks import Callbacks
6-
from rapidfuzz import distance
76

87
from ragas.dataset_schema import SingleTurnSample
98
from ragas.metrics.base import MetricType, SingleTurnMetric
@@ -16,13 +15,6 @@ class DistanceMeasure(Enum):
1615
JARO = "jaro"
1716

1817

19-
DISTANCE_MEASURE_MAP = {
20-
DistanceMeasure.LEVENSHTEIN: distance.Levenshtein,
21-
DistanceMeasure.HAMMING: distance.Hamming,
22-
DistanceMeasure.JARO: distance.Jaro,
23-
}
24-
25-
2618
@dataclass
2719
class ExactMatch(SingleTurnMetric):
2820
name: str = "exact_match" # type: ignore
@@ -42,6 +34,7 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
4234
return await self._single_turn_ascore(SingleTurnSample(**row), callbacks)
4335

4436

37+
@dataclass
4538
class StringPresent(SingleTurnMetric):
4639
name: str = "string_present" # type: ignore
4740
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
@@ -64,13 +57,28 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
6457
return await self._single_turn_ascore(SingleTurnSample(**row), callbacks)
6558

6659

67-
class StringDistance(SingleTurnMetric):
68-
name: str = "string_distance" # type: ignore
60+
@dataclass
61+
class NonLLMStringSimilarity(SingleTurnMetric):
62+
name: str = "non_llm_string_similarity" # type: ignore
6963
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
7064
default_factory=lambda: {MetricType.SINGLE_TURN: {"reference", "response"}}
7165
)
7266
distance_measure: DistanceMeasure = DistanceMeasure.LEVENSHTEIN
7367

68+
def __post_init__(self):
69+
try:
70+
from rapidfuzz import distance
71+
except ImportError:
72+
raise ImportError(
73+
"rapidfuzz is required for string distance. Please install it using `pip install rapidfuzz`"
74+
)
75+
76+
self.distance_measure_map = {
77+
DistanceMeasure.LEVENSHTEIN: distance.Levenshtein,
78+
DistanceMeasure.HAMMING: distance.Hamming,
79+
DistanceMeasure.JARO: distance.Jaro,
80+
}
81+
7482
def init(self, run_config: RunConfig):
7583
pass
7684

@@ -81,7 +89,7 @@ async def _single_turn_ascore(
8189
response = sample.response
8290
assert isinstance(reference, str), "Expecting a string"
8391
assert isinstance(response, str), "Expecting a string"
84-
return 1 - DISTANCE_MEASURE_MAP[self.distance_measure].normalized_distance(
92+
return 1 - self.distance_measure_map[self.distance_measure].normalized_distance(
8593
reference, response
8694
)
8795

0 commit comments

Comments
 (0)