7
7
8
8
import pandas as pd
9
9
from datasets import Dataset
10
- from langchain_core .embeddings import Embeddings
11
- from langchain_core .language_models import BaseLanguageModel
12
10
from langchain_openai .chat_models import ChatOpenAI
13
11
from langchain_openai .embeddings import OpenAIEmbeddings
14
12
15
13
from ragas ._analytics import TestsetGenerationEvent , track
16
- from ragas .embeddings .base import BaseRagasEmbeddings , LangchainEmbeddingsWrapper
14
+ from ragas .embeddings .base import (
15
+ BaseRagasEmbeddings ,
16
+ LangchainEmbeddingsWrapper ,
17
+ LlamaIndexEmbeddingsWrapper ,
18
+ )
17
19
from ragas .exceptions import ExceptionInRunner
18
20
from ragas .executor import Executor
19
- from ragas .llms import BaseRagasLLM , LangchainLLMWrapper
21
+ from ragas .llms import BaseRagasLLM , LangchainLLMWrapper , LlamaIndexLLMWrapper
20
22
from ragas .run_config import RunConfig
21
23
from ragas .testset .docstore import Document , DocumentStore , InMemoryDocumentStore
22
24
from ragas .testset .evolutions import (
34
36
35
37
if t .TYPE_CHECKING :
36
38
from langchain_core .documents import Document as LCDocument
39
+ from langchain_core .embeddings import Embeddings as LangchainEmbeddings
40
+ from langchain_core .language_models import BaseLanguageModel as LangchainLLM
41
+ from llama_index .core .base .embeddings .base import (
42
+ BaseEmbedding as LlamaIndexEmbeddings ,
43
+ )
44
+ from llama_index .core .base .llms .base import BaseLLM as LlamaindexLLM
37
45
from llama_index .core .schema import Document as LlamaindexDocument
38
46
39
47
logger = logging .getLogger (__name__ )
@@ -75,9 +83,9 @@ class TestsetGenerator:
75
83
@classmethod
76
84
def from_langchain (
77
85
cls ,
78
- generator_llm : BaseLanguageModel ,
79
- critic_llm : BaseLanguageModel ,
80
- embeddings : Embeddings ,
86
+ generator_llm : LangchainLLM ,
87
+ critic_llm : LangchainLLM ,
88
+ embeddings : LangchainEmbeddings ,
81
89
docstore : t .Optional [DocumentStore ] = None ,
82
90
run_config : t .Optional [RunConfig ] = None ,
83
91
chunk_size : int = 1024 ,
@@ -104,6 +112,36 @@ def from_langchain(
104
112
docstore = docstore ,
105
113
)
106
114
115
+ @classmethod
116
+ def from_llama_index (
117
+ cls ,
118
+ generator_llm : LlamaindexLLM ,
119
+ critic_llm : LlamaindexLLM ,
120
+ embeddings : LlamaIndexEmbeddings ,
121
+ docstore : t .Optional [DocumentStore ] = None ,
122
+ run_config : t .Optional [RunConfig ] = None ,
123
+ ) -> "TestsetGenerator" :
124
+ generator_llm_model = LlamaIndexLLMWrapper (generator_llm )
125
+ critic_llm_model = LlamaIndexLLMWrapper (critic_llm )
126
+ embeddings_model = LlamaIndexEmbeddingsWrapper (embeddings )
127
+ keyphrase_extractor = KeyphraseExtractor (llm = generator_llm_model )
128
+ if docstore is None :
129
+ from langchain .text_splitter import TokenTextSplitter
130
+
131
+ splitter = TokenTextSplitter (chunk_size = 1024 , chunk_overlap = 0 )
132
+ docstore = InMemoryDocumentStore (
133
+ splitter = splitter ,
134
+ embeddings = embeddings_model ,
135
+ extractor = keyphrase_extractor ,
136
+ run_config = run_config ,
137
+ )
138
+ return cls (
139
+ generator_llm = generator_llm_model ,
140
+ critic_llm = critic_llm_model ,
141
+ embeddings = embeddings_model ,
142
+ docstore = docstore ,
143
+ )
144
+
107
145
@classmethod
108
146
@deprecated ("0.1.4" , removal = "0.2.0" , alternative = "from_langchain" )
109
147
def with_openai (
0 commit comments