Skip to content

Commit 1445c98

Browse files
authored
chore: implement floordiv_op compiler (#1995)
Fixes internal issue 430133370
1 parent 9af7130 commit 1445c98

File tree

8 files changed

+387
-60
lines changed

8 files changed

+387
-60
lines changed

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414

1515
from __future__ import annotations
1616

17-
import bigframes_vendored.constants as constants
17+
import bigframes_vendored.constants as bf_constants
1818
import sqlglot.expressions as sge
1919

2020
from bigframes import dtypes
2121
from bigframes import operations as ops
22+
import bigframes.core.compile.sqlglot.expressions.constants as constants
2223
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2324
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2425

@@ -69,7 +70,7 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
6970
return sge.Add(this=left.expr, expression=right.expr)
7071

7172
raise TypeError(
72-
f"Cannot add type {left.dtype} and {right.dtype}. {constants.FEEDBACK_LINK}"
73+
f"Cannot add type {left.dtype} and {right.dtype}. {bf_constants.FEEDBACK_LINK}"
7374
)
7475

7576

@@ -89,6 +90,43 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
8990
return result
9091

9192

93+
@BINARY_OP_REGISTRATION.register(ops.floordiv_op)
94+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
95+
left_expr = left.expr
96+
if left.dtype == dtypes.BOOL_DTYPE:
97+
left_expr = sge.Cast(this=left_expr, to="INT64")
98+
right_expr = right.expr
99+
if right.dtype == dtypes.BOOL_DTYPE:
100+
right_expr = sge.Cast(this=right_expr, to="INT64")
101+
102+
result: sge.Expression = sge.Cast(
103+
this=sge.Floor(this=sge.func("IEEE_DIVIDE", left_expr, right_expr)), to="INT64"
104+
)
105+
106+
# DIV(N, 0) will error in bigquery, but needs to return `0` for int, and
107+
# `inf`` for float in BQ so we short-circuit in this case.
108+
# Multiplying left by zero propogates nulls.
109+
zero_result = (
110+
constants._INF
111+
if (left.dtype == dtypes.FLOAT_DTYPE or right.dtype == dtypes.FLOAT_DTYPE)
112+
else constants._ZERO
113+
)
114+
result = sge.Case(
115+
ifs=[
116+
sge.If(
117+
this=sge.EQ(this=right_expr, expression=constants._ZERO),
118+
true=zero_result * left_expr,
119+
)
120+
],
121+
default=result,
122+
)
123+
124+
if dtypes.is_numeric(right.dtype) and left.dtype == dtypes.TIMEDELTA_DTYPE:
125+
result = sge.Cast(this=sge.Floor(this=result), to="INT64")
126+
127+
return result
128+
129+
92130
@BINARY_OP_REGISTRATION.register(ops.ge_op)
93131
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
94132
return sge.GTE(this=left.expr, expression=right.expr)
@@ -156,7 +194,7 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
156194
return sge.Sub(this=left.expr, expression=right.expr)
157195

158196
raise TypeError(
159-
f"Cannot subtract type {left.dtype} and {right.dtype}. {constants.FEEDBACK_LINK}"
197+
f"Cannot subtract type {left.dtype} and {right.dtype}. {bf_constants.FEEDBACK_LINK}"
160198
)
161199

162200

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sqlglot.expressions as sge
16+
17+
_ZERO = sge.Cast(this=sge.convert(0), to="INT64")
18+
_NAN = sge.Cast(this=sge.convert("NaN"), to="FLOAT64")
19+
_INF = sge.Cast(this=sge.convert("Infinity"), to="FLOAT64")
20+
_NEG_INF = sge.Cast(this=sge.convert("-Infinity"), to="FLOAT64")
21+
22+
# Approx Highest number you can pass in to EXP function and get a valid FLOAT64 result
23+
# FLOAT64 has 11 exponent bits, so max values is about 2**(2**10)
24+
# ln(2**(2**10)) == (2**10)*ln(2) ~= 709.78, so EXP(x) for x>709.78 will overflow.
25+
_FLOAT64_EXP_BOUND = sge.convert(709.78)

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,10 @@
2323

2424
from bigframes import operations as ops
2525
from bigframes.core.compile.constants import UNIT_TO_US_CONVERSION_FACTORS
26+
import bigframes.core.compile.sqlglot.expressions.constants as constants
2627
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2728
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2829

29-
_NAN = sge.Cast(this=sge.convert("NaN"), to="FLOAT64")
30-
_INF = sge.Cast(this=sge.convert("Infinity"), to="FLOAT64")
31-
32-
# Approx Highest number you can pass in to EXP function and get a valid FLOAT64 result
33-
# FLOAT64 has 11 exponent bits, so max values is about 2**(2**10)
34-
# ln(2**(2**10)) == (2**10)*ln(2) ~= 709.78, so EXP(x) for x>709.78 will overflow.
35-
_FLOAT64_EXP_BOUND = sge.convert(709.78)
36-
3730
UNARY_OP_REGISTRATION = OpRegistration()
3831

3932

@@ -52,7 +45,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
5245
ifs=[
5346
sge.If(
5447
this=expr.expr < sge.convert(1),
55-
true=_NAN,
48+
true=constants._NAN,
5649
)
5750
],
5851
default=sge.func("ACOSH", expr.expr),
@@ -65,7 +58,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
6558
ifs=[
6659
sge.If(
6760
this=sge.func("ABS", expr.expr) > sge.convert(1),
68-
true=_NAN,
61+
true=constants._NAN,
6962
)
7063
],
7164
default=sge.func("ACOS", expr.expr),
@@ -78,7 +71,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
7871
ifs=[
7972
sge.If(
8073
this=sge.func("ABS", expr.expr) > sge.convert(1),
81-
true=_NAN,
74+
true=constants._NAN,
8275
)
8376
],
8477
default=sge.func("ASIN", expr.expr),
@@ -101,7 +94,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
10194
ifs=[
10295
sge.If(
10396
this=sge.func("ABS", expr.expr) > sge.convert(1),
104-
true=_NAN,
97+
true=constants._NAN,
10598
)
10699
],
107100
default=sge.func("ATANH", expr.expr),
@@ -177,7 +170,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
177170
ifs=[
178171
sge.If(
179172
this=sge.func("ABS", expr.expr) > sge.convert(709.78),
180-
true=_INF,
173+
true=constants._INF,
181174
)
182175
],
183176
default=sge.func("COSH", expr.expr),
@@ -222,8 +215,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
222215
return sge.Case(
223216
ifs=[
224217
sge.If(
225-
this=expr.expr > _FLOAT64_EXP_BOUND,
226-
true=_INF,
218+
this=expr.expr > constants._FLOAT64_EXP_BOUND,
219+
true=constants._INF,
227220
)
228221
],
229222
default=sge.func("EXP", expr.expr),
@@ -235,8 +228,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
235228
return sge.Case(
236229
ifs=[
237230
sge.If(
238-
this=expr.expr > _FLOAT64_EXP_BOUND,
239-
true=_INF,
231+
this=expr.expr > constants._FLOAT64_EXP_BOUND,
232+
true=constants._INF,
240233
)
241234
],
242235
default=sge.func("EXP", expr.expr),
@@ -403,7 +396,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
403396
ifs=[
404397
sge.If(
405398
this=expr.expr < sge.convert(0),
406-
true=_NAN,
399+
true=constants._NAN,
407400
)
408401
],
409402
default=sge.Ln(this=expr.expr),
@@ -416,7 +409,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
416409
ifs=[
417410
sge.If(
418411
this=expr.expr < sge.convert(0),
419-
true=_NAN,
412+
true=constants._NAN,
420413
)
421414
],
422415
default=sge.Log(this=expr.expr, expression=sge.convert(10)),
@@ -429,7 +422,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
429422
ifs=[
430423
sge.If(
431424
this=expr.expr < sge.convert(-1),
432-
true=_NAN,
425+
true=constants._NAN,
433426
)
434427
],
435428
default=sge.Ln(this=sge.convert(1) + expr.expr),
@@ -512,7 +505,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
512505
ifs=[
513506
sge.If(
514507
this=expr.expr < sge.convert(0),
515-
true=_NAN,
508+
true=constants._NAN,
516509
)
517510
],
518511
default=sge.Sqrt(this=expr.expr),
@@ -534,8 +527,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
534527
return sge.Case(
535528
ifs=[
536529
sge.If(
537-
this=sge.func("ABS", expr.expr) > _FLOAT64_EXP_BOUND,
538-
true=sge.func("SIGN", expr.expr) * _INF,
530+
this=sge.func("ABS", expr.expr) > constants._FLOAT64_EXP_BOUND,
531+
true=sge.func("SIGN", expr.expr) * constants._INF,
539532
)
540533
],
541534
default=sge.func("SINH", expr.expr),

tests/system/small/engines/test_numeric_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_engines_project_div_durations(
117117
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
118118

119119

120-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
120+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
121121
def test_engines_project_floordiv(
122122
scalars_array_value: array_value.ArrayValue,
123123
engine,
@@ -130,7 +130,7 @@ def test_engines_project_floordiv(
130130
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
131131

132132

133-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
133+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
134134
def test_engines_project_floordiv_durations(
135135
scalars_array_value: array_value.ArrayValue, engine
136136
):

0 commit comments

Comments
 (0)