Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: support array_compact function #1321

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
19 changes: 19 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,25 @@ impl PhysicalPlanner {
));
Ok(array_has_any_expr)
}
ExprStruct::ArrayCompact(expr) => {
let src_array_expr =
self.create_expr(expr.array_expr.as_ref().unwrap(), Arc::clone(&input_schema))?;
let datatype = to_arrow_datatype(expr.item_datatype.as_ref().unwrap());

let null_literal_expr: Arc<dyn PhysicalExpr> =
Arc::new(Literal::new(ScalarValue::Null.cast_to(&datatype)?));
let args = vec![Arc::clone(&src_array_expr), null_literal_expr];
let return_type = src_array_expr.data_type(&input_schema)?;

let array_compact_expr = Arc::new(ScalarFunctionExpr::new(
"array_compact",
array_remove_all_udf(),
args,
return_type,
));

Ok(array_compact_expr)
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
6 changes: 6 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ message Expr {
BinaryExpr array_intersect = 62;
ArrayJoin array_join = 63;
BinaryExpr arrays_overlap = 64;
ArrayCompact array_compact = 65;
}
}

Expand Down Expand Up @@ -423,6 +424,11 @@ message ArrayJoin {
Expr null_replacement_expr = 3;
}

message ArrayCompact {
Expr array_expr = 1;
DataType item_datatype = 2;
}

message DataType {
enum DataTypeId {
BOOL = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2366,6 +2366,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
case _: ArrayIntersect => convert(CometArrayIntersect)
case _: ArrayJoin => convert(CometArrayJoin)
case _: ArraysOverlap => convert(CometArraysOverlap)
case expr @ ArrayFilter(child, _) if ArrayCompact(child).replacement.sql == expr.sql =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you add a flag to enable and disable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the PR only contains basic tests, could you add a check to enable this expression only if CometConf.COMET_CAST_ALLOW_INCOMPATIBLE is enabled? We can remove this check in a future PR that adds comprehensive tests and demonstrates that we have Spark-compatible behavior for all supported data types.

Copy link
Author

@kazantsev-maksim kazantsev-maksim Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

supported the latest Interface

convert(CometArrayCompact)
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
27 changes: 26 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, ArrayRemove, Attrib
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, StructType}

import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProto}
import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProto, serializeDataType}
import org.apache.comet.shims.CometExprShim

object CometArrayRemove extends CometExpressionSerde with CometExprShim {
Expand Down Expand Up @@ -126,6 +126,31 @@ object CometArraysOverlap extends CometExpressionSerde with IncompatExpr {
}
}

object CometArrayCompact extends CometExpressionSerde with IncompatExpr {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val child = expr.children.head
val elementType = serializeDataType(child.dataType.asInstanceOf[ArrayType].elementType)
val srcExprProto = exprToProto(child, inputs, binding)
if (elementType.isDefined && srcExprProto.isDefined) {
val arrayCompactBuilder = ExprOuterClass.ArrayCompact
.newBuilder()
.setArrayExpr(srcExprProto.get)
.setItemDatatype(elementType.get)
Some(
ExprOuterClass.Expr
.newBuilder()
.setArrayCompact(arrayCompactBuilder)
.build())
} else {
withInfo(expr, "unsupported arguments for ArrayCompact", expr.children: _*)
None
}
}
}

object CometArrayJoin extends CometExpressionSerde with IncompatExpr {
override def convert(
expr: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,4 +292,24 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
}
}

test("array_compact") {
assume(isSpark34Plus)
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, n = 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1")

checkSparkAnswerAndOperator(
sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NULL"))
checkSparkAnswerAndOperator(
sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NOT NULL"))
checkSparkAnswerAndOperator(
sql("SELECT array_compact(array(_2, _3, null)) FROM t1 WHERE _2 IS NOT NULL"))
}
}
}
}

}
Loading