Skip to content

Commit ba003e4

Browse files
committed
Broadcast GetReducerFileGroupResponse
1 parent 7174275 commit ba003e4

File tree

21 files changed

+568
-43
lines changed

21 files changed

+568
-43
lines changed

LICENSE

+1
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ Apache Spark
223223
./common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
224224
./common/src/test/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManagerSuiteJ.java
225225
./common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
226+
./common/src/main/scala/org/apache/celeborn/common/util/KeyLock.scala
226227
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DB.java
227228
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DBIterator.java
228229
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/LevelDB.java

client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java

+7
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ private void initializeLifecycleManager(String appId) {
108108

109109
lifecycleManager.registerShuffleTrackerCallback(
110110
shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId));
111+
112+
lifecycleManager.registerBroadcastGetReducerFileGroupResponse(
113+
(shuffleId, getReducerFileGroupResponse) ->
114+
SparkUtils.serializeGetReducerFileGroupResponse(
115+
shuffleId, getReducerFileGroupResponse));
116+
lifecycleManager.registerInvalidatedBroadcastGetReducerFileGroupResponse(
117+
shuffleId -> SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
111118
}
112119
}
113120
}

client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java

+114
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717

1818
package org.apache.spark.shuffle.celeborn;
1919

20+
import java.io.ByteArrayInputStream;
2021
import java.io.IOException;
22+
import java.io.ObjectInputStream;
23+
import java.io.ObjectOutputStream;
2124
import java.lang.reflect.Field;
2225
import java.lang.reflect.Method;
2326
import java.util.HashSet;
@@ -38,6 +41,7 @@
3841
import org.apache.spark.SparkContext;
3942
import org.apache.spark.SparkContext$;
4043
import org.apache.spark.TaskContext;
44+
import org.apache.spark.broadcast.Broadcast;
4145
import org.apache.spark.scheduler.DAGScheduler;
4246
import org.apache.spark.scheduler.MapStatus;
4347
import org.apache.spark.scheduler.MapStatus$;
@@ -54,7 +58,9 @@
5458

5559
import org.apache.celeborn.client.ShuffleClient;
5660
import org.apache.celeborn.common.CelebornConf;
61+
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
5762
import org.apache.celeborn.common.util.JavaUtils;
63+
import org.apache.celeborn.common.util.KeyLock;
5864
import org.apache.celeborn.common.util.Utils;
5965
import org.apache.celeborn.reflect.DynFields;
6066

