Skip to content

Commit 0319c19

Browse files
authored
Feat/llamaindex: adding llamaindex (#999)
fixes: #557 its been long 🙂
1 parent e2c57b1 commit 0319c19

File tree

12 files changed

+718
-200
lines changed

12 files changed

+718
-200
lines changed

docs/howtos/integrations/llamaindex.ipynb

+444-183
Large diffs are not rendered by default.

src/ragas/embeddings/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
BaseRagasEmbeddings,
33
HuggingfaceEmbeddings,
44
LangchainEmbeddingsWrapper,
5+
LlamaIndexEmbeddingsWrapper,
56
embedding_factory,
67
)
78

89
__all__ = [
910
"HuggingfaceEmbeddings",
1011
"BaseRagasEmbeddings",
1112
"LangchainEmbeddingsWrapper",
13+
"LlamaIndexEmbeddingsWrapper",
1214
"embedding_factory",
1315
]

src/ragas/embeddings/base.py

+25
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
from ragas.run_config import RunConfig, add_async_retry, add_retry
1515

16+
if t.TYPE_CHECKING:
17+
from llama_index.core.base.embeddings.base import BaseEmbedding
18+
1619
DEFAULT_MODEL_NAME = "BAAI/bge-small-en-v1.5"
1720

1821

@@ -153,6 +156,28 @@ def predict(self, texts: List[List[str]]) -> List[List[float]]:
153156
return predictions.tolist()
154157

155158

159+
class LlamaIndexEmbeddingsWrapper(BaseRagasEmbeddings):
160+
def __init__(
161+
self, embeddings: BaseEmbedding, run_config: t.Optional[RunConfig] = None
162+
):
163+
self.embeddings = embeddings
164+
if run_config is None:
165+
run_config = RunConfig()
166+
self.set_run_config(run_config)
167+
168+
def embed_query(self, text: str) -> t.List[float]:
169+
return self.embeddings.get_query_embedding(text)
170+
171+
def embed_documents(self, texts: t.List[str]) -> t.List[t.List[float]]:
172+
return self.embeddings.get_text_embedding_batch(texts)
173+
174+
async def aembed_query(self, text: str) -> t.List[float]:
175+
return await self.embeddings.aget_query_embedding(text)
176+
177+
async def aembed_documents(self, texts: t.List[str]) -> t.List[t.List[float]]:
178+
return await self.embeddings.aget_text_embedding_batch(texts)
179+
180+
156181
def embedding_factory(
157182
model: str = "text-embedding-ada-002", run_config: t.Optional[RunConfig] = None
158183
) -> BaseRagasEmbeddings:

src/ragas/executor.py

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818

1919
def runner_exception_hook(args: threading.ExceptHookArgs):
20-
print(args)
2120
raise args.exc_type
2221

2322

