diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch
new file mode 100644
index 00000000000..0cb1fc812dc
--- /dev/null
+++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch
@@ -0,0 +1,315 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+index e469c9989f2..a4a68ef1b09 100644
+--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
++++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+@@ -33,6 +33,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut
+ import org.roaringbitmap.RoaringBitmap
+
+ import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.internal.Logging
+ import org.apache.spark.internal.config._
+ import org.apache.spark.io.CompressionCodec
+@@ -839,6 +840,7 @@ private[spark] class MapOutputTrackerMaster(
+ shuffleStatus.invalidateSerializedMergeOutputStatusCache()
+ }
+ }
++ CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId)
+ }
+
+ /**
+diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
+index 0388c7b576b..59fdc81b09d 100644
+--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
++++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
+@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration
+ import org.apache.spark.annotation.DeveloperApi
+ import org.apache.spark.api.python.PythonWorkerFactory
+ import org.apache.spark.broadcast.BroadcastManager
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.internal.{config, Logging}
+ import org.apache.spark.internal.config._
+ import org.apache.spark.memory.{MemoryManager, UnifiedMemoryManager}
+@@ -414,6 +415,7 @@ object SparkEnv extends Logging {
+ if (isDriver) {
+ val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath
+ envInstance.driverTmpDir = Some(sparkFilesDir)
++ CelebornShuffleState.init(envInstance)
+ }
+
+ envInstance
+diff --git a/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala
+new file mode 100644
+index 00000000000..5e190c512df
+--- /dev/null
++++ b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala
+@@ -0,0 +1,75 @@
++/*
++ * Licensed to the Apache Software Foundation (ASF) under one or more
++ * contributor license agreements. See the NOTICE file distributed with
++ * this work for additional information regarding copyright ownership.
++ * The ASF licenses this file to You under the Apache License, Version 2.0
++ * (the "License"); you may not use this file except in compliance with
++ * the License. You may obtain a copy of the License at
++ *
++ * http://www.apache.org/licenses/LICENSE-2.0
++ *
++ * Unless required by applicable law or agreed to in writing, software
++ * distributed under the License is distributed on an "AS IS" BASIS,
++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++ * See the License for the specific language governing permissions and
++ * limitations under the License.
++ */
++
++package org.apache.spark.celeborn
++
++import java.util.concurrent.ConcurrentHashMap
++import java.util.concurrent.atomic.AtomicBoolean
++
++import org.apache.spark.SparkEnv
++import org.apache.spark.internal.config.ConfigBuilder
++
++object CelebornShuffleState {
++
++ private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ =
++ ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled")
++ .booleanConf
++ .createWithDefault(false)
++
++ private val CELEBORN_STAGE_RERUN_ENABLED =
++ ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled")
++ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure")
++ .booleanConf
++ .createWithDefault(false)
++
++ private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean()
++ private val stageRerunEnabled = new AtomicBoolean()
++ private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]()
++
++ // call this from SparkEnv.create
++ def init(env: SparkEnv): Unit = {
++ // cleanup existing state (if required) - and initialize
++ skewShuffleIds.clear()
++
++ // use env.conf for all initialization, and not SQLConf
++ celebornOptimizeSkewedPartitionReadEnabled.set(
++ env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") &&
++ env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ))
++ stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED))
++ }
++
++ def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = {
++ skewShuffleIds.remove(shuffleId)
++ }
++
++ def registerCelebornSkewedShuffle(shuffleId: Int): Unit = {
++ skewShuffleIds.add(shuffleId)
++ }
++
++ def isCelebornSkewedShuffle(shuffleId: Int): Boolean = {
++ skewShuffleIds.contains(shuffleId)
++ }
++
++ def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = {
++ celebornOptimizeSkewedPartitionReadEnabled.get()
++ }
++
++ def celebornStageRerunEnabled: Boolean = {
++ stageRerunEnabled.get()
++ }
++
++}
+diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+index b950c07f3d8..2cb430c3c3d 100644
+--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+@@ -33,6 +33,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture}
+
+ import org.apache.spark._
+ import org.apache.spark.broadcast.Broadcast
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
+ import org.apache.spark.internal.Logging
+ import org.apache.spark.internal.config
+@@ -1780,7 +1781,7 @@ private[spark] class DAGScheduler(
+ failedStage.failedAttemptIds.add(task.stageAttemptId)
+ val shouldAbortStage =
+ failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
+- disallowStageRetryForTest
++ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)
+
+ // It is likely that we receive multiple FetchFailed for a single stage (because we have
+ // multiple tasks running concurrently on different executors). In that case, it is
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
+new file mode 100644
+index 00000000000..3dc60678461
+--- /dev/null
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
+@@ -0,0 +1,35 @@
++/*
++ * Licensed to the Apache Software Foundation (ASF) under one or more
++ * contributor license agreements. See the NOTICE file distributed with
++ * this work for additional information regarding copyright ownership.
++ * The ASF licenses this file to You under the Apache License, Version 2.0
++ * (the "License"); you may not use this file except in compliance with
++ * the License. You may obtain a copy of the License at
++ *
++ * http://www.apache.org/licenses/LICENSE-2.0
++ *
++ * Unless required by applicable law or agreed to in writing, software
++ * distributed under the License is distributed on an "AS IS" BASIS,
++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++ * See the License for the specific language governing permissions and
++ * limitations under the License.
++ */
++
++package org.apache.spark.sql.execution.adaptive
++
++import java.util.Locale
++
++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
++
++object CelebornShuffleUtil {
++
++ def isCelebornShuffle(shuffleExchangeLike: ShuffleExchangeLike): Boolean = {
++ shuffleExchangeLike match {
++ case exec: ShuffleExchangeExec =>
++ exec.shuffleDependency.shuffleHandle
++ .getClass.getName.toLowerCase(Locale.ROOT).contains("celeborn")
++ case _ => false
++ }
++ }
++
++}
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
+index 1752907a9a5..2c6a49b78eb 100644
+--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
+@@ -50,12 +50,13 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
+ private def optimizeSkewedPartitions(
+ shuffleId: Int,
+ bytesByPartitionId: Array[Long],
+- targetSize: Long): Seq[ShufflePartitionSpec] = {
++ targetSize: Long,
++ isCelebornShuffle: Boolean = false): Seq[ShufflePartitionSpec] = {
+ bytesByPartitionId.indices.flatMap { reduceIndex =>
+ val bytes = bytesByPartitionId(reduceIndex)
+ if (bytes > targetSize) {
+ val newPartitionSpec =
+- ShufflePartitionsUtil.createSkewPartitionSpecs(shuffleId, reduceIndex, targetSize)
++ ShufflePartitionsUtil.createSkewPartitionSpecs(shuffleId, reduceIndex, targetSize, isCelebornShuffle)
+ if (newPartitionSpec.isEmpty) {
+ CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil
+ } else {
+@@ -77,8 +78,9 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
+ return shuffle
+ }
+
++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(shuffle.shuffle)
+ val newPartitionsSpec = optimizeSkewedPartitions(
+- mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize)
++ mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize, isCelebornShuffle)
+ // return origin plan if we can not optimize partitions
+ if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) {
+ shuffle
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+index 88abe68197b..150699a84a3 100644
+--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+@@ -157,8 +157,10 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule {
+ Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize))
+
+ val leftParts = if (isLeftSkew) {
++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(left.shuffle)
+ val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs(
+- left.mapStats.get.shuffleId, partitionIndex, leftTargetSize)
++ left.mapStats.get.shuffleId, partitionIndex, leftTargetSize,
++ isCelebornShuffle = isCelebornShuffle)
+ if (skewSpecs.isDefined) {
+ logDebug(s"Left side partition $partitionIndex " +
+ s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " +
+@@ -171,8 +173,10 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule {
+ }
+
+ val rightParts = if (isRightSkew) {
++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(right.shuffle)
+ val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs(
+- right.mapStats.get.shuffleId, partitionIndex, rightTargetSize)
++ right.mapStats.get.shuffleId, partitionIndex, rightTargetSize,
++ isCelebornShuffle = isCelebornShuffle)
+ if (skewSpecs.isDefined) {
+ logDebug(s"Right side partition $partitionIndex " +
+ s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " +
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
+index 3609548f374..d34f43bf064 100644
+--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
+@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive
+ import scala.collection.mutable.ArrayBuffer
+
+ import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.internal.Logging
+ import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec}
+
+@@ -376,11 +377,20 @@ object ShufflePartitionsUtil extends Logging {
+ def createSkewPartitionSpecs(
+ shuffleId: Int,
+ reducerId: Int,
+- targetSize: Long): Option[Seq[PartialReducerPartitionSpec]] = {
++ targetSize: Long,
++ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = {
+ val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId)
+ if (mapPartitionSizes.exists(_ < 0)) return None
+ val mapStartIndices = splitSizeListByTargetSize(mapPartitionSizes, targetSize)
+ if (mapStartIndices.length > 1) {
++ val celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled =
++ CelebornShuffleState.celebornAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle
++
++ val throwsFetchFailure = CelebornShuffleState.celebornStageRerunEnabled
++ if (throwsFetchFailure && celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) {
++ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed")
++ CelebornShuffleState.registerCelebornSkewedShuffle(shuffleId)
++ }
+ Some(mapStartIndices.indices.map { i =>
+ val startMapIndex = mapStartIndices(i)
+ val endMapIndex = if (i == mapStartIndices.length - 1) {
+@@ -388,8 +398,21 @@ object ShufflePartitionsUtil extends Logging {
+ } else {
+ mapStartIndices(i + 1)
+ }
+- val dataSize = startMapIndex.until(endMapIndex).map(mapPartitionSizes(_)).sum
+- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize)
++ var dataSize = 0L
++ var mapIndex = startMapIndex
++ while (mapIndex < endMapIndex) {
++ dataSize += mapPartitionSizes(mapIndex)
++ mapIndex += 1
++ }
++
++ if (celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) {
++ // These `dataSize` variables may not be accurate as they only represent the sum of
++ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled.
++ // Please not to use these dataSize variables in any other part of the codebase.
++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize)
++ } else {
++ PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize)
++ }
+ })
+ } else {
+ None
diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch
new file mode 100644
index 00000000000..f8e38615c6a
--- /dev/null
+++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+index b1974948430..a045c8646ba 100644
+--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
++++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+@@ -33,6 +33,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut
+ import org.roaringbitmap.RoaringBitmap
+
+ import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.internal.Logging
+ import org.apache.spark.internal.config._
+ import org.apache.spark.io.CompressionCodec
+@@ -886,6 +887,7 @@ private[spark] class MapOutputTrackerMaster(
+ shuffleStatus.invalidateSerializedMergeOutputStatusCache()
+ }
+ }
++ CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId)
+ }
+
+ /**
+diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
+index 19467e7eca1..0ae4990219c 100644
+--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
++++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
+@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration
+ import org.apache.spark.annotation.DeveloperApi
+ import org.apache.spark.api.python.PythonWorkerFactory
+ import org.apache.spark.broadcast.BroadcastManager
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.executor.ExecutorBackend
+ import org.apache.spark.internal.{config, Logging}
+ import org.apache.spark.internal.config._
+@@ -419,6 +420,7 @@ object SparkEnv extends Logging {
+ if (isDriver) {
+ val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath
+ envInstance.driverTmpDir = Some(sparkFilesDir)
++ CelebornShuffleState.init(envInstance)
+ }
+
+ envInstance
+diff --git a/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala
+new file mode 100644
+index 00000000000..5e190c512df
+--- /dev/null
++++ b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala
+@@ -0,0 +1,75 @@
++/*
++ * Licensed to the Apache Software Foundation (ASF) under one or more
++ * contributor license agreements. See the NOTICE file distributed with
++ * this work for additional information regarding copyright ownership.
++ * The ASF licenses this file to You under the Apache License, Version 2.0
++ * (the "License"); you may not use this file except in compliance with
++ * the License. You may obtain a copy of the License at
++ *
++ * http://www.apache.org/licenses/LICENSE-2.0
++ *
++ * Unless required by applicable law or agreed to in writing, software
++ * distributed under the License is distributed on an "AS IS" BASIS,
++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++ * See the License for the specific language governing permissions and
++ * limitations under the License.
++ */
++
++package org.apache.spark.celeborn
++
++import java.util.concurrent.ConcurrentHashMap
++import java.util.concurrent.atomic.AtomicBoolean
++
++import org.apache.spark.SparkEnv
++import org.apache.spark.internal.config.ConfigBuilder
++
++object CelebornShuffleState {
++
++ private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ =
++ ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled")
++ .booleanConf
++ .createWithDefault(false)
++
++ private val CELEBORN_STAGE_RERUN_ENABLED =
++ ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled")
++ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure")
++ .booleanConf
++ .createWithDefault(false)
++
++ private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean()
++ private val stageRerunEnabled = new AtomicBoolean()
++ private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]()
++
++ // call this from SparkEnv.create
++ def init(env: SparkEnv): Unit = {
++ // cleanup existing state (if required) - and initialize
++ skewShuffleIds.clear()
++
++ // use env.conf for all initialization, and not SQLConf
++ celebornOptimizeSkewedPartitionReadEnabled.set(
++ env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") &&
++ env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ))
++ stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED))
++ }
++
++ def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = {
++ skewShuffleIds.remove(shuffleId)
++ }
++
++ def registerCelebornSkewedShuffle(shuffleId: Int): Unit = {
++ skewShuffleIds.add(shuffleId)
++ }
++
++ def isCelebornSkewedShuffle(shuffleId: Int): Boolean = {
++ skewShuffleIds.contains(shuffleId)
++ }
++
++ def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = {
++ celebornOptimizeSkewedPartitionReadEnabled.get()
++ }
++
++ def celebornStageRerunEnabled: Boolean = {
++ stageRerunEnabled.get()
++ }
++
++}
+diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+index bd2823bcac1..d0c88081527 100644
+--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+@@ -33,6 +33,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture}
+
+ import org.apache.spark._
+ import org.apache.spark.broadcast.Broadcast
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.errors.SparkCoreErrors
+ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
+ import org.apache.spark.internal.Logging
+@@ -1851,7 +1852,7 @@ private[spark] class DAGScheduler(
+ failedStage.failedAttemptIds.add(task.stageAttemptId)
+ val shouldAbortStage =
+ failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
+- disallowStageRetryForTest
++ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)
+
+ // It is likely that we receive multiple FetchFailed for a single stage (because we have
+ // multiple tasks running concurrently on different executors). In that case, it is
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
+new file mode 100644
+index 00000000000..3dc60678461
+--- /dev/null
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
+@@ -0,0 +1,35 @@
++/*
++ * Licensed to the Apache Software Foundation (ASF) under one or more
++ * contributor license agreements. See the NOTICE file distributed with
++ * this work for additional information regarding copyright ownership.
++ * The ASF licenses this file to You under the Apache License, Version 2.0
++ * (the "License"); you may not use this file except in compliance with
++ * the License. You may obtain a copy of the License at
++ *
++ * http://www.apache.org/licenses/LICENSE-2.0
++ *
++ * Unless required by applicable law or agreed to in writing, software
++ * distributed under the License is distributed on an "AS IS" BASIS,
++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++ * See the License for the specific language governing permissions and
++ * limitations under the License.
++ */
++
++package org.apache.spark.sql.execution.adaptive
++
++import java.util.Locale
++
++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
++
++object CelebornShuffleUtil {
++
++ def isCelebornShuffle(shuffleExchangeLike: ShuffleExchangeLike): Boolean = {
++ shuffleExchangeLike match {
++ case exec: ShuffleExchangeExec =>
++ exec.shuffleDependency.shuffleHandle
++ .getClass.getName.toLowerCase(Locale.ROOT).contains("celeborn")
++ case _ => false
++ }
++ }
++
++}
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
+index b34ab3e380b..cb0ed9d05a4 100644
+--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
+@@ -47,14 +47,15 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
+ private def optimizeSkewedPartitions(
+ shuffleId: Int,
+ bytesByPartitionId: Array[Long],
+- targetSize: Long): Seq[ShufflePartitionSpec] = {
++ targetSize: Long,
++ isCelebornShuffle: Boolean = false): Seq[ShufflePartitionSpec] = {
+ val smallPartitionFactor =
+ conf.getConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR)
+ bytesByPartitionId.indices.flatMap { reduceIndex =>
+ val bytes = bytesByPartitionId(reduceIndex)
+ if (bytes > targetSize) {
+ val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs(
+- shuffleId, reduceIndex, targetSize, smallPartitionFactor)
++ shuffleId, reduceIndex, targetSize, smallPartitionFactor, isCelebornShuffle)
+ if (newPartitionSpec.isEmpty) {
+ CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil
+ } else {
+@@ -76,8 +77,9 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
+ return shuffle
+ }
+
++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(shuffle.shuffle)
+ val newPartitionsSpec = optimizeSkewedPartitions(
+- mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize)
++ mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize, isCelebornShuffle)
+ // return origin plan if we can not optimize partitions
+ if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) {
+ shuffle
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+index d4a173bb9cc..21ef335e064 100644
+--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+@@ -152,8 +152,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements)
+ Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize))
+
+ val leftParts = if (isLeftSkew) {
++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(left.shuffle)
+ val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs(
+- left.mapStats.get.shuffleId, partitionIndex, leftTargetSize)
++ left.mapStats.get.shuffleId, partitionIndex, leftTargetSize,
++ isCelebornShuffle = isCelebornShuffle)
+ if (skewSpecs.isDefined) {
+ logDebug(s"Left side partition $partitionIndex " +
+ s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " +
+@@ -166,8 +168,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements)
+ }
+
+ val rightParts = if (isRightSkew) {
++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(right.shuffle)
+ val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs(
+- right.mapStats.get.shuffleId, partitionIndex, rightTargetSize)
++ right.mapStats.get.shuffleId, partitionIndex, rightTargetSize,
++ isCelebornShuffle = isCelebornShuffle)
+ if (skewSpecs.isDefined) {
+ logDebug(s"Right side partition $partitionIndex " +
+ s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " +
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
+index af689db3379..39d0b3132ee 100644
+--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
+@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive
+ import scala.collection.mutable.ArrayBuffer
+
+ import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.internal.Logging
+ import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec}
+
+@@ -380,13 +381,21 @@ object ShufflePartitionsUtil extends Logging {
+ shuffleId: Int,
+ reducerId: Int,
+ targetSize: Long,
+- smallPartitionFactor: Double = SMALL_PARTITION_FACTOR)
+- : Option[Seq[PartialReducerPartitionSpec]] = {
++ smallPartitionFactor: Double = SMALL_PARTITION_FACTOR,
++ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = {
+ val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId)
+ if (mapPartitionSizes.exists(_ < 0)) return None
+ val mapStartIndices = splitSizeListByTargetSize(
+ mapPartitionSizes, targetSize, smallPartitionFactor)
+ if (mapStartIndices.length > 1) {
++ val celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled =
++ CelebornShuffleState.celebornAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle
++
++ val throwsFetchFailure = CelebornShuffleState.celebornStageRerunEnabled
++ if (throwsFetchFailure && celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) {
++ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed")
++ CelebornShuffleState.registerCelebornSkewedShuffle(shuffleId)
++ }
+ Some(mapStartIndices.indices.map { i =>
+ val startMapIndex = mapStartIndices(i)
+ val endMapIndex = if (i == mapStartIndices.length - 1) {
+@@ -400,7 +409,15 @@ object ShufflePartitionsUtil extends Logging {
+ dataSize += mapPartitionSizes(mapIndex)
+ mapIndex += 1
+ }
+- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize)
++
++ if (celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) {
++ // These `dataSize` variables may not be accurate as they only represent the sum of
++ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled.
++ // Please not to use these dataSize variables in any other part of the codebase.
++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize)
++ } else {
++ PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize)
++ }
+ })
+ } else {
+ None
diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch
new file mode 100644
index 00000000000..9aed835fe96
--- /dev/null
+++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+index fade0b86dd8..ca0940a9251 100644
+--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
++++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+@@ -34,6 +34,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut
+ import org.roaringbitmap.RoaringBitmap
+
+ import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.internal.Logging
+ import org.apache.spark.internal.config._
+ import org.apache.spark.io.CompressionCodec
+@@ -887,6 +888,7 @@ private[spark] class MapOutputTrackerMaster(
+ shuffleStatus.invalidateSerializedMergeOutputStatusCache()
+ }
+ }
++ CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId)
+ }
+
+ /**
+diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
+index acab9a634fc..23eb72c49ac 100644
+--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
++++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
+@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration
+ import org.apache.spark.annotation.DeveloperApi
+ import org.apache.spark.api.python.PythonWorkerFactory
+ import org.apache.spark.broadcast.BroadcastManager
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.executor.ExecutorBackend
+ import org.apache.spark.internal.{config, Logging}
+ import org.apache.spark.internal.config._
+@@ -419,6 +420,7 @@ object SparkEnv extends Logging {
+ if (isDriver) {
+ val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath
+ envInstance.driverTmpDir = Some(sparkFilesDir)
++ CelebornShuffleState.init(envInstance)
+ }
+
+ envInstance
+diff --git a/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala
+new file mode 100644
+index 00000000000..5e190c512df
+--- /dev/null
++++ b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala
+@@ -0,0 +1,75 @@
++/*
++ * Licensed to the Apache Software Foundation (ASF) under one or more
++ * contributor license agreements. See the NOTICE file distributed with
++ * this work for additional information regarding copyright ownership.
++ * The ASF licenses this file to You under the Apache License, Version 2.0
++ * (the "License"); you may not use this file except in compliance with
++ * the License. You may obtain a copy of the License at
++ *
++ * http://www.apache.org/licenses/LICENSE-2.0
++ *
++ * Unless required by applicable law or agreed to in writing, software
++ * distributed under the License is distributed on an "AS IS" BASIS,
++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++ * See the License for the specific language governing permissions and
++ * limitations under the License.
++ */
++
++package org.apache.spark.celeborn
++
++import java.util.concurrent.ConcurrentHashMap
++import java.util.concurrent.atomic.AtomicBoolean
++
++import org.apache.spark.SparkEnv
++import org.apache.spark.internal.config.ConfigBuilder
++
++object CelebornShuffleState {
++
++ private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ =
++ ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled")
++ .booleanConf
++ .createWithDefault(false)
++
++ private val CELEBORN_STAGE_RERUN_ENABLED =
++ ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled")
++ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure")
++ .booleanConf
++ .createWithDefault(false)
++
++ private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean()
++ private val stageRerunEnabled = new AtomicBoolean()
++ private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]()
++
++ // call this from SparkEnv.create
++ def init(env: SparkEnv): Unit = {
++ // cleanup existing state (if required) - and initialize
++ skewShuffleIds.clear()
++
++ // use env.conf for all initialization, and not SQLConf
++ celebornOptimizeSkewedPartitionReadEnabled.set(
++ env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") &&
++ env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ))
++ stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED))
++ }
++
++ def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = {
++ skewShuffleIds.remove(shuffleId)
++ }
++
++ def registerCelebornSkewedShuffle(shuffleId: Int): Unit = {
++ skewShuffleIds.add(shuffleId)
++ }
++
++ def isCelebornSkewedShuffle(shuffleId: Int): Boolean = {
++ skewShuffleIds.contains(shuffleId)
++ }
++
++ def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = {
++ celebornOptimizeSkewedPartitionReadEnabled.get()
++ }
++
++ def celebornStageRerunEnabled: Boolean = {
++ stageRerunEnabled.get()
++ }
++
++}
+diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+index 26be8c72bbc..81feaba962c 100644
+--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+@@ -34,6 +34,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture}
+
+ import org.apache.spark._
+ import org.apache.spark.broadcast.Broadcast
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.errors.SparkCoreErrors
+ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
+ import org.apache.spark.internal.Logging
+@@ -1897,7 +1898,7 @@ private[spark] class DAGScheduler(
+
+ val shouldAbortStage =
+ failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
+- disallowStageRetryForTest
++ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)
+
+ // It is likely that we receive multiple FetchFailed for a single stage (because we have
+ // multiple tasks running concurrently on different executors). In that case, it is
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
+new file mode 100644
+index 00000000000..3dc60678461
+--- /dev/null
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
+@@ -0,0 +1,35 @@
++/*
++ * Licensed to the Apache Software Foundation (ASF) under one or more
++ * contributor license agreements. See the NOTICE file distributed with
++ * this work for additional information regarding copyright ownership.
++ * The ASF licenses this file to You under the Apache License, Version 2.0
++ * (the "License"); you may not use this file except in compliance with
++ * the License. You may obtain a copy of the License at
++ *
++ * http://www.apache.org/licenses/LICENSE-2.0
++ *
++ * Unless required by applicable law or agreed to in writing, software
++ * distributed under the License is distributed on an "AS IS" BASIS,
++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++ * See the License for the specific language governing permissions and
++ * limitations under the License.
++ */
++
++package org.apache.spark.sql.execution.adaptive
++
++import java.util.Locale
++
++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
++
++object CelebornShuffleUtil {
++
++ def isCelebornShuffle(shuffleExchangeLike: ShuffleExchangeLike): Boolean = {
++ shuffleExchangeLike match {
++ case exec: ShuffleExchangeExec =>
++ exec.shuffleDependency.shuffleHandle
++ .getClass.getName.toLowerCase(Locale.ROOT).contains("celeborn")
++ case _ => false
++ }
++ }
++
++}
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
+index b34ab3e380b..cb0ed9d05a4 100644
+--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
+@@ -47,14 +47,15 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
+ private def optimizeSkewedPartitions(
+ shuffleId: Int,
+ bytesByPartitionId: Array[Long],
+- targetSize: Long): Seq[ShufflePartitionSpec] = {
++ targetSize: Long,
++ isCelebornShuffle: Boolean = false): Seq[ShufflePartitionSpec] = {
+ val smallPartitionFactor =
+ conf.getConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR)
+ bytesByPartitionId.indices.flatMap { reduceIndex =>
+ val bytes = bytesByPartitionId(reduceIndex)
+ if (bytes > targetSize) {
+ val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs(
+- shuffleId, reduceIndex, targetSize, smallPartitionFactor)
++ shuffleId, reduceIndex, targetSize, smallPartitionFactor, isCelebornShuffle)
+ if (newPartitionSpec.isEmpty) {
+ CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil
+ } else {
+@@ -76,8 +77,9 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
+ return shuffle
+ }
+
++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(shuffle.shuffle)
+ val newPartitionsSpec = optimizeSkewedPartitions(
+- mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize)
++ mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize, isCelebornShuffle)
+ // return origin plan if we can not optimize partitions
+ if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) {
+ shuffle
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+index 37cdea084d8..4694a06919e 100644
+--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+@@ -152,8 +152,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements)
+ Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize))
+
+ val leftParts = if (isLeftSkew) {
++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(left.shuffle)
+ val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs(
+- left.mapStats.get.shuffleId, partitionIndex, leftTargetSize)
++ left.mapStats.get.shuffleId, partitionIndex, leftTargetSize,
++ isCelebornShuffle = isCelebornShuffle)
+ if (skewSpecs.isDefined) {
+ logDebug(s"Left side partition $partitionIndex " +
+ s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " +
+@@ -166,8 +168,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements)
+ }
+
+ val rightParts = if (isRightSkew) {
++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(right.shuffle)
+ val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs(
+- right.mapStats.get.shuffleId, partitionIndex, rightTargetSize)
++ right.mapStats.get.shuffleId, partitionIndex, rightTargetSize,
++ isCelebornShuffle = isCelebornShuffle)
+ if (skewSpecs.isDefined) {
+ logDebug(s"Right side partition $partitionIndex " +
+ s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " +
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
+index dbed66683b0..d656c8af6b7 100644
+--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
+@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive
+ import scala.collection.mutable.ArrayBuffer
+
+ import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.internal.Logging
+ import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec}
+
+@@ -380,13 +381,21 @@ object ShufflePartitionsUtil extends Logging {
+ shuffleId: Int,
+ reducerId: Int,
+ targetSize: Long,
+- smallPartitionFactor: Double = SMALL_PARTITION_FACTOR)
+- : Option[Seq[PartialReducerPartitionSpec]] = {
++ smallPartitionFactor: Double = SMALL_PARTITION_FACTOR,
++ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = {
+ val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId)
+ if (mapPartitionSizes.exists(_ < 0)) return None
+ val mapStartIndices = splitSizeListByTargetSize(
+ mapPartitionSizes, targetSize, smallPartitionFactor)
+ if (mapStartIndices.length > 1) {
++ val celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled =
++ CelebornShuffleState.celebornAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle
++
++ val throwsFetchFailure = CelebornShuffleState.celebornStageRerunEnabled
++ if (throwsFetchFailure && celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) {
++ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed")
++ CelebornShuffleState.registerCelebornSkewedShuffle(shuffleId)
++ }
+ Some(mapStartIndices.indices.map { i =>
+ val startMapIndex = mapStartIndices(i)
+ val endMapIndex = if (i == mapStartIndices.length - 1) {
+@@ -400,7 +409,15 @@ object ShufflePartitionsUtil extends Logging {
+ dataSize += mapPartitionSizes(mapIndex)
+ mapIndex += 1
+ }
+- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize)
++
++ if (celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) {
++ // These `dataSize` variables may not be accurate as they only represent the sum of
++ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled.
++ // Please not to use these dataSize variables in any other part of the codebase.
++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize)
++ } else {
++ PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize)
++ }
+ })
+ } else {
+ None
diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch
new file mode 100644
index 00000000000..553bdeae668
--- /dev/null
+++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+index 9a7a3b0c0e7..543423dadd9 100644
+--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
++++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+@@ -34,6 +34,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut
+ import org.roaringbitmap.RoaringBitmap
+
+ import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.internal.Logging
+ import org.apache.spark.internal.config._
+ import org.apache.spark.io.CompressionCodec
+@@ -916,6 +917,7 @@ private[spark] class MapOutputTrackerMaster(
+ shuffleStatus.invalidateSerializedMergeOutputStatusCache()
+ }
+ }
++ CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId)
+ }
+
+ /**
+diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
+index edad91a0c6f..76b377729a0 100644
+--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
++++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
+@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration
+ import org.apache.spark.annotation.DeveloperApi
+ import org.apache.spark.api.python.PythonWorkerFactory
+ import org.apache.spark.broadcast.BroadcastManager
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.executor.ExecutorBackend
+ import org.apache.spark.internal.{config, Logging}
+ import org.apache.spark.internal.config._
+@@ -419,6 +420,7 @@ object SparkEnv extends Logging {
+ if (isDriver) {
+ val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath
+ envInstance.driverTmpDir = Some(sparkFilesDir)
++ CelebornShuffleState.init(envInstance)
+ }
+
+ envInstance
+diff --git a/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala
+new file mode 100644
+index 00000000000..5e190c512df
+--- /dev/null
++++ b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala
+@@ -0,0 +1,75 @@
++/*
++ * Licensed to the Apache Software Foundation (ASF) under one or more
++ * contributor license agreements. See the NOTICE file distributed with
++ * this work for additional information regarding copyright ownership.
++ * The ASF licenses this file to You under the Apache License, Version 2.0
++ * (the "License"); you may not use this file except in compliance with
++ * the License. You may obtain a copy of the License at
++ *
++ * http://www.apache.org/licenses/LICENSE-2.0
++ *
++ * Unless required by applicable law or agreed to in writing, software
++ * distributed under the License is distributed on an "AS IS" BASIS,
++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++ * See the License for the specific language governing permissions and
++ * limitations under the License.
++ */
++
++package org.apache.spark.celeborn
++
++import java.util.concurrent.ConcurrentHashMap
++import java.util.concurrent.atomic.AtomicBoolean
++
++import org.apache.spark.SparkEnv
++import org.apache.spark.internal.config.ConfigBuilder
++
++object CelebornShuffleState {
++
++ private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ =
++ ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled")
++ .booleanConf
++ .createWithDefault(false)
++
++ private val CELEBORN_STAGE_RERUN_ENABLED =
++ ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled")
++ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure")
++ .booleanConf
++ .createWithDefault(false)
++
++ private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean()
++ private val stageRerunEnabled = new AtomicBoolean()
++ private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]()
++
++ // call this from SparkEnv.create
++ def init(env: SparkEnv): Unit = {
++ // cleanup existing state (if required) - and initialize
++ skewShuffleIds.clear()
++
++ // use env.conf for all initialization, and not SQLConf
++ celebornOptimizeSkewedPartitionReadEnabled.set(
++ env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") &&
++ env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ))
++ stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED))
++ }
++
++ def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = {
++ skewShuffleIds.remove(shuffleId)
++ }
++
++ def registerCelebornSkewedShuffle(shuffleId: Int): Unit = {
++ skewShuffleIds.add(shuffleId)
++ }
++
++ def isCelebornSkewedShuffle(shuffleId: Int): Boolean = {
++ skewShuffleIds.contains(shuffleId)
++ }
++
++ def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = {
++ celebornOptimizeSkewedPartitionReadEnabled.get()
++ }
++
++ def celebornStageRerunEnabled: Boolean = {
++ stageRerunEnabled.get()
++ }
++
++}
+diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+index 89d16e57934..3b9094f3254 100644
+--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+@@ -34,6 +34,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture}
+
+ import org.apache.spark._
+ import org.apache.spark.broadcast.Broadcast
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.errors.SparkCoreErrors
+ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
+ import org.apache.spark.internal.Logging
+@@ -1962,7 +1963,7 @@ private[spark] class DAGScheduler(
+
+ val shouldAbortStage =
+ failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
+- disallowStageRetryForTest
++ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)
+
+ // It is likely that we receive multiple FetchFailed for a single stage (because we have
+ // multiple tasks running concurrently on different executors). In that case, it is
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
+new file mode 100644
+index 00000000000..3dc60678461
+--- /dev/null
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
+@@ -0,0 +1,35 @@
++/*
++ * Licensed to the Apache Software Foundation (ASF) under one or more
++ * contributor license agreements. See the NOTICE file distributed with
++ * this work for additional information regarding copyright ownership.
++ * The ASF licenses this file to You under the Apache License, Version 2.0
++ * (the "License"); you may not use this file except in compliance with
++ * the License. You may obtain a copy of the License at
++ *
++ * http://www.apache.org/licenses/LICENSE-2.0
++ *
++ * Unless required by applicable law or agreed to in writing, software
++ * distributed under the License is distributed on an "AS IS" BASIS,
++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++ * See the License for the specific language governing permissions and
++ * limitations under the License.
++ */
++
++package org.apache.spark.sql.execution.adaptive
++
++import java.util.Locale
++
++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
++
++object CelebornShuffleUtil {
++
++ def isCelebornShuffle(shuffleExchangeLike: ShuffleExchangeLike): Boolean = {
++ shuffleExchangeLike match {
++ case exec: ShuffleExchangeExec =>
++ exec.shuffleDependency.shuffleHandle
++ .getClass.getName.toLowerCase(Locale.ROOT).contains("celeborn")
++ case _ => false
++ }
++ }
++
++}
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
+index abd096b9c7c..ff0363f87d8 100644
+--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala
+@@ -47,14 +47,15 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
+ private def optimizeSkewedPartitions(
+ shuffleId: Int,
+ bytesByPartitionId: Array[Long],
+- targetSize: Long): Seq[ShufflePartitionSpec] = {
++ targetSize: Long,
++ isCelebornShuffle: Boolean = false): Seq[ShufflePartitionSpec] = {
+ val smallPartitionFactor =
+ conf.getConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR)
+ bytesByPartitionId.indices.flatMap { reduceIndex =>
+ val bytes = bytesByPartitionId(reduceIndex)
+ if (bytes > targetSize) {
+ val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs(
+- shuffleId, reduceIndex, targetSize, smallPartitionFactor)
++ shuffleId, reduceIndex, targetSize, smallPartitionFactor, isCelebornShuffle)
+ if (newPartitionSpec.isEmpty) {
+ CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil
+ } else {
+@@ -77,8 +78,9 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
+ return shuffle
+ }
+
++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(shuffle.shuffle)
+ val newPartitionsSpec = optimizeSkewedPartitions(
+- mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize)
++ mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize, isCelebornShuffle)
+ // return origin plan if we can not optimize partitions
+ if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) {
+ shuffle
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+index 37cdea084d8..4694a06919e 100644
+--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+@@ -152,8 +152,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements)
+ Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize))
+
+ val leftParts = if (isLeftSkew) {
++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(left.shuffle)
+ val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs(
+- left.mapStats.get.shuffleId, partitionIndex, leftTargetSize)
++ left.mapStats.get.shuffleId, partitionIndex, leftTargetSize,
++ isCelebornShuffle = isCelebornShuffle)
+ if (skewSpecs.isDefined) {
+ logDebug(s"Left side partition $partitionIndex " +
+ s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " +
+@@ -166,8 +168,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements)
+ }
+
+ val rightParts = if (isRightSkew) {
++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(right.shuffle)
+ val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs(
+- right.mapStats.get.shuffleId, partitionIndex, rightTargetSize)
++ right.mapStats.get.shuffleId, partitionIndex, rightTargetSize,
++ isCelebornShuffle = isCelebornShuffle)
+ if (skewSpecs.isDefined) {
+ logDebug(s"Right side partition $partitionIndex " +
+ s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " +
+diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
+index 9370b3d8d1d..d36e26a1376 100644
+--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
+@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive
+ import scala.collection.mutable.ArrayBuffer
+
+ import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
++import org.apache.spark.celeborn.CelebornShuffleState
+ import org.apache.spark.internal.Logging
+ import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec}
+
+@@ -382,13 +383,21 @@ object ShufflePartitionsUtil extends Logging {
+ shuffleId: Int,
+ reducerId: Int,
+ targetSize: Long,
+- smallPartitionFactor: Double = SMALL_PARTITION_FACTOR)
+- : Option[Seq[PartialReducerPartitionSpec]] = {
++ smallPartitionFactor: Double = SMALL_PARTITION_FACTOR,
++ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = {
+ val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId)
+ if (mapPartitionSizes.exists(_ < 0)) return None
+ val mapStartIndices = splitSizeListByTargetSize(
+ mapPartitionSizes, targetSize, smallPartitionFactor)
+ if (mapStartIndices.length > 1) {
++ val celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled =
++ CelebornShuffleState.celebornAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle
++
++ val throwsFetchFailure = CelebornShuffleState.celebornStageRerunEnabled
++ if (throwsFetchFailure && celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) {
++ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed")
++ CelebornShuffleState.registerCelebornSkewedShuffle(shuffleId)
++ }
+ Some(mapStartIndices.indices.map { i =>
+ val startMapIndex = mapStartIndices(i)
+ val endMapIndex = if (i == mapStartIndices.length - 1) {
+@@ -402,7 +411,15 @@ object ShufflePartitionsUtil extends Logging {
+ dataSize += mapPartitionSizes(mapIndex)
+ mapIndex += 1
+ }
+- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize)
++
++ if (celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) {
++ // These `dataSize` variables may not be accurate as they only represent the sum of
++ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled.
++ // Please not to use these dataSize variables in any other part of the codebase.
++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize)
++ } else {
++ PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize)
++ }
+ })
+ } else {
+ None
diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java
new file mode 100644
index 00000000000..6bf47addca9
--- /dev/null
+++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.util.*;
+
+import org.apache.commons.lang3.tuple.Pair;
+
+import org.apache.celeborn.common.protocol.PartitionLocation;
+
+public class CelebornPartitionUtil {
+ /**
+ * The general idea is to divide each skew partition into smaller partitions:
+ *
+ *
- Spark driver will calculate the number of sub-partitions: {@code subPartitionSize =
+ * skewPartitionTotalSize / subPartitionTargetSize}
+ *
+ *
- In Celeborn, we divide the skew partition into {@code subPartitionSize} small partitions
+ * by PartitionLocation chunk offsets. This allows them to run in parallel Spark tasks.
+ *
+ *
For example, one skewed partition has 2 PartitionLocation:
+ *
+ *
+ * - PartitionLocation 0 with chunk offset [0L, 100L, 200L, 300L, 500L, 1000L]
+ *
- PartitionLocation 1 with chunk offset [0L, 200L, 500L, 800L, 900L, 1000L]
+ *
+ *
+ * If we want to divide it into 3 sub-partitions (each sub-partition target size is 2000/3), the
+ * result will be:
+ *
+ *
+ * - sub-partition 0: uniqueId0 -> (0, 3)
+ *
- sub-partition 1: uniqueId0 -> (4, 4), uniqueId1 -> (0, 0)
+ *
- sub-partition 2: uniqueId1 -> (1, 4)
+ *
+ *
+ * Note: (0, 3) means chunks with chunkIndex 0-1-2-3, four chunks.
+ *
+ * @param locations PartitionLocation information belonging to the reduce partition
+ * @param subPartitionSize the number of sub-partitions separated from the reduce partition
+ * @param subPartitionIndex current sub-partition index
+ * @return a map of partitionUniqueId to chunkRange pairs for one subtask of skew partitions
+ */
+ public static Map> splitSkewedPartitionLocations(
+ ArrayList locations, int subPartitionSize, int subPartitionIndex) {
+ locations.sort(Comparator.comparing((PartitionLocation p) -> p.getUniqueId()));
+ long totalPartitionSize =
+ locations.stream().mapToLong((PartitionLocation p) -> p.getStorageInfo().fileSize).sum();
+ long step = totalPartitionSize / subPartitionSize;
+ long startOffset = step * subPartitionIndex;
+ long endOffset =
+ subPartitionIndex < subPartitionSize - 1
+ ? step * (subPartitionIndex + 1)
+ : totalPartitionSize + 1; // last subPartition should include all remaining data
+
+ long partitionLocationOffset = 0;
+ Map> chunkRange = new HashMap<>();
+ for (PartitionLocation p : locations) {
+ int left = -1;
+ int right = -1;
+ Iterator chunkOffsets = p.getStorageInfo().getChunkOffsets().iterator();
+ // Start from index 1 since the first chunk offset is always 0.
+ chunkOffsets.next();
+ int j = 1;
+ while (chunkOffsets.hasNext()) {
+ long currentOffset = partitionLocationOffset + chunkOffsets.next();
+ if (currentOffset > startOffset && left < 0) {
+ left = j - 1;
+ }
+ if (currentOffset <= endOffset) {
+ right = j - 1;
+ }
+ if (left >= 0 && right >= 0) {
+ chunkRange.put(p.getUniqueId(), Pair.of(left, right));
+ }
+ j++;
+ }
+ partitionLocationOffset += p.getStorageInfo().getFileSize();
+ }
+ return chunkRange;
+ }
+}
diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index ab51e28d388..b35e549ae38 100644
--- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -18,14 +18,14 @@
package org.apache.spark.shuffle.celeborn
import java.io.IOException
-import java.nio.file.Files
-import java.util
+import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Set => JSet}
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConverters._
import com.google.common.annotations.VisibleForTesting
+import org.apache.commons.lang3.tuple.Pair
import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext}
import org.apache.spark.celeborn.ExceptionMakerHelper
import org.apache.spark.internal.Logging
@@ -35,7 +35,7 @@ import org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
-import org.apache.celeborn.client.{DummyShuffleClient, ShuffleClient}
+import org.apache.celeborn.client.{ClientUtils, ShuffleClient}
import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups
import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
import org.apache.celeborn.common.CelebornConf
@@ -122,15 +122,41 @@ class CelebornShuffleReader[K, C](
}
// host-port -> (TransportClient, PartitionLocation Array, PbOpenStreamList)
- val workerRequestMap = new util.HashMap[
+ val workerRequestMap = new JHashMap[
String,
- (TransportClient, util.ArrayList[PartitionLocation], PbOpenStreamList.Builder)]()
+ (TransportClient, JArrayList[PartitionLocation], PbOpenStreamList.Builder)]()
+ // partitionId -> (partition uniqueId -> chunkRange pair)
+ val partitionId2ChunkRange = new JHashMap[Int, JMap[String, Pair[Integer, Integer]]]()
+
+ val partitionId2PartitionLocations = new JHashMap[Int, JSet[PartitionLocation]]()
var partCnt = 0
+ // if startMapIndex > endMapIndex, means partition is skew partition and read by Celeborn implementation.
+ // locations will split to sub-partitions with startMapIndex size.
+ val splitSkewPartitionWithoutMapRange =
+ ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex)
+
(startPartition until endPartition).foreach { partitionId =>
if (fileGroups.partitionGroups.containsKey(partitionId)) {
- fileGroups.partitionGroups.get(partitionId).asScala.foreach { location =>
+ var locations = fileGroups.partitionGroups.get(partitionId)
+ if (splitSkewPartitionWithoutMapRange) {
+ val partitionLocation2ChunkRange = CelebornPartitionUtil.splitSkewedPartitionLocations(
+ new JArrayList(locations),
+ startMapIndex,
+ endMapIndex)
+ partitionId2ChunkRange.put(partitionId, partitionLocation2ChunkRange)
+ // filter locations avoid OPEN_STREAM when split skew partition without map range
+ val filterLocations = locations.asScala
+ .filter { location =>
+ null != partitionLocation2ChunkRange &&
+ partitionLocation2ChunkRange.containsKey(location.getUniqueId)
+ }
+ locations = filterLocations.asJava
+ partitionId2PartitionLocations.put(partitionId, locations)
+ }
+
+ locations.asScala.foreach { location =>
partCnt += 1
val hostPort = location.hostAndFetchPort
if (!workerRequestMap.containsKey(hostPort)) {
@@ -142,7 +168,7 @@ class CelebornShuffleReader[K, C](
pbOpenStreamList.setShuffleKey(shuffleKey)
workerRequestMap.put(
hostPort,
- (client, new util.ArrayList[PartitionLocation], pbOpenStreamList))
+ (client, new JArrayList[PartitionLocation], pbOpenStreamList))
} catch {
case ex: Exception =>
shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort, ex)
@@ -203,13 +229,22 @@ class CelebornShuffleReader[K, C](
def createInputStream(partitionId: Int): Unit = {
val locations =
- if (fileGroups.partitionGroups.containsKey(partitionId)) {
- new util.ArrayList(fileGroups.partitionGroups.get(partitionId))
- } else new util.ArrayList[PartitionLocation]()
+ if (splitSkewPartitionWithoutMapRange) {
+ partitionId2PartitionLocations.get(partitionId)
+ } else {
+ fileGroups.partitionGroups.get(partitionId)
+ }
+
+ val locationList =
+ if (null == locations) {
+ new JArrayList[PartitionLocation]()
+ } else {
+ new JArrayList[PartitionLocation](locations)
+ }
val streamHandlers =
if (locations != null) {
- val streamHandlerArr = new util.ArrayList[PbStreamHandler](locations.size())
- locations.asScala.foreach { loc =>
+ val streamHandlerArr = new JArrayList[PbStreamHandler](locationList.size)
+ locationList.asScala.foreach { loc =>
streamHandlerArr.add(locationStreamHandlerMap.get(loc))
}
streamHandlerArr
@@ -226,8 +261,10 @@ class CelebornShuffleReader[K, C](
endMapIndex,
if (throwsFetchFailure) ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER
else null,
- locations,
+ locationList,
streamHandlers,
+ fileGroups.pushFailedBatches,
+ partitionId2ChunkRange.get(partitionId),
fileGroups.mapAttempts,
metricsCallback)
streams.put(partitionId, inputStream)
@@ -414,7 +451,6 @@ class CelebornShuffleReader[K, C](
def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
dep.serializer.newInstance()
}
-
}
object CelebornShuffleReader {
diff --git a/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtilSuiteJ.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtilSuiteJ.java
new file mode 100644
index 00000000000..989dd31a94a
--- /dev/null
+++ b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtilSuiteJ.java
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.util.*;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.StorageInfo;
+
+public class CelebornPartitionUtilSuiteJ {
+ @Test
+ public void testSkewPartitionSplit() {
+
+ ArrayList locations = new ArrayList<>();
+ for (int i = 0; i < 13; i++) {
+ locations.add(genPartitionLocation(i, new Long[] {0L, 100L, 200L, 300L, 500L, 1000L}));
+ }
+ locations.add(genPartitionLocation(91, new Long[] {0L, 1L}));
+
+ int subPartitionSize = 3;
+
+ Map> result1 =
+ CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 0);
+ Map> expectResult1 =
+ genRanges(
+ new Object[][] {
+ {"0-0", 0, 4},
+ {"0-1", 0, 4},
+ {"0-10", 0, 4},
+ {"0-11", 0, 4},
+ {"0-12", 0, 2}
+ });
+ Assert.assertEquals(expectResult1, result1);
+
+ Map> result2 =
+ CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 1);
+ Map> expectResult2 =
+ genRanges(
+ new Object[][] {
+ {"0-12", 3, 4},
+ {"0-2", 0, 4},
+ {"0-3", 0, 4},
+ {"0-4", 0, 4},
+ {"0-5", 0, 3}
+ });
+ Assert.assertEquals(expectResult2, result2);
+
+ Map> result3 =
+ CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 2);
+ Map> expectResult3 =
+ genRanges(
+ new Object[][] {
+ {"0-5", 4, 4},
+ {"0-6", 0, 4},
+ {"0-7", 0, 4},
+ {"0-8", 0, 4},
+ {"0-9", 0, 4},
+ {"0-91", 0, 0}
+ });
+ Assert.assertEquals(expectResult3, result3);
+ }
+
+ @Test
+ public void testBoundary() {
+ ArrayList locations = new ArrayList<>();
+ locations.add(genPartitionLocation(0, new Long[] {0L, 100L, 200L, 300L, 400L, 500L}));
+
+ for (int i = 0; i < 5; i++) {
+ Map> result =
+ CelebornPartitionUtil.splitSkewedPartitionLocations(locations, 5, i);
+ Map> expectResult = genRanges(new Object[][] {{"0-0", i, i}});
+ Assert.assertEquals(expectResult, result);
+ }
+ }
+
+ @Test
+ public void testSplitStable() {
+ ArrayList locations = new ArrayList<>();
+ for (int i = 0; i < 13; i++) {
+ locations.add(genPartitionLocation(i, new Long[] {0L, 100L, 200L, 300L, 500L, 1000L}));
+ }
+ locations.add(genPartitionLocation(91, new Long[] {0L, 1L}));
+
+ Collections.shuffle(locations);
+
+ Map> result =
+ CelebornPartitionUtil.splitSkewedPartitionLocations(locations, 3, 0);
+ Map> expectResult =
+ genRanges(
+ new Object[][] {
+ {"0-0", 0, 4},
+ {"0-1", 0, 4},
+ {"0-10", 0, 4},
+ {"0-11", 0, 4},
+ {"0-12", 0, 2}
+ });
+ Assert.assertEquals(expectResult, result);
+ }
+
+ private ArrayList genPartitionLocations(Map epochToOffsets) {
+ ArrayList locations = new ArrayList<>();
+ epochToOffsets.forEach(
+ (epoch, offsets) -> {
+ PartitionLocation location =
+ new PartitionLocation(
+ 0, epoch, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY);
+ StorageInfo storageInfo =
+ new StorageInfo(
+ StorageInfo.Type.HDD,
+ "mountPoint",
+ false,
+ "filePath",
+ StorageInfo.LOCAL_DISK_MASK,
+ 1,
+ Arrays.asList(offsets));
+ location.setStorageInfo(storageInfo);
+ locations.add(location);
+ });
+ return locations;
+ }
+
+ private PartitionLocation genPartitionLocation(int epoch, Long[] offsets) {
+ PartitionLocation location =
+ new PartitionLocation(0, epoch, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY);
+ StorageInfo storageInfo =
+ new StorageInfo(
+ StorageInfo.Type.HDD,
+ "mountPoint",
+ false,
+ "filePath",
+ StorageInfo.LOCAL_DISK_MASK,
+ offsets[offsets.length - 1],
+ Arrays.asList(offsets));
+ location.setStorageInfo(storageInfo);
+ return location;
+ }
+
+ private Map> genRanges(Object[][] inputs) {
+ Map> ranges = new HashMap<>();
+ for (Object[] idToChunkRange : inputs) {
+ String uid = (String) idToChunkRange[0];
+ Pair range = Pair.of((int) idToChunkRange[1], (int) idToChunkRange[2]);
+ ranges.put(uid, range);
+ }
+ return ranges;
+ }
+}
diff --git a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
index 2b933181faa..6b3673b1843 100644
--- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -28,9 +28,11 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -44,6 +46,7 @@
import org.apache.celeborn.common.rpc.RpcEndpointRef;
import org.apache.celeborn.common.util.ExceptionMaker;
import org.apache.celeborn.common.util.JavaUtils;
+import org.apache.celeborn.common.write.PushFailedBatch;
import org.apache.celeborn.common.write.PushState;
public class DummyShuffleClient extends ShuffleClient {
@@ -139,6 +142,8 @@ public CelebornInputStream readPartition(
ExceptionMaker exceptionMaker,
ArrayList locations,
ArrayList streamHandlers,
+ Map> failedBatchSetMap,
+ Map> chunksRange,
int[] mapAttempts,
MetricsCallback metricsCallback)
throws IOException {
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index 673c9382437..8690f5cebde 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -20,9 +20,11 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
+import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.fs.FileSystem;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -39,6 +41,7 @@
import org.apache.celeborn.common.rpc.RpcEndpointRef;
import org.apache.celeborn.common.util.CelebornHadoopUtils;
import org.apache.celeborn.common.util.ExceptionMaker;
+import org.apache.celeborn.common.write.PushFailedBatch;
import org.apache.celeborn.common.write.PushState;
/**
@@ -241,6 +244,8 @@ public CelebornInputStream readPartition(
null,
null,
null,
+ null,
+ null,
metricsCallback);
}
@@ -255,6 +260,8 @@ public abstract CelebornInputStream readPartition(
ExceptionMaker exceptionMaker,
ArrayList locations,
ArrayList streamHandlers,
+ Map> failedBatchSetMap,
+ Map> chunksRange,
int[] mapAttempts,
MetricsCallback metricsCallback)
throws IOException;
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index fe373c8d22a..cf48f95bc72 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -69,6 +69,7 @@
import org.apache.celeborn.common.unsafe.Platform;
import org.apache.celeborn.common.util.*;
import org.apache.celeborn.common.write.DataBatches;
+import org.apache.celeborn.common.write.PushFailedBatch;
import org.apache.celeborn.common.write.PushState;
public class ShuffleClientImpl extends ShuffleClient {
@@ -146,30 +147,37 @@ protected Compressor initialValue() {
private final ReviveManager reviveManager;
+ private final boolean dataPushFailureTrackingEnabled;
+
public static class ReduceFileGroups {
public Map> partitionGroups;
+ public Map> pushFailedBatches;
public int[] mapAttempts;
public Set partitionIds;
ReduceFileGroups(
Map> partitionGroups,
int[] mapAttempts,
- Set partitionIds) {
+ Set partitionIds,
+ Map> pushFailedBatches) {
this.partitionGroups = partitionGroups;
this.mapAttempts = mapAttempts;
this.partitionIds = partitionIds;
+ this.pushFailedBatches = pushFailedBatches;
}
public ReduceFileGroups() {
this.partitionGroups = null;
this.mapAttempts = null;
this.partitionIds = null;
+ this.pushFailedBatches = null;
}
public void update(ReduceFileGroups fileGroups) {
partitionGroups = fileGroups.partitionGroups;
mapAttempts = fileGroups.mapAttempts;
partitionIds = fileGroups.partitionIds;
+ pushFailedBatches = fileGroups.pushFailedBatches;
}
}
@@ -199,6 +207,7 @@ public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier u
pushDataTimeout = conf.pushDataTimeoutMs();
}
authEnabled = conf.authEnabledOnClient();
+ dataPushFailureTrackingEnabled = conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled();
// init rpc env
rpcEnv =
@@ -1110,6 +1119,10 @@ public void onSuccess(ByteBuffer response) {
attemptId,
partitionId,
nextBatchId);
+ if (dataPushFailureTrackingEnabled) {
+ pushState.addFailedBatch(
+ latest.getUniqueId(), new PushFailedBatch(mapId, attemptId, nextBatchId));
+ }
ReviveRequest reviveRequest =
new ReviveRequest(
shuffleId,
@@ -1176,6 +1189,10 @@ public void onSuccess(ByteBuffer response) {
@Override
public void onFailure(Throwable e) {
+ if (dataPushFailureTrackingEnabled) {
+ pushState.addFailedBatch(
+ latest.getUniqueId(), new PushFailedBatch(mapId, attemptId, nextBatchId));
+ }
if (pushState.exception.get() != null) {
return;
}
@@ -1542,6 +1559,13 @@ public void onSuccess(ByteBuffer response) {
pushState.onSuccess(hostPort);
callback.onSuccess(ByteBuffer.wrap(new byte[] {StatusCode.SOFT_SPLIT.getValue()}));
} else {
+ if (dataPushFailureTrackingEnabled) {
+ for (DataBatches.DataBatch resubmitBatch : batchesNeedResubmit) {
+ pushState.addFailedBatch(
+ resubmitBatch.loc.getUniqueId(),
+ new PushFailedBatch(mapId, attemptId, resubmitBatch.batchId));
+ }
+ }
ReviveRequest[] requests =
addAndGetReviveRequests(
shuffleId, mapId, attemptId, batchesNeedResubmit, StatusCode.HARD_SPLIT);
@@ -1597,6 +1621,12 @@ public void onSuccess(ByteBuffer response) {
@Override
public void onFailure(Throwable e) {
+ if (dataPushFailureTrackingEnabled) {
+ for (int i = 0; i < numBatches; i++) {
+ pushState.addFailedBatch(
+ partitionUniqueIds[i], new PushFailedBatch(mapId, attemptId, batchIds[i]));
+ }
+ }
if (pushState.exception.get() != null) {
return;
}
@@ -1717,7 +1747,13 @@ private void mapEndInternal(
MapperEndResponse response =
lifecycleManagerRef.askSync(
- new MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId),
+ new MapperEnd(
+ shuffleId,
+ mapId,
+ attemptId,
+ numMappers,
+ partitionId,
+ pushState.getFailedBatches()),
rpcMaxRetries,
rpcRetryWait,
ClassTag$.MODULE$.apply(MapperEndResponse.class));
@@ -1782,7 +1818,10 @@ protected Tuple3 loadFileGroupInternal(
response.fileGroup().size());
return Tuple3.apply(
new ReduceFileGroups(
- response.fileGroup(), response.attempts(), response.partitionIds()),
+ response.fileGroup(),
+ response.attempts(),
+ response.partitionIds(),
+ response.pushFailedBatches()),
null,
null);
case SHUFFLE_NOT_REGISTERED:
@@ -1791,7 +1830,10 @@ protected Tuple3 loadFileGroupInternal(
// return empty result
return Tuple3.apply(
new ReduceFileGroups(
- response.fileGroup(), response.attempts(), response.partitionIds()),
+ response.fileGroup(),
+ response.attempts(),
+ response.partitionIds(),
+ response.pushFailedBatches()),
null,
null);
case STAGE_END_TIME_OUT:
@@ -1863,6 +1905,8 @@ public CelebornInputStream readPartition(
ExceptionMaker exceptionMaker,
ArrayList locations,
ArrayList streamHandlers,
+ Map> failedBatchSetMap,
+ Map> chunksRange,
int[] mapAttempts,
MetricsCallback metricsCallback)
throws IOException {
@@ -1895,6 +1939,8 @@ public CelebornInputStream readPartition(
locations,
streamHandlers,
mapAttempts,
+ failedBatchSetMap,
+ chunksRange,
attemptNumber,
taskId,
startMapIndex,
diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
index e525a135ae2..ea3df88d95a 100644
--- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
+++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
@@ -30,10 +30,12 @@
import com.google.common.util.concurrent.Uninterruptibles;
import io.netty.buffer.ByteBuf;
import net.jpountz.lz4.LZ4Exception;
+import org.apache.commons.lang3.tuple.Pair;
import org.roaringbitmap.RoaringBitmap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.celeborn.client.ClientUtils;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.compress.Decompressor;
import org.apache.celeborn.common.CelebornConf;
@@ -45,6 +47,7 @@
import org.apache.celeborn.common.unsafe.Platform;
import org.apache.celeborn.common.util.ExceptionMaker;
import org.apache.celeborn.common.util.Utils;
+import org.apache.celeborn.common.write.PushFailedBatch;
public abstract class CelebornInputStream extends InputStream {
private static final Logger logger = LoggerFactory.getLogger(CelebornInputStream.class);
@@ -56,6 +59,8 @@ public static CelebornInputStream create(
ArrayList locations,
ArrayList streamHandlers,
int[] attempts,
+ Map> failedBatchSetMap,
+ Map> chunksRange,
int attemptNumber,
long taskId,
int startMapIndex,
@@ -68,27 +73,57 @@ public static CelebornInputStream create(
ExceptionMaker exceptionMaker,
MetricsCallback metricsCallback)
throws IOException {
- if (locations == null || locations.size() == 0) {
+ if (locations == null || locations.isEmpty()) {
return emptyInputStream;
} else {
- return new CelebornInputStreamImpl(
- conf,
- clientFactory,
- shuffleKey,
- locations,
- streamHandlers,
- attempts,
- attemptNumber,
- taskId,
- startMapIndex,
- endMapIndex,
- fetchExcludedWorkers,
- shuffleClient,
- appShuffleId,
- shuffleId,
- partitionId,
- exceptionMaker,
- metricsCallback);
+ // if startMapIndex > endMapIndex, means partition is skew partition and read by Celeborn
+ // implementation.
+ // locations will split to sub-partitions with startMapIndex size.
+ boolean readSkewPartitionWithoutMapRange =
+ ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex);
+ if (readSkewPartitionWithoutMapRange) {
+ return new CelebornInputStreamImpl(
+ conf,
+ clientFactory,
+ shuffleKey,
+ locations,
+ streamHandlers,
+ attempts,
+ failedBatchSetMap,
+ attemptNumber,
+ taskId,
+ chunksRange,
+ fetchExcludedWorkers,
+ shuffleClient,
+ appShuffleId,
+ shuffleId,
+ partitionId,
+ exceptionMaker,
+ true,
+ metricsCallback);
+ } else {
+ return new CelebornInputStreamImpl(
+ conf,
+ clientFactory,
+ shuffleKey,
+ locations,
+ streamHandlers,
+ attempts,
+ failedBatchSetMap,
+ attemptNumber,
+ taskId,
+ startMapIndex,
+ endMapIndex,
+ /*partitionLocationToChunkRange = */ null,
+ fetchExcludedWorkers,
+ shuffleClient,
+ appShuffleId,
+ shuffleId,
+ partitionId,
+ exceptionMaker,
+ false,
+ metricsCallback);
+ }
}
}
@@ -136,9 +171,12 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
private final long taskId;
private final int startMapIndex;
private final int endMapIndex;
+ private final Map> partitionLocationToChunkRange;
private Map> batchesRead = new HashMap<>();
+ private final Map> failedBatches;
+
private byte[] compressedBuf;
private byte[] rawDataBuf;
private Decompressor decompressor;
@@ -175,6 +213,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
private ExceptionMaker exceptionMaker;
private boolean closed = false;
+ private final boolean readSkewPartitionWithoutMapRange;
+
CelebornInputStreamImpl(
CelebornConf conf,
TransportClientFactory clientFactory,
@@ -182,16 +222,62 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
ArrayList locations,
ArrayList streamHandlers,
int[] attempts,
+ Map> failedBatchSet,
+ int attemptNumber,
+ long taskId,
+ Map> partitionLocationToChunkRange,
+ ConcurrentHashMap fetchExcludedWorkers,
+ ShuffleClient shuffleClient,
+ int appShuffleId,
+ int shuffleId,
+ int partitionId,
+ ExceptionMaker exceptionMaker,
+ boolean splitSkewPartitionWithoutMapRange,
+ MetricsCallback metricsCallback)
+ throws IOException {
+ this(
+ conf,
+ clientFactory,
+ shuffleKey,
+ locations,
+ streamHandlers,
+ attempts,
+ failedBatchSet,
+ attemptNumber,
+ taskId,
+ /*startMapIndex = */ -1,
+ /*endMapIndex = */ -1,
+ partitionLocationToChunkRange,
+ fetchExcludedWorkers,
+ shuffleClient,
+ appShuffleId,
+ shuffleId,
+ partitionId,
+ exceptionMaker,
+ splitSkewPartitionWithoutMapRange,
+ metricsCallback);
+ }
+
+ CelebornInputStreamImpl(
+ CelebornConf conf,
+ TransportClientFactory clientFactory,
+ String shuffleKey,
+ ArrayList locations,
+ ArrayList streamHandlers,
+ int[] attempts,
+ Map> failedBatchSet,
int attemptNumber,
long taskId,
int startMapIndex,
int endMapIndex,
+ Map> partitionLocationToChunkRange,
ConcurrentHashMap fetchExcludedWorkers,
ShuffleClient shuffleClient,
int appShuffleId,
int shuffleId,
int partitionId,
ExceptionMaker exceptionMaker,
+ boolean readSkewPartitionWithoutMapRange,
MetricsCallback metricsCallback)
throws IOException {
this.conf = conf;
@@ -206,12 +292,15 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
this.taskId = taskId;
this.startMapIndex = startMapIndex;
this.endMapIndex = endMapIndex;
+ this.partitionLocationToChunkRange = partitionLocationToChunkRange;
this.rangeReadFilter = conf.shuffleRangeReadFilterEnabled();
this.enabledReadLocalShuffle = conf.enableReadLocalShuffleFile();
this.localHostAddress = Utils.localHostName(conf);
this.shuffleCompressionEnabled =
!conf.shuffleCompressionCodec().equals(CompressionCodec.NONE);
this.fetchExcludedWorkerExpireTimeout = conf.clientFetchExcludedWorkerExpireTimeout();
+ this.failedBatches = failedBatchSet;
+ this.readSkewPartitionWithoutMapRange = readSkewPartitionWithoutMapRange;
this.fetchExcludedWorkers = fetchExcludedWorkers;
if (conf.clientPushReplicateEnabled()) {
@@ -260,7 +349,11 @@ private Tuple2 nextReadableLocation() {
return null;
}
PartitionLocation currentLocation = locations.get(fileIndex);
- while (skipLocation(startMapIndex, endMapIndex, currentLocation)) {
+ // if pushShuffleFailureTrackingEnabled is true, should not skip location
+ while ((readSkewPartitionWithoutMapRange
+ && !partitionLocationToChunkRange.containsKey(currentLocation.getUniqueId()))
+ || (!readSkewPartitionWithoutMapRange
+ && skipLocation(startMapIndex, endMapIndex, currentLocation))) {
skipCount.increment();
fileIndex++;
if (fileIndex == locationCount) {
@@ -336,7 +429,7 @@ private PartitionReader createReaderWithRetry(
lastException = e;
shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort(), e);
fetchChunkRetryCnt++;
- if (location.hasPeer()) {
+ if (location.hasPeer() && !readSkewPartitionWithoutMapRange) {
// fetchChunkRetryCnt % 2 == 0 means both replicas have been tried,
// so sleep before next try.
if (fetchChunkRetryCnt % 2 == 0) {
@@ -407,7 +500,7 @@ private ByteBuf getNextChunk() throws IOException {
+ currentReader.getLocation(),
e);
} else {
- if (currentReader.getLocation().hasPeer()) {
+ if (currentReader.getLocation().hasPeer() && !readSkewPartitionWithoutMapRange) {
logger.warn(
"Fetch chunk failed {}/{} times for location {}, change to peer",
fetchChunkRetryCnt,
@@ -444,6 +537,15 @@ private PartitionReader createReader(
throws IOException, InterruptedException {
StorageInfo storageInfo = location.getStorageInfo();
+
+ int startChunkIndex = -1;
+ int endChunkIndex = -1;
+ if (partitionLocationToChunkRange != null) {
+ Pair chunkRange =
+ partitionLocationToChunkRange.get(location.getUniqueId());
+ startChunkIndex = chunkRange.getLeft();
+ endChunkIndex = chunkRange.getRight();
+ }
switch (storageInfo.getType()) {
case HDD:
case SSD:
@@ -473,7 +575,9 @@ private PartitionReader createReader(
endMapIndex,
fetchChunkRetryCnt,
fetchChunkMaxRetry,
- callback);
+ callback,
+ startChunkIndex,
+ endChunkIndex);
}
case S3:
case HDFS:
@@ -621,6 +725,7 @@ private boolean fillBuffer() throws IOException {
return false;
}
+ PushFailedBatch failedBatch = new PushFailedBatch(-1, -1, -1);
boolean hasData = false;
while (currentChunk.isReadable() || moveToNextChunk()) {
currentChunk.readBytes(sizeBuf);
@@ -645,6 +750,19 @@ private boolean fillBuffer() throws IOException {
// de-duplicate
if (attemptId == attempts[mapId]) {
+ if (readSkewPartitionWithoutMapRange) {
+ Set failedBatchSet =
+ this.failedBatches.get(currentReader.getLocation().getUniqueId());
+ if (null != failedBatchSet) {
+ failedBatch.setMapId(mapId);
+ failedBatch.setAttemptId(attemptId);
+ failedBatch.setBatchId(batchId);
+ if (failedBatchSet.contains(failedBatch)) {
+ logger.warn("Skip duplicated batch: {}.", failedBatch);
+ continue;
+ }
+ }
+ }
if (!batchesRead.containsKey(mapId)) {
Set batchSet = new HashSet<>();
batchesRead.put(mapId, batchSet);
diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
index 3158aa12f72..29236273663 100644
--- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
+++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
@@ -55,6 +55,8 @@ public class WorkerPartitionReader implements PartitionReader {
private int returnedChunks;
private int chunkIndex;
+ private int startChunkIndex;
+ private int endChunkIndex;
private final LinkedBlockingQueue results;
private final ChunkReceivedCallback callback;
@@ -80,7 +82,9 @@ public class WorkerPartitionReader implements PartitionReader {
int endMapIndex,
int fetchChunkRetryCnt,
int fetchChunkMaxRetry,
- MetricsCallback metricsCallback)
+ MetricsCallback metricsCallback,
+ int startChunkIndex,
+ int endChunkIndex)
throws IOException, InterruptedException {
this.shuffleKey = shuffleKey;
fetchMaxReqsInFlight = conf.clientFetchMaxReqsInFlight();
@@ -133,7 +137,12 @@ public void onFailure(int chunkIndex, Throwable e) {
} else {
this.streamHandler = pbStreamHandler;
}
-
+ this.startChunkIndex = startChunkIndex == -1 ? 0 : startChunkIndex;
+ this.endChunkIndex =
+ endChunkIndex == -1
+ ? streamHandler.getNumChunks() - 1
+ : Math.min(streamHandler.getNumChunks() - 1, endChunkIndex);
+ this.chunkIndex = this.startChunkIndex;
this.location = location;
this.clientFactory = clientFactory;
this.fetchChunkRetryCnt = fetchChunkRetryCnt;
@@ -144,13 +153,13 @@ public void onFailure(int chunkIndex, Throwable e) {
@Override
public boolean hasNext() {
- return returnedChunks < streamHandler.getNumChunks();
+ return returnedChunks < endChunkIndex - startChunkIndex + 1;
}
@Override
public ByteBuf next() throws IOException, InterruptedException {
checkException();
- if (chunkIndex < streamHandler.getNumChunks()) {
+ if (chunkIndex <= endChunkIndex) {
fetchChunks();
}
ByteBuf chunk = null;
@@ -202,10 +211,10 @@ public PartitionLocation getLocation() {
}
private void fetchChunks() throws IOException, InterruptedException {
- final int inFlight = chunkIndex - returnedChunks;
+ final int inFlight = chunkIndex - startChunkIndex - returnedChunks;
if (inFlight < fetchMaxReqsInFlight) {
final int toFetch =
- Math.min(fetchMaxReqsInFlight - inFlight + 1, streamHandler.getNumChunks() - chunkIndex);
+ Math.min(fetchMaxReqsInFlight - inFlight + 1, endChunkIndex + 1 - chunkIndex);
for (int i = 0; i < toFetch; i++) {
if (testFetch && fetchChunkRetryCnt < fetchChunkMaxRetry - 1 && chunkIndex == 3) {
callback.onFailure(chunkIndex, new CelebornIOException("Test fetch chunk failure"));
diff --git a/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala b/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala
index d7dccb941b0..b071eff3bf4 100644
--- a/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala
@@ -17,6 +17,8 @@
package org.apache.celeborn.client
+import org.apache.celeborn.common.CelebornConf
+
object ClientUtils {
/**
@@ -37,4 +39,20 @@ object ClientUtils {
}
true
}
+
+ /**
+ * If startMapIndex > endMapIndex, means partition is skew partition.
+ * locations will split to sub-partitions with startMapIndex size.
+ *
+ * @param conf cleborn conf
+ * @param startMapIndex shuffle start map index
+ * @param endMapIndex shuffle end map index
+ * @return true if read skew partition without map range
+ */
+ def readSkewPartitionWithoutMapRange(
+ conf: CelebornConf,
+ startMapIndex: Int,
+ endMapIndex: Int): Boolean = {
+ conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled && startMapIndex > endMapIndex
+ }
}
diff --git a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index 201be286978..aaeb4462fbd 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -18,6 +18,7 @@
package org.apache.celeborn.client
import java.util
+import java.util.Collections
import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, ScheduledFuture, TimeUnit}
import java.util.concurrent.atomic.{AtomicInteger, LongAdder}
@@ -40,6 +41,7 @@ import org.apache.celeborn.common.rpc.RpcCallContext
import org.apache.celeborn.common.util.FunctionConverter._
import org.apache.celeborn.common.util.JavaUtils
import org.apache.celeborn.common.util.ThreadUtils
+import org.apache.celeborn.common.write.PushFailedBatch
case class ShuffleCommittedInfo(
// partition id -> unique partition ids
@@ -215,13 +217,16 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
mapId: Int,
attemptId: Int,
numMappers: Int,
- partitionId: Int = -1): (Boolean, Boolean) = {
+ partitionId: Int = -1,
+ pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = Collections.emptyMap())
+ : (Boolean, Boolean) = {
getCommitHandler(shuffleId).finishMapperAttempt(
shuffleId,
mapId,
attemptId,
numMappers,
partitionId,
+ pushFailedBatches,
r => lifecycleManager.workerStatusTracker.recordWorkerFailure(r))
}
diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 00dfd77c669..73b0dc7d728 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -58,11 +58,15 @@ import org.apache.celeborn.common.util.{JavaUtils, PbSerDeUtils, ThreadUtils, Ut
import org.apache.celeborn.common.util.FunctionConverter._
import org.apache.celeborn.common.util.ThreadUtils.awaitResult
import org.apache.celeborn.common.util.Utils.UNKNOWN_APP_SHUFFLE_ID
+import org.apache.celeborn.common.write.PushFailedBatch
object LifecycleManager {
// shuffle id -> partition id -> partition locations
type ShuffleFileGroups =
ConcurrentHashMap[Int, ConcurrentHashMap[Integer, util.Set[PartitionLocation]]]
+ // shuffle id -> partition uniqueId -> PushFailedBatch set
+ type ShufflePushFailedBatches =
+ ConcurrentHashMap[Int, util.HashMap[String, util.Set[PushFailedBatch]]]
type ShuffleAllocatedWorkers =
ConcurrentHashMap[Int, ConcurrentHashMap[String, ShufflePartitionLocationInfo]]
type ShuffleFailedWorkers = ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]
@@ -404,13 +408,13 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
oldPartition,
isSegmentGranularityVisible = commitManager.isSegmentGranularityVisible(shuffleId))
- case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId) =>
+ case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) =>
logTrace(s"Received MapperEnd TaskEnd request, " +
s"${Utils.makeMapKey(shuffleId, mapId, attemptId)}")
val partitionType = getPartitionType(shuffleId)
partitionType match {
case PartitionType.REDUCE =>
- handleMapperEnd(context, shuffleId, mapId, attemptId, numMappers)
+ handleMapperEnd(context, shuffleId, mapId, attemptId, numMappers, pushFailedBatch)
case PartitionType.MAP =>
handleMapPartitionEnd(
context,
@@ -802,10 +806,16 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
shuffleId: Int,
mapId: Int,
attemptId: Int,
- numMappers: Int): Unit = {
+ numMappers: Int,
+ pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]]): Unit = {
val (mapperAttemptFinishedSuccess, allMapperFinished) =
- commitManager.finishMapperAttempt(shuffleId, mapId, attemptId, numMappers)
+ commitManager.finishMapperAttempt(
+ shuffleId,
+ mapId,
+ attemptId,
+ numMappers,
+ pushFailedBatches = pushFailedBatches)
if (mapperAttemptFinishedSuccess && allMapperFinished) {
// last mapper finished. call mapper end
logInfo(s"Last MapperEnd, call StageEnd with shuffleKey:" +
diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
index 65ba4dbab04..38c3b0c9c81 100644
--- a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -30,7 +30,7 @@ import scala.concurrent.duration.Duration
import org.apache.celeborn.client.{ShuffleCommittedInfo, WorkerStatusTracker}
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
-import org.apache.celeborn.client.LifecycleManager.{ShuffleFailedWorkers, ShuffleFileGroups}
+import org.apache.celeborn.client.LifecycleManager.{ShuffleFailedWorkers, ShuffleFileGroups, ShufflePushFailedBatches}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}
@@ -42,6 +42,7 @@ import org.apache.celeborn.common.util.{CollectionUtils, JavaUtils, Utils}
// Can Remove this if celeborn don't support scala211 in future
import org.apache.celeborn.common.util.FunctionConverter._
import org.apache.celeborn.common.util.ThreadUtils.awaitResult
+import org.apache.celeborn.common.write.PushFailedBatch
case class CommitFilesParam(
worker: WorkerInfo,
@@ -74,6 +75,7 @@ abstract class CommitHandler(
private val totalWritten = new LongAdder
private val fileCount = new LongAdder
protected val reducerFileGroupsMap = new ShuffleFileGroups
+ protected val shufflePushFailedBatches = new ShufflePushFailedBatches
val ec = ExecutionContext.fromExecutor(sharedRpcPool)
@@ -82,6 +84,8 @@ abstract class CommitHandler(
def getPartitionType(): PartitionType
+ def getShuffleFailedBatches(): ShufflePushFailedBatches = this.shufflePushFailedBatches
+
def isStageEnd(shuffleId: Int): Boolean = false
def isStageEndOrInProcess(shuffleId: Int): Boolean = false
@@ -178,6 +182,7 @@ abstract class CommitHandler(
def removeExpiredShuffle(shuffleId: Int): Unit = {
reducerFileGroupsMap.remove(shuffleId)
+ shufflePushFailedBatches.remove(shuffleId)
}
/**
@@ -197,6 +202,7 @@ abstract class CommitHandler(
attemptId: Int,
numMappers: Int,
partitionId: Int,
+ pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]],
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean)
def registerShuffle(
diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
index 9352dc99a72..a08f1e0d51f 100644
--- a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
@@ -39,6 +39,7 @@ import org.apache.celeborn.common.rpc.RpcCallContext
import org.apache.celeborn.common.util.FunctionConverter._
import org.apache.celeborn.common.util.JavaUtils
import org.apache.celeborn.common.util.Utils
+import org.apache.celeborn.common.write.PushFailedBatch
/**
* This commit handler is for MapPartition ShuffleType, which means that a Map Partition contains all data produced
@@ -184,6 +185,7 @@ class MapPartitionCommitHandler(
attemptId: Int,
numMappers: Int,
partitionId: Int,
+ pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]],
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = {
val inProcessingPartitionIds =
inProcessMapPartitionEndIds.computeIfAbsent(
diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 951fc89e601..55639764c7c 100644
--- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -20,11 +20,13 @@ package org.apache.celeborn.client.commit
import java.nio.ByteBuffer
import java.util
import java.util.concurrent.{Callable, ConcurrentHashMap, ThreadPoolExecutor, TimeUnit}
+import java.util.function
import scala.collection.JavaConverters._
import scala.collection.mutable
import com.google.common.cache.{Cache, CacheBuilder}
+import com.google.common.collect.Sets
import org.apache.celeborn.client.{ClientUtils, ShuffleCommittedInfo, WorkerStatusTracker}
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
@@ -38,6 +40,7 @@ import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.rpc.RpcCallContext
import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext}
import org.apache.celeborn.common.util.JavaUtils
+import org.apache.celeborn.common.write.PushFailedBatch
/**
* This commit handler is for ReducePartition ShuffleType, which means that a Reduce Partition contains all data
@@ -82,6 +85,22 @@ class ReducePartitionCommitHandler(
.maximumSize(rpcCacheSize)
.build().asInstanceOf[Cache[Int, ByteBuffer]]
+ private val newShuffleId2PushFailedBatchMapFunc
+ : function.Function[Int, util.HashMap[String, util.Set[PushFailedBatch]]] =
+ new util.function.Function[Int, util.HashMap[String, util.Set[PushFailedBatch]]]() {
+ override def apply(s: Int): util.HashMap[String, util.Set[PushFailedBatch]] = {
+ new util.HashMap[String, util.Set[PushFailedBatch]]()
+ }
+ }
+
+ private val uniqueId2PushFailedBatchMapFunc
+ : function.Function[String, util.Set[PushFailedBatch]] =
+ new util.function.Function[String, util.Set[PushFailedBatch]]() {
+ override def apply(s: String): util.Set[PushFailedBatch] = {
+ Sets.newHashSet[PushFailedBatch]()
+ }
+ }
+
override def getPartitionType(): PartitionType = {
PartitionType.REDUCE
}
@@ -240,6 +259,7 @@ class ReducePartitionCommitHandler(
attemptId: Int,
numMappers: Int,
partitionId: Int,
+ pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]],
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = {
shuffleMapperAttempts.synchronized {
if (getMapperAttempts(shuffleId) == null) {
@@ -250,6 +270,17 @@ class ReducePartitionCommitHandler(
val attempts = shuffleMapperAttempts.get(shuffleId)
if (attempts(mapId) < 0) {
attempts(mapId) = attemptId
+ if (null != pushFailedBatches && !pushFailedBatches.isEmpty) {
+ val pushFailedBatchesMap = shufflePushFailedBatches.computeIfAbsent(
+ shuffleId,
+ newShuffleId2PushFailedBatchMapFunc)
+ for ((partitionUniqId, pushFailedBatchSet) <- pushFailedBatches.asScala) {
+ val partitionPushFailedBatches = pushFailedBatchesMap.computeIfAbsent(
+ partitionUniqId,
+ uniqueId2PushFailedBatchMapFunc)
+ partitionPushFailedBatches.addAll(pushFailedBatchSet)
+ }
+ }
// Mapper with this attemptId finished, also check all other mapper finished or not.
(true, ClientUtils.areAllMapperAttemptsFinished(attempts))
} else {
@@ -301,7 +332,11 @@ class ReducePartitionCommitHandler(
val returnedMsg = GetReducerFileGroupResponse(
StatusCode.SUCCESS,
reducerFileGroupsMap.getOrDefault(shuffleId, JavaUtils.newConcurrentHashMap()),
- getMapperAttempts(shuffleId))
+ getMapperAttempts(shuffleId),
+ pushFailedBatches =
+ shufflePushFailedBatches.getOrDefault(
+ shuffleId,
+ new util.HashMap[String, util.Set[PushFailedBatch]]()))
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
}
})
diff --git a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
index 886aff3dc8e..7053937a06a 100644
--- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
+++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
@@ -423,7 +423,11 @@ public void testUpdateReducerFileGroupInterrupted() throws InterruptedException
t -> {
Thread.sleep(60 * 1000);
return GetReducerFileGroupResponse$.MODULE$.apply(
- StatusCode.SUCCESS, locations, new int[0], Collections.emptySet());
+ StatusCode.SUCCESS,
+ locations,
+ new int[0],
+ Collections.emptySet(),
+ Collections.emptyMap());
});
when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any()))
@@ -470,7 +474,11 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() {
.thenAnswer(
t -> {
return GetReducerFileGroupResponse$.MODULE$.apply(
- StatusCode.SHUFFLE_NOT_REGISTERED, locations, new int[0], Collections.emptySet());
+ StatusCode.SHUFFLE_NOT_REGISTERED,
+ locations,
+ new int[0],
+ Collections.emptySet(),
+ Collections.emptyMap());
});
when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any()))
@@ -494,7 +502,11 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() {
.thenAnswer(
t -> {
return GetReducerFileGroupResponse$.MODULE$.apply(
- StatusCode.STAGE_END_TIME_OUT, locations, new int[0], Collections.emptySet());
+ StatusCode.STAGE_END_TIME_OUT,
+ locations,
+ new int[0],
+ Collections.emptySet(),
+ Collections.emptyMap());
});
when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any()))
@@ -518,7 +530,11 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() {
.thenAnswer(
t -> {
return GetReducerFileGroupResponse$.MODULE$.apply(
- StatusCode.SHUFFLE_DATA_LOST, locations, new int[0], Collections.emptySet());
+ StatusCode.SHUFFLE_DATA_LOST,
+ locations,
+ new int[0],
+ Collections.emptySet(),
+ Collections.emptyMap());
});
when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any()))
diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
index 8cd8bedf76d..0570760ce14 100644
--- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
+++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
@@ -41,7 +41,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
private val attemptId = 0
private var lifecycleManager: LifecycleManager = _
- private var shuffleClient: ShuffleClientImpl = _
+ protected var shuffleClient: ShuffleClientImpl = _
var _shuffleId = 0
def nextShuffleId: Int = {
@@ -164,6 +164,8 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
null,
null,
null,
+ null,
+ null,
metricsCallback)
Assert.assertEquals(stream.read(), -1)
@@ -180,6 +182,8 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
null,
null,
null,
+ null,
+ null,
metricsCallback)
Assert.assertEquals(stream.read(), -1)
}
diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java b/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java
index 28cb652565b..8509d5717b2 100644
--- a/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java
+++ b/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java
@@ -68,6 +68,10 @@ public int getValue() {
public int availableStorageTypes = 0;
+ public long fileSize;
+
+ public List chunkOffsets;
+
public StorageInfo() {}
public StorageInfo(Type type, boolean isFinal, String filePath) {
@@ -95,6 +99,23 @@ public StorageInfo(
this.availableStorageTypes = availableStorageTypes;
}
+ public StorageInfo(
+ Type type,
+ String mountPoint,
+ boolean finalResult,
+ String filePath,
+ int availableStorageTypes,
+ long fileSize,
+ List chunkOffsets) {
+ this.type = type;
+ this.mountPoint = mountPoint;
+ this.finalResult = finalResult;
+ this.filePath = filePath;
+ this.availableStorageTypes = availableStorageTypes;
+ this.fileSize = fileSize;
+ this.chunkOffsets = chunkOffsets;
+ }
+
public boolean isFinalResult() {
return finalResult;
}
@@ -119,6 +140,22 @@ public String getFilePath() {
return filePath;
}
+ public void setChunkOffsets(List chunkOffsets) {
+ this.chunkOffsets = chunkOffsets;
+ }
+
+ public List getChunkOffsets() {
+ return this.chunkOffsets;
+ }
+
+ public void setFileSize(long fileSize) {
+ this.fileSize = fileSize;
+ }
+
+ public long getFileSize() {
+ return fileSize;
+ }
+
@Override
public String toString() {
return "StorageInfo{"
@@ -131,6 +168,10 @@ public String toString() {
+ finalResult
+ ", filePath="
+ filePath
+ + ", fileSize="
+ + fileSize
+ + ", chunkOffsets="
+ + chunkOffsets
+ '}';
}
@@ -215,7 +256,11 @@ public static PbStorageInfo toPb(StorageInfo storageInfo) {
.setType(storageInfo.type.value)
.setFinalResult(storageInfo.finalResult)
.setMountPoint(storageInfo.mountPoint)
- .setAvailableStorageTypes(storageInfo.availableStorageTypes);
+ .setAvailableStorageTypes(storageInfo.availableStorageTypes)
+ .setFileSize(storageInfo.getFileSize());
+ if (storageInfo.getChunkOffsets() != null) {
+ builder.addAllChunkOffsets(storageInfo.getChunkOffsets());
+ }
if (filePath != null) {
builder.setFilePath(filePath);
}
@@ -228,7 +273,9 @@ public static StorageInfo fromPb(PbStorageInfo pbStorageInfo) {
pbStorageInfo.getMountPoint(),
pbStorageInfo.getFinalResult(),
pbStorageInfo.getFilePath(),
- pbStorageInfo.getAvailableStorageTypes());
+ pbStorageInfo.getAvailableStorageTypes(),
+ pbStorageInfo.getFileSize(),
+ pbStorageInfo.getChunkOffsetsList());
}
public static int getAvailableTypes(List types) {
diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java
new file mode 100644
index 00000000000..ccee8bf1113
--- /dev/null
+++ b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.write;
+
+import java.io.Serializable;
+
+import com.google.common.base.Objects;
+import org.apache.commons.lang3.builder.ToStringBuilder;
+import org.apache.commons.lang3.builder.ToStringStyle;
+
+public class PushFailedBatch implements Serializable {
+
+ private int mapId;
+ private int attemptId;
+ private int batchId;
+
+ public PushFailedBatch(int mapId, int attemptId, int batchId) {
+ this.mapId = mapId;
+ this.attemptId = attemptId;
+ this.batchId = batchId;
+ }
+
+ public int getMapId() {
+ return mapId;
+ }
+
+ public void setMapId(int mapId) {
+ this.mapId = mapId;
+ }
+
+ public int getAttemptId() {
+ return attemptId;
+ }
+
+ public void setAttemptId(int attemptId) {
+ this.attemptId = attemptId;
+ }
+
+ public int getBatchId() {
+ return batchId;
+ }
+
+ public void setBatchId(int batchId) {
+ this.batchId = batchId;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof PushFailedBatch)) {
+ return false;
+ }
+ PushFailedBatch o = (PushFailedBatch) other;
+ return mapId == o.mapId && attemptId == o.attemptId && batchId == o.batchId;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(mapId, attemptId, batchId);
+ }
+
+ @Override
+ public String toString() {
+ return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
+ .append("mapId", mapId)
+ .append("attemptId", attemptId)
+ .append("batchId", batchId)
+ .toString();
+ }
+}
diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushState.java b/common/src/main/java/org/apache/celeborn/common/write/PushState.java
index 3979cafd632..9f691094bd9 100644
--- a/common/src/main/java/org/apache/celeborn/common/write/PushState.java
+++ b/common/src/main/java/org/apache/celeborn/common/write/PushState.java
@@ -18,9 +18,12 @@
package org.apache.celeborn.common.write;
import java.io.IOException;
+import java.util.Map;
+import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
+import com.google.common.collect.Sets;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.celeborn.common.CelebornConf;
@@ -33,9 +36,12 @@ public class PushState {
public AtomicReference exception = new AtomicReference<>();
private final InFlightRequestTracker inFlightRequestTracker;
+ private final Map> failedBatchMap;
+
public PushState(CelebornConf conf) {
pushBufferMaxSize = conf.clientPushBufferMaxSize();
inFlightRequestTracker = new InFlightRequestTracker(conf, this);
+ failedBatchMap = new ConcurrentHashMap<>();
}
public void cleanup() {
@@ -88,4 +94,14 @@ public boolean limitZeroInFlight() throws IOException {
public int remainingAllowPushes(String hostAndPushPort) {
return inFlightRequestTracker.remainingAllowPushes(hostAndPushPort);
}
+
+ public void addFailedBatch(String partitionId, PushFailedBatch failedBatch) {
+ this.failedBatchMap
+ .computeIfAbsent(partitionId, (s) -> Sets.newConcurrentHashSet())
+ .add(failedBatch);
+ }
+
+ public Map> getFailedBatches() {
+ return this.failedBatchMap;
+ }
}
diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto
index 8aa59bcb29b..553e95aa79e 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -134,6 +134,8 @@ message PbStorageInfo {
bool finalResult = 3;
string filePath = 4;
int32 availableStorageTypes = 5;
+ int64 fileSize = 6;
+ repeated int64 chunkOffsets = 7;
}
message PbPartitionLocation {
@@ -354,6 +356,17 @@ message PbMapperEnd {
int32 attemptId = 3;
int32 numMappers = 4;
int32 partitionId = 5;
+ map pushFailureBatches= 6;
+}
+
+message PbPushFailedBatchSet {
+ repeated PbPushFailedBatch failureBatches = 1;
+}
+
+message PbPushFailedBatch {
+ int32 mapId = 1;
+ int32 attemptId = 2;
+ int32 batchId = 3;
}
message PbMapperEndResponse {
@@ -375,6 +388,8 @@ message PbGetReducerFileGroupResponse {
// only map partition mode has succeed partitionIds
repeated int32 partitionIds = 4;
+
+ map pushFailedBatches = 5;
}
message PbGetShuffleId {
@@ -850,6 +865,8 @@ message PbPackedPartitionLocations {
repeated string filePaths = 10;
repeated int32 availableStorageTypes = 11;
repeated int32 modes = 12;
+ repeated int64 fileSizes = 13;
+ repeated PbChunkOffsets chunksOffsets = 14;
}
message PbPackedPartitionLocationsPair {
@@ -883,3 +900,7 @@ message PbPushMergedDataSplitPartitionInfo {
repeated int32 splitPartitionIndexes = 1;
repeated int32 statusCodes = 2;
}
+
+message PbChunkOffsets {
+ repeated int64 chunkOffset = 1;
+}
diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index ad9b3538169..0eae8a3f762 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -1038,6 +1038,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def clientPushSendBufferPoolExpireTimeout: Long = get(CLIENT_PUSH_SENDBUFFERPOOL_EXPIRETIMEOUT)
def clientPushSendBufferPoolExpireCheckInterval: Long =
get(CLIENT_PUSH_SENDBUFFERPOOL_CHECKEXPIREINTERVAL)
+ def clientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean =
+ get(CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED)
// //////////////////////////////////////////////////////
// Client Shuffle //
@@ -5932,6 +5934,15 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(10000)
+ val CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED: ConfigEntry[Boolean] =
+ buildConf("celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled")
+ .categories("client")
+ .version("0.6.0")
+ .doc("If this is true, Celeborn will adaptively split skewed partitions instead of reading them by Spark map " +
+ "range. Please note that this feature requires the `Celeborn-Optimize-Skew-Partitions-spark3_3.patch`. ")
+ .booleanConf
+ .createWithDefault(false)
+
// SSL Configs
val SSL_ENABLED: ConfigEntry[Boolean] =
diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 0e465196d28..e730b9b7afb 100644
--- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -32,6 +32,7 @@ import org.apache.celeborn.common.protocol._
import org.apache.celeborn.common.protocol.MessageType._
import org.apache.celeborn.common.quota.ResourceConsumption
import org.apache.celeborn.common.util.{PbSerDeUtils, Utils}
+import org.apache.celeborn.common.write.PushFailedBatch
sealed trait Message extends Serializable
@@ -271,7 +272,8 @@ object ControlMessages extends Logging {
mapId: Int,
attemptId: Int,
numMappers: Int,
- partitionId: Int)
+ partitionId: Int,
+ failedBatchSet: util.Map[String, util.Set[PushFailedBatch]])
extends MasterMessage
case class MapperEndResponse(status: StatusCode) extends MasterMessage
@@ -285,7 +287,8 @@ object ControlMessages extends Logging {
status: StatusCode,
fileGroup: util.Map[Integer, util.Set[PartitionLocation]],
attempts: Array[Int],
- partitionIds: util.Set[Integer] = Collections.emptySet[Integer]())
+ partitionIds: util.Set[Integer] = Collections.emptySet[Integer](),
+ pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = Collections.emptyMap())
extends MasterMessage
object WorkerExclude {
@@ -721,13 +724,18 @@ object ControlMessages extends Logging {
case pb: PbChangeLocationResponse =>
new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, pb.toByteArray)
- case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId) =>
+ case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) =>
+ val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) =>
+ val resultValue = PbSerDeUtils.toPbPushFailedBatchSet(v)
+ (k, resultValue)
+ }.toMap.asJava
val payload = PbMapperEnd.newBuilder()
.setShuffleId(shuffleId)
.setMapId(mapId)
.setAttemptId(attemptId)
.setNumMappers(numMappers)
.setPartitionId(partitionId)
+ .putAllPushFailureBatches(pushFailedMap)
.build().toByteArray
new TransportMessage(MessageType.MAPPER_END, payload)
@@ -744,7 +752,7 @@ object ControlMessages extends Logging {
.build().toByteArray
new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP, payload)
- case GetReducerFileGroupResponse(status, fileGroup, attempts, partitionIds) =>
+ case GetReducerFileGroupResponse(status, fileGroup, attempts, partitionIds, failedBatches) =>
val builder = PbGetReducerFileGroupResponse
.newBuilder()
.setStatus(status.getValue)
@@ -757,6 +765,11 @@ object ControlMessages extends Logging {
}.asJava)
builder.addAllAttempts(attempts.map(Integer.valueOf).toIterable.asJava)
builder.addAllPartitionIds(partitionIds)
+ builder.putAllPushFailedBatches(
+ failedBatches.asScala.map {
+ case (uniqueId, pushFailedBatchSet) =>
+ (uniqueId, PbSerDeUtils.toPbPushFailedBatchSet(pushFailedBatchSet))
+ }.asJava)
val payload = builder.build().toByteArray
new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE, payload)
@@ -1141,7 +1154,11 @@ object ControlMessages extends Logging {
pbMapperEnd.getMapId,
pbMapperEnd.getAttemptId,
pbMapperEnd.getNumMappers,
- pbMapperEnd.getPartitionId)
+ pbMapperEnd.getPartitionId,
+ pbMapperEnd.getPushFailureBatchesMap.asScala.map {
+ case (partitionId, pushFailedBatchSet) =>
+ (partitionId, PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet))
+ }.toMap.asJava)
case MAPPER_END_RESPONSE_VALUE =>
val pbMapperEndResponse = PbMapperEndResponse.parseFrom(message.getPayload)
@@ -1177,11 +1194,16 @@ object ControlMessages extends Logging {
val attempts = pbGetReducerFileGroupResponse.getAttemptsList.asScala.map(_.toInt).toArray
val partitionIds = new util.HashSet(pbGetReducerFileGroupResponse.getPartitionIdsList)
+ val pushFailedBatches = pbGetReducerFileGroupResponse.getPushFailedBatchesMap.asScala.map {
+ case (uniqueId, pushFailedBatchSet) =>
+ (uniqueId, PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet))
+ }.toMap.asJava
GetReducerFileGroupResponse(
Utils.toStatusCode(pbGetReducerFileGroupResponse.getStatus),
fileGroup,
attempts,
- partitionIds)
+ partitionIds,
+ pushFailedBatches)
case GET_SHUFFLE_ID_VALUE =>
message.getParsedPayload()
diff --git a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala
index 8a038242e75..553a4b9f9e0 100644
--- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala
@@ -32,6 +32,7 @@ import org.apache.celeborn.common.protocol.PartitionLocation.Mode
import org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource
import org.apache.celeborn.common.quota.ResourceConsumption
import org.apache.celeborn.common.util.{CollectionUtils => localCollectionUtils}
+import org.apache.celeborn.common.write.PushFailedBatch
object PbSerDeUtils {
@@ -524,6 +525,14 @@ object PbSerDeUtils {
pbPackedLocationsBuilder.addFilePaths("")
}
pbPackedLocationsBuilder.addAvailableStorageTypes(location.getStorageInfo.availableStorageTypes)
+ pbPackedLocationsBuilder.addFileSizes(location.getStorageInfo.getFileSize)
+ val chunkOffsets = PbChunkOffsets.newBuilder()
+ if (null != location.getStorageInfo.chunkOffsets && !location.getStorageInfo.chunkOffsets.isEmpty) {
+ chunkOffsets.addAllChunkOffset(location.getStorageInfo.chunkOffsets).build()
+ pbPackedLocationsBuilder.addChunksOffsets(chunkOffsets)
+ } else {
+ pbPackedLocationsBuilder.addChunksOffsets(chunkOffsets.build())
+ }
pbPackedLocationsBuilder.addModes(location.getMode.mode())
}
@@ -640,7 +649,9 @@ object PbSerDeUtils {
pbPackedPartitionLocations.getMountPoints(index)),
pbPackedPartitionLocations.getFinalResult(index),
filePath,
- pbPackedPartitionLocations.getAvailableStorageTypes(index)),
+ pbPackedPartitionLocations.getAvailableStorageTypes(index),
+ pbPackedPartitionLocations.getFileSizes(index),
+ pbPackedPartitionLocations.getChunksOffsets(index).getChunkOffsetList),
Utils.byteStringToRoaringBitmap(pbPackedPartitionLocations.getMapIdBitMap(index)))
}
@@ -670,4 +681,34 @@ object PbSerDeUtils {
}.asJava
}
+ def toPbPushFailedBatch(pushFailedBatch: PushFailedBatch): PbPushFailedBatch = {
+ PbPushFailedBatch.newBuilder()
+ .setMapId(pushFailedBatch.getMapId)
+ .setAttemptId(pushFailedBatch.getAttemptId)
+ .setBatchId(pushFailedBatch.getBatchId)
+ .build()
+ }
+
+ def fromPbPushFailedBatch(pbPushFailedBatch: PbPushFailedBatch): PushFailedBatch = {
+ new PushFailedBatch(
+ pbPushFailedBatch.getMapId,
+ pbPushFailedBatch.getAttemptId,
+ pbPushFailedBatch.getBatchId)
+ }
+
+ def toPbPushFailedBatchSet(failedBatchSet: util.Set[PushFailedBatch]): PbPushFailedBatchSet = {
+ val builder = PbPushFailedBatchSet.newBuilder()
+ failedBatchSet.asScala.foreach(batch => builder.addFailureBatches(toPbPushFailedBatch(batch)))
+
+ builder.build()
+ }
+
+ def fromPbPushFailedBatchSet(pbFailedBatchSet: PbPushFailedBatchSet)
+ : util.Set[PushFailedBatch] = {
+ val failedBatchSet = new util.HashSet[PushFailedBatch]()
+ pbFailedBatchSet.getFailureBatchesList.asScala.foreach(batch =>
+ failedBatchSet.add(fromPbPushFailedBatch(batch)))
+
+ failedBatchSet
+ }
}
diff --git a/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java
index 6f63e487056..0d0613a33ec 100644
--- a/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java
+++ b/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java
@@ -209,7 +209,7 @@ public void testToStringOutput() {
+ " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
+ " mode:PRIMARY\n"
+ " peer:(empty)\n"
- + " storage hint:StorageInfo{type=MEMORY, mountPoint='', finalResult=false, filePath=null}\n"
+ + " storage hint:StorageInfo{type=MEMORY, mountPoint='', finalResult=false, filePath=null, fileSize=0, chunkOffsets=null}\n"
+ " mapIdBitMap:{}]";
String exp2 =
"PartitionLocation[\n"
@@ -217,7 +217,7 @@ public void testToStringOutput() {
+ " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
+ " mode:PRIMARY\n"
+ " peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n"
- + " storage hint:StorageInfo{type=MEMORY, mountPoint='', finalResult=false, filePath=null}\n"
+ + " storage hint:StorageInfo{type=MEMORY, mountPoint='', finalResult=false, filePath=null, fileSize=0, chunkOffsets=null}\n"
+ " mapIdBitMap:{}]";
String exp3 =
"PartitionLocation[\n"
@@ -225,7 +225,8 @@ public void testToStringOutput() {
+ " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
+ " mode:PRIMARY\n"
+ " peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n"
- + " storage hint:StorageInfo{type=MEMORY, mountPoint='/mnt/disk/0', finalResult=false, filePath=null}\n"
+ + " storage hint:StorageInfo{type=MEMORY, mountPoint='/mnt/disk/0', "
+ + "finalResult=false, filePath=null, fileSize=0, chunkOffsets=null}\n"
+ " mapIdBitMap:{1,2,3}]";
assertEquals(exp1, location1.toString());
assertEquals(exp2, location2.toString());
diff --git a/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java
new file mode 100644
index 00000000000..fcfc6b79979
--- /dev/null
+++ b/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.write;
+
+import java.util.HashSet;
+import java.util.Set;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class PushFailedBatchSuiteJ {
+
+ @Test
+ public void equalsReturnsTrueForIdenticalBatches() {
+ PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
+ PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3);
+ Assert.assertEquals(batch1, batch2);
+ }
+
+ @Test
+ public void equalsReturnsFalseForDifferentBatches() {
+ PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
+ PushFailedBatch batch2 = new PushFailedBatch(4, 5, 6);
+ Assert.assertNotEquals(batch1, batch2);
+ }
+
+ @Test
+ public void hashCodeDiffersForDifferentBatches() {
+ PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
+ PushFailedBatch batch2 = new PushFailedBatch(4, 5, 6);
+ Assert.assertNotEquals(batch1.hashCode(), batch2.hashCode());
+ }
+
+ @Test
+ public void hashCodeSameForIdenticalBatches() {
+ PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
+ PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3);
+ Assert.assertEquals(batch1.hashCode(), batch2.hashCode());
+ }
+
+ @Test
+ public void hashCodeIsConsistent() {
+ PushFailedBatch batch = new PushFailedBatch(1, 2, 3);
+ int hashCode1 = batch.hashCode();
+ int hashCode2 = batch.hashCode();
+ Assert.assertEquals(hashCode1, hashCode2);
+ }
+
+ @Test
+ public void toStringReturnsExpectedFormat() {
+ PushFailedBatch batch = new PushFailedBatch(1, 2, 3);
+ String expected = "PushFailedBatch[mapId=1,attemptId=2,batchId=3]";
+ Assert.assertEquals(expected, batch.toString());
+ }
+
+ @Test
+ public void hashCodeAndEqualsWorkInSet() {
+ Set set = new HashSet<>();
+ PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3);
+ PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3);
+ set.add(batch1);
+ Assert.assertTrue(set.contains(batch2));
+ }
+}
diff --git a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala
index 9e1d442fb24..c2c6e47338b 100644
--- a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala
+++ b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random
+import com.google.common.collect.{Lists, Sets}
import org.apache.hadoop.shaded.org.apache.commons.lang3.RandomStringUtils
import org.apache.celeborn.CelebornFunSuite
@@ -35,6 +36,7 @@ import org.apache.celeborn.common.protocol.message.{ControlMessages, StatusCode}
import org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse, WorkerResource}
import org.apache.celeborn.common.quota.ResourceConsumption
import org.apache.celeborn.common.util.PbSerDeUtils.{fromPbPackedPartitionLocationsPair, toPbPackedPartitionLocationsPair}
+import org.apache.celeborn.common.write.PushFailedBatch
class PbSerDeUtilsTest extends CelebornFunSuite {
@@ -172,6 +174,37 @@ class PbSerDeUtilsTest extends CelebornFunSuite {
27,
PartitionLocation.Mode.PRIMARY)
+ val partitionLocation5 =
+ new PartitionLocation(
+ 4,
+ 4,
+ "host5",
+ 50,
+ 49,
+ 48,
+ 47,
+ PartitionLocation.Mode.PRIMARY)
+ val partitionLocation6 =
+ new PartitionLocation(
+ 5,
+ 5,
+ "host6",
+ 60,
+ 59,
+ 58,
+ 57,
+ PartitionLocation.Mode.REPLICA,
+ null,
+ new StorageInfo(
+ StorageInfo.Type.HDD,
+ "",
+ false,
+ null,
+ StorageInfo.LOCAL_DISK_MASK,
+ 5,
+ null),
+ null)
+
val workerResource = new WorkerResource()
workerResource.put(
workerInfo1,
@@ -369,6 +402,70 @@ class PbSerDeUtilsTest extends CelebornFunSuite {
assert(partitionLocation3 == loc1)
}
+ test("testPackedPartitionLocationPairCase3") {
+ partitionLocation5.setStorageInfo(new StorageInfo(
+ StorageInfo.Type.HDD,
+ "",
+ false,
+ null,
+ StorageInfo.LOCAL_DISK_MASK,
+ 5,
+ Lists.newArrayList(0, 5, 10)))
+ partitionLocation5.setPeer(partitionLocation6)
+ val pairPb = PbSerDeUtils.toPbPackedPartitionLocationsPair(
+ List(partitionLocation5, partitionLocation6))
+ val rePb = PbSerDeUtils.fromPbPackedPartitionLocationsPair(pairPb)
+
+ val loc1 = rePb._1.get(0)
+ val loc2 = rePb._2.get(0)
+
+ assert(partitionLocation5 == loc1)
+ assert(partitionLocation6 == loc2)
+ assert(loc1.getStorageInfo.getFileSize == partitionLocation5.getStorageInfo.getFileSize)
+ assert(loc1.getStorageInfo.getChunkOffsets == partitionLocation5.getStorageInfo.getChunkOffsets)
+
+ assert(loc2.getStorageInfo.getFileSize == partitionLocation6.getStorageInfo.getFileSize)
+ assert(loc2.getStorageInfo.getChunkOffsets.isEmpty)
+ }
+
+ test("testPackedPartitionLocationPairCase4") {
+ partitionLocation5.setStorageInfo(new StorageInfo(
+ StorageInfo.Type.HDD,
+ "",
+ false,
+ null,
+ StorageInfo.LOCAL_DISK_MASK,
+ 5,
+ null))
+ val pairPb = PbSerDeUtils.toPbPackedPartitionLocationsPair(
+ List(partitionLocation5))
+ val rePb = PbSerDeUtils.fromPbPackedPartitionLocationsPair(pairPb)
+
+ val loc1 = rePb._1.get(0)
+
+ assert(partitionLocation5 == loc1)
+ assert(loc1.getStorageInfo.getFileSize == partitionLocation5.getStorageInfo.getFileSize)
+ assert(loc1.getStorageInfo.getChunkOffsets.isEmpty)
+ }
+
+ test("testPackedPartitionLocationPairCase5") {
+ partitionLocation5.setStorageInfo(new StorageInfo(
+ StorageInfo.Type.HDD,
+ "",
+ false,
+ null,
+ StorageInfo.LOCAL_DISK_MASK))
+ val pairPb = PbSerDeUtils.toPbPackedPartitionLocationsPair(
+ List(partitionLocation5))
+ val rePb = PbSerDeUtils.fromPbPackedPartitionLocationsPair(pairPb)
+
+ val loc1 = rePb._1.get(0)
+
+ assert(partitionLocation5 == loc1)
+ assert(loc1.getStorageInfo.getFileSize == partitionLocation5.getStorageInfo.getFileSize)
+ assert(loc1.getStorageInfo.getChunkOffsets.isEmpty)
+ }
+
test("testPackedPartitionLocationPairIPv6") {
val pairPb = PbSerDeUtils.toPbPackedPartitionLocationsPair(
List(partitionLocationIPv6))
@@ -565,4 +662,21 @@ class PbSerDeUtilsTest extends CelebornFunSuite {
locations.asScala.foreach(p => uniqueIds.remove(p.getUniqueId))
assert(uniqueIds.isEmpty)
}
+
+ test("fromAndToPushFailedBatch") {
+ val failedBatch = new PushFailedBatch(1, 1, 2)
+ val pbPushFailedBatch = PbSerDeUtils.toPbPushFailedBatch(failedBatch)
+ val restoredFailedBatch = PbSerDeUtils.fromPbPushFailedBatch(pbPushFailedBatch)
+
+ assert(restoredFailedBatch.equals(failedBatch))
+ }
+
+ test("fromAndToPushFailedBatchSet") {
+ val failedBatchSet = Sets.newHashSet(new PushFailedBatch(1, 1, 2), new PushFailedBatch(2, 2, 3))
+ val pbPushFailedBatchSet = PbSerDeUtils.toPbPushFailedBatchSet(failedBatchSet)
+ val restoredFailedBatchSet = PbSerDeUtils.fromPbPushFailedBatchSet(pbPushFailedBatchSet)
+
+ assert(restoredFailedBatchSet.equals(failedBatchSet))
+ }
+
}
diff --git a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
index afc74707f62..03c6176edd8 100644
--- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
+++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.celeborn.common.util
import java.util
+import java.util.Collections
import org.apache.celeborn.CelebornFunSuite
import org.apache.celeborn.common.CelebornConf
@@ -144,7 +145,7 @@ class UtilsSuite extends CelebornFunSuite {
}
test("MapperEnd class convert with pb") {
- val mapperEnd = MapperEnd(1, 1, 1, 2, 1)
+ val mapperEnd = MapperEnd(1, 1, 1, 2, 1, Collections.emptyMap())
val mapperEndTrans =
Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd]
assert(mapperEnd == mapperEndTrans)
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index c576e0c4d27..d9e1a700c77 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -19,6 +19,7 @@ license: |
| Key | Default | isDynamic | Description | Since | Deprecated |
| --- | ------- | --------- | ----------- | ----- | ---------- |
+| celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled | false | false | If this is true, Celeborn will adaptively split skewed partitions instead of reading them by Spark map range. Please note that this feature requires the `Celeborn-Optimize-Skew-Partitions-spark3_3.patch`. | 0.6.0 | |
| celeborn.client.application.heartbeatInterval | 10s | false | Interval for client to send heartbeat message to master. | 0.3.0 | celeborn.application.heartbeatInterval |
| celeborn.client.application.unregister.enabled | true | false | When true, Celeborn client will inform celeborn master the application is already shutdown during client exit, this allows the cluster to release resources immediately, resulting in resource savings. | 0.3.2 | |
| celeborn.client.application.uuidSuffix.enabled | false | false | Whether to add UUID suffix for application id for unique. When `true`, add UUID suffix for unique application id. Currently, this only applies to Spark and MR. | 0.6.0 | |
diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
index b9b084b7b46..39b54b5f638 100644
--- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
+++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
@@ -17,12 +17,13 @@
package org.apache.celeborn.tests.client
+import java.nio.charset.StandardCharsets
import java.util
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
-import org.apache.celeborn.client.{LifecycleManager, WithShuffleClientSuite}
+import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl, WithShuffleClientSuite}
import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers
import org.apache.celeborn.client.commit.CommitFilesParam
import org.apache.celeborn.common.CelebornConf
@@ -228,6 +229,70 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC
lifecycleManager.stop()
}
+ test("CELEBORN-1319: test commit files and check commit info") {
+ val shuffleId = nextShuffleId
+ val conf = celebornConf.clone
+ conf.set(CelebornConf.TEST_MOCK_COMMIT_FILES_FAILURE.key, "false")
+ val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf)
+ val shuffleClient = new ShuffleClientImpl(APP, conf, userIdentifier)
+ shuffleClient.setupLifecycleManagerRef(lifecycleManager.self)
+
+ val ids = new util.ArrayList[Integer](3)
+ 0 until 3 foreach {
+ ids.add(_)
+ }
+ val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids)
+ assert(res.status == StatusCode.SUCCESS)
+ assert(res.workerResource.keySet().size() == 3)
+
+ lifecycleManager.setupEndpoints(
+ res.workerResource.keySet,
+ shuffleId,
+ new ShuffleFailedWorkers())
+
+ lifecycleManager.reserveSlotsWithRetry(
+ shuffleId,
+ new util.HashSet(res.workerResource.keySet()),
+ res.workerResource,
+ updateEpoch = false)
+
+ lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false)
+
+ val buffer = "hello world".getBytes(StandardCharsets.UTF_8)
+
+ var bufferLength = -1
+
+ 0 until 3 foreach { partitionId =>
+ bufferLength =
+ shuffleClient.pushData(shuffleId, 0, 0, partitionId, buffer, 0, buffer.length, 1, 3)
+ lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId)
+ }
+
+ val commitHandler = lifecycleManager.commitManager.getCommitHandler(shuffleId)
+ val params = new ArrayBuffer[CommitFilesParam](res.workerResource.size())
+ res.workerResource.asScala.foreach { case (workerInfo, (primaryIds, replicaIds)) =>
+ params += CommitFilesParam(
+ workerInfo,
+ primaryIds.asScala.map(_.getUniqueId).toList.asJava,
+ replicaIds.asScala.map(_.getUniqueId).toList.asJava)
+ }
+
+ val shuffleCommittedInfo = lifecycleManager.commitManager.committedPartitionInfo.get(shuffleId)
+ commitHandler.doParallelCommitFiles(
+ shuffleId,
+ shuffleCommittedInfo,
+ params,
+ new ShuffleFailedWorkers)
+
+ shuffleCommittedInfo.committedReplicaStorageInfos.values().asScala.foreach { storageInfo =>
+ assert(storageInfo.fileSize == bufferLength)
+ // chunkOffsets contains 0 by default, and bufferFlushOffset
+ assert(storageInfo.chunkOffsets.size() == 2)
+ }
+
+ lifecycleManager.stop()
+ }
+
override def afterAll(): Unit = {
logInfo("all test complete , stop celeborn mini cluster")
shutdownMiniCluster()
diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala
new file mode 100644
index 00000000000..716ad4e3693
--- /dev/null
+++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.spark
+
+import com.google.common.collect.Sets
+import org.apache.spark.{SparkConf, SparkContext, SparkContextHelper}
+import org.apache.spark.shuffle.celeborn.SparkShuffleManager
+import org.junit.Assert
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.ShuffleClient
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.protocol.ShuffleMode
+import org.apache.celeborn.common.write.PushFailedBatch
+import org.apache.celeborn.service.deploy.worker.PushDataHandler
+
+class PushFailedBatchSuite extends AnyFunSuite
+ with SparkTestBase
+ with BeforeAndAfterEach {
+
+ override def beforeAll(): Unit = {
+ val workerConf = Map(
+ CelebornConf.TEST_WORKER_PUSH_REPLICA_DATA_TIMEOUT.key -> "true")
+
+ setupMiniClusterWithRandomPorts(workerConf = workerConf, workerNum = 4)
+ }
+
+ override def beforeEach(): Unit = {
+ ShuffleClient.reset()
+ PushDataHandler.pushPrimaryDataTimeoutTested.set(false)
+ PushDataHandler.pushReplicaDataTimeoutTested.set(false)
+ PushDataHandler.pushPrimaryMergeDataTimeoutTested.set(false)
+ PushDataHandler.pushReplicaMergeDataTimeoutTested.set(false)
+ }
+
+ override protected def afterEach() {
+ System.gc()
+ }
+
+ test("CELEBORN-1319: check failed batch info by making push timeout") {
+ val sparkConf = new SparkConf()
+ .set(s"spark.${CelebornConf.TEST_CLIENT_RETRY_REVIVE.key}", "false")
+ .set(s"spark.${CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key}", "true")
+ .set(s"spark.${CelebornConf.CLIENT_PUSH_DATA_TIMEOUT.key}", "3s")
+ .set(
+ s"spark.${CelebornConf.CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED.key}",
+ "true")
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .set("spark.shuffle.manager", "org.apache.spark.shuffle.celeborn.SparkShuffleManager")
+ .setAppName("celeborn-1319")
+ .setMaster("local[2]")
+ updateSparkConf(sparkConf, ShuffleMode.HASH)
+ val sc = new SparkContext(sparkConf)
+
+ sc.parallelize(1 to 1, 1).repartition(1).map(i => i + 1).collect()
+
+ val manager = SparkContextHelper.env
+ .shuffleManager
+ .asInstanceOf[SparkShuffleManager]
+ .getLifecycleManager
+
+ // only one batch failed due to push timeout, so shuffle id will be 0,
+ // and PartitionLocation uniqueId will be 0-0
+ val pushFailedBatch = manager.commitManager.getCommitHandler(0).getShuffleFailedBatches()
+ assert(!pushFailedBatch.isEmpty)
+ Assert.assertEquals(
+ pushFailedBatch.get(0).get("0-0"),
+ Sets.newHashSet(new PushFailedBatch(0, 0, 1)))
+
+ sc.stop()
+ }
+
+}
diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala
index 6cfbd227a94..f1a5d5e7769 100644
--- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala
+++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala
@@ -48,63 +48,69 @@ class SkewJoinSuite extends AnyFunSuite
}
CompressionCodec.values.foreach { codec =>
- test(s"celeborn spark integration test - skew join - $codec") {
- val sparkConf = new SparkConf().setAppName("celeborn-demo")
- .setMaster("local[2]")
- .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
- .set("spark.sql.adaptive.skewJoin.enabled", "true")
- .set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "16MB")
- .set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "12MB")
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.sql.adaptive.autoBroadcastJoinThreshold", "-1")
- .set(SQLConf.PARQUET_COMPRESSION.key, "gzip")
- .set(s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", codec.name)
- .set(s"spark.${CelebornConf.SHUFFLE_RANGE_READ_FILTER_ENABLED.key}", "true")
+ Seq(false, true).foreach { enabled =>
+ test(
+ s"celeborn spark integration test - skew join - with $codec - with client skew $enabled") {
+ val sparkConf = new SparkConf().setAppName("celeborn-demo")
+ .setMaster("local[2]")
+ .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+ .set("spark.sql.adaptive.skewJoin.enabled", "true")
+ .set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "16MB")
+ .set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "12MB")
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .set("spark.sql.adaptive.autoBroadcastJoinThreshold", "-1")
+ .set(SQLConf.PARQUET_COMPRESSION.key, "gzip")
+ .set(s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", codec.name)
+ .set(s"spark.${CelebornConf.SHUFFLE_RANGE_READ_FILTER_ENABLED.key}", "true")
+ .set(
+ s"spark.${CelebornConf.CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED.key}",
+ s"$enabled")
- enableCeleborn(sparkConf)
+ enableCeleborn(sparkConf)
- val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
- if (sparkSession.version.startsWith("3")) {
- import sparkSession.implicits._
- val df = sparkSession.sparkContext.parallelize(1 to 120000, 8)
- .map(i => {
- val random = new Random()
- val oriKey = random.nextInt(64)
- val key = if (oriKey < 32) 1 else oriKey
- val fas = random.nextInt(1200000)
- val fa = Range(fas, fas + 100).mkString(",")
- val fbs = random.nextInt(1200000)
- val fb = Range(fbs, fbs + 100).mkString(",")
- val fcs = random.nextInt(1200000)
- val fc = Range(fcs, fcs + 100).mkString(",")
- val fds = random.nextInt(1200000)
- val fd = Range(fds, fds + 100).mkString(",")
+ val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
+ if (sparkSession.version.startsWith("3")) {
+ import sparkSession.implicits._
+ val df = sparkSession.sparkContext.parallelize(1 to 120000, 8)
+ .map(i => {
+ val random = new Random()
+ val oriKey = random.nextInt(64)
+ val key = if (oriKey < 32) 1 else oriKey
+ val fas = random.nextInt(1200000)
+ val fa = Range(fas, fas + 100).mkString(",")
+ val fbs = random.nextInt(1200000)
+ val fb = Range(fbs, fbs + 100).mkString(",")
+ val fcs = random.nextInt(1200000)
+ val fc = Range(fcs, fcs + 100).mkString(",")
+ val fds = random.nextInt(1200000)
+ val fd = Range(fds, fds + 100).mkString(",")
- (key, fa, fb, fc, fd)
- })
- .toDF("fa", "f1", "f2", "f3", "f4")
- df.createOrReplaceTempView("view1")
- val df2 = sparkSession.sparkContext.parallelize(1 to 8, 8)
- .map(i => {
- val random = new Random()
- val oriKey = random.nextInt(64)
- val key = if (oriKey < 32) 1 else oriKey
- val fas = random.nextInt(1200000)
- val fa = Range(fas, fas + 100).mkString(",")
- val fbs = random.nextInt(1200000)
- val fb = Range(fbs, fbs + 100).mkString(",")
- val fcs = random.nextInt(1200000)
- val fc = Range(fcs, fcs + 100).mkString(",")
- val fds = random.nextInt(1200000)
- val fd = Range(fds, fds + 100).mkString(",")
- (key, fa, fb, fc, fd)
- })
- .toDF("fb", "f6", "f7", "f8", "f9")
- df2.createOrReplaceTempView("view2")
- sparkSession.sql("drop table if exists fres")
- sparkSession.sql("create table fres using parquet as select * from view1 a inner join view2 b on a.fa=b.fb where a.fa=1 ")
- sparkSession.sql("drop table fres")
- sparkSession.stop()
+ (key, fa, fb, fc, fd)
+ })
+ .toDF("fa", "f1", "f2", "f3", "f4")
+ df.createOrReplaceTempView("view1")
+ val df2 = sparkSession.sparkContext.parallelize(1 to 8, 8)
+ .map(i => {
+ val random = new Random()
+ val oriKey = random.nextInt(64)
+ val key = if (oriKey < 32) 1 else oriKey
+ val fas = random.nextInt(1200000)
+ val fa = Range(fas, fas + 100).mkString(",")
+ val fbs = random.nextInt(1200000)
+ val fb = Range(fbs, fbs + 100).mkString(",")
+ val fcs = random.nextInt(1200000)
+ val fc = Range(fcs, fcs + 100).mkString(",")
+ val fds = random.nextInt(1200000)
+ val fd = Range(fds, fds + 100).mkString(",")
+ (key, fa, fb, fc, fd)
+ })
+ .toDF("fb", "f6", "f7", "f8", "f9")
+ df2.createOrReplaceTempView("view2")
+ sparkSession.sql("drop table if exists fres")
+ sparkSession.sql("create table fres using parquet as select * from view1 a inner join view2 b on a.fa=b.fb where a.fa=1 ")
+ sparkSession.sql("drop table fres")
+ sparkSession.stop()
+ }
}
}
}
diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
index 65285f64729..b2e7c4e844c 100644
--- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
+++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
@@ -32,7 +32,7 @@ import org.roaringbitmap.RoaringBitmap
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.identity.UserIdentifier
import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.meta.{WorkerInfo, WorkerPartitionLocationInfo}
+import org.apache.celeborn.common.meta.{ReduceFileMeta, WorkerInfo, WorkerPartitionLocationInfo}
import org.apache.celeborn.common.metrics.MetricsSystem
import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType, StorageInfo}
import org.apache.celeborn.common.protocol.message.ControlMessages._
@@ -337,7 +337,21 @@ private[deploy] class Controller(
// Only HDFS can be null, means that this partition location is deleted.
logDebug(s"Location $uniqueId is deleted.")
} else {
- committedStorageInfos.put(uniqueId, fileWriter.getStorageInfo)
+ val storageInfo = fileWriter.getStorageInfo
+ val fileInfo =
+ if (null != fileWriter.getDiskFileInfo) {
+ fileWriter.getDiskFileInfo
+ } else {
+ fileWriter.getMemoryFileInfo
+ }
+ val fileMeta = fileInfo.getFileMeta
+ fileMeta match {
+ case meta: ReduceFileMeta =>
+ storageInfo.setFileSize(bytes)
+ storageInfo.setChunkOffsets(meta.getChunkOffsets)
+ case _ =>
+ }
+ committedStorageInfos.put(uniqueId, storageInfo)
if (fileWriter.getMapIdBitMap != null) {
committedMapIdBitMap.put(uniqueId, fileWriter.getMapIdBitMap)
}
diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index ff32b940173..3c99639eb84 100644
--- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -249,8 +249,8 @@ class FetchHandler(
// 1. when the current request is a non-range openStream, but the original unsorted file
// has been deleted by another range's openStream request.
// 2. when the current request is a range openStream request.
- if ((endIndex != Int.MaxValue) || (endIndex == Int.MaxValue
- && !fileInfo.addStream(streamId))) {
+ if ((endIndex != Int.MaxValue && endIndex != -1 && endIndex >= startIndex) || (endIndex == Int.MaxValue && !fileInfo.addStream(
+ streamId))) {
fileInfo = partitionsSorter.getSortedFileInfo(
shuffleKey,
fileName,
diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
index 0ff646b8f8b..e81593e1d99 100644
--- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
+++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
@@ -116,6 +116,8 @@ trait ReadWriteTestBase extends AnyFunSuite
null,
null,
null,
+ null,
+ null,
metricsCallback)
val outputStream = new ByteArrayOutputStream()