Skip to content

Commit f907b5b

Browse files
ckp
1 parent 69df835 commit f907b5b

File tree

2 files changed

+63
-70
lines changed

2 files changed

+63
-70
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -482,68 +482,62 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper {
482482
}
483483

484484
test("compatibility: ShufflePartitionIdPassThroughSpec on both sides") {
485-
val dist = ClusteredDistribution(Seq($"a", $"b"))
486-
val p1 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10)
487-
val p2 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10)
485+
val ab = ClusteredDistribution(Seq($"a", $"b"))
486+
val cd = ClusteredDistribution(Seq($"c", $"d"))
487+
val passThrough_a_10 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10)
488488

489489
// Identical specs should be compatible
490490
checkCompatible(
491-
p1.createShuffleSpec(dist),
492-
p2.createShuffleSpec(ClusteredDistribution(Seq($"c", $"d"))),
491+
passThrough_a_10.createShuffleSpec(ab),
492+
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10).createShuffleSpec(cd),
493493
expected = true
494494
)
495495

496496
// Different number of partitions should be incompatible
497-
val p3 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 5)
498497
checkCompatible(
499-
p1.createShuffleSpec(dist),
500-
p3.createShuffleSpec(ClusteredDistribution(Seq($"c", $"d"))),
498+
passThrough_a_10.createShuffleSpec(ab),
499+
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 5).createShuffleSpec(cd),
501500
expected = false
502501
)
503502

504503
// Mismatched key positions should be incompatible
505-
val dist1 = ClusteredDistribution(Seq($"a", $"b"))
506-
val p4 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"b"), 10) // Key at pos 1
507-
val dist2 = ClusteredDistribution(Seq($"c", $"d"))
508-
val p5 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10) // Key at pos 0
509504
checkCompatible(
510-
p4.createShuffleSpec(dist1),
511-
p5.createShuffleSpec(dist2),
505+
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"b"), 10).createShuffleSpec(ab),
506+
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10).createShuffleSpec(cd),
512507
expected = false
513508
)
514509

515510
// Mismatched clustering keys
516-
val dist3 = ClusteredDistribution(Seq($"e", $"b"))
517511
checkCompatible(
518-
p1.createShuffleSpec(dist3),
519-
p2.createShuffleSpec(dist),
512+
passThrough_a_10.createShuffleSpec(ClusteredDistribution(Seq($"e", $"b"))),
513+
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10).createShuffleSpec(ab),
520514
expected = false
521515
)
522516
}
523517

524518
test("compatibility: ShufflePartitionIdPassThroughSpec vs other specs") {
525-
val dist = ClusteredDistribution(Seq($"a", $"b"))
519+
val ab = ClusteredDistribution(Seq($"a", $"b"))
520+
val cd = ClusteredDistribution(Seq($"c", $"d"))
521+
val passThrough_a_10 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10)
526522

527523
// Compatibility with SinglePartitionShuffleSpec when numPartitions is 1
528-
val p1 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 1)
529524
checkCompatible(
530-
p1.createShuffleSpec(dist),
525+
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 1).createShuffleSpec(ab),
531526
SinglePartitionShuffleSpec,
532527
expected = true
533528
)
534529

535530
// Incompatible with SinglePartitionShuffleSpec when numPartitions > 1
536531
checkCompatible(
537-
p.createShuffleSpec(dist),
532+
passThrough_a_10.createShuffleSpec(ab),
538533
SinglePartitionShuffleSpec,
539534
expected = false
540535
)
541536

542537
// Incompatible with HashShuffleSpec
543-
val p2 = HashPartitioning(Seq($"c"), 10)
544538
checkCompatible(
545-
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10).createShuffleSpec(dist),
546-
p2.createShuffleSpec(ClusteredDistribution(Seq($"c", $"d"))),
539+
passThrough_a_10.createShuffleSpec(ab),
540+
HashShuffleSpec(HashPartitioning(Seq($"c"), 10), cd),
547541
expected = false
548542
)
549543
}

sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,13 +1199,13 @@ class EnsureRequirementsSuite extends SharedSparkSession {
11991199

12001200
test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when children are compatible") {
12011201
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
1202-
val plan1 = DummySparkPlan(
1203-
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
1204-
val plan2 = DummySparkPlan(
1205-
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
1206-
val smjExec = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, plan1, plan2)
1202+
val passThrough_a_5 = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)
12071203

1208-
EnsureRequirements.apply(smjExec) match {
1204+
val leftPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
1205+
val rightPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
1206+
val join = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, leftPlan, rightPlan)
1207+
1208+
EnsureRequirements.apply(join) match {
12091209
case SortMergeJoinExec(
12101210
leftKeys,
12111211
rightKeys,
@@ -1225,13 +1225,13 @@ class EnsureRequirementsSuite extends SharedSparkSession {
12251225
test("ShufflePartitionIdPassThrough incompatibility - different partitions") {
12261226
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
12271227
// Different number of partitions - should add shuffles
1228-
val plan1 = DummySparkPlan(
1228+
val leftPlan = DummySparkPlan(
12291229
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
1230-
val plan2 = DummySparkPlan(
1230+
val rightPlan = DummySparkPlan(
12311231
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 8))
1232-
val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, plan1, plan2)
1232+
val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, leftPlan, rightPlan)
12331233

1234-
EnsureRequirements.apply(smjExec) match {
1234+
EnsureRequirements.apply(join) match {
12351235
case SortMergeJoinExec(_, _, _, _,
12361236
SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _),
12371237
SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) =>
@@ -1248,15 +1248,15 @@ class EnsureRequirementsSuite extends SharedSparkSession {
12481248
test("ShufflePartitionIdPassThrough incompatibility - key position mismatch") {
12491249
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
12501250
// Key position mismatch - should add shuffles
1251-
val plan1 = DummySparkPlan(
1251+
val leftPlan = DummySparkPlan(
12521252
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
1253-
val plan2 = DummySparkPlan(
1253+
val rightPlan = DummySparkPlan(
12541254
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5))
12551255
// Join on different keys than partitioning keys
1256-
val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprD :: exprC :: Nil, Inner, None,
1257-
plan1, plan2)
1256+
val join = SortMergeJoinExec(exprA :: exprB :: Nil, exprD :: exprC :: Nil, Inner, None,
1257+
leftPlan, rightPlan)
12581258

1259-
EnsureRequirements.apply(smjExec) match {
1259+
EnsureRequirements.apply(join) match {
12601260
case SortMergeJoinExec(_, _, _, _,
12611261
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
12621262
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), _) =>
@@ -1269,13 +1269,13 @@ class EnsureRequirementsSuite extends SharedSparkSession {
12691269
test("ShufflePartitionIdPassThrough vs HashPartitioning - always shuffles") {
12701270
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
12711271
// ShufflePartitionIdPassThrough vs HashPartitioning - always adds shuffles
1272-
val plan1 = DummySparkPlan(
1272+
val leftPlan = DummySparkPlan(
12731273
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
1274-
val plan2 = DummySparkPlan(
1274+
val rightPlan = DummySparkPlan(
12751275
outputPartitioning = HashPartitioning(exprB :: Nil, 5))
1276-
val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, plan1, plan2)
1276+
val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, leftPlan, rightPlan)
12771277

1278-
EnsureRequirements.apply(smjExec) match {
1278+
EnsureRequirements.apply(join) match {
12791279
case SortMergeJoinExec(_, _, _, _,
12801280
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
12811281
SortExec(_, _, _: DummySparkPlan, _), _) =>
@@ -1292,12 +1292,12 @@ class EnsureRequirementsSuite extends SharedSparkSession {
12921292
test("ShufflePartitionIdPassThrough vs SinglePartition - shuffles added") {
12931293
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
12941294
// Even when compatible (numPartitions=1), shuffles added due to canCreatePartitioning=false
1295-
val plan1 = DummySparkPlan(
1295+
val leftPlan = DummySparkPlan(
12961296
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 1))
1297-
val plan2 = DummySparkPlan(outputPartitioning = SinglePartition)
1298-
val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, plan1, plan2)
1297+
val rightPlan = DummySparkPlan(outputPartitioning = SinglePartition)
1298+
val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, leftPlan, rightPlan)
12991299

1300-
EnsureRequirements.apply(smjExec) match {
1300+
EnsureRequirements.apply(join) match {
13011301
case SortMergeJoinExec(_, _, _, _,
13021302
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
13031303
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), _) =>
@@ -1310,16 +1310,17 @@ class EnsureRequirementsSuite extends SharedSparkSession {
13101310

13111311
test("ShufflePartitionIdPassThrough - compatible with multiple clustering keys") {
13121312
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
1313+
val passThrough_a_5 = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)
1314+
val passThrough_b_5 = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)
1315+
13131316
// Both partitioned by exprA, joined on (exprA, exprB)
13141317
// Should be compatible because exprA positions overlap
1315-
val plan1 = DummySparkPlan(
1316-
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
1317-
val plan2 = DummySparkPlan(
1318-
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
1319-
val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None,
1320-
plan1, plan2)
1318+
val leftPlanA = DummySparkPlan(outputPartitioning = passThrough_a_5)
1319+
val rightPlanA = DummySparkPlan(outputPartitioning = passThrough_a_5)
1320+
val joinA = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None,
1321+
leftPlanA, rightPlanA)
13211322

1322-
EnsureRequirements.apply(smjExec) match {
1323+
EnsureRequirements.apply(joinA) match {
13231324
case SortMergeJoinExec(
13241325
leftKeys,
13251326
rightKeys,
@@ -1338,14 +1339,12 @@ class EnsureRequirementsSuite extends SharedSparkSession {
13381339
// Test case 2: partition key matches at position 1
13391340
// Both sides partitioned by exprB and join on (exprA, exprB)
13401341
// Should be compatible because partition key exprB matches at position 1 in join keys
1341-
val plan3 = DummySparkPlan(
1342-
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5))
1343-
val plan4 = DummySparkPlan(
1344-
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5))
1345-
val smjExec2 = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None,
1346-
plan3, plan4)
1342+
val leftPlanB = DummySparkPlan(outputPartitioning = passThrough_b_5)
1343+
val rightPlanB = DummySparkPlan(outputPartitioning = passThrough_b_5)
1344+
val joinB = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None,
1345+
leftPlanB, rightPlanB)
13471346

1348-
EnsureRequirements.apply(smjExec2) match {
1347+
EnsureRequirements.apply(joinB) match {
13491348
case SortMergeJoinExec(
13501349
leftKeys,
13511350
rightKeys,
@@ -1368,13 +1367,13 @@ class EnsureRequirementsSuite extends SharedSparkSession {
13681367
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
13691368
// Partitioned by exprA and exprB respectively, but joining on completely different keys
13701369
// Should require shuffles because partition keys don't match join keys
1371-
val plan1 = DummySparkPlan(
1370+
val leftPlan = DummySparkPlan(
13721371
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
1373-
val plan2 = DummySparkPlan(
1372+
val rightPlan = DummySparkPlan(
13741373
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5))
1375-
val smjExec = SortMergeJoinExec(exprC :: Nil, exprD :: Nil, Inner, None, plan1, plan2)
1374+
val join = SortMergeJoinExec(exprC :: Nil, exprD :: Nil, Inner, None, leftPlan, rightPlan)
13761375

1377-
EnsureRequirements.apply(smjExec) match {
1376+
EnsureRequirements.apply(join) match {
13781377
case SortMergeJoinExec(_, _, _, _,
13791378
SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _),
13801379
SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) =>
@@ -1395,14 +1394,14 @@ class EnsureRequirementsSuite extends SharedSparkSession {
13951394
// Test if cross-position matching works: left partition key exprA matches right join key
13961395
// exprA (pos 0)
13971396
// and right partition key exprB matches left join key exprB (pos 1)
1398-
val plan1 = DummySparkPlan(
1397+
val leftPlan = DummySparkPlan(
13991398
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
1400-
val plan2 = DummySparkPlan(
1399+
val rightPlan = DummySparkPlan(
14011400
outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5))
1402-
val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None,
1403-
plan1, plan2)
1401+
val join = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None,
1402+
leftPlan, rightPlan)
14041403

1405-
EnsureRequirements.apply(smjExec) match {
1404+
EnsureRequirements.apply(join) match {
14061405
case SortMergeJoinExec(_, _, _, _,
14071406
SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _),
14081407
SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) =>

0 commit comments

Comments
 (0)