src/ragas/integrations/llama_index.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import typing as t
5+
from copy import copy
6+
from uuid import uuid4
7+
8+
from datasets import Dataset
9+
10+
from ragas.embeddings import LlamaIndexEmbeddingsWrapper
11+
from ragas.evaluation import evaluate as ragas_evaluate
12+
from ragas.exceptions import ExceptionInRunner
13+
from ragas.executor import Executor
14+
from ragas.llms import LlamaIndexLLMWrapper
15+
from ragas.validation import EVALMODE_TO_COLUMNS, validate_evaluation_modes
16+
17+
if t.TYPE_CHECKING:
18+
from llama_index.core.base.embeddings.base import (
19+
BaseEmbedding as LlamaIndexEmbeddings,
20+
)
21+
from llama_index.core.base.llms.base import BaseLLM as LlamaindexLLM
22+
23+
from ragas.evaluation import Result
24+
from ragas.metrics.base import Metric
25+
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
def validate_dataset(dataset: dict, metrics: list[Metric]):
31+
# change EVALMODE_TO_COLUMNS for usecase with no contexts and answer
32+
evalmod_to_columns_llamaindex = copy(EVALMODE_TO_COLUMNS)
33+
for mode in EVALMODE_TO_COLUMNS:
34+
if "answer" in EVALMODE_TO_COLUMNS[mode]:
35+
EVALMODE_TO_COLUMNS[mode].remove("answer")
36+
if "contexts" in EVALMODE_TO_COLUMNS[mode]:
37+
EVALMODE_TO_COLUMNS[mode].remove("contexts")
38+
39+
hf_dataset = Dataset.from_dict(dataset)
40+
validate_evaluation_modes(hf_dataset, metrics, evalmod_to_columns_llamaindex)
41+
42+
43+
def evaluate(
44+
query_engine,
45+
dataset: dict,
46+
metrics: list[Metric],
47+
llm: t.Optional[LlamaindexLLM] = None,
48+
embeddings: t.Optional[LlamaIndexEmbeddings] = None,
49+
raise_exceptions: bool = True,
50+
column_map: t.Optional[t.Dict[str, str]] = None,
51+
) -> Result:
52+
column_map = column_map or {}
53+
54+
# wrap llms and embeddings
55+
li_llm = None
56+
if llm is not None:
57+
li_llm = LlamaIndexLLMWrapper(llm)
58+
li_embeddings = None
59+
if embeddings is not None:
60+
li_embeddings = LlamaIndexEmbeddingsWrapper(embeddings)
61+
62+
# validate and transform dataset
63+
if dataset is None:
64+
raise ValueError("Provide dataset!")
65+
66+
exec = Executor(
67+
desc="Running Query Engine",
68+
keep_progress_bar=True,
69+
raise_exceptions=raise_exceptions,
70+
)
71+
72+
# get query
73+
queries = dataset["question"]
74+
for i, q in enumerate(queries):
75+
exec.submit(query_engine.aquery, q, name=f"query-{i}")
76+
77+
answers: t.List[str] = []
78+
contexts: t.List[t.List[str]] = []
79+
try:
80+
results = exec.results()
81+
if results == []:
82+
raise ExceptionInRunner()
83+
except Exception as e:
84+
raise e
85+
else:
86+
for r in results:
87+
answers.append(r.response)
88+
contexts.append([n.node.text for n in r.source_nodes])
89+
90+
# create HF dataset
91+
hf_dataset = Dataset.from_dict(
92+
{
93+
"question": queries,
94+
"contexts": contexts,
95+
"answer": answers,
96+
}
97+
)
98+
if "ground_truth" in dataset:
99+
hf_dataset = hf_dataset.add_column(
100+
name="ground_truth",
101+
column=dataset["ground_truth"],
102+
new_fingerprint=str(uuid4()),
103+
)
104+
105+
results = ragas_evaluate(
106+
dataset=hf_dataset,
107+
metrics=metrics,
108+
llm=li_llm,
109+
embeddings=li_embeddings,
110+
raise_exceptions=raise_exceptions,
111+
)
112+
113+
return results

src/ragas/llms/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1-
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper, llm_factory
1+
from ragas.llms.base import (
2+
BaseRagasLLM,
3+
LangchainLLMWrapper,
4+
LlamaIndexLLMWrapper,
5+
llm_factory,
6+
)
27

38
__all__ = [
49
"BaseRagasLLM",
510
"LangchainLLMWrapper",
11+
"LlamaIndexLLMWrapper",
612
"llm_factory",
713
]

src/ragas/llms/base.py

+74-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from langchain_community.chat_models.vertexai import ChatVertexAI
1111
from langchain_community.llms import VertexAI
1212
from langchain_core.language_models import BaseLanguageModel
13-
from langchain_core.outputs import LLMResult
13+
from langchain_core.outputs import Generation, LLMResult
1414
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
1515
from langchain_openai.llms import AzureOpenAI, OpenAI
1616
from langchain_openai.llms.base import BaseOpenAI
@@ -19,6 +19,7 @@
1919

