Skip to content

Commit

Permalink
feat(RAG):add rag operators and rag awel examples (#1061)
Browse files Browse the repository at this point in the history
Co-authored-by: csunny <[email protected]>
  • Loading branch information
Aries-ckt and csunny authored Jan 13, 2024
1 parent 99ea6ac commit a035433
Show file tree
Hide file tree
Showing 29 changed files with 1,010 additions and 102 deletions.
36 changes: 30 additions & 6 deletions dbgpt/rag/extractor/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

SUMMARY_PROMPT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结:
{context}
答案尽量精确和简单,不要过长,长度控制在100字左右
答案尽量精确和简单,不要过长,长度控制在100字左右, 注意:请用<中文>来进行总结。
"""

SUMMARY_PROMPT_TEMPLATE_EN = """
Expand All @@ -18,6 +18,13 @@
the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters.
"""

REFINE_SUMMARY_TEMPLATE_ZH = """我们已经提供了一个到某一点的现有总结:{context}\n 请根据你之前推理的内容进行最终的总结,总结回答的时候最好按照1.2.3.进行. 注意:请用<中文>来进行总结。"""

REFINE_SUMMARY_TEMPLATE_EN = """
We have provided an existing summary up to a certain point: {context}, We have the opportunity to refine the existing summary (only if needed) with some more context below.
\nBased on the previous reasoning, please summarize the final conclusion in accordance with points 1.2.and 3.
"""


class SummaryExtractor(Extractor):
"""Summary Extractor, it can extract document summary."""
Expand All @@ -41,6 +48,11 @@ def __init__(
if language == "en"
else SUMMARY_PROMPT_TEMPLATE_ZH
)
self._refine_prompt_template = (
REFINE_SUMMARY_TEMPLATE_EN
if language == "en"
else REFINE_SUMMARY_TEMPLATE_ZH
)
self._concurrency_limit_with_llm = concurrency_limit_with_llm
self._max_iteration_with_llm = max_iteration_with_llm
self._concurrency_limit_with_llm = concurrency_limit_with_llm
Expand All @@ -64,15 +76,23 @@ async def _aextract(self, chunks: List[Chunk]) -> str:
texts = [doc.content for doc in chunks]
from dbgpt.util.prompt_util import PromptHelper

# repack chunk into prompt to adapt llm model max context window
prompt_helper = PromptHelper()
texts = prompt_helper.repack(
prompt_template=self._prompt_template, text_chunks=texts
)
if len(texts) == 1:
summary_outs = await self._llm_run_tasks(chunk_texts=texts)
summary_outs = await self._llm_run_tasks(
chunk_texts=texts, prompt_template=self._refine_prompt_template
)
return summary_outs[0]
else:
return await self._mapreduce_extract_summary(docs=texts)
map_reduce_texts = await self._mapreduce_extract_summary(docs=texts)
summary_outs = await self._llm_run_tasks(
chunk_texts=[map_reduce_texts],
prompt_template=self._refine_prompt_template,
)
return summary_outs[0]

def _extract(self, chunks: List[Chunk]) -> str:
"""document extract summary
Expand All @@ -98,7 +118,8 @@ async def _mapreduce_extract_summary(
return docs[0]
else:
summary_outs = await self._llm_run_tasks(
chunk_texts=docs[0 : self._max_iteration_with_llm]
chunk_texts=docs[0 : self._max_iteration_with_llm],
prompt_template=self._prompt_template,
)
from dbgpt.util.prompt_util import PromptHelper

Expand All @@ -108,18 +129,21 @@ async def _mapreduce_extract_summary(
)
return await self._mapreduce_extract_summary(docs=summary_outs)

async def _llm_run_tasks(self, chunk_texts: List[str]) -> List[str]:
async def _llm_run_tasks(
self, chunk_texts: List[str], prompt_template: str
) -> List[str]:
"""llm run tasks
Args:
chunk_texts: List[str]
prompt_template: str
Returns:
summary_outs: List[str]
"""
tasks = []
for chunk_text in chunk_texts:
from dbgpt.core import ModelMessage

prompt = self._prompt_template.format(context=chunk_text)
prompt = prompt_template.format(context=chunk_text)
messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)]
request = ModelRequest(model=self._model_name, messages=messages)
tasks.append(self._llm_client.generate(request))
Expand Down
37 changes: 37 additions & 0 deletions dbgpt/rag/operator/db_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any, Optional

from dbgpt.core.awel.task.base import IN
from dbgpt.core.interface.retriever import RetrieverOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.storage.vector_store.connector import VectorStoreConnector


class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
"""The DBSchema Retriever Operator.
Args:
connection (RDBMSDatabase): The connection.
top_k (int, optional): The top k. Defaults to 4.
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
"""

def __init__(
self,
top_k: int = 4,
connection: Optional[RDBMSDatabase] = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs
):
super().__init__(**kwargs)
self._retriever = DBSchemaRetriever(
top_k=top_k,
connection=connection,
vector_store_connector=vector_store_connector,
)

def retrieve(self, query: IN) -> Any:
"""retrieve table schemas.
Args:
query (IN): query.
"""
return self._retriever.retrieve(query)
39 changes: 39 additions & 0 deletions dbgpt/rag/operator/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from functools import reduce
from typing import Any, Optional

from dbgpt.core.awel.task.base import IN
from dbgpt.core.interface.retriever import RetrieverOperator
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.rag.retriever.rerank import Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.connector import VectorStoreConnector


