diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 6f41bf0ad..d80e22f1b 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -82,6 +82,7 @@ use datafusion::{ }, prelude::SessionContext, }; +use datafusion_functions_nested::concat::ArrayAppend; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_comet_proto::{ @@ -107,7 +108,8 @@ use datafusion_common::{ }; use datafusion_expr::expr::find_df_window_func; use datafusion_expr::{ - AggregateUDF, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + AggregateUDF, ScalarUDF, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, }; use datafusion_physical_expr::expressions::{Literal, StatsType}; use datafusion_physical_expr::window::WindowExpr; @@ -691,6 +693,33 @@ impl PhysicalPlanner { expr.ordinal as usize, ))) } + ExprStruct::ArrayAppend(expr) => { + let left = + self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; + let right = + self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + let return_type = left.data_type(&input_schema)?; + let args = vec![Arc::clone(&left), right]; + let datafusion_array_append = + Arc::new(ScalarUDF::new_from_impl(ArrayAppend::new())); + let array_append_expr: Arc = Arc::new(ScalarFunctionExpr::new( + "array_append", + datafusion_array_append, + args, + return_type, + )); + + let is_null_expr: Arc = Arc::new(IsNullExpr::new(left)); + let null_literal_expr: Arc = + Arc::new(Literal::new(ScalarValue::Null)); + + let case_expr = CaseExpr::try_new( + None, + vec![(is_null_expr, null_literal_expr)], + Some(array_append_expr), + )?; + Ok(Arc::new(case_expr)) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 3a8193f4a..220f5e521 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -82,6 +82,7 @@ message Expr { ToJson to_json = 55; ListExtract list_extract = 56; GetArrayStructFields get_array_struct_fields = 57; + BinaryExpr array_append = 58; } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 2a86c5c36..4a130ad0d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2237,7 +2237,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(expr, "unsupported arguments for GetArrayStructFields", child) None } - + case _ if expr.prettyName == "array_append" => + createBinaryExpr( + expr.children(0), + expr.children(1), + inputs, + (builder, binaryExpr) => builder.setArrayAppend(binaryExpr)) case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 0d00867d1..5079f1910 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2313,4 +2313,28 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("array_append") { + // array append has been added in Spark 3.4 and in Spark 4.0 it gets written to ArrayInsert + assume(isSpark34Plus && !isSpark40Plus) + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1"); + checkSparkAnswerAndOperator(spark.sql("Select array_append(array(_1),false) from t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_append(array(_2, _3, _4), 4) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_append(array(_2, _3, _4), null) FROM t1")); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_append(array(_6, _7), CAST(6.5 AS DOUBLE)) FROM t1")); + checkSparkAnswerAndOperator(spark.sql("SELECT array_append(array(_8), 'test') FROM t1")); + checkSparkAnswerAndOperator(spark.sql("SELECT array_append(array(_19), _19) FROM t1")); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_append((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + } + + } + } }