Skip to content

Commit 4e52066

Browse files
address comments
1 parent c53123d commit 4e52066

File tree

4 files changed

+42
-68
lines changed

4 files changed

+42
-68
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,6 @@ case class ShufflePartitionIdPassThrough(
638638
expr: DirectShufflePartitionID,
639639
numPartitions: Int) extends Expression with Partitioning with Unevaluable {
640640

641-
// We don't support creating partitioning for ShufflePartitionIdPassThrough.
642641
override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = {
643642
ShufflePartitionIdPassThroughSpec(this, distribution)
644643
}
@@ -1009,6 +1008,7 @@ case class ShufflePartitionIdPassThroughSpec(
10091008
false
10101009
}
10111010

1011+
// We don't support creating partitioning for ShufflePartitionIdPassThrough.
10121012
override def canCreatePartitioning: Boolean = false
10131013

10141014
override def numPartitions: Int = partitioning.numPartitions

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper {
516516
val dist3 = ClusteredDistribution(Seq($"e", $"b"))
517517
checkCompatible(
518518
p1.createShuffleSpec(dist3),
519-
p2.createShuffleSpec(dist2),
519+
p2.createShuffleSpec(dist),
520520
expected = false
521521
)
522522
}

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -165,27 +165,28 @@ case class EnsureRequirements(
165165
// Check if the following conditions are satisfied:
166166
// 1. There are exactly two children (e.g., join). Note that Spark doesn't support
167167
// multi-way join at the moment, so this check should be sufficient.
168-
// 2. All children are of `KeyGroupedPartitioning`, and they are compatible with each other
168+
// 2. All children are of the same partitioning, and they are compatible with each other
169169
// If both are true, skip shuffle.
170-
val isKeyGroupCompatible = parent.isDefined &&
170+
val areChildrenCompatible = parent.isDefined &&
171171
children.length == 2 && childrenIndexes.length == 2 && {
172172
val left = children.head
173173
val right = children(1)
174+
175+
// key group compatibility check
174176
val newChildren = checkKeyGroupCompatible(
175177
parent.get, left, right, requiredChildDistributions)
176178
if (newChildren.isDefined) {
177179
children = newChildren.get
180+
true
181+
} else {
182+
// If key group check fails, check ShufflePartitionIdPassThrough compatibility
183+
checkShufflePartitionIdPassThroughCompatible(
184+
left, right, requiredChildDistributions)
178185
}
179-
newChildren.isDefined
180186
}
181187

182-
val isShufflePassThroughCompatible = !isKeyGroupCompatible &&
183-
parent.isDefined && children.length == 2 && childrenIndexes.length == 2 &&
184-
checkShufflePartitionIdPassThroughCompatible(
185-
children.head, children(1), requiredChildDistributions)
186-
187188
children = children.zip(requiredChildDistributions).zipWithIndex.map {
188-
case ((child, _), idx) if isKeyGroupCompatible || isShufflePassThroughCompatible ||
189+
case ((child, _), idx) if areChildrenCompatible ||
189190
!childrenIndexes.contains(idx) =>
190191
child
191192
case ((child, dist), idx) =>

sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala

Lines changed: 30 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,7 +1197,7 @@ class EnsureRequirementsSuite extends SharedSparkSession {
11971197
TransformExpression(BucketFunction, expr, Some(numBuckets))
11981198
}
11991199

1200-
test("ShufflePartitionIdPassThrough - avoid necessary shuffle when they are compatible") {
1200+
test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when children are compatible") {
12011201
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
12021202
val plan1 = DummySparkPlan(
12031203
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
@@ -1253,7 +1253,8 @@ class EnsureRequirementsSuite extends SharedSparkSession {
12531253
val plan2 = DummySparkPlan(
12541254
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5))
12551255
// Join on different keys than partitioning keys
1256-
val smjExec = SortMergeJoinExec(exprB :: Nil, exprD :: Nil, Inner, None, plan1, plan2)
1256+
val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprD :: exprC :: Nil, Inner, None,
1257+
plan1, plan2)
12571258

12581259
EnsureRequirements.apply(smjExec) match {
12591260
case SortMergeJoinExec(_, _, _, _,
@@ -1306,32 +1307,6 @@ class EnsureRequirementsSuite extends SharedSparkSession {
13061307
}
13071308
}
13081309

1309-
test("ShufflePartitionIdPassThrough - incompatible due to different expressions " +
1310-
"with same base column") {
1311-
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
1312-
// Even though both use exprA as base and have same numPartitions,
1313-
// different Pmod operations make them incompatible
1314-
val plan1 = DummySparkPlan(
1315-
outputPartitioning = ShufflePartitionIdPassThrough(
1316-
DirectShufflePartitionID(Pmod(exprA, Literal(10))), 5))
1317-
val plan2 = DummySparkPlan(
1318-
outputPartitioning = ShufflePartitionIdPassThrough(
1319-
DirectShufflePartitionID(Pmod(exprA, Literal(5))), 5))
1320-
val smjExec = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, plan1, plan2)
1321-
1322-
EnsureRequirements.apply(smjExec) match {
1323-
case SortMergeJoinExec(_, _, _, _,
1324-
SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _),
1325-
SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) =>
1326-
// Both sides should be shuffled due to expression mismatch
1327-
assert(p1.numPartitions == 10)
1328-
assert(p2.numPartitions == 10)
1329-
assert(p1.expressions == Seq(exprA))
1330-
assert(p2.expressions == Seq(exprA))
1331-
case other => fail(s"Expected shuffles on both sides, but got: $other")
1332-
}
1333-
}
1334-
}
13351310

13361311
test("ShufflePartitionIdPassThrough - compatible with multiple clustering keys") {
13371312
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
@@ -1359,6 +1334,33 @@ class EnsureRequirementsSuite extends SharedSparkSession {
13591334
case other => fail(s"We don't expect shuffle on neither sides with multiple " +
13601335
s"clustering keys, but got: $other")
13611336
}
1337+
1338+
// Test case 2: partition key matches at position 1
1339+
// Both sides partitioned by exprB and join on (exprA, exprB)
1340+
// Should be compatible because partition key exprB matches at position 1 in join keys
1341+
val plan3 = DummySparkPlan(
1342+
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5))
1343+
val plan4 = DummySparkPlan(
1344+
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5))
1345+
val smjExec2 = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None,
1346+
plan3, plan4)
1347+
1348+
EnsureRequirements.apply(smjExec2) match {
1349+
case SortMergeJoinExec(
1350+
leftKeys,
1351+
rightKeys,
1352+
_,
1353+
_,
1354+
SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _),
1355+
SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _),
1356+
_
1357+
) =>
1358+
// No shuffles because exprB (partition key) appears at position 1 in join keys
1359+
assert(leftKeys === Seq(exprA, exprB))
1360+
assert(rightKeys === Seq(exprA, exprB))
1361+
case other => fail(s"Expected no shuffles due to position overlap at position 1, " +
1362+
s"but got: $other")
1363+
}
13621364
}
13631365
}
13641366

