Skip to content

Commit 865690d

Browse files
committed
[CELEBORN-1921] Broadcast large GetReducerFileGroupResponse to prevent Spark driver network exhausted
1 parent 7174275 commit 865690d

File tree

21 files changed

+652
-21
lines changed

21 files changed

+652
-21
lines changed

LICENSE

+1
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ Apache Spark
223223
./common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
224224
./common/src/test/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManagerSuiteJ.java
225225
./common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
226+
./common/src/main/scala/org/apache/celeborn/common/util/KeyLock.scala
226227
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DB.java
227228
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DBIterator.java
228229
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/LevelDB.java

client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java

+7
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ private void initializeLifecycleManager(String appId) {
108108

109109
lifecycleManager.registerShuffleTrackerCallback(
110110
shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId));
111+
112+
lifecycleManager.registerBroadcastGetReducerFileGroupResponse(
113+
(shuffleId, getReducerFileGroupResponse) ->
114+
SparkUtils.serializeGetReducerFileGroupResponse(
115+
shuffleId, getReducerFileGroupResponse));
116+
lifecycleManager.registerInvalidatedBroadcastGetReducerFileGroupResponse(
117+
shuffleId -> SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
111118
}
112119
}
113120
}

client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java

+118
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@
1717

1818
package org.apache.spark.shuffle.celeborn;
1919

20+
import java.io.ByteArrayInputStream;
2021
import java.io.IOException;
22+
import java.io.ObjectInputStream;
23+
import java.io.ObjectOutputStream;
2124
import java.lang.reflect.Field;
2225
import java.lang.reflect.Method;
2326
import java.util.HashSet;
2427
import java.util.List;
2528
import java.util.Map;
2629
import java.util.Set;
2730
import java.util.concurrent.ConcurrentHashMap;
31+
import java.util.concurrent.atomic.AtomicInteger;
2832
import java.util.concurrent.atomic.LongAdder;
2933
import java.util.stream.Collectors;
3034

@@ -38,6 +42,7 @@
3842
import org.apache.spark.SparkContext;
3943
import org.apache.spark.SparkContext$;
4044
import org.apache.spark.TaskContext;
45+
import org.apache.spark.broadcast.Broadcast;
4146
import org.apache.spark.scheduler.DAGScheduler;
4247
import org.apache.spark.scheduler.MapStatus;
4348
import org.apache.spark.scheduler.MapStatus$;
@@ -54,7 +59,9 @@
5459

5560
import org.apache.celeborn.client.ShuffleClient;
5661
import org.apache.celeborn.common.CelebornConf;
62+
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
5763
import org.apache.celeborn.common.util.JavaUtils;
64+
import org.apache.celeborn.common.util.KeyLock;
5865
import org.apache.celeborn.common.util.Utils;
5966
import org.apache.celeborn.reflect.DynFields;
6067

