diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
index 6fc2159431..e3da2b063e 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
@@ -53,14 +53,18 @@ public class MutableShuffleHandleInfo extends ShuffleHandleInfoBase {
   private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers;
 
   private Map<String, Set<ShuffleServerInfo>> excludedServerToReplacements;
+  /**
+   * partitionId -> excluded server -> replacement servers. The replacement servers for exclude
+   * server of specific partition.
+   */
+  private Map<Integer, Map<String, Set<ShuffleServerInfo>>>
+      excludedServerForPartitionToReplacements;
 
   public MutableShuffleHandleInfo(
       int shuffleId,
       Map<Integer, List<ShuffleServerInfo>> partitionToServers,
       RemoteStorageInfo storageInfo) {
-    super(shuffleId, storageInfo);
-    this.excludedServerToReplacements = new HashMap<>();
-    this.partitionReplicaAssignedServers = toPartitionReplicaMapping(partitionToServers);
+    this(shuffleId, storageInfo, toPartitionReplicaMapping(partitionToServers));
   }
 
   @VisibleForTesting
@@ -70,6 +74,7 @@ protected MutableShuffleHandleInfo(
       Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers) {
     super(shuffleId, storageInfo);
     this.excludedServerToReplacements = new HashMap<>();
+    this.excludedServerForPartitionToReplacements = new HashMap<>();
     this.partitionReplicaAssignedServers = partitionReplicaAssignedServers;
   }
 
@@ -77,7 +82,7 @@ public MutableShuffleHandleInfo(int shuffleId, RemoteStorageInfo storageInfo) {
     super(shuffleId, storageInfo);
   }
 
-  private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> toPartitionReplicaMapping(
+  private static Map<Integer, Map<Integer, List<ShuffleServerInfo>>> toPartitionReplicaMapping(
       Map<Integer, List<ShuffleServerInfo>> partitionToServers) {
     Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers =
         new HashMap<>();
@@ -102,6 +107,21 @@ public Set<ShuffleServerInfo> getReplacements(String faultyServerId) {
     return excludedServerToReplacements.get(faultyServerId);
   }
 
+  public Set<ShuffleServerInfo> getReplacementsForPartition(
+      int partitionId, String excludedServerId) {
+    return excludedServerForPartitionToReplacements
+        .getOrDefault(partitionId, Collections.emptyMap())
+        .getOrDefault(excludedServerId, Collections.emptySet());
+  }
+
+  /**
+   * Update the assignment for the receiving failure server of the given partition.
+   *
+   * @param partitionId the partition id
+   * @param receivingFailureServerId the id of the receiving failure server
+   * @param replacements the new assigned servers for replacing the receiving failure server
+   * @return the updated server list for receiving data
+   */
   public Set<ShuffleServerInfo> updateAssignment(
       int partitionId, String receivingFailureServerId, Set<ShuffleServerInfo> replacements) {
     if (replacements == null || StringUtils.isEmpty(receivingFailureServerId)) {
@@ -109,6 +129,11 @@ public Set<ShuffleServerInfo> updateAssignment(
     }
     excludedServerToReplacements.put(receivingFailureServerId, replacements);
 
+    return updateAssignmentInternal(partitionId, receivingFailureServerId, replacements);
+  }
+
+  private Set<ShuffleServerInfo> updateAssignmentInternal(
+      int partitionId, String receivingFailureServerId, Set<ShuffleServerInfo> replacements) {
     Set<ShuffleServerInfo> updatedServers = new HashSet<>();
     Map<Integer, List<ShuffleServerInfo>> replicaServers =
         partitionReplicaAssignedServers.get(partitionId);
@@ -131,6 +156,26 @@ public Set<ShuffleServerInfo> updateAssignment(
     return updatedServers;
   }
 
+  /**
+   * Update the assignment for the receiving failure server of the need split partition.
+   *
+   * @param partitionId the partition id
+   * @param receivingFailureServerId the id of the receiving failure server
+   * @param replacements the new assigned servers for replacing the receiving failure server
+   * @return the updated server list for receiving data
+   */
+  public Set<ShuffleServerInfo> updateAssignmentOnPartitionSplit(
+      int partitionId, String receivingFailureServerId, Set<ShuffleServerInfo> replacements) {
+    if (replacements == null || StringUtils.isEmpty(receivingFailureServerId)) {
+      return Collections.emptySet();
+    }
+    excludedServerForPartitionToReplacements
+        .computeIfAbsent(partitionId, x -> new HashMap<>())
+        .put(receivingFailureServerId, replacements);
+
+    return updateAssignmentInternal(partitionId, receivingFailureServerId, replacements);
+  }
+
   @Override
   public Set<ShuffleServerInfo> getServers() {
     return partitionReplicaAssignedServers.values().stream()
@@ -149,6 +194,7 @@ public Map<Integer, List<ShuffleServerInfo>> getAvailablePartitionServersForWrit
           replicaServers.entrySet()) {
         ShuffleServerInfo candidate;
         int candidateSize = replicaServerEntry.getValue().size();
+        // Use the last one for each replica writing
         candidate = replicaServerEntry.getValue().get(candidateSize - 1);
         assignment.computeIfAbsent(partitionId, x -> new ArrayList<>()).add(candidate);
       }
@@ -266,4 +312,10 @@ public static MutableShuffleHandleInfo fromProto(RssProtos.MutableShuffleHandleI
     handle.partitionReplicaAssignedServers = partitionToServers;
     return handle;
   }
+
+  public Set<String> listExcludedServersForPartition(int partitionId) {
+    return excludedServerForPartitionToReplacements
+        .getOrDefault(partitionId, Collections.emptyMap())
+        .keySet();
+  }
 }
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java
index 99f7a7421b..c0d4b89f57 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java
@@ -31,7 +31,8 @@ public interface ShuffleHandleInfo {
 
   /**
    * Get the assignment of available servers for writer to write partitioned blocks to corresponding
-   * shuffleServers. Implementations might return dynamic, up-to-date information here.
+   * shuffleServers. Implementations might return dynamic, up-to-date information here. Returns
+   * partitionId -> [replica1, replica2, ...]
    */
   Map<Integer, List<ShuffleServerInfo>> getAvailablePartitionServersForWriter();
 
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index 11d81ed31d..8e921c66e7 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -737,7 +737,8 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure(
       int stageId,
       int stageAttemptNumber,
       int shuffleId,
-      Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers) {
+      Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers,
+      boolean partitionSplit) {
     long startTime = System.currentTimeMillis();
     ShuffleHandleInfo handleInfo = shuffleHandleInfoManager.get(shuffleId);
     MutableShuffleHandleInfo internalHandle = null;
@@ -754,8 +755,11 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure(
     synchronized (internalHandle) {
       // If the reassignment servers for one partition exceeds the max reassign server num,
       // it should fast fail.
-      internalHandle.checkPartitionReassignServerNum(
-          partitionToFailureServers.keySet(), partitionReassignMaxServerNum);
+      if (!partitionSplit) {
+        // Do not check the partition reassign server num for partition split case
+        internalHandle.checkPartitionReassignServerNum(
+            partitionToFailureServers.keySet(), partitionReassignMaxServerNum);
+      }
 
       Map<ShuffleServerInfo, List<PartitionRange>> newServerToPartitions = new HashMap<>();
       // receivingFailureServer -> partitionId -> replacementServerIds. For logging
@@ -769,27 +773,44 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure(
           String serverId = receivingFailureServer.getServerId();
 
           boolean serverHasReplaced = false;
-          Set<ShuffleServerInfo> replacements = internalHandle.getReplacements(serverId);
-          if (CollectionUtils.isEmpty(replacements)) {
-            final int requiredServerNum = 1;
-            Set<String> excludedServers = new HashSet<>(internalHandle.listExcludedServers());
-            excludedServers.add(serverId);
-            replacements =
-                reassignServerForTask(
-                    stageId,
-                    stageAttemptNumber,
-                    shuffleId,
-                    Sets.newHashSet(partitionId),
-                    excludedServers,
-                    requiredServerNum,
-                    true);
+
+          Set<ShuffleServerInfo> updatedReassignServers;
+          if (!partitionSplit) {
+            Set<ShuffleServerInfo> replacements = internalHandle.getReplacements(serverId);
+            if (CollectionUtils.isEmpty(replacements)) {
+              replacements =
+                  requestReassignServer(
+                      stageId,
+                      stageAttemptNumber,
+                      shuffleId,
+                      internalHandle,
+                      partitionId,
+                      serverId);
+            } else {
+              serverHasReplaced = true;
+            }
+            updatedReassignServers =
+                internalHandle.updateAssignment(partitionId, serverId, replacements);
           } else {
-            serverHasReplaced = true;
+            Set<ShuffleServerInfo> replacements =
+                internalHandle.getReplacementsForPartition(partitionId, serverId);
+            if (CollectionUtils.isEmpty(replacements)) {
+              replacements =
+                  requestReassignServer(
+                      stageId,
+                      stageAttemptNumber,
+                      shuffleId,
+                      internalHandle,
+                      partitionId,
+                      serverId);
+            } else {
+              serverHasReplaced = true;
+            }
+            updatedReassignServers =
+                internalHandle.updateAssignmentOnPartitionSplit(
+                    partitionId, serverId, replacements);
           }
 
-          Set<ShuffleServerInfo> updatedReassignServers =
-              internalHandle.updateAssignment(partitionId, serverId, replacements);
-
           if (!updatedReassignServers.isEmpty()) {
             reassignResult
                 .computeIfAbsent(serverId, x -> new HashMap<>())
@@ -825,6 +846,31 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure(
     }
   }
 
+  private Set<ShuffleServerInfo> requestReassignServer(
+      int stageId,
+      int stageAttemptNumber,
+      int shuffleId,
+      MutableShuffleHandleInfo internalHandle,
+      int partitionId,
+      String serverId) {
+    Set<ShuffleServerInfo> replacements;
+    final int requiredServerNum = 1;
+    Set<String> excludedServers = new HashSet<>(internalHandle.listExcludedServers());
+    // Exclude the servers that has already been replaced for partition split case.
+    excludedServers.addAll(internalHandle.listExcludedServersForPartition(partitionId));
+    excludedServers.add(serverId);
+    replacements =
+        reassignServerForTask(
+            stageId,
+            stageAttemptNumber,
+            shuffleId,
+            Sets.newHashSet(partitionId),
+            excludedServers,
+            requiredServerNum,
+            true);
+    return replacements;
+  }
+
   @Override
   public void stop() {
     if (managerClientSupplier != null
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
index 77379efb5f..5abb6b832c 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java
@@ -86,5 +86,6 @@ MutableShuffleHandleInfo reassignOnBlockSendFailure(
       int stageId,
       int stageAttemptNumber,
       int shuffleId,
-      Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers);
+      Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers,
+      boolean partitionSplit);
 }
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
index bc2812ae9d..cbc66105bb 100644
--- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
@@ -267,12 +267,13 @@ public void reassignOnBlockSendFailure(
     RssProtos.ReassignOnBlockSendFailureResponse reply;
     try {
       LOG.info(
-          "Accepted reassign request on block sent failure for shuffleId: {}, stageId: {}, stageAttemptNumber: {} from taskAttemptId: {} on executorId: {}",
+          "Accepted reassign request on block sent failure for shuffleId: {}, stageId: {}, stageAttemptNumber: {} from taskAttemptId: {} on executorId: {} while partition split:{}",
           request.getShuffleId(),
           request.getStageId(),
           request.getStageAttemptNumber(),
           request.getTaskAttemptId(),
-          request.getExecutorId());
+          request.getExecutorId(),
+          request.getPartitionSplit());
       MutableShuffleHandleInfo handle =
           shuffleManager.reassignOnBlockSendFailure(
               request.getStageId(),
@@ -281,7 +282,8 @@ public void reassignOnBlockSendFailure(
               request.getFailurePartitionToServerIdsMap().entrySet().stream()
                   .collect(
                       Collectors.toMap(
-                          Map.Entry::getKey, x -> ReceivingFailureServer.fromProto(x.getValue()))));
+                          Map.Entry::getKey, x -> ReceivingFailureServer.fromProto(x.getValue()))),
+              request.getPartitionSplit());
       code = RssProtos.StatusCode.SUCCESS;
       reply =
           RssProtos.ReassignOnBlockSendFailureResponse.newBuilder()
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
index e861923584..3a24f301bf 100644
--- a/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java
@@ -20,6 +20,7 @@
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -142,4 +143,65 @@ public void testCreatePartitionReplicaTracking() {
     assertEquals(b, inventory.get(1).get(1).get(0));
     assertEquals(c, inventory.get(2).get(0).get(0));
   }
+
+  @Test
+  public void testUpdateAssignmentOnPartitionSplit() {
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    partitionToServers.put(1, Arrays.asList(createFakeServerInfo("a"), createFakeServerInfo("b")));
+    partitionToServers.put(2, Arrays.asList(createFakeServerInfo("c")));
+
+    MutableShuffleHandleInfo handleInfo =
+        new MutableShuffleHandleInfo(1, partitionToServers, new RemoteStorageInfo(""));
+
+    // case1: update the replacement servers but has existing servers
+    Set<ShuffleServerInfo> updated =
+        handleInfo.updateAssignment(
+            1, "a", Sets.newHashSet(createFakeServerInfo("a"), createFakeServerInfo("d")));
+    assertTrue(updated.stream().findFirst().get().getId().equals("d"));
+
+    // case2: update when having multiple servers
+    Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers =
+        new HashMap<>();
+    List<ShuffleServerInfo> servers =
+        new ArrayList<>(
+            Arrays.asList(
+                createFakeServerInfo("a"),
+                createFakeServerInfo("b"),
+                createFakeServerInfo("c"),
+                createFakeServerInfo("d")));
+    partitionReplicaAssignedServers
+        .computeIfAbsent(1, x -> new HashMap<>())
+        .computeIfAbsent(0, x -> servers);
+    handleInfo =
+        new MutableShuffleHandleInfo(1, new RemoteStorageInfo(""), partitionReplicaAssignedServers);
+
+    Map<Integer, List<ShuffleServerInfo>> availablePartitionServers =
+        handleInfo.getAvailablePartitionServersForWriter();
+    assertEquals("d", availablePartitionServers.get(1).get(0).getHost());
+    Map<Integer, List<ShuffleServerInfo>> assignment = handleInfo.getAllPartitionServersForReader();
+    assertEquals(4, assignment.get(1).size());
+
+    int partitionId = 1;
+
+    handleInfo.getReplacementsForPartition(1, "a");
+    HashSet<ShuffleServerInfo> replacements =
+        Sets.newHashSet(
+            createFakeServerInfo("b"),
+            createFakeServerInfo("d"),
+            createFakeServerInfo("e"),
+            createFakeServerInfo("f"));
+    updated = handleInfo.updateAssignmentOnPartitionSplit(partitionId, "a", replacements);
+    assertEquals(updated, Sets.newHashSet(createFakeServerInfo("e"), createFakeServerInfo("f")));
+
+    Set<String> excludedServers = handleInfo.listExcludedServersForPartition(partitionId);
+    assertEquals(1, excludedServers.size());
+    assertEquals("a", excludedServers.iterator().next());
+    assertEquals(replacements, handleInfo.getReplacementsForPartition(1, "a"));
+    availablePartitionServers = handleInfo.getAvailablePartitionServersForWriter();
+    // The current writer is the last one
+    assertEquals("f", availablePartitionServers.get(1).get(0).getHost());
+    assignment = handleInfo.getAllPartitionServersForReader();
+    // All the servers were selected as writer are available as reader
+    assertEquals(6, assignment.get(1).size());
+  }
 }
diff --git a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
index 317d0cd9ea..8e6a307c57 100644
--- a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
+++ b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java
@@ -80,7 +80,8 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure(
       int stageId,
       int stageAttemptNumber,
       int shuffleId,
-      Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers) {
+      Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers,
+      boolean partitionSplit) {
     return null;
   }
 }
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index b3a5ccf099..decefc6d37 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -556,6 +556,8 @@ private void collectFailedBlocksToResend() {
       return;
     }
 
+    reassignOnPartitionNeedSplit(failedTracker);
+
     Set<Long> failedBlockIds = failedTracker.getFailedBlockIds();
     if (CollectionUtils.isEmpty(failedBlockIds)) {
       return;
@@ -619,8 +621,26 @@ private void collectFailedBlocksToResend() {
     reassignAndResendBlocks(resendCandidates);
   }
 
+  private void reassignOnPartitionNeedSplit(FailedBlockSendTracker failedTracker) {
+    Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers = new HashMap<>();
+
+    failedTracker
+        .removeAllTrackedPartitions()
+        .forEach(
+            partitionStatus ->
+                failurePartitionToServers
+                    .computeIfAbsent(partitionStatus.getPartitionId(), x -> new ArrayList<>())
+                    .add(
+                        new ReceivingFailureServer(
+                            partitionStatus.getShuffleServerInfo().getId(), StatusCode.SUCCESS)));
+    if (!failurePartitionToServers.isEmpty()) {
+      doReassignOnBlockSendFailure(failurePartitionToServers, true);
+    }
+  }
+
   private void doReassignOnBlockSendFailure(
-      Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers) {
+      Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers,
+      boolean partitionSplit) {
     LOG.info(
         "Initiate reassignOnBlockSendFailure. failure partition servers: {}",
         failurePartitionToServers);
@@ -636,7 +656,8 @@ private void doReassignOnBlockSendFailure(
               executorId,
               taskAttemptId,
               stageId,
-              stageAttemptNum);
+              stageAttemptNum,
+              partitionSplit);
       RssReassignOnBlockSendFailureResponse response =
           managerClientSupplier.get().reassignOnBlockSendFailure(request);
       if (response.getStatusCode() != StatusCode.SUCCESS) {
@@ -681,7 +702,12 @@ private void reassignAndResendBlocks(Set<TrackingBlockStatus> blocks) {
           serverBlocks.entrySet()) {
         String serverId = blockStatusEntry.getKey().getId();
         // avoid duplicate reassign for the same failure server.
-        String latestServerId = getPartitionAssignedServers(partitionId).get(0).getId();
+        // todo: getting the replacement should support multi replica.
+        List<ShuffleServerInfo> servers = getPartitionAssignedServers(partitionId);
+        // Gets the first replica for this partition for now.
+        // It can not work if we want to use multiple replicas.
+        ShuffleServerInfo replacement = servers.get(0);
+        String latestServerId = replacement.getId();
         if (!serverId.equals(latestServerId)) {
           continue;
         }
@@ -693,13 +719,16 @@ private void reassignAndResendBlocks(Set<TrackingBlockStatus> blocks) {
     }
 
     if (!failurePartitionToServers.isEmpty()) {
-      doReassignOnBlockSendFailure(failurePartitionToServers);
+      doReassignOnBlockSendFailure(failurePartitionToServers, false);
     }
 
     for (TrackingBlockStatus blockStatus : blocks) {
       ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo();
       // todo: getting the replacement should support multi replica.
-      ShuffleServerInfo replacement = getPartitionAssignedServers(block.getPartitionId()).get(0);
+      List<ShuffleServerInfo> servers = getPartitionAssignedServers(block.getPartitionId());
+      // Gets the first replica for this partition for now.
+      // It can not work if we want to use multiple replicas.
+      ShuffleServerInfo replacement = servers.get(0);
       if (blockStatus.getShuffleServerInfo().getId().equals(replacement.getId())) {
         throw new RssException(
             "No available replacement server for: " + blockStatus.getShuffleServerInfo().getId());
diff --git a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
index c0ff6d5bdb..12856faf71 100644
--- a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
+++ b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
@@ -21,6 +21,8 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
 
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
@@ -28,6 +30,7 @@
 
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.rpc.StatusCode;
 
 public class FailedBlockSendTracker {
@@ -40,8 +43,11 @@ public class FailedBlockSendTracker {
    */
   private Map<Long, List<TrackingBlockStatus>> trackingBlockStatusMap;
 
+  private final BlockingQueue<TrackingPartitionStatus> trackingNeedSplitPartitionStatusQueue;
+
   public FailedBlockSendTracker() {
     this.trackingBlockStatusMap = Maps.newConcurrentMap();
+    this.trackingNeedSplitPartitionStatusQueue = new LinkedBlockingQueue<>();
   }
 
   public void add(
@@ -56,6 +62,8 @@ public void add(
 
   public void merge(FailedBlockSendTracker failedBlockSendTracker) {
     this.trackingBlockStatusMap.putAll(failedBlockSendTracker.trackingBlockStatusMap);
+    this.trackingNeedSplitPartitionStatusQueue.addAll(
+        failedBlockSendTracker.trackingNeedSplitPartitionStatusQueue);
   }
 
   public void remove(long blockId) {
@@ -72,6 +80,7 @@ public void clearAndReleaseBlockResources() {
               }
             });
     trackingBlockStatusMap.clear();
+    trackingNeedSplitPartitionStatusQueue.clear();
   }
 
   public Set<Long> getFailedBlockIds() {
@@ -94,4 +103,20 @@ public Set<ShuffleServerInfo> getFaultyShuffleServers() {
             });
     return shuffleServerInfos;
   }
+
+  public void addNeedSplitPartition(int partitionId, ShuffleServerInfo shuffleServerInfo) {
+    try {
+      trackingNeedSplitPartitionStatusQueue.put(
+          new TrackingPartitionStatus(partitionId, shuffleServerInfo));
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      throw new RssException(e);
+    }
+  }
+
+  public List<TrackingPartitionStatus> removeAllTrackedPartitions() {
+    List<TrackingPartitionStatus> trackingPartitionStatusList = Lists.newArrayList();
+    trackingNeedSplitPartitionStatusQueue.drainTo(trackingPartitionStatusList);
+    return trackingPartitionStatusList;
+  }
 }
diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index ac93d57b1e..b78f06bd3c 100644
--- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -218,6 +218,8 @@ private boolean sendShuffleDataAsync(
                             .forEach(
                                 blockId ->
                                     blockIdsSendSuccessTracker.get(blockId).incrementAndGet());
+                        recordNeedSplitPartition(
+                            failedBlockSendTracker, ssi, response.getNeedSplitPartitionIds());
                         if (defectiveServers != null) {
                           defectiveServers.remove(ssi);
                         }
@@ -281,6 +283,16 @@ void recordFailedBlocks(
         .forEach(block -> blockIdsSendFailTracker.add(block, shuffleServerInfo, statusCode));
   }
 
+  void recordNeedSplitPartition(
+      FailedBlockSendTracker blockIdsSendFailTracker,
+      ShuffleServerInfo shuffleServerInfo,
+      Set<Integer> needSplitPartitions) {
+    if (needSplitPartitions != null) {
+      needSplitPartitions.forEach(
+          partition -> blockIdsSendFailTracker.addNeedSplitPartition(partition, shuffleServerInfo));
+    }
+  }
+
   void genServerToBlocks(
       ShuffleBlockInfo sbi,
       List<ShuffleServerInfo> serverList,
@@ -322,6 +334,7 @@ void genServerToBlocks(
   }
 
   @Override
+  @VisibleForTesting
   public SendShuffleDataResult sendShuffleData(
       String appId,
       List<ShuffleBlockInfo> shuffleBlockInfoList,
diff --git a/client/src/main/java/org/apache/uniffle/client/impl/TrackingPartitionStatus.java b/client/src/main/java/org/apache/uniffle/client/impl/TrackingPartitionStatus.java
new file mode 100644
index 0000000000..d734661f5f
--- /dev/null
+++ b/client/src/main/java/org/apache/uniffle/client/impl/TrackingPartitionStatus.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.impl;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+public class TrackingPartitionStatus {
+  private int partitionId;
+  private ShuffleServerInfo shuffleServerInfo;
+
+  public TrackingPartitionStatus(int partitionId, ShuffleServerInfo shuffleServerInfo) {
+    this.shuffleServerInfo = shuffleServerInfo;
+    this.partitionId = partitionId;
+  }
+
+  public ShuffleServerInfo getShuffleServerInfo() {
+    return shuffleServerInfo;
+  }
+
+  public int getPartitionId() {
+    return partitionId;
+  }
+}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 98180d647c..b0720756ae 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -31,6 +31,7 @@
 import com.google.protobuf.ByteString;
 import com.google.protobuf.UnsafeByteOperations;
 import io.netty.buffer.Unpooled;
+import org.apache.commons.lang3.tuple.Pair;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -280,17 +281,18 @@ public long requirePreAllocation(
       int retryMax,
       long retryIntervalMax) {
     return requirePreAllocation(
-        appId,
-        shuffleId,
-        partitionIds,
-        partitionRequireSizes,
-        requireSize,
-        retryMax,
-        retryIntervalMax,
-        new AtomicReference<>(StatusCode.INTERNAL_ERROR));
+            appId,
+            shuffleId,
+            partitionIds,
+            partitionRequireSizes,
+            requireSize,
+            retryMax,
+            retryIntervalMax,
+            new AtomicReference<>(StatusCode.INTERNAL_ERROR))
+        .getLeft();
   }
 
-  public long requirePreAllocation(
+  public Pair<Long, List<Integer>> requirePreAllocation(
       String appId,
       int shuffleId,
       List<Integer> partitionIds,
@@ -311,6 +313,7 @@ public long requirePreAllocation(
     long start = System.currentTimeMillis();
     int retry = 0;
     long result = FAILED_REQUIRE_ID;
+    List<Integer> needSplitPartitionIds = Collections.emptyList();
     if (LOG.isDebugEnabled()) {
       LOG.debug(
           "Requiring buffer for appId: {}, shuffleId: {}, partitionIds: {} with {} bytes from {}:{}",
@@ -328,7 +331,7 @@ public long requirePreAllocation(
       } catch (Exception e) {
         LOG.error(
             "Exception happened when requiring pre-allocated buffer from {}:{}", host, port, e);
-        return result;
+        return Pair.of(result, needSplitPartitionIds);
       }
       if (rpcResponse.getStatus() != NO_BUFFER
           && rpcResponse.getStatus() != RssProtos.StatusCode.NO_BUFFER_FOR_HUGE_PARTITION) {
@@ -348,7 +351,7 @@ public long requirePreAllocation(
                 + retryMax
                 + " times, cost: {}(ms)",
             System.currentTimeMillis() - start);
-        return result;
+        return Pair.of(result, needSplitPartitionIds);
       }
       try {
         LOG.info(
@@ -382,6 +385,7 @@ public long requirePreAllocation(
             System.currentTimeMillis() - start);
       }
       result = rpcResponse.getRequireBufferId();
+      needSplitPartitionIds = rpcResponse.getNeedSplitPartitionIdsList();
     } else if (NOT_RETRY_STATUS_CODES.contains(
         StatusCode.fromCode(rpcResponse.getStatus().getNumber()))) {
       failedStatusCodeRef.set(StatusCode.fromCode(rpcResponse.getStatus().getNumber()));
@@ -398,7 +402,7 @@ public long requirePreAllocation(
               + rpcResponse.getRetMsg();
       throw new NotRetryException(msg);
     }
-    return result;
+    return Pair.of(result, needSplitPartitionIds);
   }
 
   private RssProtos.ShuffleUnregisterByAppIdResponse doUnregisterShuffleByAppId(
@@ -564,16 +568,18 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ
       try {
         RetryUtils.retryWithCondition(
             () -> {
+              // TODO(baoloongmao): support partition split follow netty client
               long requireId =
                   requirePreAllocation(
-                      appId,
-                      shuffleId,
-                      partitionIds,
-                      partitionRequireSizes,
-                      allocateSize,
-                      request.getRetryMax() / maxRetryAttempts,
-                      request.getRetryIntervalMax(),
-                      failedStatusCode);
+                          appId,
+                          shuffleId,
+                          partitionIds,
+                          partitionRequireSizes,
+                          allocateSize,
+                          request.getRetryMax() / maxRetryAttempts,
+                          request.getRetryIntervalMax(),
+                          failedStatusCode)
+                      .getLeft();
               if (requireId == FAILED_REQUIRE_ID) {
                 throw new RssException(
                     String.format(
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
index 723b0ecf34..a89ea9c4c6 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
@@ -20,10 +20,13 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 
 import com.google.common.annotations.VisibleForTesting;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.hbase.thirdparty.org.glassfish.jersey.internal.guava.Sets;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -127,7 +130,7 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ
     int stageAttemptNumber = request.getStageAttemptNumber();
     boolean isSuccessful = true;
     AtomicReference<StatusCode> failedStatusCode = new AtomicReference<>(StatusCode.INTERNAL_ERROR);
-
+    Set<Integer> needSplitPartitionIds = Sets.newHashSet();
     for (Map.Entry<Integer, Map<Integer, List<ShuffleBlockInfo>>> stb :
         shuffleIdToBlocks.entrySet()) {
       int shuffleId = stb.getKey();
@@ -161,7 +164,7 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ
         RetryUtils.retryWithCondition(
             () -> {
               final TransportClient transportClient = getTransportClient();
-              long requireId =
+              Pair<Long, List<Integer>> result =
                   requirePreAllocation(
                       request.getAppId(),
                       shuffleId,
@@ -171,6 +174,8 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ
                       request.getRetryMax(),
                       request.getRetryIntervalMax(),
                       failedStatusCode);
+              long requireId = result.getLeft();
+              needSplitPartitionIds.addAll(result.getRight());
               if (requireId == FAILED_REQUIRE_ID) {
                 throw new RssException(
                     String.format(
@@ -232,6 +237,7 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ
     } else {
       response = new RssSendShuffleDataResponse(failedStatusCode.get());
     }
+    response.setNeedSplitPartitionIds(needSplitPartitionIds);
     return response;
   }
 
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignOnBlockSendFailureRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignOnBlockSendFailureRequest.java
index 303499fb46..cd3a2ab84f 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignOnBlockSendFailureRequest.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReassignOnBlockSendFailureRequest.java
@@ -25,6 +25,7 @@
 import org.apache.uniffle.proto.RssProtos;
 
 public class RssReassignOnBlockSendFailureRequest {
+  private final boolean partitionSplit;
   private int shuffleId;
   private Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers;
   private String executorId;
@@ -38,13 +39,15 @@ public RssReassignOnBlockSendFailureRequest(
       String executorId,
       long taskAttemptId,
       int stageId,
-      int stageAttemptNum) {
+      int stageAttemptNum,
+      boolean partitionSplit) {
     this.shuffleId = shuffleId;
     this.failurePartitionToServers = failurePartitionToServers;
     this.executorId = executorId;
     this.taskAttemptId = taskAttemptId;
     this.stageId = stageId;
     this.stageAttemptNumber = stageAttemptNum;
+    this.partitionSplit = partitionSplit;
   }
 
   public static RssProtos.RssReassignOnBlockSendFailureRequest toProto(
@@ -60,6 +63,7 @@ public static RssProtos.RssReassignOnBlockSendFailureRequest toProto(
         .setStageId(request.stageId)
         .setStageAttemptNumber(request.stageAttemptNumber)
         .setTaskAttemptId(request.taskAttemptId)
+        .setPartitionSplit(request.partitionSplit)
         .build();
   }
 }
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/response/RssSendShuffleDataResponse.java b/internal-client/src/main/java/org/apache/uniffle/client/response/RssSendShuffleDataResponse.java
index eb20dd8a0d..917d740735 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/response/RssSendShuffleDataResponse.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/response/RssSendShuffleDataResponse.java
@@ -18,6 +18,7 @@
 package org.apache.uniffle.client.response;
 
 import java.util.List;
+import java.util.Set;
 
 import org.apache.uniffle.common.rpc.StatusCode;
 
@@ -25,6 +26,7 @@ public class RssSendShuffleDataResponse extends ClientResponse {
 
   private List<Long> successBlockIds;
   private List<Long> failedBlockIds;
+  private Set<Integer> needSplitPartitionIds;
 
   public RssSendShuffleDataResponse(StatusCode statusCode) {
     super(statusCode);
@@ -45,4 +47,12 @@ public List<Long> getFailedBlockIds() {
   public void setFailedBlockIds(List<Long> failedBlockIds) {
     this.failedBlockIds = failedBlockIds;
   }
+
+  public void setNeedSplitPartitionIds(Set<Integer> needSplitPartitionIds) {
+    this.needSplitPartitionIds = needSplitPartitionIds;
+  }
+
+  public Set<Integer> getNeedSplitPartitionIds() {
+    return needSplitPartitionIds;
+  }
 }
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 5e8cc632d5..e883eddc43 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -64,6 +64,8 @@ message RequireBufferResponse {
   int64 requireBufferId = 1;
   StatusCode status = 2;
   string retMsg = 3;
+  // need split partitions
+  repeated int32 needSplitPartitionIds = 4;
 }
 
 message ShuffleDataBlockSegment {
@@ -669,6 +671,7 @@ message RssReassignOnBlockSendFailureRequest {
   int32 stageId = 4;
   int32 stageAttemptNumber = 5;
   string executorId = 6;
+  optional bool partitionSplit = 7;
 }
 
 message ReceivingFailureServers {
diff --git a/server/src/main/java/org/apache/uniffle/server/HugePartitionUtils.java b/server/src/main/java/org/apache/uniffle/server/HugePartitionUtils.java
index f2501cd420..59a010cfff 100644
--- a/server/src/main/java/org/apache/uniffle/server/HugePartitionUtils.java
+++ b/server/src/main/java/org/apache/uniffle/server/HugePartitionUtils.java
@@ -147,4 +147,9 @@ public static boolean limitHugePartition(
     }
     return false;
   }
+
+  public static boolean hasExceedPartitionSplitLimit(
+      ShuffleBufferManager shuffleBufferManager, long usedPartitionDataSize) {
+    return usedPartitionDataSize > shuffleBufferManager.getHugePartitionSplitLimit();
+  }
 }
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
index 310afcb231..677127b3b4 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java
@@ -516,6 +516,16 @@ public class ShuffleServerConf extends RssBaseConf {
                   + "will be terminated. This helps to significantly improve the "
                   + "stability of the cluster by preventing partitions from becoming too large.");
 
+  public static final ConfigOption<Long> HUGE_PARTITION_SPLIT_LIMIT =
+      ConfigOptions.key("rss.server.huge-partition.split.limit")
+          .longType()
+          .defaultValue(Long.MAX_VALUE)
+          .withDescription(
+              "This option sets the maximum partition slice size threshold. "
+                  + "If the partition size exceeds this threshold, the rss client will "
+                  + "receive the need split partition list and resend the failed blocks to "
+                  + "new servers through reassign mechanism.");
+
   public static final ConfigOption<Long> SERVER_DECOMMISSION_CHECK_INTERVAL =
       ConfigOptions.key("rss.server.decommission.check.interval")
           .longType()
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
index 85fe7fa6e7..303a4a7d20 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
@@ -19,6 +19,7 @@
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -34,6 +35,7 @@
 import io.grpc.stub.StreamObserver;
 import io.netty.buffer.ByteBuf;
 import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.lang3.tuple.Pair;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -727,21 +729,24 @@ public void requireBuffer(
       long requireBufferId = -1;
       String responseMessage = "";
       String shuffleDataInfo = "appId[" + appId + "], shuffleId[" + request.getShuffleId() + "]";
+      List<Integer> needSplitPartitionIds = Collections.emptyList();
       try {
         if (StringUtils.isEmpty(appId)) {
           // To be compatible with older client version
           requireBufferId =
               shuffleServer.getShuffleTaskManager().requireBuffer(request.getRequireSize());
         } else {
-          requireBufferId =
+          Pair<Long, List<Integer>> pair =
               shuffleServer
                   .getShuffleTaskManager()
-                  .requireBuffer(
+                  .requireBufferReturnPair(
                       appId,
                       request.getShuffleId(),
                       request.getPartitionIdsList(),
                       request.getPartitionRequireSizesList(),
                       request.getRequireSize());
+          requireBufferId = pair.getLeft();
+          needSplitPartitionIds = pair.getRight();
         }
       } catch (NoBufferException e) {
         responseMessage = e.getMessage();
@@ -775,6 +780,7 @@ public void requireBuffer(
               .setStatus(status.toProto())
               .setRequireBufferId(requireBufferId)
               .setRetMsg(responseMessage)
+              .addAllNeedSplitPartitionIds(needSplitPartitionIds)
               .build();
       responseObserver.onNext(response);
       responseObserver.onCompleted();
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
index af37646a78..18dee6b0d3 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
@@ -48,6 +48,7 @@
 import com.google.common.collect.Range;
 import com.google.common.collect.Sets;
 import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang3.tuple.Pair;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -581,7 +582,17 @@ public long getPartitionDataSize(String appId, int shuffleId, int partitionId) {
     return shuffleTaskInfo.getPartitionDataSize(shuffleId, partitionId);
   }
 
-  public long requireBuffer(
+  /**
+   * Require buffer for shuffle data
+   *
+   * @param appId the appId
+   * @param shuffleId the shuffleId
+   * @param partitionIds the partitionIds
+   * @param partitionRequireSizes the partitionRequireSizes
+   * @param requireSize the requireSize
+   * @return returns (requireId, splitPartitionIds)
+   */
+  public Pair<Long, List<Integer>> requireBufferReturnPair(
       String appId,
       int shuffleId,
       List<Integer> partitionIds,
@@ -592,6 +603,7 @@ public long requireBuffer(
       LOG.error("No such app is registered. appId: {}, shuffleId: {}", appId, shuffleId);
       throw new NoRegisterException("No such app is registered. appId: " + appId);
     }
+    List<Integer> splitPartitionIds = new ArrayList<>();
     // To be compatible with legacy clients which have empty partitionRequireSizes
     if (partitionIds.size() == partitionRequireSizes.size()) {
       for (int i = 0; i < partitionIds.size(); i++) {
@@ -610,9 +622,33 @@ public long requireBuffer(
         }
         HugePartitionUtils.checkExceedPartitionHardLimit(
             "requireBuffer", shuffleBufferManager, partitionUsedDataSize, partitionRequireSize);
+        if (HugePartitionUtils.hasExceedPartitionSplitLimit(
+            shuffleBufferManager, partitionUsedDataSize)) {
+          LOG.info(
+              "Need split partition. appId: {}, shuffleId: {}, partitionIds: {}, partitionUsedDataSize: {}",
+              appId,
+              shuffleId,
+              partitionIds,
+              partitionUsedDataSize);
+          splitPartitionIds.add(partitionId);
+          // We do not mind to reduce the partitionRequireSize from the requireSize for soft
+          // partition split
+        }
       }
     }
-    return requireBuffer(appId, requireSize);
+    return Pair.of(requireBuffer(appId, requireSize), splitPartitionIds);
+  }
+
+  @VisibleForTesting
+  public long requireBuffer(
+      String appId,
+      int shuffleId,
+      List<Integer> partitionIds,
+      List<Integer> partitionRequireSizes,
+      int requireSize) {
+    return requireBufferReturnPair(
+            appId, shuffleId, partitionIds, partitionRequireSizes, requireSize)
+        .getLeft();
   }
 
   public long requireBuffer(String appId, int requireSize) {
diff --git a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
index 7d7b98ac93..85cdf98b1b 100644
--- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
@@ -81,6 +81,7 @@ public class ShuffleBufferManager {
   // Huge partition vars
   private ReconfigurableConfManager.Reconfigurable<Long> hugePartitionSizeThresholdRef;
   private ReconfigurableConfManager.Reconfigurable<Long> hugePartitionSizeHardLimitRef;
+  private ReconfigurableConfManager.Reconfigurable<Long> hugePartitionSplitLimitRef;
   private long hugePartitionMemoryLimitSize;
   protected AtomicLong preAllocatedSize = new AtomicLong(0L);
   protected AtomicLong inFlushSize = new AtomicLong(0L);
@@ -141,6 +142,8 @@ public ShuffleBufferManager(
         conf.getReconfigurableConf(ShuffleServerConf.HUGE_PARTITION_SIZE_THRESHOLD);
     this.hugePartitionSizeHardLimitRef =
         conf.getReconfigurableConf(ShuffleServerConf.HUGE_PARTITION_SIZE_HARD_LIMIT);
+    this.hugePartitionSplitLimitRef =
+        conf.getReconfigurableConf(ShuffleServerConf.HUGE_PARTITION_SPLIT_LIMIT);
     this.hugePartitionMemoryLimitSize =
         Math.round(
             capacity * conf.get(ShuffleServerConf.HUGE_PARTITION_MEMORY_USAGE_LIMITATION_RATIO));
@@ -839,4 +842,8 @@ public void setBufferFlushThreshold(long bufferFlushThreshold) {
   public ShuffleBufferType getShuffleBufferType() {
     return shuffleBufferType;
   }
+
+  public long getHugePartitionSplitLimit() {
+    return hugePartitionSplitLimitRef.getSizeAsBytes();
+  }
 }
diff --git a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
index 13dbe94e12..9d9bdd194b 100644
--- a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
+++ b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
@@ -30,6 +30,7 @@
 import com.google.common.collect.RangeMap;
 import com.google.common.util.concurrent.Uninterruptibles;
 import io.prometheus.client.Collector;
+import org.apache.commons.lang3.tuple.Pair;
 import org.awaitility.Awaitility;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -795,4 +796,62 @@ public void blockSizeMetricsTest() {
               }
             });
   }
+
+  @Test
+  public void splitPartitionTest(@TempDir File tmpDir) throws Exception {
+    ShuffleServerConf shuffleConf = new ShuffleServerConf();
+    File dataDir = new File(tmpDir, "data");
+    shuffleConf.setString(ShuffleServerConf.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name());
+    shuffleConf.set(
+        ShuffleServerConf.RSS_STORAGE_BASE_PATH, Arrays.asList(dataDir.getAbsolutePath()));
+    shuffleConf.set(ShuffleServerConf.HUGE_PARTITION_SPLIT_LIMIT, 200L);
+
+    ShuffleServer mockShuffleServer = mock(ShuffleServer.class);
+    StorageManager storageManager =
+        StorageManagerFactory.getInstance().createStorageManager(shuffleConf);
+    ShuffleFlushManager shuffleFlushManager =
+        new ShuffleFlushManager(shuffleConf, mockShuffleServer, storageManager);
+    shuffleBufferManager = new ShuffleBufferManager(shuffleConf, shuffleFlushManager, false);
+    ShuffleTaskManager shuffleTaskManager =
+        new ShuffleTaskManager(
+            shuffleConf, shuffleFlushManager, shuffleBufferManager, storageManager);
+
+    when(mockShuffleServer.getShuffleFlushManager()).thenReturn(shuffleFlushManager);
+    when(mockShuffleServer.getShuffleBufferManager()).thenReturn(shuffleBufferManager);
+    when(mockShuffleServer.getShuffleTaskManager()).thenReturn(shuffleTaskManager);
+
+    String appId = "flushSingleBufferForHugePartitionTest_appId";
+    int shuffleId = 1;
+
+    shuffleTaskManager.registerShuffle(
+        appId, shuffleId, Arrays.asList(new PartitionRange(0, 0)), new RemoteStorageInfo(""), "");
+
+    // case1: its partition size does not exceed the split limit
+    shuffleBufferManager.registerBuffer(appId, shuffleId, 0, 0);
+    ShufflePartitionedData partitionedData = createData(0, 1);
+    shuffleTaskManager.cacheShuffleData(appId, shuffleId, false, partitionedData);
+    shuffleTaskManager.updateCachedBlockIds(appId, shuffleId, 0, partitionedData.getBlockList());
+    long usedSize = shuffleTaskManager.getPartitionDataSize(appId, shuffleId, 0);
+    assertEquals(1 + 32, usedSize);
+    assertFalse(
+        HugePartitionUtils.hasExceedPartitionSplitLimit(
+            shuffleBufferManager, shuffleTaskManager.getPartitionDataSize(appId, shuffleId, 0)));
+
+    // case2: its partition exceed the split limit
+    partitionedData = createData(0, 200);
+    shuffleTaskManager.cacheShuffleData(appId, shuffleId, false, partitionedData);
+    shuffleTaskManager.updateCachedBlockIds(appId, shuffleId, 0, partitionedData.getBlockList());
+    usedSize = shuffleTaskManager.getPartitionDataSize(appId, shuffleId, 0);
+    assertEquals(1 + 32 + 200 + 32, usedSize);
+    assertTrue(
+        HugePartitionUtils.hasExceedPartitionSplitLimit(
+            shuffleBufferManager, shuffleTaskManager.getPartitionDataSize(appId, shuffleId, 0)));
+
+    // check returned need split partitions
+    Pair<Long, List<Integer>> pair =
+        shuffleTaskManager.requireBufferReturnPair(
+            appId, shuffleId, Arrays.asList(0, 1), Arrays.asList(10, 10), 20);
+    assertEquals(1, pair.getRight().size());
+    assertEquals(0, pair.getRight().get(0));
+  }
 }