Skip to content

Commit 0d40ac0

Browse files
committed
fix analyzer
1 parent ae31b98 commit 0d40ac0

File tree

2 files changed

+72
-5
lines changed

2 files changed

+72
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,38 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
222222
val outerPlan = AnalysisContext.get.outerPlan
223223
if (outerPlan.isEmpty) return e
224224

225-
def resolve(nameParts: Seq[String]): Option[Expression] = try {
225+
def findNestedSubqueryPlans(p: LogicalPlan): Seq[LogicalPlan] = {
226+
if (!p.containsPattern(PLAN_EXPRESSION)) {
227+
// There are no nested subquery plans in the current plan,
228+
// stop searching for its children plan
229+
return Seq.empty
230+
}
231+
232+
val subqueriesInThisNode: Seq[SubqueryExpression] =
233+
p.expressions.flatMap(_.collect {
234+
case in: InSubquery => in.query
235+
case s: SubqueryExpression => s
236+
})
237+
238+
val subqueryPlansFromExpressions: Seq[LogicalPlan] =
239+
subqueriesInThisNode.flatMap(s => findNestedSubqueryPlans(s.plan) :+ s.plan)
240+
241+
val subqueryPlansFromChildren: Seq[LogicalPlan] =
242+
p.children.flatMap(findNestedSubqueryPlans)
243+
244+
// Subquery plan in more inner position gets collected first
245+
// As it is more near the position of the outer reference, it is more likely to have
246+
// the original attributes.
247+
// Though as there are no conflicts, the order does not affect correctness.
248+
subqueryPlansFromChildren ++ subqueryPlansFromExpressions
249+
}
250+
251+
// The passed in `outerPlan` is the outermost plan
252+
// Outer references can be from the `outerPlan` or any of its nested subquery plans
253+
// So we need to try resolving the outer references by using all the plans
254+
val outerPlans = Seq(outerPlan.get) ++ findNestedSubqueryPlans(outerPlan.get)
255+
256+
def resolve(nameParts: Seq[String], outerPlan: Option[LogicalPlan]): Option[Expression] = try {
226257
outerPlan.get match {
227258
// Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions.
228259
// We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will
@@ -240,14 +271,21 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
240271
None
241272
}
242273

243-
e.transformWithPruning(
244-
_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) {
274+
e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) {
245275
case u: UnresolvedAttribute =>
246-
resolve(u.nameParts).getOrElse(u)
276+
val maybeResolved = outerPlans.foldLeft(Option.empty[Expression]) { (acc, plan) =>
277+
// If we've already resolved, keep that; otherwise try this plan
278+
acc.orElse(resolve(u.nameParts, Some(plan)))
279+
}
280+
maybeResolved.getOrElse(u)
247281
// Re-resolves `TempResolvedColumn` as outer references if it has tried to be resolved with
248282
// Aggregate but failed.
249283
case t: TempResolvedColumn if t.hasTried =>
250-
resolve(t.nameParts).getOrElse(t)
284+
val maybeResolved = outerPlans.foldLeft(Option.empty[Expression]) { (acc, plan) =>
285+
// If we've already resolved, keep that; otherwise try this plan
286+
acc.orElse(resolve(t.nameParts, Some(plan)))
287+
}
288+
maybeResolved.getOrElse(t)
251289
}
252290
}
253291

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5014,6 +5014,35 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
50145014
}
50155015
}
50165016
}
5017+
5018+
test("SPARK-50983: Support nested correlated subquery for analyzer") {
5019+
sql("Create table Sales (product_id int, amount int)")
5020+
sql("create table Products (product_id int, category_id int, price double)")
5021+
5022+
val query =
5023+
"""
5024+
| SELECT s.product_id
5025+
| FROM Sales s
5026+
| WHERE s.amount > (
5027+
| SELECT AVG(s2.amount)
5028+
| FROM Sales s2
5029+
| WHERE s2.product_id IN (
5030+
| SELECT p2.product_id
5031+
| FROM Products p2
5032+
| WHERE p2.product_id = s2.product_id
5033+
| )
5034+
| );
5035+
|""".stripMargin
5036+
5037+
withTable("Sales", "Products") {
5038+
withSQLConf(
5039+
"spark.sql.optimizer.supportNestedCorrelatedSubqueries.enabled" -> "true",
5040+
"spark.sql.planChangeLog.level" -> "info"
5041+
) {
5042+
val res = sql(query).collect()
5043+
}
5044+
}
5045+
}
50175046
}
50185047

50195048
case class Foo(bar: Option[String])

0 commit comments

Comments
 (0)