@@ -346,4 +353,115 @@ public static void addSparkListener(SparkListener listener) {
346353
sparkContext.addSparkListener(listener);
347354
}
348355
}
356+
357+
/**
358+
* A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread accessing the
359+
* broadcast belonging to the shuffle id at a time.
360+
*/
361+
private static KeyLock<Integer> shuffleBroadcastLock = new KeyLock();
362+
363+
@VisibleForTesting
364+
public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new AtomicInteger();
365+
366+
protected static Map<Integer, Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>>
367+
getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
368+
369+
public static byte[] serializeGetReducerFileGroupResponse(
370+
Integer shuffleId, GetReducerFileGroupResponse response) {
371+
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
372+
if (sparkContext == null) {
373+
logger.error("Can not get active SparkContext.");
374+
return null;
375+
}
376+
377+
return shuffleBroadcastLock.withLock(
378+
shuffleId,
379+
() -> {
380+
Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>
381+
cachedSerializeGetReducerFileGroupResponse =
382+
getReducerFileGroupResponseBroadcasts.get(shuffleId);
383+
if (cachedSerializeGetReducerFileGroupResponse != null) {
384+
return cachedSerializeGetReducerFileGroupResponse._2;
385+
}
386+
387+
try {
388+
logger.info("Broadcasting GetReducerFileGroupResponse for shuffle: {}", shuffleId);
389+
Broadcast<GetReducerFileGroupResponse> broadcast =
390+
sparkContext.broadcast(
391+
response,
392+
scala.reflect.ClassManifestFactory.fromClass(
393+
GetReducerFileGroupResponse.class));
394+
395+
// Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard
396+
// one
397+
// This implementation doesn't reallocate the whole memory block but allocates
398+
// additional buffers. This way no buffers need to be garbage collected and
399+
// the contents don't have to be copied to the new buffer.
400+
org.apache.commons.io.output.ByteArrayOutputStream out =
401+
new org.apache.commons.io.output.ByteArrayOutputStream();
402+
try (ObjectOutputStream oos = new ObjectOutputStream(out)) {
403+
oos.writeObject(broadcast);
404+
}
405+
byte[] _serializeResult = out.toByteArray();
406+
getReducerFileGroupResponseBroadcasts.put(
407+
shuffleId, Tuple2.apply(broadcast, _serializeResult));
408+
getReducerFileGroupResponseBroadcastNum.incrementAndGet();
409+
return _serializeResult;
410+
} catch (Throwable e) {
411+
logger.error(
412+
"Failed to serialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
413+
return null;
414+
}
415+
});
416+
}
417+
418+
public static GetReducerFileGroupResponse deserializeGetReducerFileGroupResponse(
419+
Integer shuffleId, byte[] bytes) {
420+
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
421+
if (sparkContext == null) {
422+
logger.error("Can not get active SparkContext.");
423+
return null;
424+
}
425+
426+
return shuffleBroadcastLock.withLock(
427+
shuffleId,
428+
() -> {
429+
GetReducerFileGroupResponse response = null;
430+
logger.info(
431+
"Deserializing GetReducerFileGroupResponse broadcast for shuffle: {}", shuffleId);
432+
433+
try {
434+
try (ObjectInputStream objIn = new ObjectInputStream(new ByteArrayInputStream(bytes))) {
435+
Broadcast<GetReducerFileGroupResponse> broadcast =
436+
(Broadcast<GetReducerFileGroupResponse>) objIn.readObject();
437+
response = broadcast.value();
438+
}
439+
} catch (Throwable e) {
440+
logger.error(
441+
"Failed to deserialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
442+
}
443+
return response;
444+
});
445+
}
446+
447+
public static void invalidateSerializedGetReducerFileGroupResponse(Integer shuffleId) {
448+
shuffleBroadcastLock.withLock(
449+
shuffleId,
450+
() -> {
451+
try {
452+
Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>
453+
cachedSerializeGetReducerFileGroupResponse =
454+
getReducerFileGroupResponseBroadcasts.remove(shuffleId);
455+
if (cachedSerializeGetReducerFileGroupResponse != null) {
456+
cachedSerializeGetReducerFileGroupResponse._1().destroy();
457+
}
458+
} catch (Throwable e) {
459+
logger.error(
460+
"Failed to invalidate serialized GetReducerFileGroupResponse for shuffle: "
461+
+ shuffleId,
462+
e);
463+
}
464+
return null;
465+
});
466+
}
349467
}

client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala

+11
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
2020
import java.io.IOException
2121
import java.util.concurrent.{ThreadPoolExecutor, TimeUnit}
2222
import java.util.concurrent.atomic.AtomicReference
23+
import java.util.function.BiFunction
2324

2425
import org.apache.spark.{Aggregator, InterruptibleIterator, TaskContext}
2526
import org.apache.spark.internal.Logging
@@ -33,6 +34,7 @@ import org.apache.celeborn.client.read.CelebornInputStream
3334
import org.apache.celeborn.client.read.MetricsCallback
3435
import org.apache.celeborn.common.CelebornConf
3536
import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException}
37+
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
3638
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils}
3739

