Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

review: IQL aggregations #76

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/dbally_benchmark/e2e_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import dbally
from dbally.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError
from dbally.iql_generator.filters_prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError
from dbally.llms.litellm import LiteLLM
from dbally.llms.local import LocalLLM
from dbally.view_selection.prompt import VIEW_SELECTION_TEMPLATE
Expand Down
10 changes: 5 additions & 5 deletions benchmark/dbally_benchmark/iql_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
from sqlalchemy import create_engine

from dbally.audit.event_tracker import EventTracker
from dbally.iql_generator.iql_generator import IQLGenerator
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError
from dbally.iql_generator.filters_prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError
from dbally.iql_generator.iql_filters_generator import IQLFiltersGenerator
from dbally.llms.litellm import LiteLLM
from dbally.llms.local import LocalLLM
from dbally.views.structured import BaseStructuredView


async def _run_iql_for_single_example(
example: BIRDExample, view: BaseStructuredView, iql_generator: IQLGenerator
example: BIRDExample, view: BaseStructuredView, iql_generator: IQLFiltersGenerator
) -> IQLResult:
filter_list = view.list_filters()
event_tracker = EventTracker()
Expand All @@ -46,7 +46,7 @@ async def _run_iql_for_single_example(


async def run_iql_for_dataset(
dataset: BIRDDataset, view: BaseStructuredView, iql_generator: IQLGenerator
dataset: BIRDDataset, view: BaseStructuredView, iql_generator: IQLFiltersGenerator
) -> List[IQLResult]:
"""
Runs IQL predictions for a dataset.
Expand Down Expand Up @@ -102,7 +102,7 @@ async def evaluate(cfg: DictConfig) -> Any:
else:
llm = LiteLLM(api_key=benchmark_cfg.openai_api_key, model_name=cfg.model_name)

iql_generator = IQLGenerator(llm=llm)
iql_generator = IQLFiltersGenerator(llm=llm)

run = None
if cfg.neptune.log:
Expand Down
5 changes: 3 additions & 2 deletions src/dbally/assistants/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from dbally.assistants.base import AssistantAdapter, FunctionCallingError, FunctionCallState
from dbally.collection import Collection
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.iql_generator.aggregation_prompt import UnsupportedAggregationError
from dbally.iql_generator.filters_prompt import UnsupportedQueryError

_DBALLY_INFO = "Dbally has access to the following database views: "

Expand Down Expand Up @@ -114,7 +115,7 @@ async def process_response(
# In case of raise_exception use TaskGroup, otherwise asyncio.gather.
response_dbally = await self.collection.ask(question=function_args.get("query"))
response = json.dumps(response_dbally.results)
except UnsupportedQueryError:
except (UnsupportedQueryError, UnsupportedAggregationError):
state = FunctionCallState.UNSUPPORTED_QUERY
response = str(state)

Expand Down
2 changes: 1 addition & 1 deletion src/dbally/gradio/gradio_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler
from dbally.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.iql_generator.filters_prompt import UnsupportedQueryError
from dbally.prompt.template import PromptTemplateError


Expand Down
61 changes: 61 additions & 0 deletions src/dbally/iql_generator/aggregation_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import List

from dbally.exceptions import UnsupportedAggregationError
from dbally.prompt.template import PromptFormat, PromptTemplate
from dbally.views.exposed_functions import ExposedFunction


def _validate_agg_response(llm_response: str) -> str:
"""
Validates LLM response to IQL

Args:
llm_response: LLM response

Returns:
A string containing aggregations.

Raises:
UnsupportedAggregationError: When IQL generator is unable to construct a query
with given aggregation.
"""
if "unsupported query" in llm_response.lower():
raise UnsupportedAggregationError
return llm_response


class AggregationPromptFormat(PromptFormat):
"""
Aggregation prompt format, providing a question and aggregation to be used in the conversation.
"""

def __init__(
self,
question: str,
aggregations: List[ExposedFunction] = None,
) -> None:
super().__init__()
self.question = question
self.aggregations = "\n".join([str(aggregation) for aggregation in aggregations]) if aggregations else []


AGGREGATION_GENERATION_TEMPLATE = PromptTemplate[AggregationPromptFormat](
[
{
"role": "system",
"content": "You have access to an API that lets you query a database supporting a SINGLE aggregation.\n"
"When prompted for an aggregation, use the following methods: \n"
"{aggregations}"
"DO NOT INCLUDE arguments names in your response. Only the values.\n"
"You MUST use only these methods:\n"
"\n{aggregations}\n"
"It is VERY IMPORTANT not to use methods other than those listed above."
"""If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`"""
"This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. "
"Structure output to resemble the following pattern:\n"
'aggregation1("arg1", arg2)\n',
},
{"role": "user", "content": "{question}"},
],
response_parser=_validate_agg_response,
)
81 changes: 81 additions & 0 deletions src/dbally/iql_generator/iql_aggregation_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import List, Optional

from dbally.audit import EventTracker
from dbally.iql import IQLError, IQLQuery
from dbally.iql_generator.aggregation_prompt import AGGREGATION_GENERATION_TEMPLATE, AggregationPromptFormat
from dbally.llms.base import LLM
from dbally.llms.clients import LLMOptions
from dbally.prompt.template import PromptTemplate
from dbally.views.exposed_functions import ExposedFunction

ERROR_MESSAGE = "Unfortunately, generated IQL aggregation is not valid. Please try again, \
generation of correct IQL is very important. Below you have errors generated by the system:\n{error}"


class IQLAggregationGenerator:
"""
Class used to manage choice and formatting of aggregation based on natural language question.
"""

def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[AggregationPromptFormat]] = None) -> None:
"""
Constructs a new AggregationFormatter instance.

Args:
llm: LLM used to generate IQL
prompt_template: If not provided by the users is set to `AGGREGATION_GENERATION_TEMPLATE`
"""
self._llm = llm
self._prompt_template = prompt_template or AGGREGATION_GENERATION_TEMPLATE

async def generate_iql(
self,
question: str,
aggregations: List[ExposedFunction],
event_tracker: EventTracker,
llm_options: Optional[LLMOptions] = None,
n_retries: int = 3,
) -> IQLQuery:
"""
Generates IQL in text form using LLM.

Args:
question: User question.
event_tracker: Event store used to audit the generation process.
aggregations: List of aggregations exposed by the view.
llm_options: Options to use for the LLM client.
n_retries: Number of retries to regenerate IQL in case of errors.

Returns:
Generated aggregation query.

Raises:
IQLError: If IQL generation fails after all retries.
"""
prompt_format = AggregationPromptFormat(
question=question,
aggregations=aggregations,
)

formatted_prompt = self._prompt_template.format_prompt(prompt_format)

for retry in range(n_retries + 1):
try:
response = await self._llm.generate_text(
prompt=formatted_prompt,
event_tracker=event_tracker,
options=llm_options,
)
# TODO: Move response parsing to llm generate_text method
iql = formatted_prompt.response_parser(response)
# TODO: Move IQL query parsing to prompt response parser
return await IQLQuery.parse(
source=iql,
allowed_functions=aggregations,
event_tracker=event_tracker,
)
except IQLError as exc:
if retry == n_retries:
raise exc
formatted_prompt = formatted_prompt.add_assistant_message(response)
formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc))
91 changes: 91 additions & 0 deletions src/dbally/iql_generator/iql_filters_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import List, Optional

from dbally.audit.event_tracker import EventTracker
from dbally.iql import IQLError, IQLQuery
from dbally.iql_generator.filters_prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.prompt.elements import FewShotExample
from dbally.prompt.template import PromptTemplate
from dbally.views.exposed_functions import ExposedFunction

ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \
generation of correct IQL is very important. Below you have errors generated by the system:\n{error}"


class IQLFiltersGenerator:
"""
Class used to generate IQL from natural language question.

In db-ally, LLM uses IQL (Intermediate Query Language) to express complex queries in a simplified way.
The class used to generate IQL from natural language query is `IQLGenerator`.

IQL generation is done using the method `self.generate_iql`.
It uses LLM to generate text-based responses, passing in the prompt template, formatted filters, and user question.
"""

def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None) -> None:
"""
Constructs a new IQLGenerator instance.

Args:
llm: LLM used to generate IQL
prompt_template: If not provided by the users is set to `default_iql_template`
"""
self._llm = llm
self._prompt_template = prompt_template or IQL_GENERATION_TEMPLATE

async def generate_iql(
self,
question: str,
filters: List[ExposedFunction],
event_tracker: EventTracker,
examples: Optional[List[FewShotExample]] = None,
llm_options: Optional[LLMOptions] = None,
n_retries: int = 3,
) -> IQLQuery:
"""
Generates IQL in text form using LLM.

Args:
question: User question.
filters: List of filters exposed by the view.
event_tracker: Event store used to audit the generation process.
examples: List of examples to be injected into the conversation.
llm_options: Options to use for the LLM client.
n_retries: Number of retries to regenerate IQL in case of errors.

Returns:
Generated IQL query.

Raises:
IQLError: If IQL generation fails after all retries.
"""
prompt_format = IQLGenerationPromptFormat(
question=question,
filters=filters,
examples=examples,
)

formatted_prompt = self._prompt_template.format_prompt(prompt_format)

for retry in range(n_retries + 1):
try:
response = await self._llm.generate_text(
prompt=formatted_prompt,
event_tracker=event_tracker,
options=llm_options,
)
# TODO: Move response parsing to llm generate_text method
iql = formatted_prompt.response_parser(response)
# TODO: Move IQL query parsing to prompt response parser
return await IQLQuery.parse(
source=iql,
allowed_functions=filters or [],
event_tracker=event_tracker,
)
except IQLError as exc:
if retry == n_retries:
raise exc
formatted_prompt = formatted_prompt.add_assistant_message(response)
formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc))
Loading
Loading