Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(views): optional filtering for structured views #78

Merged
merged 7 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/dbally/iql/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ def __init__(self, source: str) -> None:
super().__init__(message, source)


class IQLEmptyExpressionError(IQLError):
"""Raised when IQL expression is empty."""
class IQLNoStatementError(IQLError):
"""Raised when IQL does not have any statement."""

def __init__(self, source: str) -> None:
message = "Empty IQL expression"
message = "Empty IQL"
super().__init__(message, source)


class IQLMultipleExpressionsError(IQLError):
"""Raised when IQL contains multiple expressions."""
class IQLMultipleStatementsError(IQLError):
"""Raised when IQL contains multiple statements."""

def __init__(self, nodes: List[ast.stmt], source: str) -> None:
message = "Multiple expressions or statements in IQL are not supported"
message = "Multiple statements in IQL are not supported"
super().__init__(message, source)
self.nodes = nodes

Expand Down
8 changes: 4 additions & 4 deletions src/dbally/iql/_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from dbally.iql._exceptions import (
IQLArgumentParsingError,
IQLArgumentValidationError,
IQLEmptyExpressionError,
IQLFunctionNotExists,
IQLIncorrectNumberArgumentsError,
IQLMultipleExpressionsError,
IQLMultipleStatementsError,
IQLNoExpressionError,
IQLNoStatementError,
IQLSyntaxError,
IQLUnsupportedSyntaxError,
)
Expand Down Expand Up @@ -50,10 +50,10 @@ async def process(self) -> syntax.Node:
raise IQLSyntaxError(self.source) from exc

if not ast_tree.body:
raise IQLEmptyExpressionError(self.source)
raise IQLNoStatementError(self.source)

if len(ast_tree.body) > 1:
raise IQLMultipleExpressionsError(ast_tree.body, self.source)
raise IQLMultipleStatementsError(ast_tree.body, self.source)

if not isinstance(ast_tree.body[0], ast.Expr):
raise IQLNoExpressionError(ast_tree.body[0], self.source)
Expand Down
110 changes: 104 additions & 6 deletions src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

from dbally.audit.event_tracker import EventTracker
from dbally.iql import IQLError, IQLQuery
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat
from dbally.iql_generator.prompt import (
FILTERING_DECISION_TEMPLATE,
IQL_GENERATION_TEMPLATE,
FilteringDecisionPromptFormat,
IQLGenerationPromptFormat,
)
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.llms.clients.exceptions import LLMError
Expand All @@ -25,17 +30,110 @@ class IQLGenerator:
It uses LLM to generate text-based responses, passing in the prompt template, formatted filters, and user question.
"""

def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None) -> None:
def __init__(
self,
llm: LLM,
*,
decision_prompt: Optional[PromptTemplate[FilteringDecisionPromptFormat]] = None,
generation_prompt: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None,
) -> None:
"""
Constructs a new IQLGenerator instance.

Args:
llm: LLM used to generate IQL
llm: LLM used to generate IQL.
decision_prompt: Prompt template for filtering decision making.
generation_prompt: Prompt template for IQL generation.
"""
self._llm = llm
self._prompt_template = prompt_template or IQL_GENERATION_TEMPLATE
self._decision_prompt = decision_prompt or FILTERING_DECISION_TEMPLATE
self._generation_prompt = generation_prompt or IQL_GENERATION_TEMPLATE

async def generate(
self,
question: str,
filters: List[ExposedFunction],
event_tracker: EventTracker,
examples: Optional[List[FewShotExample]] = None,
llm_options: Optional[LLMOptions] = None,
n_retries: int = 3,
) -> Optional[IQLQuery]:
"""
Generates IQL in text form using LLM.

Args:
question: User question.
filters: List of filters exposed by the view.
event_tracker: Event store used to audit the generation process.
examples: List of examples to be injected into the conversation.
llm_options: Options to use for the LLM client.
n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection.

Returns:
Generated IQL query or None if the decision is not to continue.

