Skip to content

[SPARK-51008][SQL] Add ResultStage for AQE #49715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@ object StaticSQLConf {
.checkValue(thres => thres > 0 && thres <= 128, "The threshold must be in (0,128].")
.createWithDefault(16)

val RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD =
buildStaticConf("spark.sql.resultQueryStage.maxThreadThreshold")
.internal()
.doc("The maximum degree of parallelism to execute ResultQueryStageExec in AQE")
.version("4.0.0")
.intConf
.checkValue(thres => thres > 0 && thres <= 1024, "The threshold must be in (0,1024].")
.createWithDefault(1024)

val SQL_EVENT_TRUNCATE_LENGTH = buildStaticConf("spark.sql.event.truncate.length")
.doc("Threshold of SQL length beyond which it will be truncated before adding to " +
"event. Defaults to no truncation. If set to 0, callsite will be logged instead.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution

import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future => JFuture}
import java.util.concurrent.{CompletableFuture, ConcurrentHashMap, ExecutorService}
import java.util.concurrent.atomic.AtomicLong

import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -301,15 +301,15 @@ object SQLExecution extends Logging {
* SparkContext local properties are forwarded to execution thread
*/
def withThreadLocalCaptured[T](
sparkSession: SparkSession, exec: ExecutorService) (body: => T): JFuture[T] = {
sparkSession: SparkSession, exec: ExecutorService) (body: => T): CompletableFuture[T] = {
val activeSession = sparkSession
val sc = sparkSession.sparkContext
val localProps = Utils.cloneProperties(sc.getLocalProperties)
// `getCurrentJobArtifactState` will return a stat only in Spark Connect mode. In non-Connect
// mode, we default back to the resources of the current Spark session.
val artifactState = JobArtifactSet.getCurrentJobArtifactState.getOrElse(
activeSession.artifactManager.state)
exec.submit(() => JobArtifactSet.withActiveJobArtifactState(artifactState) {
CompletableFuture.supplyAsync(() => JobArtifactSet.withActiveJobArtifactState(artifactState) {
val originalSession = SparkSession.getActiveSession
val originalLocalProps = sc.getLocalProperties
SparkSession.setActiveSession(activeSession)
Expand All @@ -326,6 +326,6 @@ object SQLExecution extends Logging {
SparkSession.clearActiveSession()
}
res
})
}, exec)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,11 @@ case class AdaptiveSparkPlanExec(

def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity)

private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized {
if (isFinalPlan) return currentPhysicalPlan

/**
* Run `fun` on finalized physical plan
*/
def withFinalPlanUpdate[T](fun: SparkPlan => T): T = lock.synchronized {
_isFinalPlan = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so when we call df.collect multi-times, we will re-optimize final stage multi-times. It is due to for each call we need to wrap new ResultQueryStageExec.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case we construct QueryResultStageExec directly and won't re-optimize it: https://github.com/apache/spark/pull/49715/files#diff-ec42cd27662f3f528832c298a60fffa1d341feb04aa1d8c80044b70cbe0ebbfcR536

// In case of this adaptive plan being executed out of `withActive` scoped functions, e.g.,
// `plan.queryExecution.rdd`, we need to set active session here as new plan nodes can be
// created in the middle of the execution.
Expand All @@ -279,7 +281,7 @@ case class AdaptiveSparkPlanExec(
// Use inputPlan logicalLink here in case some top level physical nodes may be removed
// during `initialPlan`
var currentLogicalPlan = inputPlan.logicalLink.get
var result = createQueryStages(currentPhysicalPlan)
var result = createQueryStages(fun, currentPhysicalPlan, firstRun = true)
val events = new LinkedBlockingQueue[StageMaterializationEvent]()
val errors = new mutable.ArrayBuffer[Throwable]()
var stagesToReplace = Seq.empty[QueryStageExec]
Expand Down Expand Up @@ -344,56 +346,53 @@ case class AdaptiveSparkPlanExec(
if (errors.nonEmpty) {
cleanUpAndThrowException(errors.toSeq, None)
}

// Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
// than that of the current plan; otherwise keep the current physical plan together with
// the current logical plan since the physical plan's logical links point to the logical
// plan it has originated from.
// Meanwhile, we keep a list of the query stages that have been created since last plan
// update, which stands for the "semantic gap" between the current logical and physical
// plans. And each time before re-planning, we replace the corresponding nodes in the
// current logical plan with logical query stages to make it semantically in sync with
// the current physical plan. Once a new plan is adopted and both logical and physical
// plans are updated, we can clear the query stage list because at this point the two plans
// are semantically and physically in sync again.
val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
val afterReOptimize = reOptimize(logicalPlan)
if (afterReOptimize.isDefined) {
val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
if (newCost < origCost ||
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
lazy val plans =
sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")
logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}")
cleanUpTempTags(newPhysicalPlan)
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
stagesToReplace = Seq.empty[QueryStageExec]
if (!currentPhysicalPlan.isInstanceOf[ResultQueryStageExec]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to skip ResultQueryStageExec ?

Copy link
Contributor

@cloud-fan cloud-fan Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Result stage is already the last step, there is nothing to reoptimize.

// Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
// than that of the current plan; otherwise keep the current physical plan together with
// the current logical plan since the physical plan's logical links point to the logical
// plan it has originated from.
// Meanwhile, we keep a list of the query stages that have been created since last plan
// update, which stands for the "semantic gap" between the current logical and physical
// plans. And each time before re-planning, we replace the corresponding nodes in the
// current logical plan with logical query stages to make it semantically in sync with
// the current physical plan. Once a new plan is adopted and both logical and physical
// plans are updated, we can clear the query stage list because at this point the two
// plans are semantically and physically in sync again.
val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
val afterReOptimize = reOptimize(logicalPlan)
if (afterReOptimize.isDefined) {
val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
if (newCost < origCost ||
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
lazy val plans = sideBySide(
currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")
logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}")
cleanUpTempTags(newPhysicalPlan)
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
stagesToReplace = Seq.empty[QueryStageExec]
}
}
}
// Now that some stages have finished, we can try creating new stages.
result = createQueryStages(currentPhysicalPlan)
result = createQueryStages(fun, currentPhysicalPlan, firstRun = false)
}

// Run the final plan when there's no more unfinished stages.
currentPhysicalPlan = applyPhysicalRules(
optimizeQueryStage(result.newPlan, isFinalStage = true),
postStageCreationRules(supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
_isFinalPlan = true
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
currentPhysicalPlan
}
_isFinalPlan = true
finalPlanUpdate
// Dereference the result so it can be GCed. After this resultStage.isMaterialized will return
// false, which is expected. If we want to collect result again, we should invoke
// `withFinalPlanUpdate` and pass another result handler and we will create a new result stage.
currentPhysicalPlan.asInstanceOf[ResultQueryStageExec].resultOption.getAndUpdate(_ => None)
.get.asInstanceOf[T]
}

// Use a lazy val to avoid this being called more than once.
@transient private lazy val finalPlanUpdate: Unit = {
// Subqueries that don't belong to any query stage of the main query will execute after the
// last UI update in `getFinalPhysicalPlan`, so we need to update UI here again to make sure
// the newly generated nodes of those subqueries are updated.
if (shouldUpdatePlan && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) {
// Do final plan update after result stage has materialized.
if (shouldUpdatePlan) {
getExecutionId.foreach(onUpdatePlan(_, Seq.empty))
}
logOnLevel(log"Final plan:\n${MDC(QUERY_PLAN, currentPhysicalPlan)}")
Expand Down Expand Up @@ -426,13 +425,6 @@ case class AdaptiveSparkPlanExec(
}
}

private def withFinalPlanUpdate[T](fun: SparkPlan => T): T = {
val plan = getFinalPhysicalPlan()
val result = fun(plan)
finalPlanUpdate
result
}

protected override def stringArgs: Iterator[Any] = Iterator(s"isFinalPlan=$isFinalPlan")

override def generateTreeString(
Expand Down Expand Up @@ -521,6 +513,66 @@ case class AdaptiveSparkPlanExec(
this.inputPlan == obj.asInstanceOf[AdaptiveSparkPlanExec].inputPlan
}

/**
* We separate stage creation of result and non-result stages because there are several edge cases
* of result stage creation:
* - existing ResultQueryStage created in previous `withFinalPlanUpdate`.
* - the root node is a non-result query stage and we have to create query result stage on top of
* it.
* - we create a non-result query stage as root node and the stage is immediately materialized
* due to stage resue, therefore we have to create a result stage right after.
*
* This method wraps around `createNonResultQueryStages`, the general logic is:
* - Early return if ResultQueryStageExec already created before.
* - Create non result query stage if possible.
* - Try to create result query stage when there is no new non-result query stage created and all
* stages are materialized.
*/
private def createQueryStages(
resultHandler: SparkPlan => Any,
plan: SparkPlan,
firstRun: Boolean): CreateStageResult = {
plan match {
// 1. ResultQueryStageExec is already created, no need to create non-result stages
case resultStage @ ResultQueryStageExec(_, optimizedPlan, _) =>
assertStageNotFailed(resultStage)
if (firstRun) {
// There is already an existing ResultQueryStage created in previous `withFinalPlanUpdate`
// e.g, when we do `df.collect` multiple times. Here we create a new result stage to
// execute it again, as the handler function can be different.
val newResultStage = ResultQueryStageExec(currentStageId, optimizedPlan, resultHandler)
currentStageId += 1
setLogicalLinkForNewQueryStage(newResultStage, optimizedPlan)
CreateStageResult(newPlan = newResultStage,
allChildStagesMaterialized = false,
newStages = Seq(newResultStage))
} else {
// We will hit this branch after we've created result query stage in the AQE loop, we
// should do nothing.
CreateStageResult(newPlan = resultStage,
allChildStagesMaterialized = resultStage.isMaterialized,
newStages = Seq.empty)
}
case _ =>
// 2. Create non result query stage
val result = createNonResultQueryStages(plan)
var allNewStages = result.newStages
var newPlan = result.newPlan
var allChildStagesMaterialized = result.allChildStagesMaterialized
// 3. Create result stage
if (allNewStages.isEmpty && allChildStagesMaterialized) {
val resultStage = newResultQueryStage(resultHandler, newPlan)
newPlan = resultStage
allChildStagesMaterialized = false
allNewStages :+= resultStage
}
CreateStageResult(
newPlan = newPlan,
allChildStagesMaterialized = allChildStagesMaterialized,
newStages = allNewStages)
}
}

/**
* This method is called recursively to traverse the plan tree bottom-up and create a new query
* stage or try reusing an existing stage if the current node is an [[Exchange]] node and all of
Expand All @@ -531,7 +583,7 @@ case class AdaptiveSparkPlanExec(
* 2) Whether the child query stages (if any) of the current node have all been materialized.
* 3) A list of the new query stages that have been created.
*/
private def createQueryStages(plan: SparkPlan): CreateStageResult = plan match {
private def createNonResultQueryStages(plan: SparkPlan): CreateStageResult = plan match {
case e: Exchange =>
// First have a quick check in the `stageCache` without having to traverse down the node.
context.stageCache.get(e.canonicalized) match {
Expand All @@ -544,7 +596,7 @@ case class AdaptiveSparkPlanExec(
newStages = if (isMaterialized) Seq.empty else Seq(stage))

case _ =>
val result = createQueryStages(e.child)
val result = createNonResultQueryStages(e.child)
val newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange]
// Create a query stage only when all the child query stages are ready.
if (result.allChildStagesMaterialized) {
Expand Down Expand Up @@ -588,14 +640,28 @@ case class AdaptiveSparkPlanExec(
if (plan.children.isEmpty) {
CreateStageResult(newPlan = plan, allChildStagesMaterialized = true, newStages = Seq.empty)
} else {
val results = plan.children.map(createQueryStages)
val results = plan.children.map(createNonResultQueryStages)
CreateStageResult(
newPlan = plan.withNewChildren(results.map(_.newPlan)),
allChildStagesMaterialized = results.forall(_.allChildStagesMaterialized),
newStages = results.flatMap(_.newStages))
}
}

private def newResultQueryStage(
resultHandler: SparkPlan => Any,
plan: SparkPlan): ResultQueryStageExec = {
// Run the final plan when there's no more unfinished stages.
val optimizedRootPlan = applyPhysicalRules(
optimizeQueryStage(plan, isFinalStage = true),
postStageCreationRules(supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
val resultStage = ResultQueryStageExec(currentStageId, optimizedRootPlan, resultHandler)
currentStageId += 1
setLogicalLinkForNewQueryStage(resultStage, plan)
resultStage
}

private def newQueryStage(plan: SparkPlan): QueryStageExec = {
val queryStage = plan match {
case e: Exchange =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ trait AdaptiveSparkPlanHelper {
}

/**
* Strip the executePlan of AdaptiveSparkPlanExec leaf node.
* Strip the top [[AdaptiveSparkPlanExec]] and [[ResultQueryStageExec]] nodes off
* the [[SparkPlan]].
*/
def stripAQEPlan(p: SparkPlan): SparkPlan = p match {
case a: AdaptiveSparkPlanExec => a.executedPlan
case a: AdaptiveSparkPlanExec => stripAQEPlan(a.executedPlan)
case ResultQueryStageExec(_, plan, _) => plan
case other => other
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.adaptive

import java.util.concurrent.atomic.AtomicReference

import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.Promise

import org.apache.spark.{MapOutputStatistics, SparkException}
import org.apache.spark.broadcast.Broadcast
Expand All @@ -32,7 +34,10 @@ import org.apache.spark.sql.columnar.CachedBatch
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ThreadUtils

/**
* A query stage is an independent subgraph of the query plan. AQE framework will materialize its
Expand Down Expand Up @@ -303,3 +308,43 @@ case class TableCacheQueryStageExec(

override def getRuntimeStatistics: Statistics = inMemoryTableScan.runtimeStatistics
}

case class ResultQueryStageExec(
override val id: Int,
override val plan: SparkPlan,
resultHandler: SparkPlan => Any) extends QueryStageExec {

override def resetMetrics(): Unit = {
plan.resetMetrics()
}

override protected def doMaterialize(): Future[Any] = {
val javaFuture = SQLExecution.withThreadLocalCaptured(
session,
ResultQueryStageExec.executionContext) {
resultHandler(plan)
}
val scalaPromise: Promise[Any] = Promise()
javaFuture.whenComplete { (result: Any, exception: Throwable) =>
if (exception != null) {
scalaPromise.failure(exception match {
case completionException: java.util.concurrent.CompletionException =>
completionException.getCause
case ex => ex
})
} else {
scalaPromise.success(result)
}
}
scalaPromise.future
}

// Result stage could be any SparkPlan, so we don't have a specific runtime statistics for it.
override def getRuntimeStatistics: Statistics = Statistics.DUMMY
}

object ResultQueryStageExec {
private[execution] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("ResultQueryStageExecution",
SQLConf.get.getConf(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD)))
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ object SparkPlanGraph {
buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges)
}
case "TableCacheQueryStage" =>
case "TableCacheQueryStage" | "ResultQueryStage" =>
buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges)
case "Subquery" if subgraph != null =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1659,7 +1659,9 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
_.nodeName.contains("TableCacheQueryStage"))
val aqeNode = findNodeInSparkPlanInfo(inMemoryScanNode.get,
_.nodeName.contains("AdaptiveSparkPlan"))
aqeNode.get.children.head.nodeName == "AQEShuffleRead"
val aqePlanRoot = findNodeInSparkPlanInfo(inMemoryScanNode.get,
_.nodeName.contains("ResultQueryStage"))
aqePlanRoot.get.children.head.nodeName == "AQEShuffleRead"
}

withTempView("t0", "t1", "t2") {
Expand Down
Loading