Skip to content

Commit

Permalink
cohere: refactor embedding methods and improve configuration handling…
Browse files Browse the repository at this point in the history
… for graph stores
  • Loading branch information
Appointat committed Jan 7, 2025
1 parent 38b48df commit dd8efc3
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 109 deletions.
7 changes: 3 additions & 4 deletions dbgpt/rag/transformer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import List, Optional

from tenacity import retry, stop_after_attempt, wait_fixed

Expand Down Expand Up @@ -30,11 +30,10 @@ def __init__(self, embedding_fn: Embeddings):
"""Initialize the Embedder."""
self._embedding_fn = embedding_fn

@abstractmethod
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
async def embed(self, input: Any) -> Any:
async def embed(self, text: str) -> List[float]:
"""Embed vector from text."""
return await self._embedding_fn.aembed_query(input)
return await self._embedding_fn.aembed_query(text=text)

@abstractmethod
async def batch_embed(
Expand Down
13 changes: 1 addition & 12 deletions dbgpt/rag/transformer/graph_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
from typing import List

from dbgpt.core.interface.embeddings import Embeddings
from dbgpt.rag.transformer.base import EmbedderBase
from dbgpt.storage.graph_store.graph import Graph, GraphElemType

Expand All @@ -14,22 +13,13 @@
class GraphEmbedder(EmbedderBase):
"""GraphEmbedder class."""

def __init__(self, embedding_fn: Embeddings):
"""Initialize the GraphEmbedder."""
super().__init__(embedding_fn)

async def embed(self, input: str) -> List[float]:
"""Embed vector from text."""
return await super().embed(input)

async def batch_embed(
self,
inputs: List[Graph],
batch_size: int = 1,
) -> List[Graph]:
"""Embed graph from graphs in batches."""
for graph in inputs:

texts = []
vectors = []

Expand Down Expand Up @@ -62,8 +52,7 @@ async def batch_embed(
for idx, vector in enumerate(batch_results):
if isinstance(vector, Exception):
raise RuntimeError(f"Failed to embed text{idx}")
else:
vectors.append(vector)
vectors.append(vector)

# Push vectors back into Graph
for vertex, vector in zip(graph.vertices(), vectors):
Expand Down
12 changes: 1 addition & 11 deletions dbgpt/rag/transformer/text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
from typing import List

from dbgpt.core.interface.embeddings import Embeddings
from dbgpt.rag.transformer.base import EmbedderBase

logger = logging.getLogger(__name__)
Expand All @@ -13,14 +12,6 @@
class TextEmbedder(EmbedderBase):
"""TextEmbedder class."""

def __init__(self, embedding_fn: Embeddings):
"""Initialize the Embedder."""
super().__init__(embedding_fn)

async def embed(self, input: str) -> List[float]:
"""Embed vector from text."""
return await super().embed(input)

async def batch_embed(
self,
inputs: List[str],
Expand Down Expand Up @@ -48,8 +39,7 @@ async def batch_embed(
for idx, vector in enumerate(batch_results):
if isinstance(vector, Exception):
raise RuntimeError(f"Failed to embed text{idx}")
else:
vectors.append(vector)
vectors.append(vector)

return vectors

Expand Down
12 changes: 10 additions & 2 deletions dbgpt/storage/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ class GraphStoreConfig(BaseModel):
default=None,
description="The embedding function of graph store, optional.",
)
enable_summary: bool = Field(
default=False,
description="Enable graph community summary or not.",
)
enable_similarity_search: bool = Field(
default=False,
description="Enable similarity search or not.",
)


class GraphStoreBase(ABC):
Expand All @@ -32,8 +40,8 @@ def __init__(self, config: GraphStoreConfig):
"""Initialize graph store."""
self._config = config
self._conn = None
self.enable_summary = False
self.enable_similarity_search = True
self.enable_summary = config.enable_summary
self.enable_similarity_search = config.enable_similarity_search

@abstractmethod
def get_config(self) -> GraphStoreConfig:
Expand Down
5 changes: 3 additions & 2 deletions dbgpt/storage/graph_store/tugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class TuGraphStoreConfig(GraphStoreConfig):
),
)
enable_summary: bool = Field(
default=False,
default=True,
description="Enable graph community summary or not.",
)
enable_similarity_search: bool = Field(
Expand All @@ -89,7 +89,8 @@ def __init__(self, config: TuGraphStoreConfig) -> None:
self._password = os.getenv("TUGRAPH_PASSWORD", config.password)
self.enable_summary = (
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
or config.enable_summary
if "GRAPH_COMMUNITY_SUMMARY_ENABLED" in os.environ
else config.enable_summary
)
self.enable_similarity_search = (
os.environ["SIMILARITY_SEARCH_ENABLED"].lower() == "true"
Expand Down
50 changes: 46 additions & 4 deletions dbgpt/storage/knowledge_graph/community/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union
from typing import AsyncGenerator, Dict, Iterator, List, Optional, Union

from dbgpt.storage.graph_store.base import GraphStoreBase
from dbgpt.storage.graph_store.graph import (
Expand Down Expand Up @@ -186,7 +186,20 @@ def explore_trigraph(
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> MemoryGraph:
"""Explore the triple graph from given subjects up to a depth."""
"""Explore the graph from given subjects up to a depth.
Args:
subs (Union[List[str], List[List[float]]): The list of the subjects (keywords or embedding vectors).
topk (Optional[int]): The number of the top similar entities.
score_threshold (Optional[float]): The threshold of the similarity score.
direct (Direction): The direction of the graph that will be explored.
depth (int): The depth of the graph that will be explored.
fan (Optional[int]): Not used.
limit (Optional[int]): The limit number of the queried entities.
Returns:
MemoryGraph: The triplet graph that includes the entities and the relations.
"""

@abstractmethod
def explore_docgraph_with_entities(
Expand All @@ -199,7 +212,22 @@ def explore_docgraph_with_entities(
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> MemoryGraph:
"""Explore the document graph after get entities from triple graph."""
"""Explore the graph from given subjects up to a depth.
Args:
subs (List[str]): The list of the entities.
topk (Optional[int]): The number of the top similar chunks.
score_threshold (Optional[float]): The threshold of the similarity score.
direct (Direction): The direction of the graph that will be explored.
depth (int): The depth of the graph that will be explored.
fan (Optional[int]): Not used.
limit (Optional[int]): The limit number of the queried chunks.
Returns:
MemoryGraph: The document graph that includes the leaf chunks that connect to the
entities, the chains from documents to the leaf chunks, and the chain
from documents to chunks.
"""

@abstractmethod
def explore_docgraph_without_entities(
Expand All @@ -212,7 +240,21 @@ def explore_docgraph_without_entities(
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> MemoryGraph:
"""Explore the document graph only from given subjects up to a depth."""
"""Explore the graph from given subjects up to a depth.
Args:
subs (Union[List[str], List[List[float]]): The list of the subjects (keywords or embedding vectors).
topk (Optional[int]): The number of the top similar chunks.
score_threshold (Optional[float]): The threshold of the similarity score.
direct (Direction): The direction of the graph that will be explored.
depth (int): The depth of the graph that will be explored.
fan (Optional[int]): Not used.
limit (Optional[int]): The limit number of the queried chunks.
Returns:
MemoryGraph: The document graph that includes the chains from documents to chunks
that contain the subs (keywords) or similar chunks (embedding vectors).
"""

@abstractmethod
def query(self, query: str, **kwargs) -> MemoryGraph:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
def __init__(self, enable_summary: bool = False):
"""Initialize MemGraph Community Store Adapter."""
self._graph_store = MemoryGraphStore(MemoryGraphStoreConfig())
self.enable_summary = enable_summary

super().__init__(self._graph_store)

Expand Down
Loading

0 comments on commit dd8efc3

Please sign in to comment.