diff --git a/dbgpt/rag/extractor/summary.py b/dbgpt/rag/extractor/summary.py index 936240b15..97977d61f 100644 --- a/dbgpt/rag/extractor/summary.py +++ b/dbgpt/rag/extractor/summary.py @@ -9,7 +9,7 @@ SUMMARY_PROMPT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结: {context} -答案尽量精确和简单,不要过长,长度控制在100字左右 +答案尽量精确和简单,不要过长,长度控制在100字左右, 注意:请用<中文>来进行总结。 """ SUMMARY_PROMPT_TEMPLATE_EN = """ @@ -18,6 +18,13 @@ the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters. """ +REFINE_SUMMARY_TEMPLATE_ZH = """我们已经提供了一个到某一点的现有总结:{context}\n 请根据你之前推理的内容进行最终的总结,总结回答的时候最好按照1.2.3.进行. 注意:请用<中文>来进行总结。""" + +REFINE_SUMMARY_TEMPLATE_EN = """ +We have provided an existing summary up to a certain point: {context}, We have the opportunity to refine the existing summary (only if needed) with some more context below. +\nBased on the previous reasoning, please summarize the final conclusion in accordance with points 1.2.and 3. +""" + class SummaryExtractor(Extractor): """Summary Extractor, it can extract document summary.""" @@ -41,6 +48,11 @@ def __init__( if language == "en" else SUMMARY_PROMPT_TEMPLATE_ZH ) + self._refine_prompt_template = ( + REFINE_SUMMARY_TEMPLATE_EN + if language == "en" + else REFINE_SUMMARY_TEMPLATE_ZH + ) self._concurrency_limit_with_llm = concurrency_limit_with_llm self._max_iteration_with_llm = max_iteration_with_llm self._concurrency_limit_with_llm = concurrency_limit_with_llm @@ -64,15 +76,23 @@ async def _aextract(self, chunks: List[Chunk]) -> str: texts = [doc.content for doc in chunks] from dbgpt.util.prompt_util import PromptHelper + # repack chunk into prompt to adapt llm model max context window prompt_helper = PromptHelper() texts = prompt_helper.repack( prompt_template=self._prompt_template, text_chunks=texts ) if len(texts) == 1: - summary_outs = await self._llm_run_tasks(chunk_texts=texts) + summary_outs = await self._llm_run_tasks( + chunk_texts=texts, prompt_template=self._refine_prompt_template + ) return summary_outs[0] else: - return await self._mapreduce_extract_summary(docs=texts) + map_reduce_texts = await self._mapreduce_extract_summary(docs=texts) + summary_outs = await self._llm_run_tasks( + chunk_texts=[map_reduce_texts], + prompt_template=self._refine_prompt_template, + ) + return summary_outs[0] def _extract(self, chunks: List[Chunk]) -> str: """document extract summary @@ -98,7 +118,8 @@ async def _mapreduce_extract_summary( return docs[0] else: summary_outs = await self._llm_run_tasks( - chunk_texts=docs[0 : self._max_iteration_with_llm] + chunk_texts=docs[0 : self._max_iteration_with_llm], + prompt_template=self._prompt_template, ) from dbgpt.util.prompt_util import PromptHelper @@ -108,10 +129,13 @@ async def _mapreduce_extract_summary( ) return await self._mapreduce_extract_summary(docs=summary_outs) - async def _llm_run_tasks(self, chunk_texts: List[str]) -> List[str]: + async def _llm_run_tasks( + self, chunk_texts: List[str], prompt_template: str + ) -> List[str]: """llm run tasks Args: chunk_texts: List[str] + prompt_template: str Returns: summary_outs: List[str] """ @@ -119,7 +143,7 @@ async def _llm_run_tasks(self, chunk_texts: List[str]) -> List[str]: for chunk_text in chunk_texts: from dbgpt.core import ModelMessage - prompt = self._prompt_template.format(context=chunk_text) + prompt = prompt_template.format(context=chunk_text) messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)] request = ModelRequest(model=self._model_name, messages=messages) tasks.append(self._llm_client.generate(request)) diff --git a/dbgpt/rag/operator/db_schema.py b/dbgpt/rag/operator/db_schema.py new file mode 100644 index 000000000..988b1674a --- /dev/null +++ b/dbgpt/rag/operator/db_schema.py @@ -0,0 +1,37 @@ +from typing import Any, Optional + +from dbgpt.core.awel.task.base import IN +from dbgpt.core.interface.retriever import RetrieverOperator +from dbgpt.datasource.rdbms.base import RDBMSDatabase +from dbgpt.rag.retriever.db_schema import DBSchemaRetriever +from dbgpt.storage.vector_store.connector import VectorStoreConnector + + +class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]): + """The DBSchema Retriever Operator. + Args: + connection (RDBMSDatabase): The connection. + top_k (int, optional): The top k. Defaults to 4. + vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None. + """ + + def __init__( + self, + top_k: int = 4, + connection: Optional[RDBMSDatabase] = None, + vector_store_connector: Optional[VectorStoreConnector] = None, + **kwargs + ): + super().__init__(**kwargs) + self._retriever = DBSchemaRetriever( + top_k=top_k, + connection=connection, + vector_store_connector=vector_store_connector, + ) + + def retrieve(self, query: IN) -> Any: + """retrieve table schemas. + Args: + query (IN): query. + """ + return self._retriever.retrieve(query) diff --git a/dbgpt/rag/operator/embedding.py b/dbgpt/rag/operator/embedding.py new file mode 100644 index 000000000..99e3ab341 --- /dev/null +++ b/dbgpt/rag/operator/embedding.py @@ -0,0 +1,39 @@ +from functools import reduce +from typing import Any, Optional + +from dbgpt.core.awel.task.base import IN +from dbgpt.core.interface.retriever import RetrieverOperator +from dbgpt.rag.retriever.embedding import EmbeddingRetriever +from dbgpt.rag.retriever.rerank import Ranker +from dbgpt.rag.retriever.rewrite import QueryRewrite +from dbgpt.storage.vector_store.connector import VectorStoreConnector + + +class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]): + def __init__( + self, + top_k: int, + score_threshold: Optional[float] = 0.3, + query_rewrite: Optional[QueryRewrite] = None, + rerank: Ranker = None, + vector_store_connector: VectorStoreConnector = None, + **kwargs + ): + super().__init__(**kwargs) + self._score_threshold = score_threshold + self._retriever = EmbeddingRetriever( + top_k=top_k, + query_rewrite=query_rewrite, + rerank=rerank, + vector_store_connector=vector_store_connector, + ) + + def retrieve(self, query: IN) -> Any: + if isinstance(query, str): + return self._retriever.retrieve_with_scores(query, self._score_threshold) + elif isinstance(query, list): + candidates = [ + self._retriever.retrieve_with_scores(q, self._score_threshold) + for q in query + ] + return reduce(lambda x, y: x + y, candidates) diff --git a/dbgpt/rag/operator/knowledge.py b/dbgpt/rag/operator/knowledge.py new file mode 100644 index 000000000..01869a3a4 --- /dev/null +++ b/dbgpt/rag/operator/knowledge.py @@ -0,0 +1,26 @@ +from typing import Any, List, Optional + +from dbgpt.core.awel import MapOperator +from dbgpt.core.awel.task.base import IN +from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge +from dbgpt.rag.knowledge.factory import KnowledgeFactory + + +class KnowledgeOperator(MapOperator[Any, Any]): + """Knowledge Operator.""" + + def __init__( + self, knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, **kwargs + ): + """Init the query rewrite operator. + Args: + knowledge_type: (Optional[KnowledgeType]) The knowledge type. + """ + super().__init__(**kwargs) + self._knowledge_type = knowledge_type + + async def map(self, datasource: IN) -> Knowledge: + """knowledge operator.""" + return await self.blocking_func_to_async( + KnowledgeFactory.create, datasource, self._knowledge_type + ) diff --git a/dbgpt/rag/operator/rerank.py b/dbgpt/rag/operator/rerank.py new file mode 100644 index 000000000..1641e4744 --- /dev/null +++ b/dbgpt/rag/operator/rerank.py @@ -0,0 +1,43 @@ +from typing import Any, Optional, List + +from dbgpt.core import LLMClient +from dbgpt.core.awel import MapOperator +from dbgpt.core.awel.task.base import IN +from dbgpt.rag.chunk import Chunk +from dbgpt.rag.retriever.rerank import DefaultRanker +from dbgpt.rag.retriever.rewrite import QueryRewrite + + +class RerankOperator(MapOperator[Any, Any]): + """The Rewrite Operator.""" + + def __init__( + self, + topk: Optional[int] = 3, + algorithm: Optional[str] = "default", + rank_fn: Optional[callable] = None, + **kwargs + ): + """Init the query rewrite operator. + Args: + topk (int): The number of the candidates. + algorithm (Optional[str]): The rerank algorithm name. + rank_fn (Optional[callable]): The rank function. + """ + super().__init__(**kwargs) + self._algorithm = algorithm + self._rerank = DefaultRanker( + topk=topk, + rank_fn=rank_fn, + ) + + async def map(self, candidates_with_scores: IN) -> List[Chunk]: + """rerank the candidates. + Args: + candidates_with_scores (IN): The candidates with scores. + Returns: + List[Chunk]: The reranked candidates. + """ + return await self.blocking_func_to_async( + self._rerank.rank, candidates_with_scores + ) diff --git a/dbgpt/rag/operator/rewrite.py b/dbgpt/rag/operator/rewrite.py new file mode 100644 index 000000000..9d63b3540 --- /dev/null +++ b/dbgpt/rag/operator/rewrite.py @@ -0,0 +1,41 @@ +from typing import Any, Optional, List + +from dbgpt.core import LLMClient +from dbgpt.core.awel import MapOperator +from dbgpt.core.awel.task.base import IN +from dbgpt.rag.retriever.rewrite import QueryRewrite + + +class QueryRewriteOperator(MapOperator[Any, Any]): + """The Rewrite Operator.""" + + def __init__( + self, + llm_client: Optional[LLMClient], + model_name: Optional[str] = None, + language: Optional[str] = "en", + nums: Optional[int] = 1, + **kwargs + ): + """Init the query rewrite operator. + Args: + llm_client (Optional[LLMClient]): The LLM client. + model_name (Optional[str]): The model name. + language (Optional[str]): The prompt language. + nums (Optional[int]): The number of the rewrite results. + """ + super().__init__(**kwargs) + self._nums = nums + self._rewrite = QueryRewrite( + llm_client=llm_client, + model_name=model_name, + language=language, + ) + + async def map(self, query_context: IN) -> List[str]: + """Rewrite the query.""" + query = query_context.get("query") + context = query_context.get("context") + return await self._rewrite.rewrite( + origin_query=query, context=context, nums=self._nums + ) diff --git a/dbgpt/rag/operator/summary.py b/dbgpt/rag/operator/summary.py new file mode 100644 index 000000000..fefee07fc --- /dev/null +++ b/dbgpt/rag/operator/summary.py @@ -0,0 +1,49 @@ +from typing import Any, Optional + +from dbgpt.core import LLMClient +from dbgpt.core.awel.task.base import IN +from dbgpt.serve.rag.assembler.summary import SummaryAssembler +from dbgpt.serve.rag.operators.base import AssemblerOperator + + +class SummaryAssemblerOperator(AssemblerOperator[Any, Any]): + def __init__( + self, + llm_client: Optional[LLMClient], + model_name: Optional[str] = "gpt-3.5-turbo", + language: Optional[str] = "en", + max_iteration_with_llm: Optional[int] = 5, + concurrency_limit_with_llm: Optional[int] = 3, + **kwargs + ): + """ + Init the summary assemble operator. + Args: + llm_client: (Optional[LLMClient]) The LLM client. + model_name: (Optional[str]) The model name. + language: (Optional[str]) The prompt language. + max_iteration_with_llm: (Optional[int]) The max iteration with llm. + concurrency_limit_with_llm: (Optional[int]) The concurrency limit with llm. + """ + super().__init__(**kwargs) + self._llm_client = llm_client + self._model_name = model_name + self._language = language + self._max_iteration_with_llm = max_iteration_with_llm + self._concurrency_limit_with_llm = concurrency_limit_with_llm + + async def map(self, knowledge: IN) -> Any: + """Assemble the summary.""" + assembler = SummaryAssembler.load_from_knowledge( + knowledge=knowledge, + llm_client=self._llm_client, + model_name=self._model_name, + language=self._language, + max_iteration_with_llm=self._max_iteration_with_llm, + concurrency_limit_with_llm=self._concurrency_limit_with_llm, + ) + return await assembler.generate_summary() + + def assemble(self, knowledge: IN) -> Any: + """assemble knowledge for input value.""" + pass diff --git a/dbgpt/rag/retriever/db_struct.py b/dbgpt/rag/retriever/db_schema.py similarity index 90% rename from dbgpt/rag/retriever/db_struct.py rename to dbgpt/rag/retriever/db_schema.py index ad1c84e3f..8e09ea85f 100644 --- a/dbgpt/rag/retriever/db_struct.py +++ b/dbgpt/rag/retriever/db_schema.py @@ -1,6 +1,7 @@ from functools import reduce from typing import List, Optional +from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary from dbgpt.util.chat_util import run_async_tasks from dbgpt.datasource.rdbms.base import RDBMSDatabase from dbgpt.rag.chunk import Chunk @@ -9,14 +10,13 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector -class DBStructRetriever(BaseRetriever): - """DBStruct retriever.""" +class DBSchemaRetriever(BaseRetriever): + """DBSchema retriever.""" def __init__( self, top_k: int = 4, connection: Optional[RDBMSDatabase] = None, - is_embeddings: bool = True, query_rewrite: bool = False, rerank: Ranker = None, vector_store_connector: Optional[VectorStoreConnector] = None, @@ -26,14 +26,13 @@ def __init__( Args: top_k (int): top k connection (Optional[RDBMSDatabase]): RDBMSDatabase connection. - is_embeddings (bool): Whether to query by embeddings in the vector store, Defaults to True. query_rewrite (bool): query rewrite rerank (Ranker): rerank vector_store_connector (VectorStoreConnector): vector store connector code example: .. code-block:: python >>> from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect - >>> from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler + >>> from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler >>> from dbgpt.storage.vector_store.connector import VectorStoreConnector >>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig >>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever @@ -71,16 +70,18 @@ def _create_temporary_connection(): embedding_fn=embedding_fn ) # get db struct retriever - retriever = DBStructRetriever(top_k=3, vector_store_connector=vector_connector) + retriever = DBSchemaRetriever(top_k=3, vector_store_connector=vector_connector) chunks = retriever.retrieve("show columns from table") print(f"db struct rag example results:{[chunk.content for chunk in chunks]}") """ self._top_k = top_k - self._is_embeddings = is_embeddings self._connection = connection self._query_rewrite = query_rewrite self._vector_store_connector = vector_store_connector + self._need_embeddings = False + if self._vector_store_connector: + self._need_embeddings = True self._rerank = rerank or DefaultRanker(self._top_k) def _retrieve(self, query: str) -> List[Chunk]: @@ -88,7 +89,7 @@ def _retrieve(self, query: str) -> List[Chunk]: Args: query (str): query text """ - if self._is_embeddings: + if self._need_embeddings: queries = [query] candidates = [ self._vector_store_connector.similar_search(query, self._top_k) @@ -97,8 +98,6 @@ def _retrieve(self, query: str) -> List[Chunk]: candidates = reduce(lambda x, y: x + y, candidates) return candidates else: - from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary - table_summaries = _parse_db_summary(self._connection) return [Chunk(content=table_summary) for table_summary in table_summaries] @@ -115,7 +114,7 @@ async def _aretrieve(self, query: str) -> List[Chunk]: Args: query (str): query text """ - if self._is_embeddings: + if self._need_embeddings: queries = [query] candidates = [self._similarity_search(query) for query in queries] candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1) @@ -145,7 +144,7 @@ async def _similarity_search(self, query) -> List[Chunk]: self._top_k, ) - async def _aparse_db_summary(self) -> List[Chunk]: + async def _aparse_db_summary(self) -> List[str]: """Similar search.""" from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary diff --git a/dbgpt/rag/retriever/embedding.py b/dbgpt/rag/retriever/embedding.py index 48b738d22..d7607e506 100644 --- a/dbgpt/rag/retriever/embedding.py +++ b/dbgpt/rag/retriever/embedding.py @@ -99,7 +99,12 @@ async def _aretrieve(self, query: str) -> List[Chunk]: """ queries = [query] if self._query_rewrite: - new_queries = await self._query_rewrite.rewrite(origin_query=query, nums=1) + candidates_tasks = [self._similarity_search(query) for query in queries] + chunks = await self._run_async_tasks(candidates_tasks) + context = "\n".join([chunk.content for chunk in chunks]) + new_queries = await self._query_rewrite.rewrite( + origin_query=query, context=context, nums=1 + ) queries.extend(new_queries) candidates = [self._similarity_search(query) for query in queries] candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1) @@ -117,7 +122,12 @@ async def _aretrieve_with_score( """ queries = [query] if self._query_rewrite: - new_queries = await self._query_rewrite.rewrite(origin_query=query, nums=1) + candidates_tasks = [self._similarity_search(query) for query in queries] + chunks = await self._run_async_tasks(candidates_tasks) + context = "\n".join([chunk.content for chunk in chunks]) + new_queries = await self._query_rewrite.rewrite( + origin_query=query, context=context, nums=1 + ) queries.extend(new_queries) candidates_with_score = [ self._similarity_search_with_score(query, score_threshold) @@ -137,6 +147,12 @@ async def _similarity_search(self, query) -> List[Chunk]: self._top_k, ) + async def _run_async_tasks(self, tasks) -> List[Chunk]: + """Run async tasks.""" + candidates = await run_async_tasks(tasks=tasks, concurrency_limit=1) + candidates = reduce(lambda x, y: x + y, candidates) + return candidates + async def _similarity_search_with_score( self, query, score_threshold ) -> List[Chunk]: diff --git a/dbgpt/rag/retriever/rewrite.py b/dbgpt/rag/retriever/rewrite.py index fdee0046e..82b460647 100644 --- a/dbgpt/rag/retriever/rewrite.py +++ b/dbgpt/rag/retriever/rewrite.py @@ -2,14 +2,14 @@ from dbgpt.core import LLMClient, ModelMessage, ModelRequest, ModelMessageRoleType REWRITE_PROMPT_TEMPLATE_EN = """ -Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: '\n": - "original query:: {original_query}\n" - "queries:\n" +Based on the given context {context}, Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: '": + "original query:{original_query}\n" + "queries:" """ -REWRITE_PROMPT_TEMPLATE_ZH = """请根据原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries:': -"original_query:{original_query}\n" -"queries:\n" +REWRITE_PROMPT_TEMPLATE_ZH = """请根据上下文{context}, 将原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries:' +"original_query:{original_query}\n" +"queries:" """ @@ -29,6 +29,7 @@ def __init__( - query: (str), user query - model_name: (str), llm model name - llm_client: (Optional[LLMClient]) + - language: (Optional[str]), language """ self._model_name = model_name self._llm_client = llm_client @@ -39,17 +40,22 @@ def __init__( else REWRITE_PROMPT_TEMPLATE_ZH ) - async def rewrite(self, origin_query: str, nums: Optional[int] = 1) -> List[str]: + async def rewrite( + self, origin_query: str, context: Optional[str], nums: Optional[int] = 1 + ) -> List[str]: """query rewrite Args: origin_query: str original query + context: Optional[str] context nums: Optional[int] rewrite nums Returns: queries: List[str] """ from dbgpt.util.chat_util import run_async_tasks - prompt = self._prompt_template.format(original_query=origin_query, nums=nums) + prompt = self._prompt_template.format( + context=context, original_query=origin_query, nums=nums + ) messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)] request = ModelRequest(model=self._model_name, messages=messages) tasks = [self._llm_client.generate(request)] @@ -61,8 +67,12 @@ async def rewrite(self, origin_query: str, nums: Optional[int] = 1) -> List[str] queries, ) ) - print("rewrite queries:", queries) - return self._parse_llm_output(output=queries[0]) + if len(queries) == 0: + print("llm generate no rewrite queries.") + return queries + new_queries = self._parse_llm_output(output=queries[0])[0:nums] + print(f"rewrite queries: {new_queries}") + return new_queries def correct(self) -> List[str]: pass @@ -81,6 +91,8 @@ def _parse_llm_output(self, output: str) -> List[str]: if response.startswith("queries:"): response = response[len("queries:") :] + if response.startswith("queries:"): + response = response[len("queries:") :] queries = response.split(",") if len(queries) == 1: @@ -90,6 +102,10 @@ def _parse_llm_output(self, output: str) -> List[str]: if len(queries) == 1: queries = response.split("?") for k in queries: + if k.startswith("queries:"): + k = k[len("queries:") :] + if k.startswith("queries:"): + k = response[len("queries:") :] rk = k if lowercase: rk = rk.lower() diff --git a/dbgpt/rag/retriever/tests/test_db_struct.py b/dbgpt/rag/retriever/tests/test_db_struct.py index 0596c15fe..842aac037 100644 --- a/dbgpt/rag/retriever/tests/test_db_struct.py +++ b/dbgpt/rag/retriever/tests/test_db_struct.py @@ -4,7 +4,7 @@ import dbgpt from dbgpt.rag.chunk import Chunk -from dbgpt.rag.retriever.db_struct import DBStructRetriever +from dbgpt.rag.retriever.db_schema import DBSchemaRetriever from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary @@ -22,7 +22,7 @@ def mock_vector_store_connector(): @pytest.fixture def dbstruct_retriever(mock_db_connection, mock_vector_store_connector): - return DBStructRetriever( + return DBSchemaRetriever( connection=mock_db_connection, vector_store_connector=mock_vector_store_connector, ) diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index a58beed51..34947ff19 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -53,9 +53,9 @@ def get_db_summary(self, dbname, query, topk): embedding_fn=self.embeddings, vector_store_config=vector_store_config, ) - from dbgpt.rag.retriever.db_struct import DBStructRetriever + from dbgpt.rag.retriever.db_schema import DBSchemaRetriever - retriever = DBStructRetriever( + retriever = DBSchemaRetriever( top_k=topk, vector_store_connector=vector_connector ) table_docs = retriever.retrieve(query) @@ -92,9 +92,9 @@ def init_db_profile(self, db_summary_client, dbname): vector_store_config=vector_store_config, ) if not vector_connector.vector_name_exists(): - from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler + from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler - db_assembler = DBStructAssembler.load_from_connection( + db_assembler = DBSchemaAssembler.load_from_connection( connection=db_summary_client.db, vector_store_connector=vector_connector ) if len(db_assembler.get_chunks()) > 0: diff --git a/dbgpt/rag/text_splitter/token_splitter.py b/dbgpt/rag/text_splitter/token_splitter.py index 8c347e612..15605ae04 100644 --- a/dbgpt/rag/text_splitter/token_splitter.py +++ b/dbgpt/rag/text_splitter/token_splitter.py @@ -1,7 +1,7 @@ """Token splitter.""" from typing import Callable, List, Optional -from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel +from pydantic import BaseModel, Field, PrivateAttr from dbgpt.util.global_helper import globals_helper from dbgpt.util.splitter_utils import split_by_sep, split_by_char diff --git a/dbgpt/serve/rag/assembler/db_struct.py b/dbgpt/serve/rag/assembler/db_schema.py similarity index 82% rename from dbgpt/serve/rag/assembler/db_struct.py rename to dbgpt/serve/rag/assembler/db_schema.py index 9d85efbf9..2d1a98bc2 100644 --- a/dbgpt/serve/rag/assembler/db_struct.py +++ b/dbgpt/serve/rag/assembler/db_schema.py @@ -7,24 +7,24 @@ from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.knowledge.base import Knowledge, ChunkStrategy from dbgpt.rag.knowledge.factory import KnowledgeFactory -from dbgpt.rag.retriever.db_struct import DBStructRetriever +from dbgpt.rag.retriever.db_schema import DBSchemaRetriever from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary from dbgpt.serve.rag.assembler.base import BaseAssembler from dbgpt.storage.vector_store.connector import VectorStoreConnector -class DBStructAssembler(BaseAssembler): - """DBStructAssembler +class DBSchemaAssembler(BaseAssembler): + """DBSchemaAssembler Example: .. code-block:: python from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect - from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler + from dbgpt.serve.rag.assembler.db_struct import DBSchemaAssembler from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig connection = SQLiteTempConnect.create_temporary_db() - assembler = DBStructAssembler.load_from_connection( + assembler = DBSchemaAssembler.load_from_connection( connection=connection, embedding_model=embedding_model_path, ) @@ -53,18 +53,21 @@ def __init__( """ if connection is None: raise ValueError("datasource connection must be provided.") + self._connection = connection + self._vector_store_connector = vector_store_connector from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory - embedding_factory = embedding_factory or DefaultEmbeddingFactory( - default_model_name=os.getenv("EMBEDDING_MODEL") - ) - self._connection = connection - if embedding_model: - embedding_fn = embedding_factory.create(model_name=embedding_model) - self._vector_store_connector = ( - vector_store_connector - or VectorStoreConnector.from_default(embedding_fn=embedding_fn) - ) + self._embedding_model = embedding_model + if self._embedding_model: + embedding_factory = embedding_factory or DefaultEmbeddingFactory( + default_model_name=self._embedding_model + ) + self.embedding_fn = embedding_factory.create(self._embedding_model) + if self._vector_store_connector.vector_store_config.embedding_fn is None: + self._vector_store_connector.vector_store_config.embedding_fn = ( + self.embedding_fn + ) + super().__init__( chunk_parameters=chunk_parameters, **kwargs, @@ -79,7 +82,7 @@ def load_from_connection( embedding_model: Optional[str] = None, embedding_factory: Optional[EmbeddingFactory] = None, vector_store_connector: Optional[VectorStoreConnector] = None, - ) -> "DBStructAssembler": + ) -> "DBSchemaAssembler": """Load document embedding into vector store from path. Args: connection: (RDBMSDatabase) RDBMSDatabase connection. @@ -89,13 +92,9 @@ def load_from_connection( embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use. vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use. Returns: - DBStructAssembler + DBSchemaAssembler """ - from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory - - embedding_factory = embedding_factory or DefaultEmbeddingFactory( - default_model_name=embedding_model or os.getenv("EMBEDDING_MODEL_PATH") - ) + embedding_factory = embedding_factory chunk_parameters = chunk_parameters or ChunkParameters( chunk_strategy=ChunkStrategy.CHUNK_BY_SIZE.name, chunk_overlap=0 ) @@ -136,14 +135,14 @@ def persist(self) -> List[str]: def _extract_info(self, chunks) -> List[Chunk]: """Extract info from chunks.""" - def as_retriever(self, top_k: Optional[int] = 4) -> DBStructRetriever: + def as_retriever(self, top_k: Optional[int] = 4) -> DBSchemaRetriever: """ Args: top_k:(Optional[int]), default 4 Returns: - DBStructRetriever + DBSchemaRetriever """ - return DBStructRetriever( + return DBSchemaRetriever( top_k=top_k, connection=self._connection, is_embeddings=True, diff --git a/dbgpt/serve/rag/assembler/embedding.py b/dbgpt/serve/rag/assembler/embedding.py index 10138c39e..61f4196a6 100644 --- a/dbgpt/serve/rag/assembler/embedding.py +++ b/dbgpt/serve/rag/assembler/embedding.py @@ -46,17 +46,19 @@ def __init__( """ if knowledge is None: raise ValueError("knowledge datasource must be provided.") + self._vector_store_connector = vector_store_connector from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory - embedding_factory = embedding_factory or DefaultEmbeddingFactory( - default_model_name=os.getenv("EMBEDDING_MODEL") - ) - if embedding_model: - embedding_fn = embedding_factory.create(model_name=embedding_model) - self._vector_store_connector = ( - vector_store_connector - or VectorStoreConnector.from_default(embedding_fn=embedding_fn) - ) + self._embedding_model = embedding_model + if self._embedding_model: + embedding_factory = embedding_factory or DefaultEmbeddingFactory( + default_model_name=self._embedding_model + ) + self.embedding_fn = embedding_factory.create(self._embedding_model) + if self._vector_store_connector.vector_store_config.embedding_fn is None: + self._vector_store_connector.vector_store_config.embedding_fn = ( + self.embedding_fn + ) super().__init__( knowledge=knowledge, diff --git a/dbgpt/serve/rag/assembler/summary.py b/dbgpt/serve/rag/assembler/summary.py index a9d4af00e..2294cb990 100644 --- a/dbgpt/serve/rag/assembler/summary.py +++ b/dbgpt/serve/rag/assembler/summary.py @@ -57,9 +57,8 @@ def __init__( from dbgpt.rag.extractor.summary import SummaryExtractor self._extractor = extractor or SummaryExtractor( - llm_client=self._llm_client, model_name=self._model_name + llm_client=self._llm_client, model_name=self._model_name, language=language ) - self._language = language super().__init__( knowledge=knowledge, chunk_parameters=chunk_parameters, diff --git a/dbgpt/serve/rag/assembler/tests/test_embedding_assembler.py b/dbgpt/serve/rag/assembler/tests/test_embedding_assembler.py index 0a5d7d21f..fc77c18c9 100644 --- a/dbgpt/serve/rag/assembler/tests/test_embedding_assembler.py +++ b/dbgpt/serve/rag/assembler/tests/test_embedding_assembler.py @@ -7,7 +7,7 @@ from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.knowledge.base import Knowledge from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter -from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler +from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler from dbgpt.storage.vector_store.connector import VectorStoreConnector @@ -66,7 +66,7 @@ def test_load_knowledge( mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE" mock_chunk_parameters.text_splitter = CharacterTextSplitter() mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE - assembler = DBStructAssembler( + assembler = DBSchemaAssembler( connection=mock_db_connection, chunk_parameters=mock_chunk_parameters, embedding_factory=mock_embedding_factory, diff --git a/dbgpt/serve/rag/operators/__init__.py b/dbgpt/serve/rag/operators/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/rag/operators/base.py b/dbgpt/serve/rag/operators/base.py new file mode 100644 index 000000000..8f2cca315 --- /dev/null +++ b/dbgpt/serve/rag/operators/base.py @@ -0,0 +1,23 @@ +from abc import abstractmethod + +from dbgpt.core.awel import MapOperator +from dbgpt.core.awel.task.base import IN, OUT + + +class AssemblerOperator(MapOperator[IN, OUT]): + """The Base Assembler Operator.""" + + async def map(self, input_value: IN) -> OUT: + """Map input value to output value. + + Args: + input_value (IN): The input value. + + Returns: + OUT: The output value. + """ + return await self.blocking_func_to_async(self.assemble, input_value) + + @abstractmethod + def assemble(self, input_value: IN) -> OUT: + """assemble knowledge for input value.""" diff --git a/dbgpt/serve/rag/operators/db_schema.py b/dbgpt/serve/rag/operators/db_schema.py new file mode 100644 index 000000000..995a05f7e --- /dev/null +++ b/dbgpt/serve/rag/operators/db_schema.py @@ -0,0 +1,36 @@ +from typing import Any, Optional + +from dbgpt.core.awel.task.base import IN +from dbgpt.datasource.rdbms.base import RDBMSDatabase +from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler +from dbgpt.serve.rag.operators.base import AssemblerOperator +from dbgpt.storage.vector_store.connector import VectorStoreConnector + + +class DBSchemaAssemblerOperator(AssemblerOperator[Any, Any]): + """The DBSchema Assembler Operator. + Args: + connection (RDBMSDatabase): The connection. + chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to None. + vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None. + """ + + def __init__( + self, + connection: RDBMSDatabase = None, + vector_store_connector: Optional[VectorStoreConnector] = None, + **kwargs + ): + self._connection = connection + self._vector_store_connector = vector_store_connector + self._assembler = DBSchemaAssembler.load_from_connection( + connection=self._connection, + vector_store_connector=self._vector_store_connector, + ) + super().__init__(**kwargs) + + def assemble(self, input_value: IN) -> Any: + """assemble knowledge for input value.""" + if self._vector_store_connector: + self._assembler.persist() + return self._assembler.get_chunks() diff --git a/dbgpt/serve/rag/operators/embedding.py b/dbgpt/serve/rag/operators/embedding.py new file mode 100644 index 000000000..3dd32e29b --- /dev/null +++ b/dbgpt/serve/rag/operators/embedding.py @@ -0,0 +1,44 @@ +from typing import Any, Optional + +from dbgpt.core.awel.task.base import IN +from dbgpt.rag.chunk_manager import ChunkParameters +from dbgpt.rag.knowledge.base import Knowledge +from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler +from dbgpt.serve.rag.operators.base import AssemblerOperator +from dbgpt.storage.vector_store.connector import VectorStoreConnector + + +class EmbeddingAssemblerOperator(AssemblerOperator[Any, Any]): + """The Embedding Assembler Operator. + Args: + knowledge (Knowledge): The knowledge. + chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to None. + vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None. + """ + + def __init__( + self, + chunk_parameters: Optional[ChunkParameters] = ChunkParameters( + chunk_strategy="CHUNK_BY_SIZE" + ), + vector_store_connector: VectorStoreConnector = None, + **kwargs + ): + """ + Args: + chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to ChunkParameters(chunk_strategy="CHUNK_BY_SIZE"). + vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None. + """ + self._chunk_parameters = chunk_parameters + self._vector_store_connector = vector_store_connector + super().__init__(**kwargs) + + def assemble(self, knowledge: IN) -> Any: + """assemble knowledge for input value.""" + assembler = EmbeddingAssembler.load_from_knowledge( + knowledge=knowledge, + chunk_parameters=self._chunk_parameters, + vector_store_connector=self._vector_store_connector, + ) + assembler.persist() + return assembler.get_chunks() diff --git a/dbgpt/storage/vector_store/connector.py b/dbgpt/storage/vector_store/connector.py index 27ee1b871..50e7148a0 100644 --- a/dbgpt/storage/vector_store/connector.py +++ b/dbgpt/storage/vector_store/connector.py @@ -92,6 +92,11 @@ def similar_search_with_scores( """ return self.client.similar_search_with_scores(doc, topk, score_threshold) + @property + def vector_store_config(self) -> VectorStoreConfig: + """vector store config.""" + return self._vector_store_config + def vector_name_exists(self): """is vector store name exist.""" return self.client.vector_name_exists() diff --git a/examples/awel/simple_dbschema_retriever_example.py b/examples/awel/simple_dbschema_retriever_example.py new file mode 100644 index 000000000..744d7b763 --- /dev/null +++ b/examples/awel/simple_dbschema_retriever_example.py @@ -0,0 +1,130 @@ +import os +from typing import Dict, List + +from pydantic import BaseModel, Field + +from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH +from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator +from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect +from dbgpt.rag.chunk import Chunk +from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory +from dbgpt.rag.operator.db_schema import DBSchemaRetrieverOperator +from dbgpt.serve.rag.operators.db_schema import DBSchemaAssemblerOperator +from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig +from dbgpt.storage.vector_store.connector import VectorStoreConnector + +"""AWEL: Simple rag db schema embedding operator example + + if you not set vector_store_connector, it will return all tables schema in database. + ``` + retriever_task = DBSchemaRetrieverOperator( + connection=_create_temporary_connection() + ) + ``` + if you set vector_store_connector, it will recall topk similarity tables schema in database. + ``` + retriever_task = DBSchemaRetrieverOperator( + connection=_create_temporary_connection() + top_k=1, + vector_store_connector=vector_store_connector + ) + ``` + + Examples: + ..code-block:: shell + curl --location 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/dbschema' \ + --header 'Content-Type: application/json' \ + --data '{"query": "what is user name?"}' +""" + + +def _create_vector_connector(): + """Create vector connector.""" + return VectorStoreConnector.from_default( + "Chroma", + vector_store_config=ChromaVectorConfig( + name="vector_name", + persist_path=os.path.join(PILOT_PATH, "data"), + ), + embedding_fn=DefaultEmbeddingFactory( + default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), + ).create(), + ) + + +def _create_temporary_connection(): + """Create a temporary database connection for testing.""" + connect = SQLiteTempConnect.create_temporary_db() + connect.create_temp_tables( + { + "user": { + "columns": { + "id": "INTEGER PRIMARY KEY", + "name": "TEXT", + "age": "INTEGER", + }, + "data": [ + (1, "Tom", 10), + (2, "Jerry", 16), + (3, "Jack", 18), + (4, "Alice", 20), + (5, "Bob", 22), + ], + } + } + ) + return connect + + +def _join_fn(chunks: List[Chunk], query: str) -> str: + print(f"db schema info is {[chunk.content for chunk in chunks]}") + return query + + +class TriggerReqBody(BaseModel): + query: str = Field(..., description="User query") + + +class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def map(self, input_value: TriggerReqBody) -> Dict: + params = { + "query": input_value.query, + } + print(f"Receive input value: {input_value}") + return params + + +with DAG("simple_rag_db_schema_example") as dag: + trigger = HttpTrigger( + "/examples/rag/dbschema", methods="POST", request_body=TriggerReqBody + ) + request_handle_task = RequestHandleOperator() + query_operator = MapOperator(lambda request: request["query"]) + vector_store_connector = _create_vector_connector() + assembler_task = DBSchemaAssemblerOperator( + connection=_create_temporary_connection(), + vector_store_connector=vector_store_connector, + ) + join_operator = JoinOperator(combine_function=_join_fn) + retriever_task = DBSchemaRetrieverOperator( + connection=_create_temporary_connection(), + top_k=1, + vector_store_connector=vector_store_connector, + ) + result_parse_task = MapOperator(lambda chunks: [chunk.content for chunk in chunks]) + trigger >> request_handle_task >> assembler_task >> join_operator + trigger >> request_handle_task >> query_operator >> join_operator + join_operator >> retriever_task >> result_parse_task + + +if __name__ == "__main__": + if dag.leaf_nodes[0].dev_mode: + # Development mode, you can run the dag locally for debugging. + from dbgpt.core.awel import setup_dev_environment + + setup_dev_environment([dag], port=5555) + else: + pass diff --git a/examples/awel/simple_rag_embedding_example.py b/examples/awel/simple_rag_embedding_example.py new file mode 100644 index 000000000..268d4bcb2 --- /dev/null +++ b/examples/awel/simple_rag_embedding_example.py @@ -0,0 +1,86 @@ +import asyncio +import os +from typing import Dict, List + +from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH +from dbgpt.core.awel import DAG, InputOperator, MapOperator, SimpleCallDataInputSource +from dbgpt.rag.chunk import Chunk +from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory +from dbgpt.rag.operator.knowledge import KnowledgeOperator +from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator +from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig +from dbgpt.storage.vector_store.connector import VectorStoreConnector + +"""AWEL: Simple rag embedding operator example + + pre-requirements: + set your file path in your example code. + Examples: + ..code-block:: shell + python examples/awel/simple_rag_embedding_example.py +""" + + +def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict: + """context Join function for JoinOperator. + + Args: + context_dict (Dict): context dict + chunks (List[Chunk]): chunks + Returns: + Dict: context dict + """ + context_dict["context"] = "\n".join([chunk.content for chunk in chunks]) + return context_dict + + +def _create_vector_connector(): + """Create vector connector.""" + return VectorStoreConnector.from_default( + "Chroma", + vector_store_config=ChromaVectorConfig( + name="vector_name", + persist_path=os.path.join(PILOT_PATH, "data"), + ), + embedding_fn=DefaultEmbeddingFactory( + default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), + ).create(), + ) + + +class ResultOperator(MapOperator): + """The Result Operator.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def map(self, chunks: List) -> str: + result = f"embedding success, there are {len(chunks)} chunks." + print(result) + return result + + +with DAG("simple_sdk_rag_embedding_example") as dag: + knowledge_operator = KnowledgeOperator() + vector_connector = _create_vector_connector() + input_task = InputOperator(input_source=SimpleCallDataInputSource()) + file_path_parser = MapOperator(map_function=lambda x: x["file_path"]) + embedding_operator = EmbeddingAssemblerOperator( + vector_store_connector=vector_connector, + ) + output_task = ResultOperator() + ( + input_task + >> file_path_parser + >> knowledge_operator + >> embedding_operator + >> output_task + ) + +if __name__ == "__main__": + input_data = { + "data": { + "file_path": "docs/docs/awel.md", + } + } + output = asyncio.run(output_task.call(call_data=input_data)) diff --git a/examples/awel/simple_rag_retriever_example.py b/examples/awel/simple_rag_retriever_example.py new file mode 100644 index 000000000..1d48d5478 --- /dev/null +++ b/examples/awel/simple_rag_retriever_example.py @@ -0,0 +1,127 @@ +import asyncio +import os +from typing import Dict, List + +from pydantic import BaseModel, Field + +from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH +from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator +from dbgpt.model import OpenAILLMClient +from dbgpt.rag.chunk import Chunk +from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory +from dbgpt.rag.operator.embedding import EmbeddingRetrieverOperator +from dbgpt.rag.operator.rerank import RerankOperator +from dbgpt.rag.operator.rewrite import QueryRewriteOperator +from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig +from dbgpt.storage.vector_store.connector import VectorStoreConnector + +"""AWEL: Simple rag embedding operator example + + pre-requirements: + 1. install openai python sdk + + ``` + pip install openai + ``` + 2. set openai key and base + ``` + export OPENAI_API_KEY={your_openai_key} + export OPENAI_API_BASE={your_openai_base} + ``` + 3. make sure you have vector store. + if there are no data in vector store, please run examples/awel/simple_rag_embedding_example.py + + + ensure your embedding model in DB-GPT/models/. + + Examples: + ..code-block:: shell + DBGPT_SERVER="http://127.0.0.1:5555" + curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/rag/retrieve \ + -H "Content-Type: application/json" -d '{ + "query": "what is awel talk about?" + }' +""" + + +class TriggerReqBody(BaseModel): + query: str = Field(..., description="User query") + + +class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def map(self, input_value: TriggerReqBody) -> Dict: + params = { + "query": input_value.query, + } + print(f"Receive input value: {input_value}") + return params + + +def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict: + """context Join function for JoinOperator. + + Args: + context_dict (Dict): context dict + chunks (List[Chunk]): chunks + Returns: + Dict: context dict + """ + context_dict["context"] = "\n".join([chunk.content for chunk in chunks]) + return context_dict + + +def _create_vector_connector(): + """Create vector connector.""" + return VectorStoreConnector.from_default( + "Chroma", + vector_store_config=ChromaVectorConfig( + name="vector_name", + persist_path=os.path.join(PILOT_PATH, "data"), + ), + embedding_fn=DefaultEmbeddingFactory( + default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), + ).create(), + ) + + +with DAG("simple_sdk_rag_retriever_example") as dag: + vector_connector = _create_vector_connector() + trigger = HttpTrigger( + "/examples/rag/retrieve", methods="POST", request_body=TriggerReqBody + ) + request_handle_task = RequestHandleOperator() + query_parser = MapOperator(map_function=lambda x: x["query"]) + context_join_operator = JoinOperator(combine_function=_context_join_fn) + rewrite_operator = QueryRewriteOperator(llm_client=OpenAILLMClient()) + retriever_context_operator = EmbeddingRetrieverOperator( + top_k=3, + vector_store_connector=vector_connector, + ) + retriever_operator = EmbeddingRetrieverOperator( + top_k=3, + vector_store_connector=vector_connector, + ) + rerank_operator = RerankOperator() + model_parse_task = MapOperator(lambda out: out.to_dict()) + + trigger >> request_handle_task >> context_join_operator + ( + trigger + >> request_handle_task + >> query_parser + >> retriever_context_operator + >> context_join_operator + ) + context_join_operator >> rewrite_operator >> retriever_operator >> rerank_operator + +if __name__ == "__main__": + if dag.leaf_nodes[0].dev_mode: + # Development mode, you can run the dag locally for debugging. + from dbgpt.core.awel import setup_dev_environment + + setup_dev_environment([dag], port=5555) + else: + pass diff --git a/examples/awel/simple_rag_rewrite_example.py b/examples/awel/simple_rag_rewrite_example.py new file mode 100644 index 000000000..b09929bd4 --- /dev/null +++ b/examples/awel/simple_rag_rewrite_example.py @@ -0,0 +1,74 @@ +"""AWEL: Simple rag rewrite example + + pre-requirements: + 1. install openai python sdk + ``` + pip install openai + ``` + 2. set openai key and base + ``` + export OPENAI_API_KEY={your_openai_key} + export OPENAI_API_BASE={your_openai_base} + ``` + or + ``` + import os + os.environ["OPENAI_API_KEY"] = {your_openai_key} + os.environ["OPENAI_API_BASE"] = {your_openai_base} + ``` + python examples/awel/simple_rag_rewrite_example.py + Example: + + .. code-block:: shell + + DBGPT_SERVER="http://127.0.0.1:5000" + curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/rag/rewrite \ + -H "Content-Type: application/json" -d '{ + "query": "compare curry and james", + "context":"steve curry and lebron james are nba all-stars" + }' +""" +from typing import Dict + +from dbgpt._private.pydantic import BaseModel, Field +from dbgpt.core.awel import DAG, HttpTrigger, MapOperator +from dbgpt.model import OpenAILLMClient +from dbgpt.rag.operator.rewrite import QueryRewriteOperator + + +class TriggerReqBody(BaseModel): + query: str = Field(..., description="User query") + context: str = Field(..., description="context") + + +class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def map(self, input_value: TriggerReqBody) -> Dict: + params = { + "query": input_value.query, + "context": input_value.context, + } + print(f"Receive input value: {input_value}") + return params + + +with DAG("dbgpt_awel_simple_rag_rewrite_example") as dag: + trigger = HttpTrigger( + "/examples/rag/rewrite", methods="POST", request_body=TriggerReqBody + ) + request_handle_task = RequestHandleOperator() + # build query rewrite operator + rewrite_task = QueryRewriteOperator(llm_client=OpenAILLMClient(), nums=2) + trigger >> request_handle_task >> rewrite_task + + +if __name__ == "__main__": + if dag.leaf_nodes[0].dev_mode: + # Development mode, you can run the dag locally for debugging. + from dbgpt.core.awel import setup_dev_environment + + setup_dev_environment([dag], port=5555) + else: + pass diff --git a/examples/awel/simple_rag_summary_example.py b/examples/awel/simple_rag_summary_example.py new file mode 100644 index 000000000..f16cbec73 --- /dev/null +++ b/examples/awel/simple_rag_summary_example.py @@ -0,0 +1,84 @@ +"""AWEL: +This example shows how to use AWEL to build a simple rag summary example. + pre-requirements: + 1. install openai python sdk + ``` + pip install openai + ``` + 2. set openai key and base + ``` + export OPENAI_API_KEY={your_openai_key} + export OPENAI_API_BASE={your_openai_base} + ``` + or + ``` + import os + os.environ["OPENAI_API_KEY"] = {your_openai_key} + os.environ["OPENAI_API_BASE"] = {your_openai_base} + ``` + python examples/awel/simple_rag_summary_example.py + Example: + + .. code-block:: shell + + DBGPT_SERVER="http://127.0.0.1:5000" + FILE_PATH="{your_file_path}" + curl -X POST http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/summary \ + -H "Content-Type: application/json" -d '{ + "file_path": $FILE_PATH + }' +""" +from typing import Dict + +from dbgpt._private.pydantic import BaseModel, Field +from dbgpt.core.awel import DAG, HttpTrigger, MapOperator +from dbgpt.model import OpenAILLMClient +from dbgpt.rag.operator.knowledge import KnowledgeOperator +from dbgpt.rag.operator.summary import SummaryAssemblerOperator + + +class TriggerReqBody(BaseModel): + file_path: str = Field(..., description="file_path") + + +class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def map(self, input_value: TriggerReqBody) -> Dict: + params = { + "file_path": input_value.file_path, + } + print(f"Receive input value: {input_value}") + return params + + +with DAG("dbgpt_awel_simple_rag_rewrite_example") as dag: + trigger = HttpTrigger( + "/examples/rag/summary", methods="POST", request_body=TriggerReqBody + ) + request_handle_task = RequestHandleOperator() + path_operator = MapOperator(lambda request: request["file_path"]) + # build knowledge operator + knowledge_operator = KnowledgeOperator() + # build summary assembler operator + summary_operator = SummaryAssemblerOperator( + llm_client=OpenAILLMClient(), language="en" + ) + ( + trigger + >> request_handle_task + >> path_operator + >> knowledge_operator + >> summary_operator + ) + + +if __name__ == "__main__": + if dag.leaf_nodes[0].dev_mode: + # Development mode, you can run the dag locally for debugging. + from dbgpt.core.awel import setup_dev_environment + + setup_dev_environment([dag], port=5555) + else: + pass diff --git a/examples/rag/db_struct_rag_example.py b/examples/rag/db_schema_rag_example.py similarity index 67% rename from examples/rag/db_struct_rag_example.py rename to examples/rag/db_schema_rag_example.py index 01a97341c..c101163aa 100644 --- a/examples/rag/db_struct_rag_example.py +++ b/examples/rag/db_schema_rag_example.py @@ -1,6 +1,9 @@ +import os + +from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory -from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler +from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig from dbgpt.storage.vector_store.connector import VectorStoreConnector @@ -13,7 +16,7 @@ Examples: ..code-block:: shell - python examples/rag/db_struct_rag_example.py + python examples/rag/db_schema_rag_example.py """ @@ -41,28 +44,29 @@ def _create_temporary_connection(): return connect -if __name__ == "__main__": - connection = _create_temporary_connection() - - embedding_model_path = "{your_embedding_model_path}" - vector_persist_path = "{your_persist_path}" - embedding_fn = DefaultEmbeddingFactory( - default_model_name=embedding_model_path - ).create() - vector_connector = VectorStoreConnector.from_default( +def _create_vector_connector(): + """Create vector connector.""" + return VectorStoreConnector.from_default( "Chroma", vector_store_config=ChromaVectorConfig( - name="vector_name", - persist_path=vector_persist_path, + name="db_schema_vector_store_name", + persist_path=os.path.join(PILOT_PATH, "data"), ), - embedding_fn=embedding_fn, + embedding_fn=DefaultEmbeddingFactory( + default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), + ).create(), ) - assembler = DBStructAssembler.load_from_connection( + + +if __name__ == "__main__": + connection = _create_temporary_connection() + vector_connector = _create_vector_connector() + assembler = DBSchemaAssembler.load_from_connection( connection=connection, vector_store_connector=vector_connector, ) assembler.persist() - # get db struct retriever + # get db schema retriever retriever = assembler.as_retriever(top_k=1) chunks = retriever.retrieve("show columns from user") - print(f"db struct rag example results:{[chunk.content for chunk in chunks]}") + print(f"db schema rag example results:{[chunk.content for chunk in chunks]}") diff --git a/examples/rag/embedding_rag_example.py b/examples/rag/embedding_rag_example.py index b51a60444..ef2b5c591 100644 --- a/examples/rag/embedding_rag_example.py +++ b/examples/rag/embedding_rag_example.py @@ -1,5 +1,7 @@ import asyncio +import os +from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH from dbgpt.rag.chunk_manager import ChunkParameters from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory from dbgpt.rag.knowledge.factory import KnowledgeFactory @@ -20,21 +22,24 @@ """ -async def main(): - file_path = "./docs/docs/awel.md" - vector_persist_path = "{your_persist_path}" - embedding_model_path = "{your_embedding_model_path}" - knowledge = KnowledgeFactory.from_file_path(file_path) - vector_connector = VectorStoreConnector.from_default( +def _create_vector_connector(): + """Create vector connector.""" + return VectorStoreConnector.from_default( "Chroma", vector_store_config=ChromaVectorConfig( - name="vector_name", - persist_path=vector_persist_path, + name="db_schema_vector_store_name", + persist_path=os.path.join(PILOT_PATH, "data"), ), embedding_fn=DefaultEmbeddingFactory( - default_model_name=embedding_model_path + default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), ).create(), ) + + +async def main(): + file_path = "docs/docs/awel.md" + knowledge = KnowledgeFactory.from_file_path(file_path) + vector_connector = _create_vector_connector() chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE") # get embedding assembler assembler = EmbeddingAssembler.load_from_knowledge( @@ -45,7 +50,7 @@ async def main(): assembler.persist() # get embeddings retriever retriever = assembler.as_retriever(3) - chunks = await retriever.aretrieve_with_scores("RAG", 0.3) + chunks = await retriever.aretrieve_with_scores("what is awel talk about", 0.3) print(f"embedding rag example results:{chunks}")