3840
class CelebornShuffleReader[K, C](
@@ -254,4 +256,13 @@ class CelebornShuffleReader[K, C](
254256

255257
object CelebornShuffleReader {
256258
var streamCreatorPool: ThreadPoolExecutor = null
259+
// Register the deserializer for GetReducerFileGroupResponse broadcast
260+
ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(new BiFunction[
261+
Integer,
262+
Array[Byte],
263+
GetReducerFileGroupResponse] {
264+
override def apply(shuffleId: Integer, broadcast: Array[Byte]): GetReducerFileGroupResponse = {
265+
SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, broadcast)
266+
}
267+
})
257268
}

client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java

+7
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,13 @@ private void initializeLifecycleManager(String appId) {
150150

151151
lifecycleManager.registerShuffleTrackerCallback(
152152
shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId));
153+
154+
lifecycleManager.registerBroadcastGetReducerFileGroupResponse(
155+
(shuffleId, getReducerFileGroupResponse) ->
156+
SparkUtils.serializeGetReducerFileGroupResponse(
157+
shuffleId, getReducerFileGroupResponse));
158+
lifecycleManager.registerInvalidatedBroadcastGetReducerFileGroupResponse(
159+
shuffleId -> SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
153160
}
154161
}
155162
}

client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java

+124
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717

1818
package org.apache.spark.shuffle.celeborn;
1919

20+
import java.io.ByteArrayInputStream;
21+
import java.io.ObjectInputStream;
22+
import java.io.ObjectOutputStream;
2023
import java.util.HashSet;
2124
import java.util.List;
2225
import java.util.Map;
2326
import java.util.Set;
2427
import java.util.concurrent.ConcurrentHashMap;
28+
import java.util.concurrent.atomic.AtomicInteger;
2529
import java.util.concurrent.atomic.LongAdder;
2630
import java.util.stream.Collectors;
2731

@@ -36,6 +40,8 @@
3640
import org.apache.spark.SparkContext;
3741
import org.apache.spark.SparkContext$;
3842
import org.apache.spark.TaskContext;
43+
import org.apache.spark.broadcast.Broadcast;
44+
import org.apache.spark.io.CompressionCodec;
3945
import org.apache.spark.scheduler.DAGScheduler;
4046
import org.apache.spark.scheduler.MapStatus;
4147
import org.apache.spark.scheduler.MapStatus$;
@@ -57,7 +63,9 @@
5763

