diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 294922f2f..c367a5235 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -108,6 +108,7 @@ use jni::objects::GlobalRef; use num::{BigInt, ToPrimitive}; use std::cmp::max; use std::{collections::HashMap, sync::Arc}; +use datafusion_expr::test::function_stub::count_udaf; // For clippy error on type_complexity. type PhyAggResult = Result, ExecutionError>; @@ -1521,7 +1522,7 @@ impl PhysicalPlanner { .schema(schema) .alias("count") .with_ignore_nulls(false) - .with_distinct(false) + .with_distinct(spark_expr.distinct) .build() .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } @@ -1534,7 +1535,7 @@ impl PhysicalPlanner { .schema(schema) .alias("min") .with_ignore_nulls(false) - .with_distinct(false) + .with_distinct(spark_expr.distinct) .build() .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } @@ -1547,7 +1548,7 @@ impl PhysicalPlanner { .schema(schema) .alias("max") .with_ignore_nulls(false) - .with_distinct(false) + .with_distinct(spark_expr.distinct) .build() .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } @@ -1572,7 +1573,7 @@ impl PhysicalPlanner { .schema(schema) .alias("sum") .with_ignore_nulls(false) - .with_distinct(false) + .with_distinct(spark_expr.distinct) .build() .map_err(|e| e.into()) } @@ -1600,7 +1601,7 @@ impl PhysicalPlanner { .schema(schema) .alias("avg") .with_ignore_nulls(false) - .with_distinct(false) + .with_distinct(spark_expr.distinct) .build() .map_err(|e| e.into()) } @@ -1612,7 +1613,7 @@ impl PhysicalPlanner { .schema(schema) .alias("first") .with_ignore_nulls(false) - .with_distinct(false) + .with_distinct(spark_expr.distinct) .build() .map_err(|e| e.into()) } @@ -1624,7 +1625,7 @@ impl PhysicalPlanner { .schema(schema) .alias("last") .with_ignore_nulls(false) - .with_distinct(false) + .with_distinct(spark_expr.distinct) .build() .map_err(|e| e.into()) } @@ -1635,7 +1636,7 @@ impl PhysicalPlanner { .schema(schema) .alias("bit_and") .with_ignore_nulls(false) - .with_distinct(false) + .with_distinct(spark_expr.distinct) .build() .map_err(|e| e.into()) } @@ -1646,7 +1647,7 @@ impl PhysicalPlanner { .schema(schema) .alias("bit_or") .with_ignore_nulls(false) - .with_distinct(false) + .with_distinct(spark_expr.distinct) .build() .map_err(|e| e.into()) } @@ -1657,7 +1658,7 @@ impl PhysicalPlanner { .schema(schema) .alias("bit_xor") .with_ignore_nulls(false) - .with_distinct(false) + .with_distinct(spark_expr.distinct) .build() .map_err(|e| e.into()) } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index e76ecdccf..0c9c9a0a4 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -89,6 +89,7 @@ message Expr { } message AggExpr { + bool distinct = 1; oneof expr_struct { Count count = 2; Sum sum = 3; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7a69e6309..abb6df917 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -386,6 +386,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim ExprOuterClass.AggExpr .newBuilder() .setSum(sumBuilder) + .setDistinct(aggExpr.isDistinct) .build()) } else { if (dataType.isEmpty) { @@ -432,6 +433,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } case Count(children) => + if (children.length > 1 && aggExpr.isDistinct) { + withInfo(aggExpr, "no support for count distinct with multiple expressions") + return None + } + val exprChildren = children.map(exprToProto(_, inputs, binding)) if (exprChildren.forall(_.isDefined)) { @@ -442,6 +448,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim ExprOuterClass.AggExpr .newBuilder() .setCount(countBuilder) + .setDistinct(aggExpr.isDistinct) .build()) } else { withInfo(aggExpr, children: _*) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 8170230bc..886454990 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -108,7 +108,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { val cometShuffles = collect(df2.queryExecution.executedPlan) { case _: CometShuffleExchangeExec => true } - if (shuffleMode == "jvm") { + if (shuffleMode == "jvm" || shuffleMode == "auto") { assert(cometShuffles.length == 1) } else { // we fall back to Spark for shuffle because we do not support @@ -608,8 +608,8 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withView("v") { sql("CREATE TEMP VIEW v AS SELECT _1, _2 FROM tbl ORDER BY _1") checkSparkAnswer( - "SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1), COUNT(_1)," + - " COUNT(DISTINCT _1), AVG(_1), FIRST(_1), LAST(_1) FROM v GROUP BY _2") + "SELECT _2, SUM(_1), SUM(DISTINCT _2), MIN(_1), MAX(_1), COUNT(_1)," + + " COUNT(DISTINCT _2), AVG(_1), FIRST(_1), LAST(_1) FROM v GROUP BY _2 ORDER BY _2") } } }