Skip to content

Commit

Permalink
use config to pass parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
SonglinLyu committed Jan 20, 2025
1 parent 4e7d1a7 commit 197ee57
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 51 deletions.
4 changes: 1 addition & 3 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,9 @@ EXECUTE_LOCAL_COMMANDS=False
#*******************************************************************#
VECTOR_STORE_TYPE=Chroma
GRAPH_STORE_TYPE=TuGraph
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE=5
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE=5 # the top size of knowledge graph search for triplets
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE=0.3
KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE=5
KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_RECALL_SCORE=0.7
KNOWLEDGE_GRAPH_TEXT_SEARCH_TOP_SIZE=5
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE=20
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE=0.0

Expand Down
5 changes: 0 additions & 5 deletions dbgpt/storage/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ class GraphStoreConfig(BaseModel):
default=False,
description="Enable similarity search or not.",
)
enable_text_search: bool = Field(
default=False,
description="Enable text search or not.",
)


class GraphStoreBase(ABC):
Expand All @@ -46,7 +42,6 @@ def __init__(self, config: GraphStoreConfig):
self._conn = None
self.enable_summary = config.enable_summary
self.enable_similarity_search = config.enable_similarity_search
self.enable_text_search = config.enable_text_search

@abstractmethod
def get_config(self) -> GraphStoreConfig:
Expand Down
5 changes: 0 additions & 5 deletions dbgpt/storage/graph_store/tugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,6 @@ def __init__(self, config: TuGraphStoreConfig) -> None:
if "SIMILARITY_SEARCH_ENABLED" in os.environ
else config.enable_similarity_search
)
self.enable_text_search = (
os.environ["TEXT_SEARCH_ENABLED"].lower() == "true"
if "TEXT_SEARCH_ENABLED" in os.environ
else config.enable_text_search
)
self._plugin_names = (
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
or config.plugin_names
Expand Down
23 changes: 8 additions & 15 deletions dbgpt/storage/knowledge_graph/community_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dbgpt.rag.transformer.text_embedder import TextEmbedder
from dbgpt.storage.knowledge_graph.base import ParagraphChunk
from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore
from dbgpt.storage.knowledge_graph.graph_retriever.graph_retriever import GraphRetriever
from dbgpt.storage.knowledge_graph.graph_retriever.graph_retriever_router import GraphRetrieverRouter
from dbgpt.storage.knowledge_graph.knowledge_graph import (
GRAPH_PARAMETERS,
BuiltinKnowledgeGraph,
Expand Down Expand Up @@ -209,6 +209,10 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
default=0.7,
description="Recall score of similarity search",
)
enable_text_search: bool = Field(
default=False,
description="Enable text2gql search or not.",
)
text_search_topk: int = Field(
default=5,
description="Topk of text search",
Expand Down Expand Up @@ -363,20 +367,9 @@ def community_store_configure(name: str, cfg: VectorStoreConfig):

self._knowledge_graph_triplet_search_top_size = 5
self._knowledge_graph_document_search_top_size = 5
self._graph_retriever = GraphRetriever(
self._triplet_graph_enabled,
self._document_graph_enabled,
self._knowledge_graph_triplet_search_top_size,
self._knowledge_graph_document_search_top_size,
self._keyword_extractor,
self._graph_retriever_router = GraphRetrieverRouter(
config,
self._graph_store.enable_similarity_search,
self._config.embedding_fn,
self._triplet_embedding_batch_size,
self._similarity_search_topk,
self._similarity_search_score_threshold,
self._graph_store.enable_text_search,
self._llm_client,
self._model_name,
self._graph_store_apdater,
)

Expand Down Expand Up @@ -556,7 +549,7 @@ async def asimilar_search_with_scores(
]
context = "\n".join(summaries) if summaries else ""

subgraph, subgraph_for_doc, text2gql_query = await self._graph_retriever.retrieve(text)
subgraph, subgraph_for_doc, text2gql_query = await self._graph_retriever_router.retrieve(text)

knowledge_graph_str = subgraph.format() if subgraph else ""
knowledge_graph_for_doc_str = (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Graph Retriever."""

import logging
import os
from abc import ABC, abstractmethod
from typing import List, Optional, Union

Expand Down Expand Up @@ -31,39 +32,68 @@
from dbgpt.storage.knowledge_graph.graph_retriever.vector_based_graph_retriever import (
VectorBasedGraphRetriever,
)
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor

logger = logging.getLogger(__name__)


class GraphRetriever(GraphRetrieverBase):
"""Graph retriever class."""
class GraphRetrieverRouter:
"""Graph Retriever Router class."""

def __init__(
self,
triplet_graph_enabled,
document_graph_enabled,
triplet_topk,
document_topk,
keyword_extractor,
config,
enable_similarity_search,
embedding_fn,
embedding_batch_size,
similarity_search_topk,
similarity_search_score_threshold,
enable_text_search,
llm_client,
model_name,
graph_store_apdater,
):
self._triplet_graph_enabled = triplet_graph_enabled
self._document_graph_enabled = document_graph_enabled
self._keyword_extractor = keyword_extractor

self._triplet_graph_enabled = (
os.environ["TRIPLET_GRAPH_ENABLED"].lower() == "true"
if "TRIPLET_GRAPH_ENABLED" in os.environ
else config.triplet_graph_enabled
)
self._document_graph_enabled = (
os.environ["DOCUMENT_GRAPH_ENABLED"].lower() == "true"
if "DOCUMENT_GRAPH_ENABLED" in os.environ
else config.document_graph_enabled
)
triplet_topk = int(
os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE", config.extract_topk)
)
document_topk = int(
os.getenv(
"KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE",
config.knowledge_graph_chunk_search_top_size,
)
)
llm_client = config.llm_client
model_name = config.model_name
self._enable_similarity_search = enable_similarity_search
self._text_embedder = TextEmbedder(embedding_fn)
self._embedding_batch_size = embedding_batch_size
self._embedding_batch_size = int(
os.getenv(
"KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE",
config.knowledge_graph_embedding_batch_size,
)
)
similarity_search_topk = int(
os.getenv(
"KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE",
config.similarity_search_topk,
)
)
similarity_search_score_threshold = float(
os.getenv(
"KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE",
config.extract_score_threshold,
)
)
self._enable_text_search = (
os.environ["TEXT_SEARCH_ENABLED"].lower() == "true"
if "TEXT_SEARCH_ENABLED" in os.environ
else config.enable_text_search
)

self._enable_text_search = enable_text_search
self._keyword_extractor = KeywordExtractor(llm_client, model_name)
self._text_embedder = TextEmbedder(config.embedding_fn)

self._keyword_based_graph_retriever = KeywordBasedGraphRetriever(
graph_store_apdater, triplet_topk
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, graph_store_apdater, triplet_topk):
self._graph_store_apdater = graph_store_apdater
self._triplet_topk = triplet_topk

async def retrieve(self, keywords: List[str]) -> tuple[MemoryGraph, List[str]]:
async def retrieve(self, keywords: List[str]) -> MemoryGraph:
"""Retrieve from triplets graph with keywords."""

subgraph = self._graph_store_apdater.explore_trigraph(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(

async def retrieve(
self, vectors: List[List[float]]
) -> tuple[MemoryGraph, List[List[float]]]:
) -> MemoryGraph:
"""Retrieve from triplet graph with vectors."""

subgraph = self._graph_store_apdater.explore_trigraph(
Expand Down

0 comments on commit 197ee57

Please sign in to comment.