Skip to content

Commit 31b2d36

Browse files
committed
chore: implement floordiv_op compiler
1 parent ba93b5b commit 31b2d36

File tree

8 files changed

+386
-60
lines changed

8 files changed

+386
-60
lines changed

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414

1515
from __future__ import annotations
1616

17-
import bigframes_vendored.constants as constants
17+
import bigframes_vendored.constants as bf_constants
1818
import sqlglot.expressions as sge
1919

2020
from bigframes import dtypes
2121
from bigframes import operations as ops
22+
import bigframes.core.compile.sqlglot.expressions.constants as constants
2223
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2324
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2425

@@ -69,7 +70,7 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
6970
return sge.Add(this=left.expr, expression=right.expr)
7071

7172
raise TypeError(
72-
f"Cannot add type {left.dtype} and {right.dtype}. {constants.FEEDBACK_LINK}"
73+
f"Cannot add type {left.dtype} and {right.dtype}. {bf_constants.FEEDBACK_LINK}"
7374
)
7475

7576

@@ -89,6 +90,43 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
8990
return result
9091

9192

93+
@BINARY_OP_REGISTRATION.register(ops.floordiv_op)
94+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
95+
left_expr = left.expr
96+
if left.dtype == dtypes.BOOL_DTYPE:
97+
left_expr = sge.Cast(this=left_expr, to="INT64")
98+
right_expr = right.expr
99+
if right.dtype == dtypes.BOOL_DTYPE:
100+
right_expr = sge.Cast(this=right_expr, to="INT64")
101+
102+
result: sge.Expression = sge.Cast(
103+
this=sge.Floor(this=sge.func("IEEE_DIVIDE", left_expr, right_expr)), to="INT64"
104+
)
105+
106+
# DIV(N, 0) will error in bigquery, but needs to return `0` for int, and
107+
# `inf`` for float in BQ so we short-circuit in this case.
108+
# Multiplying left by zero propogates nulls.
109+
zero_result = (
110+
constants._INF
111+
if (left.dtype == dtypes.FLOAT_DTYPE or right.dtype == dtypes.FLOAT_DTYPE)
112+
else constants._ZERO
113+
)
114+
result = sge.Case(
115+
ifs=[
116+
sge.If(
117+
this=sge.EQ(this=right_expr, expression=constants._ZERO),
118+
true=zero_result * left_expr,
119+
)
120+
],
121+
default=result,
122+
)
123+
124+
if dtypes.is_numeric(right.dtype) and left.dtype == dtypes.TIMEDELTA_DTYPE:
125+
result = sge.Cast(this=sge.Floor(this=result), to="INT64")
126+
127+
return result
128+
129+
92130
@BINARY_OP_REGISTRATION.register(ops.ge_op)
93131
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
94132
return sge.GTE(this=left.expr, expression=right.expr)
@@ -156,5 +194,5 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
156194
return sge.Sub(this=left.expr, expression=right.expr)
157195

158196
raise TypeError(
159-
f"Cannot subtract type {left.dtype} and {right.dtype}. {constants.FEEDBACK_LINK}"
197+
f"Cannot subtract type {left.dtype} and {right.dtype}. {bf_constants.FEEDBACK_LINK}"
160198
)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sqlglot.expressions as sge
16+
17+
_ZERO = sge.Cast(this=sge.convert(0), to="INT64")
18+
_NAN = sge.Cast(this=sge.convert("NaN"), to="FLOAT64")
19+
_INF = sge.Cast(this=sge.convert("Infinity"), to="FLOAT64")
20+
21+
# Approx Highest number you can pass in to EXP function and get a valid FLOAT64 result
22+
# FLOAT64 has 11 exponent bits, so max values is about 2**(2**10)
23+
# ln(2**(2**10)) == (2**10)*ln(2) ~= 709.78, so EXP(x) for x>709.78 will overflow.
24+
_FLOAT64_EXP_BOUND = sge.convert(709.78)

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,10 @@
2222
import sqlglot.expressions as sge
2323

2424
from bigframes import operations as ops
25+
import bigframes.core.compile.sqlglot.expressions.constants as constants
2526
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2627
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2728

28-
_NAN = sge.Cast(this=sge.convert("NaN"), to="FLOAT64")
29-
_INF = sge.Cast(this=sge.convert("Infinity"), to="FLOAT64")
30-
31-
# Approx Highest number you can pass in to EXP function and get a valid FLOAT64 result
32-
# FLOAT64 has 11 exponent bits, so max values is about 2**(2**10)
33-
# ln(2**(2**10)) == (2**10)*ln(2) ~= 709.78, so EXP(x) for x>709.78 will overflow.
34-
_FLOAT64_EXP_BOUND = sge.convert(709.78)
35-
3629
UNARY_OP_REGISTRATION = OpRegistration()
3730

