From a4faa91657eb52ef8f952ade740b4143943a126d Mon Sep 17 00:00:00 2001 From: Avery Qi Date: Fri, 24 Jan 2025 13:49:15 -0800 Subject: [PATCH 01/13] sync --- .../resources/error/error-conditions.json | 5 + .../sql/catalyst/analysis/Analyzer.scala | 114 ++++++++++++++---- .../sql/catalyst/analysis/CheckAnalysis.scala | 55 +++++++-- ...ctionTableSubqueryArgumentExpression.scala | 13 +- .../sql/catalyst/expressions/subquery.scala | 59 ++++++++- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../sql/catalyst/optimizer/expressions.scala | 4 +- .../sql/catalyst/optimizer/subquery.scala | 37 +++--- .../apache/spark/sql/internal/SQLConf.scala | 8 ++ .../adaptive/PlanAdaptiveSubqueries.scala | 4 +- .../apache/spark/sql/execution/subquery.scala | 2 +- 11 files changed, 245 insertions(+), 58 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 48a71cfbf615f..ce3f765851159 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5906,6 +5906,11 @@ "Correlated scalar subqueries must be aggregated to return at most one row." ] }, + "NESTED_CORRELATED_SUBQUERIES_NOT_SUPPORTED" : { + "message" : [ + "Nested correlated subqueries are not supported." + ] + }, "NON_CORRELATED_COLUMNS_IN_GROUP_BY" : { "message" : [ "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns: ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 92cfc4119dd0c..2bcaca9f7bcc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2263,6 +2263,37 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor * Note: CTEs are handled in CTESubstitution. */ object ResolveSubquery extends Rule[LogicalPlan] { + + private def getOuterAttrsNeedToBePropagated(plan: LogicalPlan): Seq[Expression] = { + plan.expressions.flatMap { + case subExpr: SubqueryExpression => subExpr.getUnresolvedOuterAttrs + case in: InSubquery => in.query.getUnresolvedOuterAttrs + case expr if expr.containsPattern(PLAN_EXPRESSION) => + expr.collect { + case subExpr: SubqueryExpression => subExpr.getUnresolvedOuterAttrs + }.flatten + case _ => Seq.empty + } ++ plan.children.flatMap{ + case p if p.containsPattern(PLAN_EXPRESSION) => + getOuterAttrsNeedToBePropagated(p) + case _ => Seq.empty + } + } + + private def getUnresolvedOuterReferences( + s: SubqueryExpression, p: LogicalPlan + ): Seq[Expression] = { + val outerReferencesInSubquery = s.getOuterAttrs + + // return outer references cannot be handled in current plan + outerReferencesInSubquery.filter( + _ match { + case a: AttributeReference => !p.inputSet.contains(a) + case _ => false + } + ) + } + /** * Resolves the subquery plan that is referenced in a subquery expression, by invoking the * entire analyzer recursively. We set outer plan in `AnalysisContext`, so that the analyzer @@ -2274,15 +2305,23 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor e: SubqueryExpression, outer: LogicalPlan)( f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { - val newSubqueryPlan = AnalysisContext.withOuterPlan(outer) { - executeSameContext(e.plan) + val newSubqueryPlan = if (AnalysisContext.get.outerPlan.isDefined) { + val propogatedOuterPlan = AnalysisContext.get.outerPlan.get + AnalysisContext.withOuterPlan(propogatedOuterPlan) { + executeSameContext(e.plan) + } + } else { + AnalysisContext.withOuterPlan(outer) { + executeSameContext(e.plan) + } } // If the subquery plan is fully resolved, pull the outer references and record // them as children of SubqueryExpression. if (newSubqueryPlan.resolved) { // Record the outer references as children of subquery expression. - f(newSubqueryPlan, SubExprUtils.getOuterReferences(newSubqueryPlan)) + f(newSubqueryPlan, SubExprUtils.getOuterReferences(newSubqueryPlan) ++ + getOuterAttrsNeedToBePropagated(newSubqueryPlan)) } else { e.withNewPlan(newSubqueryPlan) } @@ -2299,18 +2338,45 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor */ private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) { - case s @ ScalarSubquery(sub, _, exprId, _, _, _, _) if !sub.resolved => - resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId)) - case e @ Exists(sub, _, exprId, _, _) if !sub.resolved => - resolveSubQuery(e, outer)(Exists(_, _, exprId)) + // There are four kinds of outer references here: + // 1. Outer references which are newly introduced in the subquery `res` + // which can be resolved in current `plan`. + // It is extracted by `SubExprUtils.getOuterReferences(res.plan)` and + // stored among res.outerAttrs + // 2. Outer references which are newly introduced in the subquery `res` + // which cannot be resolved in current `plan` + // It is extracted by `SubExprUtils.getOuterReferences(res.plan)` with + // `getUnresolvedOuterReferences(res, plan)` filter and stored in + // res.unresolvedOuterAttrs + // 3. Outer references which are introduced by nested subquery within `res.plan` + // which can be resolved in current `plan` + // It is extracted by `getOuterAttrsNeedToBePropagated(res.plan)`, filtered + // by `plan.inputSet.contains(_)`, need to be stored in res.outerAttrs + // 4. Outer references which are introduced by nested subquery within `res.plan` + // which cannot be resolved in current `plan` + // It is extracted by `getOuterAttrsNeedToBePropagated(res.plan)`, filtered + // by `!plan.inputSet.contains(_)`, need to be stored in + // res.outerAttrs and res.unresolvedOuterAttrs + case s @ ScalarSubquery(sub, _, _, exprId, _, _, _, _) if !sub.resolved => + val res = resolveSubQuery(s, outer)(ScalarSubquery(_, _, Seq.empty, exprId)) + val unresolvedOuterReferences = getUnresolvedOuterReferences(res, plan) + res.withNewUnresolvedOuterAttrs(unresolvedOuterReferences) + case e @ Exists(sub, _, _, exprId, _, _) if !sub.resolved => + val res = resolveSubQuery(e, outer)(Exists(_, _, Seq.empty, exprId)) + val unresolvedOuterReferences = getUnresolvedOuterReferences(res, plan) + res.withNewUnresolvedOuterAttrs(unresolvedOuterReferences) case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _)) if values.forall(_.resolved) && !l.resolved => val expr = resolveSubQuery(l, outer)((plan, exprs) => { - ListQuery(plan, exprs, exprId, plan.output.length) - }) - InSubquery(values, expr.asInstanceOf[ListQuery]) - case s @ LateralSubquery(sub, _, exprId, _, _) if !sub.resolved => - resolveSubQuery(s, outer)(LateralSubquery(_, _, exprId)) + ListQuery(plan, exprs, Seq.empty, exprId, plan.output.length) + }).asInstanceOf[ListQuery] + val unresolvedOuterReferences = getUnresolvedOuterReferences(expr, plan) + val newExpr = expr.withNewUnresolvedOuterAttrs(unresolvedOuterReferences) + InSubquery(values, newExpr) + case s @ LateralSubquery(sub, _, _, exprId, _, _) if !sub.resolved => + val res = resolveSubQuery(s, outer)(LateralSubquery(_, _, Seq.empty, exprId)) + val unresolvedOuterReferences = getUnresolvedOuterReferences(res, plan) + res.withNewUnresolvedOuterAttrs(unresolvedOuterReferences) case a: FunctionTableSubqueryArgumentExpression if !a.plan.resolved => resolveSubQuery(a, outer)( (plan, outerAttrs) => a.copy(plan = plan, outerAttrs = outerAttrs)) @@ -4043,25 +4109,25 @@ object ResolveExpressionsWithNamePlaceholders extends Rule[LogicalPlan] { * +- Filter exists#245 [min(b#227)#249] * : +- Project [1 AS 1#247] * : +- Filter (d#238 < min(outer(b#227))) <----- - * : +- SubqueryAlias r - * : +- Project [_1#234 AS c#237, _2#235 AS d#238] - * : +- LocalRelation [_1#234, _2#235] + * :- SubqueryAlias r + * : - Project [_1#234 AS c#237, _2#235 AS d#238] + * : - LocalRelation [_1#234, _2#235] * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] - * +- SubqueryAlias l - * +- Project [_1#223 AS a#226, _2#224 AS b#227] - * +- LocalRelation [_1#223, _2#224] + * - SubqueryAlias l + * - Project [_1#223 AS a#226, _2#224 AS b#227] + * - LocalRelation [_1#223, _2#224] * Plan after the rule. * Project [a#226] * +- Filter exists#245 [min(b#227)#249] * : +- Project [1 AS 1#247] * : +- Filter (d#238 < outer(min(b#227)#249)) <----- - * : +- SubqueryAlias r - * : +- Project [_1#234 AS c#237, _2#235 AS d#238] - * : +- LocalRelation [_1#234, _2#235] + * :- SubqueryAlias r + * : - Project [_1#234 AS c#237, _2#235 AS d#238] + * : - LocalRelation [_1#234, _2#235] * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] - * +- SubqueryAlias l - * +- Project [_1#223 AS a#226, _2#224 AS b#227] - * +- LocalRelation [_1#223, _2#224] + * - SubqueryAlias l + * - Project [_1#223 AS a#226, _2#224 AS b#227] + * - LocalRelation [_1#223, _2#224] */ object UpdateOuterReferences extends Rule[LogicalPlan] { private def stripAlias(expr: Expression): Expression = expr match { case a: Alias => a.child } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 269411a6fc616..1bc7275d3e8b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -231,6 +231,32 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB } } + def checkNoUnresolvedOuterReferencesInMainQuery(plan: LogicalPlan): Unit = { + plan.expressions.foreach { + case subExpr: SubqueryExpression if subExpr.getUnresolvedOuterAttrs.nonEmpty => + subExpr.failAnalysis( + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_NOT_FOUND", + messageParameters = Map.empty) + case in: InSubquery if in.query.getUnresolvedOuterAttrs.nonEmpty => + in.query.failAnalysis( + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_NOT_FOUND", + messageParameters = Map.empty) + case expr if expr.containsPattern(PLAN_EXPRESSION) => + expr.collect { + case subExpr: SubqueryExpression if subExpr.getUnresolvedOuterAttrs.nonEmpty => + subExpr.failAnalysis( + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_NOT_FOUND", + messageParameters = Map.empty) + } + case _ => + } + plan.children.foreach { + case p: LogicalPlan if p.containsPattern(PLAN_EXPRESSION) => + checkNoUnresolvedOuterReferencesInMainQuery(p) + case _ => + } + } + def checkAnalysis(plan: LogicalPlan): Unit = { // We should inline all CTE relations to restore the original plan shape, as the analysis check // may need to match certain plan shapes. For dangling CTE relations, they will still be kept @@ -244,6 +270,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB } preemptedError.clear() try { + checkNoUnresolvedOuterReferencesInMainQuery(inlinedPlan) checkAnalysis0(inlinedPlan) preemptedError.getErrorOpt().foreach(throw _) // throw preempted error if any } catch { @@ -1137,6 +1164,20 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case _ => } + def checkUnresolvedOuterReferences(expr: SubqueryExpression): Unit = { + if ((!SQLConf.get.getConf(SQLConf.SUPPORT_NESTED_CORRELATED_SUBQUERIES)) && + expr.getUnresolvedOuterAttrs.nonEmpty) { + expr.failAnalysis( + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + "NESTED_CORRELATED_SUBQUERIES_NOT_SUPPORTED", + messageParameters = Map.empty) + } + } + + // Check if there are nested correlated subqueries in the plan. + checkUnresolvedOuterReferences(expr) + + // Validate the subquery plan. checkAnalysis0(expr.plan) @@ -1144,7 +1185,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB checkOuterReference(plan, expr) expr match { - case ScalarSubquery(query, outerAttrs, _, _, _, _, _) => + case ScalarSubquery(query, outerAttrs, _, _, _, _, _, _) => // Scalar subquery must return one column as output. if (query.output.size != 1) { throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size, @@ -1354,9 +1395,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB // +- Project [c1#87, c2#88] // : (Aggregate or Window operator) // : +- Filter [outer(c2#77) >= c2#88)] - // : +- SubqueryAlias t2, `t2` - // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] - // : +- LocalRelation [_1#84, _2#85] + // : - SubqueryAlias t2, `t2` + // : - Project [_1#84 AS c1#87, _2#85 AS c2#88] + // : - LocalRelation [_1#84, _2#85] // +- SubqueryAlias t1, `t1` // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] // +- LocalRelation [_1#73, _2#74] @@ -1373,7 +1414,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB // Original subquery plan: // Aggregate [count(1)] // +- Filter ((a + b) = outer(c)) - // +- LocalRelation [a, b] + //- LocalRelation [a, b] // // Plan after pulling up correlated predicates: // Aggregate [a, b] [count(1), a, b] @@ -1383,8 +1424,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB // Project [c1, count(1)] // +- Join LeftOuter ((a + b) = c) // :- LocalRelation [c] - // +- Aggregate [a, b] [count(1), a, b] - // +- LocalRelation [a, b] + //- Aggregate [a, b] [count(1), a, b] + // - LocalRelation [a, b] // // The right hand side of the join transformed from the subquery will output // count(1) | a | b diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala index bfd3bc8051dff..844f1c984507a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala @@ -67,12 +67,14 @@ import org.apache.spark.sql.types.DataType case class FunctionTableSubqueryArgumentExpression( plan: LogicalPlan, outerAttrs: Seq[Expression] = Seq.empty, + unresolvedOuterAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, partitionByExpressions: Seq[Expression] = Seq.empty, withSinglePartition: Boolean = false, orderByExpressions: Seq[SortOrder] = Seq.empty, selectedInputExpressions: Seq[PythonUDTFSelectedExpression] = Seq.empty) - extends SubqueryExpression(plan, outerAttrs, exprId, Seq.empty, None) with Unevaluable { + extends SubqueryExpression( + plan, outerAttrs, unresolvedOuterAttrs, exprId, Seq.empty, None) with Unevaluable { assert(!(withSinglePartition && partitionByExpressions.nonEmpty), "WITH SINGLE PARTITION is mutually exclusive with PARTITION BY") @@ -83,6 +85,14 @@ case class FunctionTableSubqueryArgumentExpression( copy(plan = plan) override def withNewOuterAttrs(outerAttrs: Seq[Expression]) : FunctionTableSubqueryArgumentExpression = copy(outerAttrs = outerAttrs) + override def withNewUnresolvedOuterAttrs( + unresolvedOuterAttrs: Seq[Expression] + ): FunctionTableSubqueryArgumentExpression = { + if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { + // throw + } + copy(unresolvedOuterAttrs = unresolvedOuterAttrs) + } override def hint: Option[HintInfo] = None override def withNewHint(hint: Option[HintInfo]): FunctionTableSubqueryArgumentExpression = copy() @@ -91,6 +101,7 @@ case class FunctionTableSubqueryArgumentExpression( FunctionTableSubqueryArgumentExpression( plan.canonicalized, outerAttrs.map(_.canonicalized), + unresolvedOuterAttrs.map(_.canonicalized), ExprId(0), partitionByExpressions, withSinglePartition, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index c0a2bf25fbe67..684cad88176aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -67,6 +67,8 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { * * @param plan: the subquery plan * @param outerAttrs: the outer references in the subquery plan + * @param unresolvedOuterAttrs: the outer references in the subquery plan + * but can't be resolved by its immediate parent plan * @param exprId: ID of the expression * @param joinCond: the join conditions with the outer query. It contains both inner and outer * query references. @@ -76,18 +78,22 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { abstract class SubqueryExpression( plan: LogicalPlan, outerAttrs: Seq[Expression], + unresolvedOuterAttrs: Seq[Expression], exprId: ExprId, joinCond: Seq[Expression], hint: Option[HintInfo]) extends PlanExpression[LogicalPlan] { override lazy val resolved: Boolean = childrenResolved && plan.resolved override lazy val references: AttributeSet = - AttributeSet.fromAttributeSets(outerAttrs.map(_.references)) + AttributeSet.fromAttributeSets(outerAttrs.map(_.references)) -- + AttributeSet.fromAttributeSets(unresolvedOuterAttrs.map(_.references)) override def children: Seq[Expression] = outerAttrs ++ joinCond override def withNewPlan(plan: LogicalPlan): SubqueryExpression def withNewOuterAttrs(outerAttrs: Seq[Expression]): SubqueryExpression def isCorrelated: Boolean = outerAttrs.nonEmpty def hint: Option[HintInfo] def withNewHint(hint: Option[HintInfo]): SubqueryExpression + def withNewUnresolvedOuterAttrs(unresolvedOuterAttrs: Seq[Expression]): SubqueryExpression + def getUnresolvedOuterAttrs: Seq[Expression] = unresolvedOuterAttrs } object SubqueryExpression { @@ -395,12 +401,14 @@ object SubExprUtils extends PredicateHelper { case class ScalarSubquery( plan: LogicalPlan, outerAttrs: Seq[Expression] = Seq.empty, + unresolvedOuterAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None, mayHaveCountBug: Option[Boolean] = None, needSingleJoin: Option[Boolean] = None) - extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { + extends SubqueryExpression( + plan, outerAttrs, unresolvedOuterAttrs, exprId, joinCond, hint) with Unevaluable { override def dataType: DataType = { if (!plan.schema.fields.nonEmpty) { throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(plan.schema.fields.length, @@ -412,12 +420,21 @@ case class ScalarSubquery( override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan) override def withNewOuterAttrs(outerAttrs: Seq[Expression]): ScalarSubquery = copy( outerAttrs = outerAttrs) + override def withNewUnresolvedOuterAttrs( + unresolvedOuterAttrs: Seq[Expression] + ): ScalarSubquery = { + if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { + // throw + } + copy(unresolvedOuterAttrs = unresolvedOuterAttrs) + } override def withNewHint(hint: Option[HintInfo]): ScalarSubquery = copy(hint = hint) override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" override lazy val canonicalized: Expression = { ScalarSubquery( plan.canonicalized, outerAttrs.map(_.canonicalized), + unresolvedOuterAttrs.map(_.canonicalized), ExprId(0), joinCond.map(_.canonicalized)) } @@ -457,21 +474,32 @@ case class UnresolvedScalarSubqueryPlanId(planId: Long) case class LateralSubquery( plan: LogicalPlan, outerAttrs: Seq[Expression] = Seq.empty, + unresolvedOuterAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None) - extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { + extends SubqueryExpression( + plan, outerAttrs, unresolvedOuterAttrs, exprId, joinCond, hint) with Unevaluable { override def dataType: DataType = plan.output.toStructType override def nullable: Boolean = true override def withNewPlan(plan: LogicalPlan): LateralSubquery = copy(plan = plan) override def withNewOuterAttrs(outerAttrs: Seq[Expression]): LateralSubquery = copy( outerAttrs = outerAttrs) + override def withNewUnresolvedOuterAttrs( + unresolvedOuterAttrs: Seq[Expression] + ): LateralSubquery = { + if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { + // throw + } + copy(unresolvedOuterAttrs = unresolvedOuterAttrs) + } override def withNewHint(hint: Option[HintInfo]): LateralSubquery = copy(hint = hint) override def toString: String = s"lateral-subquery#${exprId.id} $conditionString" override lazy val canonicalized: Expression = { LateralSubquery( plan.canonicalized, outerAttrs.map(_.canonicalized), + unresolvedOuterAttrs.map(_.canonicalized), ExprId(0), joinCond.map(_.canonicalized)) } @@ -500,13 +528,15 @@ case class LateralSubquery( case class ListQuery( plan: LogicalPlan, outerAttrs: Seq[Expression] = Seq.empty, + unresolvedOuterAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, // The plan of list query may have more columns after de-correlation, and we need to track the // number of the columns of the original plan, to report the data type properly. numCols: Int = -1, joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None) - extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { + extends SubqueryExpression( + plan, outerAttrs, unresolvedOuterAttrs, exprId, joinCond, hint) with Unevaluable { def childOutputs: Seq[Attribute] = plan.output.take(numCols) override def dataType: DataType = if (numCols > 1) { childOutputs.toStructType @@ -526,12 +556,21 @@ case class ListQuery( override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) override def withNewOuterAttrs(outerAttrs: Seq[Expression]): ListQuery = copy( outerAttrs = outerAttrs) + override def withNewUnresolvedOuterAttrs( + unresolvedOuterAttrs: Seq[Expression] + ): ListQuery = { + if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { + // throw + } + copy(unresolvedOuterAttrs = unresolvedOuterAttrs) + } override def withNewHint(hint: Option[HintInfo]): ListQuery = copy(hint = hint) override def toString: String = s"list#${exprId.id} $conditionString" override lazy val canonicalized: Expression = { ListQuery( plan.canonicalized, outerAttrs.map(_.canonicalized), + unresolvedOuterAttrs.map(_.canonicalized), ExprId(0), numCols, joinCond.map(_.canonicalized)) @@ -574,22 +613,32 @@ case class ListQuery( case class Exists( plan: LogicalPlan, outerAttrs: Seq[Expression] = Seq.empty, + unresolvedOuterAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None) - extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) + extends SubqueryExpression(plan, outerAttrs, unresolvedOuterAttrs, exprId, joinCond, hint) with Predicate with Unevaluable { override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) override def withNewOuterAttrs(outerAttrs: Seq[Expression]): Exists = copy( outerAttrs = outerAttrs) + override def withNewUnresolvedOuterAttrs( + unresolvedOuterAttrs: Seq[Expression] + ): Exists = { + if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { + // throw + } + copy(unresolvedOuterAttrs = unresolvedOuterAttrs) + } override def withNewHint(hint: Option[HintInfo]): Exists = copy(hint = hint) override def toString: String = s"exists#${exprId.id} $conditionString" override lazy val canonicalized: Expression = { Exists( plan.canonicalized, outerAttrs.map(_.canonicalized), + unresolvedOuterAttrs.map(_.canonicalized), ExprId(0), joinCond.map(_.canonicalized)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9d269f37e58b9..c2738b736c4c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -364,7 +364,7 @@ abstract class Optimizer(catalogManager: CatalogManager) case d: DynamicPruningSubquery => d case s @ ScalarSubquery( PhysicalOperation(projections, predicates, a @ Aggregate(group, _, child, _)), - _, _, _, _, mayHaveCountBug, _) + _, _, _, _, _, mayHaveCountBug, _) if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) && mayHaveCountBug.nonEmpty && mayHaveCountBug.get => // This is a subquery with an aggregate that may suffer from a COUNT bug. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index e867953bcf282..e6d553e80d4eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -90,7 +90,7 @@ object ConstantFolding extends Rule[LogicalPlan] { } // Don't replace ScalarSubquery if its plan is an aggregate that may suffer from a COUNT bug. - case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug, _) + case s @ ScalarSubquery(_, _, _, _, _, _, mayHaveCountBug, _) if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) && mayHaveCountBug.nonEmpty && mayHaveCountBug.get => s @@ -891,7 +891,7 @@ object NullPropagation extends Rule[LogicalPlan] { case InSubquery(Seq(Literal(null, _)), _) if SQLConf.get.legacyNullInEmptyBehavior => Literal.create(null, BooleanType) - case InSubquery(Seq(Literal(null, _)), ListQuery(sub, _, _, _, conditions, _)) + case InSubquery(Seq(Literal(null, _)), ListQuery(sub, _, _, _, _, conditions, _)) if !SQLConf.get.legacyNullInEmptyBehavior && conditions.isEmpty => If(Exists(sub), Literal(null, BooleanType), FalseLiteral) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 5a4e9f37c3951..a46b5d83ff35d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -319,7 +319,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val introducedAttrs = ArrayBuffer.empty[Attribute] val newExprs = exprs.map { e => e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) { - case Exists(sub, _, _, conditions, subHint) => + case Exists(sub, _, _, _, conditions, subHint) => val exists = AttributeReference("exists", BooleanType, nullable = false)() val existenceJoin = ExistenceJoin(exists) val newCondition = conditions.reduceLeftOption(And) @@ -328,7 +328,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { existenceJoin, newCondition, subHint) introducedAttrs += exists exists - case Not(InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint))) => + case Not(InSubquery(values, ListQuery(sub, _, _, _, _, conditions, subHint))) => val exists = AttributeReference("exists", BooleanType, nullable = false)() // Deduplicate conflicting attributes if any. val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values)) @@ -353,7 +353,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { ExistenceJoin(exists), Some(finalJoinCond), joinHint) introducedAttrs += exists Not(exists) - case InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint)) => + case InSubquery(values, ListQuery(sub, _, _, _, _, conditions, subHint)) => val exists = AttributeReference("exists", BooleanType, nullable = false)() // Deduplicate conflicting attributes if any. val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values)) @@ -506,7 +506,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper } plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { - case ScalarSubquery(sub, children, exprId, conditions, hint, + case ScalarSubquery(sub, children, unresolvedOuterAttrs, exprId, conditions, hint, mayHaveCountBugOld, needSingleJoinOld) if children.nonEmpty => @@ -558,26 +558,32 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper } else { conf.getConf(SQLConf.SCALAR_SUBQUERY_USE_SINGLE_JOIN) && !guaranteedToReturnOneRow(sub) } - ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), + ScalarSubquery(newPlan, children, unresolvedOuterAttrs, + exprId, getJoinCondition(newCond, conditions), hint, Some(mayHaveCountBug), Some(needSingleJoin)) - case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty => + case Exists(sub, children, unresolvedOuterAttrs, exprId, conditions, hint) + if children.nonEmpty => val (newPlan, newCond) = if (SQLConf.get.decorrelateInnerQueryEnabledForExistsIn) { decorrelate(sub, plan, handleCountBug = true) } else { pullOutCorrelatedPredicates(sub, plan) } - Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint) - case ListQuery(sub, children, exprId, numCols, conditions, hint) if children.nonEmpty => + Exists(newPlan, children, unresolvedOuterAttrs, + exprId, getJoinCondition(newCond, conditions), hint) + case ListQuery(sub, children, unresolvedOuterAttrs, exprId, numCols, conditions, hint) + if children.nonEmpty => val (newPlan, newCond) = if (SQLConf.get.decorrelateInnerQueryEnabledForExistsIn) { decorrelate(sub, plan, handleCountBug = true) } else { pullOutCorrelatedPredicates(sub, plan) } val joinCond = getJoinCondition(newCond, conditions) - ListQuery(newPlan, children, exprId, numCols, joinCond, hint) - case LateralSubquery(sub, children, exprId, conditions, hint) if children.nonEmpty => + ListQuery(newPlan, children, unresolvedOuterAttrs, exprId, numCols, joinCond, hint) + case LateralSubquery(sub, children, unresolvedOuterAttrs, exprId, conditions, hint) + if children.nonEmpty => val (newPlan, newCond) = decorrelate(sub, plan, handleCountBug = true) - LateralSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint) + LateralSubquery(newPlan, children, + unresolvedOuterAttrs, exprId, getJoinCondition(newCond, conditions), hint) } } @@ -817,7 +823,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = { val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]() val newChild = subqueries.foldLeft(child) { - case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug, + case (currentChild, ScalarSubquery(sub, _, _, _, conditions, subHint, mayHaveCountBug, needSingleJoin)) => val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions) val origOutput = query.output.head @@ -1004,7 +1010,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe object RewriteLateralSubquery extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( _.containsPattern(LATERAL_JOIN)) { - case LateralJoin(left, LateralSubquery(sub, _, _, joinCond, subHint), joinType, condition) => + case LateralJoin(left, LateralSubquery(sub, _, _, _, joinCond, subHint), joinType, condition) => val newRight = DecorrelateInnerQuery.rewriteDomainJoins(left, sub, joinCond) val newCond = (condition ++ joinCond).reduceOption(And) // The subquery appears on the right side of the join, hence add the hint to the right side @@ -1041,7 +1047,7 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { */ private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries { case LateralJoin( - left, right @ LateralSubquery(OneRowSubquery(plan), _, _, _, _), _, None) + left, right @ LateralSubquery(OneRowSubquery(plan), _, _, _, _, _), _, None) if !hasCorrelatedSubquery(right.plan) && right.joinCond.isEmpty => plan match { case Project(projectList, _: OneRowRelation) => @@ -1064,7 +1070,8 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { case p: LogicalPlan => p.transformExpressionsUpWithPruning( _.containsPattern(SCALAR_SUBQUERY)) { - case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _) + case s @ ScalarSubquery( + OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _, _) if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty => assert(p.projectList.size == 1) stripOuterReferences(p.projectList).head diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d4298087b2108..58584f00563d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4009,6 +4009,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SUPPORT_NESTED_CORRELATED_SUBQUERIES = + buildConf("spark.sql.optimizer.supportNestedCorrelatedSubqueries.enabled") + .internal() + .doc("If enabled, support nested correlated subqueries") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val PULL_OUT_NESTED_DATA_OUTER_REF_EXPRESSIONS_ENABLED = buildConf("spark.sql.optimizer.pullOutNestedDataOuterRefExpressions.enabled") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index 5f2638655c37c..dd38c1e94e481 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -30,11 +30,11 @@ case class PlanAdaptiveSubqueries( def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning( _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { - case expressions.ScalarSubquery(_, _, exprId, _, _, _, _) => + case expressions.ScalarSubquery(_, _, _, exprId, _, _, _, _) => val subquery = SubqueryExec.createForScalarSubquery( s"subquery#${exprId.id}", subqueryMap(exprId.id)) execution.ScalarSubquery(subquery, exprId) - case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _, _)) => + case expressions.InSubquery(values, ListQuery(_, _, _, exprId, _, _, _)) => val expr = if (values.length == 1) { values.head } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 771c743f70629..32468c84cc9df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -187,7 +187,7 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { SubqueryExec.createForScalarSubquery( s"scalar-subquery#${subquery.exprId.id}", executedPlan), subquery.exprId) - case expressions.InSubquery(values, ListQuery(query, _, exprId, _, _, _)) => + case expressions.InSubquery(values, ListQuery(query, _, _, exprId, _, _, _)) => val expr = if (values.length == 1) { values.head } else { From ea57fe5de82717e34d04afa366c15ff9ece2d08d Mon Sep 17 00:00:00 2001 From: Avery Qi Date: Fri, 24 Jan 2025 13:50:04 -0800 Subject: [PATCH 02/13] change version --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 58584f00563d3..a12cbd991c208 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4013,7 +4013,7 @@ object SQLConf { buildConf("spark.sql.optimizer.supportNestedCorrelatedSubqueries.enabled") .internal() .doc("If enabled, support nested correlated subqueries") - .version("4.0.0") + .version("4.1.0") .booleanConf .createWithDefault(false) From d7deba6ea7fb0d2d36fdfbb24157c49b9300f559 Mon Sep 17 00:00:00 2001 From: Avery Qi Date: Fri, 24 Jan 2025 13:52:24 -0800 Subject: [PATCH 03/13] test --- .../org/apache/spark/sql/SQLQuerySuite.scala | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1b8f596a999b7..c92b38d73edc7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4941,6 +4941,79 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Row(Array(0), Array(0)), Row(Array(1), Array(1)), Row(Array(2), Array(2))) checkAnswer(df, expectedAnswer) } + + test("SPARK-50983: Support nested correlated subquery") { + sql("Create table Sales (product_id int, amount int)") + sql("create table Products (product_id int, category_id int, price double)") + + withTable("Sales", "Products") { + sql("insert into Sales values (1, 10), (2, 20), (3, 30), (4, 40), (5, 50)") + sql( + "insert into " + + "Products values (1, 1, 100.0), (2, 1, 200.0)," + + " (3, 2, 300.0), (4, 2, 400.0), (5, 3, 500.0)" + ) + val query = + """ + |SELECT p.product_id, p.category_id, p.price + |FROM Products p + |WHERE p.product_id IN ( + | SELECT s.product_id + | FROM Sales s + | WHERE s.amount > ( + | SELECT AVG(s2.amount) + | FROM Sales s2 + | WHERE s2.product_id IN ( + | SELECT p2.product_id + | FROM Products p2 + | WHERE p2.category_id = p.category_id + | ) + | ) + |); + |""".stripMargin + + val query2 = + """ + |SELECT p.product_id, p.category_id, p.price + |FROM Products p + |WHERE p.product_id IN ( + | SELECT s.product_id + | FROM Sales s + | WHERE s.product_id = p.product_id + |); + |""".stripMargin + + val query3 = + """ + |SELECT p.product_id, p.category_id, p.price + |FROM Products p + |WHERE p.product_id IN ( + | SELECT s.product_id + | FROM Sales s + | WHERE s.amount > ( + | SELECT AVG(s2.amount) + | FROM Sales s2 + | WHERE s2.product_id IN ( + | SELECT p2.product_id + | FROM Products p2 + | ) + | ) + |); + |""".stripMargin + withSQLConf( + "spark.sql.planChangeLog.level" -> "info", + "spark.sql.optimizer.supportNestedCorrelatedSubqueries.enabled" -> "false" + ) { + try { + sql(query).collect() + } catch { + case e: AnalysisException => + assert(e.errorClass.isDefined && e.errorClass.get == + "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NESTED_CORRELATED_SUBQUERIES_NOT_SUPPORTED") + } + } + } + } } case class Foo(bar: Option[String]) From ad073fffca9bef1d1c5511a32592166cc7e61b75 Mon Sep 17 00:00:00 2001 From: Avery Qi Date: Fri, 24 Jan 2025 13:54:41 -0800 Subject: [PATCH 04/13] add todo --- .../FunctionTableSubqueryArgumentExpression.scala | 2 +- .../apache/spark/sql/catalyst/expressions/subquery.scala | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala index 844f1c984507a..9e5bacc0a82fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala @@ -89,7 +89,7 @@ case class FunctionTableSubqueryArgumentExpression( unresolvedOuterAttrs: Seq[Expression] ): FunctionTableSubqueryArgumentExpression = { if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { - // throw + // TODO(avery): create suitable error subclass to throw } copy(unresolvedOuterAttrs = unresolvedOuterAttrs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 684cad88176aa..9d987171446f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -424,7 +424,7 @@ case class ScalarSubquery( unresolvedOuterAttrs: Seq[Expression] ): ScalarSubquery = { if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { - // throw + // TODO(avery): create suitable error subclass to throw } copy(unresolvedOuterAttrs = unresolvedOuterAttrs) } @@ -489,7 +489,7 @@ case class LateralSubquery( unresolvedOuterAttrs: Seq[Expression] ): LateralSubquery = { if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { - // throw + // TODO(avery): create suitable error subclass to throw } copy(unresolvedOuterAttrs = unresolvedOuterAttrs) } @@ -560,7 +560,7 @@ case class ListQuery( unresolvedOuterAttrs: Seq[Expression] ): ListQuery = { if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { - // throw + // TODO(avery): create suitable error subclass to throw } copy(unresolvedOuterAttrs = unresolvedOuterAttrs) } @@ -628,7 +628,7 @@ case class Exists( unresolvedOuterAttrs: Seq[Expression] ): Exists = { if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { - // throw + // TODO(avery): create suitable error subclass to throw } copy(unresolvedOuterAttrs = unresolvedOuterAttrs) } From 52dc80262cc73c8cfa46b7821523ddcca86b3de4 Mon Sep 17 00:00:00 2001 From: Avery Date: Fri, 24 Jan 2025 14:17:09 -0800 Subject: [PATCH 05/13] restore unrelated comments --- .../sql/catalyst/analysis/Analyzer.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2bcaca9f7bcc2..bcef55a9a5554 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -359,7 +359,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // Please do not insert any other rules in between. See the TODO comments in rule // ResolveLateralColumnAliasReference for more details. ResolveLateralColumnAliasReference :: - ResolveExpressionsWithNamePlaceholders :: + :: ResolveDeserializer :: ResolveNewInstance :: ResolveUpCast :: @@ -4109,25 +4109,25 @@ object ResolveExpressionsWithNamePlaceholders extends Rule[LogicalPlan] { * +- Filter exists#245 [min(b#227)#249] * : +- Project [1 AS 1#247] * : +- Filter (d#238 < min(outer(b#227))) <----- - * :- SubqueryAlias r - * : - Project [_1#234 AS c#237, _2#235 AS d#238] - * : - LocalRelation [_1#234, _2#235] + * : +- SubqueryAlias r + * : +- Project [_1#234 AS c#237, _2#235 AS d#238] + * : +- LocalRelation [_1#234, _2#235] * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] - * - SubqueryAlias l - * - Project [_1#223 AS a#226, _2#224 AS b#227] - * - LocalRelation [_1#223, _2#224] + * +- SubqueryAlias l + * +- Project [_1#223 AS a#226, _2#224 AS b#227] + * +- LocalRelation [_1#223, _2#224] * Plan after the rule. * Project [a#226] * +- Filter exists#245 [min(b#227)#249] * : +- Project [1 AS 1#247] * : +- Filter (d#238 < outer(min(b#227)#249)) <----- - * :- SubqueryAlias r - * : - Project [_1#234 AS c#237, _2#235 AS d#238] - * : - LocalRelation [_1#234, _2#235] + * : +- SubqueryAlias r + * : +- Project [_1#234 AS c#237, _2#235 AS d#238] + * : +- LocalRelation [_1#234, _2#235] * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] - * - SubqueryAlias l - * - Project [_1#223 AS a#226, _2#224 AS b#227] - * - LocalRelation [_1#223, _2#224] + * +- SubqueryAlias l + * +- Project [_1#223 AS a#226, _2#224 AS b#227] + * +- LocalRelation [_1#223, _2#224] */ object UpdateOuterReferences extends Rule[LogicalPlan] { private def stripAlias(expr: Expression): Expression = expr match { case a: Alias => a.child } From ae31b98208636693c0daebab746c933dd982c93a Mon Sep 17 00:00:00 2001 From: Avery Date: Fri, 24 Jan 2025 14:19:15 -0800 Subject: [PATCH 06/13] restore unrelated comments --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bcef55a9a5554..782eba43452bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -359,7 +359,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // Please do not insert any other rules in between. See the TODO comments in rule // ResolveLateralColumnAliasReference for more details. ResolveLateralColumnAliasReference :: - :: + ResolveExpressionsWithNamePlaceholders :: ResolveDeserializer :: ResolveNewInstance :: ResolveUpCast :: From 0d40ac05e412ffe46d0c8ea5367ec7e075089a35 Mon Sep 17 00:00:00 2001 From: Avery Qi Date: Thu, 30 Jan 2025 15:08:17 -0800 Subject: [PATCH 07/13] fix analyzer --- .../analysis/ColumnResolutionHelper.scala | 48 +++++++++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 29 +++++++++++ 2 files changed, 72 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 56b2103c555db..44ac6b5269480 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -222,7 +222,38 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { val outerPlan = AnalysisContext.get.outerPlan if (outerPlan.isEmpty) return e - def resolve(nameParts: Seq[String]): Option[Expression] = try { + def findNestedSubqueryPlans(p: LogicalPlan): Seq[LogicalPlan] = { + if (!p.containsPattern(PLAN_EXPRESSION)) { + // There are no nested subquery plans in the current plan, + // stop searching for its children plan + return Seq.empty + } + + val subqueriesInThisNode: Seq[SubqueryExpression] = + p.expressions.flatMap(_.collect { + case in: InSubquery => in.query + case s: SubqueryExpression => s + }) + + val subqueryPlansFromExpressions: Seq[LogicalPlan] = + subqueriesInThisNode.flatMap(s => findNestedSubqueryPlans(s.plan) :+ s.plan) + + val subqueryPlansFromChildren: Seq[LogicalPlan] = + p.children.flatMap(findNestedSubqueryPlans) + + // Subquery plan in more inner position gets collected first + // As it is more near the position of the outer reference, it is more likely to have + // the original attributes. + // Though as there are no conflicts, the order does not affect correctness. + subqueryPlansFromChildren ++ subqueryPlansFromExpressions + } + + // The passed in `outerPlan` is the outermost plan + // Outer references can be from the `outerPlan` or any of its nested subquery plans + // So we need to try resolving the outer references by using all the plans + val outerPlans = Seq(outerPlan.get) ++ findNestedSubqueryPlans(outerPlan.get) + + def resolve(nameParts: Seq[String], outerPlan: Option[LogicalPlan]): Option[Expression] = try { outerPlan.get match { // Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions. // We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will @@ -240,14 +271,21 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { None } - e.transformWithPruning( - _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { case u: UnresolvedAttribute => - resolve(u.nameParts).getOrElse(u) + val maybeResolved = outerPlans.foldLeft(Option.empty[Expression]) { (acc, plan) => + // If we've already resolved, keep that; otherwise try this plan + acc.orElse(resolve(u.nameParts, Some(plan))) + } + maybeResolved.getOrElse(u) // Re-resolves `TempResolvedColumn` as outer references if it has tried to be resolved with // Aggregate but failed. case t: TempResolvedColumn if t.hasTried => - resolve(t.nameParts).getOrElse(t) + val maybeResolved = outerPlans.foldLeft(Option.empty[Expression]) { (acc, plan) => + // If we've already resolved, keep that; otherwise try this plan + acc.orElse(resolve(t.nameParts, Some(plan))) + } + maybeResolved.getOrElse(t) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c92b38d73edc7..98b9da118b6a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -5014,6 +5014,35 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-50983: Support nested correlated subquery for analyzer") { + sql("Create table Sales (product_id int, amount int)") + sql("create table Products (product_id int, category_id int, price double)") + + val query = + """ + | SELECT s.product_id + | FROM Sales s + | WHERE s.amount > ( + | SELECT AVG(s2.amount) + | FROM Sales s2 + | WHERE s2.product_id IN ( + | SELECT p2.product_id + | FROM Products p2 + | WHERE p2.product_id = s2.product_id + | ) + | ); + |""".stripMargin + + withTable("Sales", "Products") { + withSQLConf( + "spark.sql.optimizer.supportNestedCorrelatedSubqueries.enabled" -> "true", + "spark.sql.planChangeLog.level" -> "info" + ) { + val res = sql(query).collect() + } + } + } } case class Foo(bar: Option[String]) From 1f749ab66c216bb12d7dd82536ba8bcb87dc1456 Mon Sep 17 00:00:00 2001 From: Avery Qi Date: Thu, 30 Jan 2025 15:14:11 -0800 Subject: [PATCH 08/13] remove unnecessary modifications --- .../spark/sql/catalyst/analysis/CheckAnalysis.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 1bc7275d3e8b8..ce89dd4aaa4ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1395,9 +1395,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB // +- Project [c1#87, c2#88] // : (Aggregate or Window operator) // : +- Filter [outer(c2#77) >= c2#88)] - // : - SubqueryAlias t2, `t2` - // : - Project [_1#84 AS c1#87, _2#85 AS c2#88] - // : - LocalRelation [_1#84, _2#85] + // : +- SubqueryAlias t2, `t2` + // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] + // : +- LocalRelation [_1#84, _2#85] // +- SubqueryAlias t1, `t1` // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] // +- LocalRelation [_1#73, _2#74] @@ -1414,7 +1414,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB // Original subquery plan: // Aggregate [count(1)] // +- Filter ((a + b) = outer(c)) - //- LocalRelation [a, b] + // +- LocalRelation [a, b] // // Plan after pulling up correlated predicates: // Aggregate [a, b] [count(1), a, b] @@ -1424,8 +1424,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB // Project [c1, count(1)] // +- Join LeftOuter ((a + b) = c) // :- LocalRelation [c] - //- Aggregate [a, b] [count(1), a, b] - // - LocalRelation [a, b] + // +- Aggregate [a, b] [count(1), a, b] + // +- LocalRelation [a, b] // // The right hand side of the join transformed from the subquery will output // count(1) | a | b From 261c002fc74a8c63deb9bf860c208e1513b2a6eb Mon Sep 17 00:00:00 2001 From: Avery Date: Thu, 30 Jan 2025 15:37:33 -0800 Subject: [PATCH 09/13] fix wrong number of arguments --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 782eba43452bb..dc67aa92c588f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2365,7 +2365,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val res = resolveSubQuery(e, outer)(Exists(_, _, Seq.empty, exprId)) val unresolvedOuterReferences = getUnresolvedOuterReferences(res, plan) res.withNewUnresolvedOuterAttrs(unresolvedOuterReferences) - case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _)) + case InSubquery(values, l @ ListQuery(_, _, _, exprId, _, _, _)) if values.forall(_.resolved) && !l.resolved => val expr = resolveSubQuery(l, outer)((plan, exprs) => { ListQuery(plan, exprs, Seq.empty, exprId, plan.output.length) From 1aa444143945dc3f32a8461b52dc2514e637beb8 Mon Sep 17 00:00:00 2001 From: Avery Qi Date: Thu, 30 Jan 2025 15:39:48 -0800 Subject: [PATCH 10/13] add getOuterAttrs function to SubqueryExpression --- .../org/apache/spark/sql/catalyst/expressions/subquery.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 9d987171446f7..8d8e0cce0812a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -94,6 +94,7 @@ abstract class SubqueryExpression( def withNewHint(hint: Option[HintInfo]): SubqueryExpression def withNewUnresolvedOuterAttrs(unresolvedOuterAttrs: Seq[Expression]): SubqueryExpression def getUnresolvedOuterAttrs: Seq[Expression] = unresolvedOuterAttrs + def getOuterAttrs: Seq[Expression] = outerAttrs } object SubqueryExpression { From e0300e8f52789db4fa4e483dfac2146d4535883c Mon Sep 17 00:00:00 2001 From: Avery Qi Date: Thu, 30 Jan 2025 16:10:34 -0800 Subject: [PATCH 11/13] add arguments changes for DynamicPruningSubquery --- .../spark/sql/catalyst/catalog/SQLFunction.scala | 3 ++- .../sql/catalyst/expressions/DynamicPruning.scala | 11 ++++++++++- .../spark/sql/catalyst/optimizer/subquery.scala | 8 ++++---- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala index 923373c1856a9..acb36a9db4549 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala @@ -85,7 +85,8 @@ case class SQLFunction( case (None, Some(Project(expr :: Nil, _: OneRowRelation))) if !isTableFunc => (Some(expr), None) - case (Some(ScalarSubquery(Project(expr :: Nil, _: OneRowRelation), _, _, _, _, _, _)), None) + case (Some(ScalarSubquery( + Project(expr :: Nil, _: OneRowRelation), _, _, _, _, _, _, _)), None) if !isTableFunc => (Some(expr), None) case (_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala index b65576403e9d8..197439e46ca12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.errors.QueryExecutionErrors trait DynamicPruning extends Predicate @@ -47,7 +48,7 @@ case class DynamicPruningSubquery( onlyInBroadcast: Boolean, exprId: ExprId = NamedExpression.newExprId, hint: Option[HintInfo] = None) - extends SubqueryExpression(buildQuery, Seq(pruningKey), exprId, Seq.empty, hint) + extends SubqueryExpression(buildQuery, Seq(pruningKey), Seq.empty, exprId, Seq.empty, hint) with DynamicPruning with Unevaluable with UnaryLike[Expression] { @@ -67,6 +68,14 @@ case class DynamicPruningSubquery( copy() } + override def withNewUnresolvedOuterAttrs( + unresolvedOuterAttrs: Seq[Expression] + ): DynamicPruningSubquery = { + // TODO(avery): create suitable error subclass to throw + // DynamicPruningSubquery should not have this method called on it. + return this + } + override def withNewHint(hint: Option[HintInfo]): SubqueryExpression = copy(hint = hint) override lazy val resolved: Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index a46b5d83ff35d..6dda96118327b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -131,17 +131,17 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { - case (p, Exists(sub, _, _, conditions, subHint)) => + case (p, Exists(sub, _, _, _, conditions, subHint)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) val join = buildJoin(outerPlan, rewriteDomainJoinsIfPresent(outerPlan, sub, joinCond), LeftSemi, joinCond, subHint) Project(p.output, join) - case (p, Not(Exists(sub, _, _, conditions, subHint))) => + case (p, Not(Exists(sub, _, _, _, conditions, subHint))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) val join = buildJoin(outerPlan, rewriteDomainJoinsIfPresent(outerPlan, sub, joinCond), LeftAnti, joinCond, subHint) Project(p.output, join) - case (p, InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint))) => + case (p, InSubquery(values, ListQuery(sub, _, _, _, _, conditions, subHint))) => // Deduplicate conflicting attributes if any. val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values)) val inConditions = values.zip(newSub.output).map(EqualTo.tupled) @@ -149,7 +149,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val join = Join(outerPlan, rewriteDomainJoinsIfPresent(outerPlan, newSub, joinCond), LeftSemi, joinCond, JoinHint(None, subHint)) Project(p.output, join) - case (p, Not(InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint)))) => + case (p, Not(InSubquery(values, ListQuery(sub, _, _, _, _, conditions, subHint)))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. From b003c6b4ca60cfaf613b48d8d9a7ffb2624c035b Mon Sep 17 00:00:00 2001 From: Avery Qi Date: Thu, 30 Jan 2025 16:26:50 -0800 Subject: [PATCH 12/13] remove unused import --- .../apache/spark/sql/catalyst/expressions/DynamicPruning.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala index 197439e46ca12..70774be98774e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.trees.UnaryLike -import org.apache.spark.sql.errors.QueryExecutionErrors trait DynamicPruning extends Predicate From ebbf523cab36e28f6559fefde504e4985efb1140 Mon Sep 17 00:00:00 2001 From: Avery Qi Date: Thu, 30 Jan 2025 17:19:27 -0800 Subject: [PATCH 13/13] fix wrong number of arguments --- .../spark/sql/execution/command/PlanResolutionSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 2cc203129817b..726b9920aef31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -1008,7 +1008,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { query match { case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()), UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))), - _, _, _, _, _) => + _, _, _, _, _, _) => assert(projects.size == 1 && projects.head.name == "s.name") assert(outputColumnNames.size == 1 && outputColumnNames.head == "name") case o => fail("Unexpected subquery: \n" + o.treeString) @@ -1089,7 +1089,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { query match { case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()), UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))), - _, _, _, _, _) => + _, _, _, _, _, _) => assert(projects.size == 1 && projects.head.name == "s.name") assert(outputColumnNames.size == 1 && outputColumnNames.head == "name") case o => fail("Unexpected subquery: \n" + o.treeString)