@@ -1414,35 +1416,6 @@ class EnsureRequirementsSuite extends SharedSparkSession {
14141416
}
14151417
}
14161418

1417-
test("ShufflePartitionIdPassThrough - compatible when partition key matches at any position") {
1418-
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
1419-
// Both sides partitioned by exprB and join on (exprA, exprB)
1420-
// Should be compatible because partition key exprB matches at position 1 in join keys
1421-
val plan1 = DummySparkPlan(
1422-
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5))
1423-
val plan2 = DummySparkPlan(
1424-
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5))
1425-
val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None,
1426-
plan1, plan2)
1427-
1428-
EnsureRequirements.apply(smjExec) match {
1429-
case SortMergeJoinExec(
1430-
leftKeys,
1431-
rightKeys,
1432-
_,
1433-
_,
1434-
SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _),
1435-
SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _),
1436-
_
1437-
) =>
1438-
// No shuffles because exprB (partition key) appears at position 1 in join keys
1439-
assert(leftKeys === Seq(exprA, exprB))
1440-
assert(rightKeys === Seq(exprA, exprB))
1441-
case other => fail(s"Expected no shuffles due to position overlap at position 1, " +
1442-
s"but got: $other")
1443-
}
1444-
}
1445-
}
14461419

14471420
def years(expr: Expression): TransformExpression = {
14481421
TransformExpression(YearsFunction, Seq(expr))

0 commit comments

Comments
 (0)