Skip to content

Commit

Permalink
feat: add schema-linking awel example (#1081)
Browse files Browse the repository at this point in the history
  • Loading branch information
junewgl authored Jan 21, 2024
1 parent 2d90519 commit 4f83363
Show file tree
Hide file tree
Showing 5 changed files with 445 additions and 0 deletions.
44 changes: 44 additions & 0 deletions dbgpt/rag/operator/schema_linking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Any, Optional

from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking
from dbgpt.storage.vector_store.connector import VectorStoreConnector


class SchemaLinkingOperator(MapOperator[Any, Any]):
"""The Schema Linking Operator."""

def __init__(
self,
top_k: int = 5,
connection: Optional[RDBMSDatabase] = None,
llm: Optional[LLMClient] = None,
model_name: Optional[str] = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs
):
"""Init the schema linking operator
Args:
connection (RDBMSDatabase): The connection.
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)

self._schema_linking = SchemaLinking(
top_k=top_k,
connection=connection,
llm=llm,
model_name=model_name,
vector_store_connector=vector_store_connector,
)

async def map(self, query: str) -> str:
"""retrieve table schemas.
Args:
query (str): query.
Return:
str: schema info
"""
return str(await self._schema_linking.schema_linking_with_llm(query))
Empty file.
60 changes: 60 additions & 0 deletions dbgpt/rag/schemalinker/base_linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from abc import ABC, abstractmethod
from typing import List


class BaseSchemaLinker(ABC):
"""Base Linker."""

def schema_linking(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""
return self._schema_linking(query)

def schema_linking_with_vector_db(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""
return self._schema_linking_with_vector_db(query)

async def schema_linking_with_llm(self, query: str) -> List:
""" "
Args:
query(str): query text
Returns:
List: list of schema
"""
return await self._schema_linking_with_llm(query)

@abstractmethod
def _schema_linking(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""

@abstractmethod
def _schema_linking_with_vector_db(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""

@abstractmethod
async def _schema_linking_with_llm(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""
78 changes: 78 additions & 0 deletions dbgpt/rag/schemalinker/schema_linking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from functools import reduce
from typing import List, Optional

from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.chat_util import run_async_tasks

INSTRUCTION = (
"You need to filter out the most relevant database table schema information (it may be a single "
"table or multiple tables) required to generate the SQL of the question query from the given "
"database schema information. First, I will show you an example of an instruction followed by "
"the correct schema response. Then, I will give you a new instruction, and you should write "
"the schema response that appropriately completes the request.\n### Example1 Instruction:\n"
"['job(id, name, age)', 'user(id, name, age)', 'student(id, name, age, info)']\n### Example1 "
"Input:\nFind the age of student table\n### Example1 Response:\n['student(id, name, age, info)']"
"\n###New Instruction:\n{}"
)
INPUT_PROMPT = "\n###New Input:\n{}\n###New Response:"


class SchemaLinking(BaseSchemaLinker):
"""SchemaLinking by LLM"""

def __init__(
self,
top_k: int = 5,
connection: Optional[RDBMSDatabase] = None,
llm: Optional[LLMClient] = None,
model_name: Optional[str] = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs
):
"""
Args:
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)
self._top_k = top_k
self._connection = connection
self._llm = llm
self._model_name = model_name
self._vector_store_connector = vector_store_connector

def _schema_linking(self, query: str) -> List:
"""get all db schema info"""
table_summaries = _parse_db_summary(self._connection)
chunks = [Chunk(content=table_summary) for table_summary in table_summaries]
chunks_content = [chunk.content for chunk in chunks]
return chunks_content

def _schema_linking_with_vector_db(self, query: str) -> List:
queries = [query]
candidates = [
self._vector_store_connector.similar_search(query, self._top_k)
for query in queries
]
candidates = reduce(lambda x, y: x + y, candidates)
return candidates

async def _schema_linking_with_llm(self, query: str) -> List:
chunks_content = self.schema_linking(query)
schema_prompt = INSTRUCTION.format(
str(chunks_content) + INPUT_PROMPT.format(query)
)
messages = [
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=schema_prompt)
]
request = ModelRequest(model=self._model_name, messages=messages)
tasks = [self._llm.generate(request)]
# get accurate schem info by llm
schema = await run_async_tasks(tasks=tasks, concurrency_limit=1)
schema_text = schema[0].text
return schema_text
Loading

0 comments on commit 4f83363

Please sign in to comment.