Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

review: aggregations in structured views #85

Merged
merged 7 commits into from
Aug 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion benchmarks/sql/bench/pipelines/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from dbally.iql._exceptions import IQLError
from dbally.iql._query import IQLQuery
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.llms.base import LLM
from dbally.llms.clients.exceptions import LLMError
from dbally.llms.litellm import LiteLLM
from dbally.llms.local import LocalLLM

@@ -16,6 +20,25 @@ class IQL:
source: Optional[str] = None
unsupported: bool = False
valid: bool = True
generated: bool = True

@classmethod
def from_query(cls, query: Optional[Union[IQLQuery, Exception]]) -> "IQL":
"""
Creates an IQL object from the query.

Args:
query: The IQL query or exception.

Returns:
The IQL object.
"""
return cls(
source=query.source if isinstance(query, (IQLQuery, IQLError)) else None,
unsupported=isinstance(query, UnsupportedQueryError),
valid=not isinstance(query, IQLError),
generated=not isinstance(query, LLMError),
)


@dataclass
@@ -47,6 +70,7 @@ class EvaluationResult:
"""

db_id: str
question_id: str
question: str
reference: ExecutionResult
prediction: ExecutionResult
40 changes: 9 additions & 31 deletions benchmarks/sql/bench/pipelines/collection.py
Original file line number Diff line number Diff line change
@@ -5,10 +5,8 @@
import dbally
from dbally.collection.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
from dbally.iql._exceptions import IQLError
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.view_selection.llm_view_selector import LLMViewSelector
from dbally.views.exceptions import IQLGenerationError
from dbally.views.exceptions import ViewExecutionError

from ..views import VIEWS_REGISTRY
from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult
@@ -74,44 +72,23 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
return_natural_response=False,
)
except NoViewFoundError:
prediction = ExecutionResult(
view_name=None,
iql=None,
sql=None,
)
except IQLGenerationError as exc:
prediction = ExecutionResult()
except ViewExecutionError as exc:
prediction = ExecutionResult(
view_name=exc.view_name,
iql=IQLResult(
filters=IQL(
source=exc.filters,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
aggregation=IQL(
source=exc.aggregation,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
filters=IQL.from_query(exc.iql.filters),
aggregation=IQL.from_query(exc.iql.aggregation),
),
sql=None,
)
else:
prediction = ExecutionResult(
view_name=result.view_name,
iql=IQLResult(
filters=IQL(
source=result.context.get("iql"),
unsupported=False,
valid=True,
),
aggregation=IQL(
source=None,
unsupported=False,
valid=True,
),
filters=IQL(source=result.context["iql"]["filters"]),
aggregation=IQL(source=result.context["iql"]["aggregation"]),
),
sql=result.context.get("sql"),
sql=result.context["sql"],
)

reference = ExecutionResult(
@@ -134,6 +111,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:

return EvaluationResult(
db_id=data["db_id"],
question_id=data["question_id"],
question=data["question"],
reference=reference,
prediction=prediction,
35 changes: 8 additions & 27 deletions benchmarks/sql/bench/pipelines/view.py
Original file line number Diff line number Diff line change
@@ -5,9 +5,7 @@

from sqlalchemy import create_engine

from dbally.iql._exceptions import IQLError
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.views.exceptions import IQLGenerationError
from dbally.views.exceptions import ViewExecutionError
from dbally.views.freeform.text2sql.view import BaseText2SQLView
from dbally.views.sqlalchemy_base import SqlAlchemyBaseView

@@ -94,37 +92,20 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
dry_run=True,
n_retries=0,
)
except IQLGenerationError as exc:
except ViewExecutionError as exc:
prediction = ExecutionResult(
view_name=data["view_name"],
iql=IQLResult(
filters=IQL(
source=exc.filters,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
aggregation=IQL(
source=exc.aggregation,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
filters=IQL.from_query(exc.iql.filters),
aggregation=IQL.from_query(exc.iql.aggregation),
),
sql=None,
)
else:
prediction = ExecutionResult(
view_name=data["view_name"],
iql=IQLResult(
filters=IQL(
source=result.context["iql"],
unsupported=False,
valid=True,
),
aggregation=IQL(
source=None,
unsupported=False,
valid=True,
),
filters=IQL(source=result.context["iql"]["filters"]),
aggregation=IQL(source=result.context["iql"]["aggregation"]),
),
sql=result.context["sql"],
)
@@ -135,12 +116,10 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
filters=IQL(
source=data["iql_filters"],
unsupported=data["iql_filters_unsupported"],
valid=True,
),
aggregation=IQL(
source=data["iql_aggregation"],
unsupported=data["iql_aggregation_unsupported"],
valid=True,
),
context=data["iql_context"],
),
@@ -149,6 +128,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:

return EvaluationResult(
db_id=data["db_id"],
question_id=data["question_id"],
question=data["question"],
reference=reference,
prediction=prediction,
@@ -209,6 +189,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:

return EvaluationResult(
db_id=data["db_id"],
question_id=data["question_id"],
question=data["question"],
reference=reference,
prediction=prediction,
7 changes: 3 additions & 4 deletions benchmarks/sql/bench/views/structured/superhero.py
Original file line number Diff line number Diff line change
@@ -286,12 +286,11 @@ class SuperheroColourFilterMixin:
"""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.eye_colour = aliased(Colour)
self.hair_colour = aliased(Colour)
self.skin_colour = aliased(Colour)