class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
def __init__(
self,
top_k: int,
score_threshold: Optional[float] = 0.3,
query_rewrite: Optional[QueryRewrite] = None,
rerank: Ranker = None,
vector_store_connector: VectorStoreConnector = None,
**kwargs
):
super().__init__(**kwargs)
self._score_threshold = score_threshold
self._retriever = EmbeddingRetriever(
top_k=top_k,
query_rewrite=query_rewrite,
rerank=rerank,
vector_store_connector=vector_store_connector,
)

def retrieve(self, query: IN) -> Any:
if isinstance(query, str):
return self._retriever.retrieve_with_scores(query, self._score_threshold)
elif isinstance(query, list):
candidates = [
self._retriever.retrieve_with_scores(q, self._score_threshold)
for q in query
]
return reduce(lambda x, y: x + y, candidates)
26 changes: 26 additions & 0 deletions dbgpt/rag/operator/knowledge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Any, List, Optional

from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN
from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge
from dbgpt.rag.knowledge.factory import KnowledgeFactory


class KnowledgeOperator(MapOperator[Any, Any]):
"""Knowledge Operator."""

def __init__(
self, knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, **kwargs
):
"""Init the query rewrite operator.
Args:
knowledge_type: (Optional[KnowledgeType]) The knowledge type.
"""
super().__init__(**kwargs)
self._knowledge_type = knowledge_type

async def map(self, datasource: IN) -> Knowledge:
"""knowledge operator."""
return await self.blocking_func_to_async(
KnowledgeFactory.create, datasource, self._knowledge_type
)
43 changes: 43 additions & 0 deletions dbgpt/rag/operator/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Any, Optional, List

from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.rerank import DefaultRanker
from dbgpt.rag.retriever.rewrite import QueryRewrite


class RerankOperator(MapOperator[Any, Any]):
"""The Rewrite Operator."""

def __init__(
self,
topk: Optional[int] = 3,
algorithm: Optional[str] = "default",
rank_fn: Optional[callable] = None,
**kwargs
):
"""Init the query rewrite operator.
Args:
topk (int): The number of the candidates.
algorithm (Optional[str]): The rerank algorithm name.
rank_fn (Optional[callable]): The rank function.
"""
super().__init__(**kwargs)
self._algorithm = algorithm
self._rerank = DefaultRanker(
topk=topk,
rank_fn=rank_fn,
)

async def map(self, candidates_with_scores: IN) -> List[Chunk]:
"""rerank the candidates.
Args:
candidates_with_scores (IN): The candidates with scores.
Returns:
List[Chunk]: The reranked candidates.
"""
return await self.blocking_func_to_async(
self._rerank.rank, candidates_with_scores
)
41 changes: 41 additions & 0 deletions dbgpt/rag/operator/rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Any, Optional, List

from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN
from dbgpt.rag.retriever.rewrite import QueryRewrite


class QueryRewriteOperator(MapOperator[Any, Any]):
"""The Rewrite Operator."""

def __init__(
self,
llm_client: Optional[LLMClient],
model_name: Optional[str] = None,
language: Optional[str] = "en",
nums: Optional[int] = 1,
**kwargs
):
"""Init the query rewrite operator.
Args:
llm_client (Optional[LLMClient]): The LLM client.
model_name (Optional[str]): The model name.
language (Optional[str]): The prompt language.
nums (Optional[int]): The number of the rewrite results.
"""
super().__init__(**kwargs)
self._nums = nums
self._rewrite = QueryRewrite(
llm_client=llm_client,
model_name=model_name,
language=language,
)

async def map(self, query_context: IN) -> List[str]:
"""Rewrite the query."""
query = query_context.get("query")
context = query_context.get("context")
return await self._rewrite.rewrite(
origin_query=query, context=context, nums=self._nums
)
49 changes: 49 additions & 0 deletions dbgpt/rag/operator/summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Any, Optional

from dbgpt.core import LLMClient
from dbgpt.core.awel.task.base import IN
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
from dbgpt.serve.rag.operators.base import AssemblerOperator


class SummaryAssemblerOperator(AssemblerOperator[Any, Any]):
def __init__(
self,
llm_client: Optional[LLMClient],
model_name: Optional[str] = "gpt-3.5-turbo",
language: Optional[str] = "en",
max_iteration_with_llm: Optional[int] = 5,
concurrency_limit_with_llm: Optional[int] = 3,
**kwargs
):
"""
Init the summary assemble operator.
Args:
llm_client: (Optional[LLMClient]) The LLM client.
model_name: (Optional[str]) The model name.
language: (Optional[str]) The prompt language.
max_iteration_with_llm: (Optional[int]) The max iteration with llm.
concurrency_limit_with_llm: (Optional[int]) The concurrency limit with llm.
"""
super().__init__(**kwargs)
self._llm_client = llm_client
self._model_name = model_name
self._language = language
self._max_iteration_with_llm = max_iteration_with_llm
self._concurrency_limit_with_llm = concurrency_limit_with_llm

async def map(self, knowledge: IN) -> Any:
"""Assemble the summary."""
assembler = SummaryAssembler.load_from_knowledge(
knowledge=knowledge,
llm_client=self._llm_client,
model_name=self._model_name,
language=self._language,
max_iteration_with_llm=self._max_iteration_with_llm,
concurrency_limit_with_llm=self._concurrency_limit_with_llm,
)
return await assembler.generate_summary()

def assemble(self, knowledge: IN) -> Any:
"""assemble knowledge for input value."""
pass
Loading

0 comments on commit a035433

Please sign in to comment.