Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][SPARK-50892][SQL]Add UnionLoopExec, physical operator for recursion, to perform execution of recursive queries #49955

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -4421,6 +4421,12 @@
],
"sqlState" : "38000"
},
"RECURSION_LEVEL_LIMIT_EXCEEDED" : {
"message" : [
"Recursion level limit <levelLimit> reached but query has not exhausted, try increasing CTE_RECURSION_LEVEL_LIMIT"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we mention the config name here? Otherwise this error message is not actionable.

],
"sqlState" : "42836"
},
"RECURSIVE_CTE_IN_LEGACY_MODE" : {
"message" : [
"Recursive definitions cannot be used in legacy CTE precedence mode (spark.sql.legacy.ctePrecedencePolicy=LEGACY)."
Expand Down Expand Up @@ -5180,6 +5186,12 @@
],
"sqlState" : "42846"
},
"UNION_NOT_SUPPORTED_IN_RECURSIVE_CTE" : {
"message" : [
"The UNION operator is not yet supported within recursive common table expressions (WITH clauses that refer to themselves, directly or indirectly). Please use UNION ALL instead."
],
"sqlState" : "42836"
},
"UNKNOWN_PRIMITIVE_TYPE_IN_VARIANT" : {
"message" : [
"Unknown primitive type with id <id> was found in a variant value."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
// and we exclude those rows from the current iteration result.
case alias @ SubqueryAlias(_,
Distinct(Union(Seq(anchor, recursion), false, false))) =>
cteDef.failAnalysis(
errorClass = "UNION_NOT_SUPPORTED_IN_RECURSIVE_CTE",
messageParameters = Map.empty)
if (!anchor.resolved) {
cteDef
} else {
Expand All @@ -126,6 +129,9 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
colNames,
Distinct(Union(Seq(anchor, recursion), false, false))
)) =>
cteDef.failAnalysis(
errorClass = "UNION_NOT_SUPPORTED_IN_RECURSIVE_CTE",
messageParameters = Map.empty)
if (!anchor.resolved) {
cteDef
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,15 @@ object LimitPushDown extends Rule[LogicalPlan] {
case LocalLimit(exp, u: Union) =>
LocalLimit(exp, u.copy(children = u.children.map(maybePushLocalLimit(exp, _))))

// If limit node is present, we should propagate it down to UnionLoop, so that it is later
// propagated to UnionLoopExec.
// Limit node is constructed by placing GlobalLimit over LocalLimit (look at Limit apply method)
// that is the reason why we match it this way.
case g @ GlobalLimit(IntegerLiteral(limit), l @ LocalLimit(_, p @ Project(_, ul: UnionLoop))) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: case Limit(...) matches a GlobalLimit wrapping a LocalLimit with the same value, we should use it instead.

g.copy(child = l.copy(child = p.copy(child = ul.copy(limit = Some(limit)))))
case g @ GlobalLimit(IntegerLiteral(limit), l @ LocalLimit(_, u: UnionLoop)) =>
g.copy(child = l.copy(child = u.copy(limit = Some(limit))))

// Add extra limits below JOIN:
// 1. For LEFT OUTER and RIGHT OUTER JOIN, we push limits to the left and right sides
// respectively if join condition is not empty.
Expand Down Expand Up @@ -1031,6 +1040,9 @@ object ColumnPruning extends Rule[LogicalPlan] {
} else {
p
}
// TODO: Pruning `UnionLoop`s needs to take into account both the outer `Project` and the inner
// `UnionLoopRef` nodes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might hurt performance a lot. Let's figure it out

case p @ Project(_, _: UnionLoop) => p

// Prune unnecessary window expressions
case p @ Project(_, w: Window) if !w.windowOutputSet.subsetOf(p.references) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ import org.apache.spark.sql.internal.SQLConf
* @param id The id of the loop, inherited from [[CTERelationDef]] within which the Union lived.
* @param anchor The plan of the initial element of the loop.
* @param recursion The plan that describes the recursion with an [[UnionLoopRef]] node.
* @param limit An optional limit that can be pushed down to the node to stop the loop earlier.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the previous comment is good enough to describe this parameter.

* @param limit In case we have a plan with the limit node, we can push it down first to the
* UnionLoop, which will be then transferred to UnionLoopExec to stop the recursion
* after specific amount of rows is already generated.
*/
case class UnionLoop(
id: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4516,6 +4516,14 @@ object SQLConf {
.checkValues(LegacyBehaviorPolicy.values.map(_.toString))
.createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString)

val CTE_RECURSION_LEVEL_LIMIT = buildConf("spark.sql.cteRecursionLevelLimit")
.doc("Maximum level of recursion that is allowed wile executing a recursive CTE definition." +
"If a query does not get exhausted before reaching this limit it fails. Use -1 for " +
"unlimited.")
.version("4.0.0")
.intConf
.createWithDefault(100)

val LEGACY_INLINE_CTE_IN_COMMANDS = buildConf("spark.sql.legacy.inlineCTEInCommands")
.internal()
.doc("If true, always inline the CTE relations for the queries in commands. This is the " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
GlobalLimitExec(child = planLater(child), offset = offset) :: Nil
case union: logical.Union =>
execution.UnionExec(union.children.map(planLater)) :: Nil
case u @ logical.UnionLoop(id, anchor, recursion, limit) =>
execution.UnionLoopExec(id, anchor, recursion, u.output, limit) :: Nil
case g @ logical.Generate(generator, _, outer, _, _, child) =>
execution.GenerateExec(
generator, g.requiredChildOutput, outer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Limit, LogicalPlan, Project, Union, UnionLoopRef}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.classic.Dataset
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.types.{LongType, StructType}
Expand Down Expand Up @@ -714,6 +717,177 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan {
copy(children = newChildren)
}

/**
* The physical node for recursion. Currently only UNION ALL case is supported.
* For the details about the execution, look at the comment above doExecute function.
*
* A simple recursive query:
* {{{
* WITH RECURSIVE t(n) AS (
* SELECT 1
* UNION ALL
* SELECT n+1 FROM t WHERE n < 5)
* SELECT * FROM t;
* }}}
* Corresponding logical plan for the recursive query above:
* {{{
* WithCTE
* :- CTERelationDef 0, false
* : +- SubqueryAlias t
* : +- Project [1#0 AS n#3]
* : +- UnionLoop 0
* : :- Project [1 AS 1#0]
* : : +- OneRowRelation
* : +- Project [(n#1 + 1) AS (n + 1)#2]
* : +- Filter (n#1 < 5)
* : +- SubqueryAlias t
* : +- Project [1#0 AS n#1]
* : +- UnionLoopRef 0, [1#0], false
* +- Project [n#3]
* +- SubqueryAlias t
* +- CTERelationRef 0, true, [n#3], false, false
* }}}
*
* @param loopId This is id of the CTERelationDef containing the recursive query. Its value is
* first passed down to UnionLoop when creating it, and then to UnionLoopExec in
* SparkStrategies.
* @param anchor The logical plan of the initial element of the loop.
* @param recursion The logical plan that describes the recursion with an [[UnionLoopRef]] node.
* CTERelationRef, which is marked as recursive, gets substituted with
* [[UnionLoopRef]] in ResolveWithCTE.
* Both anchor and recursion are marked with @transient annotation, so that they
* are not serialized.
* @param output The output attributes of this loop.
* @param limit If defined, the total number of rows output by this operator will be bounded by
* limit.
* Its value is pushed down to UnionLoop in Optimizer in case Limit node is present
* in the logical plan and then transferred to UnionLoopExec in SparkStrategies.
* Note here: limit can be applied in the main query calling the recursive CTE, and not
* inside the recursive term of recursive CTE.
*/
case class UnionLoopExec(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's move it to a new file

loopId: Long,
@transient anchor: LogicalPlan,
@transient recursion: LogicalPlan,
override val output: Seq[Attribute],
limit: Option[Int] = None) extends LeafExecNode {

override def innerChildren: Seq[QueryPlan[_]] = Seq(anchor, recursion)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do they have to be inner children?


override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"numRecursiveLoops" -> SQLMetrics.createMetric(sparkContext, "number of recursive loops"))

/**
* This function executes the plan (optionally with appended limit node) and caches the result,
* with the caching mode specified in config.
*/
private def executeAndCacheAndCount(
plan: LogicalPlan, currentLimit: Int) = {
// In case limit is defined, we create a (global) limit node above the plan and execute
// the newly created plan.
// Note here: global limit requires coordination (shuffle) between partitions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it's better to use local limit? It's just a best effort to reduce the generated records.

val planOrLimitedPlan = if (limit.isDefined) {
Limit(Literal(currentLimit), plan)
} else {
plan
}
val df = Dataset.ofRows(session, planOrLimitedPlan)
val cachedDF = df.repartition()
val count = cachedDF.count()
(cachedDF, count)
}

/**
* In the first iteration, anchor term is executed.
* Then, in each following iteration, the UnionLoopRef node is substituted with the plan from the
* previous iteration, and such plan is executed.
* After every iteration, the dataframe is repartitioned.
* The recursion stops when the generated dataframe is empty, or either the limit or
* the specified maximum depth from the config is reached.
*/
override protected def doExecute(): RDD[InternalRow] = {
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
val numOutputRows = longMetric("numOutputRows")
val numRecursiveLoops = longMetric("numRecursiveLoops")
val levelLimit = conf.getConf(SQLConf.CTE_RECURSION_LEVEL_LIMIT)

// currentLimit is initialized from the limit argument, and in each step it is decreased by
// the number of rows generated in that step.
// If limit is not passed down, currentLimit is set to be zero and won't be considered in the
// condition of while loop down (limit.isEmpty will be true).
var currentLimit = limit.getOrElse(0)
val unionChildren = mutable.ArrayBuffer.empty[LogicalRDD]

var (prevDF, prevCount) = executeAndCacheAndCount(anchor, currentLimit)

var currentLevel = 1

// Main loop for obtaining the result of the recursive query.
while (prevCount > 0 && (limit.isEmpty || currentLimit > 0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one idea: the key here is to get the row count of the current iteration, so that we can decide if we should keep iterating or not. The shuffle is only to save recomputing of the query. But for very simple queries (e.g. local scan with simple filter/project), shuffle is probably more expensive than recomputing. We should detect such case and avoid shuffle.


if (levelLimit != -1 && currentLevel > levelLimit) {
throw new SparkException(
errorClass = "RECURSION_LEVEL_LIMIT_EXCEEDED",
messageParameters = Map("levelLimit" -> levelLimit.toString),
cause = null)
}

// Inherit stats and constraints from the dataset of the previous iteration.
val prevPlan = LogicalRDD.fromDataset(prevDF.queryExecution.toRdd, prevDF, prevDF.isStreaming)
.newInstance()
unionChildren += prevPlan

// Update metrics
numOutputRows += prevCount
numRecursiveLoops += 1
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)

// the current plan is created by substituting UnionLoopRef node with the project node of
// the previous plan.
// This way we support only UNION ALL case. Additional case should be added for UNION case.
// One way of supporting UNION case can be seen at SPARK-24497 PR from Peter Toth.
val newRecursion = recursion.transform {
case r: UnionLoopRef =>
val prevPlanToRefMapping = prevPlan.output.zip(r.output).map {
case (fa, ta) => Alias(fa, ta.name)(ta.exprId)
}
Project(prevPlanToRefMapping, prevPlan)
}

val (df, count) = executeAndCacheAndCount(newRecursion, currentLimit)
prevDF = df
prevCount = count

currentLevel += 1
if (limit.isDefined) {
currentLimit -= count.toInt
}
}

if (unionChildren.isEmpty) {
new EmptyRDD[InternalRow](sparkContext)
} else if (unionChildren.length == 1) {
Dataset.ofRows(session, unionChildren.head).queryExecution.toRdd
} else {
Dataset.ofRows(session, Union(unionChildren.toSeq)).queryExecution.toRdd
}
}

override def doCanonicalize(): SparkPlan =
super.doCanonicalize().asInstanceOf[UnionLoopExec]
.copy(anchor = anchor.canonicalized, recursion = recursion.canonicalized)

override def verboseStringWithOperatorId(): String = {
s"""
|$formattedNodeName
|Loop id: $loopId
|${QueryPlan.generateFieldString("Output", output)}
|Limit: $limit
|""".stripMargin
}
}

/**
* Physical plan for returning a new RDD that has exactly `numPartitions` partitions.
* Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g.
Expand Down
Loading