5864
import org.apache.celeborn.client.ShuffleClient;
5965
import org.apache.celeborn.common.CelebornConf;
66+
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
6067
import org.apache.celeborn.common.util.JavaUtils;
68+
import org.apache.celeborn.common.util.KeyLock;
6169
import org.apache.celeborn.reflect.DynConstructors;
6270
import org.apache.celeborn.reflect.DynFields;
6371
import org.apache.celeborn.reflect.DynMethods;
@@ -462,4 +470,120 @@ public static void addSparkListener(SparkListener listener) {
462470
sparkContext.addSparkListener(listener);
463471
}
464472
}
473+
474+
/**
475+
* A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread accessing the
476+
* broadcast belonging to the shuffle id at a time.
477+
*/
478+
private static KeyLock<Integer> shuffleBroadcastLock = new KeyLock();
479+
480+
@VisibleForTesting
481+
public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new AtomicInteger();
482+
483+
protected static Map<Integer, Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>>
484+
getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
485+
486+
public static byte[] serializeGetReducerFileGroupResponse(
487+
Integer shuffleId, GetReducerFileGroupResponse response) {
488+
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
489+
if (sparkContext == null) {
490+
LOG.error("Can not get active SparkContext.");
491+
return null;
492+
}
493+
494+
return shuffleBroadcastLock.withLock(
495+
shuffleId,
496+
() -> {
497+
Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>
498+
cachedSerializeGetReducerFileGroupResponse =
499+
getReducerFileGroupResponseBroadcasts.get(shuffleId);
500+
if (cachedSerializeGetReducerFileGroupResponse != null) {
501+
return cachedSerializeGetReducerFileGroupResponse._2;
502+
}
503+
504+
try {
505+
LOG.info("Broadcasting GetReducerFileGroupResponse for shuffle: {}", shuffleId);
506+
Broadcast<GetReducerFileGroupResponse> broadcast =
507+
sparkContext.broadcast(
508+
response,
509+
scala.reflect.ClassManifestFactory.fromClass(
510+
GetReducerFileGroupResponse.class));
511+
512+
CompressionCodec codec = CompressionCodec.createCodec(sparkContext.conf());
513+
// Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard
514+
// one
515+
// This implementation doesn't reallocate the whole memory block but allocates
516+
// additional buffers. This way no buffers need to be garbage collected and
517+
// the contents don't have to be copied to the new buffer.
518+
org.apache.commons.io.output.ByteArrayOutputStream out =
519+
new org.apache.commons.io.output.ByteArrayOutputStream();
520+
try (ObjectOutputStream oos =
521+
new ObjectOutputStream(codec.compressedOutputStream(out))) {
522+
oos.writeObject(broadcast);
523+
}
524+
byte[] _serializeResult = out.toByteArray();
525+
getReducerFileGroupResponseBroadcasts.put(
526+
shuffleId, Tuple2.apply(broadcast, _serializeResult));
527+
getReducerFileGroupResponseBroadcastNum.incrementAndGet();
528+
return _serializeResult;
529+
} catch (Throwable e) {
530+
LOG.error(
531+
"Failed to serialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
532+
return null;
533+
}
534+
});
535+
}
536+
537+
public static GetReducerFileGroupResponse deserializeGetReducerFileGroupResponse(
538+
Integer shuffleId, byte[] bytes) {
539+
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
540+
if (sparkContext == null) {
541+
LOG.error("Can not get active SparkContext.");
542+
return null;
543+
}
544+
545+
return shuffleBroadcastLock.withLock(
546+
shuffleId,
547+
() -> {
548+
GetReducerFileGroupResponse response = null;
549+
LOG.info(
550+
"Deserializing GetReducerFileGroupResponse broadcast for shuffle: {}", shuffleId);
551+
552+
try {
553+
CompressionCodec codec = CompressionCodec.createCodec(sparkContext.conf());
554+
try (ObjectInputStream objIn =
555+
new ObjectInputStream(
556+
codec.compressedInputStream(new ByteArrayInputStream(bytes)))) {
557+
Broadcast<GetReducerFileGroupResponse> broadcast =
558+
(Broadcast<GetReducerFileGroupResponse>) objIn.readObject();
559+
response = broadcast.value();
560+
}
561+
} catch (Throwable e) {
562+
LOG.error(
563+
"Failed to deserialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
564+
}
565+
return response;
566+
});
567+
}
568+
569+
public static void invalidateSerializedGetReducerFileGroupResponse(Integer shuffleId) {
570+
shuffleBroadcastLock.withLock(
571+
shuffleId,
572+
() -> {
573+
try {
574+
Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>
575+
cachedSerializeGetReducerFileGroupResponse =
576+
getReducerFileGroupResponseBroadcasts.remove(shuffleId);
577+
if (cachedSerializeGetReducerFileGroupResponse != null) {
578+
cachedSerializeGetReducerFileGroupResponse._1().destroy();
579+
}
580+
} catch (Throwable e) {
581+
LOG.error(
582+
"Failed to invalidate serialized GetReducerFileGroupResponse for shuffle: "
583+
+ shuffleId,
584+
e);
585+
}
586+
return null;
587+
});
588+
}
465589
}

0 commit comments

Comments
 (0)