diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 22af76f6e4..0fbbb3225f 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -601,6 +601,13 @@ object CometConf extends ShimCometConf { .toSequence .createWithDefault(Seq("Range,InMemoryTableScan,RDDScan")) + val COMET_PREFER_TO_ARROW_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.preferToColumnar.enabled") + .internal() + .doc("TODO: doc") + .booleanConf + .createWithDefault(true) + val COMET_CASE_CONVERSION_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.caseConversion.enabled") .doc( diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 03333a14d4..c4ee96e418 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -154,7 +154,13 @@ object Utils extends CometTypeShim { name, fieldType, Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava) - case StructType(fields) => + case st @ StructType(fields) => + if (st.names.toSet.size != fields.length) { + throw new SparkException( + "Duplicated field names in Arrow Struct are not allowed," + + s" got ${st.names.mkString("[", ", ", "]")}.") + } + val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) new Field( name, diff --git a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala index 755c345717..b5b8a53029 100644 --- a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala +++ b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.ExtendedExplainGenerator import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} import org.apache.spark.sql.execution.{InputAdapter, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.comet.CometExplainInfo.getActualPlan @@ -158,6 +159,7 @@ object CometExplainInfo { case p: InputAdapter => getActualPlan(p.child) case p: QueryStageExec => getActualPlan(p.plan) case p: WholeStageCodegenExec => getActualPlan(p.child) + case p: ReusedExchangeExec => getActualPlan(p.child) case p => p } } diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 091f70fdc2..54521e0325 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -137,11 +137,27 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { */ // spotless:on private def transform(plan: SparkPlan): SparkPlan = { - def operator2Proto(op: SparkPlan): Option[Operator] = { - if (op.children.forall(_.isInstanceOf[CometNativeExec])) { - QueryPlanSerde.operator2Proto( - op, - op.children.map(_.asInstanceOf[CometNativeExec].nativeOp): _*) + def operator2Proto[T <: SparkPlan](plan: T): Option[(T, Operator)] = { + val newPlan = if (CometConf.COMET_PREFER_TO_ARROW_ENABLED.get(conf)) { + val newChildren = plan.children.map { + case p: CometNativeExec => p + case op => + val cometOp = CometSparkToColumnarExec(op) + QueryPlanSerde + .operator2Proto(cometOp) + .map(CometScanWrapper(_, cometOp)) + .getOrElse(op) + } + plan.withNewChildren(newChildren).asInstanceOf[T] + } else { + plan + } + if (newPlan.children.forall(_.isInstanceOf[CometNativeExec])) { + QueryPlanSerde + .operator2Proto( + newPlan, + newPlan.children.map(_.asInstanceOf[CometNativeExec].nativeOp): _*) + .map(op => (newPlan, op)) } else { None } @@ -150,8 +166,8 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { /** * Convert operator to proto and then apply a transformation to wrap the proto in a new plan. */ - def newPlanWithProto(op: SparkPlan, fun: Operator => SparkPlan): SparkPlan = { - operator2Proto(op).map(fun).getOrElse(op) + def newPlanWithProto[T <: SparkPlan](op: T)(fun: (T, Operator) => SparkPlan): SparkPlan = { + operator2Proto(op).map(p => fun(p._1, p._2)).getOrElse(op) } def convertNode(op: SparkPlan): SparkPlan = op match { @@ -171,34 +187,58 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { CometScanWrapper(nativeOp.get, cometOp) case op: ProjectExec => - newPlanWithProto( - op, - CometProjectExec(_, op, op.output, op.projectList, op.child, SerializedPlan(None))) - + newPlanWithProto(op) { case (newPlan, operator) => + CometProjectExec( + operator, + newPlan, + newPlan.output, + newPlan.projectList, + newPlan.child, + SerializedPlan(None)) + } case op: FilterExec => - newPlanWithProto( - op, - CometFilterExec(_, op, op.output, op.condition, op.child, SerializedPlan(None))) + newPlanWithProto(op) { case (newPlan, operator) => + CometFilterExec( + operator, + newPlan, + newPlan.output, + newPlan.condition, + newPlan.child, + SerializedPlan(None)) + } case op: SortExec => - newPlanWithProto( - op, + newPlanWithProto(op) { case (newPlan, operator) => CometSortExec( - _, - op, - op.output, - op.outputOrdering, - op.sortOrder, - op.child, - SerializedPlan(None))) + operator, + newPlan, + newPlan.output, + newPlan.outputOrdering, + newPlan.sortOrder, + newPlan.child, + SerializedPlan(None)) + } case op: LocalLimitExec => - newPlanWithProto(op, CometLocalLimitExec(_, op, op.limit, op.child, SerializedPlan(None))) + newPlanWithProto(op) { case (newPlan, operator) => + CometLocalLimitExec( + operator, + newPlan, + newPlan.limit, + newPlan.child, + SerializedPlan(None)) + } case op: GlobalLimitExec => - newPlanWithProto( - op, - CometGlobalLimitExec(_, op, op.limit, op.offset, op.child, SerializedPlan(None))) + newPlanWithProto(op) { case (newPlan, operator) => + CometGlobalLimitExec( + operator, + newPlan, + newPlan.limit, + newPlan.offset, + newPlan.child, + SerializedPlan(None)) + } case op: CollectLimitExec => val fallbackReasons = new ListBuffer[String]() @@ -227,9 +267,15 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { } case op: ExpandExec => - newPlanWithProto( - op, - CometExpandExec(_, op, op.output, op.projections, op.child, SerializedPlan(None))) + newPlanWithProto(op) { case (newPlan, operator) => + CometExpandExec( + operator, + newPlan, + newPlan.output, + newPlan.projections, + newPlan.child, + SerializedPlan(None)) + } // When Comet shuffle is disabled, we don't want to transform the HashAggregate // to CometHashAggregate. Otherwise, we probably get partial Comet aggregation @@ -248,46 +294,44 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { if (multiMode || sparkFinalMode) { op } else { - newPlanWithProto( - op, - nativeOp => { - // The aggExprs could be empty. For example, if the aggregate functions only have - // distinct aggregate functions or only have group by, the aggExprs is empty and - // modes is empty too. If aggExprs is not empty, we need to verify all the - // aggregates have the same mode. - assert(modes.length == 1 || modes.isEmpty) - CometHashAggregateExec( - nativeOp, - op, - op.output, - op.groupingExpressions, - op.aggregateExpressions, - op.resultExpressions, - op.child.output, - modes.headOption, - op.child, - SerializedPlan(None)) - }) + newPlanWithProto(op) { case (newPlan, operator) => + // The aggExprs could be empty. For example, if the aggregate functions only have + // distinct aggregate functions or only have group by, the aggExprs is empty and + // modes is empty too. If aggExprs is not empty, we need to verify all the + // aggregates have the same mode. + assert(modes.length == 1 || modes.isEmpty) + CometHashAggregateExec( + operator, + newPlan, + newPlan.output, + newPlan.groupingExpressions, + newPlan.aggregateExpressions, + newPlan.resultExpressions, + newPlan.child.output, + modes.headOption, + newPlan.child, + SerializedPlan(None)) + } } case op: ShuffledHashJoinExec if CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) && op.children.forall(isCometNative) => - newPlanWithProto( - op, + newPlanWithProto(op) { case (newPlan, operator) => CometHashJoinExec( - _, - op, - op.output, - op.outputOrdering, - op.leftKeys, - op.rightKeys, - op.joinType, - op.condition, - op.buildSide, - op.left, - op.right, - SerializedPlan(None))) + operator, + newPlan, + newPlan.output, + newPlan.outputOrdering, + newPlan.leftKeys, + newPlan.rightKeys, + newPlan.joinType, + newPlan.condition, + newPlan.buildSide, + newPlan.left, + newPlan.right, + SerializedPlan(None)) + } case op: ShuffledHashJoinExec if !CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) => withInfo(op, "ShuffleHashJoin is not enabled") @@ -297,40 +341,44 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { case op: BroadcastHashJoinExec if CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) && - op.children.forall(isCometNative) => - newPlanWithProto( - op, + // check has columnar broadcast child + op.children.exists { + case CometSinkPlaceHolder(_, _, _, true) => true + case _ => false + } => + newPlanWithProto(op) { case (newPlan, operator) => CometBroadcastHashJoinExec( - _, - op, - op.output, - op.outputOrdering, - op.leftKeys, - op.rightKeys, - op.joinType, - op.condition, - op.buildSide, - op.left, - op.right, - SerializedPlan(None))) + operator, + newPlan, + newPlan.output, + newPlan.outputOrdering, + newPlan.leftKeys, + newPlan.rightKeys, + newPlan.joinType, + newPlan.condition, + newPlan.buildSide, + newPlan.left, + newPlan.right, + SerializedPlan(None)) + } case op: SortMergeJoinExec if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) && op.children.forall(isCometNative) => - newPlanWithProto( - op, + newPlanWithProto(op) { case (newPlan, operator) => CometSortMergeJoinExec( - _, - op, - op.output, - op.outputOrdering, - op.leftKeys, - op.rightKeys, - op.joinType, - op.condition, - op.left, - op.right, - SerializedPlan(None))) + operator, + newPlan, + newPlan.output, + newPlan.outputOrdering, + newPlan.leftKeys, + newPlan.rightKeys, + newPlan.joinType, + newPlan.condition, + newPlan.left, + newPlan.right, + SerializedPlan(None)) + } case op: SortMergeJoinExec if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) && @@ -391,26 +439,25 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { withInfo(s, Seq(info1, info2).flatten.mkString(",")) case w: WindowExec => - newPlanWithProto( - w, + newPlanWithProto(w) { case (newPlan, operator) => CometWindowExec( - _, - w, - w.output, - w.windowExpression, - w.partitionSpec, - w.orderSpec, - w.child, - SerializedPlan(None))) + operator, + newPlan, + newPlan.output, + newPlan.windowExpression, + newPlan.partitionSpec, + newPlan.orderSpec, + newPlan.child, + SerializedPlan(None)) + } case u: UnionExec if CometConf.COMET_EXEC_UNION_ENABLED.get(conf) && u.children.forall(isCometNative) => - newPlanWithProto( - u, { - val cometOp = CometUnionExec(u, u.output, u.children) - CometSinkPlaceHolder(_, u, cometOp) - }) + newPlanWithProto(u) { case (newPlan, operator) => + val cometOp = CometUnionExec(newPlan, newPlan.output, newPlan.children) + CometSinkPlaceHolder(operator, newPlan, cometOp) + } case u: UnionExec if !CometConf.COMET_EXEC_UNION_ENABLED.get(conf) => withInfo(u, "Union is not enabled") @@ -420,13 +467,17 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { // For AQE broadcast stage on a Comet broadcast exchange case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => - newPlanWithProto(s, CometSinkPlaceHolder(_, s, s)) + newPlanWithProto(s) { case (newPlan, operator) => + CometSinkPlaceHolder(operator, newPlan, newPlan, isBroadcast = true) + } case s @ BroadcastQueryStageExec( _, ReusedExchangeExec(_, _: CometBroadcastExchangeExec), _) => - newPlanWithProto(s, CometSinkPlaceHolder(_, s, s)) + newPlanWithProto(s) { case (newPlan, operator) => + CometSinkPlaceHolder(operator, newPlan, newPlan, isBroadcast = true) + } // `CometBroadcastExchangeExec`'s broadcast output is not compatible with Spark's broadcast // exchange. It is only used for Comet native execution. We only transform Spark broadcast @@ -440,7 +491,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { QueryPlanSerde.operator2Proto(b) match { case Some(nativeOp) => val cometOp = CometBroadcastExchangeExec(b, b.output, b.mode, b.child) - CometSinkPlaceHolder(nativeOp, b, cometOp) + CometSinkPlaceHolder(nativeOp, b, cometOp, isBroadcast = true) case None => b } case other => other @@ -473,26 +524,29 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { // For AQE shuffle stage on a Comet shuffle exchange case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => - newPlanWithProto(s, CometSinkPlaceHolder(_, s, s)) + newPlanWithProto(s) { case (newPlan, operator) => + CometSinkPlaceHolder(operator, newPlan, newPlan) + } // For AQE shuffle stage on a reused Comet shuffle exchange // Note that we don't need to handle `ReusedExchangeExec` for non-AQE case, because // the query plan won't be re-optimized/planned in non-AQE mode. case s @ ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => - newPlanWithProto(s, CometSinkPlaceHolder(_, s, s)) + newPlanWithProto(s) { case (newPlan, operator) => + CometSinkPlaceHolder(operator, newPlan, newPlan) + } // Native shuffle for Comet operators case s: ShuffleExchangeExec => val nativeShuffle: Option[SparkPlan] = if (nativeShuffleSupported(s)) { - val newOp = operator2Proto(s) - newOp match { - case Some(nativeOp) => + operator2Proto(s) match { + case Some((newPlan, newOp)) => // Switch to use Decimal128 regardless of precision, since Arrow native execution // doesn't support Decimal32 and Decimal64 yet. conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true") - val cometOp = CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle) - Some(CometSinkPlaceHolder(nativeOp, s, cometOp)) + val cometOp = CometShuffleExchangeExec(newPlan, shuffleType = CometNativeShuffle) + Some(CometSinkPlaceHolder(newOp, newPlan, cometOp)) case None => None } @@ -535,8 +589,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { case op => op match { - case _: CometExec | _: AQEShuffleReadExec | _: BroadcastExchangeExec | - _: CometBroadcastExchangeExec | _: CometShuffleExchangeExec | + case _: CometPlan | _: AQEShuffleReadExec | _: BroadcastExchangeExec | _: BroadcastQueryStageExec | _: AdaptiveSparkPlanExec => // Some execs should never be replaced. We include // these cases specially here so we do not add a misleading 'info' message @@ -555,9 +608,25 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { } } - plan.transformUp { case op => + val newPlan = plan.transformUp { case op => convertNode(op) } + + // insert CometColumnarToRowExec if necessary + newPlan.transformUp { + case c2r: ColumnarToRowTransition => c2r + case op if !op.supportsColumnar => + val newChildren = op.children.map { + // CometExec already handles columnar to row conversion internally + // Don't explicitly add CometColumnarToRowExec helps broadcast reuse, + // for plan like: BroadcastExchangeExec(CometExec) + case cometExec: CometExec => cometExec + case c if c.supportsColumnar => CometColumnarToRowExec(c) + case other => other + } + op.withNewChildren(newChildren) + case o => o + } } private def normalizePlan(plan: SparkPlan): SparkPlan = { @@ -657,7 +726,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { // Remove placeholders newPlan = newPlan.transform { - case CometSinkPlaceHolder(_, _, s) => s + case CometSinkPlaceHolder(_, _, s, _) => s case CometScanWrapper(_, s) => s } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala index 6d0a31236f..d965a6ff7b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala @@ -53,6 +53,7 @@ import org.apache.comet.vector.CometPlainVector */ case class CometColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransition + with CometPlan with CodegenSupport { // supportsColumnar requires to be only called on driver side, see also SPARK-37779. assert(Utils.isInRunningSparkTask || child.supportsColumnar) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometPlan.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometPlan.scala index e5d268cd90..3ba8b8f4de 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometPlan.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometPlan.scala @@ -19,9 +19,17 @@ package org.apache.spark.sql.comet +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan /** * The base trait for physical Comet operators. */ -trait CometPlan extends SparkPlan +trait CometPlan extends SparkPlan { + + override def setLogicalLink(logicalPlan: LogicalPlan): Unit = { + // Don't propagate the logical plan to children, as they may not be CometPlan. + setTagValue(SparkPlan.LOGICAL_PLAN_TAG, logicalPlan) + } + +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a7cfacc475..630aa7fa6e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -998,7 +998,8 @@ case class CometScanWrapper(override val nativeOp: Operator, override val origin case class CometSinkPlaceHolder( override val nativeOp: Operator, // Must be a Scan override val originalPlan: SparkPlan, - child: SparkPlan) + child: SparkPlan, + isBroadcast: Boolean = false) extends CometUnaryExec { override val serializedPlanOpt: SerializedPlan = SerializedPlan(None) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 6574d9568d..ec46538e4e 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -409,7 +409,8 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { (0 until 100).map(i => (i, (i % 10).toString)), "tbl", dictionaryEnabled) { - val n = if (nativeShuffleEnabled) 2 else 1 + val n = + if (nativeShuffleEnabled || CometConf.COMET_PREFER_TO_ARROW_ENABLED.get()) 2 else 1 checkSparkAnswerAndNumOfAggregates("SELECT _2, SUM(_1) FROM tbl GROUP BY _2", n) checkSparkAnswerAndNumOfAggregates("SELECT _2, COUNT(_1) FROM tbl GROUP BY _2", n) checkSparkAnswerAndNumOfAggregates("SELECT _2, MIN(_1) FROM tbl GROUP BY _2", n) @@ -517,13 +518,12 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { (-0.0.asInstanceOf[Float], 2), (0.0.asInstanceOf[Float], 3), (Float.NaN, 4)) - withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "false") { - withParquetTable(data, "tbl", dictionaryEnabled) { - checkSparkAnswer("SELECT SUM(_2), MIN(_2), MAX(_2), _1 FROM tbl GROUP BY _1") - checkSparkAnswer("SELECT MIN(_1), MAX(_1), MIN(_2), MAX(_2) FROM tbl") - checkSparkAnswer("SELECT AVG(_2), _1 FROM tbl GROUP BY _1") - checkSparkAnswer("SELECT AVG(_1), AVG(_2) FROM tbl") - } + withParquetTable(data, "tbl", dictionaryEnabled) { + checkSparkAnswer("SELECT SUM(_2), MIN(_2), MAX(_2), _1 FROM tbl GROUP BY _1") + // FIXME: Add MIN(_1) once https://github.com/apache/datafusion-comet/issues/2448 is fixed + checkSparkAnswer("SELECT MAX(_1), MIN(_2), MAX(_2) FROM tbl") + checkSparkAnswer("SELECT AVG(_2), _1 FROM tbl GROUP BY _1") + checkSparkAnswer("SELECT AVG(_1), AVG(_2) FROM tbl") } } } @@ -592,7 +592,9 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { val path = new Path(dir.toURI.toString, "test") makeParquetFile(path, 1000, 20, dictionaryEnabled) withParquetTable(path.toUri.toString, "tbl") { - val expectedNumOfCometAggregates = if (nativeShuffleEnabled) 2 else 1 + val expectedNumOfCometAggregates = + if (nativeShuffleEnabled || CometConf.COMET_PREFER_TO_ARROW_ENABLED.get()) 2 + else 1 checkSparkAnswerAndNumOfAggregates( "SELECT _g2, SUM(_7) FROM tbl GROUP BY _g2", @@ -734,7 +736,11 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { sql("CREATE TABLE t(v VARCHAR(3), i INT) USING PARQUET") sql("INSERT INTO t VALUES ('c', 1)") withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "false") { - checkSparkAnswerAndNumOfAggregates("SELECT v, sum(i) FROM t GROUP BY v ORDER BY v", 1) + val expectedNumOfCometAggregates = + if (CometConf.COMET_PREFER_TO_ARROW_ENABLED.get()) 2 else 1 + checkSparkAnswerAndNumOfAggregates( + "SELECT v, sum(i) FROM t GROUP BY v ORDER BY v", + expectedNumOfCometAggregates) } } } @@ -1058,7 +1064,11 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { "tbl", dictionaryEnabled) { withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "false") { - checkSparkAnswerAndNumOfAggregates("SELECT _2 , AVG(_1) FROM tbl GROUP BY _2", 1) + val expectedNumOfCometAggregates = + if (CometConf.COMET_PREFER_TO_ARROW_ENABLED.get()) 2 else 1 + checkSparkAnswerAndNumOfAggregates( + "SELECT _2 , AVG(_1) FROM tbl GROUP BY _2", + expectedNumOfCometAggregates) } } } @@ -1095,7 +1105,9 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { val path = new Path(dir.toURI.toString, "test") makeParquetFile(path, 1000, 20, dictionaryEnabled) withParquetTable(path.toUri.toString, "tbl") { - val expectedNumOfCometAggregates = if (nativeShuffleEnabled) 2 else 1 + val expectedNumOfCometAggregates = + if (nativeShuffleEnabled || CometConf.COMET_PREFER_TO_ARROW_ENABLED.get()) 2 + else 1 checkSparkAnswerAndNumOfAggregates( "SELECT _g2, AVG(_7) FROM tbl GROUP BY _g2", diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 47d2205a08..37bf56c292 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -109,6 +109,10 @@ class CometExecSuite extends CometTestBase { Seq("parquet").foreach { v1List => withSQLConf( SQLConf.USE_V1_SOURCE_LIST.key -> v1List, + // FIXME: prefer to arrow will cause DPP not work, + // PlanAdaptiveDynamicPruningFilters/PlanDynamicPruningFilters will check BHJ to reused, + // but preferToArrow will cause BHJ to convert to CometBHJ + CometConf.COMET_PREFER_TO_ARROW_ENABLED.key -> "false", CometConf.COMET_DPP_FALLBACK_ENABLED.key -> "true") { spark.read.parquet(factPath).createOrReplaceTempView("dpp_fact") spark.read.parquet(dimPath).createOrReplaceTempView("dpp_dim") @@ -468,8 +472,9 @@ class CometExecSuite extends CometTestBase { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", CometConf.COMET_SHUFFLE_MODE.key -> columnarShuffleMode) { - val df = sql("SELECT * FROM v where c1 = 1 order by c1, c2") - val shuffle = find(df.queryExecution.executedPlan) { + val (_, cometPlan) = checkSparkAnswer("SELECT * FROM v where c1 = 1 order by c1, c2") + println(cometPlan) + val shuffle = find(cometPlan) { case _: CometShuffleExchangeExec if columnarShuffleMode.equalsIgnoreCase("jvm") => true case _: ShuffleExchangeExec if !columnarShuffleMode.equalsIgnoreCase("jvm") => true