diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java index 6fc2159431..e3da2b063e 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java @@ -53,14 +53,18 @@ public class MutableShuffleHandleInfo extends ShuffleHandleInfoBase { private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers; private Map<String, Set<ShuffleServerInfo>> excludedServerToReplacements; + /** + * partitionId -> excluded server -> replacement servers. The replacement servers for exclude + * server of specific partition. + */ + private Map<Integer, Map<String, Set<ShuffleServerInfo>>> + excludedServerForPartitionToReplacements; public MutableShuffleHandleInfo( int shuffleId, Map<Integer, List<ShuffleServerInfo>> partitionToServers, RemoteStorageInfo storageInfo) { - super(shuffleId, storageInfo); - this.excludedServerToReplacements = new HashMap<>(); - this.partitionReplicaAssignedServers = toPartitionReplicaMapping(partitionToServers); + this(shuffleId, storageInfo, toPartitionReplicaMapping(partitionToServers)); } @VisibleForTesting @@ -70,6 +74,7 @@ protected MutableShuffleHandleInfo( Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers) { super(shuffleId, storageInfo); this.excludedServerToReplacements = new HashMap<>(); + this.excludedServerForPartitionToReplacements = new HashMap<>(); this.partitionReplicaAssignedServers = partitionReplicaAssignedServers; } @@ -77,7 +82,7 @@ public MutableShuffleHandleInfo(int shuffleId, RemoteStorageInfo storageInfo) { super(shuffleId, storageInfo); } - private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> toPartitionReplicaMapping( + private static Map<Integer, Map<Integer, List<ShuffleServerInfo>>> toPartitionReplicaMapping( Map<Integer, List<ShuffleServerInfo>> partitionToServers) { Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers = new HashMap<>(); @@ -102,6 +107,21 @@ public Set<ShuffleServerInfo> getReplacements(String faultyServerId) { return excludedServerToReplacements.get(faultyServerId); } + public Set<ShuffleServerInfo> getReplacementsForPartition( + int partitionId, String excludedServerId) { + return excludedServerForPartitionToReplacements + .getOrDefault(partitionId, Collections.emptyMap()) + .getOrDefault(excludedServerId, Collections.emptySet()); + } + + /** + * Update the assignment for the receiving failure server of the given partition. + * + * @param partitionId the partition id + * @param receivingFailureServerId the id of the receiving failure server + * @param replacements the new assigned servers for replacing the receiving failure server + * @return the updated server list for receiving data + */ public Set<ShuffleServerInfo> updateAssignment( int partitionId, String receivingFailureServerId, Set<ShuffleServerInfo> replacements) { if (replacements == null || StringUtils.isEmpty(receivingFailureServerId)) { @@ -109,6 +129,11 @@ public Set<ShuffleServerInfo> updateAssignment( } excludedServerToReplacements.put(receivingFailureServerId, replacements); + return updateAssignmentInternal(partitionId, receivingFailureServerId, replacements); + } + + private Set<ShuffleServerInfo> updateAssignmentInternal( + int partitionId, String receivingFailureServerId, Set<ShuffleServerInfo> replacements) { Set<ShuffleServerInfo> updatedServers = new HashSet<>(); Map<Integer, List<ShuffleServerInfo>> replicaServers = partitionReplicaAssignedServers.get(partitionId); @@ -131,6 +156,26 @@ public Set<ShuffleServerInfo> updateAssignment( return updatedServers; } + /** + * Update the assignment for the receiving failure server of the need split partition. + * + * @param partitionId the partition id + * @param receivingFailureServerId the id of the receiving failure server + * @param replacements the new assigned servers for replacing the receiving failure server + * @return the updated server list for receiving data + */ + public Set<ShuffleServerInfo> updateAssignmentOnPartitionSplit( + int partitionId, String receivingFailureServerId, Set<ShuffleServerInfo> replacements) { + if (replacements == null || StringUtils.isEmpty(receivingFailureServerId)) { + return Collections.emptySet(); + } + excludedServerForPartitionToReplacements + .computeIfAbsent(partitionId, x -> new HashMap<>()) + .put(receivingFailureServerId, replacements); + + return updateAssignmentInternal(partitionId, receivingFailureServerId, replacements); + } + @Override public Set<ShuffleServerInfo> getServers() { return partitionReplicaAssignedServers.values().stream() @@ -149,6 +194,7 @@ public Map<Integer, List<ShuffleServerInfo>> getAvailablePartitionServersForWrit replicaServers.entrySet()) { ShuffleServerInfo candidate; int candidateSize = replicaServerEntry.getValue().size(); + // Use the last one for each replica writing candidate = replicaServerEntry.getValue().get(candidateSize - 1); assignment.computeIfAbsent(partitionId, x -> new ArrayList<>()).add(candidate); } @@ -266,4 +312,10 @@ public static MutableShuffleHandleInfo fromProto(RssProtos.MutableShuffleHandleI handle.partitionReplicaAssignedServers = partitionToServers; return handle; } + + public Set<String> listExcludedServersForPartition(int partitionId) { + return excludedServerForPartitionToReplacements + .getOrDefault(partitionId, Collections.emptyMap()) + .keySet(); + } } diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java index 99f7a7421b..c0d4b89f57 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java @@ -31,7 +31,8 @@ public interface ShuffleHandleInfo { /** * Get the assignment of available servers for writer to write partitioned blocks to corresponding - * shuffleServers. Implementations might return dynamic, up-to-date information here. + * shuffleServers. Implementations might return dynamic, up-to-date information here. Returns + * partitionId -> [replica1, replica2, ...] */ Map<Integer, List<ShuffleServerInfo>> getAvailablePartitionServersForWriter(); diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index 11d81ed31d..8e921c66e7 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -737,7 +737,8 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure( int stageId, int stageAttemptNumber, int shuffleId, - Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers) { + Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers, + boolean partitionSplit) { long startTime = System.currentTimeMillis(); ShuffleHandleInfo handleInfo = shuffleHandleInfoManager.get(shuffleId); MutableShuffleHandleInfo internalHandle = null; @@ -754,8 +755,11 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure( synchronized (internalHandle) { // If the reassignment servers for one partition exceeds the max reassign server num, // it should fast fail. - internalHandle.checkPartitionReassignServerNum( - partitionToFailureServers.keySet(), partitionReassignMaxServerNum); + if (!partitionSplit) { + // Do not check the partition reassign server num for partition split case + internalHandle.checkPartitionReassignServerNum( + partitionToFailureServers.keySet(), partitionReassignMaxServerNum); + } Map<ShuffleServerInfo, List<PartitionRange>> newServerToPartitions = new HashMap<>(); // receivingFailureServer -> partitionId -> replacementServerIds. For logging @@ -769,27 +773,44 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure( String serverId = receivingFailureServer.getServerId(); boolean serverHasReplaced = false; - Set<ShuffleServerInfo> replacements = internalHandle.getReplacements(serverId); - if (CollectionUtils.isEmpty(replacements)) { - final int requiredServerNum = 1; - Set<String> excludedServers = new HashSet<>(internalHandle.listExcludedServers()); - excludedServers.add(serverId); - replacements = - reassignServerForTask( - stageId, - stageAttemptNumber, - shuffleId, - Sets.newHashSet(partitionId), - excludedServers, - requiredServerNum, - true); + + Set<ShuffleServerInfo> updatedReassignServers; + if (!partitionSplit) { + Set<ShuffleServerInfo> replacements = internalHandle.getReplacements(serverId); + if (CollectionUtils.isEmpty(replacements)) { + replacements = + requestReassignServer( + stageId, + stageAttemptNumber, + shuffleId, + internalHandle, + partitionId, + serverId); + } else { + serverHasReplaced = true; + } + updatedReassignServers = + internalHandle.updateAssignment(partitionId, serverId, replacements); } else { - serverHasReplaced = true; + Set<ShuffleServerInfo> replacements = + internalHandle.getReplacementsForPartition(partitionId, serverId); + if (CollectionUtils.isEmpty(replacements)) { + replacements = + requestReassignServer( + stageId, + stageAttemptNumber, + shuffleId, + internalHandle, + partitionId, + serverId); + } else { + serverHasReplaced = true; + } + updatedReassignServers = + internalHandle.updateAssignmentOnPartitionSplit( + partitionId, serverId, replacements); } - Set<ShuffleServerInfo> updatedReassignServers = - internalHandle.updateAssignment(partitionId, serverId, replacements); - if (!updatedReassignServers.isEmpty()) { reassignResult .computeIfAbsent(serverId, x -> new HashMap<>()) @@ -825,6 +846,31 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure( } } + private Set<ShuffleServerInfo> requestReassignServer( + int stageId, + int stageAttemptNumber, + int shuffleId, + MutableShuffleHandleInfo internalHandle, + int partitionId, + String serverId) { + Set<ShuffleServerInfo> replacements; + final int requiredServerNum = 1; + Set<String> excludedServers = new HashSet<>(internalHandle.listExcludedServers()); + // Exclude the servers that has already been replaced for partition split case. + excludedServers.addAll(internalHandle.listExcludedServersForPartition(partitionId)); + excludedServers.add(serverId); + replacements = + reassignServerForTask( + stageId, + stageAttemptNumber, + shuffleId, + Sets.newHashSet(partitionId), + excludedServers, + requiredServerNum, + true); + return replacements; + } + @Override public void stop() { if (managerClientSupplier != null diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java index 77379efb5f..5abb6b832c 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java @@ -86,5 +86,6 @@ MutableShuffleHandleInfo reassignOnBlockSendFailure( int stageId, int stageAttemptNumber, int shuffleId, - Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers); + Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers, + boolean partitionSplit); } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java index bc2812ae9d..cbc66105bb 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java @@ -267,12 +267,13 @@ public void reassignOnBlockSendFailure( RssProtos.ReassignOnBlockSendFailureResponse reply; try { LOG.info( - "Accepted reassign request on block sent failure for shuffleId: {}, stageId: {}, stageAttemptNumber: {} from taskAttemptId: {} on executorId: {}", + "Accepted reassign request on block sent failure for shuffleId: {}, stageId: {}, stageAttemptNumber: {} from taskAttemptId: {} on executorId: {} while partition split:{}", request.getShuffleId(), request.getStageId(), request.getStageAttemptNumber(), request.getTaskAttemptId(), - request.getExecutorId()); + request.getExecutorId(), + request.getPartitionSplit()); MutableShuffleHandleInfo handle = shuffleManager.reassignOnBlockSendFailure( request.getStageId(), @@ -281,7 +282,8 @@ public void reassignOnBlockSendFailure( request.getFailurePartitionToServerIdsMap().entrySet().stream() .collect( Collectors.toMap( - Map.Entry::getKey, x -> ReceivingFailureServer.fromProto(x.getValue())))); + Map.Entry::getKey, x -> ReceivingFailureServer.fromProto(x.getValue()))), + request.getPartitionSplit()); code = RssProtos.StatusCode.SUCCESS; reply = RssProtos.ReassignOnBlockSendFailureResponse.newBuilder() diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java index e861923584..3a24f301bf 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -142,4 +143,65 @@ public void testCreatePartitionReplicaTracking() { assertEquals(b, inventory.get(1).get(1).get(0)); assertEquals(c, inventory.get(2).get(0).get(0)); } + + @Test + public void testUpdateAssignmentOnPartitionSplit() { + Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>(); + partitionToServers.put(1, Arrays.asList(createFakeServerInfo("a"), createFakeServerInfo("b"))); + partitionToServers.put(2, Arrays.asList(createFakeServerInfo("c"))); + + MutableShuffleHandleInfo handleInfo = + new MutableShuffleHandleInfo(1, partitionToServers, new RemoteStorageInfo("")); + + // case1: update the replacement servers but has existing servers + Set<ShuffleServerInfo> updated = + handleInfo.updateAssignment( + 1, "a", Sets.newHashSet(createFakeServerInfo("a"), createFakeServerInfo("d"))); + assertTrue(updated.stream().findFirst().get().getId().equals("d")); + + // case2: update when having multiple servers + Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers = + new HashMap<>(); + List<ShuffleServerInfo> servers = + new ArrayList<>( + Arrays.asList( + createFakeServerInfo("a"), + createFakeServerInfo("b"), + createFakeServerInfo("c"), + createFakeServerInfo("d"))); + partitionReplicaAssignedServers + .computeIfAbsent(1, x -> new HashMap<>()) + .computeIfAbsent(0, x -> servers); + handleInfo = + new MutableShuffleHandleInfo(1, new RemoteStorageInfo(""), partitionReplicaAssignedServers); + + Map<Integer, List<ShuffleServerInfo>> availablePartitionServers = + handleInfo.getAvailablePartitionServersForWriter(); + assertEquals("d", availablePartitionServers.get(1).get(0).getHost()); + Map<Integer, List<ShuffleServerInfo>> assignment = handleInfo.getAllPartitionServersForReader(); + assertEquals(4, assignment.get(1).size()); + + int partitionId = 1; + + handleInfo.getReplacementsForPartition(1, "a"); + HashSet<ShuffleServerInfo> replacements = + Sets.newHashSet( + createFakeServerInfo("b"), + createFakeServerInfo("d"), + createFakeServerInfo("e"), + createFakeServerInfo("f")); + updated = handleInfo.updateAssignmentOnPartitionSplit(partitionId, "a", replacements); + assertEquals(updated, Sets.newHashSet(createFakeServerInfo("e"), createFakeServerInfo("f"))); + + Set<String> excludedServers = handleInfo.listExcludedServersForPartition(partitionId); + assertEquals(1, excludedServers.size()); + assertEquals("a", excludedServers.iterator().next()); + assertEquals(replacements, handleInfo.getReplacementsForPartition(1, "a")); + availablePartitionServers = handleInfo.getAvailablePartitionServersForWriter(); + // The current writer is the last one + assertEquals("f", availablePartitionServers.get(1).get(0).getHost()); + assignment = handleInfo.getAllPartitionServersForReader(); + // All the servers were selected as writer are available as reader + assertEquals(6, assignment.get(1).size()); + } } diff --git a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java index 317d0cd9ea..8e6a307c57 100644 --- a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java +++ b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java @@ -80,7 +80,8 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure( int stageId, int stageAttemptNumber, int shuffleId, - Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers) { + Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers, + boolean partitionSplit) { return null; } } diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index b3a5ccf099..decefc6d37 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -556,6 +556,8 @@ private void collectFailedBlocksToResend() { return; } + reassignOnPartitionNeedSplit(failedTracker); + Set<Long> failedBlockIds = failedTracker.getFailedBlockIds(); if (CollectionUtils.isEmpty(failedBlockIds)) { return; @@ -619,8 +621,26 @@ private void collectFailedBlocksToResend() { reassignAndResendBlocks(resendCandidates); } + private void reassignOnPartitionNeedSplit(FailedBlockSendTracker failedTracker) { + Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers = new HashMap<>(); + + failedTracker + .removeAllTrackedPartitions() + .forEach( + partitionStatus -> + failurePartitionToServers + .computeIfAbsent(partitionStatus.getPartitionId(), x -> new ArrayList<>()) + .add( + new ReceivingFailureServer( + partitionStatus.getShuffleServerInfo().getId(), StatusCode.SUCCESS))); + if (!failurePartitionToServers.isEmpty()) { + doReassignOnBlockSendFailure(failurePartitionToServers, true); + } + } + private void doReassignOnBlockSendFailure( - Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers) { + Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers, + boolean partitionSplit) { LOG.info( "Initiate reassignOnBlockSendFailure. failure partition servers: {}", failurePartitionToServers); @@ -636,7 +656,8 @@ private void doReassignOnBlockSendFailure( executorId, taskAttemptId, stageId, - stageAttemptNum); + stageAttemptNum, + partitionSplit); RssReassignOnBlockSendFailureResponse response = managerClientSupplier.get().reassignOnBlockSendFailure(request); if (response.getStatusCode() != StatusCode.SUCCESS) { @@ -681,7 +702,12 @@ private void reassignAndResendBlocks(Set<TrackingBlockStatus> blocks) { serverBlocks.entrySet()) { String serverId = blockStatusEntry.getKey().getId(); // avoid duplicate reassign for the same failure server. - String latestServerId = getPartitionAssignedServers(partitionId).get(0).getId(); + // todo: getting the replacement should support multi replica. + List<ShuffleServerInfo> servers = getPartitionAssignedServers(partitionId); + // Gets the first replica for this partition for now. + // It can not work if we want to use multiple replicas. + ShuffleServerInfo replacement = servers.get(0); + String latestServerId = replacement.getId(); if (!serverId.equals(latestServerId)) { continue; } @@ -693,13 +719,16 @@ private void reassignAndResendBlocks(Set<TrackingBlockStatus> blocks) { } if (!failurePartitionToServers.isEmpty()) { - doReassignOnBlockSendFailure(failurePartitionToServers); + doReassignOnBlockSendFailure(failurePartitionToServers, false); } for (TrackingBlockStatus blockStatus : blocks) { ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo(); // todo: getting the replacement should support multi replica. - ShuffleServerInfo replacement = getPartitionAssignedServers(block.getPartitionId()).get(0); + List<ShuffleServerInfo> servers = getPartitionAssignedServers(block.getPartitionId()); + // Gets the first replica for this partition for now. + // It can not work if we want to use multiple replicas. + ShuffleServerInfo replacement = servers.get(0); if (blockStatus.getShuffleServerInfo().getId().equals(replacement.getId())) { throw new RssException( "No available replacement server for: " + blockStatus.getShuffleServerInfo().getId()); diff --git a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java index c0ff6d5bdb..12856faf71 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java @@ -21,6 +21,8 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -28,6 +30,7 @@ import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.rpc.StatusCode; public class FailedBlockSendTracker { @@ -40,8 +43,11 @@ public class FailedBlockSendTracker { */ private Map<Long, List<TrackingBlockStatus>> trackingBlockStatusMap; + private final BlockingQueue<TrackingPartitionStatus> trackingNeedSplitPartitionStatusQueue; + public FailedBlockSendTracker() { this.trackingBlockStatusMap = Maps.newConcurrentMap(); + this.trackingNeedSplitPartitionStatusQueue = new LinkedBlockingQueue<>(); } public void add( @@ -56,6 +62,8 @@ public void add( public void merge(FailedBlockSendTracker failedBlockSendTracker) { this.trackingBlockStatusMap.putAll(failedBlockSendTracker.trackingBlockStatusMap); + this.trackingNeedSplitPartitionStatusQueue.addAll( + failedBlockSendTracker.trackingNeedSplitPartitionStatusQueue); } public void remove(long blockId) { @@ -72,6 +80,7 @@ public void clearAndReleaseBlockResources() { } }); trackingBlockStatusMap.clear(); + trackingNeedSplitPartitionStatusQueue.clear(); } public Set<Long> getFailedBlockIds() { @@ -94,4 +103,20 @@ public Set<ShuffleServerInfo> getFaultyShuffleServers() { }); return shuffleServerInfos; } + + public void addNeedSplitPartition(int partitionId, ShuffleServerInfo shuffleServerInfo) { + try { + trackingNeedSplitPartitionStatusQueue.put( + new TrackingPartitionStatus(partitionId, shuffleServerInfo)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RssException(e); + } + } + + public List<TrackingPartitionStatus> removeAllTrackedPartitions() { + List<TrackingPartitionStatus> trackingPartitionStatusList = Lists.newArrayList(); + trackingNeedSplitPartitionStatusQueue.drainTo(trackingPartitionStatusList); + return trackingPartitionStatusList; + } } diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index ac93d57b1e..b78f06bd3c 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -218,6 +218,8 @@ private boolean sendShuffleDataAsync( .forEach( blockId -> blockIdsSendSuccessTracker.get(blockId).incrementAndGet()); + recordNeedSplitPartition( + failedBlockSendTracker, ssi, response.getNeedSplitPartitionIds()); if (defectiveServers != null) { defectiveServers.remove(ssi); } @@ -281,6 +283,16 @@ void recordFailedBlocks( .forEach(block -> blockIdsSendFailTracker.add(block, shuffleServerInfo, statusCode)); } + void recordNeedSplitPartition( + FailedBlockSendTracker blockIdsSendFailTracker, + ShuffleServerInfo shuffleServerInfo, + Set<Integer> needSplitPartitions) { + if (needSplitPartitions != null) { + needSplitPartitions.forEach( + partition -> blockIdsSendFailTracker.addNeedSplitPartition(partition, shuffleServerInfo)); + } + } + void genServerToBlocks( ShuffleBlockInfo sbi, List<ShuffleServerInfo> serverList, @@ -322,6 +334,7 @@ void genServerToBlocks( } @Override + @VisibleForTesting public SendShuffleDataResult sendShuffleData( String appId, List<ShuffleBlockInfo> shuffleBlockInfoList, diff --git a/client/src/main/java/org/apache/uniffle/client/impl/TrackingPartitionStatus.java b/client/src/main/java/org/apache/uniffle/client/impl/TrackingPartitionStatus.java new file mode 100644 index 0000000000..d734661f5f --- /dev/null +++ b/client/src/main/java/org/apache/uniffle/client/impl/TrackingPartitionStatus.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.impl; + +import org.apache.uniffle.common.ShuffleServerInfo; + +public class TrackingPartitionStatus { + private int partitionId; + private ShuffleServerInfo shuffleServerInfo; + + public TrackingPartitionStatus(int partitionId, ShuffleServerInfo shuffleServerInfo) { + this.shuffleServerInfo = shuffleServerInfo; + this.partitionId = partitionId; + } + + public ShuffleServerInfo getShuffleServerInfo() { + return shuffleServerInfo; + } + + public int getPartitionId() { + return partitionId; + } +} diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index 98180d647c..b0720756ae 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -31,6 +31,7 @@ import com.google.protobuf.ByteString; import com.google.protobuf.UnsafeByteOperations; import io.netty.buffer.Unpooled; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -280,17 +281,18 @@ public long requirePreAllocation( int retryMax, long retryIntervalMax) { return requirePreAllocation( - appId, - shuffleId, - partitionIds, - partitionRequireSizes, - requireSize, - retryMax, - retryIntervalMax, - new AtomicReference<>(StatusCode.INTERNAL_ERROR)); + appId, + shuffleId, + partitionIds, + partitionRequireSizes, + requireSize, + retryMax, + retryIntervalMax, + new AtomicReference<>(StatusCode.INTERNAL_ERROR)) + .getLeft(); } - public long requirePreAllocation( + public Pair<Long, List<Integer>> requirePreAllocation( String appId, int shuffleId, List<Integer> partitionIds, @@ -311,6 +313,7 @@ public long requirePreAllocation( long start = System.currentTimeMillis(); int retry = 0; long result = FAILED_REQUIRE_ID; + List<Integer> needSplitPartitionIds = Collections.emptyList(); if (LOG.isDebugEnabled()) { LOG.debug( "Requiring buffer for appId: {}, shuffleId: {}, partitionIds: {} with {} bytes from {}:{}", @@ -328,7 +331,7 @@ public long requirePreAllocation( } catch (Exception e) { LOG.error( "Exception happened when requiring pre-allocated buffer from {}:{}", host, port, e); - return result; + return Pair.of(result, needSplitPartitionIds); } if (rpcResponse.getStatus() != NO_BUFFER && rpcResponse.getStatus() != RssProtos.StatusCode.NO_BUFFER_FOR_HUGE_PARTITION) { @@ -348,7 +351,7 @@ public long requirePreAllocation( + retryMax + " times, cost: {}(ms)", System.currentTimeMillis() - start); - return result; + return Pair.of(result, needSplitPartitionIds); } try { LOG.info( @@ -382,6 +385,7 @@ public long requirePreAllocation( System.currentTimeMillis() - start); } result = rpcResponse.getRequireBufferId(); + needSplitPartitionIds = rpcResponse.getNeedSplitPartitionIdsList(); } else if (NOT_RETRY_STATUS_CODES.contains( StatusCode.fromCode(rpcResponse.getStatus().getNumber()))) { failedStatusCodeRef.set(StatusCode.fromCode(rpcResponse.getStatus().getNumber())); @@ -398,7 +402,7 @@ public long requirePreAllocation( + rpcResponse.getRetMsg(); throw new NotRetryException(msg); } - return result; + return Pair.of(result, needSplitPartitionIds); } private RssProtos.ShuffleUnregisterByAppIdResponse doUnregisterShuffleByAppId( @@ -564,16 +568,18 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ try { RetryUtils.retryWithCondition( () -> { + // TODO(baoloongmao): support partition split follow netty client long requireId = requirePreAllocation( - appId, - shuffleId, - partitionIds, - partitionRequireSizes, - allocateSize, - request.getRetryMax() / maxRetryAttempts, - request.getRetryIntervalMax(), - failedStatusCode); + appId, + shuffleId, + partitionIds, + partitionRequireSizes, + allocateSize, + request.getRetryMax() / maxRetryAttempts, + request.getRetryIntervalMax(), + failedStatusCode) + .getLeft(); if (requireId == FAILED_REQUIRE_ID) { throw new RssException( String.format( diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java index 723b0ecf34..a89ea9c4c6 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java @@ -20,10 +20,13 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import com.google.common.annotations.VisibleForTesting; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.hbase.thirdparty.org.glassfish.jersey.internal.guava.Sets; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -127,7 +130,7 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ int stageAttemptNumber = request.getStageAttemptNumber(); boolean isSuccessful = true; AtomicReference<StatusCode> failedStatusCode = new AtomicReference<>(StatusCode.INTERNAL_ERROR); - + Set<Integer> needSplitPartitionIds = Sets.newHashSet(); for (Map.Entry<Integer, Map<Integer, List<ShuffleBlockInfo>>> stb : shuffleIdToBlocks.entrySet()) { int shuffleId = stb.getKey(); @@ -161,7 +164,7 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ RetryUtils.retryWithCondition( () -> { final TransportClient transportClient = getTransportClient(); - long requireId = + Pair<Long, List<Integer>> result = requirePreAllocation( request.getAppId(), shuffleId, @@ -171,6 +174,8 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ request.getRetryMax(), request.getRetryIntervalMax(), failedStatusCode); + long requireId = result.getLeft(); + needSplitPartitionIds.addAll(result.getRight()); if (requireId == FAILED_REQUIRE_ID) { throw new RssException( String.format( @@ -232,6 +237,7 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ } else { response = new RssSendShuffleDataResponse(failedStatusCode.get()); } + response.setNeedSplitPartitionIds(needSplitPartitionIds); return response; } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignOnBlockSendFailureRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignOnBlockSendFailureRequest.java index 303499fb46..cd3a2ab84f 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignOnBlockSendFailureRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignOnBlockSendFailureRequest.java @@ -25,6 +25,7 @@ import org.apache.uniffle.proto.RssProtos; public class RssReassignOnBlockSendFailureRequest { + private final boolean partitionSplit; private int shuffleId; private Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers; private String executorId; @@ -38,13 +39,15 @@ public RssReassignOnBlockSendFailureRequest( String executorId, long taskAttemptId, int stageId, - int stageAttemptNum) { + int stageAttemptNum, + boolean partitionSplit) { this.shuffleId = shuffleId; this.failurePartitionToServers = failurePartitionToServers; this.executorId = executorId; this.taskAttemptId = taskAttemptId; this.stageId = stageId; this.stageAttemptNumber = stageAttemptNum; + this.partitionSplit = partitionSplit; } public static RssProtos.RssReassignOnBlockSendFailureRequest toProto( @@ -60,6 +63,7 @@ public static RssProtos.RssReassignOnBlockSendFailureRequest toProto( .setStageId(request.stageId) .setStageAttemptNumber(request.stageAttemptNumber) .setTaskAttemptId(request.taskAttemptId) + .setPartitionSplit(request.partitionSplit) .build(); } } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/response/RssSendShuffleDataResponse.java b/internal-client/src/main/java/org/apache/uniffle/client/response/RssSendShuffleDataResponse.java index eb20dd8a0d..917d740735 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/response/RssSendShuffleDataResponse.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/response/RssSendShuffleDataResponse.java @@ -18,6 +18,7 @@ package org.apache.uniffle.client.response; import java.util.List; +import java.util.Set; import org.apache.uniffle.common.rpc.StatusCode; @@ -25,6 +26,7 @@ public class RssSendShuffleDataResponse extends ClientResponse { private List<Long> successBlockIds; private List<Long> failedBlockIds; + private Set<Integer> needSplitPartitionIds; public RssSendShuffleDataResponse(StatusCode statusCode) { super(statusCode); @@ -45,4 +47,12 @@ public List<Long> getFailedBlockIds() { public void setFailedBlockIds(List<Long> failedBlockIds) { this.failedBlockIds = failedBlockIds; } + + public void setNeedSplitPartitionIds(Set<Integer> needSplitPartitionIds) { + this.needSplitPartitionIds = needSplitPartitionIds; + } + + public Set<Integer> getNeedSplitPartitionIds() { + return needSplitPartitionIds; + } } diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index 5e8cc632d5..e883eddc43 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -64,6 +64,8 @@ message RequireBufferResponse { int64 requireBufferId = 1; StatusCode status = 2; string retMsg = 3; + // need split partitions + repeated int32 needSplitPartitionIds = 4; } message ShuffleDataBlockSegment { @@ -669,6 +671,7 @@ message RssReassignOnBlockSendFailureRequest { int32 stageId = 4; int32 stageAttemptNumber = 5; string executorId = 6; + optional bool partitionSplit = 7; } message ReceivingFailureServers { diff --git a/server/src/main/java/org/apache/uniffle/server/HugePartitionUtils.java b/server/src/main/java/org/apache/uniffle/server/HugePartitionUtils.java index f2501cd420..59a010cfff 100644 --- a/server/src/main/java/org/apache/uniffle/server/HugePartitionUtils.java +++ b/server/src/main/java/org/apache/uniffle/server/HugePartitionUtils.java @@ -147,4 +147,9 @@ public static boolean limitHugePartition( } return false; } + + public static boolean hasExceedPartitionSplitLimit( + ShuffleBufferManager shuffleBufferManager, long usedPartitionDataSize) { + return usedPartitionDataSize > shuffleBufferManager.getHugePartitionSplitLimit(); + } } diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java index 310afcb231..677127b3b4 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java @@ -516,6 +516,16 @@ public class ShuffleServerConf extends RssBaseConf { + "will be terminated. This helps to significantly improve the " + "stability of the cluster by preventing partitions from becoming too large."); + public static final ConfigOption<Long> HUGE_PARTITION_SPLIT_LIMIT = + ConfigOptions.key("rss.server.huge-partition.split.limit") + .longType() + .defaultValue(Long.MAX_VALUE) + .withDescription( + "This option sets the maximum partition slice size threshold. " + + "If the partition size exceeds this threshold, the rss client will " + + "receive the need split partition list and resend the failed blocks to " + + "new servers through reassign mechanism."); + public static final ConfigOption<Long> SERVER_DECOMMISSION_CHECK_INTERVAL = ConfigOptions.key("rss.server.decommission.check.interval") .longType() diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index 85fe7fa6e7..303a4a7d20 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -34,6 +35,7 @@ import io.grpc.stub.StreamObserver; import io.netty.buffer.ByteBuf; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -727,21 +729,24 @@ public void requireBuffer( long requireBufferId = -1; String responseMessage = ""; String shuffleDataInfo = "appId[" + appId + "], shuffleId[" + request.getShuffleId() + "]"; + List<Integer> needSplitPartitionIds = Collections.emptyList(); try { if (StringUtils.isEmpty(appId)) { // To be compatible with older client version requireBufferId = shuffleServer.getShuffleTaskManager().requireBuffer(request.getRequireSize()); } else { - requireBufferId = + Pair<Long, List<Integer>> pair = shuffleServer .getShuffleTaskManager() - .requireBuffer( + .requireBufferReturnPair( appId, request.getShuffleId(), request.getPartitionIdsList(), request.getPartitionRequireSizesList(), request.getRequireSize()); + requireBufferId = pair.getLeft(); + needSplitPartitionIds = pair.getRight(); } } catch (NoBufferException e) { responseMessage = e.getMessage(); @@ -775,6 +780,7 @@ public void requireBuffer( .setStatus(status.toProto()) .setRequireBufferId(requireBufferId) .setRetMsg(responseMessage) + .addAllNeedSplitPartitionIds(needSplitPartitionIds) .build(); responseObserver.onNext(response); responseObserver.onCompleted(); diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java index af37646a78..18dee6b0d3 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java @@ -48,6 +48,7 @@ import com.google.common.collect.Range; import com.google.common.collect.Sets; import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.tuple.Pair; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -581,7 +582,17 @@ public long getPartitionDataSize(String appId, int shuffleId, int partitionId) { return shuffleTaskInfo.getPartitionDataSize(shuffleId, partitionId); } - public long requireBuffer( + /** + * Require buffer for shuffle data + * + * @param appId the appId + * @param shuffleId the shuffleId + * @param partitionIds the partitionIds + * @param partitionRequireSizes the partitionRequireSizes + * @param requireSize the requireSize + * @return returns (requireId, splitPartitionIds) + */ + public Pair<Long, List<Integer>> requireBufferReturnPair( String appId, int shuffleId, List<Integer> partitionIds, @@ -592,6 +603,7 @@ public long requireBuffer( LOG.error("No such app is registered. appId: {}, shuffleId: {}", appId, shuffleId); throw new NoRegisterException("No such app is registered. appId: " + appId); } + List<Integer> splitPartitionIds = new ArrayList<>(); // To be compatible with legacy clients which have empty partitionRequireSizes if (partitionIds.size() == partitionRequireSizes.size()) { for (int i = 0; i < partitionIds.size(); i++) { @@ -610,9 +622,33 @@ public long requireBuffer( } HugePartitionUtils.checkExceedPartitionHardLimit( "requireBuffer", shuffleBufferManager, partitionUsedDataSize, partitionRequireSize); + if (HugePartitionUtils.hasExceedPartitionSplitLimit( + shuffleBufferManager, partitionUsedDataSize)) { + LOG.info( + "Need split partition. appId: {}, shuffleId: {}, partitionIds: {}, partitionUsedDataSize: {}", + appId, + shuffleId, + partitionIds, + partitionUsedDataSize); + splitPartitionIds.add(partitionId); + // We do not mind to reduce the partitionRequireSize from the requireSize for soft + // partition split + } } } - return requireBuffer(appId, requireSize); + return Pair.of(requireBuffer(appId, requireSize), splitPartitionIds); + } + + @VisibleForTesting + public long requireBuffer( + String appId, + int shuffleId, + List<Integer> partitionIds, + List<Integer> partitionRequireSizes, + int requireSize) { + return requireBufferReturnPair( + appId, shuffleId, partitionIds, partitionRequireSizes, requireSize) + .getLeft(); } public long requireBuffer(String appId, int requireSize) { diff --git a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java index 7d7b98ac93..85cdf98b1b 100644 --- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java +++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java @@ -81,6 +81,7 @@ public class ShuffleBufferManager { // Huge partition vars private ReconfigurableConfManager.Reconfigurable<Long> hugePartitionSizeThresholdRef; private ReconfigurableConfManager.Reconfigurable<Long> hugePartitionSizeHardLimitRef; + private ReconfigurableConfManager.Reconfigurable<Long> hugePartitionSplitLimitRef; private long hugePartitionMemoryLimitSize; protected AtomicLong preAllocatedSize = new AtomicLong(0L); protected AtomicLong inFlushSize = new AtomicLong(0L); @@ -141,6 +142,8 @@ public ShuffleBufferManager( conf.getReconfigurableConf(ShuffleServerConf.HUGE_PARTITION_SIZE_THRESHOLD); this.hugePartitionSizeHardLimitRef = conf.getReconfigurableConf(ShuffleServerConf.HUGE_PARTITION_SIZE_HARD_LIMIT); + this.hugePartitionSplitLimitRef = + conf.getReconfigurableConf(ShuffleServerConf.HUGE_PARTITION_SPLIT_LIMIT); this.hugePartitionMemoryLimitSize = Math.round( capacity * conf.get(ShuffleServerConf.HUGE_PARTITION_MEMORY_USAGE_LIMITATION_RATIO)); @@ -839,4 +842,8 @@ public void setBufferFlushThreshold(long bufferFlushThreshold) { public ShuffleBufferType getShuffleBufferType() { return shuffleBufferType; } + + public long getHugePartitionSplitLimit() { + return hugePartitionSplitLimitRef.getSizeAsBytes(); + } } diff --git a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java index 13dbe94e12..9d9bdd194b 100644 --- a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java +++ b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java @@ -30,6 +30,7 @@ import com.google.common.collect.RangeMap; import com.google.common.util.concurrent.Uninterruptibles; import io.prometheus.client.Collector; +import org.apache.commons.lang3.tuple.Pair; import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -795,4 +796,62 @@ public void blockSizeMetricsTest() { } }); } + + @Test + public void splitPartitionTest(@TempDir File tmpDir) throws Exception { + ShuffleServerConf shuffleConf = new ShuffleServerConf(); + File dataDir = new File(tmpDir, "data"); + shuffleConf.setString(ShuffleServerConf.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name()); + shuffleConf.set( + ShuffleServerConf.RSS_STORAGE_BASE_PATH, Arrays.asList(dataDir.getAbsolutePath())); + shuffleConf.set(ShuffleServerConf.HUGE_PARTITION_SPLIT_LIMIT, 200L); + + ShuffleServer mockShuffleServer = mock(ShuffleServer.class); + StorageManager storageManager = + StorageManagerFactory.getInstance().createStorageManager(shuffleConf); + ShuffleFlushManager shuffleFlushManager = + new ShuffleFlushManager(shuffleConf, mockShuffleServer, storageManager); + shuffleBufferManager = new ShuffleBufferManager(shuffleConf, shuffleFlushManager, false); + ShuffleTaskManager shuffleTaskManager = + new ShuffleTaskManager( + shuffleConf, shuffleFlushManager, shuffleBufferManager, storageManager); + + when(mockShuffleServer.getShuffleFlushManager()).thenReturn(shuffleFlushManager); + when(mockShuffleServer.getShuffleBufferManager()).thenReturn(shuffleBufferManager); + when(mockShuffleServer.getShuffleTaskManager()).thenReturn(shuffleTaskManager); + + String appId = "flushSingleBufferForHugePartitionTest_appId"; + int shuffleId = 1; + + shuffleTaskManager.registerShuffle( + appId, shuffleId, Arrays.asList(new PartitionRange(0, 0)), new RemoteStorageInfo(""), ""); + + // case1: its partition size does not exceed the split limit + shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 0); + ShufflePartitionedData partitionedData = createData(0, 1); + shuffleTaskManager.cacheShuffleData(appId, shuffleId, false, partitionedData); + shuffleTaskManager.updateCachedBlockIds(appId, shuffleId, 0, partitionedData.getBlockList()); + long usedSize = shuffleTaskManager.getPartitionDataSize(appId, shuffleId, 0); + assertEquals(1 + 32, usedSize); + assertFalse( + HugePartitionUtils.hasExceedPartitionSplitLimit( + shuffleBufferManager, shuffleTaskManager.getPartitionDataSize(appId, shuffleId, 0))); + + // case2: its partition exceed the split limit + partitionedData = createData(0, 200); + shuffleTaskManager.cacheShuffleData(appId, shuffleId, false, partitionedData); + shuffleTaskManager.updateCachedBlockIds(appId, shuffleId, 0, partitionedData.getBlockList()); + usedSize = shuffleTaskManager.getPartitionDataSize(appId, shuffleId, 0); + assertEquals(1 + 32 + 200 + 32, usedSize); + assertTrue( + HugePartitionUtils.hasExceedPartitionSplitLimit( + shuffleBufferManager, shuffleTaskManager.getPartitionDataSize(appId, shuffleId, 0))); + + // check returned need split partitions + Pair<Long, List<Integer>> pair = + shuffleTaskManager.requireBufferReturnPair( + appId, shuffleId, Arrays.asList(0, 1), Arrays.asList(10, 10), 20); + assertEquals(1, pair.getRight().size()); + assertEquals(0, pair.getRight().get(0)); + } }