@@ -1197,7 +1197,7 @@ class EnsureRequirementsSuite extends SharedSparkSession {
11971197 TransformExpression (BucketFunction , expr, Some (numBuckets))
11981198 }
11991199
1200- test(" ShufflePartitionIdPassThrough - avoid necessary shuffle when they are compatible" ) {
1200+ test(" ShufflePartitionIdPassThrough - avoid unnecessary shuffle when children are compatible" ) {
12011201 withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 10" ) {
12021202 val plan1 = DummySparkPlan (
12031203 outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprA), 5 ))
@@ -1253,7 +1253,8 @@ class EnsureRequirementsSuite extends SharedSparkSession {
12531253 val plan2 = DummySparkPlan (
12541254 outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprC), 5 ))
12551255 // Join on different keys than partitioning keys
1256- val smjExec = SortMergeJoinExec (exprB :: Nil , exprD :: Nil , Inner , None , plan1, plan2)
1256+ val smjExec = SortMergeJoinExec (exprA :: exprB :: Nil , exprD :: exprC :: Nil , Inner , None ,
1257+ plan1, plan2)
12571258
12581259 EnsureRequirements .apply(smjExec) match {
12591260 case SortMergeJoinExec (_, _, _, _,
@@ -1306,32 +1307,6 @@ class EnsureRequirementsSuite extends SharedSparkSession {
13061307 }
13071308 }
13081309
1309- test(" ShufflePartitionIdPassThrough - incompatible due to different expressions " +
1310- " with same base column" ) {
1311- withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 10" ) {
1312- // Even though both use exprA as base and have same numPartitions,
1313- // different Pmod operations make them incompatible
1314- val plan1 = DummySparkPlan (
1315- outputPartitioning = ShufflePartitionIdPassThrough (
1316- DirectShufflePartitionID (Pmod (exprA, Literal (10 ))), 5 ))
1317- val plan2 = DummySparkPlan (
1318- outputPartitioning = ShufflePartitionIdPassThrough (
1319- DirectShufflePartitionID (Pmod (exprA, Literal (5 ))), 5 ))
1320- val smjExec = SortMergeJoinExec (exprA :: Nil , exprA :: Nil , Inner , None , plan1, plan2)
1321-
1322- EnsureRequirements .apply(smjExec) match {
1323- case SortMergeJoinExec (_, _, _, _,
1324- SortExec (_, _, ShuffleExchangeExec (p1 : HashPartitioning , _, _, _), _),
1325- SortExec (_, _, ShuffleExchangeExec (p2 : HashPartitioning , _, _, _), _), _) =>
1326- // Both sides should be shuffled due to expression mismatch
1327- assert(p1.numPartitions == 10 )
1328- assert(p2.numPartitions == 10 )
1329- assert(p1.expressions == Seq (exprA))
1330- assert(p2.expressions == Seq (exprA))
1331- case other => fail(s " Expected shuffles on both sides, but got: $other" )
1332- }
1333- }
1334- }
13351310
13361311 test(" ShufflePartitionIdPassThrough - compatible with multiple clustering keys" ) {
13371312 withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 10" ) {
@@ -1359,6 +1334,33 @@ class EnsureRequirementsSuite extends SharedSparkSession {
13591334 case other => fail(s " We don't expect shuffle on neither sides with multiple " +
13601335 s " clustering keys, but got: $other" )
13611336 }
1337+
1338+ // Test case 2: partition key matches at position 1
1339+ // Both sides partitioned by exprB and join on (exprA, exprB)
1340+ // Should be compatible because partition key exprB matches at position 1 in join keys
1341+ val plan3 = DummySparkPlan (
1342+ outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprB), 5 ))
1343+ val plan4 = DummySparkPlan (
1344+ outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprB), 5 ))
1345+ val smjExec2 = SortMergeJoinExec (exprA :: exprB :: Nil , exprA :: exprB :: Nil , Inner , None ,
1346+ plan3, plan4)
1347+
1348+ EnsureRequirements .apply(smjExec2) match {
1349+ case SortMergeJoinExec (
1350+ leftKeys,
1351+ rightKeys,
1352+ _,
1353+ _,
1354+ SortExec (_, _, DummySparkPlan (_, _, _ : ShufflePartitionIdPassThrough , _, _), _),
1355+ SortExec (_, _, DummySparkPlan (_, _, _ : ShufflePartitionIdPassThrough , _, _), _),
1356+ _
1357+ ) =>
1358+ // No shuffles because exprB (partition key) appears at position 1 in join keys
1359+ assert(leftKeys === Seq (exprA, exprB))
1360+ assert(rightKeys === Seq (exprA, exprB))
1361+ case other => fail(s " Expected no shuffles due to position overlap at position 1, " +
1362+ s " but got: $other" )
1363+ }
13621364 }
13631365 }
13641366
@@ -1414,35 +1416,6 @@ class EnsureRequirementsSuite extends SharedSparkSession {
14141416 }
14151417 }
14161418
1417- test(" ShufflePartitionIdPassThrough - compatible when partition key matches at any position" ) {
1418- withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 10" ) {
1419- // Both sides partitioned by exprB and join on (exprA, exprB)
1420- // Should be compatible because partition key exprB matches at position 1 in join keys
1421- val plan1 = DummySparkPlan (
1422- outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprB), 5 ))
1423- val plan2 = DummySparkPlan (
1424- outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprB), 5 ))
1425- val smjExec = SortMergeJoinExec (exprA :: exprB :: Nil , exprA :: exprB :: Nil , Inner , None ,
1426- plan1, plan2)
1427-
1428- EnsureRequirements .apply(smjExec) match {
1429- case SortMergeJoinExec (
1430- leftKeys,
1431- rightKeys,
1432- _,
1433- _,
1434- SortExec (_, _, DummySparkPlan (_, _, _ : ShufflePartitionIdPassThrough , _, _), _),
1435- SortExec (_, _, DummySparkPlan (_, _, _ : ShufflePartitionIdPassThrough , _, _), _),
1436- _
1437- ) =>
1438- // No shuffles because exprB (partition key) appears at position 1 in join keys
1439- assert(leftKeys === Seq (exprA, exprB))
1440- assert(rightKeys === Seq (exprA, exprB))
1441- case other => fail(s " Expected no shuffles due to position overlap at position 1, " +
1442- s " but got: $other" )
1443- }
1444- }
1445- }
14461419
14471420 def years (expr : Expression ): TransformExpression = {
14481421 TransformExpression (YearsFunction , Seq (expr))
0 commit comments