3831

@@ -51,7 +44,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
5144
ifs=[
5245
sge.If(
5346
this=expr.expr < sge.convert(1),
54-
true=_NAN,
47+
true=constants._NAN,
5548
)
5649
],
5750
default=sge.func("ACOSH", expr.expr),
@@ -64,7 +57,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
6457
ifs=[
6558
sge.If(
6659
this=sge.func("ABS", expr.expr) > sge.convert(1),
67-
true=_NAN,
60+
true=constants._NAN,
6861
)
6962
],
7063
default=sge.func("ACOS", expr.expr),
@@ -77,7 +70,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
7770
ifs=[
7871
sge.If(
7972
this=sge.func("ABS", expr.expr) > sge.convert(1),
80-
true=_NAN,
73+
true=constants._NAN,
8174
)
8275
],
8376
default=sge.func("ASIN", expr.expr),
@@ -100,7 +93,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
10093
ifs=[
10194
sge.If(
10295
this=sge.func("ABS", expr.expr) > sge.convert(1),
103-
true=_NAN,
96+
true=constants._NAN,
10497
)
10598
],
10699
default=sge.func("ATANH", expr.expr),
@@ -176,7 +169,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
176169
ifs=[
177170
sge.If(
178171
this=sge.func("ABS", expr.expr) > sge.convert(709.78),
179-
true=_INF,
172+
true=constants._INF,
180173
)
181174
],
182175
default=sge.func("COSH", expr.expr),
@@ -221,8 +214,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
221214
return sge.Case(
222215
ifs=[
223216
sge.If(
224-
this=expr.expr > _FLOAT64_EXP_BOUND,
225-
true=_INF,
217+
this=expr.expr > constants._FLOAT64_EXP_BOUND,
218+
true=constants._INF,
226219
)
227220
],
228221
default=sge.func("EXP", expr.expr),
@@ -234,8 +227,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
234227
return sge.Case(
235228
ifs=[
236229
sge.If(
237-
this=expr.expr > _FLOAT64_EXP_BOUND,
238-
true=_INF,
230+
this=expr.expr > constants._FLOAT64_EXP_BOUND,
231+
true=constants._INF,
239232
)
240233
],
241234
default=sge.func("EXP", expr.expr),
@@ -382,7 +375,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
382375
ifs=[
383376
sge.If(
384377
this=expr.expr < sge.convert(0),
385-
true=_NAN,
378+
true=constants._NAN,
386379
)
387380
],
388381
default=sge.Ln(this=expr.expr),
@@ -395,7 +388,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
395388
ifs=[
396389
sge.If(
397390
this=expr.expr < sge.convert(0),
398-
true=_NAN,
391+
true=constants._NAN,
399392
)
400393
],
401394
default=sge.Log(this=expr.expr, expression=sge.convert(10)),
@@ -408,7 +401,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
408401
ifs=[
409402
sge.If(
410403
this=expr.expr < sge.convert(-1),
411-
true=_NAN,
404+
true=constants._NAN,
412405
)
413406
],
414407
default=sge.Ln(this=sge.convert(1) + expr.expr),
@@ -476,7 +469,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
476469
ifs=[
477470
sge.If(
478471
this=expr.expr < sge.convert(0),
479-
true=_NAN,
472+
true=constants._NAN,
480473
)
481474
],
482475
default=sge.Sqrt(this=expr.expr),
@@ -523,8 +516,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
523516
return sge.Case(
524517
ifs=[
525518
sge.If(
526-
this=sge.func("ABS", expr.expr) > _FLOAT64_EXP_BOUND,
527-
true=sge.func("SIGN", expr.expr) * _INF,
519+
this=sge.func("ABS", expr.expr) > constants._FLOAT64_EXP_BOUND,
520+
true=sge.func("SIGN", expr.expr) * constants._INF,
528521
)
529522
],
530523
default=sge.func("SINH", expr.expr),

tests/system/small/engines/test_numeric_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_engines_project_div_durations(
117117
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
118118

119119

120-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
120+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
121121
def test_engines_project_floordiv(
122122
scalars_array_value: array_value.ArrayValue,
123123
engine,
@@ -130,7 +130,7 @@ def test_engines_project_floordiv(
130130
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
131131

132132

133-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
133+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
134134
def test_engines_project_floordiv_durations(
135135
scalars_array_value: array_value.ArrayValue, engine
136136
):

0 commit comments

Comments
 (0)