Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CELEBORN-1921] Broadcast large GetReducerFileGroupResponse to prevent Spark driver network exhausted #3158

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ Apache Spark
./common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
./common/src/test/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManagerSuiteJ.java
./common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
./common/src/main/scala/org/apache/celeborn/common/util/KeyLock.scala
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DB.java
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DBIterator.java
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/LevelDB.java
Expand Down
1 change: 1 addition & 0 deletions client-spark/spark-2-shaded/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
<include>io.netty:*</include>
<include>org.apache.commons:commons-lang3</include>
<include>org.roaringbitmap:RoaringBitmap</include>
<include>commons-io:commons-io</include>
</includes>
</artifactSet>
<filters>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ private void initializeLifecycleManager(String appId) {
lifecycleManager.registerShuffleTrackerCallback(
shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId));
}

if (celebornConf.getReducerFileGroupBroadcastEnabled()) {
lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback(
(shuffleId, getReducerFileGroupResponse) ->
SparkUtils.serializeGetReducerFileGroupResponse(
shuffleId, getReducerFileGroupResponse));
lifecycleManager.registerInvalidatedBroadcastCallback(
shuffleId -> SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@

package org.apache.spark.shuffle.celeborn;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;

Expand All @@ -37,7 +41,12 @@
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkContext$;
import org.apache.spark.SparkEnv;
import org.apache.spark.SparkEnv$;
import org.apache.spark.TaskContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.scheduler.DAGScheduler;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
Expand All @@ -54,7 +63,10 @@

import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.KeyLock;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.reflect.DynFields;

Expand Down Expand Up @@ -346,4 +358,121 @@ public static void addSparkListener(SparkListener listener) {
sparkContext.addSparkListener(listener);
}
}

/**
* A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread accessing the
* broadcast belonging to the shuffle id at a time.
*/
private static KeyLock<Integer> shuffleBroadcastLock = new KeyLock();

@VisibleForTesting
public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new AtomicInteger();

@VisibleForTesting
public static Map<Integer, Tuple2<Broadcast<TransportMessage>, byte[]>>
getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();

public static byte[] serializeGetReducerFileGroupResponse(
Integer shuffleId, GetReducerFileGroupResponse response) {
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
if (sparkContext == null) {
logger.error("Can not get active SparkContext.");
return null;
}

return shuffleBroadcastLock.withLock(
shuffleId,
() -> {
Tuple2<Broadcast<TransportMessage>, byte[]> cachedSerializeGetReducerFileGroupResponse =
getReducerFileGroupResponseBroadcasts.get(shuffleId);
if (cachedSerializeGetReducerFileGroupResponse != null) {
return cachedSerializeGetReducerFileGroupResponse._2;
}

try {
logger.info("Broadcasting GetReducerFileGroupResponse for shuffle: {}", shuffleId);
TransportMessage transportMessage =
(TransportMessage) Utils.toTransportMessage(response);
Broadcast<TransportMessage> broadcast =
sparkContext.broadcast(
transportMessage,
scala.reflect.ClassManifestFactory.fromClass(TransportMessage.class));

CompressionCodec codec = CompressionCodec$.MODULE$.createCodec(sparkContext.conf());
// Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard
// one
// This implementation doesn't reallocate the whole memory block but allocates
// additional buffers. This way no buffers need to be garbage collected and
// the contents don't have to be copied to the new buffer.
org.apache.commons.io.output.ByteArrayOutputStream out =
new org.apache.commons.io.output.ByteArrayOutputStream();
try (ObjectOutputStream oos =
new ObjectOutputStream(codec.compressedOutputStream(out))) {
oos.writeObject(broadcast);
}
byte[] _serializeResult = out.toByteArray();
getReducerFileGroupResponseBroadcasts.put(
shuffleId, Tuple2.apply(broadcast, _serializeResult));
getReducerFileGroupResponseBroadcastNum.incrementAndGet();
return _serializeResult;
} catch (Throwable e) {
logger.error(
"Failed to serialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
return null;
}
});
}

public static GetReducerFileGroupResponse deserializeGetReducerFileGroupResponse(
Integer shuffleId, byte[] bytes) {
SparkEnv sparkEnv = SparkEnv$.MODULE$.get();
if (sparkEnv == null) {
logger.error("Can not get SparkEnv.");
return null;
}

return shuffleBroadcastLock.withLock(
shuffleId,
() -> {
GetReducerFileGroupResponse response = null;
logger.info(
"Deserializing GetReducerFileGroupResponse broadcast for shuffle: {}", shuffleId);

try {
CompressionCodec codec = CompressionCodec$.MODULE$.createCodec(sparkEnv.conf());
try (ObjectInputStream objIn =
new ObjectInputStream(
codec.compressedInputStream(new ByteArrayInputStream(bytes)))) {
Broadcast<TransportMessage> broadcast =
(Broadcast<TransportMessage>) objIn.readObject();
response =
(GetReducerFileGroupResponse) Utils.fromTransportMessage(broadcast.value());
}
} catch (Throwable e) {
logger.error(
"Failed to deserialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
}
return response;
});
}

public static void invalidateSerializedGetReducerFileGroupResponse(Integer shuffleId) {
shuffleBroadcastLock.withLock(
shuffleId,
() -> {
try {
Tuple2<Broadcast<TransportMessage>, byte[]> cachedSerializeGetReducerFileGroupResponse =
getReducerFileGroupResponseBroadcasts.remove(shuffleId);
if (cachedSerializeGetReducerFileGroupResponse != null) {
cachedSerializeGetReducerFileGroupResponse._1().destroy();
}
} catch (Throwable e) {
logger.error(
"Failed to invalidate serialized GetReducerFileGroupResponse for shuffle: "
+ shuffleId,
e);
}
return null;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
import java.io.IOException
import java.util.concurrent.{ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
import java.util.function.BiFunction

import org.apache.spark.{Aggregator, InterruptibleIterator, TaskContext}
import org.apache.spark.internal.Logging
Expand All @@ -33,6 +34,7 @@ import org.apache.celeborn.client.read.CelebornInputStream
import org.apache.celeborn.client.read.MetricsCallback
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException}
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils}

class CelebornShuffleReader[K, C](
Expand Down Expand Up @@ -254,4 +256,13 @@ class CelebornShuffleReader[K, C](

object CelebornShuffleReader {
var streamCreatorPool: ThreadPoolExecutor = null
// Register the deserializer for GetReducerFileGroupResponse broadcast
ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(new BiFunction[
Integer,
Array[Byte],
GetReducerFileGroupResponse] {
override def apply(shuffleId: Integer, broadcast: Array[Byte]): GetReducerFileGroupResponse = {
SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, broadcast)
}
})
}
1 change: 1 addition & 0 deletions client-spark/spark-3-shaded/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
<include>io.netty:*</include>
<include>org.apache.commons:commons-lang3</include>
<include>org.roaringbitmap:RoaringBitmap</include>
<include>commons-io:commons-io</include>
Copy link
Member Author

@turboFei turboFei Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fix the class not found:

rn.ApplicationMaster$$anon$2.run(ApplicationMaster.scala:732)
Caused by: java.lang.NoClassDefFoundError: org/apache/celeborn/shaded/org/apache/commons/io/output/ByteArrayOutputStream
	at org.apache.spark.shuffle.celeborn.SparkShuffleManager.<init>(SparkShuffleManager.java:113)
	... 20 more
Caused by: java.lang.ClassNotFoundException: org.apache.celeborn.shaded.org.apache.commons.io.output.ByteArrayOutputStream
	at java.net.URLClassLoader.findClass(URLClassLoader.java:387)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:419)
	at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:365)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:352)
	... 21 more
jar -tf client-spark/spark-3-shaded/target/celeborn-client-spark-3-shaded_2.12-0.6.0-SNAPSHOT.jar|grep org/apache/celeborn/shaded/org/apache/commons/io/output/ByteArrayOutputStream
org/apache/celeborn/shaded/org/apache/commons/io/output/ByteArrayOutputStream.class

Copy link
Member Author

@turboFei turboFei Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The org.apache.commons package is relocated.

<pattern>org.apache.commons</pattern>
<shadedPattern>${shading.prefix}.org.apache.commons</shadedPattern>

</includes>
</artifactSet>
<filters>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ private void initializeLifecycleManager(String appId) {
SparkUtils::isCelebornSkewShuffleOrChildShuffle);
}
}

if (celebornConf.getReducerFileGroupBroadcastEnabled()) {
lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback(
(shuffleId, getReducerFileGroupResponse) ->
SparkUtils.serializeGetReducerFileGroupResponse(
shuffleId, getReducerFileGroupResponse));
lifecycleManager.registerInvalidatedBroadcastCallback(
shuffleId -> SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
}
}
}
}
Expand Down
Loading
Loading