diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index b4dc6174be..8364e757a1 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -244,6 +244,28 @@ def compile_join( joins_nulls=node.joins_nulls, ) + @_compile_node.register + def compile_isin_join( + self, node: nodes.InNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR + ) -> ir.SQLGlotIR: + conditions = ( + typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(node.left_col), + node.left_col.output_type, + ), + typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(node.right_col), + node.right_col.output_type, + ), + ) + + return left.isin_join( + right, + indicator_col=node.indicator_col.sql, + conditions=conditions, + joins_nulls=node.joins_nulls, + ) + @_compile_node.register def compile_concat( self, node: nodes.ConcatNode, *children: ir.SQLGlotIR diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 1a00cd0a93..5f3f07dd3b 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -336,6 +336,67 @@ def join( return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def isin_join( + self, + right: SQLGlotIR, + indicator_col: str, + conditions: tuple[typed_expr.TypedExpr, typed_expr.TypedExpr], + joins_nulls: bool = True, + ) -> SQLGlotIR: + """Joins the current query with another SQLGlotIR instance.""" + left_cte_name = sge.to_identifier( + next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + ) + right_cte_name = sge.to_identifier( + next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + ) + + left_select = _select_to_cte(self.expr, left_cte_name) + right_select = _select_to_cte(right.expr, right_cte_name) + + left_ctes = left_select.args.pop("with", []) + right_ctes = right_select.args.pop("with", []) + merged_ctes = [*left_ctes, *right_ctes] + + left_condition = typed_expr.TypedExpr( + sge.Column(this=conditions[0].expr, table=left_cte_name), + conditions[0].dtype, + ) + right_condition = typed_expr.TypedExpr( + sge.Column(this=conditions[1].expr, table=right_cte_name), + conditions[1].dtype, + ) + + new_column: sge.Expression + if joins_nulls: + new_column = sge.Exists( + this=sge.Select() + .select(sge.convert(1)) + .from_(sge.Table(this=right_cte_name)) + .where( + _join_condition(left_condition, right_condition, joins_nulls=True) + ) + ) + else: + new_column = sge.In( + this=left_condition.expr, + expressions=[right_condition.expr], + ) + + new_column = sge.Alias( + this=new_column, + alias=sge.to_identifier(indicator_col, quoted=self.quoted), + ) + + new_expr = ( + sge.Select() + .select(sge.Column(this=sge.Star(), table=left_cte_name), new_column) + .from_(sge.Table(this=left_cte_name)) + ) + new_expr.set("with", sge.With(expressions=merged_ctes)) + + return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def explode( self, column_names: tuple[str, ...], diff --git a/tests/unit/core/compile/sqlglot/conftest.py b/tests/unit/core/compile/sqlglot/conftest.py index f65343fd66..3279b3a259 100644 --- a/tests/unit/core/compile/sqlglot/conftest.py +++ b/tests/unit/core/compile/sqlglot/conftest.py @@ -85,7 +85,7 @@ def scalar_types_table_schema() -> typing.Sequence[bigquery.SchemaField]: bigquery.SchemaField("numeric_col", "NUMERIC"), bigquery.SchemaField("float64_col", "FLOAT"), bigquery.SchemaField("rowindex", "INTEGER"), - bigquery.SchemaField("rowindex_2", "INTEGER"), + bigquery.SchemaField("rowindex_2", "INTEGER", mode="REQUIRED"), bigquery.SchemaField("string_col", "STRING"), bigquery.SchemaField("time_col", "TIME"), bigquery.SchemaField("timestamp_col", "TIMESTAMP"), diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_isin/test_compile_isin/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_isin/test_compile_isin/out.sql new file mode 100644 index 0000000000..d6b2a9f167 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_isin/test_compile_isin/out.sql @@ -0,0 +1,37 @@ +WITH `bfcte_1` AS ( + SELECT + `int64_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_1` +), `bfcte_0` AS ( + SELECT + `int64_too` AS `bfcol_4` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_3` AS ( + SELECT + `bfcol_4` + FROM `bfcte_0` + GROUP BY + `bfcol_4` +), `bfcte_4` AS ( + SELECT + `bfcte_2`.*, + EXISTS( + SELECT + 1 + FROM `bfcte_3` + WHERE + COALESCE(`bfcte_2`.`bfcol_3`, 0) = COALESCE(`bfcte_3`.`bfcol_4`, 0) + AND COALESCE(`bfcte_2`.`bfcol_3`, 1) = COALESCE(`bfcte_3`.`bfcol_4`, 1) + ) AS `bfcol_5` + FROM `bfcte_2` +) +SELECT + `bfcol_2` AS `rowindex`, + `bfcol_5` AS `int64_col` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/test_compile_isin.py b/tests/unit/core/compile/sqlglot/test_compile_isin.py new file mode 100644 index 0000000000..8b3e7f7291 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/test_compile_isin.py @@ -0,0 +1,31 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import bigframes.pandas as bpd + +pytest.importorskip("pytest_snapshot") + + +def test_compile_isin(scalar_types_df: bpd.DataFrame, snapshot): + bf_isin = scalar_types_df["int64_col"].isin(scalar_types_df["int64_too"]).to_frame() + snapshot.assert_match(bf_isin.sql, "out.sql") + + +def test_compile_isin_not_nullable(scalar_types_df: bpd.DataFrame, snapshot): + bf_isin = ( + scalar_types_df["rowindex_2"].isin(scalar_types_df["rowindex_2"]).to_frame() + ) + snapshot.assert_match(bf_isin.sql, "out.sql")