diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 56b2103c555db..44ac6b5269480 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -222,7 +222,38 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { val outerPlan = AnalysisContext.get.outerPlan if (outerPlan.isEmpty) return e - def resolve(nameParts: Seq[String]): Option[Expression] = try { + def findNestedSubqueryPlans(p: LogicalPlan): Seq[LogicalPlan] = { + if (!p.containsPattern(PLAN_EXPRESSION)) { + // There are no nested subquery plans in the current plan, + // stop searching for its children plan + return Seq.empty + } + + val subqueriesInThisNode: Seq[SubqueryExpression] = + p.expressions.flatMap(_.collect { + case in: InSubquery => in.query + case s: SubqueryExpression => s + }) + + val subqueryPlansFromExpressions: Seq[LogicalPlan] = + subqueriesInThisNode.flatMap(s => findNestedSubqueryPlans(s.plan) :+ s.plan) + + val subqueryPlansFromChildren: Seq[LogicalPlan] = + p.children.flatMap(findNestedSubqueryPlans) + + // Subquery plan in more inner position gets collected first + // As it is more near the position of the outer reference, it is more likely to have + // the original attributes. + // Though as there are no conflicts, the order does not affect correctness. + subqueryPlansFromChildren ++ subqueryPlansFromExpressions + } + + // The passed in `outerPlan` is the outermost plan + // Outer references can be from the `outerPlan` or any of its nested subquery plans + // So we need to try resolving the outer references by using all the plans + val outerPlans = Seq(outerPlan.get) ++ findNestedSubqueryPlans(outerPlan.get) + + def resolve(nameParts: Seq[String], outerPlan: Option[LogicalPlan]): Option[Expression] = try { outerPlan.get match { // Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions. // We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will @@ -240,14 +271,21 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { None } - e.transformWithPruning( - _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { case u: UnresolvedAttribute => - resolve(u.nameParts).getOrElse(u) + val maybeResolved = outerPlans.foldLeft(Option.empty[Expression]) { (acc, plan) => + // If we've already resolved, keep that; otherwise try this plan + acc.orElse(resolve(u.nameParts, Some(plan))) + } + maybeResolved.getOrElse(u) // Re-resolves `TempResolvedColumn` as outer references if it has tried to be resolved with // Aggregate but failed. case t: TempResolvedColumn if t.hasTried => - resolve(t.nameParts).getOrElse(t) + val maybeResolved = outerPlans.foldLeft(Option.empty[Expression]) { (acc, plan) => + // If we've already resolved, keep that; otherwise try this plan + acc.orElse(resolve(t.nameParts, Some(plan))) + } + maybeResolved.getOrElse(t) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c92b38d73edc7..98b9da118b6a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -5014,6 +5014,35 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-50983: Support nested correlated subquery for analyzer") { + sql("Create table Sales (product_id int, amount int)") + sql("create table Products (product_id int, category_id int, price double)") + + val query = + """ + | SELECT s.product_id + | FROM Sales s + | WHERE s.amount > ( + | SELECT AVG(s2.amount) + | FROM Sales s2 + | WHERE s2.product_id IN ( + | SELECT p2.product_id + | FROM Products p2 + | WHERE p2.product_id = s2.product_id + | ) + | ); + |""".stripMargin + + withTable("Sales", "Products") { + withSQLConf( + "spark.sql.optimizer.supportNestedCorrelatedSubqueries.enabled" -> "true", + "spark.sql.planChangeLog.level" -> "info" + ) { + val res = sql(query).collect() + } + } + } } case class Foo(bar: Option[String])