@@ -346,4 +352,112 @@ public static void addSparkListener(SparkListener listener) {
346352
sparkContext.addSparkListener(listener);
347353
}
348354
}
355+
356+
/**
357+
* A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread accessing the
358+
* broadcast belonging to the shuffle id at a time.
359+
*/
360+
private static KeyLock<Integer> shuffleBroadcastLock = new KeyLock();
361+
362+
protected static Map<Integer, Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>>
363+
getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
364+
365+
public static byte[] serializeGetReducerFileGroupResponse(
366+
Integer shuffleId, GetReducerFileGroupResponse response) {
367+
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
368+
if (sparkContext == null) {
369+
logger.error("Can not get active SparkContext.");
370+
return null;
371+
}
372+
373+
return shuffleBroadcastLock.withLock(
374+
shuffleId,
375+
() -> {
376+
Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>
377+
cachedSerializeGetReducerFileGroupResponse =
378+
getReducerFileGroupResponseBroadcasts.get(shuffleId);
379+
if (cachedSerializeGetReducerFileGroupResponse != null) {
380+
return cachedSerializeGetReducerFileGroupResponse._2;
381+
}
382+
383+
try {
384+
Broadcast<GetReducerFileGroupResponse> broadcast =
385+
sparkContext.broadcast(
386+
response,
387+
scala.reflect.ClassManifestFactory.fromClass(
388+
GetReducerFileGroupResponse.class));
389+
390+
// Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard
391+
// one
392+
// This implementation doesn't reallocate the whole memory block but allocates
393+
// additional buffers. This way no buffers need to be garbage collected and
394+
// the contents don't have to be copied to the new buffer.
395+
org.apache.commons.io.output.ByteArrayOutputStream out =
396+
new org.apache.commons.io.output.ByteArrayOutputStream();
397+
try (ObjectOutputStream oos = new ObjectOutputStream(out)) {
398+
oos.writeObject(broadcast);
399+
}
400+
byte[] _serializeResult = out.toByteArray();
401+
getReducerFileGroupResponseBroadcasts.put(
402+
shuffleId, Tuple2.apply(broadcast, _serializeResult));
403+
return _serializeResult;
404+
} catch (Throwable e) {
405+
logger.error(
406+
"Failed to serialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
407+
return null;
408+
}
409+
});
410+
}
411+
412+
public static GetReducerFileGroupResponse deserializeGetReducerFileGroupResponse(
413+
Integer shuffleId, byte[] bytes) {
414+
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
415+
if (sparkContext == null) {
416+
logger.error("Can not get active SparkContext.");
417+
return null;
418+
}
419+
420+
return shuffleBroadcastLock.withLock(
421+
shuffleId,
422+
() -> {
423+
GetReducerFileGroupResponse response = null;
424+
425+
try {
426+
try (ObjectInputStream objIn = new ObjectInputStream(new ByteArrayInputStream(bytes))) {
427+
Broadcast<GetReducerFileGroupResponse> broadcast =
428+
(Broadcast<GetReducerFileGroupResponse>) objIn.readObject();
429+
response = broadcast.value();
430+
logger.info(
431+
"Shuffle {} broadcast GetReducerFileGroupResponse size={}.",
432+
shuffleId,
433+
bytes.length);
434+
}
435+
} catch (Throwable e) {
436+
logger.error(
437+
"Failed to deserialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
438+
}
439+
return response;
440+
});
441+
}
442+
443+
public static void invalidateSerializedGetReducerFileGroupResponse(Integer shuffleId) {
444+
shuffleBroadcastLock.withLock(
445+
shuffleId,
446+
() -> {
447+
try {
448+
Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>
449+
cachedSerializeGetReducerFileGroupResponse =
450+
getReducerFileGroupResponseBroadcasts.remove(shuffleId);
451+
if (cachedSerializeGetReducerFileGroupResponse != null) {
452+
cachedSerializeGetReducerFileGroupResponse._1().destroy();
453+
}
454+
} catch (Throwable e) {
455+
logger.error(
456+
"Failed to invalidate serialized GetReducerFileGroupResponse for shuffle: "
457+
+ shuffleId,
458+
e);
459+
}
460+
return null;
461+
});
462+
}
349463
}

client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala

+11
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
2020
import java.io.IOException
2121
import java.util.concurrent.{ThreadPoolExecutor, TimeUnit}
2222
import java.util.concurrent.atomic.AtomicReference
23+
import java.util.function.BiFunction
2324

2425
import org.apache.spark.{Aggregator, InterruptibleIterator, TaskContext}
2526
import org.apache.spark.internal.Logging
@@ -33,6 +34,7 @@ import org.apache.celeborn.client.read.CelebornInputStream
3334
import org.apache.celeborn.client.read.MetricsCallback
3435
import org.apache.celeborn.common.CelebornConf
3536
import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException}
37+
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
3638
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils}
3739

