From 2570b50c50648f26b865d8f3b84adce6d4fa3c10 Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Fri, 13 Dec 2024 15:07:48 +0800 Subject: [PATCH 01/44] [CELEBORN-1319] Optimize skew partition logic for Reduce Mode to avoid sorting shuffle files --- .../celeborn/CelebornShuffleReader.scala | 1 + .../apache/celeborn/client/ShuffleClient.java | 4 + .../celeborn/client/ShuffleClientImpl.java | 58 +++++++++++- .../client/read/CelebornInputStream.java | 66 ++++++++++++- .../client/read/DfsPartitionReader.java | 1 + .../client/read/LocalPartitionReader.java | 1 + .../client/read/WorkerPartitionReader.java | 1 + .../celeborn/client/CommitManager.scala | 6 +- .../celeborn/client/LifecycleManager.scala | 17 +++- .../client/commit/CommitHandler.scala | 6 +- .../commit/MapPartitionCommitHandler.scala | 2 + .../commit/ReducePartitionCommitHandler.scala | 17 +++- .../celeborn/client/DummyShuffleClient.java | 7 +- .../client/WithShuffleClientSuite.scala | 2 + .../celeborn/common/protocol/StorageInfo.java | 20 ++++ .../common/write/PushFailedBatch.java | 94 +++++++++++++++++++ .../celeborn/common/write/PushState.java | 13 +++ common/src/main/proto/TransportMessages.proto | 13 +++ .../apache/celeborn/common/CelebornConf.scala | 10 ++ .../protocol/message/ControlMessages.scala | 27 ++++-- .../celeborn/common/util/PbSerDeUtils.scala | 21 +++++ .../celeborn/common/util/UtilsSuite.scala | 3 +- .../service/deploy/worker/Controller.scala | 10 +- .../service/deploy/worker/FetchHandler.scala | 15 ++- .../deploy/cluster/ReadWriteTestBase.scala | 1 + 25 files changed, 386 insertions(+), 30 deletions(-) create mode 100644 common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index ad802f8cef9..1d9a00587cb 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -222,6 +222,7 @@ class CelebornShuffleReader[K, C]( else null, locations, streamHandlers, + fileGroups.pushFailedBatchSet, fileGroups.mapAttempts, metricsCallback) streams.put(partitionId, inputStream) diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 673c9382437..5d57d2971d5 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; @@ -39,6 +40,7 @@ import org.apache.celeborn.common.rpc.RpcEndpointRef; import org.apache.celeborn.common.util.CelebornHadoopUtils; import org.apache.celeborn.common.util.ExceptionMaker; +import org.apache.celeborn.common.write.PushFailedBatch; import org.apache.celeborn.common.write.PushState; /** @@ -241,6 +243,7 @@ public CelebornInputStream readPartition( null, null, null, + null, metricsCallback); } @@ -255,6 +258,7 @@ public abstract CelebornInputStream readPartition( ExceptionMaker exceptionMaker, ArrayList locations, ArrayList streamHandlers, + Set pushFailedBatchSet, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException; diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 48412f81f12..f03afdc78e7 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -67,6 +67,7 @@ import org.apache.celeborn.common.unsafe.Platform; import org.apache.celeborn.common.util.*; import org.apache.celeborn.common.write.DataBatches; +import org.apache.celeborn.common.write.PushFailedBatch; import org.apache.celeborn.common.write.PushState; public class ShuffleClientImpl extends ShuffleClient { @@ -142,30 +143,37 @@ protected Compressor initialValue() { private final ReviveManager reviveManager; - public static class ReduceFileGroups { + private final boolean dataPushFailureTrackingEnabled; + + protected static class ReduceFileGroups { public Map> partitionGroups; + public Set pushFailedBatchSet; public int[] mapAttempts; public Set partitionIds; ReduceFileGroups( Map> partitionGroups, int[] mapAttempts, - Set partitionIds) { + Set partitionIds, + Set pushFailedBatches) { this.partitionGroups = partitionGroups; this.mapAttempts = mapAttempts; this.partitionIds = partitionIds; + this.pushFailedBatchSet = pushFailedBatches; } public ReduceFileGroups() { this.partitionGroups = null; this.mapAttempts = null; this.partitionIds = null; + this.pushFailedBatchSet = null; } public void update(ReduceFileGroups fileGroups) { partitionGroups = fileGroups.partitionGroups; mapAttempts = fileGroups.mapAttempts; partitionIds = fileGroups.partitionIds; + pushFailedBatchSet = fileGroups.pushFailedBatchSet; } } @@ -193,6 +201,7 @@ public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier u pushDataTimeout = conf.pushDataTimeoutMs(); } authEnabled = conf.authEnabledOnClient(); + dataPushFailureTrackingEnabled = conf.clientPushFailureTrackingEnabled(); // init rpc env rpcEnv = @@ -1102,6 +1111,11 @@ public void onSuccess(ByteBuffer response) { attemptId, partitionId, nextBatchId); + if (dataPushFailureTrackingEnabled) { + pushState.addFailedBatch( + new PushFailedBatch( + mapId, attemptId, nextBatchId, partitionId, latest.getEpoch())); + } ReviveRequest reviveRequest = new ReviveRequest( shuffleId, @@ -1168,6 +1182,11 @@ public void onSuccess(ByteBuffer response) { @Override public void onFailure(Throwable e) { + if (dataPushFailureTrackingEnabled) { + pushState.addFailedBatch( + new PushFailedBatch( + mapId, attemptId, nextBatchId, partitionId, latest.getEpoch())); + } if (pushState.exception.get() != null) { return; } @@ -1390,6 +1409,7 @@ private void doPushMergedData( final String[] partitionUniqueIds = new String[numBatches]; final int[] offsets = new int[numBatches]; final int[] batchIds = new int[numBatches]; + final int[] epochs = new int[numBatches]; int currentSize = 0; CompositeByteBuf byteBuf = Unpooled.compositeBuffer(); for (int i = 0; i < numBatches; i++) { @@ -1398,6 +1418,7 @@ private void doPushMergedData( partitionUniqueIds[i] = batch.loc.getUniqueId(); offsets[i] = currentSize; batchIds[i] = batch.batchId; + epochs[i] = batch.loc.getEpoch(); currentSize += batch.body.length; byteBuf.addComponent(true, Unpooled.wrappedBuffer(batch.body)); } @@ -1534,6 +1555,13 @@ public void onSuccess(ByteBuffer response) { pushState.onSuccess(hostPort); callback.onSuccess(ByteBuffer.wrap(new byte[] {StatusCode.SOFT_SPLIT.getValue()})); } else { + if (dataPushFailureTrackingEnabled) { + for (int i = 0; i < numBatches; i++) { + pushState.addFailedBatch( + new PushFailedBatch( + mapId, attemptId, batchIds[i], partitionIds[i], epochs[i])); + } + } ReviveRequest[] requests = addAndGetReviveRequests( shuffleId, mapId, attemptId, batchesNeedResubmit, StatusCode.HARD_SPLIT); @@ -1589,6 +1617,12 @@ public void onSuccess(ByteBuffer response) { @Override public void onFailure(Throwable e) { + if (dataPushFailureTrackingEnabled) { + for (int i = 0; i < numBatches; i++) { + pushState.addFailedBatch( + new PushFailedBatch(mapId, attemptId, batchIds[i], partitionIds[i], epochs[i])); + } + } if (pushState.exception.get() != null) { return; } @@ -1709,7 +1743,13 @@ private void mapEndInternal( MapperEndResponse response = lifecycleManagerRef.askSync( - new MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId), + new MapperEnd( + shuffleId, + mapId, + attemptId, + numMappers, + partitionId, + pushState.getFailedBatches()), ClassTag$.MODULE$.apply(MapperEndResponse.class)); if (response.status() != StatusCode.SUCCESS) { throw new CelebornIOException("MapperEnd failed! StatusCode: " + response.status()); @@ -1770,7 +1810,10 @@ protected Tuple2 loadFileGroupInternal( response.fileGroup().size()); return Tuple2.apply( new ReduceFileGroups( - response.fileGroup(), response.attempts(), response.partitionIds()), + response.fileGroup(), + response.attempts(), + response.partitionIds(), + response.pushFailedBatches()), null); case SHUFFLE_NOT_REGISTERED: logger.warn( @@ -1781,7 +1824,10 @@ protected Tuple2 loadFileGroupInternal( // return empty result return Tuple2.apply( new ReduceFileGroups( - response.fileGroup(), response.attempts(), response.partitionIds()), + response.fileGroup(), + response.attempts(), + response.partitionIds(), + response.pushFailedBatches()), null); case STAGE_END_TIME_OUT: case SHUFFLE_DATA_LOST: @@ -1852,6 +1898,7 @@ public CelebornInputStream readPartition( ExceptionMaker exceptionMaker, ArrayList locations, ArrayList streamHandlers, + Set pushFailedBatchSet, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException { @@ -1884,6 +1931,7 @@ public CelebornInputStream readPartition( locations, streamHandlers, mapAttempts, + pushFailedBatchSet, attemptNumber, taskId, startMapIndex, diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index fc295a746c8..7610229834e 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -43,6 +43,7 @@ import org.apache.celeborn.common.unsafe.Platform; import org.apache.celeborn.common.util.ExceptionMaker; import org.apache.celeborn.common.util.Utils; +import org.apache.celeborn.common.write.PushFailedBatch; public abstract class CelebornInputStream extends InputStream { private static final Logger logger = LoggerFactory.getLogger(CelebornInputStream.class); @@ -54,6 +55,7 @@ public static CelebornInputStream create( ArrayList locations, ArrayList streamHandlers, int[] attempts, + Set failedBatchSet, int attemptNumber, long taskId, int startMapIndex, @@ -69,13 +71,20 @@ public static CelebornInputStream create( if (locations == null || locations.size() == 0) { return emptyInputStream; } else { + // if startMapIndex > endMapIndex, means partition is skew partition. + // locations will split to sub-partitions with startMapIndex size. + ArrayList filterLocations = locations; + if (conf.clientPushFailureTrackingEnabled() && startMapIndex > endMapIndex) { + filterLocations = getSkewPartitionLocations(locations, startMapIndex, endMapIndex); + } return new CelebornInputStreamImpl( conf, clientFactory, shuffleKey, - locations, + filterLocations, streamHandlers, attempts, + failedBatchSet, attemptNumber, taskId, startMapIndex, @@ -90,6 +99,41 @@ public static CelebornInputStream create( } } + public static ArrayList getSkewPartitionLocations( + List locations, int subPartitionSize, int subPartitionIndex) { + Set sortSet = + new TreeSet<>( + (o1, o2) -> { + if (o1.getStorageInfo().fileSize > o2.getStorageInfo().fileSize) { + return 1; + } else if (o1.getStorageInfo().fileSize < o2.getStorageInfo().fileSize) { + return -1; + } else { + return o1.hashCode() - o2.hashCode(); + } + }); + sortSet.addAll(locations); + PartitionLocation[] orderedPartitionLocations = sortSet.toArray(new PartitionLocation[0]); + + ArrayList result = new ArrayList<>(); + + int step = locations.size() / subPartitionSize; + + // if partition location is [1,2,3,4,5,6,7,8,9,10], and skew partition split to 3 task: + // task 0: 1, 6, 7 + // task 1: 2, 5, 8 + // task 2: 3, 4, 9, 10 + for (int i = 0; i < step + 1; i++) { + if (i % 2 == 0 && (i * 3 + subPartitionIndex) < locations.size()) { + result.add(orderedPartitionLocations[i * subPartitionSize + subPartitionIndex]); + } else if (((i + 1) * subPartitionSize - subPartitionIndex - 1) < locations.size()) { + result.add(orderedPartitionLocations[(i + 1) * subPartitionSize - subPartitionIndex - 1]); + } + } + + return result; + } + public static CelebornInputStream empty() { return emptyInputStream; } @@ -137,6 +181,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private Map> batchesRead = new HashMap<>(); + private final Set failedBatches; + private byte[] compressedBuf; private byte[] rawDataBuf; private Decompressor decompressor; @@ -173,6 +219,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private ExceptionMaker exceptionMaker; private boolean closed = false; + private final boolean pushShuffleFailureTrackingEnabled; + CelebornInputStreamImpl( CelebornConf conf, TransportClientFactory clientFactory, @@ -180,6 +228,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { ArrayList locations, ArrayList streamHandlers, int[] attempts, + Set failedBatchSet, int attemptNumber, long taskId, int startMapIndex, @@ -210,6 +259,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { this.shuffleCompressionEnabled = !conf.shuffleCompressionCodec().equals(CompressionCodec.NONE); this.fetchExcludedWorkerExpireTimeout = conf.clientFetchExcludedWorkerExpireTimeout(); + this.failedBatches = failedBatchSet; + this.pushShuffleFailureTrackingEnabled = conf.clientPushFailureTrackingEnabled(); this.fetchExcludedWorkers = fetchExcludedWorkers; if (conf.clientPushReplicateEnabled()) { @@ -612,6 +663,19 @@ private boolean fillBuffer() throws IOException { // de-duplicate if (attemptId == attempts[mapId]) { + if (pushShuffleFailureTrackingEnabled) { + PushFailedBatch failedBatch = + new PushFailedBatch( + mapId, + attemptId, + batchId, + currentReader.getLocation().getId(), + currentReader.getLocation().getEpoch()); + if (this.failedBatches.contains(failedBatch)) { + logger.warn("Skip duplicated batch: {}.", failedBatch); + continue; + } + } if (!batchesRead.containsKey(mapId)) { Set batchSet = new HashSet<>(); batchesRead.put(mapId, batchSet); diff --git a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java index b69cf580fa7..18a2ca9418e 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java @@ -107,6 +107,7 @@ public DfsPartitionReader( .setFileName(location.getFileName()) .setStartIndex(startMapIndex) .setEndIndex(endMapIndex) + .setShuffleDataNeedSort(conf.clientPushFailureTrackingEnabled()) .build() .toByteArray()); ByteBuffer response = client.sendRpcSync(openStream.toByteBuffer(), fetchTimeoutMs); diff --git a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java index a769687c8df..5a760a88bf6 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java @@ -104,6 +104,7 @@ public LocalPartitionReader( .setStartIndex(startMapIndex) .setEndIndex(endMapIndex) .setReadLocalShuffle(true) + .setShuffleDataNeedSort(conf.clientPushFailureTrackingEnabled()) .build() .toByteArray()); ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(), fetchTimeoutMs); diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index 3158aa12f72..5ab01e90afd 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -126,6 +126,7 @@ public void onFailure(int chunkIndex, Throwable e) { .setFileName(location.getFileName()) .setStartIndex(startMapIndex) .setEndIndex(endMapIndex) + .setShuffleDataNeedSort(conf.clientPushFailureTrackingEnabled()) .build() .toByteArray()); ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(), fetchTimeoutMs); diff --git a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala index 201be286978..e7516f9be86 100644 --- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.{AtomicInteger, LongAdder} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import com.google.common.collect.Sets import org.roaringbitmap.RoaringBitmap import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo @@ -40,6 +41,7 @@ import org.apache.celeborn.common.rpc.RpcCallContext import org.apache.celeborn.common.util.FunctionConverter._ import org.apache.celeborn.common.util.JavaUtils import org.apache.celeborn.common.util.ThreadUtils +import org.apache.celeborn.common.write.PushFailedBatch case class ShuffleCommittedInfo( // partition id -> unique partition ids @@ -215,13 +217,15 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage mapId: Int, attemptId: Int, numMappers: Int, - partitionId: Int = -1): (Boolean, Boolean) = { + partitionId: Int = -1, + pushFailedBatches: util.Set[PushFailedBatch] = Sets.newHashSet()): (Boolean, Boolean) = { getCommitHandler(shuffleId).finishMapperAttempt( shuffleId, mapId, attemptId, numMappers, partitionId, + pushFailedBatches, r => lifecycleManager.workerStatusTracker.recordWorkerFailure(r)) } diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 94b5a0676bf..0f95b04cfed 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -58,11 +58,14 @@ import org.apache.celeborn.common.util.{JavaUtils, PbSerDeUtils, ThreadUtils, Ut import org.apache.celeborn.common.util.FunctionConverter._ import org.apache.celeborn.common.util.ThreadUtils.awaitResult import org.apache.celeborn.common.util.Utils.UNKNOWN_APP_SHUFFLE_ID +import org.apache.celeborn.common.write.PushFailedBatch object LifecycleManager { // shuffle id -> partition id -> partition locations type ShuffleFileGroups = ConcurrentHashMap[Int, ConcurrentHashMap[Integer, util.Set[PartitionLocation]]] + type ShufflePushFailedBatches = + ConcurrentHashMap[Int, ConcurrentHashMap[Integer, util.Set[PushFailedBatch]]] type ShuffleAllocatedWorkers = ConcurrentHashMap[Int, ConcurrentHashMap[String, ShufflePartitionLocationInfo]] type ShuffleFailedWorkers = ConcurrentHashMap[WorkerInfo, (StatusCode, Long)] @@ -404,13 +407,13 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends oldPartition, isSegmentGranularityVisible = commitManager.isSegmentGranularityVisible(shuffleId)) - case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId) => + case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) => logTrace(s"Received MapperEnd TaskEnd request, " + s"${Utils.makeMapKey(shuffleId, mapId, attemptId)}") val partitionType = getPartitionType(shuffleId) partitionType match { case PartitionType.REDUCE => - handleMapperEnd(context, shuffleId, mapId, attemptId, numMappers) + handleMapperEnd(context, shuffleId, mapId, attemptId, numMappers, pushFailedBatch) case PartitionType.MAP => handleMapPartitionEnd( context, @@ -802,10 +805,16 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shuffleId: Int, mapId: Int, attemptId: Int, - numMappers: Int): Unit = { + numMappers: Int, + pushFailedBatches: util.Set[PushFailedBatch]): Unit = { val (mapperAttemptFinishedSuccess, allMapperFinished) = - commitManager.finishMapperAttempt(shuffleId, mapId, attemptId, numMappers) + commitManager.finishMapperAttempt( + shuffleId, + mapId, + attemptId, + numMappers, + pushFailedBatches = pushFailedBatches) if (mapperAttemptFinishedSuccess && allMapperFinished) { // last mapper finished. call mapper end logInfo(s"Last MapperEnd, call StageEnd with shuffleKey:" + diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala index 65ba4dbab04..d76ecc35339 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala @@ -30,7 +30,7 @@ import scala.concurrent.duration.Duration import org.apache.celeborn.client.{ShuffleCommittedInfo, WorkerStatusTracker} import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo -import org.apache.celeborn.client.LifecycleManager.{ShuffleFailedWorkers, ShuffleFileGroups} +import org.apache.celeborn.client.LifecycleManager.{ShuffleFailedWorkers, ShuffleFileGroups, ShufflePushFailedBatches} import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo} @@ -42,6 +42,7 @@ import org.apache.celeborn.common.util.{CollectionUtils, JavaUtils, Utils} // Can Remove this if celeborn don't support scala211 in future import org.apache.celeborn.common.util.FunctionConverter._ import org.apache.celeborn.common.util.ThreadUtils.awaitResult +import org.apache.celeborn.common.write.PushFailedBatch case class CommitFilesParam( worker: WorkerInfo, @@ -74,6 +75,7 @@ abstract class CommitHandler( private val totalWritten = new LongAdder private val fileCount = new LongAdder protected val reducerFileGroupsMap = new ShuffleFileGroups + protected val shufflePushFailedBatches = new ShufflePushFailedBatches val ec = ExecutionContext.fromExecutor(sharedRpcPool) @@ -178,6 +180,7 @@ abstract class CommitHandler( def removeExpiredShuffle(shuffleId: Int): Unit = { reducerFileGroupsMap.remove(shuffleId) + shufflePushFailedBatches.remove(shuffleId) } /** @@ -197,6 +200,7 @@ abstract class CommitHandler( attemptId: Int, numMappers: Int, partitionId: Int, + pushFailedBatches: util.Set[PushFailedBatch], recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) def registerShuffle( diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala index 9352dc99a72..e4bd94dbc81 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala @@ -39,6 +39,7 @@ import org.apache.celeborn.common.rpc.RpcCallContext import org.apache.celeborn.common.util.FunctionConverter._ import org.apache.celeborn.common.util.JavaUtils import org.apache.celeborn.common.util.Utils +import org.apache.celeborn.common.write.PushFailedBatch /** * This commit handler is for MapPartition ShuffleType, which means that a Map Partition contains all data produced @@ -184,6 +185,7 @@ class MapPartitionCommitHandler( attemptId: Int, numMappers: Int, partitionId: Int, + pushFailedBatches: util.Set[PushFailedBatch], recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = { val inProcessingPartitionIds = inProcessMapPartitionEndIds.computeIfAbsent( diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index 951fc89e601..e0289508d52 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -38,6 +38,7 @@ import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.rpc.RpcCallContext import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext} import org.apache.celeborn.common.util.JavaUtils +import org.apache.celeborn.common.write.PushFailedBatch /** * This commit handler is for ReducePartition ShuffleType, which means that a Reduce Partition contains all data @@ -240,6 +241,7 @@ class ReducePartitionCommitHandler( attemptId: Int, numMappers: Int, partitionId: Int, + pushFailedBatches: util.Set[PushFailedBatch], recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = { shuffleMapperAttempts.synchronized { if (getMapperAttempts(shuffleId) == null) { @@ -250,6 +252,14 @@ class ReducePartitionCommitHandler( val attempts = shuffleMapperAttempts.get(shuffleId) if (attempts(mapId) < 0) { attempts(mapId) = attemptId + if (null != pushFailedBatches && !pushFailedBatches.isEmpty) { + val pushFailedBatchesMap = shufflePushFailedBatches.computeIfAbsent( + shuffleId, + _ => { + JavaUtils.newConcurrentHashMap[Integer, util.Set[PushFailedBatch]]() + }) + pushFailedBatchesMap.put(mapId, pushFailedBatches) + } // Mapper with this attemptId finished, also check all other mapper finished or not. (true, ClientUtils.areAllMapperAttemptsFinished(attempts)) } else { @@ -301,7 +311,12 @@ class ReducePartitionCommitHandler( val returnedMsg = GetReducerFileGroupResponse( StatusCode.SUCCESS, reducerFileGroupsMap.getOrDefault(shuffleId, JavaUtils.newConcurrentHashMap()), - getMapperAttempts(shuffleId)) + getMapperAttempts(shuffleId), + pushFailedBatches = + shufflePushFailedBatches.getOrDefault( + shuffleId, + JavaUtils.newConcurrentHashMap()).values().asScala.flatMap(x => + x.asScala.toSet[PushFailedBatch]).toSet.asJava) context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg) } }) diff --git a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java index 77a9c784c4a..c2fef84470f 100644 --- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -24,10 +24,7 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; @@ -43,6 +40,7 @@ import org.apache.celeborn.common.rpc.RpcEndpointRef; import org.apache.celeborn.common.util.ExceptionMaker; import org.apache.celeborn.common.util.JavaUtils; +import org.apache.celeborn.common.write.PushFailedBatch; import org.apache.celeborn.common.write.PushState; public class DummyShuffleClient extends ShuffleClient { @@ -136,6 +134,7 @@ public CelebornInputStream readPartition( ExceptionMaker exceptionMaker, ArrayList locations, ArrayList streamHandlers, + Set pushFailedBatchSet, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException { diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala index 8cd8bedf76d..ad5b722b92a 100644 --- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala @@ -164,6 +164,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { null, null, null, + null, metricsCallback) Assert.assertEquals(stream.read(), -1) @@ -180,6 +181,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { null, null, null, + null, metricsCallback) Assert.assertEquals(stream.read(), -1) } diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java b/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java index 28cb652565b..a96d332f6be 100644 --- a/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java +++ b/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java @@ -68,6 +68,10 @@ public int getValue() { public int availableStorageTypes = 0; + public long fileSize; + + public int numChunks; + public StorageInfo() {} public StorageInfo(Type type, boolean isFinal, String filePath) { @@ -119,6 +123,22 @@ public String getFilePath() { return filePath; } + public void setNumChunks(int numChunks) { + this.numChunks = numChunks; + } + + public int getNumChunks() { + return this.numChunks; + } + + public void setFileSize(long fileSize) { + this.fileSize = fileSize; + } + + public long getFileSize() { + return fileSize; + } + @Override public String toString() { return "StorageInfo{" diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java new file mode 100644 index 00000000000..9388c3c581d --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java @@ -0,0 +1,94 @@ +package org.apache.celeborn.common.write; + +import java.io.Serializable; + +import com.google.common.base.Objects; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; + +public class PushFailedBatch implements Serializable { + + private int mapId; + private int attemptId; + private int batchId; + private int epoch; + private int reduceId; + + public int getMapId() { + return mapId; + } + + public void setMapId(int mapId) { + this.mapId = mapId; + } + + public int getAttemptId() { + return attemptId; + } + + public void setAttemptId(int attemptId) { + this.attemptId = attemptId; + } + + public int getBatchId() { + return batchId; + } + + public void setBatchId(int batchId) { + this.batchId = batchId; + } + + public int getReduceId() { + return reduceId; + } + + public void setReduceId(int reduceId) { + this.reduceId = reduceId; + } + + public int getEpoch() { + return epoch; + } + + public void setEpoch(int epoch) { + this.epoch = epoch; + } + + public PushFailedBatch(int mapId, int attemptId, int batchId, int reduceId, int epoch) { + this.mapId = mapId; + this.attemptId = attemptId; + this.batchId = batchId; + this.reduceId = reduceId; + this.epoch = epoch; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof PushFailedBatch)) { + return false; + } + PushFailedBatch o = (PushFailedBatch) other; + return super.equals(o) + && mapId == o.mapId + && attemptId == o.attemptId + && batchId == o.batchId + && reduceId == o.reduceId + && epoch == o.epoch; + } + + @Override + public int hashCode() { + return Objects.hashCode(super.hashCode(), mapId, attemptId, batchId, reduceId, epoch); + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("mapId", mapId) + .append("attemptId", attemptId) + .append("batchId", batchId) + .append("reduceId", reduceId) + .append("epoch", epoch) + .toString(); + } +} diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushState.java b/common/src/main/java/org/apache/celeborn/common/write/PushState.java index 3979cafd632..7c390e3bb48 100644 --- a/common/src/main/java/org/apache/celeborn/common/write/PushState.java +++ b/common/src/main/java/org/apache/celeborn/common/write/PushState.java @@ -18,6 +18,8 @@ package org.apache.celeborn.common.write; import java.io.IOException; +import java.util.HashSet; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; @@ -33,9 +35,12 @@ public class PushState { public AtomicReference exception = new AtomicReference<>(); private final InFlightRequestTracker inFlightRequestTracker; + private Set failedBatchSet; + public PushState(CelebornConf conf) { pushBufferMaxSize = conf.clientPushBufferMaxSize(); inFlightRequestTracker = new InFlightRequestTracker(conf, this); + failedBatchSet = new HashSet<>(); } public void cleanup() { @@ -88,4 +93,12 @@ public boolean limitZeroInFlight() throws IOException { public int remainingAllowPushes(String hostAndPushPort) { return inFlightRequestTracker.remainingAllowPushes(hostAndPushPort); } + + public void addFailedBatch(PushFailedBatch failedBatch) { + this.failedBatchSet.add(failedBatch); + } + + public Set getFailedBatches() { + return this.failedBatchSet; + } } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index 8aa59bcb29b..56885ccbfbe 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -354,6 +354,15 @@ message PbMapperEnd { int32 attemptId = 3; int32 numMappers = 4; int32 partitionId = 5; + repeated PbPushFailedBatch pushFailedBatches = 6; +} + +message PbPushFailedBatch { + int32 mapId = 1; + int32 attemptId = 2; + int32 batchId = 3; + int32 epoch = 4; + int32 reduceId =5; } message PbMapperEndResponse { @@ -375,6 +384,8 @@ message PbGetReducerFileGroupResponse { // only map partition mode has succeed partitionIds repeated int32 partitionIds = 4; + + repeated PbPushFailedBatch pushFailedBatches = 5; } message PbGetShuffleId { @@ -691,6 +702,8 @@ message PbOpenStream { int32 endIndex = 4; int32 initialCredit = 5; bool readLocalShuffle = 6; + bool requireSubpartitionId = 7; + bool shuffleDataNeedSort = 8; } message PbStreamHandler { diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 791c6fc2fc4..0eb0fe9e657 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1035,6 +1035,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientPushSendBufferPoolExpireTimeout: Long = get(CLIENT_PUSH_SENDBUFFERPOOL_EXPIRETIMEOUT) def clientPushSendBufferPoolExpireCheckInterval: Long = get(CLIENT_PUSH_SENDBUFFERPOOL_CHECKEXPIREINTERVAL) + def clientPushFailureTrackingEnabled: Boolean = get(CLIENT_DATA_PUSH_FAILURE_TRACKING_ENABLED) // ////////////////////////////////////////////////////// // Client Shuffle // @@ -5875,6 +5876,15 @@ object CelebornConf extends Logging { .intConf .createWithDefault(10000) + val CLIENT_DATA_PUSH_FAILURE_TRACKING_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.client.dataPushFailure.tracking.enabled") + .categories("client") + .version("0.5.0") + .doc("When client push data to worker failed, client will record the failed batch info. " + + "Feature used to optimize skew join by avoid data sorting") + .booleanConf + .createWithDefault(false) + // SSL Configs val SSL_ENABLED: ConfigEntry[Boolean] = diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 0e465196d28..1c9e7831546 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -19,6 +19,7 @@ package org.apache.celeborn.common.protocol.message import java.util import java.util.{Collections, UUID} +import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ @@ -31,7 +32,8 @@ import org.apache.celeborn.common.network.protocol.TransportMessage import org.apache.celeborn.common.protocol._ import org.apache.celeborn.common.protocol.MessageType._ import org.apache.celeborn.common.quota.ResourceConsumption -import org.apache.celeborn.common.util.{PbSerDeUtils, Utils} +import org.apache.celeborn.common.util.{JavaUtils, PbSerDeUtils, Utils} +import org.apache.celeborn.common.write.PushFailedBatch sealed trait Message extends Serializable @@ -271,7 +273,8 @@ object ControlMessages extends Logging { mapId: Int, attemptId: Int, numMappers: Int, - partitionId: Int) + partitionId: Int, + failedBatchSet: util.Set[PushFailedBatch]) extends MasterMessage case class MapperEndResponse(status: StatusCode) extends MasterMessage @@ -285,7 +288,8 @@ object ControlMessages extends Logging { status: StatusCode, fileGroup: util.Map[Integer, util.Set[PartitionLocation]], attempts: Array[Int], - partitionIds: util.Set[Integer] = Collections.emptySet[Integer]()) + partitionIds: util.Set[Integer] = Collections.emptySet[Integer](), + pushFailedBatches: util.Set[PushFailedBatch] = new util.HashSet[PushFailedBatch]) extends MasterMessage object WorkerExclude { @@ -721,13 +725,15 @@ object ControlMessages extends Logging { case pb: PbChangeLocationResponse => new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, pb.toByteArray) - case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId) => + case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) => val payload = PbMapperEnd.newBuilder() .setShuffleId(shuffleId) .setMapId(mapId) .setAttemptId(attemptId) .setNumMappers(numMappers) .setPartitionId(partitionId) + .addAllPushFailedBatches(pushFailedBatch.asScala.map( + PbSerDeUtils.toPbPushFailedBatch).asJava) .build().toByteArray new TransportMessage(MessageType.MAPPER_END, payload) @@ -744,7 +750,7 @@ object ControlMessages extends Logging { .build().toByteArray new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP, payload) - case GetReducerFileGroupResponse(status, fileGroup, attempts, partitionIds) => + case GetReducerFileGroupResponse(status, fileGroup, attempts, partitionIds, failedBatches) => val builder = PbGetReducerFileGroupResponse .newBuilder() .setStatus(status.getValue) @@ -757,6 +763,8 @@ object ControlMessages extends Logging { }.asJava) builder.addAllAttempts(attempts.map(Integer.valueOf).toIterable.asJava) builder.addAllPartitionIds(partitionIds) + builder.addAllPushFailedBatches( + failedBatches.asScala.map(PbSerDeUtils.toPbPushFailedBatch).asJava) val payload = builder.build().toByteArray new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE, payload) @@ -1141,7 +1149,9 @@ object ControlMessages extends Logging { pbMapperEnd.getMapId, pbMapperEnd.getAttemptId, pbMapperEnd.getNumMappers, - pbMapperEnd.getPartitionId) + pbMapperEnd.getPartitionId, + pbMapperEnd.getPushFailedBatchesList.asScala.toSet.map( + PbSerDeUtils.fromPbPushFailedBatch).asJava) case MAPPER_END_RESPONSE_VALUE => val pbMapperEndResponse = PbMapperEndResponse.parseFrom(message.getPayload) @@ -1177,11 +1187,14 @@ object ControlMessages extends Logging { val attempts = pbGetReducerFileGroupResponse.getAttemptsList.asScala.map(_.toInt).toArray val partitionIds = new util.HashSet(pbGetReducerFileGroupResponse.getPartitionIdsList) + val pushFailedBatches = pbGetReducerFileGroupResponse + .getPushFailedBatchesList.asScala.map(PbSerDeUtils.fromPbPushFailedBatch).toSet.asJava GetReducerFileGroupResponse( Utils.toStatusCode(pbGetReducerFileGroupResponse.getStatus), fileGroup, attempts, - partitionIds) + partitionIds, + pushFailedBatches) case GET_SHUFFLE_ID_VALUE => message.getParsedPayload() diff --git a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala index 8a038242e75..d3229c1c4fa 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala @@ -32,6 +32,7 @@ import org.apache.celeborn.common.protocol.PartitionLocation.Mode import org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource import org.apache.celeborn.common.quota.ResourceConsumption import org.apache.celeborn.common.util.{CollectionUtils => localCollectionUtils} +import org.apache.celeborn.common.write.PushFailedBatch object PbSerDeUtils { @@ -670,4 +671,24 @@ object PbSerDeUtils { }.asJava } + def toPbPushFailedBatch(pushFailedBatch: PushFailedBatch): PbPushFailedBatch = { + val builder = PbPushFailedBatch.newBuilder() + .setMapId(pushFailedBatch.getMapId) + .setAttemptId(pushFailedBatch.getAttemptId) + .setBatchId(pushFailedBatch.getBatchId) + .setReduceId(pushFailedBatch.getReduceId) + .setEpoch(pushFailedBatch.getEpoch) + + builder.build() + } + + def fromPbPushFailedBatch(pbPushFailedBatch: PbPushFailedBatch): PushFailedBatch = { + new PushFailedBatch( + pbPushFailedBatch.getMapId, + pbPushFailedBatch.getAttemptId, + pbPushFailedBatch.getBatchId, + pbPushFailedBatch.getReduceId, + pbPushFailedBatch.getEpoch) + } + } diff --git a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala index afc74707f62..7cc2872e400 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala @@ -27,6 +27,7 @@ import org.apache.celeborn.common.identity.DefaultIdentityProvider import org.apache.celeborn.common.protocol.{PartitionLocation, TransportModuleConstants} import org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse, MapperEnd} import org.apache.celeborn.common.protocol.message.StatusCode +import org.apache.celeborn.common.write.PushFailedBatch class UtilsSuite extends CelebornFunSuite { @@ -144,7 +145,7 @@ class UtilsSuite extends CelebornFunSuite { } test("MapperEnd class convert with pb") { - val mapperEnd = MapperEnd(1, 1, 1, 2, 1) + val mapperEnd = MapperEnd(1, 1, 1, 2, 1, new util.HashSet[PushFailedBatch]()) val mapperEndTrans = Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd] assert(mapperEnd == mapperEndTrans) diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala index 65285f64729..11ef735b6f3 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala @@ -337,7 +337,15 @@ private[deploy] class Controller( // Only HDFS can be null, means that this partition location is deleted. logDebug(s"Location $uniqueId is deleted.") } else { - committedStorageInfos.put(uniqueId, fileWriter.getStorageInfo) + val storageInfo = fileWriter.getStorageInfo + val fileMeta = fileWriter.getDiskFileInfo.getFileMeta + fileMeta match { + case meta: ReduceFileMeta => + storageInfo.setNumChunks(meta.getNumChunks) + storageInfo.setFileSize(bytes) + case _ => + } + committedStorageInfos.put(uniqueId, storageInfo) if (fileWriter.getMapIdBitMap != null) { committedMapIdBitMap.put(uniqueId, fileWriter.getMapIdBitMap) } diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index ff32b940173..c1f1db5b0c0 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -136,6 +136,7 @@ class FetchHandler( rpcRequest.requestId, isLegacy = false, openStream.getReadLocalShuffle, + openStream.getShuffleDataNeedSort, callback) case openStreamList: PbOpenStreamList => val shuffleKey = openStreamList.getShuffleKey() @@ -207,6 +208,7 @@ class FetchHandler( isLegacy = true, // legacy [[OpenStream]] doesn't support read local shuffle readLocalShuffle = false, + shuffleNeedSort = true, callback) case Message.Type.OPEN_STREAM_WITH_CREDIT => val openStreamWithCredit = message.asInstanceOf[OpenStreamWithCredit] @@ -220,6 +222,7 @@ class FetchHandler( rpcRequestId = rpcRequest.requestId, isLegacy = true, readLocalShuffle = false, + shuffleNeedSort = true, callback) case _ => logError(s"Received an unknown message type id: ${message.`type`.id}") @@ -237,7 +240,8 @@ class FetchHandler( fileName: String, startIndex: Int, endIndex: Int, - readLocalShuffle: Boolean = false): PbStreamHandlerOpt = { + readLocalShuffle: Boolean = false, + shuffleNeedSort: Boolean = true): PbStreamHandlerOpt = { try { logDebug(s"Received open stream request $shuffleKey $fileName $startIndex " + s"$endIndex get file name $fileName from client channel " + @@ -249,8 +253,9 @@ class FetchHandler( // 1. when the current request is a non-range openStream, but the original unsorted file // has been deleted by another range's openStream request. // 2. when the current request is a range openStream request. - if ((endIndex != Int.MaxValue) || (endIndex == Int.MaxValue - && !fileInfo.addStream(streamId))) { + if (shuffleNeedSort && ((endIndex != Int.MaxValue) || (endIndex == Int.MaxValue && !fileInfo + .addStream( + streamId)))) { fileInfo = partitionsSorter.getSortedFileInfo( shuffleKey, fileName, @@ -341,6 +346,7 @@ class FetchHandler( rpcRequestId: Long, isLegacy: Boolean, readLocalShuffle: Boolean = false, + shuffleNeedSort: Boolean = true, callback: RpcResponseCallback): Unit = { checkAuth(client, Utils.splitShuffleKey(shuffleKey)._1) workerSource.recordAppActiveConnection(client, shuffleKey) @@ -356,7 +362,8 @@ class FetchHandler( fileName, startIndex, endIndex, - readLocalShuffle) + readLocalShuffle, + shuffleNeedSort) if (pbStreamHandlerOpt.getStatus != StatusCode.SUCCESS.getValue) { throw new CelebornIOException(pbStreamHandlerOpt.getErrorMsg) diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala index 0ff646b8f8b..eee7b46fb0f 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala @@ -116,6 +116,7 @@ trait ReadWriteTestBase extends AnyFunSuite null, null, null, + null, metricsCallback) val outputStream = new ByteArrayOutputStream() From 43416c33d5c2ea96013d409f8ac61e10d6becfed Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Mon, 11 Mar 2024 15:03:35 +0800 Subject: [PATCH 02/44] fix unit test --- .../commit/ReducePartitionCommitHandler.scala | 12 +++++++++--- .../celeborn/common/write/PushFailedBatch.java | 17 +++++++++++++++++ docs/configuration/client.md | 1 + 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index e0289508d52..49c8828a056 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -20,6 +20,7 @@ package org.apache.celeborn.client.commit import java.nio.ByteBuffer import java.util import java.util.concurrent.{Callable, ConcurrentHashMap, ThreadPoolExecutor, TimeUnit} +import java.util.function import scala.collection.JavaConverters._ import scala.collection.mutable @@ -83,6 +84,13 @@ class ReducePartitionCommitHandler( .maximumSize(rpcCacheSize) .build().asInstanceOf[Cache[Int, ByteBuffer]] + val newMapFunc: function.Function[Int, ConcurrentHashMap[Integer, util.Set[PushFailedBatch]]] = + new util.function.Function[Int, ConcurrentHashMap[Integer, util.Set[PushFailedBatch]]]() { + override def apply(s: Int): ConcurrentHashMap[Integer, util.Set[PushFailedBatch]] = { + JavaUtils.newConcurrentHashMap[Integer, util.Set[PushFailedBatch]]() + } + } + override def getPartitionType(): PartitionType = { PartitionType.REDUCE } @@ -255,9 +263,7 @@ class ReducePartitionCommitHandler( if (null != pushFailedBatches && !pushFailedBatches.isEmpty) { val pushFailedBatchesMap = shufflePushFailedBatches.computeIfAbsent( shuffleId, - _ => { - JavaUtils.newConcurrentHashMap[Integer, util.Set[PushFailedBatch]]() - }) + newMapFunc) pushFailedBatchesMap.put(mapId, pushFailedBatches) } // Mapper with this attemptId finished, also check all other mapper finished or not. diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java index 9388c3c581d..edc2de1aa9a 100644 --- a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java +++ b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.celeborn.common.write; import java.io.Serializable; diff --git a/docs/configuration/client.md b/docs/configuration/client.md index f035713eb90..c1c252ae1d0 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -25,6 +25,7 @@ license: | | celeborn.client.chunk.prefetch.enabled | false | false | Whether to enable chunk prefetch when creating CelebornInputStream. | 0.6.0 | | | celeborn.client.closeIdleConnections | true | false | Whether client will close idle connections. | 0.3.0 | | | celeborn.client.commitFiles.ignoreExcludedWorker | false | false | When true, LifecycleManager will skip workers which are in the excluded list. | 0.3.0 | | +| celeborn.client.dataPushFailure.tracking.enabled | false | false | When client push data to worker failed, client will record the failed batch info. Feature used to optimize skew join by avoid data sorting | 0.5.0 | | | celeborn.client.eagerlyCreateInputStream.threads | 32 | false | Threads count for streamCreatorPool in CelebornShuffleReader. | 0.3.1 | | | celeborn.client.excludePeerWorkerOnFailure.enabled | true | false | When true, Celeborn will exclude partition's peer worker on failure when push data to replica failed. | 0.3.0 | | | celeborn.client.excludedWorker.expireTimeout | 180s | false | Timeout time for LifecycleManager to clear reserved excluded worker. Default to be 1.5 * `celeborn.master.heartbeat.worker.timeout` to cover worker heartbeat timeout check period | 0.3.0 | celeborn.worker.excluded.expireTimeout | From 355f6b334f11d4e1971a2f57e1f3508d8591d731 Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Sat, 16 Mar 2024 16:10:55 +0800 Subject: [PATCH 03/44] fix ut and refactor code --- .../celeborn/client/ShuffleClientImpl.java | 25 +++++----- .../client/read/CelebornInputStream.java | 41 +++++++++-------- .../client/read/DfsPartitionReader.java | 1 - .../client/read/LocalPartitionReader.java | 1 - .../client/read/WorkerPartitionReader.java | 1 - .../celeborn/client/CommitManager.scala | 4 +- .../celeborn/client/LifecycleManager.scala | 5 +- .../client/commit/CommitHandler.scala | 2 +- .../commit/MapPartitionCommitHandler.scala | 2 +- .../commit/ReducePartitionCommitHandler.scala | 33 +++++++++---- .../common/write/PushFailedBatch.java | 22 ++------- .../celeborn/common/write/PushState.java | 17 ++++--- common/src/main/proto/TransportMessages.proto | 9 ++-- .../protocol/message/ControlMessages.scala | 46 ++++++++++++++----- .../celeborn/common/util/PbSerDeUtils.scala | 2 - .../celeborn/common/util/UtilsSuite.scala | 4 +- .../service/deploy/worker/FetchHandler.scala | 7 +-- 17 files changed, 125 insertions(+), 97 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index f03afdc78e7..57ca8faefaa 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -147,7 +147,7 @@ protected Compressor initialValue() { protected static class ReduceFileGroups { public Map> partitionGroups; - public Set pushFailedBatchSet; + public Map> pushFailedBatches; public int[] mapAttempts; public Set partitionIds; @@ -155,25 +155,25 @@ protected static class ReduceFileGroups { Map> partitionGroups, int[] mapAttempts, Set partitionIds, - Set pushFailedBatches) { + Map> pushFailedBatches) { this.partitionGroups = partitionGroups; this.mapAttempts = mapAttempts; this.partitionIds = partitionIds; - this.pushFailedBatchSet = pushFailedBatches; + this.pushFailedBatches = pushFailedBatches; } public ReduceFileGroups() { this.partitionGroups = null; this.mapAttempts = null; this.partitionIds = null; - this.pushFailedBatchSet = null; + this.pushFailedBatches = null; } public void update(ReduceFileGroups fileGroups) { partitionGroups = fileGroups.partitionGroups; mapAttempts = fileGroups.mapAttempts; partitionIds = fileGroups.partitionIds; - pushFailedBatchSet = fileGroups.pushFailedBatchSet; + pushFailedBatches = fileGroups.pushFailedBatches; } } @@ -1113,8 +1113,8 @@ public void onSuccess(ByteBuffer response) { nextBatchId); if (dataPushFailureTrackingEnabled) { pushState.addFailedBatch( - new PushFailedBatch( - mapId, attemptId, nextBatchId, partitionId, latest.getEpoch())); + latest.getUniqueId(), + new PushFailedBatch(mapId, attemptId, nextBatchId, latest.getEpoch())); } ReviveRequest reviveRequest = new ReviveRequest( @@ -1184,8 +1184,8 @@ public void onSuccess(ByteBuffer response) { public void onFailure(Throwable e) { if (dataPushFailureTrackingEnabled) { pushState.addFailedBatch( - new PushFailedBatch( - mapId, attemptId, nextBatchId, partitionId, latest.getEpoch())); + latest.getUniqueId(), + new PushFailedBatch(mapId, attemptId, nextBatchId, latest.getEpoch())); } if (pushState.exception.get() != null) { return; @@ -1558,8 +1558,8 @@ public void onSuccess(ByteBuffer response) { if (dataPushFailureTrackingEnabled) { for (int i = 0; i < numBatches; i++) { pushState.addFailedBatch( - new PushFailedBatch( - mapId, attemptId, batchIds[i], partitionIds[i], epochs[i])); + partitionUniqueIds[i], + new PushFailedBatch(mapId, attemptId, batchIds[i], epochs[i])); } } ReviveRequest[] requests = @@ -1620,7 +1620,8 @@ public void onFailure(Throwable e) { if (dataPushFailureTrackingEnabled) { for (int i = 0; i < numBatches; i++) { pushState.addFailedBatch( - new PushFailedBatch(mapId, attemptId, batchIds[i], partitionIds[i], epochs[i])); + partitionUniqueIds[i], + new PushFailedBatch(mapId, attemptId, batchIds[i], epochs[i])); } } if (pushState.exception.get() != null) { diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 7610229834e..34a9c42e210 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -55,7 +55,7 @@ public static CelebornInputStream create( ArrayList locations, ArrayList streamHandlers, int[] attempts, - Set failedBatchSet, + Map> failedBatchSet, int attemptNumber, long taskId, int startMapIndex, @@ -74,8 +74,11 @@ public static CelebornInputStream create( // if startMapIndex > endMapIndex, means partition is skew partition. // locations will split to sub-partitions with startMapIndex size. ArrayList filterLocations = locations; - if (conf.clientPushFailureTrackingEnabled() && startMapIndex > endMapIndex) { + boolean splitSkewPartitionWithoutMapRange = + conf.clientPushFailureTrackingEnabled() && startMapIndex > endMapIndex; + if (splitSkewPartitionWithoutMapRange) { filterLocations = getSkewPartitionLocations(locations, startMapIndex, endMapIndex); + endMapIndex = Integer.MAX_VALUE; } return new CelebornInputStreamImpl( conf, @@ -95,6 +98,7 @@ public static CelebornInputStream create( shuffleId, partitionId, exceptionMaker, + splitSkewPartitionWithoutMapRange, metricsCallback); } } @@ -115,18 +119,18 @@ public static ArrayList getSkewPartitionLocations( sortSet.addAll(locations); PartitionLocation[] orderedPartitionLocations = sortSet.toArray(new PartitionLocation[0]); - ArrayList result = new ArrayList<>(); - int step = locations.size() / subPartitionSize; + ArrayList result = new ArrayList<>(step + 1); // if partition location is [1,2,3,4,5,6,7,8,9,10], and skew partition split to 3 task: // task 0: 1, 6, 7 // task 1: 2, 5, 8 // task 2: 3, 4, 9, 10 for (int i = 0; i < step + 1; i++) { - if (i % 2 == 0 && (i * 3 + subPartitionIndex) < locations.size()) { + if (i % 2 == 0 && (i * subPartitionSize + subPartitionIndex) < locations.size()) { result.add(orderedPartitionLocations[i * subPartitionSize + subPartitionIndex]); - } else if (((i + 1) * subPartitionSize - subPartitionIndex - 1) < locations.size()) { + } else if (i % 2 == 1 + && ((i + 1) * subPartitionSize - subPartitionIndex - 1) < locations.size()) { result.add(orderedPartitionLocations[(i + 1) * subPartitionSize - subPartitionIndex - 1]); } } @@ -181,7 +185,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private Map> batchesRead = new HashMap<>(); - private final Set failedBatches; + private final Map> failedBatches; private byte[] compressedBuf; private byte[] rawDataBuf; @@ -219,7 +223,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private ExceptionMaker exceptionMaker; private boolean closed = false; - private final boolean pushShuffleFailureTrackingEnabled; + private final boolean splitSkewPartitionWithoutMapRange; CelebornInputStreamImpl( CelebornConf conf, @@ -228,7 +232,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { ArrayList locations, ArrayList streamHandlers, int[] attempts, - Set failedBatchSet, + Map> failedBatchSet, int attemptNumber, long taskId, int startMapIndex, @@ -239,6 +243,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { int shuffleId, int partitionId, ExceptionMaker exceptionMaker, + boolean splitSkewPartitionWithoutMapRange, MetricsCallback metricsCallback) throws IOException { this.conf = conf; @@ -260,7 +265,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { !conf.shuffleCompressionCodec().equals(CompressionCodec.NONE); this.fetchExcludedWorkerExpireTimeout = conf.clientFetchExcludedWorkerExpireTimeout(); this.failedBatches = failedBatchSet; - this.pushShuffleFailureTrackingEnabled = conf.clientPushFailureTrackingEnabled(); + this.splitSkewPartitionWithoutMapRange = splitSkewPartitionWithoutMapRange; this.fetchExcludedWorkers = fetchExcludedWorkers; if (conf.clientPushReplicateEnabled()) { @@ -309,7 +314,9 @@ private Tuple2 nextReadableLocation() { return null; } PartitionLocation currentLocation = locations.get(fileIndex); - while (skipLocation(startMapIndex, endMapIndex, currentLocation)) { + // if pushShuffleFailureTrackingEnabled is true, should not skip location + while (!splitSkewPartitionWithoutMapRange + && skipLocation(startMapIndex, endMapIndex, currentLocation)) { skipCount.increment(); fileIndex++; if (fileIndex == locationCount) { @@ -663,15 +670,13 @@ private boolean fillBuffer() throws IOException { // de-duplicate if (attemptId == attempts[mapId]) { - if (pushShuffleFailureTrackingEnabled) { + if (splitSkewPartitionWithoutMapRange) { PushFailedBatch failedBatch = new PushFailedBatch( - mapId, - attemptId, - batchId, - currentReader.getLocation().getId(), - currentReader.getLocation().getEpoch()); - if (this.failedBatches.contains(failedBatch)) { + mapId, attemptId, batchId, currentReader.getLocation().getEpoch()); + if (this.failedBatches + .get(currentReader.getLocation().getUniqueId()) + .contains(failedBatch)) { logger.warn("Skip duplicated batch: {}.", failedBatch); continue; } diff --git a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java index 18a2ca9418e..b69cf580fa7 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java @@ -107,7 +107,6 @@ public DfsPartitionReader( .setFileName(location.getFileName()) .setStartIndex(startMapIndex) .setEndIndex(endMapIndex) - .setShuffleDataNeedSort(conf.clientPushFailureTrackingEnabled()) .build() .toByteArray()); ByteBuffer response = client.sendRpcSync(openStream.toByteBuffer(), fetchTimeoutMs); diff --git a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java index 5a760a88bf6..a769687c8df 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java @@ -104,7 +104,6 @@ public LocalPartitionReader( .setStartIndex(startMapIndex) .setEndIndex(endMapIndex) .setReadLocalShuffle(true) - .setShuffleDataNeedSort(conf.clientPushFailureTrackingEnabled()) .build() .toByteArray()); ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(), fetchTimeoutMs); diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index 5ab01e90afd..3158aa12f72 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -126,7 +126,6 @@ public void onFailure(int chunkIndex, Throwable e) { .setFileName(location.getFileName()) .setStartIndex(startMapIndex) .setEndIndex(endMapIndex) - .setShuffleDataNeedSort(conf.clientPushFailureTrackingEnabled()) .build() .toByteArray()); ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(), fetchTimeoutMs); diff --git a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala index e7516f9be86..6e65a9c0a17 100644 --- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala @@ -18,6 +18,7 @@ package org.apache.celeborn.client import java.util +import java.util.Collections import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, ScheduledFuture, TimeUnit} import java.util.concurrent.atomic.{AtomicInteger, LongAdder} @@ -218,7 +219,8 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage attemptId: Int, numMappers: Int, partitionId: Int = -1, - pushFailedBatches: util.Set[PushFailedBatch] = Sets.newHashSet()): (Boolean, Boolean) = { + pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = Collections.emptyMap()) + : (Boolean, Boolean) = { getCommitHandler(shuffleId).finishMapperAttempt( shuffleId, mapId, diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 0f95b04cfed..712c84a764f 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -64,8 +64,9 @@ object LifecycleManager { // shuffle id -> partition id -> partition locations type ShuffleFileGroups = ConcurrentHashMap[Int, ConcurrentHashMap[Integer, util.Set[PartitionLocation]]] + // shuffle id -> partition uniqueId -> PushFailedBatch set type ShufflePushFailedBatches = - ConcurrentHashMap[Int, ConcurrentHashMap[Integer, util.Set[PushFailedBatch]]] + ConcurrentHashMap[Int, util.HashMap[String, util.Set[PushFailedBatch]]] type ShuffleAllocatedWorkers = ConcurrentHashMap[Int, ConcurrentHashMap[String, ShufflePartitionLocationInfo]] type ShuffleFailedWorkers = ConcurrentHashMap[WorkerInfo, (StatusCode, Long)] @@ -806,7 +807,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends mapId: Int, attemptId: Int, numMappers: Int, - pushFailedBatches: util.Set[PushFailedBatch]): Unit = { + pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]]): Unit = { val (mapperAttemptFinishedSuccess, allMapperFinished) = commitManager.finishMapperAttempt( diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala index d76ecc35339..63371601b95 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala @@ -200,7 +200,7 @@ abstract class CommitHandler( attemptId: Int, numMappers: Int, partitionId: Int, - pushFailedBatches: util.Set[PushFailedBatch], + pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]], recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) def registerShuffle( diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala index e4bd94dbc81..a08f1e0d51f 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala @@ -185,7 +185,7 @@ class MapPartitionCommitHandler( attemptId: Int, numMappers: Int, partitionId: Int, - pushFailedBatches: util.Set[PushFailedBatch], + pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]], recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = { val inProcessingPartitionIds = inProcessMapPartitionEndIds.computeIfAbsent( diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index 49c8828a056..d4257c386f6 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -26,6 +26,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import com.google.common.cache.{Cache, CacheBuilder} +import com.google.common.collect.Sets import org.apache.celeborn.client.{ClientUtils, ShuffleCommittedInfo, WorkerStatusTracker} import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo @@ -84,10 +85,19 @@ class ReducePartitionCommitHandler( .maximumSize(rpcCacheSize) .build().asInstanceOf[Cache[Int, ByteBuffer]] - val newMapFunc: function.Function[Int, ConcurrentHashMap[Integer, util.Set[PushFailedBatch]]] = - new util.function.Function[Int, ConcurrentHashMap[Integer, util.Set[PushFailedBatch]]]() { - override def apply(s: Int): ConcurrentHashMap[Integer, util.Set[PushFailedBatch]] = { - JavaUtils.newConcurrentHashMap[Integer, util.Set[PushFailedBatch]]() + private val newShuffleId2PushFailedBatchMapFunc + : function.Function[Int, util.HashMap[String, util.Set[PushFailedBatch]]] = + new util.function.Function[Int, util.HashMap[String, util.Set[PushFailedBatch]]]() { + override def apply(s: Int): util.HashMap[String, util.Set[PushFailedBatch]] = { + new util.HashMap[String, util.Set[PushFailedBatch]]() + } + } + + private val uniqueId2PushFailedBatchMapFunc + : function.Function[String, util.Set[PushFailedBatch]] = + new util.function.Function[String, util.Set[PushFailedBatch]]() { + override def apply(s: String): util.Set[PushFailedBatch] = { + Sets.newHashSet[PushFailedBatch]() } } @@ -249,7 +259,7 @@ class ReducePartitionCommitHandler( attemptId: Int, numMappers: Int, partitionId: Int, - pushFailedBatches: util.Set[PushFailedBatch], + pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]], recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = { shuffleMapperAttempts.synchronized { if (getMapperAttempts(shuffleId) == null) { @@ -263,8 +273,14 @@ class ReducePartitionCommitHandler( if (null != pushFailedBatches && !pushFailedBatches.isEmpty) { val pushFailedBatchesMap = shufflePushFailedBatches.computeIfAbsent( shuffleId, - newMapFunc) - pushFailedBatchesMap.put(mapId, pushFailedBatches) + newShuffleId2PushFailedBatchMapFunc) + pushFailedBatches.forEach((k, v) => { + val partitionPushFailedBatches = pushFailedBatchesMap.computeIfAbsent( + k, + uniqueId2PushFailedBatchMapFunc) + partitionPushFailedBatches.addAll(v) + }) + pushFailedBatchesMap.get(pushFailedBatches) } // Mapper with this attemptId finished, also check all other mapper finished or not. (true, ClientUtils.areAllMapperAttemptsFinished(attempts)) @@ -321,8 +337,7 @@ class ReducePartitionCommitHandler( pushFailedBatches = shufflePushFailedBatches.getOrDefault( shuffleId, - JavaUtils.newConcurrentHashMap()).values().asScala.flatMap(x => - x.asScala.toSet[PushFailedBatch]).toSet.asJava) + new util.HashMap[String, util.Set[PushFailedBatch]]())) context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg) } }) diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java index edc2de1aa9a..ab0fb340852 100644 --- a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java +++ b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java @@ -29,7 +29,6 @@ public class PushFailedBatch implements Serializable { private int attemptId; private int batchId; private int epoch; - private int reduceId; public int getMapId() { return mapId; @@ -55,14 +54,6 @@ public void setBatchId(int batchId) { this.batchId = batchId; } - public int getReduceId() { - return reduceId; - } - - public void setReduceId(int reduceId) { - this.reduceId = reduceId; - } - public int getEpoch() { return epoch; } @@ -71,11 +62,10 @@ public void setEpoch(int epoch) { this.epoch = epoch; } - public PushFailedBatch(int mapId, int attemptId, int batchId, int reduceId, int epoch) { + public PushFailedBatch(int mapId, int attemptId, int batchId, int epoch) { this.mapId = mapId; this.attemptId = attemptId; this.batchId = batchId; - this.reduceId = reduceId; this.epoch = epoch; } @@ -85,17 +75,12 @@ public boolean equals(Object other) { return false; } PushFailedBatch o = (PushFailedBatch) other; - return super.equals(o) - && mapId == o.mapId - && attemptId == o.attemptId - && batchId == o.batchId - && reduceId == o.reduceId - && epoch == o.epoch; + return mapId == o.mapId && attemptId == o.attemptId && batchId == o.batchId && epoch == o.epoch; } @Override public int hashCode() { - return Objects.hashCode(super.hashCode(), mapId, attemptId, batchId, reduceId, epoch); + return Objects.hashCode(mapId, attemptId, batchId, epoch); } @Override @@ -104,7 +89,6 @@ public String toString() { .append("mapId", mapId) .append("attemptId", attemptId) .append("batchId", batchId) - .append("reduceId", reduceId) .append("epoch", epoch) .toString(); } diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushState.java b/common/src/main/java/org/apache/celeborn/common/write/PushState.java index 7c390e3bb48..9f691094bd9 100644 --- a/common/src/main/java/org/apache/celeborn/common/write/PushState.java +++ b/common/src/main/java/org/apache/celeborn/common/write/PushState.java @@ -18,11 +18,12 @@ package org.apache.celeborn.common.write; import java.io.IOException; -import java.util.HashSet; +import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; +import com.google.common.collect.Sets; import org.apache.commons.lang3.tuple.Pair; import org.apache.celeborn.common.CelebornConf; @@ -35,12 +36,12 @@ public class PushState { public AtomicReference exception = new AtomicReference<>(); private final InFlightRequestTracker inFlightRequestTracker; - private Set failedBatchSet; + private final Map> failedBatchMap; public PushState(CelebornConf conf) { pushBufferMaxSize = conf.clientPushBufferMaxSize(); inFlightRequestTracker = new InFlightRequestTracker(conf, this); - failedBatchSet = new HashSet<>(); + failedBatchMap = new ConcurrentHashMap<>(); } public void cleanup() { @@ -94,11 +95,13 @@ public int remainingAllowPushes(String hostAndPushPort) { return inFlightRequestTracker.remainingAllowPushes(hostAndPushPort); } - public void addFailedBatch(PushFailedBatch failedBatch) { - this.failedBatchSet.add(failedBatch); + public void addFailedBatch(String partitionId, PushFailedBatch failedBatch) { + this.failedBatchMap + .computeIfAbsent(partitionId, (s) -> Sets.newConcurrentHashSet()) + .add(failedBatch); } - public Set getFailedBatches() { - return this.failedBatchSet; + public Map> getFailedBatches() { + return this.failedBatchMap; } } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index 56885ccbfbe..f7166d0b144 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -354,7 +354,11 @@ message PbMapperEnd { int32 attemptId = 3; int32 numMappers = 4; int32 partitionId = 5; - repeated PbPushFailedBatch pushFailedBatches = 6; + map pushFailureBatches= 6; +} + +message PbPushFailedBatchSet { + repeated PbPushFailedBatch failureBatches = 1; } message PbPushFailedBatch { @@ -362,7 +366,6 @@ message PbPushFailedBatch { int32 attemptId = 2; int32 batchId = 3; int32 epoch = 4; - int32 reduceId =5; } message PbMapperEndResponse { @@ -385,7 +388,7 @@ message PbGetReducerFileGroupResponse { // only map partition mode has succeed partitionIds repeated int32 partitionIds = 4; - repeated PbPushFailedBatch pushFailedBatches = 5; + map pushFailedBatches = 5; } message PbGetShuffleId { diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 1c9e7831546..2a6edd9a867 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -19,7 +19,6 @@ package org.apache.celeborn.common.protocol.message import java.util import java.util.{Collections, UUID} -import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ @@ -32,7 +31,7 @@ import org.apache.celeborn.common.network.protocol.TransportMessage import org.apache.celeborn.common.protocol._ import org.apache.celeborn.common.protocol.MessageType._ import org.apache.celeborn.common.quota.ResourceConsumption -import org.apache.celeborn.common.util.{JavaUtils, PbSerDeUtils, Utils} +import org.apache.celeborn.common.util.{PbSerDeUtils, Utils} import org.apache.celeborn.common.write.PushFailedBatch sealed trait Message extends Serializable @@ -274,7 +273,7 @@ object ControlMessages extends Logging { attemptId: Int, numMappers: Int, partitionId: Int, - failedBatchSet: util.Set[PushFailedBatch]) + failedBatchSet: util.Map[String, util.Set[PushFailedBatch]]) extends MasterMessage case class MapperEndResponse(status: StatusCode) extends MasterMessage @@ -289,7 +288,8 @@ object ControlMessages extends Logging { fileGroup: util.Map[Integer, util.Set[PartitionLocation]], attempts: Array[Int], partitionIds: util.Set[Integer] = Collections.emptySet[Integer](), - pushFailedBatches: util.Set[PushFailedBatch] = new util.HashSet[PushFailedBatch]) + pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = + new util.HashMap[String, util.Set[PushFailedBatch]]()) extends MasterMessage object WorkerExclude { @@ -726,14 +726,19 @@ object ControlMessages extends Logging { new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, pb.toByteArray) case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) => + val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) => + val resultValue = + PbPushFailedBatchSet.newBuilder().addAllFailureBatches(v.asScala.map(PbSerDeUtils + .toPbPushFailedBatch).asJava).build() + (k, resultValue) + }.toMap.asJava val payload = PbMapperEnd.newBuilder() .setShuffleId(shuffleId) .setMapId(mapId) .setAttemptId(attemptId) .setNumMappers(numMappers) .setPartitionId(partitionId) - .addAllPushFailedBatches(pushFailedBatch.asScala.map( - PbSerDeUtils.toPbPushFailedBatch).asJava) + .putAllPushFailureBatches(pushFailedMap) .build().toByteArray new TransportMessage(MessageType.MAPPER_END, payload) @@ -763,8 +768,15 @@ object ControlMessages extends Logging { }.asJava) builder.addAllAttempts(attempts.map(Integer.valueOf).toIterable.asJava) builder.addAllPartitionIds(partitionIds) - builder.addAllPushFailedBatches( - failedBatches.asScala.map(PbSerDeUtils.toPbPushFailedBatch).asJava) + builder.putAllPushFailedBatches( + failedBatches.asScala.map { + case (uniqueId, pushFailedBatchSet) => + ( + uniqueId, + PbPushFailedBatchSet.newBuilder().addAllFailureBatches( + pushFailedBatchSet.asScala.map(PbSerDeUtils + .toPbPushFailedBatch).asJava).build()) + }.asJava) val payload = builder.build().toByteArray new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE, payload) @@ -1150,8 +1162,13 @@ object ControlMessages extends Logging { pbMapperEnd.getAttemptId, pbMapperEnd.getNumMappers, pbMapperEnd.getPartitionId, - pbMapperEnd.getPushFailedBatchesList.asScala.toSet.map( - PbSerDeUtils.fromPbPushFailedBatch).asJava) + pbMapperEnd.getPushFailureBatchesMap.asScala.map { + case (partitionId, pushFailedBatchSet) => + ( + partitionId, + pushFailedBatchSet.getFailureBatchesList.asScala.map(PbSerDeUtils + .fromPbPushFailedBatch).toSet.asJava) + }.toMap.asJava) case MAPPER_END_RESPONSE_VALUE => val pbMapperEndResponse = PbMapperEndResponse.parseFrom(message.getPayload) @@ -1187,8 +1204,13 @@ object ControlMessages extends Logging { val attempts = pbGetReducerFileGroupResponse.getAttemptsList.asScala.map(_.toInt).toArray val partitionIds = new util.HashSet(pbGetReducerFileGroupResponse.getPartitionIdsList) - val pushFailedBatches = pbGetReducerFileGroupResponse - .getPushFailedBatchesList.asScala.map(PbSerDeUtils.fromPbPushFailedBatch).toSet.asJava + val pushFailedBatches = pbGetReducerFileGroupResponse.getPushFailedBatchesMap.asScala.map { + case (uniqueId, pushFailedBatchSet) => + ( + uniqueId, + pushFailedBatchSet.getFailureBatchesList.asScala.map(PbSerDeUtils + .fromPbPushFailedBatch).toSet.asJava) + }.toMap.asJava GetReducerFileGroupResponse( Utils.toStatusCode(pbGetReducerFileGroupResponse.getStatus), fileGroup, diff --git a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala index d3229c1c4fa..ceae38e877f 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala @@ -676,7 +676,6 @@ object PbSerDeUtils { .setMapId(pushFailedBatch.getMapId) .setAttemptId(pushFailedBatch.getAttemptId) .setBatchId(pushFailedBatch.getBatchId) - .setReduceId(pushFailedBatch.getReduceId) .setEpoch(pushFailedBatch.getEpoch) builder.build() @@ -687,7 +686,6 @@ object PbSerDeUtils { pbPushFailedBatch.getMapId, pbPushFailedBatch.getAttemptId, pbPushFailedBatch.getBatchId, - pbPushFailedBatch.getReduceId, pbPushFailedBatch.getEpoch) } diff --git a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala index 7cc2872e400..79c09dc2150 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala @@ -18,6 +18,8 @@ package org.apache.celeborn.common.util import java.util +import java.util.Collections +import java.util.stream.Collectors import org.apache.celeborn.CelebornFunSuite import org.apache.celeborn.common.CelebornConf @@ -145,7 +147,7 @@ class UtilsSuite extends CelebornFunSuite { } test("MapperEnd class convert with pb") { - val mapperEnd = MapperEnd(1, 1, 1, 2, 1, new util.HashSet[PushFailedBatch]()) + val mapperEnd = MapperEnd(1, 1, 1, 2, 1, Collections.emptyMap()) val mapperEndTrans = Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd] assert(mapperEnd == mapperEndTrans) diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index c1f1db5b0c0..0d0d7063e83 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -136,7 +136,6 @@ class FetchHandler( rpcRequest.requestId, isLegacy = false, openStream.getReadLocalShuffle, - openStream.getShuffleDataNeedSort, callback) case openStreamList: PbOpenStreamList => val shuffleKey = openStreamList.getShuffleKey() @@ -208,7 +207,6 @@ class FetchHandler( isLegacy = true, // legacy [[OpenStream]] doesn't support read local shuffle readLocalShuffle = false, - shuffleNeedSort = true, callback) case Message.Type.OPEN_STREAM_WITH_CREDIT => val openStreamWithCredit = message.asInstanceOf[OpenStreamWithCredit] @@ -222,7 +220,6 @@ class FetchHandler( rpcRequestId = rpcRequest.requestId, isLegacy = true, readLocalShuffle = false, - shuffleNeedSort = true, callback) case _ => logError(s"Received an unknown message type id: ${message.`type`.id}") @@ -346,7 +343,6 @@ class FetchHandler( rpcRequestId: Long, isLegacy: Boolean, readLocalShuffle: Boolean = false, - shuffleNeedSort: Boolean = true, callback: RpcResponseCallback): Unit = { checkAuth(client, Utils.splitShuffleKey(shuffleKey)._1) workerSource.recordAppActiveConnection(client, shuffleKey) @@ -362,8 +358,7 @@ class FetchHandler( fileName, startIndex, endIndex, - readLocalShuffle, - shuffleNeedSort) + readLocalShuffle) if (pbStreamHandlerOpt.getStatus != StatusCode.SUCCESS.getValue) { throw new CelebornIOException(pbStreamHandlerOpt.getErrorMsg) From 2b3ad581e7554454877c1f46503079963fb0b348 Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Sat, 16 Mar 2024 16:38:43 +0800 Subject: [PATCH 04/44] refactor code and imports --- .../client/read/CelebornInputStream.java | 27 ++++++------------- .../celeborn/client/CommitManager.scala | 1 - .../celeborn/common/util/UtilsSuite.scala | 2 -- 3 files changed, 8 insertions(+), 22 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 34a9c42e210..b0d5c916e0b 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -77,7 +77,7 @@ public static CelebornInputStream create( boolean splitSkewPartitionWithoutMapRange = conf.clientPushFailureTrackingEnabled() && startMapIndex > endMapIndex; if (splitSkewPartitionWithoutMapRange) { - filterLocations = getSkewPartitionLocations(locations, startMapIndex, endMapIndex); + filterLocations = getSubSkewPartitionLocations(locations, startMapIndex, endMapIndex); endMapIndex = Integer.MAX_VALUE; } return new CelebornInputStreamImpl( @@ -103,22 +103,11 @@ public static CelebornInputStream create( } } - public static ArrayList getSkewPartitionLocations( - List locations, int subPartitionSize, int subPartitionIndex) { - Set sortSet = - new TreeSet<>( - (o1, o2) -> { - if (o1.getStorageInfo().fileSize > o2.getStorageInfo().fileSize) { - return 1; - } else if (o1.getStorageInfo().fileSize < o2.getStorageInfo().fileSize) { - return -1; - } else { - return o1.hashCode() - o2.hashCode(); - } - }); - sortSet.addAll(locations); - PartitionLocation[] orderedPartitionLocations = sortSet.toArray(new PartitionLocation[0]); - + public static ArrayList getSubSkewPartitionLocations( + ArrayList locations, int subPartitionSize, int subPartitionIndex) { + locations.sort( + Comparator.comparingLong((PartitionLocation o) -> o.getStorageInfo().fileSize) + .thenComparing(PartitionLocation::getUniqueId)); int step = locations.size() / subPartitionSize; ArrayList result = new ArrayList<>(step + 1); @@ -128,10 +117,10 @@ public static ArrayList getSkewPartitionLocations( // task 2: 3, 4, 9, 10 for (int i = 0; i < step + 1; i++) { if (i % 2 == 0 && (i * subPartitionSize + subPartitionIndex) < locations.size()) { - result.add(orderedPartitionLocations[i * subPartitionSize + subPartitionIndex]); + result.add(locations.get(i * subPartitionSize + subPartitionIndex)); } else if (i % 2 == 1 && ((i + 1) * subPartitionSize - subPartitionIndex - 1) < locations.size()) { - result.add(orderedPartitionLocations[(i + 1) * subPartitionSize - subPartitionIndex - 1]); + result.add(locations.get((i + 1) * subPartitionSize - subPartitionIndex - 1)); } } diff --git a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala index 6e65a9c0a17..aaeb4462fbd 100644 --- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala @@ -25,7 +25,6 @@ import java.util.concurrent.atomic.{AtomicInteger, LongAdder} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import com.google.common.collect.Sets import org.roaringbitmap.RoaringBitmap import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo diff --git a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala index 79c09dc2150..03c6176edd8 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala @@ -19,7 +19,6 @@ package org.apache.celeborn.common.util import java.util import java.util.Collections -import java.util.stream.Collectors import org.apache.celeborn.CelebornFunSuite import org.apache.celeborn.common.CelebornConf @@ -29,7 +28,6 @@ import org.apache.celeborn.common.identity.DefaultIdentityProvider import org.apache.celeborn.common.protocol.{PartitionLocation, TransportModuleConstants} import org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse, MapperEnd} import org.apache.celeborn.common.protocol.message.StatusCode -import org.apache.celeborn.common.write.PushFailedBatch class UtilsSuite extends CelebornFunSuite { From 1a0462da7aee6692b44b07c38b18d5693dc8af59 Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Sun, 24 Mar 2024 17:07:37 +0800 Subject: [PATCH 05/44] fix NPE and remove unused code and fix scala 2.11 compile error --- .../celeborn/client/ShuffleClientImpl.java | 10 ++++----- .../client/read/CelebornInputStream.java | 22 ++++++++++++------- .../commit/ReducePartitionCommitHandler.scala | 9 ++++---- .../common/write/PushFailedBatch.java | 17 +++----------- common/src/main/proto/TransportMessages.proto | 1 - .../celeborn/common/util/PbSerDeUtils.scala | 4 +--- 6 files changed, 26 insertions(+), 37 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 57ca8faefaa..5515e8dc5c6 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -1114,7 +1114,7 @@ public void onSuccess(ByteBuffer response) { if (dataPushFailureTrackingEnabled) { pushState.addFailedBatch( latest.getUniqueId(), - new PushFailedBatch(mapId, attemptId, nextBatchId, latest.getEpoch())); + new PushFailedBatch(mapId, attemptId, nextBatchId)); } ReviveRequest reviveRequest = new ReviveRequest( @@ -1185,7 +1185,7 @@ public void onFailure(Throwable e) { if (dataPushFailureTrackingEnabled) { pushState.addFailedBatch( latest.getUniqueId(), - new PushFailedBatch(mapId, attemptId, nextBatchId, latest.getEpoch())); + new PushFailedBatch(mapId, attemptId, nextBatchId)); } if (pushState.exception.get() != null) { return; @@ -1409,7 +1409,6 @@ private void doPushMergedData( final String[] partitionUniqueIds = new String[numBatches]; final int[] offsets = new int[numBatches]; final int[] batchIds = new int[numBatches]; - final int[] epochs = new int[numBatches]; int currentSize = 0; CompositeByteBuf byteBuf = Unpooled.compositeBuffer(); for (int i = 0; i < numBatches; i++) { @@ -1418,7 +1417,6 @@ private void doPushMergedData( partitionUniqueIds[i] = batch.loc.getUniqueId(); offsets[i] = currentSize; batchIds[i] = batch.batchId; - epochs[i] = batch.loc.getEpoch(); currentSize += batch.body.length; byteBuf.addComponent(true, Unpooled.wrappedBuffer(batch.body)); } @@ -1559,7 +1557,7 @@ public void onSuccess(ByteBuffer response) { for (int i = 0; i < numBatches; i++) { pushState.addFailedBatch( partitionUniqueIds[i], - new PushFailedBatch(mapId, attemptId, batchIds[i], epochs[i])); + new PushFailedBatch(mapId, attemptId, batchIds[i])); } } ReviveRequest[] requests = @@ -1621,7 +1619,7 @@ public void onFailure(Throwable e) { for (int i = 0; i < numBatches; i++) { pushState.addFailedBatch( partitionUniqueIds[i], - new PushFailedBatch(mapId, attemptId, batchIds[i], epochs[i])); + new PushFailedBatch(mapId, attemptId, batchIds[i])); } } if (pushState.exception.get() != null) { diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index b0d5c916e0b..bf4e2b44b01 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -74,10 +74,15 @@ public static CelebornInputStream create( // if startMapIndex > endMapIndex, means partition is skew partition. // locations will split to sub-partitions with startMapIndex size. ArrayList filterLocations = locations; + logger.error("current split info: {},{}", startMapIndex, endMapIndex); boolean splitSkewPartitionWithoutMapRange = conf.clientPushFailureTrackingEnabled() && startMapIndex > endMapIndex; if (splitSkewPartitionWithoutMapRange) { + logger.error("use new mode to handle skew partition without map range"); filterLocations = getSubSkewPartitionLocations(locations, startMapIndex, endMapIndex); + + logger.error("current partition locations: {}", filterLocations); + endMapIndex = Integer.MAX_VALUE; } return new CelebornInputStreamImpl( @@ -660,14 +665,15 @@ private boolean fillBuffer() throws IOException { // de-duplicate if (attemptId == attempts[mapId]) { if (splitSkewPartitionWithoutMapRange) { - PushFailedBatch failedBatch = - new PushFailedBatch( - mapId, attemptId, batchId, currentReader.getLocation().getEpoch()); - if (this.failedBatches - .get(currentReader.getLocation().getUniqueId()) - .contains(failedBatch)) { - logger.warn("Skip duplicated batch: {}.", failedBatch); - continue; + Set failedBatchSet = this.failedBatches + .get(currentReader.getLocation().getUniqueId()); + if (null != failedBatchSet) { + PushFailedBatch failedBatch = + new PushFailedBatch(mapId, attemptId, batchId); + if (failedBatchSet.contains(failedBatch)) { + logger.warn("Skip duplicated batch: {}.", failedBatch); + continue; + } } } if (!batchesRead.containsKey(mapId)) { diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index d4257c386f6..55639764c7c 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -274,13 +274,12 @@ class ReducePartitionCommitHandler( val pushFailedBatchesMap = shufflePushFailedBatches.computeIfAbsent( shuffleId, newShuffleId2PushFailedBatchMapFunc) - pushFailedBatches.forEach((k, v) => { + for ((partitionUniqId, pushFailedBatchSet) <- pushFailedBatches.asScala) { val partitionPushFailedBatches = pushFailedBatchesMap.computeIfAbsent( - k, + partitionUniqId, uniqueId2PushFailedBatchMapFunc) - partitionPushFailedBatches.addAll(v) - }) - pushFailedBatchesMap.get(pushFailedBatches) + partitionPushFailedBatches.addAll(pushFailedBatchSet) + } } // Mapper with this attemptId finished, also check all other mapper finished or not. (true, ClientUtils.areAllMapperAttemptsFinished(attempts)) diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java index ab0fb340852..0e872611e13 100644 --- a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java +++ b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java @@ -28,7 +28,6 @@ public class PushFailedBatch implements Serializable { private int mapId; private int attemptId; private int batchId; - private int epoch; public int getMapId() { return mapId; @@ -54,19 +53,10 @@ public void setBatchId(int batchId) { this.batchId = batchId; } - public int getEpoch() { - return epoch; - } - - public void setEpoch(int epoch) { - this.epoch = epoch; - } - - public PushFailedBatch(int mapId, int attemptId, int batchId, int epoch) { + public PushFailedBatch(int mapId, int attemptId, int batchId) { this.mapId = mapId; this.attemptId = attemptId; this.batchId = batchId; - this.epoch = epoch; } @Override @@ -75,12 +65,12 @@ public boolean equals(Object other) { return false; } PushFailedBatch o = (PushFailedBatch) other; - return mapId == o.mapId && attemptId == o.attemptId && batchId == o.batchId && epoch == o.epoch; + return mapId == o.mapId && attemptId == o.attemptId && batchId == o.batchId; } @Override public int hashCode() { - return Objects.hashCode(mapId, attemptId, batchId, epoch); + return Objects.hashCode(mapId, attemptId, batchId); } @Override @@ -89,7 +79,6 @@ public String toString() { .append("mapId", mapId) .append("attemptId", attemptId) .append("batchId", batchId) - .append("epoch", epoch) .toString(); } } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index f7166d0b144..0514e5e1538 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -365,7 +365,6 @@ message PbPushFailedBatch { int32 mapId = 1; int32 attemptId = 2; int32 batchId = 3; - int32 epoch = 4; } message PbMapperEndResponse { diff --git a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala index ceae38e877f..ba4e4cdc03c 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala @@ -676,7 +676,6 @@ object PbSerDeUtils { .setMapId(pushFailedBatch.getMapId) .setAttemptId(pushFailedBatch.getAttemptId) .setBatchId(pushFailedBatch.getBatchId) - .setEpoch(pushFailedBatch.getEpoch) builder.build() } @@ -685,8 +684,7 @@ object PbSerDeUtils { new PushFailedBatch( pbPushFailedBatch.getMapId, pbPushFailedBatch.getAttemptId, - pbPushFailedBatch.getBatchId, - pbPushFailedBatch.getEpoch) + pbPushFailedBatch.getBatchId) } } From 5d1bcb6366ffb0a8c8400ce0c363f825f152e347 Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Sun, 24 Mar 2024 17:28:00 +0800 Subject: [PATCH 06/44] format code --- .../apache/celeborn/client/ShuffleClientImpl.java | 12 ++++-------- .../celeborn/client/read/CelebornInputStream.java | 7 +++---- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 5515e8dc5c6..2ff08e83770 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -1113,8 +1113,7 @@ public void onSuccess(ByteBuffer response) { nextBatchId); if (dataPushFailureTrackingEnabled) { pushState.addFailedBatch( - latest.getUniqueId(), - new PushFailedBatch(mapId, attemptId, nextBatchId)); + latest.getUniqueId(), new PushFailedBatch(mapId, attemptId, nextBatchId)); } ReviveRequest reviveRequest = new ReviveRequest( @@ -1184,8 +1183,7 @@ public void onSuccess(ByteBuffer response) { public void onFailure(Throwable e) { if (dataPushFailureTrackingEnabled) { pushState.addFailedBatch( - latest.getUniqueId(), - new PushFailedBatch(mapId, attemptId, nextBatchId)); + latest.getUniqueId(), new PushFailedBatch(mapId, attemptId, nextBatchId)); } if (pushState.exception.get() != null) { return; @@ -1556,8 +1554,7 @@ public void onSuccess(ByteBuffer response) { if (dataPushFailureTrackingEnabled) { for (int i = 0; i < numBatches; i++) { pushState.addFailedBatch( - partitionUniqueIds[i], - new PushFailedBatch(mapId, attemptId, batchIds[i])); + partitionUniqueIds[i], new PushFailedBatch(mapId, attemptId, batchIds[i])); } } ReviveRequest[] requests = @@ -1618,8 +1615,7 @@ public void onFailure(Throwable e) { if (dataPushFailureTrackingEnabled) { for (int i = 0; i < numBatches; i++) { pushState.addFailedBatch( - partitionUniqueIds[i], - new PushFailedBatch(mapId, attemptId, batchIds[i])); + partitionUniqueIds[i], new PushFailedBatch(mapId, attemptId, batchIds[i])); } } if (pushState.exception.get() != null) { diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index bf4e2b44b01..5ed095f2fa6 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -665,11 +665,10 @@ private boolean fillBuffer() throws IOException { // de-duplicate if (attemptId == attempts[mapId]) { if (splitSkewPartitionWithoutMapRange) { - Set failedBatchSet = this.failedBatches - .get(currentReader.getLocation().getUniqueId()); + Set failedBatchSet = + this.failedBatches.get(currentReader.getLocation().getUniqueId()); if (null != failedBatchSet) { - PushFailedBatch failedBatch = - new PushFailedBatch(mapId, attemptId, batchId); + PushFailedBatch failedBatch = new PushFailedBatch(mapId, attemptId, batchId); if (failedBatchSet.contains(failedBatch)) { logger.warn("Skip duplicated batch: {}.", failedBatch); continue; From 2c4b03fe75e4b2441ecdb4b8eb94c8f3cd2590b1 Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Sun, 24 Mar 2024 19:55:08 +0800 Subject: [PATCH 07/44] add spark 3.3 patch --- ...rn-Optimize-Skew-Partitions-spark3_3.patch | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch new file mode 100644 index 00000000000..3fa0649d9f1 --- /dev/null +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch @@ -0,0 +1,78 @@ +From 39eeab2426f9676580e4e19c8b079e1967081c7d Mon Sep 17 00:00:00 2001 +From: wangshengjie +Date: Sun, 24 Mar 2024 19:51:05 +0800 +Subject: [PATCH] [SQL] Handle skew partitions with Celeborn + +--- + .../org/apache/spark/sql/internal/SQLConf.scala | 10 ++++++++++ + .../execution/adaptive/ShufflePartitionsUtil.scala | 12 +++++++++++- + 2 files changed, 21 insertions(+), 1 deletion(-) + +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +index af03ad9a4cb..1e55af89160 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +@@ -3784,6 +3784,13 @@ object SQLConf { + .booleanConf + .createWithDefault(false) + ++ val CELEBORN_CLIENT_DATA_PUSH_FAILURE_TRACKING_ENABLED = ++ buildConf("spark.celeborn.client.dataPushFailure.tracking.enabled") ++ .withAlternative("celeborn.client.dataPushFailure.tracking.enabled") ++ .version("3.1.2-mdh") ++ .booleanConf ++ .createWithDefault(false) ++ + /** + * Holds information about keys that have been deprecated. + * +@@ -4549,6 +4556,9 @@ class SQLConf extends Serializable with Logging { + def histogramNumericPropagateInputType: Boolean = + getConf(SQLConf.HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE) + ++ def isCelebornClientPushFailedTrackingEnabled: Boolean = getConf( ++ SQLConf.CELEBORN_CLIENT_DATA_PUSH_FAILURE_TRACKING_ENABLED) ++ + /** ********************** SQLConf functionality methods ************ */ + + /** Set Spark SQL configuration properties. */ +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +index af689db3379..7da6211e509 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +@@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} + import org.apache.spark.internal.Logging + import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} ++import org.apache.spark.sql.internal.SQLConf ++import org.apache.spark.util.Utils + + object ShufflePartitionsUtil extends Logging { + final val SMALL_PARTITION_FACTOR = 0.2 +@@ -387,6 +389,10 @@ object ShufflePartitionsUtil extends Logging { + val mapStartIndices = splitSizeListByTargetSize( + mapPartitionSizes, targetSize, smallPartitionFactor) + if (mapStartIndices.length > 1) { ++ // If Celeborn is enabled, split skew partitions without shuffle mapper-range reading ++ val splitSkewPartitionWithCeleborn = Utils.isCelebornEnabled(SparkEnv.get.conf) && ++ SQLConf.get.isCelebornClientPushFailedTrackingEnabled ++ + Some(mapStartIndices.indices.map { i => + val startMapIndex = mapStartIndices(i) + val endMapIndex = if (i == mapStartIndices.length - 1) { +@@ -400,7 +406,11 @@ object ShufflePartitionsUtil extends Logging { + dataSize += mapPartitionSizes(mapIndex) + mapIndex += 1 + } +- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ if (splitSkewPartitionWithCeleborn) { ++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize) ++ } else { ++ PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ } + }) + } else { + None +-- +2.25.1 + From ef2dacbaa5d34ddfbb078d82304b24a0100776ab Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Tue, 26 Mar 2024 22:11:49 +0800 Subject: [PATCH 08/44] remove unused log and refactor code --- .../celeborn/CelebornShuffleReader.scala | 2 +- .../apache/celeborn/client/ShuffleClient.java | 2 +- .../celeborn/client/ShuffleClientImpl.java | 4 ++-- .../client/read/CelebornInputStream.java | 10 +++------- .../celeborn/client/DummyShuffleClient.java | 8 ++++++-- .../protocol/message/ControlMessages.scala | 20 ++++--------------- .../celeborn/common/util/PbSerDeUtils.scala | 20 ++++++++++++++++--- .../common/util/PbSerDeUtilsTest.scala | 20 +++++++++++++++++++ 8 files changed, 54 insertions(+), 32 deletions(-) diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 1d9a00587cb..66b55ed4b0f 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -222,7 +222,7 @@ class CelebornShuffleReader[K, C]( else null, locations, streamHandlers, - fileGroups.pushFailedBatchSet, + fileGroups.pushFailedBatches, fileGroups.mapAttempts, metricsCallback) streams.put(partitionId, inputStream) diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 5d57d2971d5..6b3146b864d 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -258,7 +258,7 @@ public abstract CelebornInputStream readPartition( ExceptionMaker exceptionMaker, ArrayList locations, ArrayList streamHandlers, - Set pushFailedBatchSet, + Map> failedBatchSetMap, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException; diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 2ff08e83770..d5104564508 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -1893,7 +1893,7 @@ public CelebornInputStream readPartition( ExceptionMaker exceptionMaker, ArrayList locations, ArrayList streamHandlers, - Set pushFailedBatchSet, + Map> failedBatchSetMap, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException { @@ -1926,7 +1926,7 @@ public CelebornInputStream readPartition( locations, streamHandlers, mapAttempts, - pushFailedBatchSet, + failedBatchSetMap, attemptNumber, taskId, startMapIndex, diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 5ed095f2fa6..f66ef7c6c74 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -55,7 +55,7 @@ public static CelebornInputStream create( ArrayList locations, ArrayList streamHandlers, int[] attempts, - Map> failedBatchSet, + Map> failedBatchSetMap, int attemptNumber, long taskId, int startMapIndex, @@ -74,15 +74,11 @@ public static CelebornInputStream create( // if startMapIndex > endMapIndex, means partition is skew partition. // locations will split to sub-partitions with startMapIndex size. ArrayList filterLocations = locations; - logger.error("current split info: {},{}", startMapIndex, endMapIndex); boolean splitSkewPartitionWithoutMapRange = conf.clientPushFailureTrackingEnabled() && startMapIndex > endMapIndex; if (splitSkewPartitionWithoutMapRange) { - logger.error("use new mode to handle skew partition without map range"); filterLocations = getSubSkewPartitionLocations(locations, startMapIndex, endMapIndex); - - logger.error("current partition locations: {}", filterLocations); - + logger.debug("Current sub-partition locations: {}", filterLocations); endMapIndex = Integer.MAX_VALUE; } return new CelebornInputStreamImpl( @@ -92,7 +88,7 @@ public static CelebornInputStream create( filterLocations, streamHandlers, attempts, - failedBatchSet, + failedBatchSetMap, attemptNumber, taskId, startMapIndex, diff --git a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java index c2fef84470f..3cc99fc4872 100644 --- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -24,7 +24,11 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; @@ -134,7 +138,7 @@ public CelebornInputStream readPartition( ExceptionMaker exceptionMaker, ArrayList locations, ArrayList streamHandlers, - Set pushFailedBatchSet, + Map> failedBatchSetMap, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException { diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 2a6edd9a867..7bbe66ec54c 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -727,9 +727,7 @@ object ControlMessages extends Logging { case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) => val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) => - val resultValue = - PbPushFailedBatchSet.newBuilder().addAllFailureBatches(v.asScala.map(PbSerDeUtils - .toPbPushFailedBatch).asJava).build() + val resultValue = PbSerDeUtils.toPbPushFailedBatchSet(v) (k, resultValue) }.toMap.asJava val payload = PbMapperEnd.newBuilder() @@ -771,11 +769,7 @@ object ControlMessages extends Logging { builder.putAllPushFailedBatches( failedBatches.asScala.map { case (uniqueId, pushFailedBatchSet) => - ( - uniqueId, - PbPushFailedBatchSet.newBuilder().addAllFailureBatches( - pushFailedBatchSet.asScala.map(PbSerDeUtils - .toPbPushFailedBatch).asJava).build()) + (uniqueId, PbSerDeUtils.toPbPushFailedBatchSet(pushFailedBatchSet)) }.asJava) val payload = builder.build().toByteArray new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE, payload) @@ -1164,10 +1158,7 @@ object ControlMessages extends Logging { pbMapperEnd.getPartitionId, pbMapperEnd.getPushFailureBatchesMap.asScala.map { case (partitionId, pushFailedBatchSet) => - ( - partitionId, - pushFailedBatchSet.getFailureBatchesList.asScala.map(PbSerDeUtils - .fromPbPushFailedBatch).toSet.asJava) + (partitionId, PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet)) }.toMap.asJava) case MAPPER_END_RESPONSE_VALUE => @@ -1206,10 +1197,7 @@ object ControlMessages extends Logging { val partitionIds = new util.HashSet(pbGetReducerFileGroupResponse.getPartitionIdsList) val pushFailedBatches = pbGetReducerFileGroupResponse.getPushFailedBatchesMap.asScala.map { case (uniqueId, pushFailedBatchSet) => - ( - uniqueId, - pushFailedBatchSet.getFailureBatchesList.asScala.map(PbSerDeUtils - .fromPbPushFailedBatch).toSet.asJava) + (uniqueId, PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet)) }.toMap.asJava GetReducerFileGroupResponse( Utils.toStatusCode(pbGetReducerFileGroupResponse.getStatus), diff --git a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala index ba4e4cdc03c..87766e71e49 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala @@ -672,12 +672,11 @@ object PbSerDeUtils { } def toPbPushFailedBatch(pushFailedBatch: PushFailedBatch): PbPushFailedBatch = { - val builder = PbPushFailedBatch.newBuilder() + PbPushFailedBatch.newBuilder() .setMapId(pushFailedBatch.getMapId) .setAttemptId(pushFailedBatch.getAttemptId) .setBatchId(pushFailedBatch.getBatchId) - - builder.build() + .build() } def fromPbPushFailedBatch(pbPushFailedBatch: PbPushFailedBatch): PushFailedBatch = { @@ -687,4 +686,19 @@ object PbSerDeUtils { pbPushFailedBatch.getBatchId) } + def toPbPushFailedBatchSet(failedBatchSet: util.Set[PushFailedBatch]): PbPushFailedBatchSet = { + val builder = PbPushFailedBatchSet.newBuilder() + failedBatchSet.asScala.foreach(batch => builder.addFailureBatches(toPbPushFailedBatch(batch))) + + builder.build() + } + + def fromPbPushFailedBatchSet(pbFailedBatchSet: PbPushFailedBatchSet) + : util.Set[PushFailedBatch] = { + val failedBatchSet = new util.HashSet[PushFailedBatch]() + pbFailedBatchSet.getFailureBatchesList.asScala.foreach(batch => + failedBatchSet.add(fromPbPushFailedBatch(batch))) + + failedBatchSet + } } diff --git a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala index 9e1d442fb24..280e459fc77 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala @@ -26,6 +26,8 @@ import scala.util.Random import org.apache.hadoop.shaded.org.apache.commons.lang3.RandomStringUtils +import com.google.common.collect.Sets + import org.apache.celeborn.CelebornFunSuite import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.meta._ @@ -35,6 +37,7 @@ import org.apache.celeborn.common.protocol.message.{ControlMessages, StatusCode} import org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse, WorkerResource} import org.apache.celeborn.common.quota.ResourceConsumption import org.apache.celeborn.common.util.PbSerDeUtils.{fromPbPackedPartitionLocationsPair, toPbPackedPartitionLocationsPair} +import org.apache.celeborn.common.write.PushFailedBatch class PbSerDeUtilsTest extends CelebornFunSuite { @@ -565,4 +568,21 @@ class PbSerDeUtilsTest extends CelebornFunSuite { locations.asScala.foreach(p => uniqueIds.remove(p.getUniqueId)) assert(uniqueIds.isEmpty) } + + test("fromAndToPushFailedBatch") { + val failedBatch = new PushFailedBatch(1, 1, 2) + val pbPushFailedBatch = PbSerDeUtils.toPbPushFailedBatch(failedBatch) + val restoredFailedBatch = PbSerDeUtils.fromPbPushFailedBatch(pbPushFailedBatch) + + assert(restoredFailedBatch.equals(failedBatch)) + } + + test("fromAndToPushFailedBatchSet") { + val failedBatchSet = Sets.newHashSet(new PushFailedBatch(1, 1, 2), new PushFailedBatch(2, 2, 3)) + val pbPushFailedBatchSet = PbSerDeUtils.toPbPushFailedBatchSet(failedBatchSet) + val restoredFailedBatchSet = PbSerDeUtils.fromPbPushFailedBatchSet(pbPushFailedBatchSet) + + assert(restoredFailedBatchSet.equals(failedBatchSet)) + } + } From b4dae1dbba095a72d4fab5b1cec1b111e552b23b Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Thu, 28 Mar 2024 21:36:32 +0800 Subject: [PATCH 09/44] add unit tests and refactor code --- .../client/read/CelebornInputStream.java | 5 +- .../read/CelebornInputStreamSuiteJ.java | 73 +++++++++++++++++ .../client/WithShuffleClientSuite.scala | 2 +- .../common/write/PushFailedBatchSuiteJ.java | 79 +++++++++++++++++++ .../LifecycleManagerCommitFilesSuite.scala | 69 +++++++++++++++- .../service/deploy/worker/FetchHandler.scala | 8 +- 6 files changed, 228 insertions(+), 8 deletions(-) create mode 100644 client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamSuiteJ.java create mode 100644 common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index f66ef7c6c74..5ca3030f795 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -636,6 +636,7 @@ private boolean fillBuffer() throws IOException { return false; } + PushFailedBatch failedBatch = new PushFailedBatch(-1, -1, -1); boolean hasData = false; while (currentChunk.isReadable() || moveToNextChunk()) { currentChunk.readBytes(sizeBuf); @@ -664,7 +665,9 @@ private boolean fillBuffer() throws IOException { Set failedBatchSet = this.failedBatches.get(currentReader.getLocation().getUniqueId()); if (null != failedBatchSet) { - PushFailedBatch failedBatch = new PushFailedBatch(mapId, attemptId, batchId); + failedBatch.setMapId(mapId); + failedBatch.setAttemptId(attemptId); + failedBatch.setBatchId(batchId); if (failedBatchSet.contains(failedBatch)) { logger.warn("Skip duplicated batch: {}.", failedBatch); continue; diff --git a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamSuiteJ.java new file mode 100644 index 00000000000..b88a3e0712d --- /dev/null +++ b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamSuiteJ.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.read; + +import java.util.ArrayList; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.celeborn.common.protocol.PartitionLocation; + +public class CelebornInputStreamSuiteJ { + + @Test + public void returnsCorrectSubSkewPartitionLocationsForIndex() { + ArrayList locations = createMockLocations(10); + ArrayList subPartition0 = + CelebornInputStream.getSubSkewPartitionLocations(locations, 3, 0); + Assert.assertEquals(3, subPartition0.size()); + Assert.assertEquals("10-1", subPartition0.get(0).getUniqueId()); + Assert.assertEquals("5-1", subPartition0.get(1).getUniqueId()); + Assert.assertEquals("4-1", subPartition0.get(2).getUniqueId()); + + ArrayList subPartition1 = + CelebornInputStream.getSubSkewPartitionLocations(locations, 3, 1); + Assert.assertEquals(3, subPartition1.size()); + Assert.assertEquals("9-1", subPartition1.get(0).getUniqueId()); + Assert.assertEquals("6-1", subPartition1.get(1).getUniqueId()); + Assert.assertEquals("3-1", subPartition1.get(2).getUniqueId()); + + ArrayList subPartition2 = + CelebornInputStream.getSubSkewPartitionLocations(locations, 3, 2); + Assert.assertEquals(4, subPartition2.size()); + Assert.assertEquals("8-1", subPartition2.get(0).getUniqueId()); + Assert.assertEquals("7-1", subPartition2.get(1).getUniqueId()); + Assert.assertEquals("2-1", subPartition2.get(2).getUniqueId()); + Assert.assertEquals("1-1", subPartition2.get(3).getUniqueId()); + } + + @Test + public void returnsEmptyListForEmptyLocations() { + ArrayList locations = new ArrayList<>(); + ArrayList result = + CelebornInputStream.getSubSkewPartitionLocations(locations, 3, 0); + Assert.assertTrue(result.isEmpty()); + } + + private ArrayList createMockLocations(int size) { + ArrayList locations = new ArrayList<>(); + for (int i = 1; i <= size; i++) { + PartitionLocation location = + new PartitionLocation(i, 1, "mock", -1, -1, -1, -1, PartitionLocation.Mode.PRIMARY); + location.getStorageInfo().setFileSize(size - i); + locations.add(location); + } + return locations; + } +} diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala index ad5b722b92a..afc91028653 100644 --- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala @@ -41,7 +41,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { private val attemptId = 0 private var lifecycleManager: LifecycleManager = _ - private var shuffleClient: ShuffleClientImpl = _ + protected var shuffleClient: ShuffleClientImpl = _ var _shuffleId = 0 def nextShuffleId: Int = { diff --git a/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java new file mode 100644 index 00000000000..fcfc6b79979 --- /dev/null +++ b/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.write; + +import java.util.HashSet; +import java.util.Set; + +import org.junit.Assert; +import org.junit.Test; + +public class PushFailedBatchSuiteJ { + + @Test + public void equalsReturnsTrueForIdenticalBatches() { + PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3); + PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3); + Assert.assertEquals(batch1, batch2); + } + + @Test + public void equalsReturnsFalseForDifferentBatches() { + PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3); + PushFailedBatch batch2 = new PushFailedBatch(4, 5, 6); + Assert.assertNotEquals(batch1, batch2); + } + + @Test + public void hashCodeDiffersForDifferentBatches() { + PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3); + PushFailedBatch batch2 = new PushFailedBatch(4, 5, 6); + Assert.assertNotEquals(batch1.hashCode(), batch2.hashCode()); + } + + @Test + public void hashCodeSameForIdenticalBatches() { + PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3); + PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3); + Assert.assertEquals(batch1.hashCode(), batch2.hashCode()); + } + + @Test + public void hashCodeIsConsistent() { + PushFailedBatch batch = new PushFailedBatch(1, 2, 3); + int hashCode1 = batch.hashCode(); + int hashCode2 = batch.hashCode(); + Assert.assertEquals(hashCode1, hashCode2); + } + + @Test + public void toStringReturnsExpectedFormat() { + PushFailedBatch batch = new PushFailedBatch(1, 2, 3); + String expected = "PushFailedBatch[mapId=1,attemptId=2,batchId=3]"; + Assert.assertEquals(expected, batch.toString()); + } + + @Test + public void hashCodeAndEqualsWorkInSet() { + Set set = new HashSet<>(); + PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3); + PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3); + set.add(batch1); + Assert.assertTrue(set.contains(batch2)); + } +} diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala index b9b084b7b46..a772316f318 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala @@ -17,15 +17,18 @@ package org.apache.celeborn.tests.client +import java.nio.charset.StandardCharsets import java.util import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.celeborn.client.{LifecycleManager, WithShuffleClientSuite} +import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl, WithShuffleClientSuite} import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers import org.apache.celeborn.client.commit.CommitFilesParam import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.protocol.CompressionCodec +import org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.util.Utils import org.apache.celeborn.service.deploy.MiniClusterFeature @@ -228,6 +231,70 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC lifecycleManager.stop() } + test("CELEBORN-1319: test commit files and check commit info") { + val shuffleId = nextShuffleId + val conf = celebornConf.clone + conf.set(CelebornConf.TEST_MOCK_COMMIT_FILES_FAILURE.key, "false") + val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) + val shuffleClient = new ShuffleClientImpl(APP, conf, userIdentifier) + shuffleClient.setupLifecycleManagerRef(lifecycleManager.self) + + val ids = new util.ArrayList[Integer](3) + 0 until 3 foreach { + ids.add(_) + } + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + assert(res.status == StatusCode.SUCCESS) + assert(res.workerResource.keySet().size() == 3) + + lifecycleManager.setupEndpoints( + res.workerResource.keySet, + shuffleId, + new ShuffleFailedWorkers()) + + lifecycleManager.reserveSlotsWithRetry( + shuffleId, + new util.HashSet(res.workerResource.keySet()), + res.workerResource, + updateEpoch = false) + + lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false) + + val buffer = "hello world".getBytes(StandardCharsets.UTF_8) + + var bufferLength = -1 + + 0 until 3 foreach { partitionId => + bufferLength = + shuffleClient.pushData(shuffleId, 0, 0, partitionId, buffer, 0, buffer.length, 1, 3) + lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId) + } + + val commitHandler = lifecycleManager.commitManager.getCommitHandler(shuffleId) + val params = new ArrayBuffer[CommitFilesParam](res.workerResource.size()) + res.workerResource.asScala.foreach { case (workerInfo, (primaryIds, replicaIds)) => + params += CommitFilesParam( + workerInfo, + primaryIds.asScala.map(_.getUniqueId).toList.asJava, + replicaIds.asScala.map(_.getUniqueId).toList.asJava) + } + + val shuffleCommittedInfo = lifecycleManager.commitManager.committedPartitionInfo.get(shuffleId) + commitHandler.doParallelCommitFiles( + shuffleId, + shuffleCommittedInfo, + params, + new ShuffleFailedWorkers) + + shuffleCommittedInfo.committedReplicaStorageInfos.values().asScala.foreach { storageInfo => + assert(storageInfo.fileSize == bufferLength) + // chunkOffsets contains 0 by default, and bufferFlushOffset + assert(storageInfo.chunkOffsets.size() == 2) + } + + lifecycleManager.stop() + } + override def afterAll(): Unit = { logInfo("all test complete , stop celeborn mini cluster") shutdownMiniCluster() diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index 0d0d7063e83..350870bc777 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -237,8 +237,7 @@ class FetchHandler( fileName: String, startIndex: Int, endIndex: Int, - readLocalShuffle: Boolean = false, - shuffleNeedSort: Boolean = true): PbStreamHandlerOpt = { + readLocalShuffle: Boolean = false): PbStreamHandlerOpt = { try { logDebug(s"Received open stream request $shuffleKey $fileName $startIndex " + s"$endIndex get file name $fileName from client channel " + @@ -250,9 +249,8 @@ class FetchHandler( // 1. when the current request is a non-range openStream, but the original unsorted file // has been deleted by another range's openStream request. // 2. when the current request is a range openStream request. - if (shuffleNeedSort && ((endIndex != Int.MaxValue) || (endIndex == Int.MaxValue && !fileInfo - .addStream( - streamId)))) { + if ((endIndex != Int.MaxValue) || (endIndex == Int.MaxValue && !fileInfo.addStream( + streamId))) { fileInfo = partitionsSorter.getSortedFileInfo( shuffleKey, fileName, From 52fdb96596471d156cf1c9b111db48368360d0dd Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Sat, 6 Apr 2024 01:30:24 +0800 Subject: [PATCH 10/44] split the skewed partition based on the chunk range --- .../celeborn/CelebornShuffleReader.scala | 2 +- .../client/read/CelebornInputStream.java | 184 +++++++++++++----- .../client/read/WorkerPartitionReader.java | 21 +- .../celeborn/common/protocol/StorageInfo.java | 37 +++- common/src/main/proto/TransportMessages.proto | 2 + .../LifecycleManagerCommitFilesSuite.scala | 2 +- .../service/deploy/worker/Controller.scala | 2 +- .../service/deploy/worker/FetchHandler.scala | 2 +- 8 files changed, 188 insertions(+), 64 deletions(-) diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 66b55ed4b0f..f66b5d34f2f 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -201,7 +201,7 @@ class CelebornShuffleReader[K, C]( new util.ArrayList(fileGroups.partitionGroups.get(partitionId)) } else new util.ArrayList[PartitionLocation]() val streamHandlers = - if (locations != null) { + if (locations != null && !conf.clientPushFailureTrackingEnabled) { val streamHandlerArr = new util.ArrayList[PbStreamHandler](locations.size()) locations.asScala.foreach { loc => streamHandlerArr.add(locationStreamHandlerMap.get(loc)) diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 5ca3030f795..34ce0f47e78 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -30,6 +30,7 @@ import com.google.common.util.concurrent.Uninterruptibles; import io.netty.buffer.ByteBuf; import net.jpountz.lz4.LZ4Exception; +import org.apache.commons.lang3.tuple.Pair; import org.roaringbitmap.RoaringBitmap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -73,59 +74,87 @@ public static CelebornInputStream create( } else { // if startMapIndex > endMapIndex, means partition is skew partition. // locations will split to sub-partitions with startMapIndex size. - ArrayList filterLocations = locations; boolean splitSkewPartitionWithoutMapRange = conf.clientPushFailureTrackingEnabled() && startMapIndex > endMapIndex; if (splitSkewPartitionWithoutMapRange) { - filterLocations = getSubSkewPartitionLocations(locations, startMapIndex, endMapIndex); - logger.debug("Current sub-partition locations: {}", filterLocations); - endMapIndex = Integer.MAX_VALUE; + Map> partitionLocationToChunkRange = + splitSkewedPartitionLocations(locations, startMapIndex, endMapIndex); + logger.debug("Current sub-partition locations: {}", locations); + return new CelebornInputStreamImpl( + conf, + clientFactory, + shuffleKey, + locations, + streamHandlers, + attempts, + failedBatchSetMap, + attemptNumber, + taskId, + partitionLocationToChunkRange, + fetchExcludedWorkers, + shuffleClient, + appShuffleId, + shuffleId, + partitionId, + exceptionMaker, + splitSkewPartitionWithoutMapRange, + metricsCallback); + } else { + return new CelebornInputStreamImpl( + conf, + clientFactory, + shuffleKey, + locations, + streamHandlers, + attempts, + failedBatchSetMap, + attemptNumber, + taskId, + startMapIndex, + endMapIndex, + /*partitionLocationToChunkRange = */ null, + fetchExcludedWorkers, + shuffleClient, + appShuffleId, + shuffleId, + partitionId, + exceptionMaker, + splitSkewPartitionWithoutMapRange, + metricsCallback); } - return new CelebornInputStreamImpl( - conf, - clientFactory, - shuffleKey, - filterLocations, - streamHandlers, - attempts, - failedBatchSetMap, - attemptNumber, - taskId, - startMapIndex, - endMapIndex, - fetchExcludedWorkers, - shuffleClient, - appShuffleId, - shuffleId, - partitionId, - exceptionMaker, - splitSkewPartitionWithoutMapRange, - metricsCallback); } } - public static ArrayList getSubSkewPartitionLocations( + public static Map> splitSkewedPartitionLocations( ArrayList locations, int subPartitionSize, int subPartitionIndex) { - locations.sort( - Comparator.comparingLong((PartitionLocation o) -> o.getStorageInfo().fileSize) - .thenComparing(PartitionLocation::getUniqueId)); - int step = locations.size() / subPartitionSize; - ArrayList result = new ArrayList<>(step + 1); - - // if partition location is [1,2,3,4,5,6,7,8,9,10], and skew partition split to 3 task: - // task 0: 1, 6, 7 - // task 1: 2, 5, 8 - // task 2: 3, 4, 9, 10 - for (int i = 0; i < step + 1; i++) { - if (i % 2 == 0 && (i * subPartitionSize + subPartitionIndex) < locations.size()) { - result.add(locations.get(i * subPartitionSize + subPartitionIndex)); - } else if (i % 2 == 1 - && ((i + 1) * subPartitionSize - subPartitionIndex - 1) < locations.size()) { - result.add(locations.get((i + 1) * subPartitionSize - subPartitionIndex - 1)); + locations.sort(Comparator.comparing((PartitionLocation p) -> p.getUniqueId())); + + long totalPartitionSize = + locations.stream().mapToLong((PartitionLocation p) -> p.getStorageInfo().fileSize).sum(); + long step = totalPartitionSize / subPartitionSize; + long startOffset = step * subPartitionIndex; + long endOffset = step * (subPartitionIndex + 1); + long partitionLocationOffset = 0; + Map> chunkRange = new HashMap<>(); + for (int i = 0; i < locations.size(); i++) { + PartitionLocation p = locations.get(i); + int left = -1; + int right = -1; + for (int j = 0; j < p.getStorageInfo().getChunkOffsets().size(); j++) { + long currentOffset = partitionLocationOffset + p.getStorageInfo().getChunkOffsets().get(j); + if (currentOffset >= startOffset && left < 0) { + left = j; + } + if (currentOffset < endOffset) { + right = j; + } + if (left >= 0 && right >= 0) { + chunkRange.put(p.getUniqueId(), Pair.of(left, right)); + } } + partitionLocationOffset += p.getStorageInfo().getFileSize(); } - - return result; + return chunkRange; } public static CelebornInputStream empty() { @@ -172,6 +201,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private final long taskId; private final int startMapIndex; private final int endMapIndex; + private final Map> partitionLocationToChunkRange; private Map> batchesRead = new HashMap<>(); @@ -215,6 +245,49 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private final boolean splitSkewPartitionWithoutMapRange; + CelebornInputStreamImpl( + CelebornConf conf, + TransportClientFactory clientFactory, + String shuffleKey, + ArrayList locations, + ArrayList streamHandlers, + int[] attempts, + Map> failedBatchSet, + int attemptNumber, + long taskId, + Map> partitionLocationToChunkRange, + ConcurrentHashMap fetchExcludedWorkers, + ShuffleClient shuffleClient, + int appShuffleId, + int shuffleId, + int partitionId, + ExceptionMaker exceptionMaker, + boolean splitSkewPartitionWithoutMapRange, + MetricsCallback metricsCallback) + throws IOException { + this( + conf, + clientFactory, + shuffleKey, + locations, + streamHandlers, + attempts, + failedBatchSet, + attemptNumber, + taskId, + /*startMapIndex = */ -1, + /*endMapIndex = */ -1, + partitionLocationToChunkRange, + fetchExcludedWorkers, + shuffleClient, + appShuffleId, + shuffleId, + partitionId, + exceptionMaker, + splitSkewPartitionWithoutMapRange, + metricsCallback); + } + CelebornInputStreamImpl( CelebornConf conf, TransportClientFactory clientFactory, @@ -227,6 +300,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { long taskId, int startMapIndex, int endMapIndex, + Map> partitionLocationToChunkRange, ConcurrentHashMap fetchExcludedWorkers, ShuffleClient shuffleClient, int appShuffleId, @@ -248,6 +322,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { this.taskId = taskId; this.startMapIndex = startMapIndex; this.endMapIndex = endMapIndex; + this.partitionLocationToChunkRange = partitionLocationToChunkRange; this.rangeReadFilter = conf.shuffleRangeReadFilterEnabled(); this.enabledReadLocalShuffle = conf.enableReadLocalShuffleFile(); this.localHostAddress = Utils.localHostName(conf); @@ -305,8 +380,10 @@ private Tuple2 nextReadableLocation() { } PartitionLocation currentLocation = locations.get(fileIndex); // if pushShuffleFailureTrackingEnabled is true, should not skip location - while (!splitSkewPartitionWithoutMapRange - && skipLocation(startMapIndex, endMapIndex, currentLocation)) { + while ((splitSkewPartitionWithoutMapRange + && !partitionLocationToChunkRange.containsKey(currentLocation.getUniqueId())) + || (!splitSkewPartitionWithoutMapRange + && skipLocation(startMapIndex, endMapIndex, currentLocation))) { skipCount.increment(); fileIndex++; if (fileIndex == locationCount) { @@ -473,6 +550,15 @@ private PartitionReader createReader( logger.debug("Create reader for location {}", location); StorageInfo storageInfo = location.getStorageInfo(); + + int startChunkIndex = -1; + int endChunkIndex = -1; + if (partitionLocationToChunkRange != null) { + Pair chunkRange = + partitionLocationToChunkRange.get(location.getUniqueId()); + startChunkIndex = chunkRange.getLeft(); + endChunkIndex = chunkRange.getRight(); + } switch (storageInfo.getType()) { case HDD: case SSD: @@ -495,7 +581,9 @@ private PartitionReader createReader( endMapIndex, fetchChunkRetryCnt, fetchChunkMaxRetry, - callback); + callback, + startChunkIndex, + endChunkIndex); } case S3: case HDFS: @@ -668,7 +756,9 @@ private boolean fillBuffer() throws IOException { failedBatch.setMapId(mapId); failedBatch.setAttemptId(attemptId); failedBatch.setBatchId(batchId); - if (failedBatchSet.contains(failedBatch)) { + String uid = currentReader.getLocation().getUniqueId(); + if (failedBatches.containsKey(uid) + && failedBatches.get(uid).contains(failedBatch)) { logger.warn("Skip duplicated batch: {}.", failedBatch); continue; } diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index 3158aa12f72..29236273663 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -55,6 +55,8 @@ public class WorkerPartitionReader implements PartitionReader { private int returnedChunks; private int chunkIndex; + private int startChunkIndex; + private int endChunkIndex; private final LinkedBlockingQueue results; private final ChunkReceivedCallback callback; @@ -80,7 +82,9 @@ public class WorkerPartitionReader implements PartitionReader { int endMapIndex, int fetchChunkRetryCnt, int fetchChunkMaxRetry, - MetricsCallback metricsCallback) + MetricsCallback metricsCallback, + int startChunkIndex, + int endChunkIndex) throws IOException, InterruptedException { this.shuffleKey = shuffleKey; fetchMaxReqsInFlight = conf.clientFetchMaxReqsInFlight(); @@ -133,7 +137,12 @@ public void onFailure(int chunkIndex, Throwable e) { } else { this.streamHandler = pbStreamHandler; } - + this.startChunkIndex = startChunkIndex == -1 ? 0 : startChunkIndex; + this.endChunkIndex = + endChunkIndex == -1 + ? streamHandler.getNumChunks() - 1 + : Math.min(streamHandler.getNumChunks() - 1, endChunkIndex); + this.chunkIndex = this.startChunkIndex; this.location = location; this.clientFactory = clientFactory; this.fetchChunkRetryCnt = fetchChunkRetryCnt; @@ -144,13 +153,13 @@ public void onFailure(int chunkIndex, Throwable e) { @Override public boolean hasNext() { - return returnedChunks < streamHandler.getNumChunks(); + return returnedChunks < endChunkIndex - startChunkIndex + 1; } @Override public ByteBuf next() throws IOException, InterruptedException { checkException(); - if (chunkIndex < streamHandler.getNumChunks()) { + if (chunkIndex <= endChunkIndex) { fetchChunks(); } ByteBuf chunk = null; @@ -202,10 +211,10 @@ public PartitionLocation getLocation() { } private void fetchChunks() throws IOException, InterruptedException { - final int inFlight = chunkIndex - returnedChunks; + final int inFlight = chunkIndex - startChunkIndex - returnedChunks; if (inFlight < fetchMaxReqsInFlight) { final int toFetch = - Math.min(fetchMaxReqsInFlight - inFlight + 1, streamHandler.getNumChunks() - chunkIndex); + Math.min(fetchMaxReqsInFlight - inFlight + 1, endChunkIndex + 1 - chunkIndex); for (int i = 0; i < toFetch; i++) { if (testFetch && fetchChunkRetryCnt < fetchChunkMaxRetry - 1 && chunkIndex == 3) { callback.onFailure(chunkIndex, new CelebornIOException("Test fetch chunk failure")); diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java b/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java index a96d332f6be..038219b3389 100644 --- a/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java +++ b/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java @@ -70,7 +70,7 @@ public int getValue() { public long fileSize; - public int numChunks; + public List chunkOffsets; public StorageInfo() {} @@ -99,6 +99,23 @@ public StorageInfo( this.availableStorageTypes = availableStorageTypes; } + public StorageInfo( + Type type, + String mountPoint, + boolean finalResult, + String filePath, + int availableStorageTypes, + long fileSize, + List chunkOffsets) { + this.type = type; + this.mountPoint = mountPoint; + this.finalResult = finalResult; + this.filePath = filePath; + this.availableStorageTypes = availableStorageTypes; + this.fileSize = fileSize; + this.chunkOffsets = chunkOffsets; + } + public boolean isFinalResult() { return finalResult; } @@ -123,12 +140,12 @@ public String getFilePath() { return filePath; } - public void setNumChunks(int numChunks) { - this.numChunks = numChunks; + public void setChunkOffsets(List chunkOffsets) { + this.chunkOffsets = chunkOffsets; } - public int getNumChunks() { - return this.numChunks; + public List getChunkOffsets() { + return this.chunkOffsets; } public void setFileSize(long fileSize) { @@ -235,7 +252,11 @@ public static PbStorageInfo toPb(StorageInfo storageInfo) { .setType(storageInfo.type.value) .setFinalResult(storageInfo.finalResult) .setMountPoint(storageInfo.mountPoint) - .setAvailableStorageTypes(storageInfo.availableStorageTypes); + .setAvailableStorageTypes(storageInfo.availableStorageTypes) + .setFileSize(storageInfo.getFileSize()); + if (storageInfo.getChunkOffsets() != null) { + builder.addAllChunkOffsets(storageInfo.getChunkOffsets()); + } if (filePath != null) { builder.setFilePath(filePath); } @@ -248,7 +269,9 @@ public static StorageInfo fromPb(PbStorageInfo pbStorageInfo) { pbStorageInfo.getMountPoint(), pbStorageInfo.getFinalResult(), pbStorageInfo.getFilePath(), - pbStorageInfo.getAvailableStorageTypes()); + pbStorageInfo.getAvailableStorageTypes(), + pbStorageInfo.getFileSize(), + pbStorageInfo.getChunkOffsetsList()); } public static int getAvailableTypes(List types) { diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index 0514e5e1538..f8dd0824223 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -134,6 +134,8 @@ message PbStorageInfo { bool finalResult = 3; string filePath = 4; int32 availableStorageTypes = 5; + int64 fileSize = 6; + repeated int64 chunkOffsets = 7; } message PbPartitionLocation { diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala index a772316f318..ead015832fa 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala @@ -266,7 +266,7 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC 0 until 3 foreach { partitionId => bufferLength = - shuffleClient.pushData(shuffleId, 0, 0, partitionId, buffer, 0, buffer.length, 1, 3) + shuffleClient.pushData(shuffleId, 0, 0, partitionId, buffer, 0, buffer.length, 1, 1) lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId) } diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala index 11ef735b6f3..53b73da91d0 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala @@ -341,8 +341,8 @@ private[deploy] class Controller( val fileMeta = fileWriter.getDiskFileInfo.getFileMeta fileMeta match { case meta: ReduceFileMeta => - storageInfo.setNumChunks(meta.getNumChunks) storageInfo.setFileSize(bytes) + storageInfo.setChunkOffsets(meta.getChunkOffsets) case _ => } committedStorageInfos.put(uniqueId, storageInfo) diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index 350870bc777..3c99639eb84 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -249,7 +249,7 @@ class FetchHandler( // 1. when the current request is a non-range openStream, but the original unsorted file // has been deleted by another range's openStream request. // 2. when the current request is a range openStream request. - if ((endIndex != Int.MaxValue) || (endIndex == Int.MaxValue && !fileInfo.addStream( + if ((endIndex != Int.MaxValue && endIndex != -1 && endIndex >= startIndex) || (endIndex == Int.MaxValue && !fileInfo.addStream( streamId))) { fileInfo = partitionsSorter.getSortedFileInfo( shuffleKey, From deb29135dad38e6306a2866b1c2d9912954dd14f Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Sun, 7 Apr 2024 11:37:12 +0800 Subject: [PATCH 11/44] remove CelebornInputStreamSuiteJ --- .../read/CelebornInputStreamSuiteJ.java | 73 ------------------- 1 file changed, 73 deletions(-) delete mode 100644 client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamSuiteJ.java diff --git a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamSuiteJ.java deleted file mode 100644 index b88a3e0712d..00000000000 --- a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamSuiteJ.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.celeborn.client.read; - -import java.util.ArrayList; - -import org.junit.Assert; -import org.junit.Test; - -import org.apache.celeborn.common.protocol.PartitionLocation; - -public class CelebornInputStreamSuiteJ { - - @Test - public void returnsCorrectSubSkewPartitionLocationsForIndex() { - ArrayList locations = createMockLocations(10); - ArrayList subPartition0 = - CelebornInputStream.getSubSkewPartitionLocations(locations, 3, 0); - Assert.assertEquals(3, subPartition0.size()); - Assert.assertEquals("10-1", subPartition0.get(0).getUniqueId()); - Assert.assertEquals("5-1", subPartition0.get(1).getUniqueId()); - Assert.assertEquals("4-1", subPartition0.get(2).getUniqueId()); - - ArrayList subPartition1 = - CelebornInputStream.getSubSkewPartitionLocations(locations, 3, 1); - Assert.assertEquals(3, subPartition1.size()); - Assert.assertEquals("9-1", subPartition1.get(0).getUniqueId()); - Assert.assertEquals("6-1", subPartition1.get(1).getUniqueId()); - Assert.assertEquals("3-1", subPartition1.get(2).getUniqueId()); - - ArrayList subPartition2 = - CelebornInputStream.getSubSkewPartitionLocations(locations, 3, 2); - Assert.assertEquals(4, subPartition2.size()); - Assert.assertEquals("8-1", subPartition2.get(0).getUniqueId()); - Assert.assertEquals("7-1", subPartition2.get(1).getUniqueId()); - Assert.assertEquals("2-1", subPartition2.get(2).getUniqueId()); - Assert.assertEquals("1-1", subPartition2.get(3).getUniqueId()); - } - - @Test - public void returnsEmptyListForEmptyLocations() { - ArrayList locations = new ArrayList<>(); - ArrayList result = - CelebornInputStream.getSubSkewPartitionLocations(locations, 3, 0); - Assert.assertTrue(result.isEmpty()); - } - - private ArrayList createMockLocations(int size) { - ArrayList locations = new ArrayList<>(); - for (int i = 1; i <= size; i++) { - PartitionLocation location = - new PartitionLocation(i, 1, "mock", -1, -1, -1, -1, PartitionLocation.Mode.PRIMARY); - location.getStorageInfo().setFileSize(size - i); - locations.add(location); - } - return locations; - } -} From f7abfc26286ac25d36af131048824d11b4abad08 Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Mon, 8 Apr 2024 23:26:10 +0800 Subject: [PATCH 12/44] refactor code for batch open stream --- .../celeborn/CelebornShuffleReader.scala | 77 ++++++++++++++++--- .../apache/celeborn/client/ShuffleClient.java | 3 + .../celeborn/client/ShuffleClientImpl.java | 2 + .../client/read/CelebornInputStream.java | 44 +---------- .../celeborn/client/DummyShuffleClient.java | 2 + .../client/WithShuffleClientSuite.scala | 2 + .../deploy/cluster/ReadWriteTestBase.scala | 1 + 7 files changed, 82 insertions(+), 49 deletions(-) diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index f66b5d34f2f..9a53b8e84f4 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -23,7 +23,9 @@ import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ +import scala.collection.mutable +import org.apache.commons.lang3.tuple.Pair import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext} import org.apache.spark.celeborn.ExceptionMakerHelper import org.apache.spark.internal.Logging @@ -69,6 +71,12 @@ class CelebornShuffleReader[K, C]( private val throwsFetchFailure = handle.throwsFetchFailure private val encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(context) + private val comparator = new util.Comparator[PartitionLocation] { + override def compare(o1: PartitionLocation, o2: PartitionLocation): Int = { + o1.getUniqueId.compareTo(o2.getUniqueId) + } + } + override def read(): Iterator[Product2[K, C]] = { val serializerInstance = newSerializerInstance(dep) @@ -119,12 +127,37 @@ class CelebornShuffleReader[K, C]( val workerRequestMap = new util.HashMap[ String, (TransportClient, util.ArrayList[PartitionLocation], PbOpenStreamList.Builder)]() + // partitionId -> (partition uniqueId -> chunkRange pair) + val partitionId2ChunkRange = + new util.HashMap[Int, util.Map[String, Pair[Integer, Integer]]]() + + val partitionId2PartitionLocations = new util.HashMap[Int, mutable.Set[PartitionLocation]]() var partCnt = 0 + // if startMapIndex > endMapIndex, means partition is skew partition. + // locations will split to sub-partitions with startMapIndex size. + val splitSkewPartitionWithoutMapRange = conf.clientPushFailureTrackingEnabled && + startMapIndex > endMapIndex + (startPartition until endPartition).foreach { partitionId => if (fileGroups.partitionGroups.containsKey(partitionId)) { - fileGroups.partitionGroups.get(partitionId).asScala.foreach { location => + val locations = fileGroups.partitionGroups.get(partitionId) + var partitionLocationToChunkRange: util.Map[String, Pair[Integer, Integer]] = null + if (splitSkewPartitionWithoutMapRange) { + partitionLocationToChunkRange = + splitSkewedPartitionLocations(new util.ArrayList(locations), startMapIndex, endMapIndex) + partitionId2ChunkRange.put(partitionId, partitionLocationToChunkRange) + } + // filter locations avoid OPEN_STREAM when split skew partition without map range + val filterLocations = locations.asScala + .filter { location => + null != partitionLocationToChunkRange && + partitionLocationToChunkRange.containsKey(location.getUniqueId) + } + partitionId2PartitionLocations.put(partitionId, filterLocations) + + filterLocations.foreach { location => partCnt += 1 val hostPort = location.hostAndFetchPort if (!workerRequestMap.containsKey(hostPort)) { @@ -196,14 +229,12 @@ class CelebornShuffleReader[K, C]( val streams = JavaUtils.newConcurrentHashMap[Integer, CelebornInputStream]() def createInputStream(partitionId: Int): Unit = { - val locations = - if (fileGroups.partitionGroups.containsKey(partitionId)) { - new util.ArrayList(fileGroups.partitionGroups.get(partitionId)) - } else new util.ArrayList[PartitionLocation]() + val locations = partitionId2PartitionLocations.get(partitionId) + val streamHandlers = - if (locations != null && !conf.clientPushFailureTrackingEnabled) { - val streamHandlerArr = new util.ArrayList[PbStreamHandler](locations.size()) - locations.asScala.foreach { loc => + if (locations != null) { + val streamHandlerArr = new util.ArrayList[PbStreamHandler](locations.size) + locations.foreach { loc => streamHandlerArr.add(locationStreamHandlerMap.get(loc)) } streamHandlerArr @@ -220,9 +251,10 @@ class CelebornShuffleReader[K, C]( endMapIndex, if (throwsFetchFailure) ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER else null, - locations, + new util.ArrayList[PartitionLocation](locations.asJava), streamHandlers, fileGroups.pushFailedBatches, + partitionId2ChunkRange.get(partitionId), fileGroups.mapAttempts, metricsCallback) streams.put(partitionId, inputStream) @@ -395,6 +427,33 @@ class CelebornShuffleReader[K, C]( dep.serializer.newInstance() } + private def splitSkewedPartitionLocations( + locations: util.ArrayList[PartitionLocation], + subPartitionSize: Int, + subPartitionIndex: Int): util.Map[String, Pair[Integer, Integer]] = { + locations.sort(comparator) + val totalPartitionSize = + locations.stream.mapToLong((p: PartitionLocation) => p.getStorageInfo.fileSize).sum + val step = totalPartitionSize / subPartitionSize + val startOffset = step * subPartitionIndex + val endOffset = step * (subPartitionIndex + 1) + var partitionLocationOffset: Long = 0 + val chunkRange = new util.HashMap[String, Pair[Integer, Integer]] + for (i <- 0 until locations.size) { + val p = locations.get(i) + var left = -1 + var right = -1 + for (j <- 0 until p.getStorageInfo.getChunkOffsets.size) { + val currentOffset = partitionLocationOffset + p.getStorageInfo.getChunkOffsets.get(j) + if (currentOffset >= startOffset && left < 0) left = j + if (currentOffset < endOffset) right = j + if (left >= 0 && right >= 0) chunkRange.put(p.getUniqueId, Pair.of(left, right)) + } + partitionLocationOffset = partitionLocationOffset + p.getStorageInfo.getFileSize + } + chunkRange + } + } object CelebornShuffleReader { diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 6b3146b864d..8690f5cebde 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; +import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.fs.FileSystem; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -244,6 +245,7 @@ public CelebornInputStream readPartition( null, null, null, + null, metricsCallback); } @@ -259,6 +261,7 @@ public abstract CelebornInputStream readPartition( ArrayList locations, ArrayList streamHandlers, Map> failedBatchSetMap, + Map> chunksRange, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException; diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index d5104564508..43930d9d776 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -1894,6 +1894,7 @@ public CelebornInputStream readPartition( ArrayList locations, ArrayList streamHandlers, Map> failedBatchSetMap, + Map> chunksRange, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException { @@ -1927,6 +1928,7 @@ public CelebornInputStream readPartition( streamHandlers, mapAttempts, failedBatchSetMap, + chunksRange, attemptNumber, taskId, startMapIndex, diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 34ce0f47e78..1b27f99c756 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -57,6 +57,7 @@ public static CelebornInputStream create( ArrayList streamHandlers, int[] attempts, Map> failedBatchSetMap, + Map> chunksRange, int attemptNumber, long taskId, int startMapIndex, @@ -69,7 +70,7 @@ public static CelebornInputStream create( ExceptionMaker exceptionMaker, MetricsCallback metricsCallback) throws IOException { - if (locations == null || locations.size() == 0) { + if (locations == null || locations.isEmpty()) { return emptyInputStream; } else { // if startMapIndex > endMapIndex, means partition is skew partition. @@ -77,9 +78,6 @@ public static CelebornInputStream create( boolean splitSkewPartitionWithoutMapRange = conf.clientPushFailureTrackingEnabled() && startMapIndex > endMapIndex; if (splitSkewPartitionWithoutMapRange) { - Map> partitionLocationToChunkRange = - splitSkewedPartitionLocations(locations, startMapIndex, endMapIndex); - logger.debug("Current sub-partition locations: {}", locations); return new CelebornInputStreamImpl( conf, clientFactory, @@ -90,7 +88,7 @@ public static CelebornInputStream create( failedBatchSetMap, attemptNumber, taskId, - partitionLocationToChunkRange, + chunksRange, fetchExcludedWorkers, shuffleClient, appShuffleId, @@ -125,38 +123,6 @@ public static CelebornInputStream create( } } - public static Map> splitSkewedPartitionLocations( - ArrayList locations, int subPartitionSize, int subPartitionIndex) { - locations.sort(Comparator.comparing((PartitionLocation p) -> p.getUniqueId())); - - long totalPartitionSize = - locations.stream().mapToLong((PartitionLocation p) -> p.getStorageInfo().fileSize).sum(); - long step = totalPartitionSize / subPartitionSize; - long startOffset = step * subPartitionIndex; - long endOffset = step * (subPartitionIndex + 1); - long partitionLocationOffset = 0; - Map> chunkRange = new HashMap<>(); - for (int i = 0; i < locations.size(); i++) { - PartitionLocation p = locations.get(i); - int left = -1; - int right = -1; - for (int j = 0; j < p.getStorageInfo().getChunkOffsets().size(); j++) { - long currentOffset = partitionLocationOffset + p.getStorageInfo().getChunkOffsets().get(j); - if (currentOffset >= startOffset && left < 0) { - left = j; - } - if (currentOffset < endOffset) { - right = j; - } - if (left >= 0 && right >= 0) { - chunkRange.put(p.getUniqueId(), Pair.of(left, right)); - } - } - partitionLocationOffset += p.getStorageInfo().getFileSize(); - } - return chunkRange; - } - public static CelebornInputStream empty() { return emptyInputStream; } @@ -756,9 +722,7 @@ private boolean fillBuffer() throws IOException { failedBatch.setMapId(mapId); failedBatch.setAttemptId(attemptId); failedBatch.setBatchId(batchId); - String uid = currentReader.getLocation().getUniqueId(); - if (failedBatches.containsKey(uid) - && failedBatches.get(uid).contains(failedBatch)) { + if (failedBatchSet.contains(failedBatch)) { logger.warn("Skip duplicated batch: {}.", failedBatch); continue; } diff --git a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java index 3cc99fc4872..7ec097392fd 100644 --- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -31,6 +31,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -139,6 +140,7 @@ public CelebornInputStream readPartition( ArrayList locations, ArrayList streamHandlers, Map> failedBatchSetMap, + Map> chunksRange, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException { diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala index afc91028653..0570760ce14 100644 --- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala @@ -165,6 +165,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { null, null, null, + null, metricsCallback) Assert.assertEquals(stream.read(), -1) @@ -182,6 +183,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { null, null, null, + null, metricsCallback) Assert.assertEquals(stream.read(), -1) } diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala index eee7b46fb0f..e81593e1d99 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala @@ -117,6 +117,7 @@ trait ReadWriteTestBase extends AnyFunSuite null, null, null, + null, metricsCallback) val outputStream = new ByteArrayOutputStream() From 2e3be7d79d1b996e08ba1f6978b7afb65d9ac15e Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 9 Apr 2024 01:35:46 +0800 Subject: [PATCH 13/44] Revert "refactor code for batch open stream" This reverts commit 7393d45cc88a9d17eaf59620cfa06538eaa510c6. --- .../celeborn/CelebornShuffleReader.scala | 75 ++----------------- .../apache/celeborn/client/ShuffleClient.java | 3 - .../celeborn/client/ShuffleClientImpl.java | 2 - .../client/read/CelebornInputStream.java | 44 ++++++++++- .../celeborn/client/DummyShuffleClient.java | 2 - .../client/WithShuffleClientSuite.scala | 2 - .../deploy/cluster/ReadWriteTestBase.scala | 1 - 7 files changed, 48 insertions(+), 81 deletions(-) diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 9a53b8e84f4..00ff8cc5ecb 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -23,9 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ -import scala.collection.mutable -import org.apache.commons.lang3.tuple.Pair import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext} import org.apache.spark.celeborn.ExceptionMakerHelper import org.apache.spark.internal.Logging @@ -71,12 +69,6 @@ class CelebornShuffleReader[K, C]( private val throwsFetchFailure = handle.throwsFetchFailure private val encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(context) - private val comparator = new util.Comparator[PartitionLocation] { - override def compare(o1: PartitionLocation, o2: PartitionLocation): Int = { - o1.getUniqueId.compareTo(o2.getUniqueId) - } - } - override def read(): Iterator[Product2[K, C]] = { val serializerInstance = newSerializerInstance(dep) @@ -127,37 +119,12 @@ class CelebornShuffleReader[K, C]( val workerRequestMap = new util.HashMap[ String, (TransportClient, util.ArrayList[PartitionLocation], PbOpenStreamList.Builder)]() - // partitionId -> (partition uniqueId -> chunkRange pair) - val partitionId2ChunkRange = - new util.HashMap[Int, util.Map[String, Pair[Integer, Integer]]]() - - val partitionId2PartitionLocations = new util.HashMap[Int, mutable.Set[PartitionLocation]]() var partCnt = 0 - // if startMapIndex > endMapIndex, means partition is skew partition. - // locations will split to sub-partitions with startMapIndex size. - val splitSkewPartitionWithoutMapRange = conf.clientPushFailureTrackingEnabled && - startMapIndex > endMapIndex - (startPartition until endPartition).foreach { partitionId => if (fileGroups.partitionGroups.containsKey(partitionId)) { - val locations = fileGroups.partitionGroups.get(partitionId) - var partitionLocationToChunkRange: util.Map[String, Pair[Integer, Integer]] = null - if (splitSkewPartitionWithoutMapRange) { - partitionLocationToChunkRange = - splitSkewedPartitionLocations(new util.ArrayList(locations), startMapIndex, endMapIndex) - partitionId2ChunkRange.put(partitionId, partitionLocationToChunkRange) - } - // filter locations avoid OPEN_STREAM when split skew partition without map range - val filterLocations = locations.asScala - .filter { location => - null != partitionLocationToChunkRange && - partitionLocationToChunkRange.containsKey(location.getUniqueId) - } - partitionId2PartitionLocations.put(partitionId, filterLocations) - - filterLocations.foreach { location => + fileGroups.partitionGroups.get(partitionId).asScala.foreach { location => partCnt += 1 val hostPort = location.hostAndFetchPort if (!workerRequestMap.containsKey(hostPort)) { @@ -229,12 +196,14 @@ class CelebornShuffleReader[K, C]( val streams = JavaUtils.newConcurrentHashMap[Integer, CelebornInputStream]() def createInputStream(partitionId: Int): Unit = { - val locations = partitionId2PartitionLocations.get(partitionId) - + val locations = + if (fileGroups.partitionGroups.containsKey(partitionId)) { + new util.ArrayList(fileGroups.partitionGroups.get(partitionId)) + } else new util.ArrayList[PartitionLocation]() val streamHandlers = - if (locations != null) { - val streamHandlerArr = new util.ArrayList[PbStreamHandler](locations.size) - locations.foreach { loc => + if (locations != null && !conf.clientPushFailureTrackingEnabled) { + val streamHandlerArr = new util.ArrayList[PbStreamHandler](locations.size()) + locations.asScala.foreach { loc => streamHandlerArr.add(locationStreamHandlerMap.get(loc)) } streamHandlerArr @@ -254,7 +223,6 @@ class CelebornShuffleReader[K, C]( new util.ArrayList[PartitionLocation](locations.asJava), streamHandlers, fileGroups.pushFailedBatches, - partitionId2ChunkRange.get(partitionId), fileGroups.mapAttempts, metricsCallback) streams.put(partitionId, inputStream) @@ -427,33 +395,6 @@ class CelebornShuffleReader[K, C]( dep.serializer.newInstance() } - private def splitSkewedPartitionLocations( - locations: util.ArrayList[PartitionLocation], - subPartitionSize: Int, - subPartitionIndex: Int): util.Map[String, Pair[Integer, Integer]] = { - locations.sort(comparator) - val totalPartitionSize = - locations.stream.mapToLong((p: PartitionLocation) => p.getStorageInfo.fileSize).sum - val step = totalPartitionSize / subPartitionSize - val startOffset = step * subPartitionIndex - val endOffset = step * (subPartitionIndex + 1) - var partitionLocationOffset: Long = 0 - val chunkRange = new util.HashMap[String, Pair[Integer, Integer]] - for (i <- 0 until locations.size) { - val p = locations.get(i) - var left = -1 - var right = -1 - for (j <- 0 until p.getStorageInfo.getChunkOffsets.size) { - val currentOffset = partitionLocationOffset + p.getStorageInfo.getChunkOffsets.get(j) - if (currentOffset >= startOffset && left < 0) left = j - if (currentOffset < endOffset) right = j - if (left >= 0 && right >= 0) chunkRange.put(p.getUniqueId, Pair.of(left, right)) - } - partitionLocationOffset = partitionLocationOffset + p.getStorageInfo.getFileSize - } - chunkRange - } - } object CelebornShuffleReader { diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 8690f5cebde..6b3146b864d 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -24,7 +24,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; -import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.fs.FileSystem; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -245,7 +244,6 @@ public CelebornInputStream readPartition( null, null, null, - null, metricsCallback); } @@ -261,7 +259,6 @@ public abstract CelebornInputStream readPartition( ArrayList locations, ArrayList streamHandlers, Map> failedBatchSetMap, - Map> chunksRange, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException; diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 43930d9d776..d5104564508 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -1894,7 +1894,6 @@ public CelebornInputStream readPartition( ArrayList locations, ArrayList streamHandlers, Map> failedBatchSetMap, - Map> chunksRange, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException { @@ -1928,7 +1927,6 @@ public CelebornInputStream readPartition( streamHandlers, mapAttempts, failedBatchSetMap, - chunksRange, attemptNumber, taskId, startMapIndex, diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 1b27f99c756..34ce0f47e78 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -57,7 +57,6 @@ public static CelebornInputStream create( ArrayList streamHandlers, int[] attempts, Map> failedBatchSetMap, - Map> chunksRange, int attemptNumber, long taskId, int startMapIndex, @@ -70,7 +69,7 @@ public static CelebornInputStream create( ExceptionMaker exceptionMaker, MetricsCallback metricsCallback) throws IOException { - if (locations == null || locations.isEmpty()) { + if (locations == null || locations.size() == 0) { return emptyInputStream; } else { // if startMapIndex > endMapIndex, means partition is skew partition. @@ -78,6 +77,9 @@ public static CelebornInputStream create( boolean splitSkewPartitionWithoutMapRange = conf.clientPushFailureTrackingEnabled() && startMapIndex > endMapIndex; if (splitSkewPartitionWithoutMapRange) { + Map> partitionLocationToChunkRange = + splitSkewedPartitionLocations(locations, startMapIndex, endMapIndex); + logger.debug("Current sub-partition locations: {}", locations); return new CelebornInputStreamImpl( conf, clientFactory, @@ -88,7 +90,7 @@ public static CelebornInputStream create( failedBatchSetMap, attemptNumber, taskId, - chunksRange, + partitionLocationToChunkRange, fetchExcludedWorkers, shuffleClient, appShuffleId, @@ -123,6 +125,38 @@ public static CelebornInputStream create( } } + public static Map> splitSkewedPartitionLocations( + ArrayList locations, int subPartitionSize, int subPartitionIndex) { + locations.sort(Comparator.comparing((PartitionLocation p) -> p.getUniqueId())); + + long totalPartitionSize = + locations.stream().mapToLong((PartitionLocation p) -> p.getStorageInfo().fileSize).sum(); + long step = totalPartitionSize / subPartitionSize; + long startOffset = step * subPartitionIndex; + long endOffset = step * (subPartitionIndex + 1); + long partitionLocationOffset = 0; + Map> chunkRange = new HashMap<>(); + for (int i = 0; i < locations.size(); i++) { + PartitionLocation p = locations.get(i); + int left = -1; + int right = -1; + for (int j = 0; j < p.getStorageInfo().getChunkOffsets().size(); j++) { + long currentOffset = partitionLocationOffset + p.getStorageInfo().getChunkOffsets().get(j); + if (currentOffset >= startOffset && left < 0) { + left = j; + } + if (currentOffset < endOffset) { + right = j; + } + if (left >= 0 && right >= 0) { + chunkRange.put(p.getUniqueId(), Pair.of(left, right)); + } + } + partitionLocationOffset += p.getStorageInfo().getFileSize(); + } + return chunkRange; + } + public static CelebornInputStream empty() { return emptyInputStream; } @@ -722,7 +756,9 @@ private boolean fillBuffer() throws IOException { failedBatch.setMapId(mapId); failedBatch.setAttemptId(attemptId); failedBatch.setBatchId(batchId); - if (failedBatchSet.contains(failedBatch)) { + String uid = currentReader.getLocation().getUniqueId(); + if (failedBatches.containsKey(uid) + && failedBatches.get(uid).contains(failedBatch)) { logger.warn("Skip duplicated batch: {}.", failedBatch); continue; } diff --git a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java index 7ec097392fd..3cc99fc4872 100644 --- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -31,7 +31,6 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -140,7 +139,6 @@ public CelebornInputStream readPartition( ArrayList locations, ArrayList streamHandlers, Map> failedBatchSetMap, - Map> chunksRange, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException { diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala index 0570760ce14..afc91028653 100644 --- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala @@ -165,7 +165,6 @@ trait WithShuffleClientSuite extends CelebornFunSuite { null, null, null, - null, metricsCallback) Assert.assertEquals(stream.read(), -1) @@ -183,7 +182,6 @@ trait WithShuffleClientSuite extends CelebornFunSuite { null, null, null, - null, metricsCallback) Assert.assertEquals(stream.read(), -1) } diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala index e81593e1d99..eee7b46fb0f 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala @@ -117,7 +117,6 @@ trait ReadWriteTestBase extends AnyFunSuite null, null, null, - null, metricsCallback) val outputStream = new ByteArrayOutputStream() From cf8dcb265f737d0488f00d5c4d8518de515979ce Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 9 Apr 2024 02:25:52 +0800 Subject: [PATCH 14/44] `celeborn.client.dataPushFailure.tracking.enabled` -> `celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled` --- ...rn-Optimize-Skew-Partitions-spark3_3.patch | 53 ++++++++++--------- .../celeborn/CelebornShuffleReader.scala | 2 +- .../celeborn/client/ShuffleClientImpl.java | 2 +- .../client/read/CelebornInputStream.java | 2 +- .../apache/celeborn/common/CelebornConf.scala | 11 ++-- docs/configuration/client.md | 2 +- 6 files changed, 37 insertions(+), 35 deletions(-) diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch index 3fa0649d9f1..a2a9a4bd83f 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch @@ -1,43 +1,47 @@ -From 39eeab2426f9676580e4e19c8b079e1967081c7d Mon Sep 17 00:00:00 2001 -From: wangshengjie -Date: Sun, 24 Mar 2024 19:51:05 +0800 -Subject: [PATCH] [SQL] Handle skew partitions with Celeborn - ---- - .../org/apache/spark/sql/internal/SQLConf.scala | 10 ++++++++++ - .../execution/adaptive/ShufflePartitionsUtil.scala | 12 +++++++++++- - 2 files changed, 21 insertions(+), 1 deletion(-) +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -index af03ad9a4cb..1e55af89160 100644 +index af03ad9a4cb..3b5c7ce4fce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -@@ -3784,6 +3784,13 @@ object SQLConf { +@@ -3784,6 +3784,12 @@ object SQLConf { .booleanConf .createWithDefault(false) -+ val CELEBORN_CLIENT_DATA_PUSH_FAILURE_TRACKING_ENABLED = -+ buildConf("spark.celeborn.client.dataPushFailure.tracking.enabled") -+ .withAlternative("celeborn.client.dataPushFailure.tracking.enabled") -+ .version("3.1.2-mdh") ++ val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = ++ buildConf("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") ++ .version("3.3.0") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * -@@ -4549,6 +4556,9 @@ class SQLConf extends Serializable with Logging { +@@ -4549,6 +4555,9 @@ class SQLConf extends Serializable with Logging { def histogramNumericPropagateInputType: Boolean = getConf(SQLConf.HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE) -+ def isCelebornClientPushFailedTrackingEnabled: Boolean = getConf( -+ SQLConf.CELEBORN_CLIENT_DATA_PUSH_FAILURE_TRACKING_ENABLED) ++ def celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = ++ getConf(SQLConf.CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -index af689db3379..7da6211e509 100644 +index af689db3379..38e54b3ed0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala @@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer @@ -54,8 +58,8 @@ index af689db3379..7da6211e509 100644 mapPartitionSizes, targetSize, smallPartitionFactor) if (mapStartIndices.length > 1) { + // If Celeborn is enabled, split skew partitions without shuffle mapper-range reading -+ val splitSkewPartitionWithCeleborn = Utils.isCelebornEnabled(SparkEnv.get.conf) && -+ SQLConf.get.isCelebornClientPushFailedTrackingEnabled ++ val isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = Utils.isCelebornEnabled(SparkEnv.get.conf) && ++ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled + Some(mapStartIndices.indices.map { i => val startMapIndex = mapStartIndices(i) @@ -65,14 +69,11 @@ index af689db3379..7da6211e509 100644 mapIndex += 1 } - PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) -+ if (splitSkewPartitionWithCeleborn) { -+ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize) ++ if (isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, -1) + } else { + PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) + } }) } else { None --- -2.25.1 - diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 00ff8cc5ecb..1c8b151ce04 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -201,7 +201,7 @@ class CelebornShuffleReader[K, C]( new util.ArrayList(fileGroups.partitionGroups.get(partitionId)) } else new util.ArrayList[PartitionLocation]() val streamHandlers = - if (locations != null && !conf.clientPushFailureTrackingEnabled) { + if (locations != null && !conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled) { val streamHandlerArr = new util.ArrayList[PbStreamHandler](locations.size()) locations.asScala.foreach { loc => streamHandlerArr.add(locationStreamHandlerMap.get(loc)) diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index d5104564508..e3a479ee777 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -201,7 +201,7 @@ public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier u pushDataTimeout = conf.pushDataTimeoutMs(); } authEnabled = conf.authEnabledOnClient(); - dataPushFailureTrackingEnabled = conf.clientPushFailureTrackingEnabled(); + dataPushFailureTrackingEnabled = conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled(); // init rpc env rpcEnv = diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 34ce0f47e78..2739d014336 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -75,7 +75,7 @@ public static CelebornInputStream create( // if startMapIndex > endMapIndex, means partition is skew partition. // locations will split to sub-partitions with startMapIndex size. boolean splitSkewPartitionWithoutMapRange = - conf.clientPushFailureTrackingEnabled() && startMapIndex > endMapIndex; + conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled() && startMapIndex > endMapIndex; if (splitSkewPartitionWithoutMapRange) { Map> partitionLocationToChunkRange = splitSkewedPartitionLocations(locations, startMapIndex, endMapIndex); diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 0eb0fe9e657..5cd1a7a1023 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1035,7 +1035,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientPushSendBufferPoolExpireTimeout: Long = get(CLIENT_PUSH_SENDBUFFERPOOL_EXPIRETIMEOUT) def clientPushSendBufferPoolExpireCheckInterval: Long = get(CLIENT_PUSH_SENDBUFFERPOOL_CHECKEXPIREINTERVAL) - def clientPushFailureTrackingEnabled: Boolean = get(CLIENT_DATA_PUSH_FAILURE_TRACKING_ENABLED) + def clientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = + get(CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED) // ////////////////////////////////////////////////////// // Client Shuffle // @@ -5876,12 +5877,12 @@ object CelebornConf extends Logging { .intConf .createWithDefault(10000) - val CLIENT_DATA_PUSH_FAILURE_TRACKING_ENABLED: ConfigEntry[Boolean] = - buildConf("celeborn.client.dataPushFailure.tracking.enabled") + val CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") .categories("client") .version("0.5.0") - .doc("When client push data to worker failed, client will record the failed batch info. " + - "Feature used to optimize skew join by avoid data sorting") + .doc("If this is true, Celeborn will adaptively split skewed partitions instead of reading them by Spark map " + + "range. Please note that this feature requires the `Celeborn-Optimize-Skew-Partitions-spark3_3.patch`. ") .booleanConf .createWithDefault(false) diff --git a/docs/configuration/client.md b/docs/configuration/client.md index c1c252ae1d0..8c03fc042fe 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -19,13 +19,13 @@ license: | | Key | Default | isDynamic | Description | Since | Deprecated | | --- | ------- | --------- | ----------- | ----- | ---------- | +| celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled | false | false | If this is true, Celeborn will adaptively split skewed partitions instead of reading them by Spark map range. Please note that this feature requires the `Celeborn-Optimize-Skew-Partitions-spark3_3.patch`. | 0.5.0 | | | celeborn.client.application.heartbeatInterval | 10s | false | Interval for client to send heartbeat message to master. | 0.3.0 | celeborn.application.heartbeatInterval | | celeborn.client.application.unregister.enabled | true | false | When true, Celeborn client will inform celeborn master the application is already shutdown during client exit, this allows the cluster to release resources immediately, resulting in resource savings. | 0.3.2 | | | celeborn.client.application.uuidSuffix.enabled | false | false | Whether to add UUID suffix for application id for unique. When `true`, add UUID suffix for unique application id. Currently, this only applies to Spark and MR. | 0.6.0 | | | celeborn.client.chunk.prefetch.enabled | false | false | Whether to enable chunk prefetch when creating CelebornInputStream. | 0.6.0 | | | celeborn.client.closeIdleConnections | true | false | Whether client will close idle connections. | 0.3.0 | | | celeborn.client.commitFiles.ignoreExcludedWorker | false | false | When true, LifecycleManager will skip workers which are in the excluded list. | 0.3.0 | | -| celeborn.client.dataPushFailure.tracking.enabled | false | false | When client push data to worker failed, client will record the failed batch info. Feature used to optimize skew join by avoid data sorting | 0.5.0 | | | celeborn.client.eagerlyCreateInputStream.threads | 32 | false | Threads count for streamCreatorPool in CelebornShuffleReader. | 0.3.1 | | | celeborn.client.excludePeerWorkerOnFailure.enabled | true | false | When true, Celeborn will exclude partition's peer worker on failure when push data to replica failed. | 0.3.0 | | | celeborn.client.excludedWorker.expireTimeout | 180s | false | Timeout time for LifecycleManager to clear reserved excluded worker. Default to be 1.5 * `celeborn.master.heartbeat.worker.timeout` to cover worker heartbeat timeout check period | 0.3.0 | celeborn.worker.excluded.expireTimeout | From 6ce0e119bf4274b41764b87227c7dc1e32b57560 Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Tue, 9 Apr 2024 15:42:53 +0800 Subject: [PATCH 15/44] refactor code for batch open stream --- .../celeborn/CelebornShuffleReader.scala | 86 +++++++++++++++++-- .../apache/celeborn/client/ShuffleClient.java | 3 + .../celeborn/client/ShuffleClientImpl.java | 2 + .../client/read/CelebornInputStream.java | 44 +--------- .../celeborn/client/DummyShuffleClient.java | 2 + .../client/WithShuffleClientSuite.scala | 2 + .../deploy/cluster/ReadWriteTestBase.scala | 1 + 7 files changed, 92 insertions(+), 48 deletions(-) diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 1c8b151ce04..baacd2cde9e 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -23,7 +23,9 @@ import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ +import scala.collection.mutable +import org.apache.commons.lang3.tuple.Pair import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext} import org.apache.spark.celeborn.ExceptionMakerHelper import org.apache.spark.internal.Logging @@ -119,12 +121,38 @@ class CelebornShuffleReader[K, C]( val workerRequestMap = new util.HashMap[ String, (TransportClient, util.ArrayList[PartitionLocation], PbOpenStreamList.Builder)]() + // partitionId -> (partition uniqueId -> chunkRange pair) + val partitionId2ChunkRange = + new util.HashMap[Int, util.Map[String, Pair[Integer, Integer]]]() + + val partitionId2PartitionLocations = new util.HashMap[Int, util.Set[PartitionLocation]]() var partCnt = 0 + // if startMapIndex > endMapIndex, means partition is skew partition. + // locations will split to sub-partitions with startMapIndex size. + val splitSkewPartitionWithoutMapRange = + conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled && startMapIndex > endMapIndex + (startPartition until endPartition).foreach { partitionId => if (fileGroups.partitionGroups.containsKey(partitionId)) { - fileGroups.partitionGroups.get(partitionId).asScala.foreach { location => + var locations = fileGroups.partitionGroups.get(partitionId) + if (splitSkewPartitionWithoutMapRange) { + var partitionLocation2ChunkRange: util.Map[String, Pair[Integer, Integer]] = null + partitionLocation2ChunkRange = + splitSkewedPartitionLocations(new util.ArrayList(locations), startMapIndex, endMapIndex) + partitionId2ChunkRange.put(partitionId, partitionLocation2ChunkRange) + // filter locations avoid OPEN_STREAM when split skew partition without map range + val filterLocations = locations.asScala + .filter { location => + null != partitionLocation2ChunkRange && + partitionLocation2ChunkRange.containsKey(location.getUniqueId) + } + locations = filterLocations.asJava + partitionId2PartitionLocations.put(partitionId, locations) + } + + locations.asScala.foreach { location => partCnt += 1 val hostPort = location.hostAndFetchPort if (!workerRequestMap.containsKey(hostPort)) { @@ -197,13 +225,22 @@ class CelebornShuffleReader[K, C]( def createInputStream(partitionId: Int): Unit = { val locations = - if (fileGroups.partitionGroups.containsKey(partitionId)) { - new util.ArrayList(fileGroups.partitionGroups.get(partitionId)) - } else new util.ArrayList[PartitionLocation]() + if (splitSkewPartitionWithoutMapRange) { + partitionId2PartitionLocations.get(partitionId) + } else { + fileGroups.partitionGroups.get(partitionId) + } + + val locationList = + if (null == locations) { + new util.ArrayList[PartitionLocation]() + } else { + new util.ArrayList[PartitionLocation](locations) + } val streamHandlers = - if (locations != null && !conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled) { - val streamHandlerArr = new util.ArrayList[PbStreamHandler](locations.size()) - locations.asScala.foreach { loc => + if (locations != null) { + val streamHandlerArr = new util.ArrayList[PbStreamHandler](locationList.size) + locationList.asScala.foreach { loc => streamHandlerArr.add(locationStreamHandlerMap.get(loc)) } streamHandlerArr @@ -220,9 +257,10 @@ class CelebornShuffleReader[K, C]( endMapIndex, if (throwsFetchFailure) ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER else null, - new util.ArrayList[PartitionLocation](locations.asJava), + locationList, streamHandlers, fileGroups.pushFailedBatches, + partitionId2ChunkRange.get(partitionId), fileGroups.mapAttempts, metricsCallback) streams.put(partitionId, inputStream) @@ -395,8 +433,40 @@ class CelebornShuffleReader[K, C]( dep.serializer.newInstance() } + private def splitSkewedPartitionLocations( + locations: util.ArrayList[PartitionLocation], + subPartitionSize: Int, + subPartitionIndex: Int): util.Map[String, Pair[Integer, Integer]] = { + locations.sort(CelebornShuffleReader.comparator) + val totalPartitionSize = + locations.stream.mapToLong((p: PartitionLocation) => p.getStorageInfo.fileSize).sum + val step = totalPartitionSize / subPartitionSize + val startOffset = step * subPartitionIndex + val endOffset = step * (subPartitionIndex + 1) + var partitionLocationOffset: Long = 0 + val chunkRange = new util.HashMap[String, Pair[Integer, Integer]] + for (i <- 0 until locations.size) { + val p = locations.get(i) + var left = -1 + var right = -1 + for (j <- 0 until p.getStorageInfo.getChunkOffsets.size) { + val currentOffset = partitionLocationOffset + p.getStorageInfo.getChunkOffsets.get(j) + if (currentOffset >= startOffset && left < 0) left = j + if (currentOffset < endOffset) right = j + if (left >= 0 && right >= 0) chunkRange.put(p.getUniqueId, Pair.of(left, right)) + } + partitionLocationOffset = partitionLocationOffset + p.getStorageInfo.getFileSize + } + chunkRange + } + } object CelebornShuffleReader { var streamCreatorPool: ThreadPoolExecutor = null + val comparator = new util.Comparator[PartitionLocation] { + override def compare(o1: PartitionLocation, o2: PartitionLocation): Int = { + o1.getUniqueId.compareTo(o2.getUniqueId) + } + } } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 6b3146b864d..8690f5cebde 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; +import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.fs.FileSystem; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -244,6 +245,7 @@ public CelebornInputStream readPartition( null, null, null, + null, metricsCallback); } @@ -259,6 +261,7 @@ public abstract CelebornInputStream readPartition( ArrayList locations, ArrayList streamHandlers, Map> failedBatchSetMap, + Map> chunksRange, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException; diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index e3a479ee777..54c7f0bd2ea 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -1894,6 +1894,7 @@ public CelebornInputStream readPartition( ArrayList locations, ArrayList streamHandlers, Map> failedBatchSetMap, + Map> chunksRange, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException { @@ -1927,6 +1928,7 @@ public CelebornInputStream readPartition( streamHandlers, mapAttempts, failedBatchSetMap, + chunksRange, attemptNumber, taskId, startMapIndex, diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 2739d014336..96b64497bdc 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -57,6 +57,7 @@ public static CelebornInputStream create( ArrayList streamHandlers, int[] attempts, Map> failedBatchSetMap, + Map> chunksRange, int attemptNumber, long taskId, int startMapIndex, @@ -69,7 +70,7 @@ public static CelebornInputStream create( ExceptionMaker exceptionMaker, MetricsCallback metricsCallback) throws IOException { - if (locations == null || locations.size() == 0) { + if (locations == null || locations.isEmpty()) { return emptyInputStream; } else { // if startMapIndex > endMapIndex, means partition is skew partition. @@ -77,9 +78,6 @@ public static CelebornInputStream create( boolean splitSkewPartitionWithoutMapRange = conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled() && startMapIndex > endMapIndex; if (splitSkewPartitionWithoutMapRange) { - Map> partitionLocationToChunkRange = - splitSkewedPartitionLocations(locations, startMapIndex, endMapIndex); - logger.debug("Current sub-partition locations: {}", locations); return new CelebornInputStreamImpl( conf, clientFactory, @@ -90,7 +88,7 @@ public static CelebornInputStream create( failedBatchSetMap, attemptNumber, taskId, - partitionLocationToChunkRange, + chunksRange, fetchExcludedWorkers, shuffleClient, appShuffleId, @@ -125,38 +123,6 @@ public static CelebornInputStream create( } } - public static Map> splitSkewedPartitionLocations( - ArrayList locations, int subPartitionSize, int subPartitionIndex) { - locations.sort(Comparator.comparing((PartitionLocation p) -> p.getUniqueId())); - - long totalPartitionSize = - locations.stream().mapToLong((PartitionLocation p) -> p.getStorageInfo().fileSize).sum(); - long step = totalPartitionSize / subPartitionSize; - long startOffset = step * subPartitionIndex; - long endOffset = step * (subPartitionIndex + 1); - long partitionLocationOffset = 0; - Map> chunkRange = new HashMap<>(); - for (int i = 0; i < locations.size(); i++) { - PartitionLocation p = locations.get(i); - int left = -1; - int right = -1; - for (int j = 0; j < p.getStorageInfo().getChunkOffsets().size(); j++) { - long currentOffset = partitionLocationOffset + p.getStorageInfo().getChunkOffsets().get(j); - if (currentOffset >= startOffset && left < 0) { - left = j; - } - if (currentOffset < endOffset) { - right = j; - } - if (left >= 0 && right >= 0) { - chunkRange.put(p.getUniqueId(), Pair.of(left, right)); - } - } - partitionLocationOffset += p.getStorageInfo().getFileSize(); - } - return chunkRange; - } - public static CelebornInputStream empty() { return emptyInputStream; } @@ -756,9 +722,7 @@ private boolean fillBuffer() throws IOException { failedBatch.setMapId(mapId); failedBatch.setAttemptId(attemptId); failedBatch.setBatchId(batchId); - String uid = currentReader.getLocation().getUniqueId(); - if (failedBatches.containsKey(uid) - && failedBatches.get(uid).contains(failedBatch)) { + if (failedBatchSet.contains(failedBatch)) { logger.warn("Skip duplicated batch: {}.", failedBatch); continue; } diff --git a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java index 3cc99fc4872..7ec097392fd 100644 --- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -31,6 +31,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -139,6 +140,7 @@ public CelebornInputStream readPartition( ArrayList locations, ArrayList streamHandlers, Map> failedBatchSetMap, + Map> chunksRange, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException { diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala index afc91028653..0570760ce14 100644 --- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala @@ -165,6 +165,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { null, null, null, + null, metricsCallback) Assert.assertEquals(stream.read(), -1) @@ -182,6 +183,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { null, null, null, + null, metricsCallback) Assert.assertEquals(stream.read(), -1) } diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala index eee7b46fb0f..e81593e1d99 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala @@ -117,6 +117,7 @@ trait ReadWriteTestBase extends AnyFunSuite null, null, null, + null, metricsCallback) val outputStream = new ByteArrayOutputStream() From 9b7cc7f0b774343ca1c999de10c0dc61db49bf1a Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Fri, 12 Apr 2024 14:50:17 +0800 Subject: [PATCH 16/44] add ut and refactor split logic --- client-spark/spark-3-4/pom.xml | 5 + .../celeborn/CelebornShuffleReader.scala | 10 +- .../celeborn/CelebornShuffleReaderSuiteJ.java | 130 ++++++++++++++++++ .../LifecycleManagerCommitFilesSuite.scala | 2 +- 4 files changed, 144 insertions(+), 3 deletions(-) create mode 100644 client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuiteJ.java diff --git a/client-spark/spark-3-4/pom.xml b/client-spark/spark-3-4/pom.xml index 8ddfd6f2267..1f89fad9cd6 100644 --- a/client-spark/spark-3-4/pom.xml +++ b/client-spark/spark-3-4/pom.xml @@ -91,5 +91,10 @@ mockito-core test + + org.mockito + mockito-inline + test + diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index baacd2cde9e..23ef03e88a3 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -433,7 +433,7 @@ class CelebornShuffleReader[K, C]( dep.serializer.newInstance() } - private def splitSkewedPartitionLocations( + def splitSkewedPartitionLocations( locations: util.ArrayList[PartitionLocation], subPartitionSize: Int, subPartitionIndex: Int): util.Map[String, Pair[Integer, Integer]] = { @@ -442,7 +442,13 @@ class CelebornShuffleReader[K, C]( locations.stream.mapToLong((p: PartitionLocation) => p.getStorageInfo.fileSize).sum val step = totalPartitionSize / subPartitionSize val startOffset = step * subPartitionIndex - val endOffset = step * (subPartitionIndex + 1) + val endOffset = + if (subPartitionIndex == subPartitionSize - 1) { + // last subPartition should include all remaining data + totalPartitionSize + 1 + } else { + step * (subPartitionIndex + 1) + } var partitionLocationOffset: Long = 0 val chunkRange = new util.HashMap[String, Pair[Integer, Integer]] for (i <- 0 until locations.size) { diff --git a/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuiteJ.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuiteJ.java new file mode 100644 index 00000000000..3ba875f707a --- /dev/null +++ b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuiteJ.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Map; + +import com.google.common.collect.Maps; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.Assert; +import org.junit.Test; +import org.mockito.MockedStatic; +import org.mockito.Mockito; + +import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.identity.UserIdentifier; +import org.apache.celeborn.common.protocol.PartitionLocation; +import org.apache.celeborn.common.protocol.StorageInfo; + +public class CelebornShuffleReaderSuiteJ { + @Test + public void testSkewPartitionSplit() { + CelebornShuffleHandle handle = + new CelebornShuffleHandle( + "appId", "host", 0, new UserIdentifier("mock", "mock"), 0, false, 10, null); + + MockedStatic shuffleClient = null; + try { + shuffleClient = Mockito.mockStatic(ShuffleClient.class); + CelebornShuffleReader shuffleReader = + new CelebornShuffleReader( + handle, 0, 10, 0, 10, null, new CelebornConf(), null, new ExecutorShuffleIdTracker()); + + ArrayList locations = new ArrayList<>(); + for (int i = 0; i < 13; i++) { + PartitionLocation location = + new PartitionLocation(0, i, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); + StorageInfo storageInfo = + new StorageInfo( + StorageInfo.Type.HDD, + "mountPoint", + false, + "filePath", + StorageInfo.LOCAL_DISK_MASK, + 1000, + Arrays.asList(0L, 100L, 200L, 300L, 500L, 1000L)); + location.setStorageInfo(storageInfo); + locations.add(location); + } + + PartitionLocation location = + new PartitionLocation(0, 91, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); + StorageInfo storageInfo = + new StorageInfo( + StorageInfo.Type.HDD, + "mountPoint", + false, + "filePath", + StorageInfo.LOCAL_DISK_MASK, + 1, + Arrays.asList(0L, 1L)); + location.setStorageInfo(storageInfo); + locations.add(location); + + Map> expectResult = Maps.newHashMap(); + + for (int i = 0; i < 5; i++) { + int subPartitionSize = 3; + + int subPartitionIndex = 0; + Map> result1 = + shuffleReader.splitSkewedPartitionLocations( + locations, subPartitionSize, subPartitionIndex); + expectResult.clear(); + expectResult.put("0-0", Pair.of(0, 5)); + expectResult.put("0-1", Pair.of(0, 5)); + expectResult.put("0-10", Pair.of(0, 5)); + expectResult.put("0-11", Pair.of(0, 5)); + expectResult.put("0-12", Pair.of(0, 3)); + Assert.assertEquals(expectResult, result1); + + subPartitionIndex = 1; + Map> result2 = + shuffleReader.splitSkewedPartitionLocations( + locations, subPartitionSize, subPartitionIndex); + expectResult.clear(); + expectResult.put("0-12", Pair.of(4, 5)); + expectResult.put("0-2", Pair.of(0, 5)); + expectResult.put("0-3", Pair.of(0, 5)); + expectResult.put("0-4", Pair.of(0, 5)); + expectResult.put("0-5", Pair.of(0, 4)); + Assert.assertEquals(expectResult, result2); + + subPartitionIndex = 2; + Map> result3 = + shuffleReader.splitSkewedPartitionLocations( + locations, subPartitionSize, subPartitionIndex); + expectResult.clear(); + expectResult.put("0-5", Pair.of(5, 5)); + expectResult.put("0-6", Pair.of(0, 5)); + expectResult.put("0-7", Pair.of(0, 5)); + expectResult.put("0-8", Pair.of(0, 5)); + expectResult.put("0-9", Pair.of(0, 5)); + expectResult.put("0-91", Pair.of(0, 1)); + Assert.assertEquals(expectResult, result3); + } + } finally { + if (null != shuffleClient) { + shuffleClient.close(); + } + } + } +} diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala index ead015832fa..a772316f318 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala @@ -266,7 +266,7 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC 0 until 3 foreach { partitionId => bufferLength = - shuffleClient.pushData(shuffleId, 0, 0, partitionId, buffer, 0, buffer.length, 1, 1) + shuffleClient.pushData(shuffleId, 0, 0, partitionId, buffer, 0, buffer.length, 1, 3) lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId) } From 8ac717252b01c86a0d7339101f06adcb279571da Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Fri, 12 Apr 2024 15:25:50 +0800 Subject: [PATCH 17/44] refactor code according to review suggestions --- .../celeborn/common/write/PushFailedBatch.java | 12 ++++++------ .../common/protocol/message/ControlMessages.scala | 3 +-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java index 0e872611e13..ccee8bf1113 100644 --- a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java +++ b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java @@ -29,6 +29,12 @@ public class PushFailedBatch implements Serializable { private int attemptId; private int batchId; + public PushFailedBatch(int mapId, int attemptId, int batchId) { + this.mapId = mapId; + this.attemptId = attemptId; + this.batchId = batchId; + } + public int getMapId() { return mapId; } @@ -53,12 +59,6 @@ public void setBatchId(int batchId) { this.batchId = batchId; } - public PushFailedBatch(int mapId, int attemptId, int batchId) { - this.mapId = mapId; - this.attemptId = attemptId; - this.batchId = batchId; - } - @Override public boolean equals(Object other) { if (!(other instanceof PushFailedBatch)) { diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 7bbe66ec54c..e730b9b7afb 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -288,8 +288,7 @@ object ControlMessages extends Logging { fileGroup: util.Map[Integer, util.Set[PartitionLocation]], attempts: Array[Int], partitionIds: util.Set[Integer] = Collections.emptySet[Integer](), - pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = - new util.HashMap[String, util.Set[PushFailedBatch]]()) + pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = Collections.emptyMap()) extends MasterMessage object WorkerExclude { From c169aa2bfbd266a8698abbc9da1b56ddced4f79c Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Fri, 12 Apr 2024 17:46:10 +0800 Subject: [PATCH 18/44] add sbt dependency --- project/CelebornBuild.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/CelebornBuild.scala b/project/CelebornBuild.scala index 63fe4249c1a..16f14f7044f 100644 --- a/project/CelebornBuild.scala +++ b/project/CelebornBuild.scala @@ -942,7 +942,7 @@ trait SparkClientProjects { libraryDependencies ++= Seq( "org.apache.spark" %% "spark-core" % sparkVersion % "provided", "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", - ) ++ commonUnitTestDependencies + ) ++ commonUnitTestDependencies ++ Seq(Dependencies.mockitoInline % "test") ) } From 4b9fab6d1747bf475c0196265c5cd668f20874bb Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 16 Apr 2024 01:42:45 +0800 Subject: [PATCH 19/44] update spark patch --- ...rn-Optimize-Skew-Partitions-spark3_3.patch | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch index a2a9a4bd83f..e8562f18751 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch @@ -14,7 +14,7 @@ # limitations under the License. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -index af03ad9a4cb..3b5c7ce4fce 100644 +index af03ad9a4cb..7a3ee9ebfaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3784,6 +3784,12 @@ object SQLConf { @@ -23,7 +23,7 @@ index af03ad9a4cb..3b5c7ce4fce 100644 + val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = + buildConf("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") -+ .version("3.3.0") ++ .version("3.0.0") + .booleanConf + .createWithDefault(false) + @@ -41,36 +41,40 @@ index af03ad9a4cb..3b5c7ce4fce 100644 /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -index af689db3379..38e54b3ed0a 100644 +index af689db3379..393ac402363 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -@@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer +@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} +import org.apache.spark.sql.internal.SQLConf -+import org.apache.spark.util.Utils object ShufflePartitionsUtil extends Logging { final val SMALL_PARTITION_FACTOR = 0.2 -@@ -387,6 +389,10 @@ object ShufflePartitionsUtil extends Logging { +@@ -387,6 +388,12 @@ object ShufflePartitionsUtil extends Logging { val mapStartIndices = splitSizeListByTargetSize( mapPartitionSizes, targetSize, smallPartitionFactor) if (mapStartIndices.length > 1) { + // If Celeborn is enabled, split skew partitions without shuffle mapper-range reading -+ val isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = Utils.isCelebornEnabled(SparkEnv.get.conf) && -+ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled -+ ++ val isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = { ++ // TODO: check fallback or not. ++ SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled ++ } Some(mapStartIndices.indices.map { i => val startMapIndex = mapStartIndices(i) val endMapIndex = if (i == mapStartIndices.length - 1) { -@@ -400,7 +406,11 @@ object ShufflePartitionsUtil extends Logging { +@@ -400,7 +407,14 @@ object ShufflePartitionsUtil extends Logging { dataSize += mapPartitionSizes(mapIndex) mapIndex += 1 } - PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) + if (isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { -+ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, -1) ++ // These `dataSize` variables may not be accurate as they only represent the sum of ++ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. ++ // Please not to use these dataSize variables in any other part of the codebase. ++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize) + } else { + PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) + } From 2e3a1c148e46353bb1ec10194c0294eddbebbb86 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 16 Apr 2024 02:31:19 +0800 Subject: [PATCH 20/44] address comment --- ...rn-Optimize-Skew-Partitions-spark3_3.patch | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch index e8562f18751..b5f0c8f2dce 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch @@ -41,7 +41,7 @@ index af03ad9a4cb..7a3ee9ebfaf 100644 /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -index af689db3379..393ac402363 100644 +index af689db3379..d421ca80a3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer @@ -52,20 +52,32 @@ index af689db3379..393ac402363 100644 object ShufflePartitionsUtil extends Logging { final val SMALL_PARTITION_FACTOR = 0.2 -@@ -387,6 +388,12 @@ object ShufflePartitionsUtil extends Logging { +@@ -387,6 +388,24 @@ object ShufflePartitionsUtil extends Logging { val mapStartIndices = splitSizeListByTargetSize( mapPartitionSizes, targetSize, smallPartitionFactor) if (mapStartIndices.length > 1) { -+ // If Celeborn is enabled, split skew partitions without shuffle mapper-range reading + val isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = { -+ // TODO: check fallback or not. -+ SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && -+ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled ++ val throwsFetchFailure = SparkEnv.get ++ .conf ++ .get("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false") ++ .toBoolean ++ if (throwsFetchFailure) { ++ throw new UnsupportedOperationException( ++ "Currently, the 'Optimize Skewed Partition Read' feature cannot be used " + ++ "together with the 'Stage Re-run' feature. (The configuration parameters " + ++ "`spark.celeborn.client.spark.fetch.throwsFetchFailure` and " + ++ "`spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled` cannot be set " + ++ "to `true` at the same time.)") ++ } else { ++ // TODO: check fallback or not. ++ SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled ++ } + } Some(mapStartIndices.indices.map { i => val startMapIndex = mapStartIndices(i) val endMapIndex = if (i == mapStartIndices.length - 1) { -@@ -400,7 +407,14 @@ object ShufflePartitionsUtil extends Logging { +@@ -400,7 +419,14 @@ object ShufflePartitionsUtil extends Logging { dataSize += mapPartitionSizes(mapIndex) mapIndex += 1 } From 1794fd43584bda53ffd76f0b7c693f10b38e2af9 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 16 Apr 2024 02:34:49 +0800 Subject: [PATCH 21/44] fix --- ...rn-Optimize-Skew-Partitions-spark3_3.patch | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch index b5f0c8f2dce..86c8ca4fa21 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch @@ -41,7 +41,7 @@ index af03ad9a4cb..7a3ee9ebfaf 100644 /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -index af689db3379..d421ca80a3d 100644 +index af689db3379..f277bc396d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer @@ -52,32 +52,31 @@ index af689db3379..d421ca80a3d 100644 object ShufflePartitionsUtil extends Logging { final val SMALL_PARTITION_FACTOR = 0.2 -@@ -387,6 +388,24 @@ object ShufflePartitionsUtil extends Logging { +@@ -387,6 +388,23 @@ object ShufflePartitionsUtil extends Logging { val mapStartIndices = splitSizeListByTargetSize( mapPartitionSizes, targetSize, smallPartitionFactor) if (mapStartIndices.length > 1) { -+ val isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = { -+ val throwsFetchFailure = SparkEnv.get -+ .conf -+ .get("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false") -+ .toBoolean -+ if (throwsFetchFailure) { -+ throw new UnsupportedOperationException( -+ "Currently, the 'Optimize Skewed Partition Read' feature cannot be used " + ++ val isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = ++ // TODO: check fallback or not. ++ SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled ++ ++ val throwsFetchFailure = SparkEnv.get ++ .conf ++ .get("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false") ++ .toBoolean ++ if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ throw new UnsupportedOperationException( ++ "Currently, the 'Optimize Skewed Partition Read' feature cannot be used " + + "together with the 'Stage Re-run' feature. (The configuration parameters " + + "`spark.celeborn.client.spark.fetch.throwsFetchFailure` and " + + "`spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled` cannot be set " + + "to `true` at the same time.)") -+ } else { -+ // TODO: check fallback or not. -+ SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && -+ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled -+ } + } Some(mapStartIndices.indices.map { i => val startMapIndex = mapStartIndices(i) val endMapIndex = if (i == mapStartIndices.length - 1) { -@@ -400,7 +419,14 @@ object ShufflePartitionsUtil extends Logging { +@@ -400,7 +418,14 @@ object ShufflePartitionsUtil extends Logging { dataSize += mapPartitionSizes(mapIndex) mapIndex += 1 } From e7d8d73f85f50e5aee529d16c829933810019b74 Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Tue, 16 Apr 2024 21:41:40 +0800 Subject: [PATCH 22/44] add ut --- .../client/commit/CommitHandler.scala | 2 + .../tests/spark/PushFailedBatchSuite.scala | 89 ++++++++++++++ .../celeborn/tests/spark/SkewJoinSuite.scala | 114 +++++++++--------- 3 files changed, 151 insertions(+), 54 deletions(-) create mode 100644 tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala index 63371601b95..38c3b0c9c81 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala @@ -84,6 +84,8 @@ abstract class CommitHandler( def getPartitionType(): PartitionType + def getShuffleFailedBatches(): ShufflePushFailedBatches = this.shufflePushFailedBatches + def isStageEnd(shuffleId: Int): Boolean = false def isStageEndOrInProcess(shuffleId: Int): Boolean = false diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala new file mode 100644 index 00000000000..716ad4e3693 --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.tests.spark + +import com.google.common.collect.Sets +import org.apache.spark.{SparkConf, SparkContext, SparkContextHelper} +import org.apache.spark.shuffle.celeborn.SparkShuffleManager +import org.junit.Assert +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.protocol.ShuffleMode +import org.apache.celeborn.common.write.PushFailedBatch +import org.apache.celeborn.service.deploy.worker.PushDataHandler + +class PushFailedBatchSuite extends AnyFunSuite + with SparkTestBase + with BeforeAndAfterEach { + + override def beforeAll(): Unit = { + val workerConf = Map( + CelebornConf.TEST_WORKER_PUSH_REPLICA_DATA_TIMEOUT.key -> "true") + + setupMiniClusterWithRandomPorts(workerConf = workerConf, workerNum = 4) + } + + override def beforeEach(): Unit = { + ShuffleClient.reset() + PushDataHandler.pushPrimaryDataTimeoutTested.set(false) + PushDataHandler.pushReplicaDataTimeoutTested.set(false) + PushDataHandler.pushPrimaryMergeDataTimeoutTested.set(false) + PushDataHandler.pushReplicaMergeDataTimeoutTested.set(false) + } + + override protected def afterEach() { + System.gc() + } + + test("CELEBORN-1319: check failed batch info by making push timeout") { + val sparkConf = new SparkConf() + .set(s"spark.${CelebornConf.TEST_CLIENT_RETRY_REVIVE.key}", "false") + .set(s"spark.${CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key}", "true") + .set(s"spark.${CelebornConf.CLIENT_PUSH_DATA_TIMEOUT.key}", "3s") + .set( + s"spark.${CelebornConf.CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED.key}", + "true") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.celeborn.SparkShuffleManager") + .setAppName("celeborn-1319") + .setMaster("local[2]") + updateSparkConf(sparkConf, ShuffleMode.HASH) + val sc = new SparkContext(sparkConf) + + sc.parallelize(1 to 1, 1).repartition(1).map(i => i + 1).collect() + + val manager = SparkContextHelper.env + .shuffleManager + .asInstanceOf[SparkShuffleManager] + .getLifecycleManager + + // only one batch failed due to push timeout, so shuffle id will be 0, + // and PartitionLocation uniqueId will be 0-0 + val pushFailedBatch = manager.commitManager.getCommitHandler(0).getShuffleFailedBatches() + assert(!pushFailedBatch.isEmpty) + Assert.assertEquals( + pushFailedBatch.get(0).get("0-0"), + Sets.newHashSet(new PushFailedBatch(0, 0, 1))) + + sc.stop() + } + +} diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala index 6cfbd227a94..f1a5d5e7769 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala @@ -48,63 +48,69 @@ class SkewJoinSuite extends AnyFunSuite } CompressionCodec.values.foreach { codec => - test(s"celeborn spark integration test - skew join - $codec") { - val sparkConf = new SparkConf().setAppName("celeborn-demo") - .setMaster("local[2]") - .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") - .set("spark.sql.adaptive.skewJoin.enabled", "true") - .set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "16MB") - .set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "12MB") - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.sql.adaptive.autoBroadcastJoinThreshold", "-1") - .set(SQLConf.PARQUET_COMPRESSION.key, "gzip") - .set(s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", codec.name) - .set(s"spark.${CelebornConf.SHUFFLE_RANGE_READ_FILTER_ENABLED.key}", "true") + Seq(false, true).foreach { enabled => + test( + s"celeborn spark integration test - skew join - with $codec - with client skew $enabled") { + val sparkConf = new SparkConf().setAppName("celeborn-demo") + .setMaster("local[2]") + .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set("spark.sql.adaptive.skewJoin.enabled", "true") + .set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "16MB") + .set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "12MB") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.sql.adaptive.autoBroadcastJoinThreshold", "-1") + .set(SQLConf.PARQUET_COMPRESSION.key, "gzip") + .set(s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", codec.name) + .set(s"spark.${CelebornConf.SHUFFLE_RANGE_READ_FILTER_ENABLED.key}", "true") + .set( + s"spark.${CelebornConf.CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED.key}", + s"$enabled") - enableCeleborn(sparkConf) + enableCeleborn(sparkConf) - val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() - if (sparkSession.version.startsWith("3")) { - import sparkSession.implicits._ - val df = sparkSession.sparkContext.parallelize(1 to 120000, 8) - .map(i => { - val random = new Random() - val oriKey = random.nextInt(64) - val key = if (oriKey < 32) 1 else oriKey - val fas = random.nextInt(1200000) - val fa = Range(fas, fas + 100).mkString(",") - val fbs = random.nextInt(1200000) - val fb = Range(fbs, fbs + 100).mkString(",") - val fcs = random.nextInt(1200000) - val fc = Range(fcs, fcs + 100).mkString(",") - val fds = random.nextInt(1200000) - val fd = Range(fds, fds + 100).mkString(",") + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + if (sparkSession.version.startsWith("3")) { + import sparkSession.implicits._ + val df = sparkSession.sparkContext.parallelize(1 to 120000, 8) + .map(i => { + val random = new Random() + val oriKey = random.nextInt(64) + val key = if (oriKey < 32) 1 else oriKey + val fas = random.nextInt(1200000) + val fa = Range(fas, fas + 100).mkString(",") + val fbs = random.nextInt(1200000) + val fb = Range(fbs, fbs + 100).mkString(",") + val fcs = random.nextInt(1200000) + val fc = Range(fcs, fcs + 100).mkString(",") + val fds = random.nextInt(1200000) + val fd = Range(fds, fds + 100).mkString(",") - (key, fa, fb, fc, fd) - }) - .toDF("fa", "f1", "f2", "f3", "f4") - df.createOrReplaceTempView("view1") - val df2 = sparkSession.sparkContext.parallelize(1 to 8, 8) - .map(i => { - val random = new Random() - val oriKey = random.nextInt(64) - val key = if (oriKey < 32) 1 else oriKey - val fas = random.nextInt(1200000) - val fa = Range(fas, fas + 100).mkString(",") - val fbs = random.nextInt(1200000) - val fb = Range(fbs, fbs + 100).mkString(",") - val fcs = random.nextInt(1200000) - val fc = Range(fcs, fcs + 100).mkString(",") - val fds = random.nextInt(1200000) - val fd = Range(fds, fds + 100).mkString(",") - (key, fa, fb, fc, fd) - }) - .toDF("fb", "f6", "f7", "f8", "f9") - df2.createOrReplaceTempView("view2") - sparkSession.sql("drop table if exists fres") - sparkSession.sql("create table fres using parquet as select * from view1 a inner join view2 b on a.fa=b.fb where a.fa=1 ") - sparkSession.sql("drop table fres") - sparkSession.stop() + (key, fa, fb, fc, fd) + }) + .toDF("fa", "f1", "f2", "f3", "f4") + df.createOrReplaceTempView("view1") + val df2 = sparkSession.sparkContext.parallelize(1 to 8, 8) + .map(i => { + val random = new Random() + val oriKey = random.nextInt(64) + val key = if (oriKey < 32) 1 else oriKey + val fas = random.nextInt(1200000) + val fa = Range(fas, fas + 100).mkString(",") + val fbs = random.nextInt(1200000) + val fb = Range(fbs, fbs + 100).mkString(",") + val fcs = random.nextInt(1200000) + val fc = Range(fcs, fcs + 100).mkString(",") + val fds = random.nextInt(1200000) + val fd = Range(fds, fds + 100).mkString(",") + (key, fa, fb, fc, fd) + }) + .toDF("fb", "f6", "f7", "f8", "f9") + df2.createOrReplaceTempView("view2") + sparkSession.sql("drop table if exists fres") + sparkSession.sql("create table fres using parquet as select * from view1 a inner join view2 b on a.fa=b.fb where a.fa=1 ") + sparkSession.sql("drop table fres") + sparkSession.stop() + } } } } From 14d5f107ef7e24d64b4b504c2f3d0f6dd0e337bb Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 18 Apr 2024 14:46:03 +0800 Subject: [PATCH 23/44] update --- .../celeborn/CelebornShuffleReader.scala | 64 ++++++++++++++----- .../celeborn/client/ShuffleClientImpl.java | 2 +- 2 files changed, 48 insertions(+), 18 deletions(-) diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 23ef03e88a3..9e91c3461ae 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -18,7 +18,7 @@ package org.apache.spark.shuffle.celeborn import java.io.IOException -import java.util +import java.util.{ArrayList => JArrayList, Comparator, Map => JMap, HashMap => JHashMap, Set => JSet} import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicReference @@ -118,14 +118,13 @@ class CelebornShuffleReader[K, C]( } // host-port -> (TransportClient, PartitionLocation Array, PbOpenStreamList) - val workerRequestMap = new util.HashMap[ + val workerRequestMap = new JHashMap[ String, - (TransportClient, util.ArrayList[PartitionLocation], PbOpenStreamList.Builder)]() + (TransportClient, JArrayList[PartitionLocation], PbOpenStreamList.Builder)]() // partitionId -> (partition uniqueId -> chunkRange pair) - val partitionId2ChunkRange = - new util.HashMap[Int, util.Map[String, Pair[Integer, Integer]]]() + val partitionId2ChunkRange = new JHashMap[Int, JMap[String, Pair[Integer, Integer]]]() - val partitionId2PartitionLocations = new util.HashMap[Int, util.Set[PartitionLocation]]() + val partitionId2PartitionLocations = new JHashMap[Int, JSet[PartitionLocation]]() var partCnt = 0 @@ -138,9 +137,8 @@ class CelebornShuffleReader[K, C]( if (fileGroups.partitionGroups.containsKey(partitionId)) { var locations = fileGroups.partitionGroups.get(partitionId) if (splitSkewPartitionWithoutMapRange) { - var partitionLocation2ChunkRange: util.Map[String, Pair[Integer, Integer]] = null - partitionLocation2ChunkRange = - splitSkewedPartitionLocations(new util.ArrayList(locations), startMapIndex, endMapIndex) + val partitionLocation2ChunkRange = splitSkewedPartitionLocations( + new JArrayList(locations), startMapIndex, endMapIndex) partitionId2ChunkRange.put(partitionId, partitionLocation2ChunkRange) // filter locations avoid OPEN_STREAM when split skew partition without map range val filterLocations = locations.asScala @@ -164,7 +162,7 @@ class CelebornShuffleReader[K, C]( pbOpenStreamList.setShuffleKey(shuffleKey) workerRequestMap.put( hostPort, - (client, new util.ArrayList[PartitionLocation], pbOpenStreamList)) + (client, new JArrayList[PartitionLocation], pbOpenStreamList)) } catch { case ex: Exception => shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort, ex) @@ -233,13 +231,13 @@ class CelebornShuffleReader[K, C]( val locationList = if (null == locations) { - new util.ArrayList[PartitionLocation]() + new JArrayList[PartitionLocation]() } else { - new util.ArrayList[PartitionLocation](locations) + new JArrayList[PartitionLocation](locations) } val streamHandlers = if (locations != null) { - val streamHandlerArr = new util.ArrayList[PbStreamHandler](locationList.size) + val streamHandlerArr = new JArrayList[PbStreamHandler](locationList.size) locationList.asScala.foreach { loc => streamHandlerArr.add(locationStreamHandlerMap.get(loc)) } @@ -434,9 +432,9 @@ class CelebornShuffleReader[K, C]( } def splitSkewedPartitionLocations( - locations: util.ArrayList[PartitionLocation], + locations: JArrayList[PartitionLocation], subPartitionSize: Int, - subPartitionIndex: Int): util.Map[String, Pair[Integer, Integer]] = { + subPartitionIndex: Int): JMap[String, Pair[Integer, Integer]] = { locations.sort(CelebornShuffleReader.comparator) val totalPartitionSize = locations.stream.mapToLong((p: PartitionLocation) => p.getStorageInfo.fileSize).sum @@ -450,7 +448,7 @@ class CelebornShuffleReader[K, C]( step * (subPartitionIndex + 1) } var partitionLocationOffset: Long = 0 - val chunkRange = new util.HashMap[String, Pair[Integer, Integer]] + val chunkRange = new JHashMap[String, Pair[Integer, Integer]] for (i <- 0 until locations.size) { val p = locations.get(i) var left = -1 @@ -470,9 +468,41 @@ class CelebornShuffleReader[K, C]( object CelebornShuffleReader { var streamCreatorPool: ThreadPoolExecutor = null - val comparator = new util.Comparator[PartitionLocation] { + val comparator = new Comparator[PartitionLocation] { override def compare(o1: PartitionLocation, o2: PartitionLocation): Int = { o1.getUniqueId.compareTo(o2.getUniqueId) } } + + def splitSkewedPartitionLocations( + locations: JArrayList[PartitionLocation], + subPartitionSize: Int, + subPartitionIndex: Int): JMap[String, Pair[Integer, Integer]] = { + locations.sort(comparator) + val totalPartitionSize = locations.stream.mapToLong((p: PartitionLocation) => p.getStorageInfo.fileSize).sum + val step = totalPartitionSize / subPartitionSize + val startOffset = step * subPartitionIndex + val endOffset = + if (subPartitionIndex == subPartitionSize - 1) { + // last subPartition should include all remaining data + totalPartitionSize + 1 + } else { + step * (subPartitionIndex + 1) + } + var partitionLocationOffset: Long = 0 + val chunkRange = new JHashMap[String, Pair[Integer, Integer]] + for (i <- 0 until locations.size) { + val p = locations.get(i) + var left = -1 + var right = -1 + for (j <- 0 until p.getStorageInfo.getChunkOffsets.size) { + val currentOffset = partitionLocationOffset + p.getStorageInfo.getChunkOffsets.get(j) + if (currentOffset >= startOffset && left < 0) left = j + if (currentOffset < endOffset) right = j + if (left >= 0 && right >= 0) chunkRange.put(p.getUniqueId, Pair.of(left, right)) + } + partitionLocationOffset = partitionLocationOffset + p.getStorageInfo.getFileSize + } + chunkRange + } } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 54c7f0bd2ea..6cd4efa12fe 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -145,7 +145,7 @@ protected Compressor initialValue() { private final boolean dataPushFailureTrackingEnabled; - protected static class ReduceFileGroups { + public static class ReduceFileGroups { public Map> partitionGroups; public Map> pushFailedBatches; public int[] mapAttempts; From 85f9445ab4136488494f336fbfd03267395cbaa7 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 25 Apr 2024 02:38:24 +0800 Subject: [PATCH 24/44] refactor --- client-spark/spark-3-4/pom.xml | 5 - .../celeborn/CelebornShuffleReader.scala | 80 +------- .../celeborn/CelebornPartitionUtilSuiteJ.java | 182 ++++++++++++++++++ project/CelebornBuild.scala | 2 +- 4 files changed, 188 insertions(+), 81 deletions(-) create mode 100644 client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtilSuiteJ.java diff --git a/client-spark/spark-3-4/pom.xml b/client-spark/spark-3-4/pom.xml index 1f89fad9cd6..8ddfd6f2267 100644 --- a/client-spark/spark-3-4/pom.xml +++ b/client-spark/spark-3-4/pom.xml @@ -91,10 +91,5 @@ mockito-core test - - org.mockito - mockito-inline - test - diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 9e91c3461ae..a9f21f75d1f 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -18,12 +18,11 @@ package org.apache.spark.shuffle.celeborn import java.io.IOException -import java.util.{ArrayList => JArrayList, Comparator, Map => JMap, HashMap => JHashMap, Set => JSet} +import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Set => JSet} import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ -import scala.collection.mutable import org.apache.commons.lang3.tuple.Pair import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext} @@ -137,8 +136,10 @@ class CelebornShuffleReader[K, C]( if (fileGroups.partitionGroups.containsKey(partitionId)) { var locations = fileGroups.partitionGroups.get(partitionId) if (splitSkewPartitionWithoutMapRange) { - val partitionLocation2ChunkRange = splitSkewedPartitionLocations( - new JArrayList(locations), startMapIndex, endMapIndex) + val partitionLocation2ChunkRange = CelebornPartitionUtil.splitSkewedPartitionLocations( + new JArrayList(locations), + startMapIndex, + endMapIndex) partitionId2ChunkRange.put(partitionId, partitionLocation2ChunkRange) // filter locations avoid OPEN_STREAM when split skew partition without map range val filterLocations = locations.asScala @@ -430,79 +431,8 @@ class CelebornShuffleReader[K, C]( def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = { dep.serializer.newInstance() } - - def splitSkewedPartitionLocations( - locations: JArrayList[PartitionLocation], - subPartitionSize: Int, - subPartitionIndex: Int): JMap[String, Pair[Integer, Integer]] = { - locations.sort(CelebornShuffleReader.comparator) - val totalPartitionSize = - locations.stream.mapToLong((p: PartitionLocation) => p.getStorageInfo.fileSize).sum - val step = totalPartitionSize / subPartitionSize - val startOffset = step * subPartitionIndex - val endOffset = - if (subPartitionIndex == subPartitionSize - 1) { - // last subPartition should include all remaining data - totalPartitionSize + 1 - } else { - step * (subPartitionIndex + 1) - } - var partitionLocationOffset: Long = 0 - val chunkRange = new JHashMap[String, Pair[Integer, Integer]] - for (i <- 0 until locations.size) { - val p = locations.get(i) - var left = -1 - var right = -1 - for (j <- 0 until p.getStorageInfo.getChunkOffsets.size) { - val currentOffset = partitionLocationOffset + p.getStorageInfo.getChunkOffsets.get(j) - if (currentOffset >= startOffset && left < 0) left = j - if (currentOffset < endOffset) right = j - if (left >= 0 && right >= 0) chunkRange.put(p.getUniqueId, Pair.of(left, right)) - } - partitionLocationOffset = partitionLocationOffset + p.getStorageInfo.getFileSize - } - chunkRange - } - } object CelebornShuffleReader { var streamCreatorPool: ThreadPoolExecutor = null - val comparator = new Comparator[PartitionLocation] { - override def compare(o1: PartitionLocation, o2: PartitionLocation): Int = { - o1.getUniqueId.compareTo(o2.getUniqueId) - } - } - - def splitSkewedPartitionLocations( - locations: JArrayList[PartitionLocation], - subPartitionSize: Int, - subPartitionIndex: Int): JMap[String, Pair[Integer, Integer]] = { - locations.sort(comparator) - val totalPartitionSize = locations.stream.mapToLong((p: PartitionLocation) => p.getStorageInfo.fileSize).sum - val step = totalPartitionSize / subPartitionSize - val startOffset = step * subPartitionIndex - val endOffset = - if (subPartitionIndex == subPartitionSize - 1) { - // last subPartition should include all remaining data - totalPartitionSize + 1 - } else { - step * (subPartitionIndex + 1) - } - var partitionLocationOffset: Long = 0 - val chunkRange = new JHashMap[String, Pair[Integer, Integer]] - for (i <- 0 until locations.size) { - val p = locations.get(i) - var left = -1 - var right = -1 - for (j <- 0 until p.getStorageInfo.getChunkOffsets.size) { - val currentOffset = partitionLocationOffset + p.getStorageInfo.getChunkOffsets.get(j) - if (currentOffset >= startOffset && left < 0) left = j - if (currentOffset < endOffset) right = j - if (left >= 0 && right >= 0) chunkRange.put(p.getUniqueId, Pair.of(left, right)) - } - partitionLocationOffset = partitionLocationOffset + p.getStorageInfo.getFileSize - } - chunkRange - } } diff --git a/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtilSuiteJ.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtilSuiteJ.java new file mode 100644 index 00000000000..9f84b576e2b --- /dev/null +++ b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtilSuiteJ.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.Assert; +import org.junit.Test; + +import org.apache.celeborn.common.protocol.PartitionLocation; +import org.apache.celeborn.common.protocol.StorageInfo; + +public class CelebornPartitionUtilSuiteJ { + @Test + public void testSkewPartitionSplit() { + + ArrayList locations = new ArrayList<>(); + for (int i = 0; i < 13; i++) { + PartitionLocation location = + new PartitionLocation(0, i, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); + StorageInfo storageInfo = + new StorageInfo( + StorageInfo.Type.HDD, + "mountPoint", + false, + "filePath", + StorageInfo.LOCAL_DISK_MASK, + 1000, + Arrays.asList(0L, 100L, 200L, 300L, 500L, 1000L)); + location.setStorageInfo(storageInfo); + locations.add(location); + } + + PartitionLocation location = + new PartitionLocation(0, 91, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); + StorageInfo storageInfo = + new StorageInfo( + StorageInfo.Type.HDD, + "mountPoint", + false, + "filePath", + StorageInfo.LOCAL_DISK_MASK, + 1, + Arrays.asList(0L, 1L)); + location.setStorageInfo(storageInfo); + locations.add(location); + + int subPartitionSize = 3; + + int subPartitionIndex = 0; + Map> result1 = + CelebornPartitionUtil.splitSkewedPartitionLocations( + locations, subPartitionSize, subPartitionIndex); + Map> expectResult1 = + Map.ofEntries( + Map.entry("0-0", Pair.of(0, 4)), + Map.entry("0-1", Pair.of(0, 4)), + Map.entry("0-10", Pair.of(0, 4)), + Map.entry("0-11", Pair.of(0, 4)), + Map.entry("0-12", Pair.of(0, 2))); + Assert.assertEquals(expectResult1, result1); + + subPartitionIndex = 1; + Map> result2 = + CelebornPartitionUtil.splitSkewedPartitionLocations( + locations, subPartitionSize, subPartitionIndex); + Map> expectResult2 = + Map.ofEntries( + Map.entry("0-12", Pair.of(3, 4)), + Map.entry("0-2", Pair.of(0, 4)), + Map.entry("0-3", Pair.of(0, 4)), + Map.entry("0-4", Pair.of(0, 4)), + Map.entry("0-5", Pair.of(0, 3))); + Assert.assertEquals(expectResult2, result2); + + subPartitionIndex = 2; + Map> result3 = + CelebornPartitionUtil.splitSkewedPartitionLocations( + locations, subPartitionSize, subPartitionIndex); + Map> expectResult3 = + Map.ofEntries( + Map.entry("0-5", Pair.of(4, 4)), + Map.entry("0-6", Pair.of(0, 4)), + Map.entry("0-7", Pair.of(0, 4)), + Map.entry("0-8", Pair.of(0, 4)), + Map.entry("0-9", Pair.of(0, 4)), + Map.entry("0-91", Pair.of(0, 0))); + Assert.assertEquals(expectResult3, result3); + } + + @Test + public void testBoundary() { + ArrayList locations = new ArrayList<>(); + PartitionLocation location = + new PartitionLocation(0, 0, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); + StorageInfo storageInfo = + new StorageInfo( + StorageInfo.Type.HDD, + "mountPoint", + false, + "filePath", + StorageInfo.LOCAL_DISK_MASK, + 500, + Arrays.asList(0L, 100L, 200L, 300L, 400L, 500L)); + location.setStorageInfo(storageInfo); + locations.add(location); + + for (int i = 0; i < 5; i++) { + Map> result = + CelebornPartitionUtil.splitSkewedPartitionLocations(locations, 5, i); + Map> expectResult = + Map.ofEntries(Map.entry("0-0", Pair.of(i, i))); + Assert.assertEquals(expectResult, result); + } + } + + @Test + public void testSplitStable() { + ArrayList locations = new ArrayList<>(); + for (int i = 0; i < 13; i++) { + PartitionLocation location = + new PartitionLocation(0, i, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); + StorageInfo storageInfo = + new StorageInfo( + StorageInfo.Type.HDD, + "mountPoint", + false, + "filePath", + StorageInfo.LOCAL_DISK_MASK, + 1000, + Arrays.asList(0L, 100L, 200L, 300L, 500L, 1000L)); + location.setStorageInfo(storageInfo); + locations.add(location); + } + + PartitionLocation location = + new PartitionLocation(0, 91, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); + StorageInfo storageInfo = + new StorageInfo( + StorageInfo.Type.HDD, + "mountPoint", + false, + "filePath", + StorageInfo.LOCAL_DISK_MASK, + 1, + Arrays.asList(0L, 1L)); + location.setStorageInfo(storageInfo); + locations.add(location); + + Collections.shuffle(locations); + + Map> result = + CelebornPartitionUtil.splitSkewedPartitionLocations(locations, 3, 0); + Map> expectResult = + Map.ofEntries( + Map.entry("0-0", Pair.of(0, 4)), + Map.entry("0-1", Pair.of(0, 4)), + Map.entry("0-10", Pair.of(0, 4)), + Map.entry("0-11", Pair.of(0, 4)), + Map.entry("0-12", Pair.of(0, 2))); + Assert.assertEquals(expectResult, result); + } +} diff --git a/project/CelebornBuild.scala b/project/CelebornBuild.scala index 16f14f7044f..63fe4249c1a 100644 --- a/project/CelebornBuild.scala +++ b/project/CelebornBuild.scala @@ -942,7 +942,7 @@ trait SparkClientProjects { libraryDependencies ++= Seq( "org.apache.spark" %% "spark-core" % sparkVersion % "provided", "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", - ) ++ commonUnitTestDependencies ++ Seq(Dependencies.mockitoInline % "test") + ) ++ commonUnitTestDependencies ) } From 697b468c749c38030c68884f401de1cce4bb07d5 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 25 Apr 2024 02:41:42 +0800 Subject: [PATCH 25/44] add CelebornPartitionUtil.java --- .../celeborn/CelebornPartitionUtil.java | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java new file mode 100644 index 00000000000..52a912d081c --- /dev/null +++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.lang3.tuple.Pair; + +import org.apache.celeborn.common.protocol.PartitionLocation; + +public class CelebornPartitionUtil { + public static Map> splitSkewedPartitionLocations( + ArrayList locations, int subPartitionSize, int subPartitionIndex) { + locations.sort(Comparator.comparing((PartitionLocation p) -> p.getUniqueId())); + long totalPartitionSize = + locations.stream().mapToLong((PartitionLocation p) -> p.getStorageInfo().fileSize).sum(); + long step = totalPartitionSize / subPartitionSize; + long startOffset = step * subPartitionIndex; + long endOffset = 0; + if (subPartitionIndex == subPartitionSize - 1) { + // last subPartition should include all remaining data + endOffset = totalPartitionSize + 1; + } else { + endOffset = step * (subPartitionIndex + 1); + } + + long partitionLocationOffset = 0; + Map> chunkRange = new HashMap<>(); + for (int i = 0; i < locations.size(); i++) { + PartitionLocation p = locations.get(i); + int left = -1; + int right = -1; + // Start from index 1 since the first chunk offset is always 0. + for (int j = 1; j < p.getStorageInfo().getChunkOffsets().size(); j++) { + long currentOffset = partitionLocationOffset + p.getStorageInfo().getChunkOffsets().get(j); + if (currentOffset > startOffset && left < 0) { + left = j - 1; + } + if (currentOffset <= endOffset) { + right = j - 1; + } + if (left >= 0 && right >= 0) { + chunkRange.put(p.getUniqueId(), Pair.of(left, right)); + } + } + partitionLocationOffset += p.getStorageInfo().getFileSize(); + } + return chunkRange; + } +} From 2b6349d3553519ea2433b4c28ee7f840628209dd Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Fri, 26 Apr 2024 01:17:36 +0800 Subject: [PATCH 26/44] fix ut for jdk 8 --- .../celeborn/CelebornPartitionUtilSuiteJ.java | 190 ++++++++---------- 1 file changed, 87 insertions(+), 103 deletions(-) diff --git a/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtilSuiteJ.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtilSuiteJ.java index 9f84b576e2b..989dd31a94a 100644 --- a/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtilSuiteJ.java +++ b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtilSuiteJ.java @@ -17,10 +17,7 @@ package org.apache.spark.shuffle.celeborn; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Map; +import java.util.*; import org.apache.commons.lang3.tuple.Pair; import org.junit.Assert; @@ -35,100 +32,62 @@ public void testSkewPartitionSplit() { ArrayList locations = new ArrayList<>(); for (int i = 0; i < 13; i++) { - PartitionLocation location = - new PartitionLocation(0, i, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); - StorageInfo storageInfo = - new StorageInfo( - StorageInfo.Type.HDD, - "mountPoint", - false, - "filePath", - StorageInfo.LOCAL_DISK_MASK, - 1000, - Arrays.asList(0L, 100L, 200L, 300L, 500L, 1000L)); - location.setStorageInfo(storageInfo); - locations.add(location); + locations.add(genPartitionLocation(i, new Long[] {0L, 100L, 200L, 300L, 500L, 1000L})); } - - PartitionLocation location = - new PartitionLocation(0, 91, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); - StorageInfo storageInfo = - new StorageInfo( - StorageInfo.Type.HDD, - "mountPoint", - false, - "filePath", - StorageInfo.LOCAL_DISK_MASK, - 1, - Arrays.asList(0L, 1L)); - location.setStorageInfo(storageInfo); - locations.add(location); + locations.add(genPartitionLocation(91, new Long[] {0L, 1L})); int subPartitionSize = 3; - int subPartitionIndex = 0; Map> result1 = - CelebornPartitionUtil.splitSkewedPartitionLocations( - locations, subPartitionSize, subPartitionIndex); + CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 0); Map> expectResult1 = - Map.ofEntries( - Map.entry("0-0", Pair.of(0, 4)), - Map.entry("0-1", Pair.of(0, 4)), - Map.entry("0-10", Pair.of(0, 4)), - Map.entry("0-11", Pair.of(0, 4)), - Map.entry("0-12", Pair.of(0, 2))); + genRanges( + new Object[][] { + {"0-0", 0, 4}, + {"0-1", 0, 4}, + {"0-10", 0, 4}, + {"0-11", 0, 4}, + {"0-12", 0, 2} + }); Assert.assertEquals(expectResult1, result1); - subPartitionIndex = 1; Map> result2 = - CelebornPartitionUtil.splitSkewedPartitionLocations( - locations, subPartitionSize, subPartitionIndex); + CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 1); Map> expectResult2 = - Map.ofEntries( - Map.entry("0-12", Pair.of(3, 4)), - Map.entry("0-2", Pair.of(0, 4)), - Map.entry("0-3", Pair.of(0, 4)), - Map.entry("0-4", Pair.of(0, 4)), - Map.entry("0-5", Pair.of(0, 3))); + genRanges( + new Object[][] { + {"0-12", 3, 4}, + {"0-2", 0, 4}, + {"0-3", 0, 4}, + {"0-4", 0, 4}, + {"0-5", 0, 3} + }); Assert.assertEquals(expectResult2, result2); - subPartitionIndex = 2; Map> result3 = - CelebornPartitionUtil.splitSkewedPartitionLocations( - locations, subPartitionSize, subPartitionIndex); + CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 2); Map> expectResult3 = - Map.ofEntries( - Map.entry("0-5", Pair.of(4, 4)), - Map.entry("0-6", Pair.of(0, 4)), - Map.entry("0-7", Pair.of(0, 4)), - Map.entry("0-8", Pair.of(0, 4)), - Map.entry("0-9", Pair.of(0, 4)), - Map.entry("0-91", Pair.of(0, 0))); + genRanges( + new Object[][] { + {"0-5", 4, 4}, + {"0-6", 0, 4}, + {"0-7", 0, 4}, + {"0-8", 0, 4}, + {"0-9", 0, 4}, + {"0-91", 0, 0} + }); Assert.assertEquals(expectResult3, result3); } @Test public void testBoundary() { ArrayList locations = new ArrayList<>(); - PartitionLocation location = - new PartitionLocation(0, 0, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); - StorageInfo storageInfo = - new StorageInfo( - StorageInfo.Type.HDD, - "mountPoint", - false, - "filePath", - StorageInfo.LOCAL_DISK_MASK, - 500, - Arrays.asList(0L, 100L, 200L, 300L, 400L, 500L)); - location.setStorageInfo(storageInfo); - locations.add(location); + locations.add(genPartitionLocation(0, new Long[] {0L, 100L, 200L, 300L, 400L, 500L})); for (int i = 0; i < 5; i++) { Map> result = CelebornPartitionUtil.splitSkewedPartitionLocations(locations, 5, i); - Map> expectResult = - Map.ofEntries(Map.entry("0-0", Pair.of(i, i))); + Map> expectResult = genRanges(new Object[][] {{"0-0", i, i}}); Assert.assertEquals(expectResult, result); } } @@ -137,23 +96,51 @@ public void testBoundary() { public void testSplitStable() { ArrayList locations = new ArrayList<>(); for (int i = 0; i < 13; i++) { - PartitionLocation location = - new PartitionLocation(0, i, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); - StorageInfo storageInfo = - new StorageInfo( - StorageInfo.Type.HDD, - "mountPoint", - false, - "filePath", - StorageInfo.LOCAL_DISK_MASK, - 1000, - Arrays.asList(0L, 100L, 200L, 300L, 500L, 1000L)); - location.setStorageInfo(storageInfo); - locations.add(location); + locations.add(genPartitionLocation(i, new Long[] {0L, 100L, 200L, 300L, 500L, 1000L})); } + locations.add(genPartitionLocation(91, new Long[] {0L, 1L})); + Collections.shuffle(locations); + + Map> result = + CelebornPartitionUtil.splitSkewedPartitionLocations(locations, 3, 0); + Map> expectResult = + genRanges( + new Object[][] { + {"0-0", 0, 4}, + {"0-1", 0, 4}, + {"0-10", 0, 4}, + {"0-11", 0, 4}, + {"0-12", 0, 2} + }); + Assert.assertEquals(expectResult, result); + } + + private ArrayList genPartitionLocations(Map epochToOffsets) { + ArrayList locations = new ArrayList<>(); + epochToOffsets.forEach( + (epoch, offsets) -> { + PartitionLocation location = + new PartitionLocation( + 0, epoch, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); + StorageInfo storageInfo = + new StorageInfo( + StorageInfo.Type.HDD, + "mountPoint", + false, + "filePath", + StorageInfo.LOCAL_DISK_MASK, + 1, + Arrays.asList(offsets)); + location.setStorageInfo(storageInfo); + locations.add(location); + }); + return locations; + } + + private PartitionLocation genPartitionLocation(int epoch, Long[] offsets) { PartitionLocation location = - new PartitionLocation(0, 91, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); + new PartitionLocation(0, epoch, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); StorageInfo storageInfo = new StorageInfo( StorageInfo.Type.HDD, @@ -161,22 +148,19 @@ public void testSplitStable() { false, "filePath", StorageInfo.LOCAL_DISK_MASK, - 1, - Arrays.asList(0L, 1L)); + offsets[offsets.length - 1], + Arrays.asList(offsets)); location.setStorageInfo(storageInfo); - locations.add(location); - - Collections.shuffle(locations); + return location; + } - Map> result = - CelebornPartitionUtil.splitSkewedPartitionLocations(locations, 3, 0); - Map> expectResult = - Map.ofEntries( - Map.entry("0-0", Pair.of(0, 4)), - Map.entry("0-1", Pair.of(0, 4)), - Map.entry("0-10", Pair.of(0, 4)), - Map.entry("0-11", Pair.of(0, 4)), - Map.entry("0-12", Pair.of(0, 2))); - Assert.assertEquals(expectResult, result); + private Map> genRanges(Object[][] inputs) { + Map> ranges = new HashMap<>(); + for (Object[] idToChunkRange : inputs) { + String uid = (String) idToChunkRange[0]; + Pair range = Pair.of((int) idToChunkRange[1], (int) idToChunkRange[2]); + ranges.put(uid, range); + } + return ranges; } } From 8155619e97d35670f05cad842172acd522bf4190 Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Sun, 2 Jun 2024 22:27:01 +0800 Subject: [PATCH 27/44] rebase main and fix npe --- .../celeborn/common/protocol/StorageInfo.java | 4 + common/src/main/proto/TransportMessages.proto | 6 ++ .../celeborn/common/util/PbSerDeUtils.scala | 12 ++- .../common/util/PbSerDeUtilsTest.scala | 98 ++++++++++++++++++- 4 files changed, 117 insertions(+), 3 deletions(-) diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java b/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java index 038219b3389..8509d5717b2 100644 --- a/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java +++ b/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java @@ -168,6 +168,10 @@ public String toString() { + finalResult + ", filePath=" + filePath + + ", fileSize=" + + fileSize + + ", chunkOffsets=" + + chunkOffsets + '}'; } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index f8dd0824223..22d2a6150f1 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -867,6 +867,8 @@ message PbPackedPartitionLocations { repeated string filePaths = 10; repeated int32 availableStorageTypes = 11; repeated int32 modes = 12; + repeated int64 fileSizes = 13; + repeated PbChunkOffsets chunksOffsets = 14; } message PbPackedPartitionLocationsPair { @@ -900,3 +902,7 @@ message PbPushMergedDataSplitPartitionInfo { repeated int32 splitPartitionIndexes = 1; repeated int32 statusCodes = 2; } + +message PbChunkOffsets { + repeated int64 chunkOffset = 1; +} diff --git a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala index 87766e71e49..553a4b9f9e0 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala @@ -525,6 +525,14 @@ object PbSerDeUtils { pbPackedLocationsBuilder.addFilePaths("") } pbPackedLocationsBuilder.addAvailableStorageTypes(location.getStorageInfo.availableStorageTypes) + pbPackedLocationsBuilder.addFileSizes(location.getStorageInfo.getFileSize) + val chunkOffsets = PbChunkOffsets.newBuilder() + if (null != location.getStorageInfo.chunkOffsets && !location.getStorageInfo.chunkOffsets.isEmpty) { + chunkOffsets.addAllChunkOffset(location.getStorageInfo.chunkOffsets).build() + pbPackedLocationsBuilder.addChunksOffsets(chunkOffsets) + } else { + pbPackedLocationsBuilder.addChunksOffsets(chunkOffsets.build()) + } pbPackedLocationsBuilder.addModes(location.getMode.mode()) } @@ -641,7 +649,9 @@ object PbSerDeUtils { pbPackedPartitionLocations.getMountPoints(index)), pbPackedPartitionLocations.getFinalResult(index), filePath, - pbPackedPartitionLocations.getAvailableStorageTypes(index)), + pbPackedPartitionLocations.getAvailableStorageTypes(index), + pbPackedPartitionLocations.getFileSizes(index), + pbPackedPartitionLocations.getChunksOffsets(index).getChunkOffsetList), Utils.byteStringToRoaringBitmap(pbPackedPartitionLocations.getMapIdBitMap(index))) } diff --git a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala index 280e459fc77..c2c6e47338b 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala @@ -24,10 +24,9 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random +import com.google.common.collect.{Lists, Sets} import org.apache.hadoop.shaded.org.apache.commons.lang3.RandomStringUtils -import com.google.common.collect.Sets - import org.apache.celeborn.CelebornFunSuite import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.meta._ @@ -175,6 +174,37 @@ class PbSerDeUtilsTest extends CelebornFunSuite { 27, PartitionLocation.Mode.PRIMARY) + val partitionLocation5 = + new PartitionLocation( + 4, + 4, + "host5", + 50, + 49, + 48, + 47, + PartitionLocation.Mode.PRIMARY) + val partitionLocation6 = + new PartitionLocation( + 5, + 5, + "host6", + 60, + 59, + 58, + 57, + PartitionLocation.Mode.REPLICA, + null, + new StorageInfo( + StorageInfo.Type.HDD, + "", + false, + null, + StorageInfo.LOCAL_DISK_MASK, + 5, + null), + null) + val workerResource = new WorkerResource() workerResource.put( workerInfo1, @@ -372,6 +402,70 @@ class PbSerDeUtilsTest extends CelebornFunSuite { assert(partitionLocation3 == loc1) } + test("testPackedPartitionLocationPairCase3") { + partitionLocation5.setStorageInfo(new StorageInfo( + StorageInfo.Type.HDD, + "", + false, + null, + StorageInfo.LOCAL_DISK_MASK, + 5, + Lists.newArrayList(0, 5, 10))) + partitionLocation5.setPeer(partitionLocation6) + val pairPb = PbSerDeUtils.toPbPackedPartitionLocationsPair( + List(partitionLocation5, partitionLocation6)) + val rePb = PbSerDeUtils.fromPbPackedPartitionLocationsPair(pairPb) + + val loc1 = rePb._1.get(0) + val loc2 = rePb._2.get(0) + + assert(partitionLocation5 == loc1) + assert(partitionLocation6 == loc2) + assert(loc1.getStorageInfo.getFileSize == partitionLocation5.getStorageInfo.getFileSize) + assert(loc1.getStorageInfo.getChunkOffsets == partitionLocation5.getStorageInfo.getChunkOffsets) + + assert(loc2.getStorageInfo.getFileSize == partitionLocation6.getStorageInfo.getFileSize) + assert(loc2.getStorageInfo.getChunkOffsets.isEmpty) + } + + test("testPackedPartitionLocationPairCase4") { + partitionLocation5.setStorageInfo(new StorageInfo( + StorageInfo.Type.HDD, + "", + false, + null, + StorageInfo.LOCAL_DISK_MASK, + 5, + null)) + val pairPb = PbSerDeUtils.toPbPackedPartitionLocationsPair( + List(partitionLocation5)) + val rePb = PbSerDeUtils.fromPbPackedPartitionLocationsPair(pairPb) + + val loc1 = rePb._1.get(0) + + assert(partitionLocation5 == loc1) + assert(loc1.getStorageInfo.getFileSize == partitionLocation5.getStorageInfo.getFileSize) + assert(loc1.getStorageInfo.getChunkOffsets.isEmpty) + } + + test("testPackedPartitionLocationPairCase5") { + partitionLocation5.setStorageInfo(new StorageInfo( + StorageInfo.Type.HDD, + "", + false, + null, + StorageInfo.LOCAL_DISK_MASK)) + val pairPb = PbSerDeUtils.toPbPackedPartitionLocationsPair( + List(partitionLocation5)) + val rePb = PbSerDeUtils.fromPbPackedPartitionLocationsPair(pairPb) + + val loc1 = rePb._1.get(0) + + assert(partitionLocation5 == loc1) + assert(loc1.getStorageInfo.getFileSize == partitionLocation5.getStorageInfo.getFileSize) + assert(loc1.getStorageInfo.getChunkOffsets.isEmpty) + } + test("testPackedPartitionLocationPairIPv6") { val pairPb = PbSerDeUtils.toPbPackedPartitionLocationsPair( List(partitionLocationIPv6)) From aa30ff62d0b60d6ee7183ec42f2c072929e19b58 Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Mon, 3 Jun 2024 09:41:50 +0800 Subject: [PATCH 28/44] fix ut --- .../celeborn/common/protocol/PartitionLocationSuiteJ.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java index 6f63e487056..0d0613a33ec 100644 --- a/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java +++ b/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java @@ -209,7 +209,7 @@ public void testToStringOutput() { + " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n" + " mode:PRIMARY\n" + " peer:(empty)\n" - + " storage hint:StorageInfo{type=MEMORY, mountPoint='', finalResult=false, filePath=null}\n" + + " storage hint:StorageInfo{type=MEMORY, mountPoint='', finalResult=false, filePath=null, fileSize=0, chunkOffsets=null}\n" + " mapIdBitMap:{}]"; String exp2 = "PartitionLocation[\n" @@ -217,7 +217,7 @@ public void testToStringOutput() { + " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n" + " mode:PRIMARY\n" + " peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n" - + " storage hint:StorageInfo{type=MEMORY, mountPoint='', finalResult=false, filePath=null}\n" + + " storage hint:StorageInfo{type=MEMORY, mountPoint='', finalResult=false, filePath=null, fileSize=0, chunkOffsets=null}\n" + " mapIdBitMap:{}]"; String exp3 = "PartitionLocation[\n" @@ -225,7 +225,8 @@ public void testToStringOutput() { + " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n" + " mode:PRIMARY\n" + " peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n" - + " storage hint:StorageInfo{type=MEMORY, mountPoint='/mnt/disk/0', finalResult=false, filePath=null}\n" + + " storage hint:StorageInfo{type=MEMORY, mountPoint='/mnt/disk/0', " + + "finalResult=false, filePath=null, fileSize=0, chunkOffsets=null}\n" + " mapIdBitMap:{1,2,3}]"; assertEquals(exp1, location1.toString()); assertEquals(exp2, location2.toString()); From 6b4bb4058e8304c0b549078a0e3a44a9ea24c7ca Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Mon, 3 Jun 2024 11:51:06 +0800 Subject: [PATCH 29/44] fix npe when memory storage enabled --- .../apache/celeborn/service/deploy/worker/Controller.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala index 53b73da91d0..e6e8bed5e4d 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala @@ -338,7 +338,12 @@ private[deploy] class Controller( logDebug(s"Location $uniqueId is deleted.") } else { val storageInfo = fileWriter.getStorageInfo - val fileMeta = fileWriter.getDiskFileInfo.getFileMeta + val fileInfo = if (null != fileWriter.getDiskFileInfo) { + fileWriter.getDiskFileInfo + } else { + fileWriter.getMemoryFileInfo + } + val fileMeta = fileInfo.getFileMeta fileMeta match { case meta: ReduceFileMeta => storageInfo.setFileSize(bytes) From 59bef4588704cd76a44759cefb785e9982b8acc1 Mon Sep 17 00:00:00 2001 From: wangshengjie Date: Mon, 3 Jun 2024 11:55:16 +0800 Subject: [PATCH 30/44] fix code style check error --- .../celeborn/service/deploy/worker/Controller.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala index e6e8bed5e4d..32b6523af54 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala @@ -338,11 +338,12 @@ private[deploy] class Controller( logDebug(s"Location $uniqueId is deleted.") } else { val storageInfo = fileWriter.getStorageInfo - val fileInfo = if (null != fileWriter.getDiskFileInfo) { - fileWriter.getDiskFileInfo - } else { - fileWriter.getMemoryFileInfo - } + val fileInfo = + if (null != fileWriter.getDiskFileInfo) { + fileWriter.getDiskFileInfo + } else { + fileWriter.getMemoryFileInfo + } val fileMeta = fileInfo.getFileMeta fileMeta match { case meta: ReduceFileMeta => From 031fdb3bc416d04822ca9dd75cd44a2bb0296f2f Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Mon, 14 Oct 2024 16:22:32 +0800 Subject: [PATCH 31/44] fix comile error --- .../org/apache/celeborn/service/deploy/worker/Controller.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala index 32b6523af54..b2e7c4e844c 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala @@ -32,7 +32,7 @@ import org.roaringbitmap.RoaringBitmap import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.internal.Logging -import org.apache.celeborn.common.meta.{WorkerInfo, WorkerPartitionLocationInfo} +import org.apache.celeborn.common.meta.{ReduceFileMeta, WorkerInfo, WorkerPartitionLocationInfo} import org.apache.celeborn.common.metrics.MetricsSystem import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType, StorageInfo} import org.apache.celeborn.common.protocol.message.ControlMessages._ From 3966f490d1e25894efc143514d6da9b11815cdfe Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Tue, 26 Nov 2024 21:11:12 +0800 Subject: [PATCH 32/44] add spark patch --- ...rn-Optimize-Skew-Partitions-spark3_2.patch | 237 ++++++++++++++++++ ...rn-Optimize-Skew-Partitions-spark3_3.patch | 199 ++++++++++++--- ...rn-Optimize-Skew-Partitions-spark3_4.patch | 233 +++++++++++++++++ ...rn-Optimize-Skew-Partitions-spark3_5.patch | 233 +++++++++++++++++ 4 files changed, 868 insertions(+), 34 deletions(-) create mode 100644 assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch create mode 100644 assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch create mode 100644 assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch new file mode 100644 index 00000000000..eadcfa4668e --- /dev/null +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch @@ -0,0 +1,237 @@ +diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +index e469c9989f2..245d9b3b9de 100644 +--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala ++++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +@@ -661,6 +661,8 @@ private[spark] class MapOutputTrackerMaster( + pool + } + ++ val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() ++ + // Make sure that we aren't going to exceed the max RPC message size by making sure + // we use broadcast to send large map output statuses. + if (minSizeForBroadcast > maxRpcMessageSize) { +@@ -839,6 +841,7 @@ private[spark] class MapOutputTrackerMaster( + shuffleStatus.invalidateSerializedMergeOutputStatusCache() + } + } ++ skewShuffleIds.remove(shuffleId) + } + + /** +diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +index b950c07f3d8..e9e10bb647f 100644 +--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +@@ -1369,7 +1369,10 @@ private[spark] class DAGScheduler( + // The operation here can make sure for the partially completed intermediate stage, + // `findMissingPartitions()` returns all partitions every time. + stage match { +- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => ++ case sms: ShuffleMapStage if (stage.isIndeterminate && !sms.isAvailable) || ++ mapOutputTracker.skewShuffleIds.contains(sms.shuffleDep.shuffleId) => ++ logInfo(s"Unregistering shuffle output for stage ${stage.id}" + ++ s" shuffle ${sms.shuffleDep.shuffleId}") + mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) + sms.shuffleDep.newShuffleMergeState() + case _ => +@@ -1780,7 +1783,8 @@ private[spark] class DAGScheduler( + failedStage.failedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || +- disallowStageRetryForTest ++ disallowStageRetryForTest || (failedStage.isInstanceOf[ResultStage] ++ && mapOutputTracker.skewShuffleIds.contains(shuffleId)) + + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +index 6bc8ba4eebb..2e7d87c96eb 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +@@ -3431,6 +3431,12 @@ object SQLConf { + .booleanConf + .createWithDefault(false) + ++ val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = ++ buildConf("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") ++ .version("3.0.0") ++ .booleanConf ++ .createWithDefault(false) ++ + /** + * Holds information about keys that have been deprecated. + * +@@ -4154,6 +4160,9 @@ class SQLConf extends Serializable with Logging { + + def legacyParquetNanosAsLong: Boolean = getConf(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG) + ++ def celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = ++ getConf(SQLConf.CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ) ++ + /** ********************** SQLConf functionality methods ************ */ + + /** Set Spark SQL configuration properties. */ +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +new file mode 100644 +index 00000000000..3dc60678461 +--- /dev/null ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +@@ -0,0 +1,35 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.sql.execution.adaptive ++ ++import java.util.Locale ++ ++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} ++ ++object CelebornShuffleUtil { ++ ++ def isCelebornShuffle(shuffleExchangeLike: ShuffleExchangeLike): Boolean = { ++ shuffleExchangeLike match { ++ case exec: ShuffleExchangeExec => ++ exec.shuffleDependency.shuffleHandle ++ .getClass.getName.toLowerCase(Locale.ROOT).contains("celeborn") ++ case _ => false ++ } ++ } ++ ++} +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +index 1752907a9a5..2a0fe20c104 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +@@ -50,12 +50,13 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + private def optimizeSkewedPartitions( + shuffleId: Int, + bytesByPartitionId: Array[Long], +- targetSize: Long): Seq[ShufflePartitionSpec] = { ++ targetSize: Long, ++ isCelebornShuffle: Boolean = false): Seq[ShufflePartitionSpec] = { + bytesByPartitionId.indices.flatMap { reduceIndex => + val bytes = bytesByPartitionId(reduceIndex) + if (bytes > targetSize) { +- val newPartitionSpec = +- ShufflePartitionsUtil.createSkewPartitionSpecs(shuffleId, reduceIndex, targetSize) ++ val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs( ++ shuffleId, reduceIndex, targetSize, isCelebornShuffle) + if (newPartitionSpec.isEmpty) { + CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil + } else { +@@ -77,8 +78,9 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + return shuffle + } + ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(shuffle.shuffle) + val newPartitionsSpec = optimizeSkewedPartitions( +- mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize) ++ mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize, isCelebornShuffle) + // return origin plan if we can not optimize partitions + if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) { + shuffle +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +index 88abe68197b..150699a84a3 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +@@ -157,8 +157,10 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule { + Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize)) + + val leftParts = if (isLeftSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(left.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- left.mapStats.get.shuffleId, partitionIndex, leftTargetSize) ++ left.mapStats.get.shuffleId, partitionIndex, leftTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Left side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " + +@@ -171,8 +173,10 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule { + } + + val rightParts = if (isRightSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(right.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- right.mapStats.get.shuffleId, partitionIndex, rightTargetSize) ++ right.mapStats.get.shuffleId, partitionIndex, rightTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Right side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +index 3609548f374..f7c6d5dda90 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} + import org.apache.spark.internal.Logging + import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} ++import org.apache.spark.sql.internal.SQLConf + + object ShufflePartitionsUtil extends Logging { + final val SMALL_PARTITION_FACTOR = 0.2 +@@ -376,11 +377,25 @@ object ShufflePartitionsUtil extends Logging { + def createSkewPartitionSpecs( + shuffleId: Int, + reducerId: Int, +- targetSize: Long): Option[Seq[PartialReducerPartitionSpec]] = { ++ targetSize: Long, ++ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = { + val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId) + if (mapPartitionSizes.exists(_ < 0)) return None + val mapStartIndices = splitSizeListByTargetSize(mapPartitionSizes, targetSize) + if (mapStartIndices.length > 1) { ++ val isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = ++ SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle ++ ++ val throwsFetchFailure = SparkEnv.get ++ .conf ++ .get("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false") ++ .toBoolean ++ if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") ++ val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] ++ mapOutputTracker.skewShuffleIds.add(shuffleId) ++ } + Some(mapStartIndices.indices.map { i => + val startMapIndex = mapStartIndices(i) + val endMapIndex = if (i == mapStartIndices.length - 1) { +@@ -388,8 +403,20 @@ object ShufflePartitionsUtil extends Logging { + } else { + mapStartIndices(i + 1) + } +- val dataSize = startMapIndex.until(endMapIndex).map(mapPartitionSizes(_)).sum +- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ var dataSize = 0L ++ var mapIndex = startMapIndex ++ while (mapIndex < endMapIndex) { ++ dataSize += mapPartitionSizes(mapIndex) ++ mapIndex += 1 ++ } ++ if (isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ // These `dataSize` variables may not be accurate as they only represent the sum of ++ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. ++ // Please not to use these dataSize variables in any other part of the codebase. ++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize) ++ } else { ++ PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ } + }) + } else { + None diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch index 86c8ca4fa21..3a6ac262ea3 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch @@ -1,18 +1,50 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +index b1974948430..0dc92ec44a8 100644 +--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala ++++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +@@ -696,6 +696,8 @@ private[spark] class MapOutputTrackerMaster( + pool + } + ++ val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() ++ + // Make sure that we aren't going to exceed the max RPC message size by making sure + // we use broadcast to send large map output statuses. + if (minSizeForBroadcast > maxRpcMessageSize) { +@@ -886,6 +888,7 @@ private[spark] class MapOutputTrackerMaster( + shuffleStatus.invalidateSerializedMergeOutputStatusCache() + } + } ++ skewShuffleIds.remove(shuffleId) + } + + /** +diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +index bd2823bcac1..5d81b9de5b6 100644 +--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +@@ -1404,7 +1404,10 @@ private[spark] class DAGScheduler( + // The operation here can make sure for the partially completed intermediate stage, + // `findMissingPartitions()` returns all partitions every time. + stage match { +- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => ++ case sms: ShuffleMapStage if (stage.isIndeterminate && !sms.isAvailable) || ++ mapOutputTracker.skewShuffleIds.contains(sms.shuffleDep.shuffleId) => ++ logInfo(s"Unregistering shuffle output for stage ${stage.id}" + ++ s" shuffle ${sms.shuffleDep.shuffleId}") + mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) + sms.shuffleDep.newShuffleMergeState() + case _ => +@@ -1851,7 +1854,8 @@ private[spark] class DAGScheduler( + failedStage.failedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || +- disallowStageRetryForTest ++ disallowStageRetryForTest || (failedStage.isInstanceOf[ResultStage] ++ && mapOutputTracker.skewShuffleIds.contains(shuffleId)) + + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index af03ad9a4cb..7a3ee9ebfaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -40,43 +72,142 @@ index af03ad9a4cb..7a3ee9ebfaf 100644 /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +new file mode 100644 +index 00000000000..3dc60678461 +--- /dev/null ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +@@ -0,0 +1,35 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.sql.execution.adaptive ++ ++import java.util.Locale ++ ++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} ++ ++object CelebornShuffleUtil { ++ ++ def isCelebornShuffle(shuffleExchangeLike: ShuffleExchangeLike): Boolean = { ++ shuffleExchangeLike match { ++ case exec: ShuffleExchangeExec => ++ exec.shuffleDependency.shuffleHandle ++ .getClass.getName.toLowerCase(Locale.ROOT).contains("celeborn") ++ case _ => false ++ } ++ } ++ ++} +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +index b34ab3e380b..cb0ed9d05a4 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +@@ -47,14 +47,15 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + private def optimizeSkewedPartitions( + shuffleId: Int, + bytesByPartitionId: Array[Long], +- targetSize: Long): Seq[ShufflePartitionSpec] = { ++ targetSize: Long, ++ isCelebornShuffle: Boolean = false): Seq[ShufflePartitionSpec] = { + val smallPartitionFactor = + conf.getConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR) + bytesByPartitionId.indices.flatMap { reduceIndex => + val bytes = bytesByPartitionId(reduceIndex) + if (bytes > targetSize) { + val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs( +- shuffleId, reduceIndex, targetSize, smallPartitionFactor) ++ shuffleId, reduceIndex, targetSize, smallPartitionFactor, isCelebornShuffle) + if (newPartitionSpec.isEmpty) { + CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil + } else { +@@ -76,8 +77,9 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + return shuffle + } + ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(shuffle.shuffle) + val newPartitionsSpec = optimizeSkewedPartitions( +- mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize) ++ mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize, isCelebornShuffle) + // return origin plan if we can not optimize partitions + if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) { + shuffle +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +index d4a173bb9cc..21ef335e064 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +@@ -152,8 +152,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize)) + + val leftParts = if (isLeftSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(left.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- left.mapStats.get.shuffleId, partitionIndex, leftTargetSize) ++ left.mapStats.get.shuffleId, partitionIndex, leftTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Left side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " + +@@ -166,8 +168,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + } + + val rightParts = if (isRightSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(right.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- right.mapStats.get.shuffleId, partitionIndex, rightTargetSize) ++ right.mapStats.get.shuffleId, partitionIndex, rightTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Right side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -index af689db3379..f277bc396d4 100644 +index af689db3379..9d9f9c994b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer - import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} - import org.apache.spark.internal.Logging - import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} -+import org.apache.spark.sql.internal.SQLConf - - object ShufflePartitionsUtil extends Logging { - final val SMALL_PARTITION_FACTOR = 0.2 -@@ -387,6 +388,23 @@ object ShufflePartitionsUtil extends Logging { +@@ -380,13 +380,27 @@ object ShufflePartitionsUtil extends Logging { + shuffleId: Int, + reducerId: Int, + targetSize: Long, +- smallPartitionFactor: Double = SMALL_PARTITION_FACTOR) ++ smallPartitionFactor: Double = SMALL_PARTITION_FACTOR, ++ isCelebornShuffle: Boolean = false) + : Option[Seq[PartialReducerPartitionSpec]] = { + val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId) + if (mapPartitionSizes.exists(_ < 0)) return None val mapStartIndices = splitSizeListByTargetSize( mapPartitionSizes, targetSize, smallPartitionFactor) if (mapStartIndices.length > 1) { + val isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = -+ // TODO: check fallback or not. + SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && -+ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled ++ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle + + val throwsFetchFailure = SparkEnv.get + .conf + .get("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false") + .toBoolean + if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { -+ throw new UnsupportedOperationException( -+ "Currently, the 'Optimize Skewed Partition Read' feature cannot be used " + -+ "together with the 'Stage Re-run' feature. (The configuration parameters " + -+ "`spark.celeborn.client.spark.fetch.throwsFetchFailure` and " + -+ "`spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled` cannot be set " + -+ "to `true` at the same time.)") ++ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") ++ val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] ++ mapOutputTracker.skewShuffleIds.add(shuffleId) + } Some(mapStartIndices.indices.map { i => val startMapIndex = mapStartIndices(i) val endMapIndex = if (i == mapStartIndices.length - 1) { -@@ -400,7 +418,14 @@ object ShufflePartitionsUtil extends Logging { +@@ -400,7 +414,14 @@ object ShufflePartitionsUtil extends Logging { dataSize += mapPartitionSizes(mapIndex) mapIndex += 1 } diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch new file mode 100644 index 00000000000..69854c91013 --- /dev/null +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch @@ -0,0 +1,233 @@ +diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +index fade0b86dd8..3290d9fdf23 100644 +--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala ++++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +@@ -697,6 +697,8 @@ private[spark] class MapOutputTrackerMaster( + pool + } + ++ val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() ++ + // Make sure that we aren't going to exceed the max RPC message size by making sure + // we use broadcast to send large map output statuses. + if (minSizeForBroadcast > maxRpcMessageSize) { +@@ -887,6 +889,7 @@ private[spark] class MapOutputTrackerMaster( + shuffleStatus.invalidateSerializedMergeOutputStatusCache() + } + } ++ skewShuffleIds.remove(shuffleId) + } + + /** +diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +index 26be8c72bbc..c3e1d98c06f 100644 +--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +@@ -1435,7 +1435,10 @@ private[spark] class DAGScheduler( + // The operation here can make sure for the partially completed intermediate stage, + // `findMissingPartitions()` returns all partitions every time. + stage match { +- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => ++ case sms: ShuffleMapStage if (stage.isIndeterminate && !sms.isAvailable) || ++ mapOutputTracker.skewShuffleIds.contains(sms.shuffleDep.shuffleId) => ++ logInfo(s"Unregistering shuffle output for stage ${stage.id}" + ++ s" shuffle ${sms.shuffleDep.shuffleId}") + mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) + sms.shuffleDep.newShuffleMergeState() + case _ => +@@ -1897,7 +1900,8 @@ private[spark] class DAGScheduler( + + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || +- disallowStageRetryForTest ++ disallowStageRetryForTest || (failedStage.isInstanceOf[ResultStage] ++ && mapOutputTracker.skewShuffleIds.contains(shuffleId)) + + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +index be9a7c82828..195fdffd501 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +@@ -4208,6 +4208,12 @@ object SQLConf { + .booleanConf + .createWithDefault(false) + ++ val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = ++ buildConf("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") ++ .version("3.0.0") ++ .booleanConf ++ .createWithDefault(false) ++ + /** + * Holds information about keys that have been deprecated. + * +@@ -5040,6 +5046,9 @@ class SQLConf extends Serializable with Logging { + getConf(SQLConf.LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT) + } + ++ def celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = ++ getConf(SQLConf.CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ) ++ + /** ********************** SQLConf functionality methods ************ */ + + /** Set Spark SQL configuration properties. */ +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +new file mode 100644 +index 00000000000..3dc60678461 +--- /dev/null ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +@@ -0,0 +1,35 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.sql.execution.adaptive ++ ++import java.util.Locale ++ ++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} ++ ++object CelebornShuffleUtil { ++ ++ def isCelebornShuffle(shuffleExchangeLike: ShuffleExchangeLike): Boolean = { ++ shuffleExchangeLike match { ++ case exec: ShuffleExchangeExec => ++ exec.shuffleDependency.shuffleHandle ++ .getClass.getName.toLowerCase(Locale.ROOT).contains("celeborn") ++ case _ => false ++ } ++ } ++ ++} +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +index b34ab3e380b..cb0ed9d05a4 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +@@ -47,14 +47,15 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + private def optimizeSkewedPartitions( + shuffleId: Int, + bytesByPartitionId: Array[Long], +- targetSize: Long): Seq[ShufflePartitionSpec] = { ++ targetSize: Long, ++ isCelebornShuffle: Boolean = false): Seq[ShufflePartitionSpec] = { + val smallPartitionFactor = + conf.getConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR) + bytesByPartitionId.indices.flatMap { reduceIndex => + val bytes = bytesByPartitionId(reduceIndex) + if (bytes > targetSize) { + val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs( +- shuffleId, reduceIndex, targetSize, smallPartitionFactor) ++ shuffleId, reduceIndex, targetSize, smallPartitionFactor, isCelebornShuffle) + if (newPartitionSpec.isEmpty) { + CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil + } else { +@@ -76,8 +77,9 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + return shuffle + } + ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(shuffle.shuffle) + val newPartitionsSpec = optimizeSkewedPartitions( +- mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize) ++ mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize, isCelebornShuffle) + // return origin plan if we can not optimize partitions + if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) { + shuffle +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +index 37cdea084d8..4694a06919e 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +@@ -152,8 +152,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize)) + + val leftParts = if (isLeftSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(left.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- left.mapStats.get.shuffleId, partitionIndex, leftTargetSize) ++ left.mapStats.get.shuffleId, partitionIndex, leftTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Left side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " + +@@ -166,8 +168,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + } + + val rightParts = if (isRightSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(right.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- right.mapStats.get.shuffleId, partitionIndex, rightTargetSize) ++ right.mapStats.get.shuffleId, partitionIndex, rightTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Right side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +index dbed66683b0..c017c6d1229 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} + import org.apache.spark.internal.Logging + import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} ++import org.apache.spark.sql.internal.SQLConf + + object ShufflePartitionsUtil extends Logging { + final val SMALL_PARTITION_FACTOR = 0.2 +@@ -380,13 +381,27 @@ object ShufflePartitionsUtil extends Logging { + shuffleId: Int, + reducerId: Int, + targetSize: Long, +- smallPartitionFactor: Double = SMALL_PARTITION_FACTOR) ++ smallPartitionFactor: Double = SMALL_PARTITION_FACTOR, ++ isCelebornShuffle: Boolean = false) + : Option[Seq[PartialReducerPartitionSpec]] = { + val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId) + if (mapPartitionSizes.exists(_ < 0)) return None + val mapStartIndices = splitSizeListByTargetSize( + mapPartitionSizes, targetSize, smallPartitionFactor) + if (mapStartIndices.length > 1) { ++ val isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = ++ SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle ++ ++ val throwsFetchFailure = SparkEnv.get ++ .conf ++ .get("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false") ++ .toBoolean ++ if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") ++ val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] ++ mapOutputTracker.skewShuffleIds.add(shuffleId) ++ } + Some(mapStartIndices.indices.map { i => + val startMapIndex = mapStartIndices(i) + val endMapIndex = if (i == mapStartIndices.length - 1) { +@@ -400,7 +415,14 @@ object ShufflePartitionsUtil extends Logging { + dataSize += mapPartitionSizes(mapIndex) + mapIndex += 1 + } +- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ if (isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ // These `dataSize` variables may not be accurate as they only represent the sum of ++ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. ++ // Please not to use these dataSize variables in any other part of the codebase. ++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize) ++ } else { ++ PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ } + }) + } else { + None diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch new file mode 100644 index 00000000000..0ff8a02719c --- /dev/null +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch @@ -0,0 +1,233 @@ +diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +index 9a7a3b0c0e7..c886263b3eb 100644 +--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala ++++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +@@ -726,6 +726,8 @@ private[spark] class MapOutputTrackerMaster( + + private val availableProcessors = Runtime.getRuntime.availableProcessors() + ++ val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() ++ + // Make sure that we aren't going to exceed the max RPC message size by making sure + // we use broadcast to send large map output statuses. + if (minSizeForBroadcast > maxRpcMessageSize) { +@@ -916,6 +918,7 @@ private[spark] class MapOutputTrackerMaster( + shuffleStatus.invalidateSerializedMergeOutputStatusCache() + } + } ++ skewShuffleIds.remove(shuffleId) + } + + /** +diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +index 89d16e57934..dab1eca457e 100644 +--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +@@ -1480,7 +1480,10 @@ private[spark] class DAGScheduler( + // The operation here can make sure for the partially completed intermediate stage, + // `findMissingPartitions()` returns all partitions every time. + stage match { +- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => ++ case sms: ShuffleMapStage if (stage.isIndeterminate && !sms.isAvailable) || ++ mapOutputTracker.skewShuffleIds.contains(sms.shuffleDep.shuffleId) => ++ logInfo(s"Unregistering shuffle output for stage ${stage.id}" + ++ s" shuffle ${sms.shuffleDep.shuffleId}") + mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) + sms.shuffleDep.newShuffleMergeState() + case _ => +@@ -1962,7 +1965,8 @@ private[spark] class DAGScheduler( + + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || +- disallowStageRetryForTest ++ disallowStageRetryForTest || (failedStage.isInstanceOf[ResultStage] ++ && mapOutputTracker.skewShuffleIds.contains(shuffleId)) + + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +index 6f2f0088fcc..706c3ce70db 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +@@ -4423,6 +4423,12 @@ object SQLConf { + .booleanConf + .createWithDefault(false) + ++ val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = ++ buildConf("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") ++ .version("3.0.0") ++ .booleanConf ++ .createWithDefault(false) ++ + /** + * Holds information about keys that have been deprecated. + * +@@ -5278,6 +5284,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { + getConf(SQLConf.LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT) + } + ++ def celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = ++ getConf(SQLConf.CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ) ++ + /** ********************** SQLConf functionality methods ************ */ + + /** Set Spark SQL configuration properties. */ +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +new file mode 100644 +index 00000000000..3dc60678461 +--- /dev/null ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +@@ -0,0 +1,35 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.sql.execution.adaptive ++ ++import java.util.Locale ++ ++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} ++ ++object CelebornShuffleUtil { ++ ++ def isCelebornShuffle(shuffleExchangeLike: ShuffleExchangeLike): Boolean = { ++ shuffleExchangeLike match { ++ case exec: ShuffleExchangeExec => ++ exec.shuffleDependency.shuffleHandle ++ .getClass.getName.toLowerCase(Locale.ROOT).contains("celeborn") ++ case _ => false ++ } ++ } ++ ++} +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +index abd096b9c7c..ff0363f87d8 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +@@ -47,14 +47,15 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + private def optimizeSkewedPartitions( + shuffleId: Int, + bytesByPartitionId: Array[Long], +- targetSize: Long): Seq[ShufflePartitionSpec] = { ++ targetSize: Long, ++ isCelebornShuffle: Boolean = false): Seq[ShufflePartitionSpec] = { + val smallPartitionFactor = + conf.getConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR) + bytesByPartitionId.indices.flatMap { reduceIndex => + val bytes = bytesByPartitionId(reduceIndex) + if (bytes > targetSize) { + val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs( +- shuffleId, reduceIndex, targetSize, smallPartitionFactor) ++ shuffleId, reduceIndex, targetSize, smallPartitionFactor, isCelebornShuffle) + if (newPartitionSpec.isEmpty) { + CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil + } else { +@@ -77,8 +78,9 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + return shuffle + } + ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(shuffle.shuffle) + val newPartitionsSpec = optimizeSkewedPartitions( +- mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize) ++ mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize, isCelebornShuffle) + // return origin plan if we can not optimize partitions + if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) { + shuffle +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +index 37cdea084d8..4694a06919e 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +@@ -152,8 +152,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize)) + + val leftParts = if (isLeftSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(left.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- left.mapStats.get.shuffleId, partitionIndex, leftTargetSize) ++ left.mapStats.get.shuffleId, partitionIndex, leftTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Left side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " + +@@ -166,8 +168,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + } + + val rightParts = if (isRightSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(right.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- right.mapStats.get.shuffleId, partitionIndex, rightTargetSize) ++ right.mapStats.get.shuffleId, partitionIndex, rightTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Right side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +index 9370b3d8d1d..a00383e9b83 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} + import org.apache.spark.internal.Logging + import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} ++import org.apache.spark.sql.internal.SQLConf + + object ShufflePartitionsUtil extends Logging { + final val SMALL_PARTITION_FACTOR = 0.2 +@@ -382,13 +383,26 @@ object ShufflePartitionsUtil extends Logging { + shuffleId: Int, + reducerId: Int, + targetSize: Long, +- smallPartitionFactor: Double = SMALL_PARTITION_FACTOR) +- : Option[Seq[PartialReducerPartitionSpec]] = { ++ smallPartitionFactor: Double = SMALL_PARTITION_FACTOR, ++ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = { + val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId) + if (mapPartitionSizes.exists(_ < 0)) return None + val mapStartIndices = splitSizeListByTargetSize( + mapPartitionSizes, targetSize, smallPartitionFactor) + if (mapStartIndices.length > 1) { ++ val isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = ++ SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle ++ ++ val throwsFetchFailure = SparkEnv.get ++ .conf ++ .get("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false") ++ .toBoolean ++ if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") ++ val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] ++ mapOutputTracker.skewShuffleIds.add(shuffleId) ++ } + Some(mapStartIndices.indices.map { i => + val startMapIndex = mapStartIndices(i) + val endMapIndex = if (i == mapStartIndices.length - 1) { +@@ -402,7 +416,14 @@ object ShufflePartitionsUtil extends Logging { + dataSize += mapPartitionSizes(mapIndex) + mapIndex += 1 + } +- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ if (isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ // These `dataSize` variables may not be accurate as they only represent the sum of ++ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. ++ // Please not to use these dataSize variables in any other part of the codebase. ++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize) ++ } else { ++ PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ } + }) + } else { + None From 8c17522f0c2aa0db8da09cf327ca432cb7555b17 Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Tue, 26 Nov 2024 21:56:35 +0800 Subject: [PATCH 33/44] fix ut compile error and add license to spark patch --- ...leborn-Optimize-Skew-Partitions-spark3_2.patch | 15 +++++++++++++++ ...leborn-Optimize-Skew-Partitions-spark3_3.patch | 15 +++++++++++++++ ...leborn-Optimize-Skew-Partitions-spark3_4.patch | 15 +++++++++++++++ ...leborn-Optimize-Skew-Partitions-spark3_5.patch | 15 +++++++++++++++ .../client/LifecycleManagerCommitFilesSuite.scala | 2 -- 5 files changed, 60 insertions(+), 2 deletions(-) diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch index eadcfa4668e..536422af2ac 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch @@ -1,3 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index e469c9989f2..245d9b3b9de 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch index 3a6ac262ea3..8ba4978b382 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch @@ -1,3 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index b1974948430..0dc92ec44a8 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch index 69854c91013..b906e7f37ea 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch @@ -1,3 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index fade0b86dd8..3290d9fdf23 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch index 0ff8a02719c..65efae79589 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch @@ -1,3 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 9a7a3b0c0e7..c886263b3eb 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala index a772316f318..39b54b5f638 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala @@ -27,8 +27,6 @@ import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl, WithShuf import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers import org.apache.celeborn.client.commit.CommitFilesParam import org.apache.celeborn.common.CelebornConf -import org.apache.celeborn.common.protocol.CompressionCodec -import org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.util.Utils import org.apache.celeborn.service.deploy.MiniClusterFeature From 21f14d947afed837a01ff7aca47c3088cfd2f470 Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Wed, 27 Nov 2024 16:51:47 +0800 Subject: [PATCH 34/44] update spark patch to abort stage when rerun skew join stage --- ...rn-Optimize-Skew-Partitions-spark3_2.patch | 33 ++++++----- ...rn-Optimize-Skew-Partitions-spark3_3.patch | 46 ++++++++++----- ...rn-Optimize-Skew-Partitions-spark3_4.patch | 58 ++++++++++--------- ...rn-Optimize-Skew-Partitions-spark3_5.patch | 34 ++++++----- 4 files changed, 101 insertions(+), 70 deletions(-) diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch index 536422af2ac..dcaec1178be 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch @@ -35,7 +35,7 @@ index e469c9989f2..245d9b3b9de 100644 /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -index b950c07f3d8..e9e10bb647f 100644 +index b950c07f3d8..d081b4642c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1369,7 +1369,10 @@ private[spark] class DAGScheduler( @@ -50,21 +50,20 @@ index b950c07f3d8..e9e10bb647f 100644 mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) sms.shuffleDep.newShuffleMergeState() case _ => -@@ -1780,7 +1783,8 @@ private[spark] class DAGScheduler( +@@ -1780,7 +1783,7 @@ private[spark] class DAGScheduler( failedStage.failedAttemptIds.add(task.stageAttemptId) val shouldAbortStage = failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || - disallowStageRetryForTest -+ disallowStageRetryForTest || (failedStage.isInstanceOf[ResultStage] -+ && mapOutputTracker.skewShuffleIds.contains(shuffleId)) ++ disallowStageRetryForTest || mapOutputTracker.skewShuffleIds.contains(shuffleId) // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -index 6bc8ba4eebb..2e7d87c96eb 100644 +index 6bc8ba4eebb..44db30dbaec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -@@ -3431,6 +3431,12 @@ object SQLConf { +@@ -3431,6 +3431,19 @@ object SQLConf { .booleanConf .createWithDefault(false) @@ -73,16 +72,25 @@ index 6bc8ba4eebb..2e7d87c96eb 100644 + .version("3.0.0") + .booleanConf + .createWithDefault(false) ++ ++ val CELEBORN_STAGE_RERUN_ENABLED = ++ buildConf("spark.celeborn.client.spark.stageRerun.enabled") ++ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") ++ .version("3.0.0") ++ .booleanConf ++ .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * -@@ -4154,6 +4160,9 @@ class SQLConf extends Serializable with Logging { +@@ -4154,6 +4167,11 @@ class SQLConf extends Serializable with Logging { def legacyParquetNanosAsLong: Boolean = getConf(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG) + def celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = + getConf(SQLConf.CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ) ++ ++ def celebornStageRerunEnabled: Boolean = getConf(SQLConf.CELEBORN_STAGE_RERUN_ENABLED) + /** ********************** SQLConf functionality methods ************ */ @@ -189,7 +197,7 @@ index 88abe68197b..150699a84a3 100644 logDebug(s"Right side partition $partitionIndex " + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -index 3609548f374..f7c6d5dda90 100644 +index 3609548f374..59c80198f19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer @@ -200,7 +208,7 @@ index 3609548f374..f7c6d5dda90 100644 object ShufflePartitionsUtil extends Logging { final val SMALL_PARTITION_FACTOR = 0.2 -@@ -376,11 +377,25 @@ object ShufflePartitionsUtil extends Logging { +@@ -376,11 +377,22 @@ object ShufflePartitionsUtil extends Logging { def createSkewPartitionSpecs( shuffleId: Int, reducerId: Int, @@ -215,10 +223,7 @@ index 3609548f374..f7c6d5dda90 100644 + SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && + SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle + -+ val throwsFetchFailure = SparkEnv.get -+ .conf -+ .get("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false") -+ .toBoolean ++ val throwsFetchFailure = SQLConf.get.celebornStageRerunEnabled + if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") + val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] @@ -227,7 +232,7 @@ index 3609548f374..f7c6d5dda90 100644 Some(mapStartIndices.indices.map { i => val startMapIndex = mapStartIndices(i) val endMapIndex = if (i == mapStartIndices.length - 1) { -@@ -388,8 +403,20 @@ object ShufflePartitionsUtil extends Logging { +@@ -388,8 +400,20 @@ object ShufflePartitionsUtil extends Logging { } else { mapStartIndices(i + 1) } diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch index 8ba4978b382..3c194f8ce8d 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch @@ -35,7 +35,7 @@ index b1974948430..0dc92ec44a8 100644 /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -index bd2823bcac1..5d81b9de5b6 100644 +index bd2823bcac1..4f40becadc7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1404,7 +1404,10 @@ private[spark] class DAGScheduler( @@ -50,21 +50,20 @@ index bd2823bcac1..5d81b9de5b6 100644 mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) sms.shuffleDep.newShuffleMergeState() case _ => -@@ -1851,7 +1854,8 @@ private[spark] class DAGScheduler( +@@ -1851,7 +1854,7 @@ private[spark] class DAGScheduler( failedStage.failedAttemptIds.add(task.stageAttemptId) val shouldAbortStage = failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || - disallowStageRetryForTest -+ disallowStageRetryForTest || (failedStage.isInstanceOf[ResultStage] -+ && mapOutputTracker.skewShuffleIds.contains(shuffleId)) ++ disallowStageRetryForTest || mapOutputTracker.skewShuffleIds.contains(shuffleId) // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -index af03ad9a4cb..7a3ee9ebfaf 100644 +index af03ad9a4cb..6c36fb96d58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -@@ -3784,6 +3784,12 @@ object SQLConf { +@@ -3784,6 +3784,19 @@ object SQLConf { .booleanConf .createWithDefault(false) @@ -73,16 +72,25 @@ index af03ad9a4cb..7a3ee9ebfaf 100644 + .version("3.0.0") + .booleanConf + .createWithDefault(false) ++ ++ val CELEBORN_STAGE_RERUN_ENABLED = ++ buildConf("spark.celeborn.client.spark.stageRerun.enabled") ++ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") ++ .version("3.0.0") ++ .booleanConf ++ .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * -@@ -4549,6 +4555,9 @@ class SQLConf extends Serializable with Logging { +@@ -4549,6 +4562,11 @@ class SQLConf extends Serializable with Logging { def histogramNumericPropagateInputType: Boolean = getConf(SQLConf.HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE) + def celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = + getConf(SQLConf.CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ) ++ ++ def celebornStageRerunEnabled: Boolean = getConf(SQLConf.CELEBORN_STAGE_RERUN_ENABLED) + /** ********************** SQLConf functionality methods ************ */ @@ -190,17 +198,25 @@ index d4a173bb9cc..21ef335e064 100644 logDebug(s"Right side partition $partitionIndex " + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -index af689db3379..9d9f9c994b9 100644 +index af689db3379..529097549ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -@@ -380,13 +380,27 @@ object ShufflePartitionsUtil extends Logging { +@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} + import org.apache.spark.internal.Logging + import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} ++import org.apache.spark.sql.internal.SQLConf + + object ShufflePartitionsUtil extends Logging { + final val SMALL_PARTITION_FACTOR = 0.2 +@@ -380,13 +381,23 @@ object ShufflePartitionsUtil extends Logging { shuffleId: Int, reducerId: Int, targetSize: Long, - smallPartitionFactor: Double = SMALL_PARTITION_FACTOR) +- : Option[Seq[PartialReducerPartitionSpec]] = { + smallPartitionFactor: Double = SMALL_PARTITION_FACTOR, -+ isCelebornShuffle: Boolean = false) - : Option[Seq[PartialReducerPartitionSpec]] = { ++ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = { val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId) if (mapPartitionSizes.exists(_ < 0)) return None val mapStartIndices = splitSizeListByTargetSize( @@ -210,10 +226,7 @@ index af689db3379..9d9f9c994b9 100644 + SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && + SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle + -+ val throwsFetchFailure = SparkEnv.get -+ .conf -+ .get("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false") -+ .toBoolean ++ val throwsFetchFailure = SQLConf.get.celebornStageRerunEnabled + if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") + val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] @@ -222,11 +235,12 @@ index af689db3379..9d9f9c994b9 100644 Some(mapStartIndices.indices.map { i => val startMapIndex = mapStartIndices(i) val endMapIndex = if (i == mapStartIndices.length - 1) { -@@ -400,7 +414,14 @@ object ShufflePartitionsUtil extends Logging { +@@ -400,7 +411,15 @@ object ShufflePartitionsUtil extends Logging { dataSize += mapPartitionSizes(mapIndex) mapIndex += 1 } - PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ + if (isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + // These `dataSize` variables may not be accurate as they only represent the sum of + // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch index b906e7f37ea..3c194f8ce8d 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch @@ -14,10 +14,10 @@ # limitations under the License. diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala -index fade0b86dd8..3290d9fdf23 100644 +index b1974948430..0dc92ec44a8 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala -@@ -697,6 +697,8 @@ private[spark] class MapOutputTrackerMaster( +@@ -696,6 +696,8 @@ private[spark] class MapOutputTrackerMaster( pool } @@ -26,7 +26,7 @@ index fade0b86dd8..3290d9fdf23 100644 // Make sure that we aren't going to exceed the max RPC message size by making sure // we use broadcast to send large map output statuses. if (minSizeForBroadcast > maxRpcMessageSize) { -@@ -887,6 +889,7 @@ private[spark] class MapOutputTrackerMaster( +@@ -886,6 +888,7 @@ private[spark] class MapOutputTrackerMaster( shuffleStatus.invalidateSerializedMergeOutputStatusCache() } } @@ -35,10 +35,10 @@ index fade0b86dd8..3290d9fdf23 100644 /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -index 26be8c72bbc..c3e1d98c06f 100644 +index bd2823bcac1..4f40becadc7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -@@ -1435,7 +1435,10 @@ private[spark] class DAGScheduler( +@@ -1404,7 +1404,10 @@ private[spark] class DAGScheduler( // The operation here can make sure for the partially completed intermediate stage, // `findMissingPartitions()` returns all partitions every time. stage match { @@ -50,39 +50,47 @@ index 26be8c72bbc..c3e1d98c06f 100644 mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) sms.shuffleDep.newShuffleMergeState() case _ => -@@ -1897,7 +1900,8 @@ private[spark] class DAGScheduler( - +@@ -1851,7 +1854,7 @@ private[spark] class DAGScheduler( + failedStage.failedAttemptIds.add(task.stageAttemptId) val shouldAbortStage = failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || - disallowStageRetryForTest -+ disallowStageRetryForTest || (failedStage.isInstanceOf[ResultStage] -+ && mapOutputTracker.skewShuffleIds.contains(shuffleId)) ++ disallowStageRetryForTest || mapOutputTracker.skewShuffleIds.contains(shuffleId) // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -index be9a7c82828..195fdffd501 100644 +index af03ad9a4cb..6c36fb96d58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -@@ -4208,6 +4208,12 @@ object SQLConf { - .booleanConf - .createWithDefault(false) +@@ -3784,6 +3784,19 @@ object SQLConf { + .booleanConf + .createWithDefault(false) + val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = + buildConf("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") + .version("3.0.0") + .booleanConf + .createWithDefault(false) ++ ++ val CELEBORN_STAGE_RERUN_ENABLED = ++ buildConf("spark.celeborn.client.spark.stageRerun.enabled") ++ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") ++ .version("3.0.0") ++ .booleanConf ++ .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * -@@ -5040,6 +5046,9 @@ class SQLConf extends Serializable with Logging { - getConf(SQLConf.LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT) - } +@@ -4549,6 +4562,11 @@ class SQLConf extends Serializable with Logging { + def histogramNumericPropagateInputType: Boolean = + getConf(SQLConf.HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE) + def celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = + getConf(SQLConf.CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ) ++ ++ def celebornStageRerunEnabled: Boolean = getConf(SQLConf.CELEBORN_STAGE_RERUN_ENABLED) + /** ********************** SQLConf functionality methods ************ */ @@ -162,7 +170,7 @@ index b34ab3e380b..cb0ed9d05a4 100644 if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) { shuffle diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala -index 37cdea084d8..4694a06919e 100644 +index d4a173bb9cc..21ef335e064 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -152,8 +152,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) @@ -190,7 +198,7 @@ index 37cdea084d8..4694a06919e 100644 logDebug(s"Right side partition $partitionIndex " + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -index dbed66683b0..c017c6d1229 100644 +index af689db3379..529097549ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer @@ -201,14 +209,14 @@ index dbed66683b0..c017c6d1229 100644 object ShufflePartitionsUtil extends Logging { final val SMALL_PARTITION_FACTOR = 0.2 -@@ -380,13 +381,27 @@ object ShufflePartitionsUtil extends Logging { +@@ -380,13 +381,23 @@ object ShufflePartitionsUtil extends Logging { shuffleId: Int, reducerId: Int, targetSize: Long, - smallPartitionFactor: Double = SMALL_PARTITION_FACTOR) +- : Option[Seq[PartialReducerPartitionSpec]] = { + smallPartitionFactor: Double = SMALL_PARTITION_FACTOR, -+ isCelebornShuffle: Boolean = false) - : Option[Seq[PartialReducerPartitionSpec]] = { ++ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = { val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId) if (mapPartitionSizes.exists(_ < 0)) return None val mapStartIndices = splitSizeListByTargetSize( @@ -218,10 +226,7 @@ index dbed66683b0..c017c6d1229 100644 + SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && + SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle + -+ val throwsFetchFailure = SparkEnv.get -+ .conf -+ .get("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false") -+ .toBoolean ++ val throwsFetchFailure = SQLConf.get.celebornStageRerunEnabled + if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") + val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] @@ -230,11 +235,12 @@ index dbed66683b0..c017c6d1229 100644 Some(mapStartIndices.indices.map { i => val startMapIndex = mapStartIndices(i) val endMapIndex = if (i == mapStartIndices.length - 1) { -@@ -400,7 +415,14 @@ object ShufflePartitionsUtil extends Logging { +@@ -400,7 +411,15 @@ object ShufflePartitionsUtil extends Logging { dataSize += mapPartitionSizes(mapIndex) mapIndex += 1 } - PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ + if (isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + // These `dataSize` variables may not be accurate as they only represent the sum of + // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch index 65efae79589..66544751433 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch @@ -35,7 +35,7 @@ index 9a7a3b0c0e7..c886263b3eb 100644 /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -index 89d16e57934..dab1eca457e 100644 +index 89d16e57934..24aad10c2a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1480,7 +1480,10 @@ private[spark] class DAGScheduler( @@ -50,21 +50,20 @@ index 89d16e57934..dab1eca457e 100644 mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) sms.shuffleDep.newShuffleMergeState() case _ => -@@ -1962,7 +1965,8 @@ private[spark] class DAGScheduler( +@@ -1962,7 +1965,7 @@ private[spark] class DAGScheduler( val shouldAbortStage = failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || - disallowStageRetryForTest -+ disallowStageRetryForTest || (failedStage.isInstanceOf[ResultStage] -+ && mapOutputTracker.skewShuffleIds.contains(shuffleId)) ++ disallowStageRetryForTest || mapOutputTracker.skewShuffleIds.contains(shuffleId) // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -index 6f2f0088fcc..706c3ce70db 100644 +index 6f2f0088fcc..3a7b1aabbbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -@@ -4423,6 +4423,12 @@ object SQLConf { +@@ -4423,6 +4423,19 @@ object SQLConf { .booleanConf .createWithDefault(false) @@ -73,16 +72,25 @@ index 6f2f0088fcc..706c3ce70db 100644 + .version("3.0.0") + .booleanConf + .createWithDefault(false) ++ ++ val CELEBORN_STAGE_RERUN_ENABLED = ++ buildConf("spark.celeborn.client.spark.stageRerun.enabled") ++ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") ++ .version("3.0.0") ++ .booleanConf ++ .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * -@@ -5278,6 +5284,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { +@@ -5278,6 +5291,11 @@ class SQLConf extends Serializable with Logging with SqlApiConf { getConf(SQLConf.LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT) } + def celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = + getConf(SQLConf.CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ) ++ ++ def celebornStageRerunEnabled: Boolean = getConf(SQLConf.CELEBORN_STAGE_RERUN_ENABLED) + /** ********************** SQLConf functionality methods ************ */ @@ -190,7 +198,7 @@ index 37cdea084d8..4694a06919e 100644 logDebug(s"Right side partition $partitionIndex " + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -index 9370b3d8d1d..a00383e9b83 100644 +index 9370b3d8d1d..28d4b5b8d8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer @@ -201,7 +209,7 @@ index 9370b3d8d1d..a00383e9b83 100644 object ShufflePartitionsUtil extends Logging { final val SMALL_PARTITION_FACTOR = 0.2 -@@ -382,13 +383,26 @@ object ShufflePartitionsUtil extends Logging { +@@ -382,13 +383,23 @@ object ShufflePartitionsUtil extends Logging { shuffleId: Int, reducerId: Int, targetSize: Long, @@ -218,10 +226,7 @@ index 9370b3d8d1d..a00383e9b83 100644 + SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && + SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle + -+ val throwsFetchFailure = SparkEnv.get -+ .conf -+ .get("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false") -+ .toBoolean ++ val throwsFetchFailure = SQLConf.get.celebornStageRerunEnabled + if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") + val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] @@ -230,11 +235,12 @@ index 9370b3d8d1d..a00383e9b83 100644 Some(mapStartIndices.indices.map { i => val startMapIndex = mapStartIndices(i) val endMapIndex = if (i == mapStartIndices.length - 1) { -@@ -402,7 +416,14 @@ object ShufflePartitionsUtil extends Logging { +@@ -402,7 +413,15 @@ object ShufflePartitionsUtil extends Logging { dataSize += mapPartitionSizes(mapIndex) mapIndex += 1 } - PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ + if (isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + // These `dataSize` variables may not be accurate as they only represent the sum of + // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. From e9efddc4480aada31631b2b113f7e1eeee3e5954 Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Sun, 22 Dec 2024 14:57:13 +0800 Subject: [PATCH 35/44] address comment, avoid reading replicate peer when read skew partition without mapRange --- .../client/read/CelebornInputStream.java | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 96b64497bdc..67e27a7cd84 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -75,9 +75,9 @@ public static CelebornInputStream create( } else { // if startMapIndex > endMapIndex, means partition is skew partition. // locations will split to sub-partitions with startMapIndex size. - boolean splitSkewPartitionWithoutMapRange = - conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled() && startMapIndex > endMapIndex; - if (splitSkewPartitionWithoutMapRange) { + boolean readSkewPartitionWithoutMapRange = + readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex); + if (readSkewPartitionWithoutMapRange) { return new CelebornInputStreamImpl( conf, clientFactory, @@ -95,7 +95,7 @@ public static CelebornInputStream create( shuffleId, partitionId, exceptionMaker, - splitSkewPartitionWithoutMapRange, + true, metricsCallback); } else { return new CelebornInputStreamImpl( @@ -117,7 +117,7 @@ public static CelebornInputStream create( shuffleId, partitionId, exceptionMaker, - splitSkewPartitionWithoutMapRange, + false, metricsCallback); } } @@ -127,6 +127,11 @@ public static CelebornInputStream empty() { return emptyInputStream; } + public static boolean readSkewPartitionWithoutMapRange( + CelebornConf conf, int startMapIndex, int endMapIndex) { + return conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled() && startMapIndex > endMapIndex; + } + private static final CelebornInputStream emptyInputStream = new CelebornInputStream() { @Override @@ -209,7 +214,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private ExceptionMaker exceptionMaker; private boolean closed = false; - private final boolean splitSkewPartitionWithoutMapRange; + private final boolean readSkewPartitionWithoutMapRange; CelebornInputStreamImpl( CelebornConf conf, @@ -273,7 +278,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { int shuffleId, int partitionId, ExceptionMaker exceptionMaker, - boolean splitSkewPartitionWithoutMapRange, + boolean readSkewPartitionWithoutMapRange, MetricsCallback metricsCallback) throws IOException { this.conf = conf; @@ -296,7 +301,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { !conf.shuffleCompressionCodec().equals(CompressionCodec.NONE); this.fetchExcludedWorkerExpireTimeout = conf.clientFetchExcludedWorkerExpireTimeout(); this.failedBatches = failedBatchSet; - this.splitSkewPartitionWithoutMapRange = splitSkewPartitionWithoutMapRange; + this.readSkewPartitionWithoutMapRange = readSkewPartitionWithoutMapRange; this.fetchExcludedWorkers = fetchExcludedWorkers; if (conf.clientPushReplicateEnabled()) { @@ -346,9 +351,9 @@ private Tuple2 nextReadableLocation() { } PartitionLocation currentLocation = locations.get(fileIndex); // if pushShuffleFailureTrackingEnabled is true, should not skip location - while ((splitSkewPartitionWithoutMapRange + while ((readSkewPartitionWithoutMapRange && !partitionLocationToChunkRange.containsKey(currentLocation.getUniqueId())) - || (!splitSkewPartitionWithoutMapRange + || (!readSkewPartitionWithoutMapRange && skipLocation(startMapIndex, endMapIndex, currentLocation))) { skipCount.increment(); fileIndex++; @@ -414,7 +419,10 @@ private boolean isExcluded(PartitionLocation location) { private PartitionReader createReaderWithRetry( PartitionLocation location, PbStreamHandler pbStreamHandler) throws IOException { // For the first time, the location will be selected according to attemptNumber - if (fetchChunkRetryCnt == 0 && attemptNumber % 2 == 1 && location.hasPeer()) { + if (fetchChunkRetryCnt == 0 + && attemptNumber % 2 == 1 + && location.hasPeer() + && !readSkewPartitionWithoutMapRange) { location = location.getPeer(); logger.debug("Read peer {} for attempt {}.", location, attemptNumber); } @@ -429,7 +437,7 @@ private PartitionReader createReaderWithRetry( lastException = e; shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort(), e); fetchChunkRetryCnt++; - if (location.hasPeer()) { + if (location.hasPeer() && !readSkewPartitionWithoutMapRange) { // fetchChunkRetryCnt % 2 == 0 means both replicas have been tried, // so sleep before next try. if (fetchChunkRetryCnt % 2 == 0) { @@ -478,7 +486,7 @@ private ByteBuf getNextChunk() throws IOException { + currentReader.getLocation(), e); } else { - if (currentReader.getLocation().hasPeer()) { + if (currentReader.getLocation().hasPeer() && !readSkewPartitionWithoutMapRange) { logger.warn( "Fetch chunk failed {}/{} times for location {}, change to peer", fetchChunkRetryCnt, @@ -715,7 +723,7 @@ private boolean fillBuffer() throws IOException { // de-duplicate if (attemptId == attempts[mapId]) { - if (splitSkewPartitionWithoutMapRange) { + if (readSkewPartitionWithoutMapRange) { Set failedBatchSet = this.failedBatches.get(currentReader.getLocation().getUniqueId()); if (null != failedBatchSet) { From 23ca472b886d2f0addf5ab30a803f60c60a04122 Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Mon, 23 Dec 2024 10:30:57 +0800 Subject: [PATCH 36/44] update ClientUtils --- .../celeborn/CelebornShuffleReader.scala | 7 +++---- .../client/read/CelebornInputStream.java | 8 ++------ .../apache/celeborn/client/ClientUtils.scala | 18 ++++++++++++++++++ 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index a9f21f75d1f..7c14d11bcb9 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -34,7 +34,7 @@ import org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.client.{ClientUtils, ShuffleClient} import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback} import org.apache.celeborn.common.CelebornConf @@ -130,7 +130,7 @@ class CelebornShuffleReader[K, C]( // if startMapIndex > endMapIndex, means partition is skew partition. // locations will split to sub-partitions with startMapIndex size. val splitSkewPartitionWithoutMapRange = - conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled && startMapIndex > endMapIndex + ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex) (startPartition until endPartition).foreach { partitionId => if (fileGroups.partitionGroups.containsKey(partitionId)) { @@ -154,7 +154,7 @@ class CelebornShuffleReader[K, C]( locations.asScala.foreach { location => partCnt += 1 val hostPort = location.hostAndFetchPort - if (!workerRequestMap.containsKey(hostPort)) { + if (!workerRequestMap.containsKey(hostPort)) try { val client = shuffleClient.getDataClientFactory().createClient( location.getHost, @@ -171,7 +171,6 @@ class CelebornShuffleReader[K, C]( s"Failed to create client for $shuffleKey-$partitionId from host: ${location.hostAndFetchPort}. " + s"Shuffle reader will try its replica if exists.") } - } workerRequestMap.get(hostPort) match { case (_, locArr, pbOpenStreamListBuilder) => locArr.add(location) diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 67e27a7cd84..72cc6262c18 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -35,6 +35,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.celeborn.client.ClientUtils; import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.client.compress.Decompressor; import org.apache.celeborn.common.CelebornConf; @@ -76,7 +77,7 @@ public static CelebornInputStream create( // if startMapIndex > endMapIndex, means partition is skew partition. // locations will split to sub-partitions with startMapIndex size. boolean readSkewPartitionWithoutMapRange = - readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex); + ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex); if (readSkewPartitionWithoutMapRange) { return new CelebornInputStreamImpl( conf, @@ -127,11 +128,6 @@ public static CelebornInputStream empty() { return emptyInputStream; } - public static boolean readSkewPartitionWithoutMapRange( - CelebornConf conf, int startMapIndex, int endMapIndex) { - return conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled() && startMapIndex > endMapIndex; - } - private static final CelebornInputStream emptyInputStream = new CelebornInputStream() { @Override diff --git a/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala b/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala index d7dccb941b0..b071eff3bf4 100644 --- a/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala +++ b/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala @@ -17,6 +17,8 @@ package org.apache.celeborn.client +import org.apache.celeborn.common.CelebornConf + object ClientUtils { /** @@ -37,4 +39,20 @@ object ClientUtils { } true } + + /** + * If startMapIndex > endMapIndex, means partition is skew partition. + * locations will split to sub-partitions with startMapIndex size. + * + * @param conf cleborn conf + * @param startMapIndex shuffle start map index + * @param endMapIndex shuffle end map index + * @return true if read skew partition without map range + */ + def readSkewPartitionWithoutMapRange( + conf: CelebornConf, + startMapIndex: Int, + endMapIndex: Int): Boolean = { + conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled && startMapIndex > endMapIndex + } } From 94cb56d0432b7549c04a7d4af59766e75d54311c Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Sun, 19 Jan 2025 17:09:25 +0800 Subject: [PATCH 37/44] address review comment --- .../celeborn/CelebornPartitionUtil.java | 27 +++++++++---------- .../celeborn/CelebornShuffleReader.scala | 5 ++-- .../celeborn/client/ShuffleClientImpl.java | 5 ++-- .../client/read/CelebornInputStream.java | 3 ++- common/src/main/proto/TransportMessages.proto | 2 -- 5 files changed, 20 insertions(+), 22 deletions(-) diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java index 52a912d081c..0fbef891b69 100644 --- a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java +++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java @@ -17,10 +17,7 @@ package org.apache.spark.shuffle.celeborn; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.HashMap; -import java.util.Map; +import java.util.*; import org.apache.commons.lang3.tuple.Pair; @@ -34,23 +31,22 @@ public static Map> splitSkewedPartitionLocations( locations.stream().mapToLong((PartitionLocation p) -> p.getStorageInfo().fileSize).sum(); long step = totalPartitionSize / subPartitionSize; long startOffset = step * subPartitionIndex; - long endOffset = 0; - if (subPartitionIndex == subPartitionSize - 1) { - // last subPartition should include all remaining data - endOffset = totalPartitionSize + 1; - } else { - endOffset = step * (subPartitionIndex + 1); - } + long endOffset = + subPartitionIndex < subPartitionSize - 1 + ? step * (subPartitionIndex + 1) + : totalPartitionSize + 1; // last subPartition should include all remaining data long partitionLocationOffset = 0; Map> chunkRange = new HashMap<>(); - for (int i = 0; i < locations.size(); i++) { - PartitionLocation p = locations.get(i); + for (PartitionLocation p : locations) { int left = -1; int right = -1; + Iterator chunkOffsets = p.getStorageInfo().getChunkOffsets().iterator(); // Start from index 1 since the first chunk offset is always 0. - for (int j = 1; j < p.getStorageInfo().getChunkOffsets().size(); j++) { - long currentOffset = partitionLocationOffset + p.getStorageInfo().getChunkOffsets().get(j); + chunkOffsets.next(); + int j = 1; + while (chunkOffsets.hasNext()) { + long currentOffset = partitionLocationOffset + chunkOffsets.next(); if (currentOffset > startOffset && left < 0) { left = j - 1; } @@ -60,6 +56,7 @@ public static Map> splitSkewedPartitionLocations( if (left >= 0 && right >= 0) { chunkRange.put(p.getUniqueId(), Pair.of(left, right)); } + j++; } partitionLocationOffset += p.getStorageInfo().getFileSize(); } diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 7c14d11bcb9..b0a063b2638 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -127,7 +127,7 @@ class CelebornShuffleReader[K, C]( var partCnt = 0 - // if startMapIndex > endMapIndex, means partition is skew partition. + // if startMapIndex > endMapIndex, means partition is skew partition and read by Celeborn implementation. // locations will split to sub-partitions with startMapIndex size. val splitSkewPartitionWithoutMapRange = ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex) @@ -154,7 +154,7 @@ class CelebornShuffleReader[K, C]( locations.asScala.foreach { location => partCnt += 1 val hostPort = location.hostAndFetchPort - if (!workerRequestMap.containsKey(hostPort)) + if (!workerRequestMap.containsKey(hostPort)) { try { val client = shuffleClient.getDataClientFactory().createClient( location.getHost, @@ -171,6 +171,7 @@ class CelebornShuffleReader[K, C]( s"Failed to create client for $shuffleKey-$partitionId from host: ${location.hostAndFetchPort}. " + s"Shuffle reader will try its replica if exists.") } + } workerRequestMap.get(hostPort) match { case (_, locArr, pbOpenStreamListBuilder) => locArr.add(location) diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 6cd4efa12fe..5da7d2fa33e 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -1552,9 +1552,10 @@ public void onSuccess(ByteBuffer response) { callback.onSuccess(ByteBuffer.wrap(new byte[] {StatusCode.SOFT_SPLIT.getValue()})); } else { if (dataPushFailureTrackingEnabled) { - for (int i = 0; i < numBatches; i++) { + for (DataBatches.DataBatch resubmitBatch : batchesNeedResubmit) { pushState.addFailedBatch( - partitionUniqueIds[i], new PushFailedBatch(mapId, attemptId, batchIds[i])); + resubmitBatch.loc.getUniqueId(), + new PushFailedBatch(mapId, attemptId, resubmitBatch.batchId)); } } ReviveRequest[] requests = diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 72cc6262c18..2604001a076 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -74,7 +74,8 @@ public static CelebornInputStream create( if (locations == null || locations.isEmpty()) { return emptyInputStream; } else { - // if startMapIndex > endMapIndex, means partition is skew partition. + // if startMapIndex > endMapIndex, means partition is skew partition and read by Celeborn + // implementation. // locations will split to sub-partitions with startMapIndex size. boolean readSkewPartitionWithoutMapRange = ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex); diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index 22d2a6150f1..553e95aa79e 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -706,8 +706,6 @@ message PbOpenStream { int32 endIndex = 4; int32 initialCredit = 5; bool readLocalShuffle = 6; - bool requireSubpartitionId = 7; - bool shuffleDataNeedSort = 8; } message PbStreamHandler { From 60422d1e1677d5fe635a6adc20b7d962cbc54949 Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Sun, 19 Jan 2025 19:49:40 +0800 Subject: [PATCH 38/44] remove duplicate uts --- .../celeborn/CelebornShuffleReaderSuiteJ.java | 130 ------------------ 1 file changed, 130 deletions(-) delete mode 100644 client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuiteJ.java diff --git a/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuiteJ.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuiteJ.java deleted file mode 100644 index 3ba875f707a..00000000000 --- a/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuiteJ.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.celeborn; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Map; - -import com.google.common.collect.Maps; -import org.apache.commons.lang3.tuple.Pair; -import org.junit.Assert; -import org.junit.Test; -import org.mockito.MockedStatic; -import org.mockito.Mockito; - -import org.apache.celeborn.client.ShuffleClient; -import org.apache.celeborn.common.CelebornConf; -import org.apache.celeborn.common.identity.UserIdentifier; -import org.apache.celeborn.common.protocol.PartitionLocation; -import org.apache.celeborn.common.protocol.StorageInfo; - -public class CelebornShuffleReaderSuiteJ { - @Test - public void testSkewPartitionSplit() { - CelebornShuffleHandle handle = - new CelebornShuffleHandle( - "appId", "host", 0, new UserIdentifier("mock", "mock"), 0, false, 10, null); - - MockedStatic shuffleClient = null; - try { - shuffleClient = Mockito.mockStatic(ShuffleClient.class); - CelebornShuffleReader shuffleReader = - new CelebornShuffleReader( - handle, 0, 10, 0, 10, null, new CelebornConf(), null, new ExecutorShuffleIdTracker()); - - ArrayList locations = new ArrayList<>(); - for (int i = 0; i < 13; i++) { - PartitionLocation location = - new PartitionLocation(0, i, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); - StorageInfo storageInfo = - new StorageInfo( - StorageInfo.Type.HDD, - "mountPoint", - false, - "filePath", - StorageInfo.LOCAL_DISK_MASK, - 1000, - Arrays.asList(0L, 100L, 200L, 300L, 500L, 1000L)); - location.setStorageInfo(storageInfo); - locations.add(location); - } - - PartitionLocation location = - new PartitionLocation(0, 91, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); - StorageInfo storageInfo = - new StorageInfo( - StorageInfo.Type.HDD, - "mountPoint", - false, - "filePath", - StorageInfo.LOCAL_DISK_MASK, - 1, - Arrays.asList(0L, 1L)); - location.setStorageInfo(storageInfo); - locations.add(location); - - Map> expectResult = Maps.newHashMap(); - - for (int i = 0; i < 5; i++) { - int subPartitionSize = 3; - - int subPartitionIndex = 0; - Map> result1 = - shuffleReader.splitSkewedPartitionLocations( - locations, subPartitionSize, subPartitionIndex); - expectResult.clear(); - expectResult.put("0-0", Pair.of(0, 5)); - expectResult.put("0-1", Pair.of(0, 5)); - expectResult.put("0-10", Pair.of(0, 5)); - expectResult.put("0-11", Pair.of(0, 5)); - expectResult.put("0-12", Pair.of(0, 3)); - Assert.assertEquals(expectResult, result1); - - subPartitionIndex = 1; - Map> result2 = - shuffleReader.splitSkewedPartitionLocations( - locations, subPartitionSize, subPartitionIndex); - expectResult.clear(); - expectResult.put("0-12", Pair.of(4, 5)); - expectResult.put("0-2", Pair.of(0, 5)); - expectResult.put("0-3", Pair.of(0, 5)); - expectResult.put("0-4", Pair.of(0, 5)); - expectResult.put("0-5", Pair.of(0, 4)); - Assert.assertEquals(expectResult, result2); - - subPartitionIndex = 2; - Map> result3 = - shuffleReader.splitSkewedPartitionLocations( - locations, subPartitionSize, subPartitionIndex); - expectResult.clear(); - expectResult.put("0-5", Pair.of(5, 5)); - expectResult.put("0-6", Pair.of(0, 5)); - expectResult.put("0-7", Pair.of(0, 5)); - expectResult.put("0-8", Pair.of(0, 5)); - expectResult.put("0-9", Pair.of(0, 5)); - expectResult.put("0-91", Pair.of(0, 1)); - Assert.assertEquals(expectResult, result3); - } - } finally { - if (null != shuffleClient) { - shuffleClient.close(); - } - } - } -} From 4130227a0acfb924026b92c67b3172ea0e27c8ae Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Thu, 23 Jan 2025 15:51:46 +0800 Subject: [PATCH 39/44] add spark patch --- ...rn-Optimize-Skew-Partitions-spark3_2.patch | 218 +++++++++++------- ...rn-Optimize-Skew-Partitions-spark3_3.patch | 210 +++++++++++------ ...rn-Optimize-Skew-Partitions-spark3_4.patch | 212 ++++++++++------- ...rn-Optimize-Skew-Partitions-spark3_5.patch | 210 +++++++++++------ 4 files changed, 541 insertions(+), 309 deletions(-) diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch index dcaec1178be..0cb1fc812dc 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch @@ -14,87 +14,147 @@ # limitations under the License. diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala -index e469c9989f2..245d9b3b9de 100644 +index e469c9989f2..a4a68ef1b09 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala -@@ -661,6 +661,8 @@ private[spark] class MapOutputTrackerMaster( - pool - } +@@ -33,6 +33,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut + import org.roaringbitmap.RoaringBitmap -+ val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() -+ - // Make sure that we aren't going to exceed the max RPC message size by making sure - // we use broadcast to send large map output statuses. - if (minSizeForBroadcast > maxRpcMessageSize) { -@@ -839,6 +841,7 @@ private[spark] class MapOutputTrackerMaster( + import org.apache.spark.broadcast.{Broadcast, BroadcastManager} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.Logging + import org.apache.spark.internal.config._ + import org.apache.spark.io.CompressionCodec +@@ -839,6 +840,7 @@ private[spark] class MapOutputTrackerMaster( shuffleStatus.invalidateSerializedMergeOutputStatusCache() } } -+ skewShuffleIds.remove(shuffleId) ++ CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId) } /** -diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -index b950c07f3d8..d081b4642c9 100644 ---- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -@@ -1369,7 +1369,10 @@ private[spark] class DAGScheduler( - // The operation here can make sure for the partially completed intermediate stage, - // `findMissingPartitions()` returns all partitions every time. - stage match { -- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => -+ case sms: ShuffleMapStage if (stage.isIndeterminate && !sms.isAvailable) || -+ mapOutputTracker.skewShuffleIds.contains(sms.shuffleDep.shuffleId) => -+ logInfo(s"Unregistering shuffle output for stage ${stage.id}" + -+ s" shuffle ${sms.shuffleDep.shuffleId}") - mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) - sms.shuffleDep.newShuffleMergeState() - case _ => -@@ -1780,7 +1783,7 @@ private[spark] class DAGScheduler( - failedStage.failedAttemptIds.add(task.stageAttemptId) - val shouldAbortStage = - failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || -- disallowStageRetryForTest -+ disallowStageRetryForTest || mapOutputTracker.skewShuffleIds.contains(shuffleId) - - // It is likely that we receive multiple FetchFailed for a single stage (because we have - // multiple tasks running concurrently on different executors). In that case, it is -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -index 6bc8ba4eebb..44db30dbaec 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -@@ -3431,6 +3431,19 @@ object SQLConf { - .booleanConf - .createWithDefault(false) +diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala +index 0388c7b576b..59fdc81b09d 100644 +--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala ++++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala +@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration + import org.apache.spark.annotation.DeveloperApi + import org.apache.spark.api.python.PythonWorkerFactory + import org.apache.spark.broadcast.BroadcastManager ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.{config, Logging} + import org.apache.spark.internal.config._ + import org.apache.spark.memory.{MemoryManager, UnifiedMemoryManager} +@@ -414,6 +415,7 @@ object SparkEnv extends Logging { + if (isDriver) { + val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath + envInstance.driverTmpDir = Some(sparkFilesDir) ++ CelebornShuffleState.init(envInstance) + } -+ val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = -+ buildConf("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") -+ .version("3.0.0") + envInstance +diff --git a/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +new file mode 100644 +index 00000000000..5e190c512df +--- /dev/null ++++ b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +@@ -0,0 +1,75 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.celeborn ++ ++import java.util.concurrent.ConcurrentHashMap ++import java.util.concurrent.atomic.AtomicBoolean ++ ++import org.apache.spark.SparkEnv ++import org.apache.spark.internal.config.ConfigBuilder ++ ++object CelebornShuffleState { ++ ++ private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = ++ ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") + .booleanConf + .createWithDefault(false) + -+ val CELEBORN_STAGE_RERUN_ENABLED = -+ buildConf("spark.celeborn.client.spark.stageRerun.enabled") ++ private val CELEBORN_STAGE_RERUN_ENABLED = ++ ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled") + .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") -+ .version("3.0.0") + .booleanConf + .createWithDefault(false) + - /** - * Holds information about keys that have been deprecated. - * -@@ -4154,6 +4167,11 @@ class SQLConf extends Serializable with Logging { - - def legacyParquetNanosAsLong: Boolean = getConf(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG) - -+ def celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = -+ getConf(SQLConf.CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ) ++ private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean() ++ private val stageRerunEnabled = new AtomicBoolean() ++ private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() ++ ++ // call this from SparkEnv.create ++ def init(env: SparkEnv): Unit = { ++ // cleanup existing state (if required) - and initialize ++ skewShuffleIds.clear() ++ ++ // use env.conf for all initialization, and not SQLConf ++ celebornOptimizeSkewedPartitionReadEnabled.set( ++ env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ)) ++ stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED)) ++ } ++ ++ def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.remove(shuffleId) ++ } ++ ++ def registerCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.add(shuffleId) ++ } ++ ++ def isCelebornSkewedShuffle(shuffleId: Int): Boolean = { ++ skewShuffleIds.contains(shuffleId) ++ } + -+ def celebornStageRerunEnabled: Boolean = getConf(SQLConf.CELEBORN_STAGE_RERUN_ENABLED) ++ def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = { ++ celebornOptimizeSkewedPartitionReadEnabled.get() ++ } ++ ++ def celebornStageRerunEnabled: Boolean = { ++ stageRerunEnabled.get() ++ } + - /** ********************** SQLConf functionality methods ************ */ ++} +diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +index b950c07f3d8..2cb430c3c3d 100644 +--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +@@ -33,6 +33,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture} - /** Set Spark SQL configuration properties. */ + import org.apache.spark._ + import org.apache.spark.broadcast.Broadcast ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} + import org.apache.spark.internal.Logging + import org.apache.spark.internal.config +@@ -1780,7 +1781,7 @@ private[spark] class DAGScheduler( + failedStage.failedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || +- disallowStageRetryForTest ++ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId) + + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala new file mode 100644 index 00000000000..3dc60678461 @@ -137,7 +197,7 @@ index 00000000000..3dc60678461 + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala -index 1752907a9a5..2a0fe20c104 100644 +index 1752907a9a5..2c6a49b78eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala @@ -50,12 +50,13 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { @@ -150,10 +210,9 @@ index 1752907a9a5..2a0fe20c104 100644 bytesByPartitionId.indices.flatMap { reduceIndex => val bytes = bytesByPartitionId(reduceIndex) if (bytes > targetSize) { -- val newPartitionSpec = + val newPartitionSpec = - ShufflePartitionsUtil.createSkewPartitionSpecs(shuffleId, reduceIndex, targetSize) -+ val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs( -+ shuffleId, reduceIndex, targetSize, isCelebornShuffle) ++ ShufflePartitionsUtil.createSkewPartitionSpecs(shuffleId, reduceIndex, targetSize, isCelebornShuffle) if (newPartitionSpec.isEmpty) { CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil } else { @@ -197,18 +256,18 @@ index 88abe68197b..150699a84a3 100644 logDebug(s"Right side partition $partitionIndex " + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -index 3609548f374..59c80198f19 100644 +index 3609548f374..d34f43bf064 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer +@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive + import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} ++import org.apache.spark.celeborn.CelebornShuffleState import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} -+import org.apache.spark.sql.internal.SQLConf - object ShufflePartitionsUtil extends Logging { - final val SMALL_PARTITION_FACTOR = 0.2 -@@ -376,11 +377,22 @@ object ShufflePartitionsUtil extends Logging { +@@ -376,11 +377,20 @@ object ShufflePartitionsUtil extends Logging { def createSkewPartitionSpecs( shuffleId: Int, reducerId: Int, @@ -219,20 +278,18 @@ index 3609548f374..59c80198f19 100644 if (mapPartitionSizes.exists(_ < 0)) return None val mapStartIndices = splitSizeListByTargetSize(mapPartitionSizes, targetSize) if (mapStartIndices.length > 1) { -+ val isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = -+ SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && -+ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle ++ val celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = ++ CelebornShuffleState.celebornAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle + -+ val throwsFetchFailure = SQLConf.get.celebornStageRerunEnabled -+ if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ val throwsFetchFailure = CelebornShuffleState.celebornStageRerunEnabled ++ if (throwsFetchFailure && celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") -+ val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] -+ mapOutputTracker.skewShuffleIds.add(shuffleId) ++ CelebornShuffleState.registerCelebornSkewedShuffle(shuffleId) + } Some(mapStartIndices.indices.map { i => val startMapIndex = mapStartIndices(i) val endMapIndex = if (i == mapStartIndices.length - 1) { -@@ -388,8 +400,20 @@ object ShufflePartitionsUtil extends Logging { +@@ -388,8 +398,21 @@ object ShufflePartitionsUtil extends Logging { } else { mapStartIndices(i + 1) } @@ -244,7 +301,8 @@ index 3609548f374..59c80198f19 100644 + dataSize += mapPartitionSizes(mapIndex) + mapIndex += 1 + } -+ if (isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ ++ if (celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + // These `dataSize` variables may not be accurate as they only represent the sum of + // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. + // Please not to use these dataSize variables in any other part of the codebase. diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch index 3c194f8ce8d..f8e38615c6a 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch @@ -14,87 +14,147 @@ # limitations under the License. diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala -index b1974948430..0dc92ec44a8 100644 +index b1974948430..a045c8646ba 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala -@@ -696,6 +696,8 @@ private[spark] class MapOutputTrackerMaster( - pool - } +@@ -33,6 +33,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut + import org.roaringbitmap.RoaringBitmap -+ val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() -+ - // Make sure that we aren't going to exceed the max RPC message size by making sure - // we use broadcast to send large map output statuses. - if (minSizeForBroadcast > maxRpcMessageSize) { -@@ -886,6 +888,7 @@ private[spark] class MapOutputTrackerMaster( + import org.apache.spark.broadcast.{Broadcast, BroadcastManager} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.Logging + import org.apache.spark.internal.config._ + import org.apache.spark.io.CompressionCodec +@@ -886,6 +887,7 @@ private[spark] class MapOutputTrackerMaster( shuffleStatus.invalidateSerializedMergeOutputStatusCache() } } -+ skewShuffleIds.remove(shuffleId) ++ CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId) } /** -diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -index bd2823bcac1..4f40becadc7 100644 ---- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -@@ -1404,7 +1404,10 @@ private[spark] class DAGScheduler( - // The operation here can make sure for the partially completed intermediate stage, - // `findMissingPartitions()` returns all partitions every time. - stage match { -- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => -+ case sms: ShuffleMapStage if (stage.isIndeterminate && !sms.isAvailable) || -+ mapOutputTracker.skewShuffleIds.contains(sms.shuffleDep.shuffleId) => -+ logInfo(s"Unregistering shuffle output for stage ${stage.id}" + -+ s" shuffle ${sms.shuffleDep.shuffleId}") - mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) - sms.shuffleDep.newShuffleMergeState() - case _ => -@@ -1851,7 +1854,7 @@ private[spark] class DAGScheduler( - failedStage.failedAttemptIds.add(task.stageAttemptId) - val shouldAbortStage = - failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || -- disallowStageRetryForTest -+ disallowStageRetryForTest || mapOutputTracker.skewShuffleIds.contains(shuffleId) - - // It is likely that we receive multiple FetchFailed for a single stage (because we have - // multiple tasks running concurrently on different executors). In that case, it is -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -index af03ad9a4cb..6c36fb96d58 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -@@ -3784,6 +3784,19 @@ object SQLConf { - .booleanConf - .createWithDefault(false) +diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala +index 19467e7eca1..0ae4990219c 100644 +--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala ++++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala +@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration + import org.apache.spark.annotation.DeveloperApi + import org.apache.spark.api.python.PythonWorkerFactory + import org.apache.spark.broadcast.BroadcastManager ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.executor.ExecutorBackend + import org.apache.spark.internal.{config, Logging} + import org.apache.spark.internal.config._ +@@ -419,6 +420,7 @@ object SparkEnv extends Logging { + if (isDriver) { + val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath + envInstance.driverTmpDir = Some(sparkFilesDir) ++ CelebornShuffleState.init(envInstance) + } -+ val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = -+ buildConf("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") -+ .version("3.0.0") + envInstance +diff --git a/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +new file mode 100644 +index 00000000000..5e190c512df +--- /dev/null ++++ b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +@@ -0,0 +1,75 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.celeborn ++ ++import java.util.concurrent.ConcurrentHashMap ++import java.util.concurrent.atomic.AtomicBoolean ++ ++import org.apache.spark.SparkEnv ++import org.apache.spark.internal.config.ConfigBuilder ++ ++object CelebornShuffleState { ++ ++ private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = ++ ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") + .booleanConf + .createWithDefault(false) + -+ val CELEBORN_STAGE_RERUN_ENABLED = -+ buildConf("spark.celeborn.client.spark.stageRerun.enabled") ++ private val CELEBORN_STAGE_RERUN_ENABLED = ++ ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled") + .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") -+ .version("3.0.0") + .booleanConf + .createWithDefault(false) + - /** - * Holds information about keys that have been deprecated. - * -@@ -4549,6 +4562,11 @@ class SQLConf extends Serializable with Logging { - def histogramNumericPropagateInputType: Boolean = - getConf(SQLConf.HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE) - -+ def celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = -+ getConf(SQLConf.CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ) ++ private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean() ++ private val stageRerunEnabled = new AtomicBoolean() ++ private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() + -+ def celebornStageRerunEnabled: Boolean = getConf(SQLConf.CELEBORN_STAGE_RERUN_ENABLED) ++ // call this from SparkEnv.create ++ def init(env: SparkEnv): Unit = { ++ // cleanup existing state (if required) - and initialize ++ skewShuffleIds.clear() ++ ++ // use env.conf for all initialization, and not SQLConf ++ celebornOptimizeSkewedPartitionReadEnabled.set( ++ env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ)) ++ stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED)) ++ } + - /** ********************** SQLConf functionality methods ************ */ ++ def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.remove(shuffleId) ++ } ++ ++ def registerCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.add(shuffleId) ++ } ++ ++ def isCelebornSkewedShuffle(shuffleId: Int): Boolean = { ++ skewShuffleIds.contains(shuffleId) ++ } ++ ++ def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = { ++ celebornOptimizeSkewedPartitionReadEnabled.get() ++ } ++ ++ def celebornStageRerunEnabled: Boolean = { ++ stageRerunEnabled.get() ++ } ++ ++} +diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +index bd2823bcac1..d0c88081527 100644 +--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +@@ -33,6 +33,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture} + + import org.apache.spark._ + import org.apache.spark.broadcast.Broadcast ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.errors.SparkCoreErrors + import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} + import org.apache.spark.internal.Logging +@@ -1851,7 +1852,7 @@ private[spark] class DAGScheduler( + failedStage.failedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || +- disallowStageRetryForTest ++ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId) - /** Set Spark SQL configuration properties. */ + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala new file mode 100644 index 00000000000..3dc60678461 @@ -198,18 +258,18 @@ index d4a173bb9cc..21ef335e064 100644 logDebug(s"Right side partition $partitionIndex " + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -index af689db3379..529097549ba 100644 +index af689db3379..39d0b3132ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer +@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive + import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} ++import org.apache.spark.celeborn.CelebornShuffleState import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} -+import org.apache.spark.sql.internal.SQLConf - object ShufflePartitionsUtil extends Logging { - final val SMALL_PARTITION_FACTOR = 0.2 -@@ -380,13 +381,23 @@ object ShufflePartitionsUtil extends Logging { +@@ -380,13 +381,21 @@ object ShufflePartitionsUtil extends Logging { shuffleId: Int, reducerId: Int, targetSize: Long, @@ -222,26 +282,24 @@ index af689db3379..529097549ba 100644 val mapStartIndices = splitSizeListByTargetSize( mapPartitionSizes, targetSize, smallPartitionFactor) if (mapStartIndices.length > 1) { -+ val isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = -+ SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && -+ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle ++ val celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = ++ CelebornShuffleState.celebornAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle + -+ val throwsFetchFailure = SQLConf.get.celebornStageRerunEnabled -+ if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ val throwsFetchFailure = CelebornShuffleState.celebornStageRerunEnabled ++ if (throwsFetchFailure && celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") -+ val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] -+ mapOutputTracker.skewShuffleIds.add(shuffleId) ++ CelebornShuffleState.registerCelebornSkewedShuffle(shuffleId) + } Some(mapStartIndices.indices.map { i => val startMapIndex = mapStartIndices(i) val endMapIndex = if (i == mapStartIndices.length - 1) { -@@ -400,7 +411,15 @@ object ShufflePartitionsUtil extends Logging { +@@ -400,7 +409,15 @@ object ShufflePartitionsUtil extends Logging { dataSize += mapPartitionSizes(mapIndex) mapIndex += 1 } - PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) + -+ if (isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ if (celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + // These `dataSize` variables may not be accurate as they only represent the sum of + // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. + // Please not to use these dataSize variables in any other part of the codebase. diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch index 3c194f8ce8d..9aed835fe96 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch @@ -14,87 +14,147 @@ # limitations under the License. diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala -index b1974948430..0dc92ec44a8 100644 +index fade0b86dd8..ca0940a9251 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala -@@ -696,6 +696,8 @@ private[spark] class MapOutputTrackerMaster( - pool - } +@@ -34,6 +34,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut + import org.roaringbitmap.RoaringBitmap -+ val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() -+ - // Make sure that we aren't going to exceed the max RPC message size by making sure - // we use broadcast to send large map output statuses. - if (minSizeForBroadcast > maxRpcMessageSize) { -@@ -886,6 +888,7 @@ private[spark] class MapOutputTrackerMaster( + import org.apache.spark.broadcast.{Broadcast, BroadcastManager} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.Logging + import org.apache.spark.internal.config._ + import org.apache.spark.io.CompressionCodec +@@ -887,6 +888,7 @@ private[spark] class MapOutputTrackerMaster( shuffleStatus.invalidateSerializedMergeOutputStatusCache() } } -+ skewShuffleIds.remove(shuffleId) ++ CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId) } /** -diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -index bd2823bcac1..4f40becadc7 100644 ---- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -@@ -1404,7 +1404,10 @@ private[spark] class DAGScheduler( - // The operation here can make sure for the partially completed intermediate stage, - // `findMissingPartitions()` returns all partitions every time. - stage match { -- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => -+ case sms: ShuffleMapStage if (stage.isIndeterminate && !sms.isAvailable) || -+ mapOutputTracker.skewShuffleIds.contains(sms.shuffleDep.shuffleId) => -+ logInfo(s"Unregistering shuffle output for stage ${stage.id}" + -+ s" shuffle ${sms.shuffleDep.shuffleId}") - mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) - sms.shuffleDep.newShuffleMergeState() - case _ => -@@ -1851,7 +1854,7 @@ private[spark] class DAGScheduler( - failedStage.failedAttemptIds.add(task.stageAttemptId) - val shouldAbortStage = - failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || -- disallowStageRetryForTest -+ disallowStageRetryForTest || mapOutputTracker.skewShuffleIds.contains(shuffleId) - - // It is likely that we receive multiple FetchFailed for a single stage (because we have - // multiple tasks running concurrently on different executors). In that case, it is -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -index af03ad9a4cb..6c36fb96d58 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -@@ -3784,6 +3784,19 @@ object SQLConf { - .booleanConf - .createWithDefault(false) +diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala +index acab9a634fc..23eb72c49ac 100644 +--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala ++++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala +@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration + import org.apache.spark.annotation.DeveloperApi + import org.apache.spark.api.python.PythonWorkerFactory + import org.apache.spark.broadcast.BroadcastManager ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.executor.ExecutorBackend + import org.apache.spark.internal.{config, Logging} + import org.apache.spark.internal.config._ +@@ -419,6 +420,7 @@ object SparkEnv extends Logging { + if (isDriver) { + val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath + envInstance.driverTmpDir = Some(sparkFilesDir) ++ CelebornShuffleState.init(envInstance) + } -+ val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = -+ buildConf("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") -+ .version("3.0.0") + envInstance +diff --git a/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +new file mode 100644 +index 00000000000..5e190c512df +--- /dev/null ++++ b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +@@ -0,0 +1,75 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.celeborn ++ ++import java.util.concurrent.ConcurrentHashMap ++import java.util.concurrent.atomic.AtomicBoolean ++ ++import org.apache.spark.SparkEnv ++import org.apache.spark.internal.config.ConfigBuilder ++ ++object CelebornShuffleState { ++ ++ private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = ++ ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") + .booleanConf + .createWithDefault(false) + -+ val CELEBORN_STAGE_RERUN_ENABLED = -+ buildConf("spark.celeborn.client.spark.stageRerun.enabled") ++ private val CELEBORN_STAGE_RERUN_ENABLED = ++ ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled") + .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") -+ .version("3.0.0") + .booleanConf + .createWithDefault(false) + - /** - * Holds information about keys that have been deprecated. - * -@@ -4549,6 +4562,11 @@ class SQLConf extends Serializable with Logging { - def histogramNumericPropagateInputType: Boolean = - getConf(SQLConf.HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE) - -+ def celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = -+ getConf(SQLConf.CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ) ++ private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean() ++ private val stageRerunEnabled = new AtomicBoolean() ++ private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() ++ ++ // call this from SparkEnv.create ++ def init(env: SparkEnv): Unit = { ++ // cleanup existing state (if required) - and initialize ++ skewShuffleIds.clear() ++ ++ // use env.conf for all initialization, and not SQLConf ++ celebornOptimizeSkewedPartitionReadEnabled.set( ++ env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ)) ++ stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED)) ++ } ++ ++ def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.remove(shuffleId) ++ } ++ ++ def registerCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.add(shuffleId) ++ } ++ ++ def isCelebornSkewedShuffle(shuffleId: Int): Boolean = { ++ skewShuffleIds.contains(shuffleId) ++ } + -+ def celebornStageRerunEnabled: Boolean = getConf(SQLConf.CELEBORN_STAGE_RERUN_ENABLED) ++ def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = { ++ celebornOptimizeSkewedPartitionReadEnabled.get() ++ } ++ ++ def celebornStageRerunEnabled: Boolean = { ++ stageRerunEnabled.get() ++ } + - /** ********************** SQLConf functionality methods ************ */ ++} +diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +index 26be8c72bbc..81feaba962c 100644 +--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +@@ -34,6 +34,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture} + + import org.apache.spark._ + import org.apache.spark.broadcast.Broadcast ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.errors.SparkCoreErrors + import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} + import org.apache.spark.internal.Logging +@@ -1897,7 +1898,7 @@ private[spark] class DAGScheduler( + + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || +- disallowStageRetryForTest ++ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId) - /** Set Spark SQL configuration properties. */ + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala new file mode 100644 index 00000000000..3dc60678461 @@ -170,7 +230,7 @@ index b34ab3e380b..cb0ed9d05a4 100644 if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) { shuffle diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala -index d4a173bb9cc..21ef335e064 100644 +index 37cdea084d8..4694a06919e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -152,8 +152,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) @@ -198,18 +258,18 @@ index d4a173bb9cc..21ef335e064 100644 logDebug(s"Right side partition $partitionIndex " + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -index af689db3379..529097549ba 100644 +index dbed66683b0..d656c8af6b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer +@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive + import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} ++import org.apache.spark.celeborn.CelebornShuffleState import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} -+import org.apache.spark.sql.internal.SQLConf - object ShufflePartitionsUtil extends Logging { - final val SMALL_PARTITION_FACTOR = 0.2 -@@ -380,13 +381,23 @@ object ShufflePartitionsUtil extends Logging { +@@ -380,13 +381,21 @@ object ShufflePartitionsUtil extends Logging { shuffleId: Int, reducerId: Int, targetSize: Long, @@ -222,26 +282,24 @@ index af689db3379..529097549ba 100644 val mapStartIndices = splitSizeListByTargetSize( mapPartitionSizes, targetSize, smallPartitionFactor) if (mapStartIndices.length > 1) { -+ val isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = -+ SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && -+ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle ++ val celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = ++ CelebornShuffleState.celebornAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle + -+ val throwsFetchFailure = SQLConf.get.celebornStageRerunEnabled -+ if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ val throwsFetchFailure = CelebornShuffleState.celebornStageRerunEnabled ++ if (throwsFetchFailure && celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") -+ val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] -+ mapOutputTracker.skewShuffleIds.add(shuffleId) ++ CelebornShuffleState.registerCelebornSkewedShuffle(shuffleId) + } Some(mapStartIndices.indices.map { i => val startMapIndex = mapStartIndices(i) val endMapIndex = if (i == mapStartIndices.length - 1) { -@@ -400,7 +411,15 @@ object ShufflePartitionsUtil extends Logging { +@@ -400,7 +409,15 @@ object ShufflePartitionsUtil extends Logging { dataSize += mapPartitionSizes(mapIndex) mapIndex += 1 } - PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) + -+ if (isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ if (celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + // These `dataSize` variables may not be accurate as they only represent the sum of + // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. + // Please not to use these dataSize variables in any other part of the codebase. diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch index 66544751433..553bdeae668 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch @@ -14,87 +14,147 @@ # limitations under the License. diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala -index 9a7a3b0c0e7..c886263b3eb 100644 +index 9a7a3b0c0e7..543423dadd9 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala -@@ -726,6 +726,8 @@ private[spark] class MapOutputTrackerMaster( +@@ -34,6 +34,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut + import org.roaringbitmap.RoaringBitmap - private val availableProcessors = Runtime.getRuntime.availableProcessors() - -+ val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() -+ - // Make sure that we aren't going to exceed the max RPC message size by making sure - // we use broadcast to send large map output statuses. - if (minSizeForBroadcast > maxRpcMessageSize) { -@@ -916,6 +918,7 @@ private[spark] class MapOutputTrackerMaster( + import org.apache.spark.broadcast.{Broadcast, BroadcastManager} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.Logging + import org.apache.spark.internal.config._ + import org.apache.spark.io.CompressionCodec +@@ -916,6 +917,7 @@ private[spark] class MapOutputTrackerMaster( shuffleStatus.invalidateSerializedMergeOutputStatusCache() } } -+ skewShuffleIds.remove(shuffleId) ++ CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId) } /** -diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -index 89d16e57934..24aad10c2a8 100644 ---- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -@@ -1480,7 +1480,10 @@ private[spark] class DAGScheduler( - // The operation here can make sure for the partially completed intermediate stage, - // `findMissingPartitions()` returns all partitions every time. - stage match { -- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => -+ case sms: ShuffleMapStage if (stage.isIndeterminate && !sms.isAvailable) || -+ mapOutputTracker.skewShuffleIds.contains(sms.shuffleDep.shuffleId) => -+ logInfo(s"Unregistering shuffle output for stage ${stage.id}" + -+ s" shuffle ${sms.shuffleDep.shuffleId}") - mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) - sms.shuffleDep.newShuffleMergeState() - case _ => -@@ -1962,7 +1965,7 @@ private[spark] class DAGScheduler( - - val shouldAbortStage = - failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || -- disallowStageRetryForTest -+ disallowStageRetryForTest || mapOutputTracker.skewShuffleIds.contains(shuffleId) - - // It is likely that we receive multiple FetchFailed for a single stage (because we have - // multiple tasks running concurrently on different executors). In that case, it is -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -index 6f2f0088fcc..3a7b1aabbbc 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala -@@ -4423,6 +4423,19 @@ object SQLConf { - .booleanConf - .createWithDefault(false) +diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala +index edad91a0c6f..76b377729a0 100644 +--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala ++++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala +@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration + import org.apache.spark.annotation.DeveloperApi + import org.apache.spark.api.python.PythonWorkerFactory + import org.apache.spark.broadcast.BroadcastManager ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.executor.ExecutorBackend + import org.apache.spark.internal.{config, Logging} + import org.apache.spark.internal.config._ +@@ -419,6 +420,7 @@ object SparkEnv extends Logging { + if (isDriver) { + val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath + envInstance.driverTmpDir = Some(sparkFilesDir) ++ CelebornShuffleState.init(envInstance) + } -+ val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = -+ buildConf("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") -+ .version("3.0.0") + envInstance +diff --git a/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +new file mode 100644 +index 00000000000..5e190c512df +--- /dev/null ++++ b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +@@ -0,0 +1,75 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.celeborn ++ ++import java.util.concurrent.ConcurrentHashMap ++import java.util.concurrent.atomic.AtomicBoolean ++ ++import org.apache.spark.SparkEnv ++import org.apache.spark.internal.config.ConfigBuilder ++ ++object CelebornShuffleState { ++ ++ private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = ++ ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") + .booleanConf + .createWithDefault(false) + -+ val CELEBORN_STAGE_RERUN_ENABLED = -+ buildConf("spark.celeborn.client.spark.stageRerun.enabled") ++ private val CELEBORN_STAGE_RERUN_ENABLED = ++ ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled") + .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") -+ .version("3.0.0") + .booleanConf + .createWithDefault(false) + - /** - * Holds information about keys that have been deprecated. - * -@@ -5278,6 +5291,11 @@ class SQLConf extends Serializable with Logging with SqlApiConf { - getConf(SQLConf.LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT) - } - -+ def celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = -+ getConf(SQLConf.CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ) ++ private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean() ++ private val stageRerunEnabled = new AtomicBoolean() ++ private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() + -+ def celebornStageRerunEnabled: Boolean = getConf(SQLConf.CELEBORN_STAGE_RERUN_ENABLED) ++ // call this from SparkEnv.create ++ def init(env: SparkEnv): Unit = { ++ // cleanup existing state (if required) - and initialize ++ skewShuffleIds.clear() ++ ++ // use env.conf for all initialization, and not SQLConf ++ celebornOptimizeSkewedPartitionReadEnabled.set( ++ env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ)) ++ stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED)) ++ } + - /** ********************** SQLConf functionality methods ************ */ ++ def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.remove(shuffleId) ++ } ++ ++ def registerCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.add(shuffleId) ++ } ++ ++ def isCelebornSkewedShuffle(shuffleId: Int): Boolean = { ++ skewShuffleIds.contains(shuffleId) ++ } ++ ++ def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = { ++ celebornOptimizeSkewedPartitionReadEnabled.get() ++ } ++ ++ def celebornStageRerunEnabled: Boolean = { ++ stageRerunEnabled.get() ++ } ++ ++} +diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +index 89d16e57934..3b9094f3254 100644 +--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +@@ -34,6 +34,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture} + + import org.apache.spark._ + import org.apache.spark.broadcast.Broadcast ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.errors.SparkCoreErrors + import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} + import org.apache.spark.internal.Logging +@@ -1962,7 +1963,7 @@ private[spark] class DAGScheduler( + + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || +- disallowStageRetryForTest ++ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId) - /** Set Spark SQL configuration properties. */ + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala new file mode 100644 index 00000000000..3dc60678461 @@ -198,18 +258,18 @@ index 37cdea084d8..4694a06919e 100644 logDebug(s"Right side partition $partitionIndex " + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -index 9370b3d8d1d..28d4b5b8d8b 100644 +index 9370b3d8d1d..d36e26a1376 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala -@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer +@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive + import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} ++import org.apache.spark.celeborn.CelebornShuffleState import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} -+import org.apache.spark.sql.internal.SQLConf - object ShufflePartitionsUtil extends Logging { - final val SMALL_PARTITION_FACTOR = 0.2 -@@ -382,13 +383,23 @@ object ShufflePartitionsUtil extends Logging { +@@ -382,13 +383,21 @@ object ShufflePartitionsUtil extends Logging { shuffleId: Int, reducerId: Int, targetSize: Long, @@ -222,26 +282,24 @@ index 9370b3d8d1d..28d4b5b8d8b 100644 val mapStartIndices = splitSizeListByTargetSize( mapPartitionSizes, targetSize, smallPartitionFactor) if (mapStartIndices.length > 1) { -+ val isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = -+ SparkEnv.get.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && -+ SQLConf.get.celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle ++ val celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = ++ CelebornShuffleState.celebornAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle + -+ val throwsFetchFailure = SQLConf.get.celebornStageRerunEnabled -+ if (throwsFetchFailure && isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ val throwsFetchFailure = CelebornShuffleState.celebornStageRerunEnabled ++ if (throwsFetchFailure && celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") -+ val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] -+ mapOutputTracker.skewShuffleIds.add(shuffleId) ++ CelebornShuffleState.registerCelebornSkewedShuffle(shuffleId) + } Some(mapStartIndices.indices.map { i => val startMapIndex = mapStartIndices(i) val endMapIndex = if (i == mapStartIndices.length - 1) { -@@ -402,7 +413,15 @@ object ShufflePartitionsUtil extends Logging { +@@ -402,7 +411,15 @@ object ShufflePartitionsUtil extends Logging { dataSize += mapPartitionSizes(mapIndex) mapIndex += 1 } - PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) + -+ if (isCelebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ if (celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { + // These `dataSize` variables may not be accurate as they only represent the sum of + // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. + // Please not to use these dataSize variables in any other part of the codebase. From 1c5836224a54e1b02d70acff65bb76bef7742090 Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Fri, 24 Jan 2025 11:12:04 +0800 Subject: [PATCH 40/44] fix codestyle --- .../apache/spark/shuffle/celeborn/CelebornShuffleReader.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index d2f8f58dd38..b35e549ae38 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -19,15 +19,13 @@ package org.apache.spark.shuffle.celeborn import java.io.IOException import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Set => JSet} -import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit} import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ -import org.apache.commons.lang3.tuple.Pair import com.google.common.annotations.VisibleForTesting - +import org.apache.commons.lang3.tuple.Pair import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext} import org.apache.spark.celeborn.ExceptionMakerHelper import org.apache.spark.internal.Logging From cbceadf63c30899f0a00b4c67f1676d1daf46532 Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Fri, 24 Jan 2025 11:27:05 +0800 Subject: [PATCH 41/44] fix uts --- .../celeborn/client/ShuffleClientSuiteJ.java | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java index ba0a7c39b9c..477a3c19235 100644 --- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java @@ -417,7 +417,11 @@ public void testUpdateReducerFileGroupInterrupted() throws InterruptedException t -> { Thread.sleep(60 * 1000); return GetReducerFileGroupResponse$.MODULE$.apply( - StatusCode.SUCCESS, locations, new int[0], Collections.emptySet()); + StatusCode.SUCCESS, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap()); }); shuffleClient = @@ -456,7 +460,11 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() { .thenAnswer( t -> { return GetReducerFileGroupResponse$.MODULE$.apply( - StatusCode.SHUFFLE_NOT_REGISTERED, locations, new int[0], Collections.emptySet()); + StatusCode.SHUFFLE_NOT_REGISTERED, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap()); }); shuffleClient = @@ -473,7 +481,11 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() { .thenAnswer( t -> { return GetReducerFileGroupResponse$.MODULE$.apply( - StatusCode.STAGE_END_TIME_OUT, locations, new int[0], Collections.emptySet()); + StatusCode.STAGE_END_TIME_OUT, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap()); }); shuffleClient = @@ -490,7 +502,11 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() { .thenAnswer( t -> { return GetReducerFileGroupResponse$.MODULE$.apply( - StatusCode.SHUFFLE_DATA_LOST, locations, new int[0], Collections.emptySet()); + StatusCode.SHUFFLE_DATA_LOST, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap()); }); shuffleClient = From 21a06357cc238faac94b6c8a66c074ac78fb132e Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Thu, 13 Feb 2025 19:45:34 +0800 Subject: [PATCH 42/44] add method javadoc comment --- .../celeborn/CelebornPartitionUtil.java | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java index 0fbef891b69..6bf47addca9 100644 --- a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java +++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java @@ -24,6 +24,38 @@ import org.apache.celeborn.common.protocol.PartitionLocation; public class CelebornPartitionUtil { + /** + * The general idea is to divide each skew partition into smaller partitions: + * + *

- Spark driver will calculate the number of sub-partitions: {@code subPartitionSize = + * skewPartitionTotalSize / subPartitionTargetSize} + * + *

- In Celeborn, we divide the skew partition into {@code subPartitionSize} small partitions + * by PartitionLocation chunk offsets. This allows them to run in parallel Spark tasks. + * + *

For example, one skewed partition has 2 PartitionLocation: + * + *

    + *
  • PartitionLocation 0 with chunk offset [0L, 100L, 200L, 300L, 500L, 1000L] + *
  • PartitionLocation 1 with chunk offset [0L, 200L, 500L, 800L, 900L, 1000L] + *
+ * + * If we want to divide it into 3 sub-partitions (each sub-partition target size is 2000/3), the + * result will be: + * + *
    + *
  • sub-partition 0: uniqueId0 -> (0, 3) + *
  • sub-partition 1: uniqueId0 -> (4, 4), uniqueId1 -> (0, 0) + *
  • sub-partition 2: uniqueId1 -> (1, 4) + *
+ * + * Note: (0, 3) means chunks with chunkIndex 0-1-2-3, four chunks. + * + * @param locations PartitionLocation information belonging to the reduce partition + * @param subPartitionSize the number of sub-partitions separated from the reduce partition + * @param subPartitionIndex current sub-partition index + * @return a map of partitionUniqueId to chunkRange pairs for one subtask of skew partitions + */ public static Map> splitSkewedPartitionLocations( ArrayList locations, int subPartitionSize, int subPartitionIndex) { locations.sort(Comparator.comparing((PartitionLocation p) -> p.getUniqueId())); From a9a4af13288f7687e058c4b08c2cb75cbf7bc29d Mon Sep 17 00:00:00 2001 From: Shuang Date: Mon, 17 Feb 2025 15:37:58 +0800 Subject: [PATCH 43/44] Update common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala Co-authored-by: Cheng Pan --- .../main/scala/org/apache/celeborn/common/CelebornConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 9cf080bb457..eb750daeb87 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -5919,7 +5919,7 @@ object CelebornConf extends Logging { val CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED: ConfigEntry[Boolean] = buildConf("celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") .categories("client") - .version("0.5.0") + .version("0.6.0") .doc("If this is true, Celeborn will adaptively split skewed partitions instead of reading them by Spark map " + "range. Please note that this feature requires the `Celeborn-Optimize-Skew-Partitions-spark3_3.patch`. ") .booleanConf From cbad6f848f3e602c42f4d2032a9cd4b836ad045e Mon Sep 17 00:00:00 2001 From: Shuang Date: Mon, 17 Feb 2025 15:39:58 +0800 Subject: [PATCH 44/44] Update docs/configuration/client.md --- docs/configuration/client.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration/client.md b/docs/configuration/client.md index ce0dcd11204..197dbb7908e 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -19,7 +19,7 @@ license: | | Key | Default | isDynamic | Description | Since | Deprecated | | --- | ------- | --------- | ----------- | ----- | ---------- | -| celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled | false | false | If this is true, Celeborn will adaptively split skewed partitions instead of reading them by Spark map range. Please note that this feature requires the `Celeborn-Optimize-Skew-Partitions-spark3_3.patch`. | 0.5.0 | | +| celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled | false | false | If this is true, Celeborn will adaptively split skewed partitions instead of reading them by Spark map range. Please note that this feature requires the `Celeborn-Optimize-Skew-Partitions-spark3_3.patch`. | 0.6.0 | | | celeborn.client.application.heartbeatInterval | 10s | false | Interval for client to send heartbeat message to master. | 0.3.0 | celeborn.application.heartbeatInterval | | celeborn.client.application.unregister.enabled | true | false | When true, Celeborn client will inform celeborn master the application is already shutdown during client exit, this allows the cluster to release resources immediately, resulting in resource savings. | 0.3.2 | | | celeborn.client.application.uuidSuffix.enabled | false | false | Whether to add UUID suffix for application id for unique. When `true`, add UUID suffix for unique application id. Currently, this only applies to Spark and MR. | 0.6.0 | |