Skip to content

Commit

Permalink
fix analyzer
Browse files Browse the repository at this point in the history
  • Loading branch information
AveryQi115 committed Jan 30, 2025
1 parent ae31b98 commit 0d40ac0
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,38 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
val outerPlan = AnalysisContext.get.outerPlan
if (outerPlan.isEmpty) return e

def resolve(nameParts: Seq[String]): Option[Expression] = try {
def findNestedSubqueryPlans(p: LogicalPlan): Seq[LogicalPlan] = {
if (!p.containsPattern(PLAN_EXPRESSION)) {
// There are no nested subquery plans in the current plan,
// stop searching for its children plan
return Seq.empty
}

val subqueriesInThisNode: Seq[SubqueryExpression] =
p.expressions.flatMap(_.collect {
case in: InSubquery => in.query
case s: SubqueryExpression => s
})

val subqueryPlansFromExpressions: Seq[LogicalPlan] =
subqueriesInThisNode.flatMap(s => findNestedSubqueryPlans(s.plan) :+ s.plan)

val subqueryPlansFromChildren: Seq[LogicalPlan] =
p.children.flatMap(findNestedSubqueryPlans)

// Subquery plan in more inner position gets collected first
// As it is more near the position of the outer reference, it is more likely to have
// the original attributes.
// Though as there are no conflicts, the order does not affect correctness.
subqueryPlansFromChildren ++ subqueryPlansFromExpressions
}

// The passed in `outerPlan` is the outermost plan
// Outer references can be from the `outerPlan` or any of its nested subquery plans
// So we need to try resolving the outer references by using all the plans
val outerPlans = Seq(outerPlan.get) ++ findNestedSubqueryPlans(outerPlan.get)

def resolve(nameParts: Seq[String], outerPlan: Option[LogicalPlan]): Option[Expression] = try {
outerPlan.get match {
// Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions.
// We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will
Expand All @@ -240,14 +271,21 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
None
}

e.transformWithPruning(
_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) {
e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) {
case u: UnresolvedAttribute =>
resolve(u.nameParts).getOrElse(u)
val maybeResolved = outerPlans.foldLeft(Option.empty[Expression]) { (acc, plan) =>
// If we've already resolved, keep that; otherwise try this plan
acc.orElse(resolve(u.nameParts, Some(plan)))
}
maybeResolved.getOrElse(u)
// Re-resolves `TempResolvedColumn` as outer references if it has tried to be resolved with
// Aggregate but failed.
case t: TempResolvedColumn if t.hasTried =>
resolve(t.nameParts).getOrElse(t)
val maybeResolved = outerPlans.foldLeft(Option.empty[Expression]) { (acc, plan) =>
// If we've already resolved, keep that; otherwise try this plan
acc.orElse(resolve(t.nameParts, Some(plan)))
}
maybeResolved.getOrElse(t)
}
}

Expand Down
29 changes: 29 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5014,6 +5014,35 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}
}

test("SPARK-50983: Support nested correlated subquery for analyzer") {
sql("Create table Sales (product_id int, amount int)")
sql("create table Products (product_id int, category_id int, price double)")

val query =
"""
| SELECT s.product_id
| FROM Sales s
| WHERE s.amount > (
| SELECT AVG(s2.amount)
| FROM Sales s2
| WHERE s2.product_id IN (
| SELECT p2.product_id
| FROM Products p2
| WHERE p2.product_id = s2.product_id
| )
| );
|""".stripMargin

withTable("Sales", "Products") {
withSQLConf(
"spark.sql.optimizer.supportNestedCorrelatedSubqueries.enabled" -> "true",
"spark.sql.planChangeLog.level" -> "info"
) {
val res = sql(query).collect()
}
}
}
}

case class Foo(bar: Option[String])

0 comments on commit 0d40ac0

Please sign in to comment.