3840
class CelebornShuffleReader[K, C](
@@ -254,4 +256,13 @@ class CelebornShuffleReader[K, C](
254256

255257
object CelebornShuffleReader {
256258
var streamCreatorPool: ThreadPoolExecutor = null
259+
// Register the deserializer for GetReducerFileGroupResponse broadcast
260+
ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(new BiFunction[
261+
Integer,
262+
Array[Byte],
263+
GetReducerFileGroupResponse] {
264+
override def apply(shuffleId: Integer, broadcast: Array[Byte]): GetReducerFileGroupResponse = {
265+
SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, broadcast)
266+
}
267+
})
257268
}

client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java

+7
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,13 @@ private void initializeLifecycleManager(String appId) {
150150

151151
lifecycleManager.registerShuffleTrackerCallback(
152152
shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId));
153+
154+
lifecycleManager.registerBroadcastGetReducerFileGroupResponse(
155+
(shuffleId, getReducerFileGroupResponse) ->
156+
SparkUtils.serializeGetReducerFileGroupResponse(
157+
shuffleId, getReducerFileGroupResponse));
158+
lifecycleManager.registerInvalidatedBroadcastGetReducerFileGroupResponse(
159+
shuffleId -> SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
153160
}
154161
}
155162
}

client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java

+120
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.shuffle.celeborn;
1919

20+
import java.io.ByteArrayInputStream;
21+
import java.io.ObjectInputStream;
22+
import java.io.ObjectOutputStream;
2023
import java.util.HashSet;
2124
import java.util.List;
2225
import java.util.Map;
@@ -36,6 +39,8 @@
3639
import org.apache.spark.SparkContext;
3740
import org.apache.spark.SparkContext$;
3841
import org.apache.spark.TaskContext;
42+
import org.apache.spark.broadcast.Broadcast;
43+
import org.apache.spark.io.CompressionCodec;
3944
import org.apache.spark.scheduler.DAGScheduler;
4045
import org.apache.spark.scheduler.MapStatus;
4146
import org.apache.spark.scheduler.MapStatus$;
@@ -57,7 +62,9 @@
5762

