@@ -1199,13 +1199,13 @@ class EnsureRequirementsSuite extends SharedSparkSession {
11991199
12001200 test(" ShufflePartitionIdPassThrough - avoid unnecessary shuffle when children are compatible" ) {
12011201 withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 10" ) {
1202- val plan1 = DummySparkPlan (
1203- outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprA), 5 ))
1204- val plan2 = DummySparkPlan (
1205- outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprA), 5 ))
1206- val smjExec = SortMergeJoinExec (exprA :: Nil , exprA :: Nil , Inner , None , plan1, plan2)
1202+ val passThrough_a_5 = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprA), 5 )
12071203
1208- EnsureRequirements .apply(smjExec) match {
1204+ val leftPlan = DummySparkPlan (outputPartitioning = passThrough_a_5)
1205+ val rightPlan = DummySparkPlan (outputPartitioning = passThrough_a_5)
1206+ val join = SortMergeJoinExec (exprA :: Nil , exprA :: Nil , Inner , None , leftPlan, rightPlan)
1207+
1208+ EnsureRequirements .apply(join) match {
12091209 case SortMergeJoinExec (
12101210 leftKeys,
12111211 rightKeys,
@@ -1225,13 +1225,13 @@ class EnsureRequirementsSuite extends SharedSparkSession {
12251225 test(" ShufflePartitionIdPassThrough incompatibility - different partitions" ) {
12261226 withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 10" ) {
12271227 // Different number of partitions - should add shuffles
1228- val plan1 = DummySparkPlan (
1228+ val leftPlan = DummySparkPlan (
12291229 outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprA), 5 ))
1230- val plan2 = DummySparkPlan (
1230+ val rightPlan = DummySparkPlan (
12311231 outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprB), 8 ))
1232- val smjExec = SortMergeJoinExec (exprA :: Nil , exprB :: Nil , Inner , None , plan1, plan2 )
1232+ val join = SortMergeJoinExec (exprA :: Nil , exprB :: Nil , Inner , None , leftPlan, rightPlan )
12331233
1234- EnsureRequirements .apply(smjExec ) match {
1234+ EnsureRequirements .apply(join ) match {
12351235 case SortMergeJoinExec (_, _, _, _,
12361236 SortExec (_, _, ShuffleExchangeExec (p1 : HashPartitioning , _, _, _), _),
12371237 SortExec (_, _, ShuffleExchangeExec (p2 : HashPartitioning , _, _, _), _), _) =>
@@ -1248,15 +1248,15 @@ class EnsureRequirementsSuite extends SharedSparkSession {
12481248 test(" ShufflePartitionIdPassThrough incompatibility - key position mismatch" ) {
12491249 withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 10" ) {
12501250 // Key position mismatch - should add shuffles
1251- val plan1 = DummySparkPlan (
1251+ val leftPlan = DummySparkPlan (
12521252 outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprA), 5 ))
1253- val plan2 = DummySparkPlan (
1253+ val rightPlan = DummySparkPlan (
12541254 outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprC), 5 ))
12551255 // Join on different keys than partitioning keys
1256- val smjExec = SortMergeJoinExec (exprA :: exprB :: Nil , exprD :: exprC :: Nil , Inner , None ,
1257- plan1, plan2 )
1256+ val join = SortMergeJoinExec (exprA :: exprB :: Nil , exprD :: exprC :: Nil , Inner , None ,
1257+ leftPlan, rightPlan )
12581258
1259- EnsureRequirements .apply(smjExec ) match {
1259+ EnsureRequirements .apply(join ) match {
12601260 case SortMergeJoinExec (_, _, _, _,
12611261 SortExec (_, _, ShuffleExchangeExec (_ : HashPartitioning , _, _, _), _),
12621262 SortExec (_, _, ShuffleExchangeExec (_ : HashPartitioning , _, _, _), _), _) =>
@@ -1269,13 +1269,13 @@ class EnsureRequirementsSuite extends SharedSparkSession {
12691269 test(" ShufflePartitionIdPassThrough vs HashPartitioning - always shuffles" ) {
12701270 withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 10" ) {
12711271 // ShufflePartitionIdPassThrough vs HashPartitioning - always adds shuffles
1272- val plan1 = DummySparkPlan (
1272+ val leftPlan = DummySparkPlan (
12731273 outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprA), 5 ))
1274- val plan2 = DummySparkPlan (
1274+ val rightPlan = DummySparkPlan (
12751275 outputPartitioning = HashPartitioning (exprB :: Nil , 5 ))
1276- val smjExec = SortMergeJoinExec (exprA :: Nil , exprB :: Nil , Inner , None , plan1, plan2 )
1276+ val join = SortMergeJoinExec (exprA :: Nil , exprB :: Nil , Inner , None , leftPlan, rightPlan )
12771277
1278- EnsureRequirements .apply(smjExec ) match {
1278+ EnsureRequirements .apply(join ) match {
12791279 case SortMergeJoinExec (_, _, _, _,
12801280 SortExec (_, _, ShuffleExchangeExec (_ : HashPartitioning , _, _, _), _),
12811281 SortExec (_, _, _ : DummySparkPlan , _), _) =>
@@ -1292,12 +1292,12 @@ class EnsureRequirementsSuite extends SharedSparkSession {
12921292 test(" ShufflePartitionIdPassThrough vs SinglePartition - shuffles added" ) {
12931293 withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 5" ) {
12941294 // Even when compatible (numPartitions=1), shuffles added due to canCreatePartitioning=false
1295- val plan1 = DummySparkPlan (
1295+ val leftPlan = DummySparkPlan (
12961296 outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprA), 1 ))
1297- val plan2 = DummySparkPlan (outputPartitioning = SinglePartition )
1298- val smjExec = SortMergeJoinExec (exprA :: Nil , exprB :: Nil , Inner , None , plan1, plan2 )
1297+ val rightPlan = DummySparkPlan (outputPartitioning = SinglePartition )
1298+ val join = SortMergeJoinExec (exprA :: Nil , exprB :: Nil , Inner , None , leftPlan, rightPlan )
12991299
1300- EnsureRequirements .apply(smjExec ) match {
1300+ EnsureRequirements .apply(join ) match {
13011301 case SortMergeJoinExec (_, _, _, _,
13021302 SortExec (_, _, ShuffleExchangeExec (_ : HashPartitioning , _, _, _), _),
13031303 SortExec (_, _, ShuffleExchangeExec (_ : HashPartitioning , _, _, _), _), _) =>
@@ -1310,16 +1310,17 @@ class EnsureRequirementsSuite extends SharedSparkSession {
13101310
13111311 test(" ShufflePartitionIdPassThrough - compatible with multiple clustering keys" ) {
13121312 withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 10" ) {
1313+ val passThrough_a_5 = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprA), 5 )
1314+ val passThrough_b_5 = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprB), 5 )
1315+
13131316 // Both partitioned by exprA, joined on (exprA, exprB)
13141317 // Should be compatible because exprA positions overlap
1315- val plan1 = DummySparkPlan (
1316- outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprA), 5 ))
1317- val plan2 = DummySparkPlan (
1318- outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprA), 5 ))
1319- val smjExec = SortMergeJoinExec (exprA :: exprB :: Nil , exprA :: exprB :: Nil , Inner , None ,
1320- plan1, plan2)
1318+ val leftPlanA = DummySparkPlan (outputPartitioning = passThrough_a_5)
1319+ val rightPlanA = DummySparkPlan (outputPartitioning = passThrough_a_5)
1320+ val joinA = SortMergeJoinExec (exprA :: exprB :: Nil , exprA :: exprB :: Nil , Inner , None ,
1321+ leftPlanA, rightPlanA)
13211322
1322- EnsureRequirements .apply(smjExec ) match {
1323+ EnsureRequirements .apply(joinA ) match {
13231324 case SortMergeJoinExec (
13241325 leftKeys,
13251326 rightKeys,
@@ -1338,14 +1339,12 @@ class EnsureRequirementsSuite extends SharedSparkSession {
13381339 // Test case 2: partition key matches at position 1
13391340 // Both sides partitioned by exprB and join on (exprA, exprB)
13401341 // Should be compatible because partition key exprB matches at position 1 in join keys
1341- val plan3 = DummySparkPlan (
1342- outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprB), 5 ))
1343- val plan4 = DummySparkPlan (
1344- outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprB), 5 ))
1345- val smjExec2 = SortMergeJoinExec (exprA :: exprB :: Nil , exprA :: exprB :: Nil , Inner , None ,
1346- plan3, plan4)
1342+ val leftPlanB = DummySparkPlan (outputPartitioning = passThrough_b_5)
1343+ val rightPlanB = DummySparkPlan (outputPartitioning = passThrough_b_5)
1344+ val joinB = SortMergeJoinExec (exprA :: exprB :: Nil , exprA :: exprB :: Nil , Inner , None ,
1345+ leftPlanB, rightPlanB)
13471346
1348- EnsureRequirements .apply(smjExec2 ) match {
1347+ EnsureRequirements .apply(joinB ) match {
13491348 case SortMergeJoinExec (
13501349 leftKeys,
13511350 rightKeys,
@@ -1368,13 +1367,13 @@ class EnsureRequirementsSuite extends SharedSparkSession {
13681367 withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 10" ) {
13691368 // Partitioned by exprA and exprB respectively, but joining on completely different keys
13701369 // Should require shuffles because partition keys don't match join keys
1371- val plan1 = DummySparkPlan (
1370+ val leftPlan = DummySparkPlan (
13721371 outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprA), 5 ))
1373- val plan2 = DummySparkPlan (
1372+ val rightPlan = DummySparkPlan (
13741373 outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprB), 5 ))
1375- val smjExec = SortMergeJoinExec (exprC :: Nil , exprD :: Nil , Inner , None , plan1, plan2 )
1374+ val join = SortMergeJoinExec (exprC :: Nil , exprD :: Nil , Inner , None , leftPlan, rightPlan )
13761375
1377- EnsureRequirements .apply(smjExec ) match {
1376+ EnsureRequirements .apply(join ) match {
13781377 case SortMergeJoinExec (_, _, _, _,
13791378 SortExec (_, _, ShuffleExchangeExec (p1 : HashPartitioning , _, _, _), _),
13801379 SortExec (_, _, ShuffleExchangeExec (p2 : HashPartitioning , _, _, _), _), _) =>
@@ -1395,14 +1394,14 @@ class EnsureRequirementsSuite extends SharedSparkSession {
13951394 // Test if cross-position matching works: left partition key exprA matches right join key
13961395 // exprA (pos 0)
13971396 // and right partition key exprB matches left join key exprB (pos 1)
1398- val plan1 = DummySparkPlan (
1397+ val leftPlan = DummySparkPlan (
13991398 outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprA), 5 ))
1400- val plan2 = DummySparkPlan (
1399+ val rightPlan = DummySparkPlan (
14011400 outputPartitioning = ShufflePartitionIdPassThrough (DirectShufflePartitionID (exprB), 5 ))
1402- val smjExec = SortMergeJoinExec (exprA :: exprB :: Nil , exprA :: exprB :: Nil , Inner , None ,
1403- plan1, plan2 )
1401+ val join = SortMergeJoinExec (exprA :: exprB :: Nil , exprA :: exprB :: Nil , Inner , None ,
1402+ leftPlan, rightPlan )
14041403
1405- EnsureRequirements .apply(smjExec ) match {
1404+ EnsureRequirements .apply(join ) match {
14061405 case SortMergeJoinExec (_, _, _, _,
14071406 SortExec (_, _, ShuffleExchangeExec (p1 : HashPartitioning , _, _, _), _),
14081407 SortExec (_, _, ShuffleExchangeExec (p2 : HashPartitioning , _, _, _), _), _) =>
0 commit comments