2020
if t.TYPE_CHECKING:
2121
from langchain_core.callbacks import Callbacks
22+
from llama_index.core.base.llms.base import BaseLLM
2223

2324
from ragas.llms.prompt import PromptValue
2425

@@ -203,6 +204,78 @@ def set_run_config(self, run_config: RunConfig):
203204
self.run_config.exception_types = RateLimitError
204205

205206

207+
class LlamaIndexLLMWrapper(BaseRagasLLM):
208+
"""
209+
A Adaptor for LlamaIndex LLMs
210+
"""
211+
212+
def __init__(
213+
self,
214+
llm: BaseLLM,
215+
run_config: t.Optional[RunConfig] = None,
216+
):
217+
self.llm = llm
218+
219+
self._signature = ""
220+
if type(self.llm).__name__.lower() == "bedrock":
221+
self._signature = "bedrock"
222+
if run_config is None:
223+
run_config = RunConfig()
224+
self.set_run_config(run_config)
225+
226+
def check_args(
227+
self,
228+
n: int,
229+
temperature: float,
230+
stop: t.Optional[t.List[str]],
231+
callbacks: Callbacks,
232+
) -> dict[str, t.Any]:
233+
if n != 1:
234+
logger.warning("n values greater than 1 not support for LlamaIndex LLMs")
235+
if temperature != 1e-8:
236+
logger.info("temperature kwarg passed to LlamaIndex LLM")
237+
if stop is not None:
238+
logger.info("stop kwarg passed to LlamaIndex LLM")
239+
if callbacks is not None:
240+
logger.info(
241+
"callbacks not supported for LlamaIndex LLMs, ignoring callbacks"
242+
)
243+
if self._signature == "bedrock":
244+
return {"temperature": temperature}
245+
else:
246+
return {
247+
"n": n,
248+
"temperature": temperature,
249+
"stop": stop,
250+
}
251+
252+
def generate_text(
253+
self,
254+
prompt: PromptValue,
255+
n: int = 1,
256+
temperature: float = 1e-8,
257+
stop: t.Optional[t.List[str]] = None,
258+
callbacks: Callbacks = None,
259+
) -> LLMResult:
260+
kwargs = self.check_args(n, temperature, stop, callbacks)
261+
li_response = self.llm.complete(prompt.to_string(), **kwargs)
262+
263+
return LLMResult(generations=[[Generation(text=li_response.text)]])
264+
265+
async def agenerate_text(
266+
self,
267+
prompt: PromptValue,
268+
n: int = 1,
269+
temperature: float = 1e-8,
270+
stop: t.Optional[t.List[str]] = None,
271+
callbacks: Callbacks = None,
272+
) -> LLMResult:
273+
kwargs = self.check_args(n, temperature, stop, callbacks)
274+
li_response = await self.llm.acomplete(prompt.to_string(), **kwargs)
275+
276+
return LLMResult(generations=[[Generation(text=li_response.text)]])
277+
278+
206279
def llm_factory(
207280
model: str = "gpt-3.5-turbo", run_config: t.Optional[RunConfig] = None
208281
) -> BaseRagasLLM:

