Skip to content

Commit ea274ca

Browse files
committed
ut
1 parent 3cc6f87 commit ea274ca

File tree

7 files changed

+105
-14
lines changed

7 files changed

+105
-14
lines changed

client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java

+5-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import java.util.Map;
2929
import java.util.Set;
3030
import java.util.concurrent.ConcurrentHashMap;
31+
import java.util.concurrent.atomic.AtomicInteger;
3132
import java.util.concurrent.atomic.LongAdder;
3233
import java.util.stream.Collectors;
3334

@@ -359,6 +360,9 @@ public static void addSparkListener(SparkListener listener) {
359360
*/
360361
private static KeyLock<Integer> shuffleBroadcastLock = new KeyLock();
361362

363+
@VisibleForTesting
364+
public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new AtomicInteger();
365+
362366
protected static Map<Integer, Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>>
363367
getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
364368

@@ -401,6 +405,7 @@ public static byte[] serializeGetReducerFileGroupResponse(
401405
byte[] _serializeResult = out.toByteArray();
402406
getReducerFileGroupResponseBroadcasts.put(
403407
shuffleId, Tuple2.apply(broadcast, _serializeResult));
408+
getReducerFileGroupResponseBroadcastNum.incrementAndGet();
404409
return _serializeResult;
405410
} catch (Throwable e) {
406411
logger.error(
@@ -430,10 +435,6 @@ public static GetReducerFileGroupResponse deserializeGetReducerFileGroupResponse
430435
Broadcast<GetReducerFileGroupResponse> broadcast =
431436
(Broadcast<GetReducerFileGroupResponse>) objIn.readObject();
432437
response = broadcast.value();
433-
logger.info(
434-
"Shuffle {} broadcast GetReducerFileGroupResponse size={}.",
435-
shuffleId,
436-
bytes.length);
437438
}
438439
} catch (Throwable e) {
439440
logger.error(

client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java

+5-4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Map;
2626
import java.util.Set;
2727
import java.util.concurrent.ConcurrentHashMap;
28+
import java.util.concurrent.atomic.AtomicInteger;
2829
import java.util.concurrent.atomic.LongAdder;
2930
import java.util.stream.Collectors;
3031

@@ -476,6 +477,9 @@ public static void addSparkListener(SparkListener listener) {
476477
*/
477478
private static KeyLock<Integer> shuffleBroadcastLock = new KeyLock();
478479

480+
@VisibleForTesting
481+
public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new AtomicInteger();
482+
479483
protected static Map<Integer, Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>>
480484
getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
481485

@@ -520,6 +524,7 @@ public static byte[] serializeGetReducerFileGroupResponse(
520524
byte[] _serializeResult = out.toByteArray();
521525
getReducerFileGroupResponseBroadcasts.put(
522526
shuffleId, Tuple2.apply(broadcast, _serializeResult));
527+
getReducerFileGroupResponseBroadcastNum.incrementAndGet();
523528
return _serializeResult;
524529
} catch (Throwable e) {
525530
LOG.error(
@@ -552,10 +557,6 @@ public static GetReducerFileGroupResponse deserializeGetReducerFileGroupResponse
552557
Broadcast<GetReducerFileGroupResponse> broadcast =
553558
(Broadcast<GetReducerFileGroupResponse>) objIn.readObject();
554559
response = broadcast.value();
555-
LOG.info(
556-
"Shuffle {} broadcast GetReducerFileGroupResponse size={}.",
557-
shuffleId,
558-
bytes.length);
559560
}
560561
} catch (Throwable e) {
561562
LOG.error(

client/src/main/java/org/apache/celeborn/client/ShuffleClient.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ public abstract int getShuffleId(
308308
public static void registerDeserializeReducerFileGroupResponseFunction(
309309
BiFunction<Integer, byte[], ControlMessages.GetReducerFileGroupResponse> function) {
310310
if (!deserializeReducerFileGroupResponseFunction.isPresent()) {
311-
deserializeReducerFileGroupResponseFunction = Optional.of(function);
311+
deserializeReducerFileGroupResponseFunction = Optional.ofNullable(function);
312312
}
313313
}
314314

client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala

+14-4
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ class ReducePartitionCommitHandler(
7979
private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
8080
private val rpcCacheExpireTime = conf.clientRpcCacheExpireTime
8181

82+
private val getReducerFileGroupResponseBroadcastEnabled = conf.getReducerFileGroupBroadcastEnabled
83+
private val getReducerFileGroupResponseBroadcastMiniSize =
84+
conf.getReducerFileGroupBroadcastMiniSize
85+
8286
// noinspection UnstableApiUsage
8387
private val getReducerFileGroupRpcCache: Cache[Int, ByteBuffer] = CacheBuilder.newBuilder()
8488
.concurrencyLevel(rpcCacheConcurrencyLevel)
@@ -327,7 +331,7 @@ class ReducePartitionCommitHandler(
327331
getMapperAttempts(shuffleId))
328332

329333
// only check whether broadcast enabled for the UTs
330-
if (conf.getReducerFileGroupBroadcastEnabled) {
334+
if (getReducerFileGroupResponseBroadcastEnabled) {
331335
response = broadcastGetReducerFileGroup(shuffleId, response)
332336
}
333337

@@ -349,11 +353,17 @@ class ReducePartitionCommitHandler(
349353
val serializedMsg =
350354
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
351355

352-
if (conf.getReducerFileGroupBroadcastEnabled &&
353-
serializedMsg.capacity() >= conf.getReducerFileGroupBroadcastMiniSize) {
356+
if (getReducerFileGroupResponseBroadcastEnabled &&
357+
serializedMsg.capacity() >= getReducerFileGroupResponseBroadcastMiniSize) {
354358
val broadcastMsg = broadcastGetReducerFileGroup(shuffleId, returnedMsg)
355359
if (broadcastMsg != returnedMsg) {
356-
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(broadcastMsg)
360+
val serializedBroadcastMsg =
361+
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(broadcastMsg)
362+
logInfo(s"Shuffle $shuffleId GetReducerFileGroupResponse size" +
363+
s" ${serializedMsg.capacity()} reached the broadcast threshold" +
364+
s" $getReducerFileGroupResponseBroadcastMiniSize," +
365+
s" the broadcast response size is ${serializedBroadcastMsg.capacity()}.")
366+
serializedBroadcastMsg
357367
} else {
358368
serializedMsg
359369
}

common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -1207,12 +1207,14 @@ object ControlMessages extends Logging {
12071207
case (uniqueId, pushFailedBatchSet) =>
12081208
(uniqueId, PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet))
12091209
}.toMap.asJava
1210+
val broadcast = pbGetReducerFileGroupResponse.getBroadcast.toByteArray
12101211
GetReducerFileGroupResponse(
12111212
Utils.toStatusCode(pbGetReducerFileGroupResponse.getStatus),
12121213
fileGroup,
12131214
attempts,
12141215
partitionIds,
1215-
pushFailedBatches)
1216+
pushFailedBatches,
1217+
broadcast)
12161218

12171219
case GET_SHUFFLE_ID_VALUE =>
12181220
message.getParsedPayload()

tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashSuite.scala

+38
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
package org.apache.celeborn.tests.spark
1919

2020
import org.apache.spark.SparkConf
21+
import org.apache.spark.shuffle.celeborn.SparkUtils
2122
import org.apache.spark.sql.SparkSession
2223
import org.scalatest.BeforeAndAfterEach
2324
import org.scalatest.funsuite.AnyFunSuite
2425

2526
import org.apache.celeborn.client.ShuffleClient
27+
import org.apache.celeborn.common.CelebornConf
2628
import org.apache.celeborn.common.protocol.ShuffleMode
2729

2830
class CelebornHashSuite extends AnyFunSuite
@@ -64,4 +66,40 @@ class CelebornHashSuite extends AnyFunSuite
6466

6567
celebornSparkSession.stop()
6668
}
69+
70+
test("celeborn spark integration test - GetReducerFileGroupResponse broadcast") {
71+
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
72+
val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
73+
.set(
74+
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED.key}",
75+
"true")
76+
.set(
77+
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE.key}",
78+
"0")
79+
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
80+
val combineResult = combine(sparkSession)
81+
val groupbyResult = groupBy(sparkSession)
82+
val repartitionResult = repartition(sparkSession)
83+
val sqlResult = runsql(sparkSession)
84+
85+
Thread.sleep(3000L)
86+
sparkSession.stop()
87+
88+
val celebornSparkSession = SparkSession.builder()
89+
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
90+
.getOrCreate()
91+
val celebornCombineResult = combine(celebornSparkSession)
92+
val celebornGroupbyResult = groupBy(celebornSparkSession)
93+
val celebornRepartitionResult = repartition(celebornSparkSession)
94+
val celebornSqlResult = runsql(celebornSparkSession)
95+
96+
assert(combineResult.equals(celebornCombineResult))
97+
assert(groupbyResult.equals(celebornGroupbyResult))
98+
assert(repartitionResult.equals(celebornRepartitionResult))
99+
assert(combineResult.equals(celebornCombineResult))
100+
assert(sqlResult.equals(celebornSqlResult))
101+
assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() > 0)
102+
103+
celebornSparkSession.stop()
104+
}
67105
}

tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornSortSuite.scala

+39
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.celeborn.tests.spark
1919

2020
import org.apache.spark.SparkConf
21+
import org.apache.spark.shuffle.celeborn.SparkUtils
2122
import org.apache.spark.sql.SparkSession
2223
import org.scalatest.BeforeAndAfterEach
2324
import org.scalatest.funsuite.AnyFunSuite
@@ -66,4 +67,42 @@ class CelebornSortSuite extends AnyFunSuite
6667

6768
celebornSparkSession.stop()
6869
}
70+
71+
test("celeborn spark integration test - GetReducerFileGroupResponse broadcast") {
72+
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
73+
val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
74+
.set(s"spark.${CelebornConf.CLIENT_PUSH_SORT_RANDOMIZE_PARTITION_ENABLED.key}", "false")
75+
.set(
76+
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED.key}",
77+
"true")
78+
.set(
79+
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE.key}",
80+
"0")
81+
82+
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
83+
val combineResult = combine(sparkSession)
84+
val groupbyResult = groupBy(sparkSession)
85+
val repartitionResult = repartition(sparkSession)
86+
val sqlResult = runsql(sparkSession)
87+
88+
Thread.sleep(3000L)
89+
sparkSession.stop()
90+
91+
val celebornSparkSession = SparkSession.builder()
92+
.config(updateSparkConf(sparkConf, ShuffleMode.SORT))
93+
.getOrCreate()
94+
val celebornCombineResult = combine(celebornSparkSession)
95+
val celebornGroupbyResult = groupBy(celebornSparkSession)
96+
val celebornRepartitionResult = repartition(celebornSparkSession)
97+
val celebornSqlResult = runsql(celebornSparkSession)
98+
99+
assert(combineResult.equals(celebornCombineResult))
100+
assert(groupbyResult.equals(celebornGroupbyResult))
101+
assert(repartitionResult.equals(celebornRepartitionResult))
102+
assert(combineResult.equals(celebornCombineResult))
103+
assert(sqlResult.equals(celebornSqlResult))
104+
assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() > 0)
105+
106+
celebornSparkSession.stop()
107+
}
69108
}

0 commit comments

Comments
 (0)