From 015aeb646ee3cd6110ce3e35d2aaec8995ddd8a6 Mon Sep 17 00:00:00 2001 From: Eren Avsarogullari Date: Sun, 12 Jan 2025 16:50:34 -0800 Subject: [PATCH 1/2] Feat: Support array_intersect --- native/core/src/execution/planner.rs | 17 +++++++++++++++++ native/proto/src/proto/expr.proto | 1 + .../org/apache/comet/serde/QueryPlanSerde.scala | 6 ++++++ .../org/apache/comet/CometExpressionSuite.scala | 16 ++++++++++++++++ 4 files changed, 40 insertions(+) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index c43230f49..941e3d33a 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -67,6 +67,7 @@ use datafusion::{ use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr}; use datafusion_functions_nested::concat::ArrayAppend; use datafusion_functions_nested::remove::array_remove_all_udf; +use datafusion_functions_nested::set_ops::array_intersect_udf; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use crate::execution::shuffle::CompressionCodec; @@ -765,6 +766,22 @@ impl PhysicalPlanner { Ok(Arc::new(case_expr)) } + ExprStruct::ArrayIntersect(expr) => { + let left_expr = + self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; + let right_expr = + self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + let args = vec![Arc::clone(&left_expr), right_expr]; + let datafusion_array_intersect = array_intersect_udf(); + let return_type = left_expr.data_type(&input_schema)?; + let array_intersect_expr = Arc::new(ScalarFunctionExpr::new( + "array_intersect", + datafusion_array_intersect, + args, + return_type, + )); + Ok(array_intersect_expr) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 8e3bc60b0..0b7d24d9f 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -86,6 +86,7 @@ message Expr { ArrayInsert array_insert = 59; BinaryExpr array_contains = 60; BinaryExpr array_remove = 61; + BinaryExpr array_intersect = 62; } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 818648651..43e400d75 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2284,6 +2284,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim expr.children(1), inputs, (builder, binaryExpr) => builder.setArrayAppend(binaryExpr)) + case _ if expr.prettyName == "array_intersect" => + createBinaryExpr( + expr.children(0), + expr.children(1), + inputs, + (builder, binaryExpr) => builder.setArrayIntersect(binaryExpr)) case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 8c2759a38..7d63a313b 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2545,4 +2545,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("array_intersect") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + sql("SELECT array_intersect(array(_2, _3, _4), array(_9, _10)) from t1")) + checkSparkAnswerAndOperator( + sql("SELECT array_intersect(array(_2 * -1), array(_9, _10)) from t1")) + checkSparkAnswerAndOperator(sql("SELECT array_intersect(array(_18), array(_19)) from t1")) + } + } + } + } From 7bcf6cb71e54d82be98e6805af36eeb34431e842 Mon Sep 17 00:00:00 2001 From: Eren Avsarogullari Date: Mon, 13 Jan 2025 19:29:22 -0800 Subject: [PATCH 2/2] Address review comment --- .../src/test/scala/org/apache/comet/CometExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 7d63a313b..e5356a2d6 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2553,7 +2553,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { makeParquetFileAllTypes(path, dictionaryEnabled, 10000) spark.read.parquet(path.toString).createOrReplaceTempView("t1") checkSparkAnswerAndOperator( - sql("SELECT array_intersect(array(_2, _3, _4), array(_9, _10)) from t1")) + sql("SELECT array_intersect(array(_2, _3, _4), array(_3, _4)) from t1")) checkSparkAnswerAndOperator( sql("SELECT array_intersect(array(_2 * -1), array(_9, _10)) from t1")) checkSparkAnswerAndOperator(sql("SELECT array_intersect(array(_18), array(_19)) from t1"))