From 655c960349a1d201145b8ff048bd086d87e7fa76 Mon Sep 17 00:00:00 2001 From: Dharan Aditya Date: Fri, 31 Jan 2025 20:07:40 +0530 Subject: [PATCH 1/3] impl array_position Signed-off-by: Dharan Aditya --- native/core/src/execution/planner.rs | 15 +++++++++++++++ native/proto/src/proto/expr.proto | 1 + .../org/apache/comet/serde/QueryPlanSerde.scala | 1 + .../scala/org/apache/comet/serde/arrays.scala | 17 +++++++++++++++++ .../comet/CometArrayExpressionSuite.scala | 16 ++++++++++++++++ 5 files changed, 50 insertions(+) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 3dc59a9fd9..a50844e630 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -111,6 +111,7 @@ use datafusion_expr::{ WindowFunctionDefinition, }; use datafusion_functions_nested::array_has::ArrayHas; +use datafusion_functions_nested::position::array_position_udf; use datafusion_physical_expr::expressions::{Literal, StatsType}; use datafusion_physical_expr::window::WindowExpr; use datafusion_physical_expr::LexOrdering; @@ -829,6 +830,20 @@ impl PhysicalPlanner { )); Ok(array_has_any_expr) } + ExprStruct::ArrayPosition(expr) => { + let left_array_expr = + self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; + let right_array_expr = + self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + let args = vec![left_array_expr, right_array_expr]; + let array_has_any_expr = Arc::new(ScalarFunctionExpr::new( + "array_position", + array_position_udf(), + args, + DataType::UInt64, + )); + Ok(array_has_any_expr) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index fd928fd8a3..08f5450f23 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -89,6 +89,7 @@ message Expr { BinaryExpr array_intersect = 62; ArrayJoin array_join = 63; BinaryExpr arrays_overlap = 64; + BinaryExpr array_position = 66; } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index f4699af8de..c48db349c7 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2366,6 +2366,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case _: ArrayIntersect => convert(CometArrayIntersect) case _: ArrayJoin => convert(CometArrayJoin) case _: ArraysOverlap => convert(CometArraysOverlap) + case _: ArrayPosition => convert(CometArrayPosition) case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index db1679f22b..26240ba049 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -165,3 +165,20 @@ object CometArrayJoin extends CometExpressionSerde with IncompatExpr { } } } + +object CometArrayPosition extends CometExpressionSerde with IncompatExpr { + + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + createBinaryExpr( + expr, + expr.children(0), + expr.children(1), + inputs, + binding, + (builder, binaryExpr) => builder.setArraysOverlap(binaryExpr)) + } + +} diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index df1fccb698..7f94e39ebc 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -292,4 +292,20 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } + test("array_position") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + sql("SELECT array_position(array(_2, _3,_4), _2) from t1 where _2 is null")) + checkSparkAnswerAndOperator( + sql("SELECT array_position(array(_2, _3,_4), _3) from t1 where _3 is not null")) +// checkSparkAnswerAndOperator(sql( +// "SELECT array_position(case when _2 = _3 THEN array(_2, _3,_4) ELSE null END, _3) from t1")) + } + } + } + } From 5fc4831bc59580b48e4957862166594e18fcf92c Mon Sep 17 00:00:00 2001 From: Dharan Aditya Date: Fri, 31 Jan 2025 20:26:49 +0530 Subject: [PATCH 2/3] fix cast Signed-off-by: Dharan Aditya --- native/core/src/execution/planner.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index a50844e630..058e5ff6df 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -842,7 +842,11 @@ impl PhysicalPlanner { args, DataType::UInt64, )); - Ok(array_has_any_expr) + Ok(Arc::new(Cast::new( + array_has_any_expr, + DataType::Int64, + SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), + ))) } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", From 151dbd6e86c3c722edd090bae4949412a6d45888 Mon Sep 17 00:00:00 2001 From: Dharan Aditya Date: Fri, 31 Jan 2025 21:01:18 +0530 Subject: [PATCH 3/3] update ut Signed-off-by: Dharan Aditya --- .../comet/CometArrayExpressionSuite.scala | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 7f94e39ebc..a3c975e9f6 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -293,17 +293,20 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } test("array_position") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - checkSparkAnswerAndOperator( - sql("SELECT array_position(array(_2, _3,_4), _2) from t1 where _2 is null")) - checkSparkAnswerAndOperator( - sql("SELECT array_position(array(_2, _3,_4), _3) from t1 where _3 is not null")) -// checkSparkAnswerAndOperator(sql( -// "SELECT array_position(case when _2 = _3 THEN array(_2, _3,_4) ELSE null END, _3) from t1")) + withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + sql("SELECT array_position(array(_2, _3,_4), _2) from t1 where _2 is null")) + checkSparkAnswerAndOperator( + sql("SELECT array_position(array(_2, _3,_4), _3) from t1 where _3 is not null")) + // checkSparkAnswerAndOperator(sql( + // "SELECT array_position(case when _2 = _3 THEN array(_2, _3,_4) ELSE null END, _3) from t1")) + } + } } }