Raises:
LLMError: If LLM text generation fails after all retries.
IQLError: If IQL parsing fails after all retries.
UnsupportedQueryError: If the question is not supported by the view.
"""
decision = await self._decide_on_generation(
question=question,
event_tracker=event_tracker,
llm_options=llm_options,
n_retries=n_retries,
)
if not decision:
return None

return await self._generate_iql(
question=question,
filters=filters,
event_tracker=event_tracker,
examples=examples,
llm_options=llm_options,
n_retries=n_retries,
)

async def _decide_on_generation(
self,
question: str,
event_tracker: EventTracker,
llm_options: Optional[LLMOptions] = None,
n_retries: int = 3,
) -> bool:
"""
Decides whether the question requires filtering or not.

Args:
question: User question.
event_tracker: Event store used to audit the generation process.
llm_options: Options to use for the LLM client.
n_retries: Number of retries to LLM API in case of errors.

Returns:
Decision whether to generate IQL or not.

Raises:
LLMError: If LLM text generation fails after all retries.
"""
prompt_format = FilteringDecisionPromptFormat(question=question)
formatted_prompt = self._decision_prompt.format_prompt(prompt_format)

for retry in range(n_retries + 1):
try:
response = await self._llm.generate_text(
prompt=formatted_prompt,
event_tracker=event_tracker,
options=llm_options,
)
# TODO: Move response parsing to llm generate_text method
return formatted_prompt.response_parser(response)
except LLMError as exc:
if retry == n_retries:
raise exc

async def generate_iql(
async def _generate_iql(
self,
question: str,
filters: List[ExposedFunction],
Expand Down Expand Up @@ -68,7 +166,7 @@ async def generate_iql(
filters=filters,
examples=examples,
)
formatted_prompt = self._prompt_template.format_prompt(prompt_format)
formatted_prompt = self._generation_prompt.format_prompt(prompt_format)

for retry in range(n_retries + 1):
try:
Expand Down
65 changes: 65 additions & 0 deletions src/dbally/iql_generator/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,41 @@ def _validate_iql_response(llm_response: str) -> str:
return llm_response


def _decision_iql_response_parser(response: str) -> bool:
"""
Parses the response from the decision prompt.

Args:
response: Response from the LLM.

Returns:
True if the response is positive, False otherwise.
"""
response = response.lower()
if "decision:" not in response:
return False

_, decision = response.split("decision:", 1)
return "true" in decision


class FilteringDecisionPromptFormat(PromptFormat):
"""
IQL prompt format, providing a question and filters to be used in the conversation.
"""

def __init__(self, *, question: str, examples: List[FewShotExample] = None) -> None:
"""
Constructs a new IQLGenerationPromptFormat instance.

Args:
question: Question to be asked.
examples: List of examples to be injected into the conversation.
"""
super().__init__(examples)
self.question = question


class IQLGenerationPromptFormat(PromptFormat):
"""
IQL prompt format, providing a question and filters to be used in the conversation.
Expand Down Expand Up @@ -85,3 +120,33 @@ def __init__(
],
response_parser=_validate_iql_response,
)


FILTERING_DECISION_TEMPLATE = PromptTemplate[FilteringDecisionPromptFormat](
[
{
"role": "system",
"content": (
"Given a question, determine whether the answer requires initial data filtering in order to compute it.\n"
"Initial data filtering is a process in which the result set is reduced to only include the rows "
"that meet certain criteria specified in the question.\n\n"
"---\n\n"
"Follow the following format.\n\n"
"Question: ${{question}}\n"
"Hint: ${{hint}}"
"Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n"
"Decision: indicates whether the answer to the question requires initial data filtering. "
"(Respond with True or False)\n\n"
),
},
{
"role": "user",
"content": (
"Question: {question}\n"
"Hint: Look for words indicating data specific features.\n"
"Reasoning: Let's think step by step in order to "
),
},
],
response_parser=_decision_iql_response_parser,
)
3 changes: 2 additions & 1 deletion src/dbally/views/pandas_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from functools import reduce
from typing import Optional

import pandas as pd

Expand All @@ -25,7 +26,7 @@ def __init__(self, df: pd.DataFrame) -> None:
self.df = df

