Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: support array_compact function #1321

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
19 changes: 19 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,25 @@ impl PhysicalPlanner {
));
Ok(array_join_expr)
}
ExprStruct::ArrayCompact(expr) => {
let src_array_expr =
self.create_expr(expr.array_expr.as_ref().unwrap(), Arc::clone(&input_schema))?;
let datatype = to_arrow_datatype(expr.item_datatype.as_ref().unwrap());

let null_literal_expr: Arc<dyn PhysicalExpr> =
Arc::new(Literal::new(ScalarValue::Null.cast_to(&datatype)?));
let args = vec![Arc::clone(&src_array_expr), null_literal_expr];
let return_type = src_array_expr.data_type(&input_schema)?;

let array_compact_expr = Arc::new(ScalarFunctionExpr::new(
"array_compact",
array_remove_all_udf(),
args,
return_type,
));

Ok(array_compact_expr)
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
6 changes: 6 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ message Expr {
BinaryExpr array_remove = 61;
BinaryExpr array_intersect = 62;
ArrayJoin array_join = 63;
ArrayCompact array_compact = 64;
}
}

Expand Down Expand Up @@ -422,6 +423,11 @@ message ArrayJoin {
Expr null_replacement_expr = 3;
}

message ArrayCompact {
Expr array_expr = 1;
DataType item_datatype = 2;
}

message DataType {
enum DataTypeId {
BOOL = 0;
Expand Down
16 changes: 16 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2428,6 +2428,22 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*)
None
}
case expr @ ArrayFilter(child, _) if ArrayCompact(child).replacement.sql == expr.sql =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you add a flag to enable and disable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the PR only contains basic tests, could you add a check to enable this expression only if CometConf.COMET_CAST_ALLOW_INCOMPATIBLE is enabled? We can remove this check in a future PR that adds comprehensive tests and demonstrates that we have Spark-compatible behavior for all supported data types.

Copy link
Author

@kazantsev-maksim kazantsev-maksim Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

supported the latest Interface

val elementType = serializeDataType(child.dataType.asInstanceOf[ArrayType].elementType)
val srcExprProto = exprToProto(child, inputs, binding)
if (elementType.isDefined && srcExprProto.isDefined) {
val arrayCompactBuilder = ExprOuterClass.ArrayCompact
.newBuilder()
.setArrayExpr(srcExprProto.get)
.setItemDatatype(elementType.get)
Some(
ExprOuterClass.Expr
.newBuilder()
.setArrayCompact(arrayCompactBuilder)
.build())
} else {
None
kazantsev-maksim marked this conversation as resolved.
Show resolved Hide resolved
}
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
16 changes: 16 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2701,4 +2701,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("array_compact") {
assume(isSpark34Plus)
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, n = 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1")

checkSparkAnswerAndOperator(
sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NULL"))
checkSparkAnswerAndOperator(
sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NOT NULL"))
kazantsev-maksim marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}