Skip to content

Commit 139ab9b

Browse files
feat: aggregations in structured views (#62)
1 parent 23e50ff commit 139ab9b

29 files changed

+1235
-423
lines changed

benchmarks/sql/bench/pipelines/base.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
3-
from typing import Any, Dict, Optional
3+
from typing import Any, Dict, Optional, Union
44

5+
from dbally.iql._exceptions import IQLError
6+
from dbally.iql._query import IQLQuery
7+
from dbally.iql_generator.prompt import UnsupportedQueryError
58
from dbally.llms.base import LLM
9+
from dbally.llms.clients.exceptions import LLMError
610
from dbally.llms.litellm import LiteLLM
711
from dbally.llms.local import LocalLLM
812

@@ -16,6 +20,25 @@ class IQL:
1620
source: Optional[str] = None
1721
unsupported: bool = False
1822
valid: bool = True
23+
generated: bool = True
24+
25+
@classmethod
26+
def from_query(cls, query: Optional[Union[IQLQuery, Exception]]) -> "IQL":
27+
"""
28+
Creates an IQL object from the query.
29+
30+
Args:
31+
query: The IQL query or exception.
32+
33+
Returns:
34+
The IQL object.
35+
"""
36+
return cls(
37+
source=query.source if isinstance(query, (IQLQuery, IQLError)) else None,
38+
unsupported=isinstance(query, UnsupportedQueryError),
39+
valid=not isinstance(query, IQLError),
40+
generated=not isinstance(query, LLMError),
41+
)
1942

2043

2144
@dataclass
@@ -47,6 +70,7 @@ class EvaluationResult:
4770
"""
4871

4972
db_id: str
73+
question_id: str
5074
question: str
5175
reference: ExecutionResult
5276
prediction: ExecutionResult

benchmarks/sql/bench/pipelines/collection.py

+9-31
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
import dbally
66
from dbally.collection.collection import Collection
77
from dbally.collection.exceptions import NoViewFoundError
8-
from dbally.iql._exceptions import IQLError
9-
from dbally.iql_generator.prompt import UnsupportedQueryError
108
from dbally.view_selection.llm_view_selector import LLMViewSelector
11-
from dbally.views.exceptions import IQLGenerationError
9+
from dbally.views.exceptions import ViewExecutionError
1210

1311
from ..views import VIEWS_REGISTRY
1412
from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult
@@ -74,44 +72,23 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
7472
return_natural_response=False,
7573
)
7674
except NoViewFoundError:
77-
prediction = ExecutionResult(
78-
view_name=None,
79-
iql=None,
80-
sql=None,
81-
)
82-
except IQLGenerationError as exc:
75+
prediction = ExecutionResult()
76+
except ViewExecutionError as exc:
8377
prediction = ExecutionResult(
8478
view_name=exc.view_name,
8579
iql=IQLResult(
86-
filters=IQL(
87-
source=exc.filters,
88-
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
89-
valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)),
90-
),
91-
aggregation=IQL(
92-
source=exc.aggregation,
93-
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
94-
valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)),
95-
),
80+
filters=IQL.from_query(exc.iql.filters),
81+
aggregation=IQL.from_query(exc.iql.aggregation),
9682
),
97-
sql=None,
9883
)
9984
else:
10085
prediction = ExecutionResult(
10186
view_name=result.view_name,
10287
iql=IQLResult(
103-
filters=IQL(
104-
source=result.context.get("iql"),
105-
unsupported=False,
106-
valid=True,
107-
),
108-
aggregation=IQL(
109-
source=None,
110-
unsupported=False,
111-
valid=True,
112-
),
88+
filters=IQL(source=result.context["iql"]["filters"]),
89+
aggregation=IQL(source=result.context["iql"]["aggregation"]),
11390
),
114-
sql=result.context.get("sql"),
91+
sql=result.context["sql"],
11592
)
11693

11794
reference = ExecutionResult(
@@ -134,6 +111,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
134111

135112
return EvaluationResult(
136113
db_id=data["db_id"],
114+
question_id=data["question_id"],
137115
question=data["question"],
138116
reference=reference,
139117
prediction=prediction,

benchmarks/sql/bench/pipelines/view.py

+8-27
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55

66
from sqlalchemy import create_engine
77

8-
from dbally.iql._exceptions import IQLError
9-
from dbally.iql_generator.prompt import UnsupportedQueryError
10-
from dbally.views.exceptions import IQLGenerationError
8+
from dbally.views.exceptions import ViewExecutionError
119
from dbally.views.freeform.text2sql.view import BaseText2SQLView
1210
from dbally.views.sqlalchemy_base import SqlAlchemyBaseView
1311

@@ -94,37 +92,20 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
9492
dry_run=True,
9593
n_retries=0,
9694
)
97-
except IQLGenerationError as exc:
95+
except ViewExecutionError as exc:
9896
prediction = ExecutionResult(
9997
view_name=data["view_name"],
10098
iql=IQLResult(
101-
filters=IQL(
102-
source=exc.filters,
103-
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
104-
valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)),
105-
),
106-
aggregation=IQL(
107-
source=exc.aggregation,
108-
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
109-
valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)),
110-
),
99+
filters=IQL.from_query(exc.iql.filters),
100+
aggregation=IQL.from_query(exc.iql.aggregation),
111101
),
112-
sql=None,
113102
)
114103
else:
115104
prediction = ExecutionResult(
116105
view_name=data["view_name"],
117106
iql=IQLResult(
118-
filters=IQL(
119-
source=result.context["iql"],
120-
unsupported=False,
121-
valid=True,
122-
),
123-
aggregation=IQL(
124-
source=None,
125-
unsupported=False,
126-
valid=True,
127-
),
107+
filters=IQL(source=result.context["iql"]["filters"]),
108+
aggregation=IQL(source=result.context["iql"]["aggregation"]),
128109
),
129110
sql=result.context["sql"],
130111
)
@@ -135,12 +116,10 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
135116
filters=IQL(
136117
source=data["iql_filters"],
137118
unsupported=data["iql_filters_unsupported"],
138-
valid=True,
139119
),
140120
aggregation=IQL(
141121
source=data["iql_aggregation"],
142122
unsupported=data["iql_aggregation_unsupported"],
143-
valid=True,
144123
),
145124
context=data["iql_context"],
146125
),
@@ -149,6 +128,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
149128

150129
return EvaluationResult(
151130
db_id=data["db_id"],
131+
question_id=data["question_id"],
152132
question=data["question"],
153133
reference=reference,
154134
prediction=prediction,
@@ -209,6 +189,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
209189

210190
return EvaluationResult(
211191
db_id=data["db_id"],
192+
question_id=data["question_id"],
212193
question=data["question"],
213194
reference=reference,
214195
prediction=prediction,

benchmarks/sql/bench/views/structured/superhero.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sqlalchemy.ext.declarative import DeferredReflection
88
from sqlalchemy.orm import aliased, declarative_base
99

10-
from dbally.views.decorators import view_filter
10+
from dbally.views.decorators import view_aggregation, view_filter
1111
from dbally.views.sqlalchemy_base import SqlAlchemyBaseView
1212

1313
Base = declarative_base(cls=DeferredReflection)
@@ -285,8 +285,8 @@ class SuperheroColourFilterMixin:
285285
Mixin for filtering the view by the superhero colour attributes.
286286
"""
287287

288-
def __init__(self) -> None:
289-
super().__init__()
288+
def __init__(self, *args, **kwargs) -> None:
289+
super().__init__(*args, **kwargs)
290290
self.eye_colour = aliased(Colour)
291291
self.hair_colour = aliased(Colour)
292292
self.skin_colour = aliased(Colour)
@@ -427,10 +427,27 @@ def filter_by_race(self, race: str) -> ColumnElement:
427427
return Race.race == race
428428

429429

430+
class SuperheroAggregationMixin:
431+
"""
432+
Mixin for aggregating the view by the superhero attributes.
433+
"""
434+
435+
@view_aggregation()
436+
def count_superheroes(self) -> Select:
437+
"""
438+
Counts the number of superheros.
439+
440+
Returns:
441+
The superheros count.
442+
"""
443+
return self.select.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id)
444+
445+
430446
class SuperheroView(
431447
DBInitMixin,
432448
SqlAlchemyBaseView,
433449
SuperheroFilterMixin,
450+
SuperheroAggregationMixin,
434451
SuperheroColourFilterMixin,
435452
AlignmentFilterMixin,
436453
GenderFilterMixin,

docs/how-to/views/custom_views_code.py

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:
6666

6767
return ViewExecutionResult(results=filtered_data, context={})
6868

69+
6970
class CandidateView(FilteredIterableBaseView):
7071
def get_data(self) -> Iterable:
7172
return [

docs/quickstart/quickstart_code.py

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from sqlalchemy import create_engine
77
from sqlalchemy.ext.automap import automap_base
88

9-
import dbally
109
from dbally import decorators, SqlAlchemyBaseView
1110
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
1211
from dbally.llms.litellm import LiteLLM

src/dbally/iql/__init__.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
from . import syntax
22
from ._exceptions import IQLArgumentParsingError, IQLError, IQLUnsupportedSyntaxError
3-
from ._query import IQLQuery
3+
from ._query import IQLAggregationQuery, IQLFiltersQuery, IQLQuery
44

5-
__all__ = ["IQLQuery", "syntax", "IQLError", "IQLArgumentParsingError", "IQLUnsupportedSyntaxError"]
5+
__all__ = [
6+
"IQLQuery",
7+
"IQLFiltersQuery",
8+
"IQLAggregationQuery",
9+
"syntax",
10+
"IQLError",
11+
"IQLArgumentParsingError",
12+
"IQLUnsupportedSyntaxError",
13+
]

src/dbally/iql/_processor.py

+55-22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import ast
2-
from typing import TYPE_CHECKING, Any, List, Optional, Union
2+
from abc import ABC, abstractmethod
3+
from typing import TYPE_CHECKING, Any, Generic, List, Optional, TypeVar, Union
34

45
from dbally.audit.event_tracker import EventTracker
56
from dbally.iql import syntax
@@ -19,10 +20,12 @@
1920
if TYPE_CHECKING:
2021
from dbally.views.structured import ExposedFunction
2122

23+
RootT = TypeVar("RootT", bound=syntax.Node)
2224

23-
class IQLProcessor:
25+
26+
class IQLProcessor(Generic[RootT], ABC):
2427
"""
25-
Parses IQL string to tree structure.
28+
Base class for IQL processors.
2629
"""
2730

2831
def __init__(
@@ -32,9 +35,9 @@ def __init__(
3235
self.allowed_functions = {func.name: func for func in allowed_functions}
3336
self._event_tracker = event_tracker or EventTracker()
3437

35-
async def process(self) -> syntax.Node:
38+
async def process(self) -> RootT:
3639
"""
37-
Process IQL string to root IQL.Node.
40+
Process IQL string to IQL root node.
3841
3942
Returns:
4043
IQL node which is root of the tree representing IQL query.
@@ -60,25 +63,17 @@ async def process(self) -> syntax.Node:
6063

6164
return await self._parse_node(ast_tree.body[0].value)
6265

63-
async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node:
64-
if isinstance(node, ast.BoolOp):
65-
return await self._parse_bool_op(node)
66-
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
67-
return syntax.Not(await self._parse_node(node.operand))
68-
if isinstance(node, ast.Call):
69-
return await self._parse_call(node)
70-
71-
raise IQLUnsupportedSyntaxError(node, self.source)
66+
@abstractmethod
67+
async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> RootT:
68+
"""
69+
Parses AST node to IQL node.
7270
73-
async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp:
74-
if isinstance(node.op, ast.Not):
75-
return syntax.Not(await self._parse_node(node.values[0]))
76-
if isinstance(node.op, ast.And):
77-
return syntax.And([await self._parse_node(x) for x in node.values])
78-
if isinstance(node.op, ast.Or):
79-
return syntax.Or([await self._parse_node(x) for x in node.values])
71+
Args:
72+
node: AST node to parse.
8073
81-
raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp")
74+
Returns:
75+
IQL node.
76+
"""
8277

8378
async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall:
8479
func = node.func
@@ -153,3 +148,41 @@ def _to_lower_except_in_quotes(text: str, keywords: List[str]) -> str:
153148
converted_text = converted_text[: len(converted_text) - len(keyword)] + keyword.lower()
154149

155150
return converted_text
151+
152+
153+
class IQLFiltersProcessor(IQLProcessor[syntax.Node]):
154+
"""
155+
IQL processor for filters.
156+
"""
157+
158+
async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node:
159+
if isinstance(node, ast.BoolOp):
160+
return await self._parse_bool_op(node)
161+
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
162+
return syntax.Not(await self._parse_node(node.operand))
163+
if isinstance(node, ast.Call):
164+
return await self._parse_call(node)
165+
166+
raise IQLUnsupportedSyntaxError(node, self.source)
167+
168+
async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp:
169+
if isinstance(node.op, ast.Not):
170+
return syntax.Not(await self._parse_node(node.values[0]))
171+
if isinstance(node.op, ast.And):
172+
return syntax.And([await self._parse_node(x) for x in node.values])
173+
if isinstance(node.op, ast.Or):
174+
return syntax.Or([await self._parse_node(x) for x in node.values])
175+
176+
raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp")
177+
178+
179+
class IQLAggregationProcessor(IQLProcessor[syntax.FunctionCall]):
180+
"""
181+
IQL processor for aggregation.
182+
"""
183+
184+
async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.FunctionCall:
185+
if isinstance(node, ast.Call):
186+
return await self._parse_call(node)
187+
188+
raise IQLUnsupportedSyntaxError(node, self.source)

0 commit comments

Comments
 (0)