Skip to content

Commit 0f00cde

Browse files
committed
save
1 parent 7ebdc21 commit 0f00cde

File tree

4 files changed

+36
-8
lines changed

4 files changed

+36
-8
lines changed

client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java

+7
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ private void initializeLifecycleManager(String appId) {
108108

109109
lifecycleManager.registerShuffleTrackerCallback(
110110
shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId));
111+
112+
lifecycleManager.registerBroadcastGetReducerFileGroupResponse(
113+
(shuffleId, getReducerFileGroupResponse) ->
114+
SparkUtils.serializeGetReducerFileGroupResponse(
115+
shuffleId, getReducerFileGroupResponse));
116+
lifecycleManager.registerInvalidatedBroadcastGetReducerFileGroupResponse(
117+
shuffleId -> SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
111118
}
112119
}
113120
}

client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala

+11
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
2020
import java.io.IOException
2121
import java.util.concurrent.{ThreadPoolExecutor, TimeUnit}
2222
import java.util.concurrent.atomic.AtomicReference
23+
import java.util.function.BiFunction
2324

2425
import org.apache.spark.{Aggregator, InterruptibleIterator, TaskContext}
2526
import org.apache.spark.internal.Logging
@@ -33,6 +34,7 @@ import org.apache.celeborn.client.read.CelebornInputStream
3334
import org.apache.celeborn.client.read.MetricsCallback
3435
import org.apache.celeborn.common.CelebornConf
3536
import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException}
37+
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
3638
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils}
3739

3840
class CelebornShuffleReader[K, C](
@@ -254,4 +256,13 @@ class CelebornShuffleReader[K, C](
254256

255257
object CelebornShuffleReader {
256258
var streamCreatorPool: ThreadPoolExecutor = null
259+
// Register the deserializer for GetReducerFileGroupResponse
260+
ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(new BiFunction[
261+
Integer,
262+
Array[Byte],
263+
GetReducerFileGroupResponse] {
264+
override def apply(shuffleId: Integer, broadcast: Array[Byte]): GetReducerFileGroupResponse = {
265+
SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, broadcast)
266+
}
267+
})
257268
}

client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala

+12-4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.io.IOException
2121
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Set => JSet}
2222
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit}
2323
import java.util.concurrent.atomic.AtomicReference
24+
import java.util.function.BiFunction
2425

2526
import scala.collection.JavaConverters._
2627

@@ -43,7 +44,8 @@ import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRet
4344
import org.apache.celeborn.common.network.client.TransportClient
4445
import org.apache.celeborn.common.network.protocol.TransportMessage
4546
import org.apache.celeborn.common.protocol._
46-
import org.apache.celeborn.common.protocol.message.StatusCode
47+
import org.apache.celeborn.common.protocol.message.{ControlMessages, StatusCode}
48+
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
4749
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils, Utils}
4850

4951
class CelebornShuffleReader[K, C](
@@ -60,9 +62,6 @@ class CelebornShuffleReader[K, C](
6062

6163
private val dep = handle.dependency
6264

63-
ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(
64-
SparkUtils.deserializeGetReducerFileGroupResponse)
65-
6665
@VisibleForTesting
6766
val shuffleClient = ShuffleClient.get(
6867
handle.appUniqueId,
@@ -458,4 +457,13 @@ class CelebornShuffleReader[K, C](
458457

459458
object CelebornShuffleReader {
460459
var streamCreatorPool: ThreadPoolExecutor = null
460+
// Register the deserializer for GetReducerFileGroupResponse
461+
ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(new BiFunction[
462+
Integer,
463+
Array[Byte],
464+
GetReducerFileGroupResponse] {
465+
override def apply(shuffleId: Integer, broadcast: Array[Byte]): GetReducerFileGroupResponse = {
466+
SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, broadcast)
467+
}
468+
})
461469
}

client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -1815,12 +1815,14 @@ protected Tuple3<ReduceFileGroups, String, Exception> loadFileGroupInternal(
18151815
switch (response.status()) {
18161816
case SUCCESS:
18171817
if (response.broadcast() != null && response.broadcast().length > 0) {
1818-
GetReducerFileGroupResponse broadcast =
1818+
logger.info(
1819+
"Deserializing broadcast GetReducerFileGroupResponse for shuffle: {}.", shuffleId);
1820+
response =
18191821
ShuffleClient.deserializeReducerFileGroupResponse(shuffleId, response.broadcast());
1820-
if (broadcast == null) {
1821-
throw new CelebornIOException("Broadcast response is null!");
1822+
if (response == null) {
1823+
throw new CelebornIOException(
1824+
"Failed to get broadcast GetReducerFileGroupResponse for shuffle: " + shuffleId);
18221825
}
1823-
response = broadcast;
18241826
}
18251827
logger.info(
18261828
"Shuffle {} request reducer file group success using {} ms, result partition size {}.",

0 commit comments

Comments
 (0)