Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat rag graph #1647

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions metagpt/rag/engines/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from typing import Any, Optional, Union

from llama_index.core import SimpleDirectoryReader
from llama_index.core import Settings, SimpleDirectoryReader
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
Expand Down Expand Up @@ -89,6 +89,7 @@ def from_docs(
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
build_graph: bool = False,
) -> "SimpleEngine":
"""From docs.

Expand All @@ -102,6 +103,7 @@ def from_docs(
llm: Must supported by llama index. Default OpenAI.
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
ranker_configs: Configuration for rankers.
build_graph: Whether to build a graph from scratch. Default False.
"""
if not input_dir and not input_files:
raise ValueError("Must provide either `input_dir` or `input_files`.")
Expand All @@ -122,6 +124,7 @@ def from_docs(
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
build_graph=build_graph,
)

@classmethod
Expand All @@ -133,6 +136,7 @@ def from_objs(
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
build_graph: bool = False,
) -> "SimpleEngine":
"""From objs.

Expand All @@ -143,6 +147,7 @@ def from_objs(
llm: Must supported by llama index. Default OpenAI.
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
ranker_configs: Configuration for rankers.
build_graph: Whether to build a graph from scratch. Default False.
"""
objs = objs or []
retriever_configs = retriever_configs or []
Expand All @@ -159,6 +164,7 @@ def from_objs(
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
build_graph=build_graph,
)

@classmethod
Expand All @@ -178,16 +184,16 @@ async def asearch(self, content: str, **kwargs) -> str:
"""Inplement tools.SearchInterface"""
return await self.aquery(content)

def retrieve(self, query: QueryType) -> list[NodeWithScore]:
query_bundle = QueryBundle(query) if isinstance(query, str) else query
def retrieve(self, query_bundle: QueryType) -> list[NodeWithScore]:
query_bundle = QueryBundle(query_bundle) if isinstance(query_bundle, str) else query_bundle

nodes = super().retrieve(query_bundle)
self._try_reconstruct_obj(nodes)
return nodes

async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
async def aretrieve(self, query_bundle: QueryType) -> list[NodeWithScore]:
"""Allow query to be str."""
query_bundle = QueryBundle(query) if isinstance(query, str) else query
query_bundle = QueryBundle(query_bundle) if isinstance(query_bundle, str) else query_bundle

nodes = await super().aretrieve(query_bundle)
self._try_reconstruct_obj(nodes)
Expand Down Expand Up @@ -225,11 +231,16 @@ def _from_nodes(
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
build_graph: bool = False,
) -> "SimpleEngine":
embed_model = cls._resolve_embed_model(embed_model, retriever_configs)
llm = llm or get_rag_llm()
Settings.llm = llm
Settings.embed_model = embed_model

retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model)
retriever = get_retriever(
configs=retriever_configs, nodes=nodes, embed_model=embed_model, build_graph=build_graph
)
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []

