-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP][SPARK-50892][SQL]Add UnionLoopExec, physical operator for recursion, to perform execution of recursive queries #49955
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -848,6 +848,15 @@ object LimitPushDown extends Rule[LogicalPlan] { | |
case LocalLimit(exp, u: Union) => | ||
LocalLimit(exp, u.copy(children = u.children.map(maybePushLocalLimit(exp, _)))) | ||
|
||
// If limit node is present, we should propagate it down to UnionLoop, so that it is later | ||
// propagated to UnionLoopExec. | ||
// Limit node is constructed by placing GlobalLimit over LocalLimit (look at Limit apply method) | ||
// that is the reason why we match it this way. | ||
case g @ GlobalLimit(IntegerLiteral(limit), l @ LocalLimit(_, p @ Project(_, ul: UnionLoop))) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
g.copy(child = l.copy(child = p.copy(child = ul.copy(limit = Some(limit))))) | ||
case g @ GlobalLimit(IntegerLiteral(limit), l @ LocalLimit(_, u: UnionLoop)) => | ||
g.copy(child = l.copy(child = u.copy(limit = Some(limit)))) | ||
|
||
// Add extra limits below JOIN: | ||
// 1. For LEFT OUTER and RIGHT OUTER JOIN, we push limits to the left and right sides | ||
// respectively if join condition is not empty. | ||
|
@@ -1031,6 +1040,9 @@ object ColumnPruning extends Rule[LogicalPlan] { | |
} else { | ||
p | ||
} | ||
// TODO: Pruning `UnionLoop`s needs to take into account both the outer `Project` and the inner | ||
// `UnionLoopRef` nodes. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might hurt performance a lot. Let's figure it out |
||
case p @ Project(_, _: UnionLoop) => p | ||
|
||
// Prune unnecessary window expressions | ||
case p @ Project(_, w: Window) if !w.windowOutputSet.subsetOf(p.references) => | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,9 @@ import org.apache.spark.sql.internal.SQLConf | |
* @param id The id of the loop, inherited from [[CTERelationDef]] within which the Union lived. | ||
* @param anchor The plan of the initial element of the loop. | ||
* @param recursion The plan that describes the recursion with an [[UnionLoopRef]] node. | ||
* @param limit An optional limit that can be pushed down to the node to stop the loop earlier. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the previous comment is good enough to describe this parameter. |
||
* @param limit In case we have a plan with the limit node, we can push it down first to the | ||
* UnionLoop, which will be then transferred to UnionLoopExec to stop the recursion | ||
* after specific amount of rows is already generated. | ||
*/ | ||
case class UnionLoop( | ||
id: Long, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,10 @@ import org.apache.spark.sql.catalyst.InternalRow | |
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences | ||
import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
import org.apache.spark.sql.catalyst.plans.QueryPlan | ||
import org.apache.spark.sql.catalyst.plans.logical.{Limit, LogicalPlan, Project, Union, UnionLoopRef} | ||
import org.apache.spark.sql.catalyst.plans.physical._ | ||
import org.apache.spark.sql.classic.Dataset | ||
import org.apache.spark.sql.execution.metric.SQLMetrics | ||
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} | ||
import org.apache.spark.sql.types.{LongType, StructType} | ||
|
@@ -714,6 +717,177 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { | |
copy(children = newChildren) | ||
} | ||
|
||
/** | ||
* The physical node for recursion. Currently only UNION ALL case is supported. | ||
* For the details about the execution, look at the comment above doExecute function. | ||
* | ||
* A simple recursive query: | ||
* {{{ | ||
* WITH RECURSIVE t(n) AS ( | ||
* SELECT 1 | ||
* UNION ALL | ||
* SELECT n+1 FROM t WHERE n < 5) | ||
* SELECT * FROM t; | ||
* }}} | ||
* Corresponding logical plan for the recursive query above: | ||
* {{{ | ||
* WithCTE | ||
* :- CTERelationDef 0, false | ||
* : +- SubqueryAlias t | ||
* : +- Project [1#0 AS n#3] | ||
* : +- UnionLoop 0 | ||
* : :- Project [1 AS 1#0] | ||
* : : +- OneRowRelation | ||
* : +- Project [(n#1 + 1) AS (n + 1)#2] | ||
* : +- Filter (n#1 < 5) | ||
* : +- SubqueryAlias t | ||
* : +- Project [1#0 AS n#1] | ||
* : +- UnionLoopRef 0, [1#0], false | ||
* +- Project [n#3] | ||
* +- SubqueryAlias t | ||
* +- CTERelationRef 0, true, [n#3], false, false | ||
* }}} | ||
* | ||
* @param loopId This is id of the CTERelationDef containing the recursive query. Its value is | ||
* first passed down to UnionLoop when creating it, and then to UnionLoopExec in | ||
* SparkStrategies. | ||
* @param anchor The logical plan of the initial element of the loop. | ||
* @param recursion The logical plan that describes the recursion with an [[UnionLoopRef]] node. | ||
* CTERelationRef, which is marked as recursive, gets substituted with | ||
* [[UnionLoopRef]] in ResolveWithCTE. | ||
* Both anchor and recursion are marked with @transient annotation, so that they | ||
* are not serialized. | ||
* @param output The output attributes of this loop. | ||
* @param limit If defined, the total number of rows output by this operator will be bounded by | ||
* limit. | ||
* Its value is pushed down to UnionLoop in Optimizer in case Limit node is present | ||
* in the logical plan and then transferred to UnionLoopExec in SparkStrategies. | ||
* Note here: limit can be applied in the main query calling the recursive CTE, and not | ||
* inside the recursive term of recursive CTE. | ||
*/ | ||
case class UnionLoopExec( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: let's move it to a new file |
||
loopId: Long, | ||
@transient anchor: LogicalPlan, | ||
@transient recursion: LogicalPlan, | ||
override val output: Seq[Attribute], | ||
limit: Option[Int] = None) extends LeafExecNode { | ||
|
||
override def innerChildren: Seq[QueryPlan[_]] = Seq(anchor, recursion) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do they have to be inner children? |
||
|
||
override lazy val metrics = Map( | ||
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), | ||
"numRecursiveLoops" -> SQLMetrics.createMetric(sparkContext, "number of recursive loops")) | ||
|
||
/** | ||
* This function executes the plan (optionally with appended limit node) and caches the result, | ||
* with the caching mode specified in config. | ||
*/ | ||
private def executeAndCacheAndCount( | ||
plan: LogicalPlan, currentLimit: Int) = { | ||
// In case limit is defined, we create a (global) limit node above the plan and execute | ||
// the newly created plan. | ||
// Note here: global limit requires coordination (shuffle) between partitions. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then it's better to use local limit? It's just a best effort to reduce the generated records. |
||
val planOrLimitedPlan = if (limit.isDefined) { | ||
Limit(Literal(currentLimit), plan) | ||
} else { | ||
plan | ||
} | ||
val df = Dataset.ofRows(session, planOrLimitedPlan) | ||
val cachedDF = df.repartition() | ||
val count = cachedDF.count() | ||
(cachedDF, count) | ||
} | ||
|
||
/** | ||
* In the first iteration, anchor term is executed. | ||
* Then, in each following iteration, the UnionLoopRef node is substituted with the plan from the | ||
* previous iteration, and such plan is executed. | ||
* After every iteration, the dataframe is repartitioned. | ||
* The recursion stops when the generated dataframe is empty, or either the limit or | ||
* the specified maximum depth from the config is reached. | ||
*/ | ||
override protected def doExecute(): RDD[InternalRow] = { | ||
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) | ||
val numOutputRows = longMetric("numOutputRows") | ||
val numRecursiveLoops = longMetric("numRecursiveLoops") | ||
val levelLimit = conf.getConf(SQLConf.CTE_RECURSION_LEVEL_LIMIT) | ||
|
||
// currentLimit is initialized from the limit argument, and in each step it is decreased by | ||
// the number of rows generated in that step. | ||
// If limit is not passed down, currentLimit is set to be zero and won't be considered in the | ||
// condition of while loop down (limit.isEmpty will be true). | ||
var currentLimit = limit.getOrElse(0) | ||
val unionChildren = mutable.ArrayBuffer.empty[LogicalRDD] | ||
|
||
var (prevDF, prevCount) = executeAndCacheAndCount(anchor, currentLimit) | ||
|
||
var currentLevel = 1 | ||
|
||
// Main loop for obtaining the result of the recursive query. | ||
while (prevCount > 0 && (limit.isEmpty || currentLimit > 0)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one idea: the key here is to get the row count of the current iteration, so that we can decide if we should keep iterating or not. The shuffle is only to save recomputing of the query. But for very simple queries (e.g. local scan with simple filter/project), shuffle is probably more expensive than recomputing. We should detect such case and avoid shuffle. |
||
|
||
if (levelLimit != -1 && currentLevel > levelLimit) { | ||
throw new SparkException( | ||
errorClass = "RECURSION_LEVEL_LIMIT_EXCEEDED", | ||
messageParameters = Map("levelLimit" -> levelLimit.toString), | ||
cause = null) | ||
} | ||
|
||
// Inherit stats and constraints from the dataset of the previous iteration. | ||
val prevPlan = LogicalRDD.fromDataset(prevDF.queryExecution.toRdd, prevDF, prevDF.isStreaming) | ||
.newInstance() | ||
unionChildren += prevPlan | ||
|
||
// Update metrics | ||
numOutputRows += prevCount | ||
numRecursiveLoops += 1 | ||
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) | ||
|
||
// the current plan is created by substituting UnionLoopRef node with the project node of | ||
// the previous plan. | ||
// This way we support only UNION ALL case. Additional case should be added for UNION case. | ||
// One way of supporting UNION case can be seen at SPARK-24497 PR from Peter Toth. | ||
val newRecursion = recursion.transform { | ||
case r: UnionLoopRef => | ||
val prevPlanToRefMapping = prevPlan.output.zip(r.output).map { | ||
case (fa, ta) => Alias(fa, ta.name)(ta.exprId) | ||
} | ||
Project(prevPlanToRefMapping, prevPlan) | ||
} | ||
|
||
val (df, count) = executeAndCacheAndCount(newRecursion, currentLimit) | ||
prevDF = df | ||
prevCount = count | ||
|
||
currentLevel += 1 | ||
if (limit.isDefined) { | ||
currentLimit -= count.toInt | ||
} | ||
} | ||
|
||
if (unionChildren.isEmpty) { | ||
new EmptyRDD[InternalRow](sparkContext) | ||
} else if (unionChildren.length == 1) { | ||
Dataset.ofRows(session, unionChildren.head).queryExecution.toRdd | ||
} else { | ||
Dataset.ofRows(session, Union(unionChildren.toSeq)).queryExecution.toRdd | ||
} | ||
} | ||
|
||
override def doCanonicalize(): SparkPlan = | ||
super.doCanonicalize().asInstanceOf[UnionLoopExec] | ||
.copy(anchor = anchor.canonicalized, recursion = recursion.canonicalized) | ||
|
||
override def verboseStringWithOperatorId(): String = { | ||
s""" | ||
|$formattedNodeName | ||
|Loop id: $loopId | ||
|${QueryPlan.generateFieldString("Output", output)} | ||
|Limit: $limit | ||
|""".stripMargin | ||
} | ||
} | ||
|
||
/** | ||
* Physical plan for returning a new RDD that has exactly `numPartitions` partitions. | ||
* Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we mention the config name here? Otherwise this error message is not actionable.