Skip to content

Commit

Permalink
[SPARK-51259][SQL] Refactor natural and using join keys computation
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Refactor natural and using join key computation to a separate component so that it can be reused in single-pass resolver.

### Why are the changes needed?
To reuse code in single-pass resolver.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #50009 from mihailotim-db/mihailotim-db/join_refactor.

Authored-by: Mihailo Timotic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
mihailotim-db authored and cloud-fan committed Feb 20, 2025
1 parent f447c43 commit a661f9f
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy}
Expand Down Expand Up @@ -3564,55 +3564,17 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
hint: JoinHint): LogicalPlan = {
import org.apache.spark.sql.catalyst.util._

val leftKeys = joinNames.map { keyName =>
left.output.find(attr => resolver(attr.name, keyName)).getOrElse {
throw QueryCompilationErrors.unresolvedUsingColForJoinError(
keyName, left.schema.fieldNames.sorted.map(toSQLId).mkString(", "), "left")
}
}
val rightKeys = joinNames.map { keyName =>
right.output.find(attr => resolver(attr.name, keyName)).getOrElse {
throw QueryCompilationErrors.unresolvedUsingColForJoinError(
keyName, right.schema.fieldNames.sorted.map(toSQLId).mkString(", "), "right")
}
}
val joinPairs = leftKeys.zip(rightKeys)

val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And)

// columns not in joinPairs
val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))

// the output list looks like: join keys, columns from left, columns from right
val (projectList, hiddenList) = joinType match {
case LeftOuter =>
(leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)),
rightKeys.map(_.withNullability(true)))
case LeftExistence(_) =>
(leftKeys ++ lUniqueOutput, Seq.empty)
case RightOuter =>
(rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput,
leftKeys.map(_.withNullability(true)))
case FullOuter =>
// In full outer join, we should return non-null values for the join columns
// if either side has non-null values for those columns. Therefore, for each
// join column pair, add a coalesce to return the non-null value, if it exists.
val joinedCols = joinPairs.map { case (l, r) =>
// Since this is a full outer join, either side could be null, so we explicitly
// set the nullability to true for both sides.
Alias(Coalesce(Seq(l.withNullability(true), r.withNullability(true))), l.name)()
}
(joinedCols ++
lUniqueOutput.map(_.withNullability(true)) ++
rUniqueOutput.map(_.withNullability(true)),
leftKeys.map(_.withNullability(true)) ++
rightKeys.map(_.withNullability(true)))
case _ : InnerLike =>
(leftKeys ++ lUniqueOutput ++ rUniqueOutput, rightKeys)
case _ =>
throw QueryExecutionErrors.unsupportedNaturalJoinTypeError(joinType)
}
val (projectList, hiddenList, newCondition) =
NaturalAndUsingJoinResolution.computeJoinOutputsAndNewCondition(
left,
left.output,
right,
right.output,
joinType,
joinNames,
condition,
(attributeName, keyName) => resolver(attributeName, keyName)
)

// use Project to hide duplicated common keys
// propagate hidden columns from nested USING/NATURAL JOINs
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.{
Alias,
And,
Attribute,
Coalesce,
EqualTo,
Expression,
NamedExpression
}
import org.apache.spark.sql.catalyst.plans.{
FullOuter,
InnerLike,
JoinType,
LeftExistence,
LeftOuter,
RightOuter
}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.errors.{
DataTypeErrorsBase,
QueryCompilationErrors,
QueryExecutionErrors
}

object NaturalAndUsingJoinResolution extends DataTypeErrorsBase with SQLConfHelper {

/**
* For a given [[Join]], computes output, hidden output and new condition, if such exists.
*/
def computeJoinOutputsAndNewCondition(
left: LogicalPlan,
leftOutput: Seq[Attribute],
right: LogicalPlan,
rightOutput: Seq[Attribute],
joinType: JoinType,
joinNames: Seq[String],
condition: Option[Expression],
resolveName: (String, String) => Boolean)
: (Seq[NamedExpression], Seq[Attribute], Option[Expression]) = {
val (leftKeys, rightKeys) = resolveKeysForNaturalAndUsingJoin(
left,
leftOutput,
right,
rightOutput,
joinNames,
resolveName
)
val joinPairs = leftKeys.zip(rightKeys)

val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And)

// the output list looks like: join keys, columns from left, columns from right
val (output, hiddenOutput) = computeOutputAndHiddenOutput(
leftOutput,
leftKeys,
rightOutput,
rightKeys,
joinPairs,
joinType
)
(output, hiddenOutput, newCondition)
}

/**
* Returns resolved keys for joining based on the output of [[Join]]'s children or throws and
* error if a key name doesn't exist.
*/
private def resolveKeysForNaturalAndUsingJoin(
left: LogicalPlan,
leftOutput: Seq[Attribute],
right: LogicalPlan,
rightOutput: Seq[Attribute],
joinNames: Seq[String],
resolveName: (String, String) => Boolean): (Seq[Attribute], Seq[Attribute]) = {
val leftKeys = joinNames.map { keyName =>
leftOutput.find(attribute => resolveName(attribute.name, keyName)).getOrElse {
throw QueryCompilationErrors.unresolvedUsingColForJoinError(
keyName,
left.schema.fieldNames.sorted.map(toSQLId).mkString(", "),
"left"
)
}
}
val rightKeys = joinNames.map { keyName =>
rightOutput.find(attribute => resolveName(attribute.name, keyName)).getOrElse {
throw QueryCompilationErrors.unresolvedUsingColForJoinError(
keyName,
right.schema.fieldNames.sorted.map(toSQLId).mkString(", "),
"right"
)
}
}
(leftKeys, rightKeys)
}

/**
* Computes the output and hidden output for a given [[Join]], based on the output of its
* children.
*/
private def computeOutputAndHiddenOutput(
leftOutput: Seq[Attribute],
leftKeys: Seq[Attribute],
rightOutput: Seq[Attribute],
rightKeys: Seq[Attribute],
joinPairs: Seq[(Attribute, Attribute)],
joinType: JoinType): (Seq[NamedExpression], Seq[Attribute]) = {
// columns not in joinPairs
val lUniqueOutput = leftOutput.filterNot(att => leftKeys.contains(att))
val rUniqueOutput = rightOutput.filterNot(att => rightKeys.contains(att))
joinType match {
case LeftOuter =>
(
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)),
rightKeys.map(_.withNullability(true))
)
case LeftExistence(_) =>
(leftKeys ++ lUniqueOutput, Seq.empty)
case RightOuter =>
(
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput,
leftKeys.map(_.withNullability(true))
)
case FullOuter =>
// In full outer join, we should return non-null values for the join columns
// if either side has non-null values for those columns. Therefore, for each
// join column pair, add a coalesce to return the non-null value, if it exists.
val joinedCols = joinPairs.map {
case (l, r) =>
// Since this is a full outer join, either side could be null, so we explicitly
// set the nullability to true for both sides.
Alias(Coalesce(Seq(l.withNullability(true), r.withNullability(true))), l.name)()
}
(
joinedCols ++
lUniqueOutput.map(_.withNullability(true)) ++
rUniqueOutput.map(_.withNullability(true)),
leftKeys.map(_.withNullability(true)) ++
rightKeys.map(_.withNullability(true))
)
case _: InnerLike =>
(leftKeys ++ lUniqueOutput ++ rUniqueOutput, rightKeys)
case _ =>
throw QueryExecutionErrors.unsupportedNaturalJoinTypeError(joinType)
}
}
}

0 comments on commit a661f9f

Please sign in to comment.