return cls(
Expand All @@ -248,6 +259,7 @@ def _from_index(
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
llm = llm or get_rag_llm()
Settings.llm = llm

retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
Expand Down
61 changes: 51 additions & 10 deletions metagpt/rag/factories/retriever.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""RAG Retriever Factory."""


from functools import wraps

import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core import PropertyGraphIndex, StorageContext, VectorStoreIndex
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.property_graph import PGRetriever
from llama_index.core.schema import BaseNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
Expand All @@ -30,6 +31,7 @@
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
MilvusRetrieverConfig,
Neo4jPGRetrieverConfig,
)


Expand Down Expand Up @@ -60,25 +62,42 @@ def __init__(self):
ElasticsearchRetrieverConfig: self._create_es_retriever,
ElasticsearchKeywordRetrieverConfig: self._create_es_retriever,
MilvusRetrieverConfig: self._create_milvus_retriever,
Neo4jPGRetrieverConfig: self._create_neo4j_pg_retriever,
}
super().__init__(creators)

def get_retriever(self, configs: list[BaseRetrieverConfig] = None, **kwargs) -> RAGRetriever:
def get_retriever(
self, configs: list[BaseRetrieverConfig] = None, build_graph: bool = False, **kwargs
) -> RAGRetriever:
"""Creates and returns a retriever instance based on the provided configurations.

If multiple retrievers, using SimpleHybridRetriever.
If build_graph is True and no graph-related retriver_config is provided,
the default is to use memory for graph attribute storage and retrieval.
If a graph-related retriever_config is provided,
the build_graph parameter is ignored and the configured one is used directly.
"""
if not configs:
return self._create_default(**kwargs)
return self._create_default(build_graph=build_graph, **kwargs)

retrievers = super().get_instances(configs, **kwargs)
has_pg_retriever = any(isinstance(retriever, PGRetriever) for retriever in retrievers)

if build_graph and not has_pg_retriever:
pg_index = self._extract_index(None, **kwargs) or self._build_default_pg_index(**kwargs)
retrievers.append(pg_index.as_retriever())

return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]

def _create_default(self, **kwargs) -> RAGRetriever:
index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs)
def _create_default(self, build_graph: bool = False, **kwargs) -> RAGRetriever:
vector_index = self._extract_index(None, **kwargs) or self._build_default_vector_index(**kwargs)
retrievers = [vector_index.as_retriever()]

return index.as_retriever()
if build_graph:
pg_index = self._extract_index(None, **kwargs) or self._build_default_pg_index(**kwargs)
retrievers.append(pg_index.as_retriever())

return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]

def _create_milvus_retriever(self, config: MilvusRetrieverConfig, **kwargs) -> MilvusRetriever:
config.index = self._build_milvus_index(config, **kwargs)
Expand Down Expand Up @@ -106,6 +125,10 @@ def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -

return ElasticsearchRetriever(**config.model_dump())

def _create_neo4j_pg_retriever(self, config: Neo4jPGRetrieverConfig, **kwargs) -> PGRetriever:
graph_index = self._build_neo4j_pg_index(config, **kwargs)
return graph_index.as_retriever(**config.model_dump())

def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)

Expand All @@ -115,13 +138,20 @@ def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[B
def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)

def _build_default_index(self, **kwargs) -> VectorStoreIndex:
index = VectorStoreIndex(
def _build_default_vector_index(self, **kwargs) -> VectorStoreIndex:
vector_index = VectorStoreIndex(
nodes=self._extract_nodes(**kwargs),
embed_model=self._extract_embed_model(**kwargs),
)
return vector_index

return index
def _build_default_pg_index(self, **kwargs):
# build default PropertyGraphIndex
pg_index = PropertyGraphIndex(
nodes=self._extract_nodes(**kwargs),
embed_model=self._extract_embed_model(**kwargs),
)
return pg_index

@get_or_build_index
def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex:
Expand Down Expand Up @@ -151,6 +181,17 @@ def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> Vec

return self._build_index_from_vector_store(config, vector_store, **kwargs)

@get_or_build_index
def _build_neo4j_pg_index(self, config: Neo4jPGRetrieverConfig, **kwargs) -> PropertyGraphIndex:
graph_store = Neo4jPropertyGraphStore(**config.store_config.model_dump())
graph_index = PropertyGraphIndex(
nodes=self._extract_nodes(**kwargs),
property_graph_store=graph_store,
embed_model=self._extract_embed_model(**kwargs),
**config.model_dump(),
)
return graph_index

def _build_index_from_vector_store(
self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
) -> VectorStoreIndex:
Expand Down
6 changes: 3 additions & 3 deletions metagpt/rag/retrievers/hybrid_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ def __init__(self, *retrievers):
self.retrievers: list[RAGRetriever] = retrievers
super().__init__()

async def _aretrieve(self, query: QueryType, **kwargs):
async def _aretrieve(self, query_bundle: QueryType, **kwargs):
"""Asynchronously retrieves and aggregates search results from all configured retrievers.

This method queries each retriever in the `retrievers` list with the given query and
This method queries each retriever in the `retrievers` list with the given query_bundle and
additional keyword arguments. It then combines the results, ensuring that each node is
unique, based on the node's ID.
"""
all_nodes = []
for retriever in self.retrievers:
# Prevent retriever changing query
query_copy = copy.deepcopy(query)
query_copy = copy.deepcopy(query_bundle)
nodes = await retriever.aretrieve(query_copy, **kwargs)
all_nodes.extend(nodes)

Expand Down
16 changes: 15 additions & 1 deletion metagpt/rag/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from chromadb.api.types import CollectionMetadata
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.schema import TextNode, TransformComponent
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator

Expand Down Expand Up @@ -102,6 +102,20 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):
)


class Neo4jPGRetrieverStoreConfig(BaseModel):
username: str = Field(default="neo4j", description="The username for neo4j.")
password: str = Field(default="<password>", description="The password for neo4j.")
url: str = Field(default="bolt://localhost:7687", description="The neo4j server to save data.")
database: str = Field(default="neo4j", description="The database to save data.")


class Neo4jPGRetrieverConfig(IndexRetrieverConfig):
store_config: Neo4jPGRetrieverStoreConfig = Field(
default=Neo4jPGRetrieverStoreConfig(), description="Neo4jPGRetrieverStoreConfig"
)
kg_extractors: Optional[List[TransformComponent]] = Field(default=None, description="property graph extractors.")


class ElasticsearchStoreConfig(BaseModel):
index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.")
es_url: str = Field(default=None, description="Elasticsearch URL.")
Expand Down
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
aiohttp==3.8.6
aiohttp==3.9.5
#azure_storage==0.37.0
channels==4.0.0
# Django==4.1.5
Expand All @@ -13,7 +13,7 @@ lancedb==0.4.0
loguru==0.6.0
meilisearch==0.21.0
numpy~=1.26.4
openai~=1.39.0
openai~=1.40.0
openpyxl~=3.1.5
beautifulsoup4==4.12.3
pandas==2.1.1
Expand All @@ -24,7 +24,7 @@ pydantic>=2.5.3
python_docx==0.8.11
PyYAML==6.0.1
# sentence_transformers==2.2.2
setuptools==65.6.3
# setuptools==65.6.3
tenacity==8.2.3
tiktoken==0.7.0
tqdm==4.66.2
Expand Down Expand Up @@ -59,13 +59,13 @@ nbformat==5.9.2
ipython==8.17.2
ipykernel==6.27.1
scikit_learn==1.3.2
typing-extensions==4.9.0
typing-extensions==4.11
socksio~=1.0.0
gitignore-parser==0.1.9
# connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py
websockets>=10.0,<12.0
networkx~=3.2.1
google-generativeai==0.4.1
google-generativeai==0.5.2
playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py
anytree
ipywidgets==8.1.1
Expand Down
31 changes: 16 additions & 15 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,23 @@ def run(self):
"search-ddg": ["duckduckgo-search~=4.1.1"],
# "ocr": ["paddlepaddle==2.4.2", "paddleocr~=2.7.3", "tabulate==0.9.0"],
"rag": [
"llama-index-core==0.10.15",
"llama-index-embeddings-azure-openai==0.1.6",
"llama-index-embeddings-openai==0.1.5",
"llama-index-embeddings-gemini==0.1.6",
"llama-index-embeddings-ollama==0.1.2",
"llama-index-llms-azure-openai==0.1.4",
"llama-index-readers-file==0.1.4",
"llama-index-retrievers-bm25==0.1.3",
"llama-index-vector-stores-faiss==0.1.1",
"llama-index-vector-stores-elasticsearch==0.1.6",
"llama-index-vector-stores-chroma==0.1.6",
"llama-index-postprocessor-cohere-rerank==0.1.4",
"llama-index-postprocessor-colbert-rerank==0.1.1",
"llama-index-postprocessor-flag-embedding-reranker==0.1.2",
# "llama-index-vector-stores-milvus==0.1.23",
"llama-index-core==0.12.5",
"llama-index-embeddings-azure-openai==0.3.0",
"llama-index-embeddings-openai==0.3.1",
"llama-index-embeddings-gemini==0.3.0",
"llama-index-embeddings-ollama==0.5.0",
"llama-index-llms-azure-openai==0.3.0",
"llama-index-readers-file==0.4.1",
"llama-index-retrievers-bm25==0.5.0",
"llama-index-vector-stores-faiss==0.3.0",
"llama-index-vector-stores-elasticsearch==0.4.0",
"llama-index-vector-stores-chroma==0.4.1",
"llama-index-postprocessor-cohere-rerank==0.3.0",
"llama-index-postprocessor-colbert-rerank==0.3.0",
"llama-index-postprocessor-flag-embedding-reranker==0.3.0",
"llama-index-vector-stores-milvus==0.4.0",
"docx2txt==0.8",
"llama-index-graph-stores-neo4j==0.4.4",
],
}

Expand Down
3 changes: 2 additions & 1 deletion tests/metagpt/rag/engines/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_from_docs(
input_files = ["test_file1", "test_file2"]
transformations = [mocker.MagicMock()]
embed_model = mocker.MagicMock()
llm = mocker.MagicMock()
llm = MockLLM()
retriever_configs = [mocker.MagicMock()]
ranker_configs = [mocker.MagicMock()]

Expand All @@ -80,6 +80,7 @@ def test_from_docs(
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
build_graph=True,
)

# Assert
Expand Down
Loading
Loading