Skip to content

Commit afe4331

Browse files
authored
refactor: support agg_ops.DenseRankOp and RankOp for sqlglot compiler (#2114)
1 parent 7ef667b commit afe4331

File tree

6 files changed

+105
-1
lines changed

6 files changed

+105
-1
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ def _(
4747
return apply_window_if_present(sge.func("COUNT", column.expr), window)
4848

4949

50+
@UNARY_OP_REGISTRATION.register(agg_ops.DenseRankOp)
51+
def _(
52+
op: agg_ops.DenseRankOp,
53+
column: typed_expr.TypedExpr,
54+
window: typing.Optional[window_spec.WindowSpec] = None,
55+
) -> sge.Expression:
56+
# Ranking functions do not support window framing clauses.
57+
return apply_window_if_present(
58+
sge.func("DENSE_RANK"), window, include_framing_clauses=False
59+
)
60+
61+
5062
@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
5163
def _(
5264
op: agg_ops.MaxOp,
@@ -106,6 +118,18 @@ def _(
106118
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
107119

108120

121+
@UNARY_OP_REGISTRATION.register(agg_ops.RankOp)
122+
def _(
123+
op: agg_ops.RankOp,
124+
column: typed_expr.TypedExpr,
125+
window: typing.Optional[window_spec.WindowSpec] = None,
126+
) -> sge.Expression:
127+
# Ranking functions do not support window framing clauses.
128+
return apply_window_if_present(
129+
sge.func("RANK"), window, include_framing_clauses=False
130+
)
131+
132+
109133
@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)
110134
def _(
111135
op: agg_ops.SumOp,

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
def apply_window_if_present(
2626
value: sge.Expression,
2727
window: typing.Optional[window_spec.WindowSpec] = None,
28+
include_framing_clauses: bool = True,
2829
) -> sge.Expression:
2930
if window is None:
3031
return value
@@ -64,6 +65,9 @@ def apply_window_if_present(
6465
if not window.bounds and not order:
6566
return sge.Window(this=value, partition_by=group_by)
6667

68+
if not window.bounds and not include_framing_clauses:
69+
return sge.Window(this=value, partition_by=group_by, order=order)
70+
6771
kind = (
6872
"ROWS" if isinstance(window.bounds, window_spec.RowsWindowBounds) else "RANGE"
6973
)

bigframes/operations/aggregations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,8 @@ def implicitly_inherits_order(self):
519519

520520
@dataclasses.dataclass(frozen=True)
521521
class DenseRankOp(UnaryWindowOp):
522+
name: ClassVar[str] = "dense_rank"
523+
522524
@property
523525
def skips_nulls(self):
524526
return False
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
DENSE_RANK() OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `agg_int64`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
RANK() OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `agg_int64`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@
1717
import pytest
1818

1919
from bigframes.core import agg_expressions as agg_exprs
20-
from bigframes.core import array_value, identifiers, nodes
20+
from bigframes.core import (
21+
array_value,
22+
expression,
23+
identifiers,
24+
nodes,
25+
ordering,
26+
window_spec,
27+
)
2128
from bigframes.operations import aggregations as agg_ops
2229
import bigframes.pandas as bpd
2330

@@ -38,6 +45,24 @@ def _apply_unary_agg_ops(
3845
return sql
3946

4047

48+
def _apply_unary_window_op(
49+
obj: bpd.DataFrame,
50+
op: agg_exprs.UnaryAggregation,
51+
window_spec: window_spec.WindowSpec,
52+
new_name: str,
53+
) -> str:
54+
win_node = nodes.WindowOpNode(
55+
obj._block.expr.node,
56+
expression=op,
57+
window_spec=window_spec,
58+
output_name=identifiers.ColumnId(new_name),
59+
)
60+
result = array_value.ArrayValue(win_node).select_columns([new_name])
61+
62+
sql = result.session._executor.to_sql(result, enable_cache=False)
63+
return sql
64+
65+
4166
def test_count(scalar_types_df: bpd.DataFrame, snapshot):
4267
col_name = "int64_col"
4368
bf_df = scalar_types_df[[col_name]]
@@ -47,6 +72,18 @@ def test_count(scalar_types_df: bpd.DataFrame, snapshot):
4772
snapshot.assert_match(sql, "out.sql")
4873

4974

75+
def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot):
76+
col_name = "int64_col"
77+
bf_df = scalar_types_df[[col_name]]
78+
agg_expr = agg_exprs.UnaryAggregation(
79+
agg_ops.DenseRankOp(), expression.deref(col_name)
80+
)
81+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
82+
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")
83+
84+
snapshot.assert_match(sql, "out.sql")
85+
86+
5087
def test_max(scalar_types_df: bpd.DataFrame, snapshot):
5188
col_name = "int64_col"
5289
bf_df = scalar_types_df[[col_name]]
@@ -104,6 +141,17 @@ def test_min(scalar_types_df: bpd.DataFrame, snapshot):
104141
snapshot.assert_match(sql, "out.sql")
105142

106143

144+
def test_rank(scalar_types_df: bpd.DataFrame, snapshot):
145+
col_name = "int64_col"
146+
bf_df = scalar_types_df[[col_name]]
147+
agg_expr = agg_exprs.UnaryAggregation(agg_ops.RankOp(), expression.deref(col_name))
148+
149+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
150+
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")
151+
152+
snapshot.assert_match(sql, "out.sql")
153+
154+
107155
def test_sum(scalar_types_df: bpd.DataFrame, snapshot):
108156
bf_df = scalar_types_df[["int64_col", "bool_col"]]
109157
agg_ops_map = {

0 commit comments

Comments
 (0)