Skip to content

Commit

Permalink
impl array_union
Browse files Browse the repository at this point in the history
Signed-off-by: Dharan Aditya <[email protected]>
  • Loading branch information
dharanad committed Jan 31, 2025
1 parent 996362e commit 059c037
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 1 deletion.
17 changes: 16 additions & 1 deletion native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr}
use datafusion_functions_nested::array_has::array_has_any_udf;
use datafusion_functions_nested::concat::ArrayAppend;
use datafusion_functions_nested::remove::array_remove_all_udf;
use datafusion_functions_nested::set_ops::array_intersect_udf;
use datafusion_functions_nested::set_ops::{array_intersect_udf, array_union_udf};
use datafusion_functions_nested::string::array_to_string_udf;
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};

Expand Down Expand Up @@ -829,6 +829,21 @@ impl PhysicalPlanner {
));
Ok(array_has_any_expr)
}
ExprStruct::ArrayUnion(expr) => {
let left_array_expr =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right_array_expr =
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
let array_union_udf = array_union_udf();
let return_type = right_array_expr.data_type(&input_schema)?;
let args = vec![Arc::clone(&left_array_expr), right_array_expr];
Ok(Arc::new(ScalarFunctionExpr::new(
"array_union",
array_union_udf,
args,
return_type,
)))
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
1 change: 1 addition & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ message Expr {
BinaryExpr array_intersect = 62;
ArrayJoin array_join = 63;
BinaryExpr arrays_overlap = 64;
BinaryExpr array_union = 65;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2366,6 +2366,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
case _: ArrayIntersect => convert(CometArrayIntersect)
case _: ArrayJoin => convert(CometArrayJoin)
case _: ArraysOverlap => convert(CometArraysOverlap)
case _: ArrayUnion => convert(CometArrayUnion)
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
16 changes: 16 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,19 @@ object CometArrayJoin extends CometExpressionSerde with IncompatExpr {
}
}
}

object CometArrayUnion extends CometExpressionSerde with IncompatExpr {

override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
createBinaryExpr(
expr,
expr.children(0),
expr.children(1),
inputs,
binding,
(builder, binaryExpr) => builder.setArrayUnion(binaryExpr))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -292,4 +292,22 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
}
}

test("array_union") {
assume(isSpark34Plus)
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
checkSparkAnswerAndOperator(
sql("SELECT array_union(array(_2, _3, _4), array(_3, _4)) from t1"))
checkSparkAnswerAndOperator(
sql("SELECT array_union(array(_2 * -1), array(_9, _10)) from t1"))
checkSparkAnswerAndOperator(sql("SELECT array_union(array(_18), array(_19)) from t1"))
}
}
}
}

}

0 comments on commit 059c037

Please sign in to comment.