diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 3dc59a9fd..058e5ff6d 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -111,6 +111,7 @@ use datafusion_expr::{ WindowFunctionDefinition, }; use datafusion_functions_nested::array_has::ArrayHas; +use datafusion_functions_nested::position::array_position_udf; use datafusion_physical_expr::expressions::{Literal, StatsType}; use datafusion_physical_expr::window::WindowExpr; use datafusion_physical_expr::LexOrdering; @@ -829,6 +830,24 @@ impl PhysicalPlanner { )); Ok(array_has_any_expr) } + ExprStruct::ArrayPosition(expr) => { + let left_array_expr = + self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; + let right_array_expr = + self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + let args = vec![left_array_expr, right_array_expr]; + let array_has_any_expr = Arc::new(ScalarFunctionExpr::new( + "array_position", + array_position_udf(), + args, + DataType::UInt64, + )); + Ok(Arc::new(Cast::new( + array_has_any_expr, + DataType::Int64, + SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), + ))) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index fd928fd8a..08f5450f2 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -89,6 +89,7 @@ message Expr { BinaryExpr array_intersect = 62; ArrayJoin array_join = 63; BinaryExpr arrays_overlap = 64; + BinaryExpr array_position = 66; } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index f4699af8d..c48db349c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2366,6 +2366,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case _: ArrayIntersect => convert(CometArrayIntersect) case _: ArrayJoin => convert(CometArrayJoin) case _: ArraysOverlap => convert(CometArraysOverlap) + case _: ArrayPosition => convert(CometArrayPosition) case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index db1679f22..26240ba04 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -165,3 +165,20 @@ object CometArrayJoin extends CometExpressionSerde with IncompatExpr { } } } + +object CometArrayPosition extends CometExpressionSerde with IncompatExpr { + + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + createBinaryExpr( + expr, + expr.children(0), + expr.children(1), + inputs, + binding, + (builder, binaryExpr) => builder.setArraysOverlap(binaryExpr)) + } + +} diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index df1fccb69..a3c975e9f 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -292,4 +292,23 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } + test("array_position") { + withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + sql("SELECT array_position(array(_2, _3,_4), _2) from t1 where _2 is null")) + checkSparkAnswerAndOperator( + sql("SELECT array_position(array(_2, _3,_4), _3) from t1 where _3 is not null")) + // checkSparkAnswerAndOperator(sql( + // "SELECT array_position(case when _2 = _3 THEN array(_2, _3,_4) ELSE null END, _3) from t1")) + } + + } + } + } + }