# The mask to be applied to the dataframe to filter the data
self._filter_mask: pd.Series = None
self._filter_mask: Optional[pd.Series] = None

async def apply_filters(self, filters: IQLQuery) -> None:
"""
Expand Down
13 changes: 9 additions & 4 deletions src/dbally/views/sqlalchemy_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import asyncio
from typing import Optional

import sqlalchemy

Expand All @@ -13,10 +14,11 @@ class SqlAlchemyBaseView(MethodsBaseView):
Base class for views that use SQLAlchemy to generate SQL queries.
"""

def __init__(self, sqlalchemy_engine: sqlalchemy.engine.Engine) -> None:
def __init__(self, sqlalchemy_engine: sqlalchemy.Engine) -> None:
super().__init__()
self._select = self.get_select()
self._sqlalchemy_engine = sqlalchemy_engine
self._select = self.get_select()
self._where_clause: Optional[sqlalchemy.ColumnElement] = None

@abc.abstractmethod
def get_select(self) -> sqlalchemy.Select:
Expand All @@ -34,7 +36,7 @@ async def apply_filters(self, filters: IQLQuery) -> None:
Args:
filters: IQLQuery object representing the filters to apply
"""
self._select = self._select.where(await self._build_filter_node(filters.root))
self._where_clause = await self._build_filter_node(filters.root)

async def _build_filter_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement:
"""
Expand Down Expand Up @@ -75,8 +77,11 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:
Results of the query where `results` will be a list of dictionaries representing retrieved rows or an empty\
list if `dry_run` is set to `True`. Inside the `context` field the generated sql will be stored.
"""

results = []

if self._where_clause is not None:
self._select = self._select.where(self._where_clause)

sql = str(self._select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True}))

if not dry_run:
Expand Down
7 changes: 4 additions & 3 deletions src/dbally/views/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def ask(
examples = self.list_few_shots()

try:
iql = await iql_generator.generate_iql(
iql = await iql_generator.generate(
question=query,
filters=filters,
examples=examples,
Expand All @@ -90,10 +90,11 @@ async def ask(
aggregation=None,
) from exc

await self.apply_filters(iql)
if iql:
await self.apply_filters(iql)

result = self.execute(dry_run=dry_run)
result.context["iql"] = f"{iql}"
result.context["iql"] = str(iql) if iql else None

return result

Expand Down
12 changes: 6 additions & 6 deletions tests/unit/iql/test_iql_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from dbally.iql import IQLArgumentParsingError, IQLQuery, IQLUnsupportedSyntaxError, syntax
from dbally.iql._exceptions import (
IQLArgumentValidationError,
IQLEmptyExpressionError,
IQLFunctionNotExists,
IQLIncorrectNumberArgumentsError,
IQLMultipleExpressionsError,
IQLMultipleStatementsError,
IQLNoExpressionError,
IQLNoStatementError,
IQLSyntaxError,
)
from dbally.iql._processor import IQLProcessor
Expand Down Expand Up @@ -95,7 +95,7 @@ async def test_iql_parser_syntax_error():


async def test_iql_parser_multiple_expression_error():
with pytest.raises(IQLMultipleExpressionsError) as exc_info:
with pytest.raises(IQLMultipleStatementsError) as exc_info:
await IQLQuery.parse(
"filter_by_age\nfilter_by_age",
allowed_functions=[
Expand All @@ -109,11 +109,11 @@ async def test_iql_parser_multiple_expression_error():
],
)

assert exc_info.match(re.escape("Multiple expressions or statements in IQL are not supported"))
assert exc_info.match(re.escape("Multiple statements in IQL are not supported"))


async def test_iql_parser_empty_expression_error():
with pytest.raises(IQLEmptyExpressionError) as exc_info:
with pytest.raises(IQLNoStatementError) as exc_info:
await IQLQuery.parse(
"",
allowed_functions=[
Expand All @@ -127,7 +127,7 @@ async def test_iql_parser_empty_expression_error():
],
)

assert exc_info.match(re.escape("Empty IQL expression"))
assert exc_info.match(re.escape("Empty IQL"))


async def test_iql_parser_no_expression_error():
Expand Down
Loading
Loading