Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51008][SQL][WIP] Add ResultStage for AQE #49715

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@ object StaticSQLConf {
.checkValue(thres => thres > 0 && thres <= 128, "The threshold must be in (0,128].")
.createWithDefault(16)

val RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD =
buildStaticConf("spark.sql.resultQueryStage.maxThreadThreshold")
.internal()
.doc("The maximum degree of parallelism to execute ResultQueryStageExec in AQE")
.version("4.0.0")
.intConf
.checkValue(thres => thres > 0 && thres <= 1024, "The threshold must be in (0,1024].")
.createWithDefault(1024)

val SQL_EVENT_TRUNCATE_LENGTH = buildStaticConf("spark.sql.event.truncate.length")
.doc("Threshold of SQL length beyond which it will be truncated before adding to " +
"event. Defaults to no truncation. If set to 0, callsite will be logged instead.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution

import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future => JFuture}
import java.util.concurrent.{CompletableFuture, ConcurrentHashMap, ExecutorService}
import java.util.concurrent.atomic.AtomicLong

import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -301,15 +301,15 @@ object SQLExecution extends Logging {
* SparkContext local properties are forwarded to execution thread
*/
def withThreadLocalCaptured[T](
sparkSession: SparkSession, exec: ExecutorService) (body: => T): JFuture[T] = {
sparkSession: SparkSession, exec: ExecutorService) (body: => T): CompletableFuture[T] = {
val activeSession = sparkSession
val sc = sparkSession.sparkContext
val localProps = Utils.cloneProperties(sc.getLocalProperties)
// `getCurrentJobArtifactState` will return a stat only in Spark Connect mode. In non-Connect
// mode, we default back to the resources of the current Spark session.
val artifactState = JobArtifactSet.getCurrentJobArtifactState.getOrElse(
activeSession.artifactManager.state)
exec.submit(() => JobArtifactSet.withActiveJobArtifactState(artifactState) {
CompletableFuture.supplyAsync(() => JobArtifactSet.withActiveJobArtifactState(artifactState) {
val originalSession = SparkSession.getActiveSession
val originalLocalProps = sc.getLocalProperties
SparkSession.setActiveSession(activeSession)
Expand All @@ -326,6 +326,6 @@ object SQLExecution extends Logging {
SparkSession.clearActiveSession()
}
res
})
}, exec)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,14 @@ case class AdaptiveSparkPlanExec(

@volatile private var currentPhysicalPlan = initialPlan

// Use inputPlan logicalLink here in case some top level physical nodes may be removed
// during `initialPlan`
@transient @volatile private var currentLogicalPlan: LogicalPlan = {
inputPlan.logicalLink.get
}

val stagesToReplace = mutable.ArrayBuffer.empty[QueryStageExec]

@volatile private var _isFinalPlan = false

private var currentStageId = 0
Expand Down Expand Up @@ -289,26 +297,24 @@ case class AdaptiveSparkPlanExec(

def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity)

private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized {
if (isFinalPlan) return currentPhysicalPlan

/**
* Run `fun` on finalized physical plan
*/
def withFinalPlanUpdate[T](fun: SparkPlan => T): T = lock.synchronized {
_isFinalPlan = false
// In case of this adaptive plan being executed out of `withActive` scoped functions, e.g.,
// `plan.queryExecution.rdd`, we need to set active session here as new plan nodes can be
// created in the middle of the execution.
context.session.withActive {
val executionId = getExecutionId
// Use inputPlan logicalLink here in case some top level physical nodes may be removed
// during `initialPlan`
var currentLogicalPlan = inputPlan.logicalLink.get
var result = createQueryStages(currentPhysicalPlan)
var result = createQueryStages(fun, currentPhysicalPlan, true)
val events = new LinkedBlockingQueue[StageMaterializationEvent]()
val errors = new mutable.ArrayBuffer[Throwable]()
var stagesToReplace = Seq.empty[QueryStageExec]
while (!result.allChildStagesMaterialized) {
ruleContext.clearConfigs()
currentPhysicalPlan = result.newPlan
if (result.newStages.nonEmpty) {
stagesToReplace = result.newStages ++ stagesToReplace
stagesToReplace ++= result.newStages
executionId.foreach(onUpdatePlan(_, result.newStages.map(_.plan)))

// SPARK-33933: we should submit tasks of broadcast stages first, to avoid waiting
Expand Down Expand Up @@ -366,50 +372,44 @@ case class AdaptiveSparkPlanExec(
if (errors.nonEmpty) {
cleanUpAndThrowException(errors.toSeq, None)
}

// Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
// than that of the current plan; otherwise keep the current physical plan together with
// the current logical plan since the physical plan's logical links point to the logical
// plan it has originated from.
// Meanwhile, we keep a list of the query stages that have been created since last plan
// update, which stands for the "semantic gap" between the current logical and physical
// plans. And each time before re-planning, we replace the corresponding nodes in the
// current logical plan with logical query stages to make it semantically in sync with
// the current physical plan. Once a new plan is adopted and both logical and physical
// plans are updated, we can clear the query stage list because at this point the two plans
// are semantically and physically in sync again.
val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
val afterReOptimize = reOptimize(logicalPlan)
if (afterReOptimize.isDefined) {
val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
if (newCost < origCost ||
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
lazy val plans =
sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")
logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}")
cleanUpTempTags(newPhysicalPlan)
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
stagesToReplace = Seq.empty[QueryStageExec]
if (!currentPhysicalPlan.isInstanceOf[ResultQueryStageExec]) {
// Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
// than that of the current plan; otherwise keep the current physical plan together with
// the current logical plan since the physical plan's logical links point to the logical
// plan it has originated from.
// Meanwhile, we keep a list of the query stages that have been created since last plan
// update, which stands for the "semantic gap" between the current logical and physical
// plans. And each time before re-planning, we replace the corresponding nodes in the
// current logical plan with logical query stages to make it semantically in sync with
// the current physical plan. Once a new plan is adopted and both logical and physical
// plans are updated, we can clear the query stage list because at this point the two
// plans are semantically and physically in sync again.
val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan,
stagesToReplace.toSeq)
val afterReOptimize = reOptimize(logicalPlan)
if (afterReOptimize.isDefined) {
val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
if (newCost < origCost ||
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
lazy val plans = sideBySide(
currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")
logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}")
cleanUpTempTags(newPhysicalPlan)
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
stagesToReplace.clear()
}
}
// Now that some stages have finished, we can try creating new stages.
result = createQueryStages(fun, currentPhysicalPlan, false)
}
// Now that some stages have finished, we can try creating new stages.
result = createQueryStages(currentPhysicalPlan)
}

ruleContext = ruleContext.withFinalStage(isFinalStage = true)
// Run the final plan when there's no more unfinished stages.
currentPhysicalPlan = applyPhysicalRulesWithRuleContext(
optimizeQueryStage(result.newPlan, isFinalStage = true),
postStageCreationRules(supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
ruleContext.clearConfigs()
_isFinalPlan = true
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
currentPhysicalPlan
}
_isFinalPlan = true
finalPlanUpdate
currentPhysicalPlan.asInstanceOf[ResultQueryStageExec].resultOption.get().get.asInstanceOf[T]
}

// Use a lazy val to avoid this being called more than once.
Expand Down Expand Up @@ -450,13 +450,6 @@ case class AdaptiveSparkPlanExec(
}
}

private def withFinalPlanUpdate[T](fun: SparkPlan => T): T = {
val plan = getFinalPhysicalPlan()
val result = fun(plan)
finalPlanUpdate
result
}

protected override def stringArgs: Iterator[Any] = Iterator(s"isFinalPlan=$isFinalPlan")

override def generateTreeString(
Expand Down Expand Up @@ -545,6 +538,66 @@ case class AdaptiveSparkPlanExec(
this.inputPlan == obj.asInstanceOf[AdaptiveSparkPlanExec].inputPlan
}

/**
* This method is a wrapper of `createQueryStagesInternal`, which deals with result stage creation
*/
private def createQueryStages(
resultHandler: SparkPlan => Any,
plan: SparkPlan,
firstRun: Boolean): CreateStageResult = {
plan match {
case resultStage@ResultQueryStageExec(_, optimizedPlan, _) =>
return if (firstRun) {
// There is already an existing ResultQueryStage created in previous `withFinalPlanUpdate`
val newResultStage = ResultQueryStageExec(currentStageId, optimizedPlan, resultHandler)
currentStageId += 1
setLogicalLinkForNewQueryStage(newResultStage, optimizedPlan)
stagesToReplace.append(newResultStage)
CreateStageResult(newPlan = newResultStage,
allChildStagesMaterialized = false,
newStages = Seq(newResultStage))
} else {
// result stage already created, do nothing
CreateStageResult(newPlan = plan,
allChildStagesMaterialized = resultStage.isMaterialized,
newStages = Seq.empty)
}
case _ =>
}
val result = createQueryStagesInternal(plan)
var allNewStages = result.newStages
var newPlan = result.newPlan
var allChildStagesMaterialized = result.allChildStagesMaterialized
// Create result stage
if (allNewStages.isEmpty && allChildStagesMaterialized) {
val resultStage = createResultQueryStage(resultHandler, newPlan)
stagesToReplace.append(resultStage)
newPlan = resultStage
allChildStagesMaterialized = false
allNewStages :+= resultStage
}
CreateStageResult(
newPlan = newPlan,
allChildStagesMaterialized = allChildStagesMaterialized,
newStages = allNewStages)
}

private def createResultQueryStage(
resultHandler: SparkPlan => Any,
plan: SparkPlan): ResultQueryStageExec = {
ruleContext = ruleContext.withFinalStage(isFinalStage = true)
// Run the final plan when there's no more unfinished stages.
val optimizedRootPlan = applyPhysicalRulesWithRuleContext(
optimizeQueryStage(plan, isFinalStage = true),
postStageCreationRules(supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
ruleContext.clearConfigs()
val resultStage = ResultQueryStageExec(currentStageId, optimizedRootPlan, resultHandler)
currentStageId += 1
setLogicalLinkForNewQueryStage(resultStage, plan)
resultStage
}

/**
* This method is called recursively to traverse the plan tree bottom-up and create a new query
* stage or try reusing an existing stage if the current node is an [[Exchange]] node and all of
Expand All @@ -555,7 +608,7 @@ case class AdaptiveSparkPlanExec(
* 2) Whether the child query stages (if any) of the current node have all been materialized.
* 3) A list of the new query stages that have been created.
*/
private def createQueryStages(plan: SparkPlan): CreateStageResult = plan match {
private def createQueryStagesInternal(plan: SparkPlan): CreateStageResult = plan match {
case e: Exchange =>
// First have a quick check in the `stageCache` without having to traverse down the node.
context.stageCache.get(e.canonicalized) match {
Expand All @@ -568,7 +621,7 @@ case class AdaptiveSparkPlanExec(
newStages = if (isMaterialized) Seq.empty else Seq(stage))

case _ =>
val result = createQueryStages(e.child)
val result = createQueryStagesInternal(e.child)
val newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange]
// Create a query stage only when all the child query stages are ready.
if (result.allChildStagesMaterialized) {
Expand Down Expand Up @@ -612,7 +665,7 @@ case class AdaptiveSparkPlanExec(
if (plan.children.isEmpty) {
CreateStageResult(newPlan = plan, allChildStagesMaterialized = true, newStages = Seq.empty)
} else {
val results = plan.children.map(createQueryStages)
val results = plan.children.map(createQueryStagesInternal)
CreateStageResult(
newPlan = plan.withNewChildren(results.map(_.newPlan)),
allChildStagesMaterialized = results.forall(_.allChildStagesMaterialized),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ trait AdaptiveSparkPlanHelper {
}

/**
* Strip the executePlan of AdaptiveSparkPlanExec leaf node.
* Strip the top [[AdaptiveSparkPlanExec]] and [[ResultQueryStageExec]] nodes off
* the [[SparkPlan]].
*/
def stripAQEPlan(p: SparkPlan): SparkPlan = p match {
case a: AdaptiveSparkPlanExec => a.executedPlan
case a: AdaptiveSparkPlanExec => stripAQEPlan(a.executedPlan)
case ResultQueryStageExec(_, plan, _) => plan
case other => other
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.adaptive

import java.util.concurrent.atomic.AtomicReference

import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.Promise

import org.apache.spark.{MapOutputStatistics, SparkException}
import org.apache.spark.broadcast.Broadcast
Expand All @@ -32,7 +34,10 @@ import org.apache.spark.sql.columnar.CachedBatch
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ThreadUtils

/**
* A query stage is an independent subgraph of the query plan. AQE framework will materialize its
Expand Down Expand Up @@ -303,3 +308,43 @@ case class TableCacheQueryStageExec(

override def getRuntimeStatistics: Statistics = inMemoryTableScan.runtimeStatistics
}

case class ResultQueryStageExec(
override val id: Int,
override val plan: SparkPlan,
resultHandler: SparkPlan => Any) extends QueryStageExec {

override def resetMetrics(): Unit = {
plan.resetMetrics()
}

override protected def doMaterialize(): Future[Any] = {
val javaFuture = SQLExecution.withThreadLocalCaptured(
session,
ResultQueryStageExec.executionContext) {
resultHandler(plan)
}
val scalaPromise: Promise[Any] = Promise()
javaFuture.whenComplete { (result: Any, exception: Throwable) =>
if (exception != null) {
scalaPromise.failure(exception match {
case completionException: java.util.concurrent.CompletionException =>
completionException.getCause
case ex => ex
})
} else {
scalaPromise.success(result)
}
}
scalaPromise.future
}

// Result stage could be any SparkPlan, so we don't have a specific runtime statistics for it.
override def getRuntimeStatistics: Statistics = Statistics(sizeInBytes = 0, rowCount = None)
}

object ResultQueryStageExec {
private[execution] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("ResultQueryStageExecution",
SQLConf.get.getConf(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD)))
}
Loading