src/ragas/llms/prompt.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ class Prompt(BaseModel):
4444
language (str): The language of the prompt (default: "english").
4545
"""
4646

47-
name: str
47+
name: str = ""
4848
instruction: str
4949
output_format_instruction: str = ""
5050
examples: t.List[Example] = []
51-
input_keys: t.List[str]
52-
output_key: str
51+
input_keys: t.List[str] = [""]
52+
output_key: str = ""
5353
output_type: t.Literal["json", "str"] = "json"
5454
language: str = "english"
5555

src/ragas/metrics/_context_relevancy.py

-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def _compute_score(self, response: str, row: t.Dict) -> float:
6060
if response.lower() != "insufficient information."
6161
else []
6262
)
63-
# print(len(indices))
6463
if len(context_sents) == 0:
6564
return 0
6665
else:

src/ragas/testset/generator.py

+45-7
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,18 @@
77

88
import pandas as pd
99
from datasets import Dataset
10-
from langchain_core.embeddings import Embeddings
11-
from langchain_core.language_models import BaseLanguageModel
1210
from langchain_openai.chat_models import ChatOpenAI
1311
from langchain_openai.embeddings import OpenAIEmbeddings
1412

1513
from ragas._analytics import TestsetGenerationEvent, track
16-
from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper
14+
from ragas.embeddings.base import (
15+
BaseRagasEmbeddings,
16+
LangchainEmbeddingsWrapper,
17+
LlamaIndexEmbeddingsWrapper,
18+
)
1719
from ragas.exceptions import ExceptionInRunner
1820
from ragas.executor import Executor
19-
from ragas.llms import BaseRagasLLM, LangchainLLMWrapper
21+
from ragas.llms import BaseRagasLLM, LangchainLLMWrapper, LlamaIndexLLMWrapper
2022
from ragas.run_config import RunConfig
2123
from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore
2224
from ragas.testset.evolutions import (
@@ -34,6 +36,12 @@
3436

3537
if t.TYPE_CHECKING:
3638
from langchain_core.documents import Document as LCDocument
39+
from langchain_core.embeddings import Embeddings as LangchainEmbeddings
40+
from langchain_core.language_models import BaseLanguageModel as LangchainLLM
41+
from llama_index.core.base.embeddings.base import (
42+
BaseEmbedding as LlamaIndexEmbeddings,
43+
)
44+
from llama_index.core.base.llms.base import BaseLLM as LlamaindexLLM
3745
from llama_index.core.schema import Document as LlamaindexDocument
3846

3947
logger = logging.getLogger(__name__)
@@ -75,9 +83,9 @@ class TestsetGenerator:
7583
@classmethod
7684
def from_langchain(
7785
cls,
78-
generator_llm: BaseLanguageModel,
79-
critic_llm: BaseLanguageModel,
80-
embeddings: Embeddings,
86+
generator_llm: LangchainLLM,
87+
critic_llm: LangchainLLM,
88+
embeddings: LangchainEmbeddings,
8189
docstore: t.Optional[DocumentStore] = None,
8290
run_config: t.Optional[RunConfig] = None,
8391
chunk_size: int = 1024,
@@ -104,6 +112,36 @@ def from_langchain(
104112
docstore=docstore,
105113
)
106114

115+
@classmethod
116+
def from_llama_index(
117+
cls,
118+
generator_llm: LlamaindexLLM,
119+
critic_llm: LlamaindexLLM,
120+
embeddings: LlamaIndexEmbeddings,
121+
docstore: t.Optional[DocumentStore] = None,
122+
run_config: t.Optional[RunConfig] = None,
123+
) -> "TestsetGenerator":
124+
generator_llm_model = LlamaIndexLLMWrapper(generator_llm)
125+
critic_llm_model = LlamaIndexLLMWrapper(critic_llm)
126+
embeddings_model = LlamaIndexEmbeddingsWrapper(embeddings)
127+
keyphrase_extractor = KeyphraseExtractor(llm=generator_llm_model)
128+
if docstore is None:
129+
from langchain.text_splitter import TokenTextSplitter
130+
131+
splitter = TokenTextSplitter(chunk_size=1024, chunk_overlap=0)
132+
docstore = InMemoryDocumentStore(
133+
splitter=splitter,
134+
embeddings=embeddings_model,
135+
extractor=keyphrase_extractor,
136+
run_config=run_config,
137+
)
138+
return cls(
139+
generator_llm=generator_llm_model,
140+
critic_llm=critic_llm_model,
141+
embeddings=embeddings_model,
142+
docstore=docstore,
143+
)
144+
107145
@classmethod
108146
@deprecated("0.1.4", removal="0.2.0", alternative="from_langchain")
109147
def with_openai(

0 commit comments

Comments
 (0)