diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 48a71cfbf615f..ce3f765851159 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5906,6 +5906,11 @@ "Correlated scalar subqueries must be aggregated to return at most one row." ] }, + "NESTED_CORRELATED_SUBQUERIES_NOT_SUPPORTED" : { + "message" : [ + "Nested correlated subqueries are not supported." + ] + }, "NON_CORRELATED_COLUMNS_IN_GROUP_BY" : { "message" : [ "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns: ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 92cfc4119dd0c..dc67aa92c588f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2263,6 +2263,37 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor * Note: CTEs are handled in CTESubstitution. */ object ResolveSubquery extends Rule[LogicalPlan] { + + private def getOuterAttrsNeedToBePropagated(plan: LogicalPlan): Seq[Expression] = { + plan.expressions.flatMap { + case subExpr: SubqueryExpression => subExpr.getUnresolvedOuterAttrs + case in: InSubquery => in.query.getUnresolvedOuterAttrs + case expr if expr.containsPattern(PLAN_EXPRESSION) => + expr.collect { + case subExpr: SubqueryExpression => subExpr.getUnresolvedOuterAttrs + }.flatten + case _ => Seq.empty + } ++ plan.children.flatMap{ + case p if p.containsPattern(PLAN_EXPRESSION) => + getOuterAttrsNeedToBePropagated(p) + case _ => Seq.empty + } + } + + private def getUnresolvedOuterReferences( + s: SubqueryExpression, p: LogicalPlan + ): Seq[Expression] = { + val outerReferencesInSubquery = s.getOuterAttrs + + // return outer references cannot be handled in current plan + outerReferencesInSubquery.filter( + _ match { + case a: AttributeReference => !p.inputSet.contains(a) + case _ => false + } + ) + } + /** * Resolves the subquery plan that is referenced in a subquery expression, by invoking the * entire analyzer recursively. We set outer plan in `AnalysisContext`, so that the analyzer @@ -2274,15 +2305,23 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor e: SubqueryExpression, outer: LogicalPlan)( f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { - val newSubqueryPlan = AnalysisContext.withOuterPlan(outer) { - executeSameContext(e.plan) + val newSubqueryPlan = if (AnalysisContext.get.outerPlan.isDefined) { + val propogatedOuterPlan = AnalysisContext.get.outerPlan.get + AnalysisContext.withOuterPlan(propogatedOuterPlan) { + executeSameContext(e.plan) + } + } else { + AnalysisContext.withOuterPlan(outer) { + executeSameContext(e.plan) + } } // If the subquery plan is fully resolved, pull the outer references and record // them as children of SubqueryExpression. if (newSubqueryPlan.resolved) { // Record the outer references as children of subquery expression. - f(newSubqueryPlan, SubExprUtils.getOuterReferences(newSubqueryPlan)) + f(newSubqueryPlan, SubExprUtils.getOuterReferences(newSubqueryPlan) ++ + getOuterAttrsNeedToBePropagated(newSubqueryPlan)) } else { e.withNewPlan(newSubqueryPlan) } @@ -2299,18 +2338,45 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor */ private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) { - case s @ ScalarSubquery(sub, _, exprId, _, _, _, _) if !sub.resolved => - resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId)) - case e @ Exists(sub, _, exprId, _, _) if !sub.resolved => - resolveSubQuery(e, outer)(Exists(_, _, exprId)) - case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _)) + // There are four kinds of outer references here: + // 1. Outer references which are newly introduced in the subquery `res` + // which can be resolved in current `plan`. + // It is extracted by `SubExprUtils.getOuterReferences(res.plan)` and + // stored among res.outerAttrs + // 2. Outer references which are newly introduced in the subquery `res` + // which cannot be resolved in current `plan` + // It is extracted by `SubExprUtils.getOuterReferences(res.plan)` with + // `getUnresolvedOuterReferences(res, plan)` filter and stored in + // res.unresolvedOuterAttrs + // 3. Outer references which are introduced by nested subquery within `res.plan` + // which can be resolved in current `plan` + // It is extracted by `getOuterAttrsNeedToBePropagated(res.plan)`, filtered + // by `plan.inputSet.contains(_)`, need to be stored in res.outerAttrs + // 4. Outer references which are introduced by nested subquery within `res.plan` + // which cannot be resolved in current `plan` + // It is extracted by `getOuterAttrsNeedToBePropagated(res.plan)`, filtered + // by `!plan.inputSet.contains(_)`, need to be stored in + // res.outerAttrs and res.unresolvedOuterAttrs + case s @ ScalarSubquery(sub, _, _, exprId, _, _, _, _) if !sub.resolved => + val res = resolveSubQuery(s, outer)(ScalarSubquery(_, _, Seq.empty, exprId)) + val unresolvedOuterReferences = getUnresolvedOuterReferences(res, plan) + res.withNewUnresolvedOuterAttrs(unresolvedOuterReferences) + case e @ Exists(sub, _, _, exprId, _, _) if !sub.resolved => + val res = resolveSubQuery(e, outer)(Exists(_, _, Seq.empty, exprId)) + val unresolvedOuterReferences = getUnresolvedOuterReferences(res, plan) + res.withNewUnresolvedOuterAttrs(unresolvedOuterReferences) + case InSubquery(values, l @ ListQuery(_, _, _, exprId, _, _, _)) if values.forall(_.resolved) && !l.resolved => val expr = resolveSubQuery(l, outer)((plan, exprs) => { - ListQuery(plan, exprs, exprId, plan.output.length) - }) - InSubquery(values, expr.asInstanceOf[ListQuery]) - case s @ LateralSubquery(sub, _, exprId, _, _) if !sub.resolved => - resolveSubQuery(s, outer)(LateralSubquery(_, _, exprId)) + ListQuery(plan, exprs, Seq.empty, exprId, plan.output.length) + }).asInstanceOf[ListQuery] + val unresolvedOuterReferences = getUnresolvedOuterReferences(expr, plan) + val newExpr = expr.withNewUnresolvedOuterAttrs(unresolvedOuterReferences) + InSubquery(values, newExpr) + case s @ LateralSubquery(sub, _, _, exprId, _, _) if !sub.resolved => + val res = resolveSubQuery(s, outer)(LateralSubquery(_, _, Seq.empty, exprId)) + val unresolvedOuterReferences = getUnresolvedOuterReferences(res, plan) + res.withNewUnresolvedOuterAttrs(unresolvedOuterReferences) case a: FunctionTableSubqueryArgumentExpression if !a.plan.resolved => resolveSubQuery(a, outer)( (plan, outerAttrs) => a.copy(plan = plan, outerAttrs = outerAttrs)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 269411a6fc616..ce89dd4aaa4ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -231,6 +231,32 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB } } + def checkNoUnresolvedOuterReferencesInMainQuery(plan: LogicalPlan): Unit = { + plan.expressions.foreach { + case subExpr: SubqueryExpression if subExpr.getUnresolvedOuterAttrs.nonEmpty => + subExpr.failAnalysis( + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_NOT_FOUND", + messageParameters = Map.empty) + case in: InSubquery if in.query.getUnresolvedOuterAttrs.nonEmpty => + in.query.failAnalysis( + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_NOT_FOUND", + messageParameters = Map.empty) + case expr if expr.containsPattern(PLAN_EXPRESSION) => + expr.collect { + case subExpr: SubqueryExpression if subExpr.getUnresolvedOuterAttrs.nonEmpty => + subExpr.failAnalysis( + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_NOT_FOUND", + messageParameters = Map.empty) + } + case _ => + } + plan.children.foreach { + case p: LogicalPlan if p.containsPattern(PLAN_EXPRESSION) => + checkNoUnresolvedOuterReferencesInMainQuery(p) + case _ => + } + } + def checkAnalysis(plan: LogicalPlan): Unit = { // We should inline all CTE relations to restore the original plan shape, as the analysis check // may need to match certain plan shapes. For dangling CTE relations, they will still be kept @@ -244,6 +270,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB } preemptedError.clear() try { + checkNoUnresolvedOuterReferencesInMainQuery(inlinedPlan) checkAnalysis0(inlinedPlan) preemptedError.getErrorOpt().foreach(throw _) // throw preempted error if any } catch { @@ -1137,6 +1164,20 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case _ => } + def checkUnresolvedOuterReferences(expr: SubqueryExpression): Unit = { + if ((!SQLConf.get.getConf(SQLConf.SUPPORT_NESTED_CORRELATED_SUBQUERIES)) && + expr.getUnresolvedOuterAttrs.nonEmpty) { + expr.failAnalysis( + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + "NESTED_CORRELATED_SUBQUERIES_NOT_SUPPORTED", + messageParameters = Map.empty) + } + } + + // Check if there are nested correlated subqueries in the plan. + checkUnresolvedOuterReferences(expr) + + // Validate the subquery plan. checkAnalysis0(expr.plan) @@ -1144,7 +1185,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB checkOuterReference(plan, expr) expr match { - case ScalarSubquery(query, outerAttrs, _, _, _, _, _) => + case ScalarSubquery(query, outerAttrs, _, _, _, _, _, _) => // Scalar subquery must return one column as output. if (query.output.size != 1) { throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 56b2103c555db..44ac6b5269480 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -222,7 +222,38 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { val outerPlan = AnalysisContext.get.outerPlan if (outerPlan.isEmpty) return e - def resolve(nameParts: Seq[String]): Option[Expression] = try { + def findNestedSubqueryPlans(p: LogicalPlan): Seq[LogicalPlan] = { + if (!p.containsPattern(PLAN_EXPRESSION)) { + // There are no nested subquery plans in the current plan, + // stop searching for its children plan + return Seq.empty + } + + val subqueriesInThisNode: Seq[SubqueryExpression] = + p.expressions.flatMap(_.collect { + case in: InSubquery => in.query + case s: SubqueryExpression => s + }) + + val subqueryPlansFromExpressions: Seq[LogicalPlan] = + subqueriesInThisNode.flatMap(s => findNestedSubqueryPlans(s.plan) :+ s.plan) + + val subqueryPlansFromChildren: Seq[LogicalPlan] = + p.children.flatMap(findNestedSubqueryPlans) + + // Subquery plan in more inner position gets collected first + // As it is more near the position of the outer reference, it is more likely to have + // the original attributes. + // Though as there are no conflicts, the order does not affect correctness. + subqueryPlansFromChildren ++ subqueryPlansFromExpressions + } + + // The passed in `outerPlan` is the outermost plan + // Outer references can be from the `outerPlan` or any of its nested subquery plans + // So we need to try resolving the outer references by using all the plans + val outerPlans = Seq(outerPlan.get) ++ findNestedSubqueryPlans(outerPlan.get) + + def resolve(nameParts: Seq[String], outerPlan: Option[LogicalPlan]): Option[Expression] = try { outerPlan.get match { // Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions. // We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will @@ -240,14 +271,21 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { None } - e.transformWithPruning( - _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { case u: UnresolvedAttribute => - resolve(u.nameParts).getOrElse(u) + val maybeResolved = outerPlans.foldLeft(Option.empty[Expression]) { (acc, plan) => + // If we've already resolved, keep that; otherwise try this plan + acc.orElse(resolve(u.nameParts, Some(plan))) + } + maybeResolved.getOrElse(u) // Re-resolves `TempResolvedColumn` as outer references if it has tried to be resolved with // Aggregate but failed. case t: TempResolvedColumn if t.hasTried => - resolve(t.nameParts).getOrElse(t) + val maybeResolved = outerPlans.foldLeft(Option.empty[Expression]) { (acc, plan) => + // If we've already resolved, keep that; otherwise try this plan + acc.orElse(resolve(t.nameParts, Some(plan))) + } + maybeResolved.getOrElse(t) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala index 923373c1856a9..acb36a9db4549 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala @@ -85,7 +85,8 @@ case class SQLFunction( case (None, Some(Project(expr :: Nil, _: OneRowRelation))) if !isTableFunc => (Some(expr), None) - case (Some(ScalarSubquery(Project(expr :: Nil, _: OneRowRelation), _, _, _, _, _, _)), None) + case (Some(ScalarSubquery( + Project(expr :: Nil, _: OneRowRelation), _, _, _, _, _, _, _)), None) if !isTableFunc => (Some(expr), None) case (_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala index b65576403e9d8..70774be98774e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala @@ -47,7 +47,7 @@ case class DynamicPruningSubquery( onlyInBroadcast: Boolean, exprId: ExprId = NamedExpression.newExprId, hint: Option[HintInfo] = None) - extends SubqueryExpression(buildQuery, Seq(pruningKey), exprId, Seq.empty, hint) + extends SubqueryExpression(buildQuery, Seq(pruningKey), Seq.empty, exprId, Seq.empty, hint) with DynamicPruning with Unevaluable with UnaryLike[Expression] { @@ -67,6 +67,14 @@ case class DynamicPruningSubquery( copy() } + override def withNewUnresolvedOuterAttrs( + unresolvedOuterAttrs: Seq[Expression] + ): DynamicPruningSubquery = { + // TODO(avery): create suitable error subclass to throw + // DynamicPruningSubquery should not have this method called on it. + return this + } + override def withNewHint(hint: Option[HintInfo]): SubqueryExpression = copy(hint = hint) override lazy val resolved: Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala index bfd3bc8051dff..9e5bacc0a82fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala @@ -67,12 +67,14 @@ import org.apache.spark.sql.types.DataType case class FunctionTableSubqueryArgumentExpression( plan: LogicalPlan, outerAttrs: Seq[Expression] = Seq.empty, + unresolvedOuterAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, partitionByExpressions: Seq[Expression] = Seq.empty, withSinglePartition: Boolean = false, orderByExpressions: Seq[SortOrder] = Seq.empty, selectedInputExpressions: Seq[PythonUDTFSelectedExpression] = Seq.empty) - extends SubqueryExpression(plan, outerAttrs, exprId, Seq.empty, None) with Unevaluable { + extends SubqueryExpression( + plan, outerAttrs, unresolvedOuterAttrs, exprId, Seq.empty, None) with Unevaluable { assert(!(withSinglePartition && partitionByExpressions.nonEmpty), "WITH SINGLE PARTITION is mutually exclusive with PARTITION BY") @@ -83,6 +85,14 @@ case class FunctionTableSubqueryArgumentExpression( copy(plan = plan) override def withNewOuterAttrs(outerAttrs: Seq[Expression]) : FunctionTableSubqueryArgumentExpression = copy(outerAttrs = outerAttrs) + override def withNewUnresolvedOuterAttrs( + unresolvedOuterAttrs: Seq[Expression] + ): FunctionTableSubqueryArgumentExpression = { + if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { + // TODO(avery): create suitable error subclass to throw + } + copy(unresolvedOuterAttrs = unresolvedOuterAttrs) + } override def hint: Option[HintInfo] = None override def withNewHint(hint: Option[HintInfo]): FunctionTableSubqueryArgumentExpression = copy() @@ -91,6 +101,7 @@ case class FunctionTableSubqueryArgumentExpression( FunctionTableSubqueryArgumentExpression( plan.canonicalized, outerAttrs.map(_.canonicalized), + unresolvedOuterAttrs.map(_.canonicalized), ExprId(0), partitionByExpressions, withSinglePartition, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index c0a2bf25fbe67..8d8e0cce0812a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -67,6 +67,8 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { * * @param plan: the subquery plan * @param outerAttrs: the outer references in the subquery plan + * @param unresolvedOuterAttrs: the outer references in the subquery plan + * but can't be resolved by its immediate parent plan * @param exprId: ID of the expression * @param joinCond: the join conditions with the outer query. It contains both inner and outer * query references. @@ -76,18 +78,23 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { abstract class SubqueryExpression( plan: LogicalPlan, outerAttrs: Seq[Expression], + unresolvedOuterAttrs: Seq[Expression], exprId: ExprId, joinCond: Seq[Expression], hint: Option[HintInfo]) extends PlanExpression[LogicalPlan] { override lazy val resolved: Boolean = childrenResolved && plan.resolved override lazy val references: AttributeSet = - AttributeSet.fromAttributeSets(outerAttrs.map(_.references)) + AttributeSet.fromAttributeSets(outerAttrs.map(_.references)) -- + AttributeSet.fromAttributeSets(unresolvedOuterAttrs.map(_.references)) override def children: Seq[Expression] = outerAttrs ++ joinCond override def withNewPlan(plan: LogicalPlan): SubqueryExpression def withNewOuterAttrs(outerAttrs: Seq[Expression]): SubqueryExpression def isCorrelated: Boolean = outerAttrs.nonEmpty def hint: Option[HintInfo] def withNewHint(hint: Option[HintInfo]): SubqueryExpression + def withNewUnresolvedOuterAttrs(unresolvedOuterAttrs: Seq[Expression]): SubqueryExpression + def getUnresolvedOuterAttrs: Seq[Expression] = unresolvedOuterAttrs + def getOuterAttrs: Seq[Expression] = outerAttrs } object SubqueryExpression { @@ -395,12 +402,14 @@ object SubExprUtils extends PredicateHelper { case class ScalarSubquery( plan: LogicalPlan, outerAttrs: Seq[Expression] = Seq.empty, + unresolvedOuterAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None, mayHaveCountBug: Option[Boolean] = None, needSingleJoin: Option[Boolean] = None) - extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { + extends SubqueryExpression( + plan, outerAttrs, unresolvedOuterAttrs, exprId, joinCond, hint) with Unevaluable { override def dataType: DataType = { if (!plan.schema.fields.nonEmpty) { throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(plan.schema.fields.length, @@ -412,12 +421,21 @@ case class ScalarSubquery( override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan) override def withNewOuterAttrs(outerAttrs: Seq[Expression]): ScalarSubquery = copy( outerAttrs = outerAttrs) + override def withNewUnresolvedOuterAttrs( + unresolvedOuterAttrs: Seq[Expression] + ): ScalarSubquery = { + if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { + // TODO(avery): create suitable error subclass to throw + } + copy(unresolvedOuterAttrs = unresolvedOuterAttrs) + } override def withNewHint(hint: Option[HintInfo]): ScalarSubquery = copy(hint = hint) override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" override lazy val canonicalized: Expression = { ScalarSubquery( plan.canonicalized, outerAttrs.map(_.canonicalized), + unresolvedOuterAttrs.map(_.canonicalized), ExprId(0), joinCond.map(_.canonicalized)) } @@ -457,21 +475,32 @@ case class UnresolvedScalarSubqueryPlanId(planId: Long) case class LateralSubquery( plan: LogicalPlan, outerAttrs: Seq[Expression] = Seq.empty, + unresolvedOuterAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None) - extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { + extends SubqueryExpression( + plan, outerAttrs, unresolvedOuterAttrs, exprId, joinCond, hint) with Unevaluable { override def dataType: DataType = plan.output.toStructType override def nullable: Boolean = true override def withNewPlan(plan: LogicalPlan): LateralSubquery = copy(plan = plan) override def withNewOuterAttrs(outerAttrs: Seq[Expression]): LateralSubquery = copy( outerAttrs = outerAttrs) + override def withNewUnresolvedOuterAttrs( + unresolvedOuterAttrs: Seq[Expression] + ): LateralSubquery = { + if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { + // TODO(avery): create suitable error subclass to throw + } + copy(unresolvedOuterAttrs = unresolvedOuterAttrs) + } override def withNewHint(hint: Option[HintInfo]): LateralSubquery = copy(hint = hint) override def toString: String = s"lateral-subquery#${exprId.id} $conditionString" override lazy val canonicalized: Expression = { LateralSubquery( plan.canonicalized, outerAttrs.map(_.canonicalized), + unresolvedOuterAttrs.map(_.canonicalized), ExprId(0), joinCond.map(_.canonicalized)) } @@ -500,13 +529,15 @@ case class LateralSubquery( case class ListQuery( plan: LogicalPlan, outerAttrs: Seq[Expression] = Seq.empty, + unresolvedOuterAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, // The plan of list query may have more columns after de-correlation, and we need to track the // number of the columns of the original plan, to report the data type properly. numCols: Int = -1, joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None) - extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { + extends SubqueryExpression( + plan, outerAttrs, unresolvedOuterAttrs, exprId, joinCond, hint) with Unevaluable { def childOutputs: Seq[Attribute] = plan.output.take(numCols) override def dataType: DataType = if (numCols > 1) { childOutputs.toStructType @@ -526,12 +557,21 @@ case class ListQuery( override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) override def withNewOuterAttrs(outerAttrs: Seq[Expression]): ListQuery = copy( outerAttrs = outerAttrs) + override def withNewUnresolvedOuterAttrs( + unresolvedOuterAttrs: Seq[Expression] + ): ListQuery = { + if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { + // TODO(avery): create suitable error subclass to throw + } + copy(unresolvedOuterAttrs = unresolvedOuterAttrs) + } override def withNewHint(hint: Option[HintInfo]): ListQuery = copy(hint = hint) override def toString: String = s"list#${exprId.id} $conditionString" override lazy val canonicalized: Expression = { ListQuery( plan.canonicalized, outerAttrs.map(_.canonicalized), + unresolvedOuterAttrs.map(_.canonicalized), ExprId(0), numCols, joinCond.map(_.canonicalized)) @@ -574,22 +614,32 @@ case class ListQuery( case class Exists( plan: LogicalPlan, outerAttrs: Seq[Expression] = Seq.empty, + unresolvedOuterAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None) - extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) + extends SubqueryExpression(plan, outerAttrs, unresolvedOuterAttrs, exprId, joinCond, hint) with Predicate with Unevaluable { override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) override def withNewOuterAttrs(outerAttrs: Seq[Expression]): Exists = copy( outerAttrs = outerAttrs) + override def withNewUnresolvedOuterAttrs( + unresolvedOuterAttrs: Seq[Expression] + ): Exists = { + if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) { + // TODO(avery): create suitable error subclass to throw + } + copy(unresolvedOuterAttrs = unresolvedOuterAttrs) + } override def withNewHint(hint: Option[HintInfo]): Exists = copy(hint = hint) override def toString: String = s"exists#${exprId.id} $conditionString" override lazy val canonicalized: Expression = { Exists( plan.canonicalized, outerAttrs.map(_.canonicalized), + unresolvedOuterAttrs.map(_.canonicalized), ExprId(0), joinCond.map(_.canonicalized)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9d269f37e58b9..c2738b736c4c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -364,7 +364,7 @@ abstract class Optimizer(catalogManager: CatalogManager) case d: DynamicPruningSubquery => d case s @ ScalarSubquery( PhysicalOperation(projections, predicates, a @ Aggregate(group, _, child, _)), - _, _, _, _, mayHaveCountBug, _) + _, _, _, _, _, mayHaveCountBug, _) if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) && mayHaveCountBug.nonEmpty && mayHaveCountBug.get => // This is a subquery with an aggregate that may suffer from a COUNT bug. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index e867953bcf282..e6d553e80d4eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -90,7 +90,7 @@ object ConstantFolding extends Rule[LogicalPlan] { } // Don't replace ScalarSubquery if its plan is an aggregate that may suffer from a COUNT bug. - case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug, _) + case s @ ScalarSubquery(_, _, _, _, _, _, mayHaveCountBug, _) if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) && mayHaveCountBug.nonEmpty && mayHaveCountBug.get => s @@ -891,7 +891,7 @@ object NullPropagation extends Rule[LogicalPlan] { case InSubquery(Seq(Literal(null, _)), _) if SQLConf.get.legacyNullInEmptyBehavior => Literal.create(null, BooleanType) - case InSubquery(Seq(Literal(null, _)), ListQuery(sub, _, _, _, conditions, _)) + case InSubquery(Seq(Literal(null, _)), ListQuery(sub, _, _, _, _, conditions, _)) if !SQLConf.get.legacyNullInEmptyBehavior && conditions.isEmpty => If(Exists(sub), Literal(null, BooleanType), FalseLiteral) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 5a4e9f37c3951..6dda96118327b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -131,17 +131,17 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { - case (p, Exists(sub, _, _, conditions, subHint)) => + case (p, Exists(sub, _, _, _, conditions, subHint)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) val join = buildJoin(outerPlan, rewriteDomainJoinsIfPresent(outerPlan, sub, joinCond), LeftSemi, joinCond, subHint) Project(p.output, join) - case (p, Not(Exists(sub, _, _, conditions, subHint))) => + case (p, Not(Exists(sub, _, _, _, conditions, subHint))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) val join = buildJoin(outerPlan, rewriteDomainJoinsIfPresent(outerPlan, sub, joinCond), LeftAnti, joinCond, subHint) Project(p.output, join) - case (p, InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint))) => + case (p, InSubquery(values, ListQuery(sub, _, _, _, _, conditions, subHint))) => // Deduplicate conflicting attributes if any. val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values)) val inConditions = values.zip(newSub.output).map(EqualTo.tupled) @@ -149,7 +149,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val join = Join(outerPlan, rewriteDomainJoinsIfPresent(outerPlan, newSub, joinCond), LeftSemi, joinCond, JoinHint(None, subHint)) Project(p.output, join) - case (p, Not(InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint)))) => + case (p, Not(InSubquery(values, ListQuery(sub, _, _, _, _, conditions, subHint)))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. @@ -319,7 +319,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val introducedAttrs = ArrayBuffer.empty[Attribute] val newExprs = exprs.map { e => e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) { - case Exists(sub, _, _, conditions, subHint) => + case Exists(sub, _, _, _, conditions, subHint) => val exists = AttributeReference("exists", BooleanType, nullable = false)() val existenceJoin = ExistenceJoin(exists) val newCondition = conditions.reduceLeftOption(And) @@ -328,7 +328,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { existenceJoin, newCondition, subHint) introducedAttrs += exists exists - case Not(InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint))) => + case Not(InSubquery(values, ListQuery(sub, _, _, _, _, conditions, subHint))) => val exists = AttributeReference("exists", BooleanType, nullable = false)() // Deduplicate conflicting attributes if any. val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values)) @@ -353,7 +353,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { ExistenceJoin(exists), Some(finalJoinCond), joinHint) introducedAttrs += exists Not(exists) - case InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint)) => + case InSubquery(values, ListQuery(sub, _, _, _, _, conditions, subHint)) => val exists = AttributeReference("exists", BooleanType, nullable = false)() // Deduplicate conflicting attributes if any. val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values)) @@ -506,7 +506,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper } plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { - case ScalarSubquery(sub, children, exprId, conditions, hint, + case ScalarSubquery(sub, children, unresolvedOuterAttrs, exprId, conditions, hint, mayHaveCountBugOld, needSingleJoinOld) if children.nonEmpty => @@ -558,26 +558,32 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper } else { conf.getConf(SQLConf.SCALAR_SUBQUERY_USE_SINGLE_JOIN) && !guaranteedToReturnOneRow(sub) } - ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), + ScalarSubquery(newPlan, children, unresolvedOuterAttrs, + exprId, getJoinCondition(newCond, conditions), hint, Some(mayHaveCountBug), Some(needSingleJoin)) - case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty => + case Exists(sub, children, unresolvedOuterAttrs, exprId, conditions, hint) + if children.nonEmpty => val (newPlan, newCond) = if (SQLConf.get.decorrelateInnerQueryEnabledForExistsIn) { decorrelate(sub, plan, handleCountBug = true) } else { pullOutCorrelatedPredicates(sub, plan) } - Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint) - case ListQuery(sub, children, exprId, numCols, conditions, hint) if children.nonEmpty => + Exists(newPlan, children, unresolvedOuterAttrs, + exprId, getJoinCondition(newCond, conditions), hint) + case ListQuery(sub, children, unresolvedOuterAttrs, exprId, numCols, conditions, hint) + if children.nonEmpty => val (newPlan, newCond) = if (SQLConf.get.decorrelateInnerQueryEnabledForExistsIn) { decorrelate(sub, plan, handleCountBug = true) } else { pullOutCorrelatedPredicates(sub, plan) } val joinCond = getJoinCondition(newCond, conditions) - ListQuery(newPlan, children, exprId, numCols, joinCond, hint) - case LateralSubquery(sub, children, exprId, conditions, hint) if children.nonEmpty => + ListQuery(newPlan, children, unresolvedOuterAttrs, exprId, numCols, joinCond, hint) + case LateralSubquery(sub, children, unresolvedOuterAttrs, exprId, conditions, hint) + if children.nonEmpty => val (newPlan, newCond) = decorrelate(sub, plan, handleCountBug = true) - LateralSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint) + LateralSubquery(newPlan, children, + unresolvedOuterAttrs, exprId, getJoinCondition(newCond, conditions), hint) } } @@ -817,7 +823,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = { val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]() val newChild = subqueries.foldLeft(child) { - case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug, + case (currentChild, ScalarSubquery(sub, _, _, _, conditions, subHint, mayHaveCountBug, needSingleJoin)) => val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions) val origOutput = query.output.head @@ -1004,7 +1010,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe object RewriteLateralSubquery extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( _.containsPattern(LATERAL_JOIN)) { - case LateralJoin(left, LateralSubquery(sub, _, _, joinCond, subHint), joinType, condition) => + case LateralJoin(left, LateralSubquery(sub, _, _, _, joinCond, subHint), joinType, condition) => val newRight = DecorrelateInnerQuery.rewriteDomainJoins(left, sub, joinCond) val newCond = (condition ++ joinCond).reduceOption(And) // The subquery appears on the right side of the join, hence add the hint to the right side @@ -1041,7 +1047,7 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { */ private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries { case LateralJoin( - left, right @ LateralSubquery(OneRowSubquery(plan), _, _, _, _), _, None) + left, right @ LateralSubquery(OneRowSubquery(plan), _, _, _, _, _), _, None) if !hasCorrelatedSubquery(right.plan) && right.joinCond.isEmpty => plan match { case Project(projectList, _: OneRowRelation) => @@ -1064,7 +1070,8 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { case p: LogicalPlan => p.transformExpressionsUpWithPruning( _.containsPattern(SCALAR_SUBQUERY)) { - case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _) + case s @ ScalarSubquery( + OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _, _) if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty => assert(p.projectList.size == 1) stripOuterReferences(p.projectList).head diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d4298087b2108..a12cbd991c208 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4009,6 +4009,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SUPPORT_NESTED_CORRELATED_SUBQUERIES = + buildConf("spark.sql.optimizer.supportNestedCorrelatedSubqueries.enabled") + .internal() + .doc("If enabled, support nested correlated subqueries") + .version("4.1.0") + .booleanConf + .createWithDefault(false) + val PULL_OUT_NESTED_DATA_OUTER_REF_EXPRESSIONS_ENABLED = buildConf("spark.sql.optimizer.pullOutNestedDataOuterRefExpressions.enabled") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index 5f2638655c37c..dd38c1e94e481 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -30,11 +30,11 @@ case class PlanAdaptiveSubqueries( def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning( _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { - case expressions.ScalarSubquery(_, _, exprId, _, _, _, _) => + case expressions.ScalarSubquery(_, _, _, exprId, _, _, _, _) => val subquery = SubqueryExec.createForScalarSubquery( s"subquery#${exprId.id}", subqueryMap(exprId.id)) execution.ScalarSubquery(subquery, exprId) - case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _, _)) => + case expressions.InSubquery(values, ListQuery(_, _, _, exprId, _, _, _)) => val expr = if (values.length == 1) { values.head } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 771c743f70629..32468c84cc9df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -187,7 +187,7 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { SubqueryExec.createForScalarSubquery( s"scalar-subquery#${subquery.exprId.id}", executedPlan), subquery.exprId) - case expressions.InSubquery(values, ListQuery(query, _, exprId, _, _, _)) => + case expressions.InSubquery(values, ListQuery(query, _, _, exprId, _, _, _)) => val expr = if (values.length == 1) { values.head } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1b8f596a999b7..98b9da118b6a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4941,6 +4941,108 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Row(Array(0), Array(0)), Row(Array(1), Array(1)), Row(Array(2), Array(2))) checkAnswer(df, expectedAnswer) } + + test("SPARK-50983: Support nested correlated subquery") { + sql("Create table Sales (product_id int, amount int)") + sql("create table Products (product_id int, category_id int, price double)") + + withTable("Sales", "Products") { + sql("insert into Sales values (1, 10), (2, 20), (3, 30), (4, 40), (5, 50)") + sql( + "insert into " + + "Products values (1, 1, 100.0), (2, 1, 200.0)," + + " (3, 2, 300.0), (4, 2, 400.0), (5, 3, 500.0)" + ) + val query = + """ + |SELECT p.product_id, p.category_id, p.price + |FROM Products p + |WHERE p.product_id IN ( + | SELECT s.product_id + | FROM Sales s + | WHERE s.amount > ( + | SELECT AVG(s2.amount) + | FROM Sales s2 + | WHERE s2.product_id IN ( + | SELECT p2.product_id + | FROM Products p2 + | WHERE p2.category_id = p.category_id + | ) + | ) + |); + |""".stripMargin + + val query2 = + """ + |SELECT p.product_id, p.category_id, p.price + |FROM Products p + |WHERE p.product_id IN ( + | SELECT s.product_id + | FROM Sales s + | WHERE s.product_id = p.product_id + |); + |""".stripMargin + + val query3 = + """ + |SELECT p.product_id, p.category_id, p.price + |FROM Products p + |WHERE p.product_id IN ( + | SELECT s.product_id + | FROM Sales s + | WHERE s.amount > ( + | SELECT AVG(s2.amount) + | FROM Sales s2 + | WHERE s2.product_id IN ( + | SELECT p2.product_id + | FROM Products p2 + | ) + | ) + |); + |""".stripMargin + withSQLConf( + "spark.sql.planChangeLog.level" -> "info", + "spark.sql.optimizer.supportNestedCorrelatedSubqueries.enabled" -> "false" + ) { + try { + sql(query).collect() + } catch { + case e: AnalysisException => + assert(e.errorClass.isDefined && e.errorClass.get == + "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NESTED_CORRELATED_SUBQUERIES_NOT_SUPPORTED") + } + } + } + } + + test("SPARK-50983: Support nested correlated subquery for analyzer") { + sql("Create table Sales (product_id int, amount int)") + sql("create table Products (product_id int, category_id int, price double)") + + val query = + """ + | SELECT s.product_id + | FROM Sales s + | WHERE s.amount > ( + | SELECT AVG(s2.amount) + | FROM Sales s2 + | WHERE s2.product_id IN ( + | SELECT p2.product_id + | FROM Products p2 + | WHERE p2.product_id = s2.product_id + | ) + | ); + |""".stripMargin + + withTable("Sales", "Products") { + withSQLConf( + "spark.sql.optimizer.supportNestedCorrelatedSubqueries.enabled" -> "true", + "spark.sql.planChangeLog.level" -> "info" + ) { + val res = sql(query).collect() + } + } + } } case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 2cc203129817b..726b9920aef31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -1008,7 +1008,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { query match { case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()), UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))), - _, _, _, _, _) => + _, _, _, _, _, _) => assert(projects.size == 1 && projects.head.name == "s.name") assert(outputColumnNames.size == 1 && outputColumnNames.head == "name") case o => fail("Unexpected subquery: \n" + o.treeString) @@ -1089,7 +1089,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { query match { case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()), UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))), - _, _, _, _, _) => + _, _, _, _, _, _) => assert(projects.size == 1 && projects.head.name == "s.name") assert(outputColumnNames.size == 1 && outputColumnNames.head == "name") case o => fail("Unexpected subquery: \n" + o.treeString)