5863
import org.apache.celeborn.client.ShuffleClient;
5964
import org.apache.celeborn.common.CelebornConf;
65+
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
6066
import org.apache.celeborn.common.util.JavaUtils;
67+
import org.apache.celeborn.common.util.KeyLock;
6168
import org.apache.celeborn.reflect.DynConstructors;
6269
import org.apache.celeborn.reflect.DynFields;
6370
import org.apache.celeborn.reflect.DynMethods;
@@ -462,4 +469,117 @@ public static void addSparkListener(SparkListener listener) {
462469
sparkContext.addSparkListener(listener);
463470
}
464471
}
472+
473+
/**
474+
* A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread accessing the
475+
* broadcast belonging to the shuffle id at a time.
476+
*/
477+
private static KeyLock<Integer> shuffleBroadcastLock = new KeyLock();
478+
479+
protected static Map<Integer, Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>>
480+
getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
481+
482+
public static byte[] serializeGetReducerFileGroupResponse(
483+
Integer shuffleId, GetReducerFileGroupResponse response) {
484+
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
485+
if (sparkContext == null) {
486+
LOG.error("Can not get active SparkContext.");
487+
return null;
488+
}
489+
490+
return shuffleBroadcastLock.withLock(
491+
shuffleId,
492+
() -> {
493+
Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>
494+
cachedSerializeGetReducerFileGroupResponse =
495+
getReducerFileGroupResponseBroadcasts.get(shuffleId);
496+
if (cachedSerializeGetReducerFileGroupResponse != null) {
497+
return cachedSerializeGetReducerFileGroupResponse._2;
498+
}
499+
500+
try {
501+
Broadcast<GetReducerFileGroupResponse> broadcast =
502+
sparkContext.broadcast(
503+
response,
504+
scala.reflect.ClassManifestFactory.fromClass(
505+
GetReducerFileGroupResponse.class));
506+
507+
CompressionCodec codec = CompressionCodec.createCodec(sparkContext.conf());
508+
// Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard
509+
// one
510+
// This implementation doesn't reallocate the whole memory block but allocates
511+
// additional buffers. This way no buffers need to be garbage collected and
512+
// the contents don't have to be copied to the new buffer.
513+
org.apache.commons.io.output.ByteArrayOutputStream out =
514+
new org.apache.commons.io.output.ByteArrayOutputStream();
515+
try (ObjectOutputStream oos =
516+
new ObjectOutputStream(codec.compressedOutputStream(out))) {
517+
oos.writeObject(broadcast);
518+
}
519+
byte[] _serializeResult = out.toByteArray();
520+
getReducerFileGroupResponseBroadcasts.put(
521+
shuffleId, Tuple2.apply(broadcast, _serializeResult));
522+
return _serializeResult;
523+
} catch (Throwable e) {
524+
LOG.error(
525+
"Failed to serialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
526+
return null;
527+
}
528+
});
529+
}
530+
531+
public static GetReducerFileGroupResponse deserializeGetReducerFileGroupResponse(
532+
Integer shuffleId, byte[] bytes) {
533+
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
534+
if (sparkContext == null) {
535+
LOG.error("Can not get active SparkContext.");
536+
return null;
537+
}
538+
539+
return shuffleBroadcastLock.withLock(
540+
shuffleId,
541+
() -> {
542+
GetReducerFileGroupResponse response = null;
543+
544+
try {
545+
CompressionCodec codec = CompressionCodec.createCodec(sparkContext.conf());
546+
try (ObjectInputStream objIn =
547+
new ObjectInputStream(
548+
codec.compressedInputStream(new ByteArrayInputStream(bytes)))) {
549+
Broadcast<GetReducerFileGroupResponse> broadcast =
550+
(Broadcast<GetReducerFileGroupResponse>) objIn.readObject();
551+
response = broadcast.value();
552+
LOG.info(
553+
"Shuffle {} broadcast GetReducerFileGroupResponse size={}.",
554+
shuffleId,
555+
bytes.length);
556+
}
557+
} catch (Throwable e) {
558+
LOG.error(
559+
"Failed to deserialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
560+
}
561+
return response;
562+
});
563+
}
564+
565+
public static void invalidateSerializedGetReducerFileGroupResponse(Integer shuffleId) {
566+
shuffleBroadcastLock.withLock(
567+
shuffleId,
568+
() -> {
569+
try {
570+
Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>
571+
cachedSerializeGetReducerFileGroupResponse =
572+
getReducerFileGroupResponseBroadcasts.remove(shuffleId);
573+
if (cachedSerializeGetReducerFileGroupResponse != null) {
574+
cachedSerializeGetReducerFileGroupResponse._1().destroy();
575+
}
576+
} catch (Throwable e) {
577+
LOG.error(
578+
"Failed to invalidate serialized GetReducerFileGroupResponse for shuffle: "
579+
+ shuffleId,
580+
e);
581+
}
582+
return null;
583+
});
584+
}
465585
}

client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala

+12-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.io.IOException
2121
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Set => JSet}
2222
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit}
2323
import java.util.concurrent.atomic.AtomicReference
24+
import java.util.function.BiFunction
2425

2526
import scala.collection.JavaConverters._
2627

@@ -43,7 +44,8 @@ import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRet
4344
import org.apache.celeborn.common.network.client.TransportClient
4445
import org.apache.celeborn.common.network.protocol.TransportMessage
4546
import org.apache.celeborn.common.protocol._
46-
import org.apache.celeborn.common.protocol.message.StatusCode
47+
import org.apache.celeborn.common.protocol.message.{ControlMessages, StatusCode}
48+
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
4749
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils, Utils}
4850

4951
class CelebornShuffleReader[K, C](
@@ -455,4 +457,13 @@ class CelebornShuffleReader[K, C](
455457

456458
object CelebornShuffleReader {
457459
var streamCreatorPool: ThreadPoolExecutor = null
460+
// Register the deserializer for GetReducerFileGroupResponse broadcast
461+
ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(new BiFunction[
462+
Integer,
463+
Array[Byte],
464+
GetReducerFileGroupResponse] {
465+
override def apply(shuffleId: Integer, broadcast: Array[Byte]): GetReducerFileGroupResponse = {
466+
SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, broadcast)
467+
}
468+
})
458469
}

0 commit comments

Comments
 (0)