super().__init__(*args, **kwargs)

@view_filter()
def filter_by_eye_colour(self, eye_colour: str) -> ColumnElement:
"""
@@ -441,19 +440,19 @@ def count_superheroes(self) -> Select:
Returns:
The superheros count.
"""
return self.data.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id)
return self.select.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id)


class SuperheroView(
DBInitMixin,
SqlAlchemyBaseView,
SuperheroFilterMixin,
SuperheroAggregationMixin,
SuperheroColourFilterMixin,
AlignmentFilterMixin,
GenderFilterMixin,
PublisherFilterMixin,
RaceFilterMixin,
SqlAlchemyBaseView,
):
"""
View for querying only superheros data. Contains the superhero id, superhero name, full name, height, weight,
7 changes: 0 additions & 7 deletions src/dbally/exceptions.py
Original file line number Diff line number Diff line change
@@ -2,10 +2,3 @@ class DbAllyError(Exception):
"""
Base class for all exceptions raised by db-ally.
"""


class UnsupportedAggregationError(DbAllyError):
"""
Error raised when AggregationFormatter is unable to construct a query
with given aggregation.
"""
12 changes: 10 additions & 2 deletions src/dbally/iql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from . import syntax
from ._exceptions import IQLArgumentParsingError, IQLError, IQLUnsupportedSyntaxError
from ._query import IQLQuery
from ._query import IQLAggregationQuery, IQLFiltersQuery, IQLQuery

__all__ = ["IQLQuery", "syntax", "IQLError", "IQLArgumentParsingError", "IQLUnsupportedSyntaxError"]
__all__ = [
"IQLQuery",
"IQLFiltersQuery",
"IQLAggregationQuery",
"syntax",
"IQLError",
"IQLArgumentParsingError",
"IQLUnsupportedSyntaxError",
]
77 changes: 55 additions & 22 deletions src/dbally/iql/_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
from typing import TYPE_CHECKING, Any, List, Optional, Union
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Generic, List, Optional, TypeVar, Union

from dbally.audit.event_tracker import EventTracker
from dbally.iql import syntax
@@ -19,10 +20,12 @@
if TYPE_CHECKING:
from dbally.views.structured import ExposedFunction

RootT = TypeVar("RootT", bound=syntax.Node)

class IQLProcessor:

class IQLProcessor(Generic[RootT], ABC):
"""
Parses IQL string to tree structure.
Base class for IQL processors.
"""

def __init__(
@@ -32,9 +35,9 @@ def __init__(
self.allowed_functions = {func.name: func for func in allowed_functions}
self._event_tracker = event_tracker or EventTracker()

async def process(self) -> syntax.Node:
async def process(self) -> RootT:
"""
Process IQL string to root IQL.Node.
Process IQL string to IQL root node.

Returns:
IQL node which is root of the tree representing IQL query.
@@ -60,25 +63,17 @@ async def process(self) -> syntax.Node:

return await self._parse_node(ast_tree.body[0].value)

async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node:
if isinstance(node, ast.BoolOp):
return await self._parse_bool_op(node)
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.operand))
if isinstance(node, ast.Call):
return await self._parse_call(node)

raise IQLUnsupportedSyntaxError(node, self.source)
@abstractmethod
async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> RootT:
"""
Parses AST node to IQL node.

async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp:
if isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.values[0]))
if isinstance(node.op, ast.And):
return syntax.And([await self._parse_node(x) for x in node.values])
if isinstance(node.op, ast.Or):
return syntax.Or([await self._parse_node(x) for x in node.values])
Args:
node: AST node to parse.

raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp")
Returns:
IQL node.
"""

async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall:
func = node.func
@@ -153,3 +148,41 @@ def _to_lower_except_in_quotes(text: str, keywords: List[str]) -> str:
converted_text = converted_text[: len(converted_text) - len(keyword)] + keyword.lower()

return converted_text


class IQLFiltersProcessor(IQLProcessor[syntax.Node]):
"""
IQL processor for filters.
"""

async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node:
if isinstance(node, ast.BoolOp):
return await self._parse_bool_op(node)
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.operand))
if isinstance(node, ast.Call):
return await self._parse_call(node)

raise IQLUnsupportedSyntaxError(node, self.source)

async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp:
if isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.values[0]))
if isinstance(node.op, ast.And):
return syntax.And([await self._parse_node(x) for x in node.values])
if isinstance(node.op, ast.Or):
return syntax.Or([await self._parse_node(x) for x in node.values])

raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp")


class IQLAggregationProcessor(IQLProcessor[syntax.FunctionCall]):
"""
IQL processor for aggregation.
"""

async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.FunctionCall:
if isinstance(node, ast.Call):
return await self._parse_call(node)

raise IQLUnsupportedSyntaxError(node, self.source)
Loading