diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch new file mode 100644 index 00000000000..0cb1fc812dc --- /dev/null +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch @@ -0,0 +1,315 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +index e469c9989f2..a4a68ef1b09 100644 +--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala ++++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +@@ -33,6 +33,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut + import org.roaringbitmap.RoaringBitmap + + import org.apache.spark.broadcast.{Broadcast, BroadcastManager} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.Logging + import org.apache.spark.internal.config._ + import org.apache.spark.io.CompressionCodec +@@ -839,6 +840,7 @@ private[spark] class MapOutputTrackerMaster( + shuffleStatus.invalidateSerializedMergeOutputStatusCache() + } + } ++ CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId) + } + + /** +diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala +index 0388c7b576b..59fdc81b09d 100644 +--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala ++++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala +@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration + import org.apache.spark.annotation.DeveloperApi + import org.apache.spark.api.python.PythonWorkerFactory + import org.apache.spark.broadcast.BroadcastManager ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.{config, Logging} + import org.apache.spark.internal.config._ + import org.apache.spark.memory.{MemoryManager, UnifiedMemoryManager} +@@ -414,6 +415,7 @@ object SparkEnv extends Logging { + if (isDriver) { + val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath + envInstance.driverTmpDir = Some(sparkFilesDir) ++ CelebornShuffleState.init(envInstance) + } + + envInstance +diff --git a/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +new file mode 100644 +index 00000000000..5e190c512df +--- /dev/null ++++ b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +@@ -0,0 +1,75 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.celeborn ++ ++import java.util.concurrent.ConcurrentHashMap ++import java.util.concurrent.atomic.AtomicBoolean ++ ++import org.apache.spark.SparkEnv ++import org.apache.spark.internal.config.ConfigBuilder ++ ++object CelebornShuffleState { ++ ++ private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = ++ ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") ++ .booleanConf ++ .createWithDefault(false) ++ ++ private val CELEBORN_STAGE_RERUN_ENABLED = ++ ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled") ++ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") ++ .booleanConf ++ .createWithDefault(false) ++ ++ private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean() ++ private val stageRerunEnabled = new AtomicBoolean() ++ private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() ++ ++ // call this from SparkEnv.create ++ def init(env: SparkEnv): Unit = { ++ // cleanup existing state (if required) - and initialize ++ skewShuffleIds.clear() ++ ++ // use env.conf for all initialization, and not SQLConf ++ celebornOptimizeSkewedPartitionReadEnabled.set( ++ env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ)) ++ stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED)) ++ } ++ ++ def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.remove(shuffleId) ++ } ++ ++ def registerCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.add(shuffleId) ++ } ++ ++ def isCelebornSkewedShuffle(shuffleId: Int): Boolean = { ++ skewShuffleIds.contains(shuffleId) ++ } ++ ++ def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = { ++ celebornOptimizeSkewedPartitionReadEnabled.get() ++ } ++ ++ def celebornStageRerunEnabled: Boolean = { ++ stageRerunEnabled.get() ++ } ++ ++} +diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +index b950c07f3d8..2cb430c3c3d 100644 +--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +@@ -33,6 +33,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture} + + import org.apache.spark._ + import org.apache.spark.broadcast.Broadcast ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} + import org.apache.spark.internal.Logging + import org.apache.spark.internal.config +@@ -1780,7 +1781,7 @@ private[spark] class DAGScheduler( + failedStage.failedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || +- disallowStageRetryForTest ++ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId) + + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +new file mode 100644 +index 00000000000..3dc60678461 +--- /dev/null ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +@@ -0,0 +1,35 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.sql.execution.adaptive ++ ++import java.util.Locale ++ ++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} ++ ++object CelebornShuffleUtil { ++ ++ def isCelebornShuffle(shuffleExchangeLike: ShuffleExchangeLike): Boolean = { ++ shuffleExchangeLike match { ++ case exec: ShuffleExchangeExec => ++ exec.shuffleDependency.shuffleHandle ++ .getClass.getName.toLowerCase(Locale.ROOT).contains("celeborn") ++ case _ => false ++ } ++ } ++ ++} +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +index 1752907a9a5..2c6a49b78eb 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +@@ -50,12 +50,13 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + private def optimizeSkewedPartitions( + shuffleId: Int, + bytesByPartitionId: Array[Long], +- targetSize: Long): Seq[ShufflePartitionSpec] = { ++ targetSize: Long, ++ isCelebornShuffle: Boolean = false): Seq[ShufflePartitionSpec] = { + bytesByPartitionId.indices.flatMap { reduceIndex => + val bytes = bytesByPartitionId(reduceIndex) + if (bytes > targetSize) { + val newPartitionSpec = +- ShufflePartitionsUtil.createSkewPartitionSpecs(shuffleId, reduceIndex, targetSize) ++ ShufflePartitionsUtil.createSkewPartitionSpecs(shuffleId, reduceIndex, targetSize, isCelebornShuffle) + if (newPartitionSpec.isEmpty) { + CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil + } else { +@@ -77,8 +78,9 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + return shuffle + } + ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(shuffle.shuffle) + val newPartitionsSpec = optimizeSkewedPartitions( +- mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize) ++ mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize, isCelebornShuffle) + // return origin plan if we can not optimize partitions + if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) { + shuffle +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +index 88abe68197b..150699a84a3 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +@@ -157,8 +157,10 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule { + Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize)) + + val leftParts = if (isLeftSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(left.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- left.mapStats.get.shuffleId, partitionIndex, leftTargetSize) ++ left.mapStats.get.shuffleId, partitionIndex, leftTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Left side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " + +@@ -171,8 +173,10 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule { + } + + val rightParts = if (isRightSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(right.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- right.mapStats.get.shuffleId, partitionIndex, rightTargetSize) ++ right.mapStats.get.shuffleId, partitionIndex, rightTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Right side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +index 3609548f374..d34f43bf064 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive + import scala.collection.mutable.ArrayBuffer + + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.Logging + import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} + +@@ -376,11 +377,20 @@ object ShufflePartitionsUtil extends Logging { + def createSkewPartitionSpecs( + shuffleId: Int, + reducerId: Int, +- targetSize: Long): Option[Seq[PartialReducerPartitionSpec]] = { ++ targetSize: Long, ++ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = { + val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId) + if (mapPartitionSizes.exists(_ < 0)) return None + val mapStartIndices = splitSizeListByTargetSize(mapPartitionSizes, targetSize) + if (mapStartIndices.length > 1) { ++ val celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = ++ CelebornShuffleState.celebornAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle ++ ++ val throwsFetchFailure = CelebornShuffleState.celebornStageRerunEnabled ++ if (throwsFetchFailure && celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") ++ CelebornShuffleState.registerCelebornSkewedShuffle(shuffleId) ++ } + Some(mapStartIndices.indices.map { i => + val startMapIndex = mapStartIndices(i) + val endMapIndex = if (i == mapStartIndices.length - 1) { +@@ -388,8 +398,21 @@ object ShufflePartitionsUtil extends Logging { + } else { + mapStartIndices(i + 1) + } +- val dataSize = startMapIndex.until(endMapIndex).map(mapPartitionSizes(_)).sum +- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ var dataSize = 0L ++ var mapIndex = startMapIndex ++ while (mapIndex < endMapIndex) { ++ dataSize += mapPartitionSizes(mapIndex) ++ mapIndex += 1 ++ } ++ ++ if (celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ // These `dataSize` variables may not be accurate as they only represent the sum of ++ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. ++ // Please not to use these dataSize variables in any other part of the codebase. ++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize) ++ } else { ++ PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ } + }) + } else { + None diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch new file mode 100644 index 00000000000..f8e38615c6a --- /dev/null +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch @@ -0,0 +1,312 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +index b1974948430..a045c8646ba 100644 +--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala ++++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +@@ -33,6 +33,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut + import org.roaringbitmap.RoaringBitmap + + import org.apache.spark.broadcast.{Broadcast, BroadcastManager} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.Logging + import org.apache.spark.internal.config._ + import org.apache.spark.io.CompressionCodec +@@ -886,6 +887,7 @@ private[spark] class MapOutputTrackerMaster( + shuffleStatus.invalidateSerializedMergeOutputStatusCache() + } + } ++ CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId) + } + + /** +diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala +index 19467e7eca1..0ae4990219c 100644 +--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala ++++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala +@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration + import org.apache.spark.annotation.DeveloperApi + import org.apache.spark.api.python.PythonWorkerFactory + import org.apache.spark.broadcast.BroadcastManager ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.executor.ExecutorBackend + import org.apache.spark.internal.{config, Logging} + import org.apache.spark.internal.config._ +@@ -419,6 +420,7 @@ object SparkEnv extends Logging { + if (isDriver) { + val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath + envInstance.driverTmpDir = Some(sparkFilesDir) ++ CelebornShuffleState.init(envInstance) + } + + envInstance +diff --git a/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +new file mode 100644 +index 00000000000..5e190c512df +--- /dev/null ++++ b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +@@ -0,0 +1,75 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.celeborn ++ ++import java.util.concurrent.ConcurrentHashMap ++import java.util.concurrent.atomic.AtomicBoolean ++ ++import org.apache.spark.SparkEnv ++import org.apache.spark.internal.config.ConfigBuilder ++ ++object CelebornShuffleState { ++ ++ private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = ++ ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") ++ .booleanConf ++ .createWithDefault(false) ++ ++ private val CELEBORN_STAGE_RERUN_ENABLED = ++ ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled") ++ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") ++ .booleanConf ++ .createWithDefault(false) ++ ++ private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean() ++ private val stageRerunEnabled = new AtomicBoolean() ++ private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() ++ ++ // call this from SparkEnv.create ++ def init(env: SparkEnv): Unit = { ++ // cleanup existing state (if required) - and initialize ++ skewShuffleIds.clear() ++ ++ // use env.conf for all initialization, and not SQLConf ++ celebornOptimizeSkewedPartitionReadEnabled.set( ++ env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ)) ++ stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED)) ++ } ++ ++ def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.remove(shuffleId) ++ } ++ ++ def registerCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.add(shuffleId) ++ } ++ ++ def isCelebornSkewedShuffle(shuffleId: Int): Boolean = { ++ skewShuffleIds.contains(shuffleId) ++ } ++ ++ def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = { ++ celebornOptimizeSkewedPartitionReadEnabled.get() ++ } ++ ++ def celebornStageRerunEnabled: Boolean = { ++ stageRerunEnabled.get() ++ } ++ ++} +diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +index bd2823bcac1..d0c88081527 100644 +--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +@@ -33,6 +33,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture} + + import org.apache.spark._ + import org.apache.spark.broadcast.Broadcast ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.errors.SparkCoreErrors + import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} + import org.apache.spark.internal.Logging +@@ -1851,7 +1852,7 @@ private[spark] class DAGScheduler( + failedStage.failedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || +- disallowStageRetryForTest ++ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId) + + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +new file mode 100644 +index 00000000000..3dc60678461 +--- /dev/null ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +@@ -0,0 +1,35 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.sql.execution.adaptive ++ ++import java.util.Locale ++ ++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} ++ ++object CelebornShuffleUtil { ++ ++ def isCelebornShuffle(shuffleExchangeLike: ShuffleExchangeLike): Boolean = { ++ shuffleExchangeLike match { ++ case exec: ShuffleExchangeExec => ++ exec.shuffleDependency.shuffleHandle ++ .getClass.getName.toLowerCase(Locale.ROOT).contains("celeborn") ++ case _ => false ++ } ++ } ++ ++} +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +index b34ab3e380b..cb0ed9d05a4 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +@@ -47,14 +47,15 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + private def optimizeSkewedPartitions( + shuffleId: Int, + bytesByPartitionId: Array[Long], +- targetSize: Long): Seq[ShufflePartitionSpec] = { ++ targetSize: Long, ++ isCelebornShuffle: Boolean = false): Seq[ShufflePartitionSpec] = { + val smallPartitionFactor = + conf.getConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR) + bytesByPartitionId.indices.flatMap { reduceIndex => + val bytes = bytesByPartitionId(reduceIndex) + if (bytes > targetSize) { + val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs( +- shuffleId, reduceIndex, targetSize, smallPartitionFactor) ++ shuffleId, reduceIndex, targetSize, smallPartitionFactor, isCelebornShuffle) + if (newPartitionSpec.isEmpty) { + CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil + } else { +@@ -76,8 +77,9 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + return shuffle + } + ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(shuffle.shuffle) + val newPartitionsSpec = optimizeSkewedPartitions( +- mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize) ++ mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize, isCelebornShuffle) + // return origin plan if we can not optimize partitions + if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) { + shuffle +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +index d4a173bb9cc..21ef335e064 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +@@ -152,8 +152,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize)) + + val leftParts = if (isLeftSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(left.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- left.mapStats.get.shuffleId, partitionIndex, leftTargetSize) ++ left.mapStats.get.shuffleId, partitionIndex, leftTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Left side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " + +@@ -166,8 +168,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + } + + val rightParts = if (isRightSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(right.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- right.mapStats.get.shuffleId, partitionIndex, rightTargetSize) ++ right.mapStats.get.shuffleId, partitionIndex, rightTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Right side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +index af689db3379..39d0b3132ee 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive + import scala.collection.mutable.ArrayBuffer + + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.Logging + import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} + +@@ -380,13 +381,21 @@ object ShufflePartitionsUtil extends Logging { + shuffleId: Int, + reducerId: Int, + targetSize: Long, +- smallPartitionFactor: Double = SMALL_PARTITION_FACTOR) +- : Option[Seq[PartialReducerPartitionSpec]] = { ++ smallPartitionFactor: Double = SMALL_PARTITION_FACTOR, ++ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = { + val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId) + if (mapPartitionSizes.exists(_ < 0)) return None + val mapStartIndices = splitSizeListByTargetSize( + mapPartitionSizes, targetSize, smallPartitionFactor) + if (mapStartIndices.length > 1) { ++ val celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = ++ CelebornShuffleState.celebornAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle ++ ++ val throwsFetchFailure = CelebornShuffleState.celebornStageRerunEnabled ++ if (throwsFetchFailure && celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") ++ CelebornShuffleState.registerCelebornSkewedShuffle(shuffleId) ++ } + Some(mapStartIndices.indices.map { i => + val startMapIndex = mapStartIndices(i) + val endMapIndex = if (i == mapStartIndices.length - 1) { +@@ -400,7 +409,15 @@ object ShufflePartitionsUtil extends Logging { + dataSize += mapPartitionSizes(mapIndex) + mapIndex += 1 + } +- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ ++ if (celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ // These `dataSize` variables may not be accurate as they only represent the sum of ++ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. ++ // Please not to use these dataSize variables in any other part of the codebase. ++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize) ++ } else { ++ PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ } + }) + } else { + None diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch new file mode 100644 index 00000000000..9aed835fe96 --- /dev/null +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch @@ -0,0 +1,312 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +index fade0b86dd8..ca0940a9251 100644 +--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala ++++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +@@ -34,6 +34,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut + import org.roaringbitmap.RoaringBitmap + + import org.apache.spark.broadcast.{Broadcast, BroadcastManager} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.Logging + import org.apache.spark.internal.config._ + import org.apache.spark.io.CompressionCodec +@@ -887,6 +888,7 @@ private[spark] class MapOutputTrackerMaster( + shuffleStatus.invalidateSerializedMergeOutputStatusCache() + } + } ++ CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId) + } + + /** +diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala +index acab9a634fc..23eb72c49ac 100644 +--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala ++++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala +@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration + import org.apache.spark.annotation.DeveloperApi + import org.apache.spark.api.python.PythonWorkerFactory + import org.apache.spark.broadcast.BroadcastManager ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.executor.ExecutorBackend + import org.apache.spark.internal.{config, Logging} + import org.apache.spark.internal.config._ +@@ -419,6 +420,7 @@ object SparkEnv extends Logging { + if (isDriver) { + val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath + envInstance.driverTmpDir = Some(sparkFilesDir) ++ CelebornShuffleState.init(envInstance) + } + + envInstance +diff --git a/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +new file mode 100644 +index 00000000000..5e190c512df +--- /dev/null ++++ b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +@@ -0,0 +1,75 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.celeborn ++ ++import java.util.concurrent.ConcurrentHashMap ++import java.util.concurrent.atomic.AtomicBoolean ++ ++import org.apache.spark.SparkEnv ++import org.apache.spark.internal.config.ConfigBuilder ++ ++object CelebornShuffleState { ++ ++ private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = ++ ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") ++ .booleanConf ++ .createWithDefault(false) ++ ++ private val CELEBORN_STAGE_RERUN_ENABLED = ++ ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled") ++ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") ++ .booleanConf ++ .createWithDefault(false) ++ ++ private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean() ++ private val stageRerunEnabled = new AtomicBoolean() ++ private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() ++ ++ // call this from SparkEnv.create ++ def init(env: SparkEnv): Unit = { ++ // cleanup existing state (if required) - and initialize ++ skewShuffleIds.clear() ++ ++ // use env.conf for all initialization, and not SQLConf ++ celebornOptimizeSkewedPartitionReadEnabled.set( ++ env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ)) ++ stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED)) ++ } ++ ++ def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.remove(shuffleId) ++ } ++ ++ def registerCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.add(shuffleId) ++ } ++ ++ def isCelebornSkewedShuffle(shuffleId: Int): Boolean = { ++ skewShuffleIds.contains(shuffleId) ++ } ++ ++ def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = { ++ celebornOptimizeSkewedPartitionReadEnabled.get() ++ } ++ ++ def celebornStageRerunEnabled: Boolean = { ++ stageRerunEnabled.get() ++ } ++ ++} +diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +index 26be8c72bbc..81feaba962c 100644 +--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +@@ -34,6 +34,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture} + + import org.apache.spark._ + import org.apache.spark.broadcast.Broadcast ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.errors.SparkCoreErrors + import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} + import org.apache.spark.internal.Logging +@@ -1897,7 +1898,7 @@ private[spark] class DAGScheduler( + + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || +- disallowStageRetryForTest ++ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId) + + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +new file mode 100644 +index 00000000000..3dc60678461 +--- /dev/null ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +@@ -0,0 +1,35 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.sql.execution.adaptive ++ ++import java.util.Locale ++ ++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} ++ ++object CelebornShuffleUtil { ++ ++ def isCelebornShuffle(shuffleExchangeLike: ShuffleExchangeLike): Boolean = { ++ shuffleExchangeLike match { ++ case exec: ShuffleExchangeExec => ++ exec.shuffleDependency.shuffleHandle ++ .getClass.getName.toLowerCase(Locale.ROOT).contains("celeborn") ++ case _ => false ++ } ++ } ++ ++} +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +index b34ab3e380b..cb0ed9d05a4 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +@@ -47,14 +47,15 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + private def optimizeSkewedPartitions( + shuffleId: Int, + bytesByPartitionId: Array[Long], +- targetSize: Long): Seq[ShufflePartitionSpec] = { ++ targetSize: Long, ++ isCelebornShuffle: Boolean = false): Seq[ShufflePartitionSpec] = { + val smallPartitionFactor = + conf.getConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR) + bytesByPartitionId.indices.flatMap { reduceIndex => + val bytes = bytesByPartitionId(reduceIndex) + if (bytes > targetSize) { + val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs( +- shuffleId, reduceIndex, targetSize, smallPartitionFactor) ++ shuffleId, reduceIndex, targetSize, smallPartitionFactor, isCelebornShuffle) + if (newPartitionSpec.isEmpty) { + CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil + } else { +@@ -76,8 +77,9 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + return shuffle + } + ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(shuffle.shuffle) + val newPartitionsSpec = optimizeSkewedPartitions( +- mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize) ++ mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize, isCelebornShuffle) + // return origin plan if we can not optimize partitions + if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) { + shuffle +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +index 37cdea084d8..4694a06919e 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +@@ -152,8 +152,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize)) + + val leftParts = if (isLeftSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(left.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- left.mapStats.get.shuffleId, partitionIndex, leftTargetSize) ++ left.mapStats.get.shuffleId, partitionIndex, leftTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Left side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " + +@@ -166,8 +168,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + } + + val rightParts = if (isRightSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(right.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- right.mapStats.get.shuffleId, partitionIndex, rightTargetSize) ++ right.mapStats.get.shuffleId, partitionIndex, rightTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Right side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +index dbed66683b0..d656c8af6b7 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive + import scala.collection.mutable.ArrayBuffer + + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.Logging + import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} + +@@ -380,13 +381,21 @@ object ShufflePartitionsUtil extends Logging { + shuffleId: Int, + reducerId: Int, + targetSize: Long, +- smallPartitionFactor: Double = SMALL_PARTITION_FACTOR) +- : Option[Seq[PartialReducerPartitionSpec]] = { ++ smallPartitionFactor: Double = SMALL_PARTITION_FACTOR, ++ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = { + val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId) + if (mapPartitionSizes.exists(_ < 0)) return None + val mapStartIndices = splitSizeListByTargetSize( + mapPartitionSizes, targetSize, smallPartitionFactor) + if (mapStartIndices.length > 1) { ++ val celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = ++ CelebornShuffleState.celebornAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle ++ ++ val throwsFetchFailure = CelebornShuffleState.celebornStageRerunEnabled ++ if (throwsFetchFailure && celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") ++ CelebornShuffleState.registerCelebornSkewedShuffle(shuffleId) ++ } + Some(mapStartIndices.indices.map { i => + val startMapIndex = mapStartIndices(i) + val endMapIndex = if (i == mapStartIndices.length - 1) { +@@ -400,7 +409,15 @@ object ShufflePartitionsUtil extends Logging { + dataSize += mapPartitionSizes(mapIndex) + mapIndex += 1 + } +- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ ++ if (celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ // These `dataSize` variables may not be accurate as they only represent the sum of ++ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. ++ // Please not to use these dataSize variables in any other part of the codebase. ++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize) ++ } else { ++ PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ } + }) + } else { + None diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch new file mode 100644 index 00000000000..553bdeae668 --- /dev/null +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch @@ -0,0 +1,312 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +index 9a7a3b0c0e7..543423dadd9 100644 +--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala ++++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +@@ -34,6 +34,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut + import org.roaringbitmap.RoaringBitmap + + import org.apache.spark.broadcast.{Broadcast, BroadcastManager} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.Logging + import org.apache.spark.internal.config._ + import org.apache.spark.io.CompressionCodec +@@ -916,6 +917,7 @@ private[spark] class MapOutputTrackerMaster( + shuffleStatus.invalidateSerializedMergeOutputStatusCache() + } + } ++ CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId) + } + + /** +diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala +index edad91a0c6f..76b377729a0 100644 +--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala ++++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala +@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration + import org.apache.spark.annotation.DeveloperApi + import org.apache.spark.api.python.PythonWorkerFactory + import org.apache.spark.broadcast.BroadcastManager ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.executor.ExecutorBackend + import org.apache.spark.internal.{config, Logging} + import org.apache.spark.internal.config._ +@@ -419,6 +420,7 @@ object SparkEnv extends Logging { + if (isDriver) { + val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath + envInstance.driverTmpDir = Some(sparkFilesDir) ++ CelebornShuffleState.init(envInstance) + } + + envInstance +diff --git a/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +new file mode 100644 +index 00000000000..5e190c512df +--- /dev/null ++++ b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +@@ -0,0 +1,75 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.celeborn ++ ++import java.util.concurrent.ConcurrentHashMap ++import java.util.concurrent.atomic.AtomicBoolean ++ ++import org.apache.spark.SparkEnv ++import org.apache.spark.internal.config.ConfigBuilder ++ ++object CelebornShuffleState { ++ ++ private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = ++ ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") ++ .booleanConf ++ .createWithDefault(false) ++ ++ private val CELEBORN_STAGE_RERUN_ENABLED = ++ ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled") ++ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") ++ .booleanConf ++ .createWithDefault(false) ++ ++ private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean() ++ private val stageRerunEnabled = new AtomicBoolean() ++ private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() ++ ++ // call this from SparkEnv.create ++ def init(env: SparkEnv): Unit = { ++ // cleanup existing state (if required) - and initialize ++ skewShuffleIds.clear() ++ ++ // use env.conf for all initialization, and not SQLConf ++ celebornOptimizeSkewedPartitionReadEnabled.set( ++ env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ)) ++ stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED)) ++ } ++ ++ def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.remove(shuffleId) ++ } ++ ++ def registerCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.add(shuffleId) ++ } ++ ++ def isCelebornSkewedShuffle(shuffleId: Int): Boolean = { ++ skewShuffleIds.contains(shuffleId) ++ } ++ ++ def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = { ++ celebornOptimizeSkewedPartitionReadEnabled.get() ++ } ++ ++ def celebornStageRerunEnabled: Boolean = { ++ stageRerunEnabled.get() ++ } ++ ++} +diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +index 89d16e57934..3b9094f3254 100644 +--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +@@ -34,6 +34,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture} + + import org.apache.spark._ + import org.apache.spark.broadcast.Broadcast ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.errors.SparkCoreErrors + import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} + import org.apache.spark.internal.Logging +@@ -1962,7 +1963,7 @@ private[spark] class DAGScheduler( + + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || +- disallowStageRetryForTest ++ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId) + + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +new file mode 100644 +index 00000000000..3dc60678461 +--- /dev/null ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +@@ -0,0 +1,35 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.sql.execution.adaptive ++ ++import java.util.Locale ++ ++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} ++ ++object CelebornShuffleUtil { ++ ++ def isCelebornShuffle(shuffleExchangeLike: ShuffleExchangeLike): Boolean = { ++ shuffleExchangeLike match { ++ case exec: ShuffleExchangeExec => ++ exec.shuffleDependency.shuffleHandle ++ .getClass.getName.toLowerCase(Locale.ROOT).contains("celeborn") ++ case _ => false ++ } ++ } ++ ++} +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +index abd096b9c7c..ff0363f87d8 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +@@ -47,14 +47,15 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + private def optimizeSkewedPartitions( + shuffleId: Int, + bytesByPartitionId: Array[Long], +- targetSize: Long): Seq[ShufflePartitionSpec] = { ++ targetSize: Long, ++ isCelebornShuffle: Boolean = false): Seq[ShufflePartitionSpec] = { + val smallPartitionFactor = + conf.getConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR) + bytesByPartitionId.indices.flatMap { reduceIndex => + val bytes = bytesByPartitionId(reduceIndex) + if (bytes > targetSize) { + val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs( +- shuffleId, reduceIndex, targetSize, smallPartitionFactor) ++ shuffleId, reduceIndex, targetSize, smallPartitionFactor, isCelebornShuffle) + if (newPartitionSpec.isEmpty) { + CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil + } else { +@@ -77,8 +78,9 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + return shuffle + } + ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(shuffle.shuffle) + val newPartitionsSpec = optimizeSkewedPartitions( +- mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize) ++ mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize, isCelebornShuffle) + // return origin plan if we can not optimize partitions + if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) { + shuffle +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +index 37cdea084d8..4694a06919e 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +@@ -152,8 +152,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize)) + + val leftParts = if (isLeftSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(left.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- left.mapStats.get.shuffleId, partitionIndex, leftTargetSize) ++ left.mapStats.get.shuffleId, partitionIndex, leftTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Left side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " + +@@ -166,8 +168,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + } + + val rightParts = if (isRightSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(right.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- right.mapStats.get.shuffleId, partitionIndex, rightTargetSize) ++ right.mapStats.get.shuffleId, partitionIndex, rightTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Right side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +index 9370b3d8d1d..d36e26a1376 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive + import scala.collection.mutable.ArrayBuffer + + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.Logging + import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} + +@@ -382,13 +383,21 @@ object ShufflePartitionsUtil extends Logging { + shuffleId: Int, + reducerId: Int, + targetSize: Long, +- smallPartitionFactor: Double = SMALL_PARTITION_FACTOR) +- : Option[Seq[PartialReducerPartitionSpec]] = { ++ smallPartitionFactor: Double = SMALL_PARTITION_FACTOR, ++ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = { + val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId) + if (mapPartitionSizes.exists(_ < 0)) return None + val mapStartIndices = splitSizeListByTargetSize( + mapPartitionSizes, targetSize, smallPartitionFactor) + if (mapStartIndices.length > 1) { ++ val celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = ++ CelebornShuffleState.celebornAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle ++ ++ val throwsFetchFailure = CelebornShuffleState.celebornStageRerunEnabled ++ if (throwsFetchFailure && celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") ++ CelebornShuffleState.registerCelebornSkewedShuffle(shuffleId) ++ } + Some(mapStartIndices.indices.map { i => + val startMapIndex = mapStartIndices(i) + val endMapIndex = if (i == mapStartIndices.length - 1) { +@@ -402,7 +411,15 @@ object ShufflePartitionsUtil extends Logging { + dataSize += mapPartitionSizes(mapIndex) + mapIndex += 1 + } +- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ ++ if (celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ // These `dataSize` variables may not be accurate as they only represent the sum of ++ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. ++ // Please not to use these dataSize variables in any other part of the codebase. ++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize) ++ } else { ++ PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ } + }) + } else { + None diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java new file mode 100644 index 00000000000..6bf47addca9 --- /dev/null +++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtil.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn; + +import java.util.*; + +import org.apache.commons.lang3.tuple.Pair; + +import org.apache.celeborn.common.protocol.PartitionLocation; + +public class CelebornPartitionUtil { + /** + * The general idea is to divide each skew partition into smaller partitions: + * + *

- Spark driver will calculate the number of sub-partitions: {@code subPartitionSize = + * skewPartitionTotalSize / subPartitionTargetSize} + * + *

- In Celeborn, we divide the skew partition into {@code subPartitionSize} small partitions + * by PartitionLocation chunk offsets. This allows them to run in parallel Spark tasks. + * + *

For example, one skewed partition has 2 PartitionLocation: + * + *

+ * + * If we want to divide it into 3 sub-partitions (each sub-partition target size is 2000/3), the + * result will be: + * + * + * + * Note: (0, 3) means chunks with chunkIndex 0-1-2-3, four chunks. + * + * @param locations PartitionLocation information belonging to the reduce partition + * @param subPartitionSize the number of sub-partitions separated from the reduce partition + * @param subPartitionIndex current sub-partition index + * @return a map of partitionUniqueId to chunkRange pairs for one subtask of skew partitions + */ + public static Map> splitSkewedPartitionLocations( + ArrayList locations, int subPartitionSize, int subPartitionIndex) { + locations.sort(Comparator.comparing((PartitionLocation p) -> p.getUniqueId())); + long totalPartitionSize = + locations.stream().mapToLong((PartitionLocation p) -> p.getStorageInfo().fileSize).sum(); + long step = totalPartitionSize / subPartitionSize; + long startOffset = step * subPartitionIndex; + long endOffset = + subPartitionIndex < subPartitionSize - 1 + ? step * (subPartitionIndex + 1) + : totalPartitionSize + 1; // last subPartition should include all remaining data + + long partitionLocationOffset = 0; + Map> chunkRange = new HashMap<>(); + for (PartitionLocation p : locations) { + int left = -1; + int right = -1; + Iterator chunkOffsets = p.getStorageInfo().getChunkOffsets().iterator(); + // Start from index 1 since the first chunk offset is always 0. + chunkOffsets.next(); + int j = 1; + while (chunkOffsets.hasNext()) { + long currentOffset = partitionLocationOffset + chunkOffsets.next(); + if (currentOffset > startOffset && left < 0) { + left = j - 1; + } + if (currentOffset <= endOffset) { + right = j - 1; + } + if (left >= 0 && right >= 0) { + chunkRange.put(p.getUniqueId(), Pair.of(left, right)); + } + j++; + } + partitionLocationOffset += p.getStorageInfo().getFileSize(); + } + return chunkRange; + } +} diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index ab51e28d388..b35e549ae38 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -18,14 +18,14 @@ package org.apache.spark.shuffle.celeborn import java.io.IOException -import java.nio.file.Files -import java.util +import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Set => JSet} import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit} import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ import com.google.common.annotations.VisibleForTesting +import org.apache.commons.lang3.tuple.Pair import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext} import org.apache.spark.celeborn.ExceptionMakerHelper import org.apache.spark.internal.Logging @@ -35,7 +35,7 @@ import org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.celeborn.client.{DummyShuffleClient, ShuffleClient} +import org.apache.celeborn.client.{ClientUtils, ShuffleClient} import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback} import org.apache.celeborn.common.CelebornConf @@ -122,15 +122,41 @@ class CelebornShuffleReader[K, C]( } // host-port -> (TransportClient, PartitionLocation Array, PbOpenStreamList) - val workerRequestMap = new util.HashMap[ + val workerRequestMap = new JHashMap[ String, - (TransportClient, util.ArrayList[PartitionLocation], PbOpenStreamList.Builder)]() + (TransportClient, JArrayList[PartitionLocation], PbOpenStreamList.Builder)]() + // partitionId -> (partition uniqueId -> chunkRange pair) + val partitionId2ChunkRange = new JHashMap[Int, JMap[String, Pair[Integer, Integer]]]() + + val partitionId2PartitionLocations = new JHashMap[Int, JSet[PartitionLocation]]() var partCnt = 0 + // if startMapIndex > endMapIndex, means partition is skew partition and read by Celeborn implementation. + // locations will split to sub-partitions with startMapIndex size. + val splitSkewPartitionWithoutMapRange = + ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex) + (startPartition until endPartition).foreach { partitionId => if (fileGroups.partitionGroups.containsKey(partitionId)) { - fileGroups.partitionGroups.get(partitionId).asScala.foreach { location => + var locations = fileGroups.partitionGroups.get(partitionId) + if (splitSkewPartitionWithoutMapRange) { + val partitionLocation2ChunkRange = CelebornPartitionUtil.splitSkewedPartitionLocations( + new JArrayList(locations), + startMapIndex, + endMapIndex) + partitionId2ChunkRange.put(partitionId, partitionLocation2ChunkRange) + // filter locations avoid OPEN_STREAM when split skew partition without map range + val filterLocations = locations.asScala + .filter { location => + null != partitionLocation2ChunkRange && + partitionLocation2ChunkRange.containsKey(location.getUniqueId) + } + locations = filterLocations.asJava + partitionId2PartitionLocations.put(partitionId, locations) + } + + locations.asScala.foreach { location => partCnt += 1 val hostPort = location.hostAndFetchPort if (!workerRequestMap.containsKey(hostPort)) { @@ -142,7 +168,7 @@ class CelebornShuffleReader[K, C]( pbOpenStreamList.setShuffleKey(shuffleKey) workerRequestMap.put( hostPort, - (client, new util.ArrayList[PartitionLocation], pbOpenStreamList)) + (client, new JArrayList[PartitionLocation], pbOpenStreamList)) } catch { case ex: Exception => shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort, ex) @@ -203,13 +229,22 @@ class CelebornShuffleReader[K, C]( def createInputStream(partitionId: Int): Unit = { val locations = - if (fileGroups.partitionGroups.containsKey(partitionId)) { - new util.ArrayList(fileGroups.partitionGroups.get(partitionId)) - } else new util.ArrayList[PartitionLocation]() + if (splitSkewPartitionWithoutMapRange) { + partitionId2PartitionLocations.get(partitionId) + } else { + fileGroups.partitionGroups.get(partitionId) + } + + val locationList = + if (null == locations) { + new JArrayList[PartitionLocation]() + } else { + new JArrayList[PartitionLocation](locations) + } val streamHandlers = if (locations != null) { - val streamHandlerArr = new util.ArrayList[PbStreamHandler](locations.size()) - locations.asScala.foreach { loc => + val streamHandlerArr = new JArrayList[PbStreamHandler](locationList.size) + locationList.asScala.foreach { loc => streamHandlerArr.add(locationStreamHandlerMap.get(loc)) } streamHandlerArr @@ -226,8 +261,10 @@ class CelebornShuffleReader[K, C]( endMapIndex, if (throwsFetchFailure) ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER else null, - locations, + locationList, streamHandlers, + fileGroups.pushFailedBatches, + partitionId2ChunkRange.get(partitionId), fileGroups.mapAttempts, metricsCallback) streams.put(partitionId, inputStream) @@ -414,7 +451,6 @@ class CelebornShuffleReader[K, C]( def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = { dep.serializer.newInstance() } - } object CelebornShuffleReader { diff --git a/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtilSuiteJ.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtilSuiteJ.java new file mode 100644 index 00000000000..989dd31a94a --- /dev/null +++ b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornPartitionUtilSuiteJ.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn; + +import java.util.*; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.Assert; +import org.junit.Test; + +import org.apache.celeborn.common.protocol.PartitionLocation; +import org.apache.celeborn.common.protocol.StorageInfo; + +public class CelebornPartitionUtilSuiteJ { + @Test + public void testSkewPartitionSplit() { + + ArrayList locations = new ArrayList<>(); + for (int i = 0; i < 13; i++) { + locations.add(genPartitionLocation(i, new Long[] {0L, 100L, 200L, 300L, 500L, 1000L})); + } + locations.add(genPartitionLocation(91, new Long[] {0L, 1L})); + + int subPartitionSize = 3; + + Map> result1 = + CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 0); + Map> expectResult1 = + genRanges( + new Object[][] { + {"0-0", 0, 4}, + {"0-1", 0, 4}, + {"0-10", 0, 4}, + {"0-11", 0, 4}, + {"0-12", 0, 2} + }); + Assert.assertEquals(expectResult1, result1); + + Map> result2 = + CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 1); + Map> expectResult2 = + genRanges( + new Object[][] { + {"0-12", 3, 4}, + {"0-2", 0, 4}, + {"0-3", 0, 4}, + {"0-4", 0, 4}, + {"0-5", 0, 3} + }); + Assert.assertEquals(expectResult2, result2); + + Map> result3 = + CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 2); + Map> expectResult3 = + genRanges( + new Object[][] { + {"0-5", 4, 4}, + {"0-6", 0, 4}, + {"0-7", 0, 4}, + {"0-8", 0, 4}, + {"0-9", 0, 4}, + {"0-91", 0, 0} + }); + Assert.assertEquals(expectResult3, result3); + } + + @Test + public void testBoundary() { + ArrayList locations = new ArrayList<>(); + locations.add(genPartitionLocation(0, new Long[] {0L, 100L, 200L, 300L, 400L, 500L})); + + for (int i = 0; i < 5; i++) { + Map> result = + CelebornPartitionUtil.splitSkewedPartitionLocations(locations, 5, i); + Map> expectResult = genRanges(new Object[][] {{"0-0", i, i}}); + Assert.assertEquals(expectResult, result); + } + } + + @Test + public void testSplitStable() { + ArrayList locations = new ArrayList<>(); + for (int i = 0; i < 13; i++) { + locations.add(genPartitionLocation(i, new Long[] {0L, 100L, 200L, 300L, 500L, 1000L})); + } + locations.add(genPartitionLocation(91, new Long[] {0L, 1L})); + + Collections.shuffle(locations); + + Map> result = + CelebornPartitionUtil.splitSkewedPartitionLocations(locations, 3, 0); + Map> expectResult = + genRanges( + new Object[][] { + {"0-0", 0, 4}, + {"0-1", 0, 4}, + {"0-10", 0, 4}, + {"0-11", 0, 4}, + {"0-12", 0, 2} + }); + Assert.assertEquals(expectResult, result); + } + + private ArrayList genPartitionLocations(Map epochToOffsets) { + ArrayList locations = new ArrayList<>(); + epochToOffsets.forEach( + (epoch, offsets) -> { + PartitionLocation location = + new PartitionLocation( + 0, epoch, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); + StorageInfo storageInfo = + new StorageInfo( + StorageInfo.Type.HDD, + "mountPoint", + false, + "filePath", + StorageInfo.LOCAL_DISK_MASK, + 1, + Arrays.asList(offsets)); + location.setStorageInfo(storageInfo); + locations.add(location); + }); + return locations; + } + + private PartitionLocation genPartitionLocation(int epoch, Long[] offsets) { + PartitionLocation location = + new PartitionLocation(0, epoch, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY); + StorageInfo storageInfo = + new StorageInfo( + StorageInfo.Type.HDD, + "mountPoint", + false, + "filePath", + StorageInfo.LOCAL_DISK_MASK, + offsets[offsets.length - 1], + Arrays.asList(offsets)); + location.setStorageInfo(storageInfo); + return location; + } + + private Map> genRanges(Object[][] inputs) { + Map> ranges = new HashMap<>(); + for (Object[] idToChunkRange : inputs) { + String uid = (String) idToChunkRange[0]; + Pair range = Pair.of((int) idToChunkRange[1], (int) idToChunkRange[2]); + ranges.put(uid, range); + } + return ranges; + } +} diff --git a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java index 2b933181faa..6b3673b1843 100644 --- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -28,9 +28,11 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,6 +46,7 @@ import org.apache.celeborn.common.rpc.RpcEndpointRef; import org.apache.celeborn.common.util.ExceptionMaker; import org.apache.celeborn.common.util.JavaUtils; +import org.apache.celeborn.common.write.PushFailedBatch; import org.apache.celeborn.common.write.PushState; public class DummyShuffleClient extends ShuffleClient { @@ -139,6 +142,8 @@ public CelebornInputStream readPartition( ExceptionMaker exceptionMaker, ArrayList locations, ArrayList streamHandlers, + Map> failedBatchSetMap, + Map> chunksRange, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException { diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 673c9382437..8690f5cebde 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -20,9 +20,11 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; +import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.fs.FileSystem; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -39,6 +41,7 @@ import org.apache.celeborn.common.rpc.RpcEndpointRef; import org.apache.celeborn.common.util.CelebornHadoopUtils; import org.apache.celeborn.common.util.ExceptionMaker; +import org.apache.celeborn.common.write.PushFailedBatch; import org.apache.celeborn.common.write.PushState; /** @@ -241,6 +244,8 @@ public CelebornInputStream readPartition( null, null, null, + null, + null, metricsCallback); } @@ -255,6 +260,8 @@ public abstract CelebornInputStream readPartition( ExceptionMaker exceptionMaker, ArrayList locations, ArrayList streamHandlers, + Map> failedBatchSetMap, + Map> chunksRange, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException; diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index fe373c8d22a..cf48f95bc72 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -69,6 +69,7 @@ import org.apache.celeborn.common.unsafe.Platform; import org.apache.celeborn.common.util.*; import org.apache.celeborn.common.write.DataBatches; +import org.apache.celeborn.common.write.PushFailedBatch; import org.apache.celeborn.common.write.PushState; public class ShuffleClientImpl extends ShuffleClient { @@ -146,30 +147,37 @@ protected Compressor initialValue() { private final ReviveManager reviveManager; + private final boolean dataPushFailureTrackingEnabled; + public static class ReduceFileGroups { public Map> partitionGroups; + public Map> pushFailedBatches; public int[] mapAttempts; public Set partitionIds; ReduceFileGroups( Map> partitionGroups, int[] mapAttempts, - Set partitionIds) { + Set partitionIds, + Map> pushFailedBatches) { this.partitionGroups = partitionGroups; this.mapAttempts = mapAttempts; this.partitionIds = partitionIds; + this.pushFailedBatches = pushFailedBatches; } public ReduceFileGroups() { this.partitionGroups = null; this.mapAttempts = null; this.partitionIds = null; + this.pushFailedBatches = null; } public void update(ReduceFileGroups fileGroups) { partitionGroups = fileGroups.partitionGroups; mapAttempts = fileGroups.mapAttempts; partitionIds = fileGroups.partitionIds; + pushFailedBatches = fileGroups.pushFailedBatches; } } @@ -199,6 +207,7 @@ public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier u pushDataTimeout = conf.pushDataTimeoutMs(); } authEnabled = conf.authEnabledOnClient(); + dataPushFailureTrackingEnabled = conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled(); // init rpc env rpcEnv = @@ -1110,6 +1119,10 @@ public void onSuccess(ByteBuffer response) { attemptId, partitionId, nextBatchId); + if (dataPushFailureTrackingEnabled) { + pushState.addFailedBatch( + latest.getUniqueId(), new PushFailedBatch(mapId, attemptId, nextBatchId)); + } ReviveRequest reviveRequest = new ReviveRequest( shuffleId, @@ -1176,6 +1189,10 @@ public void onSuccess(ByteBuffer response) { @Override public void onFailure(Throwable e) { + if (dataPushFailureTrackingEnabled) { + pushState.addFailedBatch( + latest.getUniqueId(), new PushFailedBatch(mapId, attemptId, nextBatchId)); + } if (pushState.exception.get() != null) { return; } @@ -1542,6 +1559,13 @@ public void onSuccess(ByteBuffer response) { pushState.onSuccess(hostPort); callback.onSuccess(ByteBuffer.wrap(new byte[] {StatusCode.SOFT_SPLIT.getValue()})); } else { + if (dataPushFailureTrackingEnabled) { + for (DataBatches.DataBatch resubmitBatch : batchesNeedResubmit) { + pushState.addFailedBatch( + resubmitBatch.loc.getUniqueId(), + new PushFailedBatch(mapId, attemptId, resubmitBatch.batchId)); + } + } ReviveRequest[] requests = addAndGetReviveRequests( shuffleId, mapId, attemptId, batchesNeedResubmit, StatusCode.HARD_SPLIT); @@ -1597,6 +1621,12 @@ public void onSuccess(ByteBuffer response) { @Override public void onFailure(Throwable e) { + if (dataPushFailureTrackingEnabled) { + for (int i = 0; i < numBatches; i++) { + pushState.addFailedBatch( + partitionUniqueIds[i], new PushFailedBatch(mapId, attemptId, batchIds[i])); + } + } if (pushState.exception.get() != null) { return; } @@ -1717,7 +1747,13 @@ private void mapEndInternal( MapperEndResponse response = lifecycleManagerRef.askSync( - new MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId), + new MapperEnd( + shuffleId, + mapId, + attemptId, + numMappers, + partitionId, + pushState.getFailedBatches()), rpcMaxRetries, rpcRetryWait, ClassTag$.MODULE$.apply(MapperEndResponse.class)); @@ -1782,7 +1818,10 @@ protected Tuple3 loadFileGroupInternal( response.fileGroup().size()); return Tuple3.apply( new ReduceFileGroups( - response.fileGroup(), response.attempts(), response.partitionIds()), + response.fileGroup(), + response.attempts(), + response.partitionIds(), + response.pushFailedBatches()), null, null); case SHUFFLE_NOT_REGISTERED: @@ -1791,7 +1830,10 @@ protected Tuple3 loadFileGroupInternal( // return empty result return Tuple3.apply( new ReduceFileGroups( - response.fileGroup(), response.attempts(), response.partitionIds()), + response.fileGroup(), + response.attempts(), + response.partitionIds(), + response.pushFailedBatches()), null, null); case STAGE_END_TIME_OUT: @@ -1863,6 +1905,8 @@ public CelebornInputStream readPartition( ExceptionMaker exceptionMaker, ArrayList locations, ArrayList streamHandlers, + Map> failedBatchSetMap, + Map> chunksRange, int[] mapAttempts, MetricsCallback metricsCallback) throws IOException { @@ -1895,6 +1939,8 @@ public CelebornInputStream readPartition( locations, streamHandlers, mapAttempts, + failedBatchSetMap, + chunksRange, attemptNumber, taskId, startMapIndex, diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index e525a135ae2..ea3df88d95a 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -30,10 +30,12 @@ import com.google.common.util.concurrent.Uninterruptibles; import io.netty.buffer.ByteBuf; import net.jpountz.lz4.LZ4Exception; +import org.apache.commons.lang3.tuple.Pair; import org.roaringbitmap.RoaringBitmap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.celeborn.client.ClientUtils; import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.client.compress.Decompressor; import org.apache.celeborn.common.CelebornConf; @@ -45,6 +47,7 @@ import org.apache.celeborn.common.unsafe.Platform; import org.apache.celeborn.common.util.ExceptionMaker; import org.apache.celeborn.common.util.Utils; +import org.apache.celeborn.common.write.PushFailedBatch; public abstract class CelebornInputStream extends InputStream { private static final Logger logger = LoggerFactory.getLogger(CelebornInputStream.class); @@ -56,6 +59,8 @@ public static CelebornInputStream create( ArrayList locations, ArrayList streamHandlers, int[] attempts, + Map> failedBatchSetMap, + Map> chunksRange, int attemptNumber, long taskId, int startMapIndex, @@ -68,27 +73,57 @@ public static CelebornInputStream create( ExceptionMaker exceptionMaker, MetricsCallback metricsCallback) throws IOException { - if (locations == null || locations.size() == 0) { + if (locations == null || locations.isEmpty()) { return emptyInputStream; } else { - return new CelebornInputStreamImpl( - conf, - clientFactory, - shuffleKey, - locations, - streamHandlers, - attempts, - attemptNumber, - taskId, - startMapIndex, - endMapIndex, - fetchExcludedWorkers, - shuffleClient, - appShuffleId, - shuffleId, - partitionId, - exceptionMaker, - metricsCallback); + // if startMapIndex > endMapIndex, means partition is skew partition and read by Celeborn + // implementation. + // locations will split to sub-partitions with startMapIndex size. + boolean readSkewPartitionWithoutMapRange = + ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex); + if (readSkewPartitionWithoutMapRange) { + return new CelebornInputStreamImpl( + conf, + clientFactory, + shuffleKey, + locations, + streamHandlers, + attempts, + failedBatchSetMap, + attemptNumber, + taskId, + chunksRange, + fetchExcludedWorkers, + shuffleClient, + appShuffleId, + shuffleId, + partitionId, + exceptionMaker, + true, + metricsCallback); + } else { + return new CelebornInputStreamImpl( + conf, + clientFactory, + shuffleKey, + locations, + streamHandlers, + attempts, + failedBatchSetMap, + attemptNumber, + taskId, + startMapIndex, + endMapIndex, + /*partitionLocationToChunkRange = */ null, + fetchExcludedWorkers, + shuffleClient, + appShuffleId, + shuffleId, + partitionId, + exceptionMaker, + false, + metricsCallback); + } } } @@ -136,9 +171,12 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private final long taskId; private final int startMapIndex; private final int endMapIndex; + private final Map> partitionLocationToChunkRange; private Map> batchesRead = new HashMap<>(); + private final Map> failedBatches; + private byte[] compressedBuf; private byte[] rawDataBuf; private Decompressor decompressor; @@ -175,6 +213,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private ExceptionMaker exceptionMaker; private boolean closed = false; + private final boolean readSkewPartitionWithoutMapRange; + CelebornInputStreamImpl( CelebornConf conf, TransportClientFactory clientFactory, @@ -182,16 +222,62 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { ArrayList locations, ArrayList streamHandlers, int[] attempts, + Map> failedBatchSet, + int attemptNumber, + long taskId, + Map> partitionLocationToChunkRange, + ConcurrentHashMap fetchExcludedWorkers, + ShuffleClient shuffleClient, + int appShuffleId, + int shuffleId, + int partitionId, + ExceptionMaker exceptionMaker, + boolean splitSkewPartitionWithoutMapRange, + MetricsCallback metricsCallback) + throws IOException { + this( + conf, + clientFactory, + shuffleKey, + locations, + streamHandlers, + attempts, + failedBatchSet, + attemptNumber, + taskId, + /*startMapIndex = */ -1, + /*endMapIndex = */ -1, + partitionLocationToChunkRange, + fetchExcludedWorkers, + shuffleClient, + appShuffleId, + shuffleId, + partitionId, + exceptionMaker, + splitSkewPartitionWithoutMapRange, + metricsCallback); + } + + CelebornInputStreamImpl( + CelebornConf conf, + TransportClientFactory clientFactory, + String shuffleKey, + ArrayList locations, + ArrayList streamHandlers, + int[] attempts, + Map> failedBatchSet, int attemptNumber, long taskId, int startMapIndex, int endMapIndex, + Map> partitionLocationToChunkRange, ConcurrentHashMap fetchExcludedWorkers, ShuffleClient shuffleClient, int appShuffleId, int shuffleId, int partitionId, ExceptionMaker exceptionMaker, + boolean readSkewPartitionWithoutMapRange, MetricsCallback metricsCallback) throws IOException { this.conf = conf; @@ -206,12 +292,15 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { this.taskId = taskId; this.startMapIndex = startMapIndex; this.endMapIndex = endMapIndex; + this.partitionLocationToChunkRange = partitionLocationToChunkRange; this.rangeReadFilter = conf.shuffleRangeReadFilterEnabled(); this.enabledReadLocalShuffle = conf.enableReadLocalShuffleFile(); this.localHostAddress = Utils.localHostName(conf); this.shuffleCompressionEnabled = !conf.shuffleCompressionCodec().equals(CompressionCodec.NONE); this.fetchExcludedWorkerExpireTimeout = conf.clientFetchExcludedWorkerExpireTimeout(); + this.failedBatches = failedBatchSet; + this.readSkewPartitionWithoutMapRange = readSkewPartitionWithoutMapRange; this.fetchExcludedWorkers = fetchExcludedWorkers; if (conf.clientPushReplicateEnabled()) { @@ -260,7 +349,11 @@ private Tuple2 nextReadableLocation() { return null; } PartitionLocation currentLocation = locations.get(fileIndex); - while (skipLocation(startMapIndex, endMapIndex, currentLocation)) { + // if pushShuffleFailureTrackingEnabled is true, should not skip location + while ((readSkewPartitionWithoutMapRange + && !partitionLocationToChunkRange.containsKey(currentLocation.getUniqueId())) + || (!readSkewPartitionWithoutMapRange + && skipLocation(startMapIndex, endMapIndex, currentLocation))) { skipCount.increment(); fileIndex++; if (fileIndex == locationCount) { @@ -336,7 +429,7 @@ private PartitionReader createReaderWithRetry( lastException = e; shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort(), e); fetchChunkRetryCnt++; - if (location.hasPeer()) { + if (location.hasPeer() && !readSkewPartitionWithoutMapRange) { // fetchChunkRetryCnt % 2 == 0 means both replicas have been tried, // so sleep before next try. if (fetchChunkRetryCnt % 2 == 0) { @@ -407,7 +500,7 @@ private ByteBuf getNextChunk() throws IOException { + currentReader.getLocation(), e); } else { - if (currentReader.getLocation().hasPeer()) { + if (currentReader.getLocation().hasPeer() && !readSkewPartitionWithoutMapRange) { logger.warn( "Fetch chunk failed {}/{} times for location {}, change to peer", fetchChunkRetryCnt, @@ -444,6 +537,15 @@ private PartitionReader createReader( throws IOException, InterruptedException { StorageInfo storageInfo = location.getStorageInfo(); + + int startChunkIndex = -1; + int endChunkIndex = -1; + if (partitionLocationToChunkRange != null) { + Pair chunkRange = + partitionLocationToChunkRange.get(location.getUniqueId()); + startChunkIndex = chunkRange.getLeft(); + endChunkIndex = chunkRange.getRight(); + } switch (storageInfo.getType()) { case HDD: case SSD: @@ -473,7 +575,9 @@ private PartitionReader createReader( endMapIndex, fetchChunkRetryCnt, fetchChunkMaxRetry, - callback); + callback, + startChunkIndex, + endChunkIndex); } case S3: case HDFS: @@ -621,6 +725,7 @@ private boolean fillBuffer() throws IOException { return false; } + PushFailedBatch failedBatch = new PushFailedBatch(-1, -1, -1); boolean hasData = false; while (currentChunk.isReadable() || moveToNextChunk()) { currentChunk.readBytes(sizeBuf); @@ -645,6 +750,19 @@ private boolean fillBuffer() throws IOException { // de-duplicate if (attemptId == attempts[mapId]) { + if (readSkewPartitionWithoutMapRange) { + Set failedBatchSet = + this.failedBatches.get(currentReader.getLocation().getUniqueId()); + if (null != failedBatchSet) { + failedBatch.setMapId(mapId); + failedBatch.setAttemptId(attemptId); + failedBatch.setBatchId(batchId); + if (failedBatchSet.contains(failedBatch)) { + logger.warn("Skip duplicated batch: {}.", failedBatch); + continue; + } + } + } if (!batchesRead.containsKey(mapId)) { Set batchSet = new HashSet<>(); batchesRead.put(mapId, batchSet); diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index 3158aa12f72..29236273663 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -55,6 +55,8 @@ public class WorkerPartitionReader implements PartitionReader { private int returnedChunks; private int chunkIndex; + private int startChunkIndex; + private int endChunkIndex; private final LinkedBlockingQueue results; private final ChunkReceivedCallback callback; @@ -80,7 +82,9 @@ public class WorkerPartitionReader implements PartitionReader { int endMapIndex, int fetchChunkRetryCnt, int fetchChunkMaxRetry, - MetricsCallback metricsCallback) + MetricsCallback metricsCallback, + int startChunkIndex, + int endChunkIndex) throws IOException, InterruptedException { this.shuffleKey = shuffleKey; fetchMaxReqsInFlight = conf.clientFetchMaxReqsInFlight(); @@ -133,7 +137,12 @@ public void onFailure(int chunkIndex, Throwable e) { } else { this.streamHandler = pbStreamHandler; } - + this.startChunkIndex = startChunkIndex == -1 ? 0 : startChunkIndex; + this.endChunkIndex = + endChunkIndex == -1 + ? streamHandler.getNumChunks() - 1 + : Math.min(streamHandler.getNumChunks() - 1, endChunkIndex); + this.chunkIndex = this.startChunkIndex; this.location = location; this.clientFactory = clientFactory; this.fetchChunkRetryCnt = fetchChunkRetryCnt; @@ -144,13 +153,13 @@ public void onFailure(int chunkIndex, Throwable e) { @Override public boolean hasNext() { - return returnedChunks < streamHandler.getNumChunks(); + return returnedChunks < endChunkIndex - startChunkIndex + 1; } @Override public ByteBuf next() throws IOException, InterruptedException { checkException(); - if (chunkIndex < streamHandler.getNumChunks()) { + if (chunkIndex <= endChunkIndex) { fetchChunks(); } ByteBuf chunk = null; @@ -202,10 +211,10 @@ public PartitionLocation getLocation() { } private void fetchChunks() throws IOException, InterruptedException { - final int inFlight = chunkIndex - returnedChunks; + final int inFlight = chunkIndex - startChunkIndex - returnedChunks; if (inFlight < fetchMaxReqsInFlight) { final int toFetch = - Math.min(fetchMaxReqsInFlight - inFlight + 1, streamHandler.getNumChunks() - chunkIndex); + Math.min(fetchMaxReqsInFlight - inFlight + 1, endChunkIndex + 1 - chunkIndex); for (int i = 0; i < toFetch; i++) { if (testFetch && fetchChunkRetryCnt < fetchChunkMaxRetry - 1 && chunkIndex == 3) { callback.onFailure(chunkIndex, new CelebornIOException("Test fetch chunk failure")); diff --git a/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala b/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala index d7dccb941b0..b071eff3bf4 100644 --- a/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala +++ b/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala @@ -17,6 +17,8 @@ package org.apache.celeborn.client +import org.apache.celeborn.common.CelebornConf + object ClientUtils { /** @@ -37,4 +39,20 @@ object ClientUtils { } true } + + /** + * If startMapIndex > endMapIndex, means partition is skew partition. + * locations will split to sub-partitions with startMapIndex size. + * + * @param conf cleborn conf + * @param startMapIndex shuffle start map index + * @param endMapIndex shuffle end map index + * @return true if read skew partition without map range + */ + def readSkewPartitionWithoutMapRange( + conf: CelebornConf, + startMapIndex: Int, + endMapIndex: Int): Boolean = { + conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled && startMapIndex > endMapIndex + } } diff --git a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala index 201be286978..aaeb4462fbd 100644 --- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala @@ -18,6 +18,7 @@ package org.apache.celeborn.client import java.util +import java.util.Collections import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, ScheduledFuture, TimeUnit} import java.util.concurrent.atomic.{AtomicInteger, LongAdder} @@ -40,6 +41,7 @@ import org.apache.celeborn.common.rpc.RpcCallContext import org.apache.celeborn.common.util.FunctionConverter._ import org.apache.celeborn.common.util.JavaUtils import org.apache.celeborn.common.util.ThreadUtils +import org.apache.celeborn.common.write.PushFailedBatch case class ShuffleCommittedInfo( // partition id -> unique partition ids @@ -215,13 +217,16 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage mapId: Int, attemptId: Int, numMappers: Int, - partitionId: Int = -1): (Boolean, Boolean) = { + partitionId: Int = -1, + pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = Collections.emptyMap()) + : (Boolean, Boolean) = { getCommitHandler(shuffleId).finishMapperAttempt( shuffleId, mapId, attemptId, numMappers, partitionId, + pushFailedBatches, r => lifecycleManager.workerStatusTracker.recordWorkerFailure(r)) } diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 00dfd77c669..73b0dc7d728 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -58,11 +58,15 @@ import org.apache.celeborn.common.util.{JavaUtils, PbSerDeUtils, ThreadUtils, Ut import org.apache.celeborn.common.util.FunctionConverter._ import org.apache.celeborn.common.util.ThreadUtils.awaitResult import org.apache.celeborn.common.util.Utils.UNKNOWN_APP_SHUFFLE_ID +import org.apache.celeborn.common.write.PushFailedBatch object LifecycleManager { // shuffle id -> partition id -> partition locations type ShuffleFileGroups = ConcurrentHashMap[Int, ConcurrentHashMap[Integer, util.Set[PartitionLocation]]] + // shuffle id -> partition uniqueId -> PushFailedBatch set + type ShufflePushFailedBatches = + ConcurrentHashMap[Int, util.HashMap[String, util.Set[PushFailedBatch]]] type ShuffleAllocatedWorkers = ConcurrentHashMap[Int, ConcurrentHashMap[String, ShufflePartitionLocationInfo]] type ShuffleFailedWorkers = ConcurrentHashMap[WorkerInfo, (StatusCode, Long)] @@ -404,13 +408,13 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends oldPartition, isSegmentGranularityVisible = commitManager.isSegmentGranularityVisible(shuffleId)) - case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId) => + case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) => logTrace(s"Received MapperEnd TaskEnd request, " + s"${Utils.makeMapKey(shuffleId, mapId, attemptId)}") val partitionType = getPartitionType(shuffleId) partitionType match { case PartitionType.REDUCE => - handleMapperEnd(context, shuffleId, mapId, attemptId, numMappers) + handleMapperEnd(context, shuffleId, mapId, attemptId, numMappers, pushFailedBatch) case PartitionType.MAP => handleMapPartitionEnd( context, @@ -802,10 +806,16 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shuffleId: Int, mapId: Int, attemptId: Int, - numMappers: Int): Unit = { + numMappers: Int, + pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]]): Unit = { val (mapperAttemptFinishedSuccess, allMapperFinished) = - commitManager.finishMapperAttempt(shuffleId, mapId, attemptId, numMappers) + commitManager.finishMapperAttempt( + shuffleId, + mapId, + attemptId, + numMappers, + pushFailedBatches = pushFailedBatches) if (mapperAttemptFinishedSuccess && allMapperFinished) { // last mapper finished. call mapper end logInfo(s"Last MapperEnd, call StageEnd with shuffleKey:" + diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala index 65ba4dbab04..38c3b0c9c81 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala @@ -30,7 +30,7 @@ import scala.concurrent.duration.Duration import org.apache.celeborn.client.{ShuffleCommittedInfo, WorkerStatusTracker} import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo -import org.apache.celeborn.client.LifecycleManager.{ShuffleFailedWorkers, ShuffleFileGroups} +import org.apache.celeborn.client.LifecycleManager.{ShuffleFailedWorkers, ShuffleFileGroups, ShufflePushFailedBatches} import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo} @@ -42,6 +42,7 @@ import org.apache.celeborn.common.util.{CollectionUtils, JavaUtils, Utils} // Can Remove this if celeborn don't support scala211 in future import org.apache.celeborn.common.util.FunctionConverter._ import org.apache.celeborn.common.util.ThreadUtils.awaitResult +import org.apache.celeborn.common.write.PushFailedBatch case class CommitFilesParam( worker: WorkerInfo, @@ -74,6 +75,7 @@ abstract class CommitHandler( private val totalWritten = new LongAdder private val fileCount = new LongAdder protected val reducerFileGroupsMap = new ShuffleFileGroups + protected val shufflePushFailedBatches = new ShufflePushFailedBatches val ec = ExecutionContext.fromExecutor(sharedRpcPool) @@ -82,6 +84,8 @@ abstract class CommitHandler( def getPartitionType(): PartitionType + def getShuffleFailedBatches(): ShufflePushFailedBatches = this.shufflePushFailedBatches + def isStageEnd(shuffleId: Int): Boolean = false def isStageEndOrInProcess(shuffleId: Int): Boolean = false @@ -178,6 +182,7 @@ abstract class CommitHandler( def removeExpiredShuffle(shuffleId: Int): Unit = { reducerFileGroupsMap.remove(shuffleId) + shufflePushFailedBatches.remove(shuffleId) } /** @@ -197,6 +202,7 @@ abstract class CommitHandler( attemptId: Int, numMappers: Int, partitionId: Int, + pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]], recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) def registerShuffle( diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala index 9352dc99a72..a08f1e0d51f 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala @@ -39,6 +39,7 @@ import org.apache.celeborn.common.rpc.RpcCallContext import org.apache.celeborn.common.util.FunctionConverter._ import org.apache.celeborn.common.util.JavaUtils import org.apache.celeborn.common.util.Utils +import org.apache.celeborn.common.write.PushFailedBatch /** * This commit handler is for MapPartition ShuffleType, which means that a Map Partition contains all data produced @@ -184,6 +185,7 @@ class MapPartitionCommitHandler( attemptId: Int, numMappers: Int, partitionId: Int, + pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]], recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = { val inProcessingPartitionIds = inProcessMapPartitionEndIds.computeIfAbsent( diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index 951fc89e601..55639764c7c 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -20,11 +20,13 @@ package org.apache.celeborn.client.commit import java.nio.ByteBuffer import java.util import java.util.concurrent.{Callable, ConcurrentHashMap, ThreadPoolExecutor, TimeUnit} +import java.util.function import scala.collection.JavaConverters._ import scala.collection.mutable import com.google.common.cache.{Cache, CacheBuilder} +import com.google.common.collect.Sets import org.apache.celeborn.client.{ClientUtils, ShuffleCommittedInfo, WorkerStatusTracker} import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo @@ -38,6 +40,7 @@ import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.rpc.RpcCallContext import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext} import org.apache.celeborn.common.util.JavaUtils +import org.apache.celeborn.common.write.PushFailedBatch /** * This commit handler is for ReducePartition ShuffleType, which means that a Reduce Partition contains all data @@ -82,6 +85,22 @@ class ReducePartitionCommitHandler( .maximumSize(rpcCacheSize) .build().asInstanceOf[Cache[Int, ByteBuffer]] + private val newShuffleId2PushFailedBatchMapFunc + : function.Function[Int, util.HashMap[String, util.Set[PushFailedBatch]]] = + new util.function.Function[Int, util.HashMap[String, util.Set[PushFailedBatch]]]() { + override def apply(s: Int): util.HashMap[String, util.Set[PushFailedBatch]] = { + new util.HashMap[String, util.Set[PushFailedBatch]]() + } + } + + private val uniqueId2PushFailedBatchMapFunc + : function.Function[String, util.Set[PushFailedBatch]] = + new util.function.Function[String, util.Set[PushFailedBatch]]() { + override def apply(s: String): util.Set[PushFailedBatch] = { + Sets.newHashSet[PushFailedBatch]() + } + } + override def getPartitionType(): PartitionType = { PartitionType.REDUCE } @@ -240,6 +259,7 @@ class ReducePartitionCommitHandler( attemptId: Int, numMappers: Int, partitionId: Int, + pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]], recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = { shuffleMapperAttempts.synchronized { if (getMapperAttempts(shuffleId) == null) { @@ -250,6 +270,17 @@ class ReducePartitionCommitHandler( val attempts = shuffleMapperAttempts.get(shuffleId) if (attempts(mapId) < 0) { attempts(mapId) = attemptId + if (null != pushFailedBatches && !pushFailedBatches.isEmpty) { + val pushFailedBatchesMap = shufflePushFailedBatches.computeIfAbsent( + shuffleId, + newShuffleId2PushFailedBatchMapFunc) + for ((partitionUniqId, pushFailedBatchSet) <- pushFailedBatches.asScala) { + val partitionPushFailedBatches = pushFailedBatchesMap.computeIfAbsent( + partitionUniqId, + uniqueId2PushFailedBatchMapFunc) + partitionPushFailedBatches.addAll(pushFailedBatchSet) + } + } // Mapper with this attemptId finished, also check all other mapper finished or not. (true, ClientUtils.areAllMapperAttemptsFinished(attempts)) } else { @@ -301,7 +332,11 @@ class ReducePartitionCommitHandler( val returnedMsg = GetReducerFileGroupResponse( StatusCode.SUCCESS, reducerFileGroupsMap.getOrDefault(shuffleId, JavaUtils.newConcurrentHashMap()), - getMapperAttempts(shuffleId)) + getMapperAttempts(shuffleId), + pushFailedBatches = + shufflePushFailedBatches.getOrDefault( + shuffleId, + new util.HashMap[String, util.Set[PushFailedBatch]]())) context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg) } }) diff --git a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java index 886aff3dc8e..7053937a06a 100644 --- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java @@ -423,7 +423,11 @@ public void testUpdateReducerFileGroupInterrupted() throws InterruptedException t -> { Thread.sleep(60 * 1000); return GetReducerFileGroupResponse$.MODULE$.apply( - StatusCode.SUCCESS, locations, new int[0], Collections.emptySet()); + StatusCode.SUCCESS, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap()); }); when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) @@ -470,7 +474,11 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() { .thenAnswer( t -> { return GetReducerFileGroupResponse$.MODULE$.apply( - StatusCode.SHUFFLE_NOT_REGISTERED, locations, new int[0], Collections.emptySet()); + StatusCode.SHUFFLE_NOT_REGISTERED, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap()); }); when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) @@ -494,7 +502,11 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() { .thenAnswer( t -> { return GetReducerFileGroupResponse$.MODULE$.apply( - StatusCode.STAGE_END_TIME_OUT, locations, new int[0], Collections.emptySet()); + StatusCode.STAGE_END_TIME_OUT, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap()); }); when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) @@ -518,7 +530,11 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() { .thenAnswer( t -> { return GetReducerFileGroupResponse$.MODULE$.apply( - StatusCode.SHUFFLE_DATA_LOST, locations, new int[0], Collections.emptySet()); + StatusCode.SHUFFLE_DATA_LOST, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap()); }); when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala index 8cd8bedf76d..0570760ce14 100644 --- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala @@ -41,7 +41,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { private val attemptId = 0 private var lifecycleManager: LifecycleManager = _ - private var shuffleClient: ShuffleClientImpl = _ + protected var shuffleClient: ShuffleClientImpl = _ var _shuffleId = 0 def nextShuffleId: Int = { @@ -164,6 +164,8 @@ trait WithShuffleClientSuite extends CelebornFunSuite { null, null, null, + null, + null, metricsCallback) Assert.assertEquals(stream.read(), -1) @@ -180,6 +182,8 @@ trait WithShuffleClientSuite extends CelebornFunSuite { null, null, null, + null, + null, metricsCallback) Assert.assertEquals(stream.read(), -1) } diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java b/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java index 28cb652565b..8509d5717b2 100644 --- a/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java +++ b/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java @@ -68,6 +68,10 @@ public int getValue() { public int availableStorageTypes = 0; + public long fileSize; + + public List chunkOffsets; + public StorageInfo() {} public StorageInfo(Type type, boolean isFinal, String filePath) { @@ -95,6 +99,23 @@ public StorageInfo( this.availableStorageTypes = availableStorageTypes; } + public StorageInfo( + Type type, + String mountPoint, + boolean finalResult, + String filePath, + int availableStorageTypes, + long fileSize, + List chunkOffsets) { + this.type = type; + this.mountPoint = mountPoint; + this.finalResult = finalResult; + this.filePath = filePath; + this.availableStorageTypes = availableStorageTypes; + this.fileSize = fileSize; + this.chunkOffsets = chunkOffsets; + } + public boolean isFinalResult() { return finalResult; } @@ -119,6 +140,22 @@ public String getFilePath() { return filePath; } + public void setChunkOffsets(List chunkOffsets) { + this.chunkOffsets = chunkOffsets; + } + + public List getChunkOffsets() { + return this.chunkOffsets; + } + + public void setFileSize(long fileSize) { + this.fileSize = fileSize; + } + + public long getFileSize() { + return fileSize; + } + @Override public String toString() { return "StorageInfo{" @@ -131,6 +168,10 @@ public String toString() { + finalResult + ", filePath=" + filePath + + ", fileSize=" + + fileSize + + ", chunkOffsets=" + + chunkOffsets + '}'; } @@ -215,7 +256,11 @@ public static PbStorageInfo toPb(StorageInfo storageInfo) { .setType(storageInfo.type.value) .setFinalResult(storageInfo.finalResult) .setMountPoint(storageInfo.mountPoint) - .setAvailableStorageTypes(storageInfo.availableStorageTypes); + .setAvailableStorageTypes(storageInfo.availableStorageTypes) + .setFileSize(storageInfo.getFileSize()); + if (storageInfo.getChunkOffsets() != null) { + builder.addAllChunkOffsets(storageInfo.getChunkOffsets()); + } if (filePath != null) { builder.setFilePath(filePath); } @@ -228,7 +273,9 @@ public static StorageInfo fromPb(PbStorageInfo pbStorageInfo) { pbStorageInfo.getMountPoint(), pbStorageInfo.getFinalResult(), pbStorageInfo.getFilePath(), - pbStorageInfo.getAvailableStorageTypes()); + pbStorageInfo.getAvailableStorageTypes(), + pbStorageInfo.getFileSize(), + pbStorageInfo.getChunkOffsetsList()); } public static int getAvailableTypes(List types) { diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java new file mode 100644 index 00000000000..ccee8bf1113 --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/write/PushFailedBatch.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.write; + +import java.io.Serializable; + +import com.google.common.base.Objects; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; + +public class PushFailedBatch implements Serializable { + + private int mapId; + private int attemptId; + private int batchId; + + public PushFailedBatch(int mapId, int attemptId, int batchId) { + this.mapId = mapId; + this.attemptId = attemptId; + this.batchId = batchId; + } + + public int getMapId() { + return mapId; + } + + public void setMapId(int mapId) { + this.mapId = mapId; + } + + public int getAttemptId() { + return attemptId; + } + + public void setAttemptId(int attemptId) { + this.attemptId = attemptId; + } + + public int getBatchId() { + return batchId; + } + + public void setBatchId(int batchId) { + this.batchId = batchId; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof PushFailedBatch)) { + return false; + } + PushFailedBatch o = (PushFailedBatch) other; + return mapId == o.mapId && attemptId == o.attemptId && batchId == o.batchId; + } + + @Override + public int hashCode() { + return Objects.hashCode(mapId, attemptId, batchId); + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("mapId", mapId) + .append("attemptId", attemptId) + .append("batchId", batchId) + .toString(); + } +} diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushState.java b/common/src/main/java/org/apache/celeborn/common/write/PushState.java index 3979cafd632..9f691094bd9 100644 --- a/common/src/main/java/org/apache/celeborn/common/write/PushState.java +++ b/common/src/main/java/org/apache/celeborn/common/write/PushState.java @@ -18,9 +18,12 @@ package org.apache.celeborn.common.write; import java.io.IOException; +import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; +import com.google.common.collect.Sets; import org.apache.commons.lang3.tuple.Pair; import org.apache.celeborn.common.CelebornConf; @@ -33,9 +36,12 @@ public class PushState { public AtomicReference exception = new AtomicReference<>(); private final InFlightRequestTracker inFlightRequestTracker; + private final Map> failedBatchMap; + public PushState(CelebornConf conf) { pushBufferMaxSize = conf.clientPushBufferMaxSize(); inFlightRequestTracker = new InFlightRequestTracker(conf, this); + failedBatchMap = new ConcurrentHashMap<>(); } public void cleanup() { @@ -88,4 +94,14 @@ public boolean limitZeroInFlight() throws IOException { public int remainingAllowPushes(String hostAndPushPort) { return inFlightRequestTracker.remainingAllowPushes(hostAndPushPort); } + + public void addFailedBatch(String partitionId, PushFailedBatch failedBatch) { + this.failedBatchMap + .computeIfAbsent(partitionId, (s) -> Sets.newConcurrentHashSet()) + .add(failedBatch); + } + + public Map> getFailedBatches() { + return this.failedBatchMap; + } } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index 8aa59bcb29b..553e95aa79e 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -134,6 +134,8 @@ message PbStorageInfo { bool finalResult = 3; string filePath = 4; int32 availableStorageTypes = 5; + int64 fileSize = 6; + repeated int64 chunkOffsets = 7; } message PbPartitionLocation { @@ -354,6 +356,17 @@ message PbMapperEnd { int32 attemptId = 3; int32 numMappers = 4; int32 partitionId = 5; + map pushFailureBatches= 6; +} + +message PbPushFailedBatchSet { + repeated PbPushFailedBatch failureBatches = 1; +} + +message PbPushFailedBatch { + int32 mapId = 1; + int32 attemptId = 2; + int32 batchId = 3; } message PbMapperEndResponse { @@ -375,6 +388,8 @@ message PbGetReducerFileGroupResponse { // only map partition mode has succeed partitionIds repeated int32 partitionIds = 4; + + map pushFailedBatches = 5; } message PbGetShuffleId { @@ -850,6 +865,8 @@ message PbPackedPartitionLocations { repeated string filePaths = 10; repeated int32 availableStorageTypes = 11; repeated int32 modes = 12; + repeated int64 fileSizes = 13; + repeated PbChunkOffsets chunksOffsets = 14; } message PbPackedPartitionLocationsPair { @@ -883,3 +900,7 @@ message PbPushMergedDataSplitPartitionInfo { repeated int32 splitPartitionIndexes = 1; repeated int32 statusCodes = 2; } + +message PbChunkOffsets { + repeated int64 chunkOffset = 1; +} diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index ad9b3538169..0eae8a3f762 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1038,6 +1038,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientPushSendBufferPoolExpireTimeout: Long = get(CLIENT_PUSH_SENDBUFFERPOOL_EXPIRETIMEOUT) def clientPushSendBufferPoolExpireCheckInterval: Long = get(CLIENT_PUSH_SENDBUFFERPOOL_CHECKEXPIREINTERVAL) + def clientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = + get(CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED) // ////////////////////////////////////////////////////// // Client Shuffle // @@ -5932,6 +5934,15 @@ object CelebornConf extends Logging { .intConf .createWithDefault(10000) + val CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") + .categories("client") + .version("0.6.0") + .doc("If this is true, Celeborn will adaptively split skewed partitions instead of reading them by Spark map " + + "range. Please note that this feature requires the `Celeborn-Optimize-Skew-Partitions-spark3_3.patch`. ") + .booleanConf + .createWithDefault(false) + // SSL Configs val SSL_ENABLED: ConfigEntry[Boolean] = diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 0e465196d28..e730b9b7afb 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -32,6 +32,7 @@ import org.apache.celeborn.common.protocol._ import org.apache.celeborn.common.protocol.MessageType._ import org.apache.celeborn.common.quota.ResourceConsumption import org.apache.celeborn.common.util.{PbSerDeUtils, Utils} +import org.apache.celeborn.common.write.PushFailedBatch sealed trait Message extends Serializable @@ -271,7 +272,8 @@ object ControlMessages extends Logging { mapId: Int, attemptId: Int, numMappers: Int, - partitionId: Int) + partitionId: Int, + failedBatchSet: util.Map[String, util.Set[PushFailedBatch]]) extends MasterMessage case class MapperEndResponse(status: StatusCode) extends MasterMessage @@ -285,7 +287,8 @@ object ControlMessages extends Logging { status: StatusCode, fileGroup: util.Map[Integer, util.Set[PartitionLocation]], attempts: Array[Int], - partitionIds: util.Set[Integer] = Collections.emptySet[Integer]()) + partitionIds: util.Set[Integer] = Collections.emptySet[Integer](), + pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = Collections.emptyMap()) extends MasterMessage object WorkerExclude { @@ -721,13 +724,18 @@ object ControlMessages extends Logging { case pb: PbChangeLocationResponse => new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, pb.toByteArray) - case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId) => + case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) => + val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) => + val resultValue = PbSerDeUtils.toPbPushFailedBatchSet(v) + (k, resultValue) + }.toMap.asJava val payload = PbMapperEnd.newBuilder() .setShuffleId(shuffleId) .setMapId(mapId) .setAttemptId(attemptId) .setNumMappers(numMappers) .setPartitionId(partitionId) + .putAllPushFailureBatches(pushFailedMap) .build().toByteArray new TransportMessage(MessageType.MAPPER_END, payload) @@ -744,7 +752,7 @@ object ControlMessages extends Logging { .build().toByteArray new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP, payload) - case GetReducerFileGroupResponse(status, fileGroup, attempts, partitionIds) => + case GetReducerFileGroupResponse(status, fileGroup, attempts, partitionIds, failedBatches) => val builder = PbGetReducerFileGroupResponse .newBuilder() .setStatus(status.getValue) @@ -757,6 +765,11 @@ object ControlMessages extends Logging { }.asJava) builder.addAllAttempts(attempts.map(Integer.valueOf).toIterable.asJava) builder.addAllPartitionIds(partitionIds) + builder.putAllPushFailedBatches( + failedBatches.asScala.map { + case (uniqueId, pushFailedBatchSet) => + (uniqueId, PbSerDeUtils.toPbPushFailedBatchSet(pushFailedBatchSet)) + }.asJava) val payload = builder.build().toByteArray new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE, payload) @@ -1141,7 +1154,11 @@ object ControlMessages extends Logging { pbMapperEnd.getMapId, pbMapperEnd.getAttemptId, pbMapperEnd.getNumMappers, - pbMapperEnd.getPartitionId) + pbMapperEnd.getPartitionId, + pbMapperEnd.getPushFailureBatchesMap.asScala.map { + case (partitionId, pushFailedBatchSet) => + (partitionId, PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet)) + }.toMap.asJava) case MAPPER_END_RESPONSE_VALUE => val pbMapperEndResponse = PbMapperEndResponse.parseFrom(message.getPayload) @@ -1177,11 +1194,16 @@ object ControlMessages extends Logging { val attempts = pbGetReducerFileGroupResponse.getAttemptsList.asScala.map(_.toInt).toArray val partitionIds = new util.HashSet(pbGetReducerFileGroupResponse.getPartitionIdsList) + val pushFailedBatches = pbGetReducerFileGroupResponse.getPushFailedBatchesMap.asScala.map { + case (uniqueId, pushFailedBatchSet) => + (uniqueId, PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet)) + }.toMap.asJava GetReducerFileGroupResponse( Utils.toStatusCode(pbGetReducerFileGroupResponse.getStatus), fileGroup, attempts, - partitionIds) + partitionIds, + pushFailedBatches) case GET_SHUFFLE_ID_VALUE => message.getParsedPayload() diff --git a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala index 8a038242e75..553a4b9f9e0 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala @@ -32,6 +32,7 @@ import org.apache.celeborn.common.protocol.PartitionLocation.Mode import org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource import org.apache.celeborn.common.quota.ResourceConsumption import org.apache.celeborn.common.util.{CollectionUtils => localCollectionUtils} +import org.apache.celeborn.common.write.PushFailedBatch object PbSerDeUtils { @@ -524,6 +525,14 @@ object PbSerDeUtils { pbPackedLocationsBuilder.addFilePaths("") } pbPackedLocationsBuilder.addAvailableStorageTypes(location.getStorageInfo.availableStorageTypes) + pbPackedLocationsBuilder.addFileSizes(location.getStorageInfo.getFileSize) + val chunkOffsets = PbChunkOffsets.newBuilder() + if (null != location.getStorageInfo.chunkOffsets && !location.getStorageInfo.chunkOffsets.isEmpty) { + chunkOffsets.addAllChunkOffset(location.getStorageInfo.chunkOffsets).build() + pbPackedLocationsBuilder.addChunksOffsets(chunkOffsets) + } else { + pbPackedLocationsBuilder.addChunksOffsets(chunkOffsets.build()) + } pbPackedLocationsBuilder.addModes(location.getMode.mode()) } @@ -640,7 +649,9 @@ object PbSerDeUtils { pbPackedPartitionLocations.getMountPoints(index)), pbPackedPartitionLocations.getFinalResult(index), filePath, - pbPackedPartitionLocations.getAvailableStorageTypes(index)), + pbPackedPartitionLocations.getAvailableStorageTypes(index), + pbPackedPartitionLocations.getFileSizes(index), + pbPackedPartitionLocations.getChunksOffsets(index).getChunkOffsetList), Utils.byteStringToRoaringBitmap(pbPackedPartitionLocations.getMapIdBitMap(index))) } @@ -670,4 +681,34 @@ object PbSerDeUtils { }.asJava } + def toPbPushFailedBatch(pushFailedBatch: PushFailedBatch): PbPushFailedBatch = { + PbPushFailedBatch.newBuilder() + .setMapId(pushFailedBatch.getMapId) + .setAttemptId(pushFailedBatch.getAttemptId) + .setBatchId(pushFailedBatch.getBatchId) + .build() + } + + def fromPbPushFailedBatch(pbPushFailedBatch: PbPushFailedBatch): PushFailedBatch = { + new PushFailedBatch( + pbPushFailedBatch.getMapId, + pbPushFailedBatch.getAttemptId, + pbPushFailedBatch.getBatchId) + } + + def toPbPushFailedBatchSet(failedBatchSet: util.Set[PushFailedBatch]): PbPushFailedBatchSet = { + val builder = PbPushFailedBatchSet.newBuilder() + failedBatchSet.asScala.foreach(batch => builder.addFailureBatches(toPbPushFailedBatch(batch))) + + builder.build() + } + + def fromPbPushFailedBatchSet(pbFailedBatchSet: PbPushFailedBatchSet) + : util.Set[PushFailedBatch] = { + val failedBatchSet = new util.HashSet[PushFailedBatch]() + pbFailedBatchSet.getFailureBatchesList.asScala.foreach(batch => + failedBatchSet.add(fromPbPushFailedBatch(batch))) + + failedBatchSet + } } diff --git a/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java index 6f63e487056..0d0613a33ec 100644 --- a/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java +++ b/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java @@ -209,7 +209,7 @@ public void testToStringOutput() { + " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n" + " mode:PRIMARY\n" + " peer:(empty)\n" - + " storage hint:StorageInfo{type=MEMORY, mountPoint='', finalResult=false, filePath=null}\n" + + " storage hint:StorageInfo{type=MEMORY, mountPoint='', finalResult=false, filePath=null, fileSize=0, chunkOffsets=null}\n" + " mapIdBitMap:{}]"; String exp2 = "PartitionLocation[\n" @@ -217,7 +217,7 @@ public void testToStringOutput() { + " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n" + " mode:PRIMARY\n" + " peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n" - + " storage hint:StorageInfo{type=MEMORY, mountPoint='', finalResult=false, filePath=null}\n" + + " storage hint:StorageInfo{type=MEMORY, mountPoint='', finalResult=false, filePath=null, fileSize=0, chunkOffsets=null}\n" + " mapIdBitMap:{}]"; String exp3 = "PartitionLocation[\n" @@ -225,7 +225,8 @@ public void testToStringOutput() { + " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n" + " mode:PRIMARY\n" + " peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n" - + " storage hint:StorageInfo{type=MEMORY, mountPoint='/mnt/disk/0', finalResult=false, filePath=null}\n" + + " storage hint:StorageInfo{type=MEMORY, mountPoint='/mnt/disk/0', " + + "finalResult=false, filePath=null, fileSize=0, chunkOffsets=null}\n" + " mapIdBitMap:{1,2,3}]"; assertEquals(exp1, location1.toString()); assertEquals(exp2, location2.toString()); diff --git a/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java new file mode 100644 index 00000000000..fcfc6b79979 --- /dev/null +++ b/common/src/test/java/org/apache/celeborn/common/write/PushFailedBatchSuiteJ.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.write; + +import java.util.HashSet; +import java.util.Set; + +import org.junit.Assert; +import org.junit.Test; + +public class PushFailedBatchSuiteJ { + + @Test + public void equalsReturnsTrueForIdenticalBatches() { + PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3); + PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3); + Assert.assertEquals(batch1, batch2); + } + + @Test + public void equalsReturnsFalseForDifferentBatches() { + PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3); + PushFailedBatch batch2 = new PushFailedBatch(4, 5, 6); + Assert.assertNotEquals(batch1, batch2); + } + + @Test + public void hashCodeDiffersForDifferentBatches() { + PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3); + PushFailedBatch batch2 = new PushFailedBatch(4, 5, 6); + Assert.assertNotEquals(batch1.hashCode(), batch2.hashCode()); + } + + @Test + public void hashCodeSameForIdenticalBatches() { + PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3); + PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3); + Assert.assertEquals(batch1.hashCode(), batch2.hashCode()); + } + + @Test + public void hashCodeIsConsistent() { + PushFailedBatch batch = new PushFailedBatch(1, 2, 3); + int hashCode1 = batch.hashCode(); + int hashCode2 = batch.hashCode(); + Assert.assertEquals(hashCode1, hashCode2); + } + + @Test + public void toStringReturnsExpectedFormat() { + PushFailedBatch batch = new PushFailedBatch(1, 2, 3); + String expected = "PushFailedBatch[mapId=1,attemptId=2,batchId=3]"; + Assert.assertEquals(expected, batch.toString()); + } + + @Test + public void hashCodeAndEqualsWorkInSet() { + Set set = new HashSet<>(); + PushFailedBatch batch1 = new PushFailedBatch(1, 2, 3); + PushFailedBatch batch2 = new PushFailedBatch(1, 2, 3); + set.add(batch1); + Assert.assertTrue(set.contains(batch2)); + } +} diff --git a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala index 9e1d442fb24..c2c6e47338b 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random +import com.google.common.collect.{Lists, Sets} import org.apache.hadoop.shaded.org.apache.commons.lang3.RandomStringUtils import org.apache.celeborn.CelebornFunSuite @@ -35,6 +36,7 @@ import org.apache.celeborn.common.protocol.message.{ControlMessages, StatusCode} import org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse, WorkerResource} import org.apache.celeborn.common.quota.ResourceConsumption import org.apache.celeborn.common.util.PbSerDeUtils.{fromPbPackedPartitionLocationsPair, toPbPackedPartitionLocationsPair} +import org.apache.celeborn.common.write.PushFailedBatch class PbSerDeUtilsTest extends CelebornFunSuite { @@ -172,6 +174,37 @@ class PbSerDeUtilsTest extends CelebornFunSuite { 27, PartitionLocation.Mode.PRIMARY) + val partitionLocation5 = + new PartitionLocation( + 4, + 4, + "host5", + 50, + 49, + 48, + 47, + PartitionLocation.Mode.PRIMARY) + val partitionLocation6 = + new PartitionLocation( + 5, + 5, + "host6", + 60, + 59, + 58, + 57, + PartitionLocation.Mode.REPLICA, + null, + new StorageInfo( + StorageInfo.Type.HDD, + "", + false, + null, + StorageInfo.LOCAL_DISK_MASK, + 5, + null), + null) + val workerResource = new WorkerResource() workerResource.put( workerInfo1, @@ -369,6 +402,70 @@ class PbSerDeUtilsTest extends CelebornFunSuite { assert(partitionLocation3 == loc1) } + test("testPackedPartitionLocationPairCase3") { + partitionLocation5.setStorageInfo(new StorageInfo( + StorageInfo.Type.HDD, + "", + false, + null, + StorageInfo.LOCAL_DISK_MASK, + 5, + Lists.newArrayList(0, 5, 10))) + partitionLocation5.setPeer(partitionLocation6) + val pairPb = PbSerDeUtils.toPbPackedPartitionLocationsPair( + List(partitionLocation5, partitionLocation6)) + val rePb = PbSerDeUtils.fromPbPackedPartitionLocationsPair(pairPb) + + val loc1 = rePb._1.get(0) + val loc2 = rePb._2.get(0) + + assert(partitionLocation5 == loc1) + assert(partitionLocation6 == loc2) + assert(loc1.getStorageInfo.getFileSize == partitionLocation5.getStorageInfo.getFileSize) + assert(loc1.getStorageInfo.getChunkOffsets == partitionLocation5.getStorageInfo.getChunkOffsets) + + assert(loc2.getStorageInfo.getFileSize == partitionLocation6.getStorageInfo.getFileSize) + assert(loc2.getStorageInfo.getChunkOffsets.isEmpty) + } + + test("testPackedPartitionLocationPairCase4") { + partitionLocation5.setStorageInfo(new StorageInfo( + StorageInfo.Type.HDD, + "", + false, + null, + StorageInfo.LOCAL_DISK_MASK, + 5, + null)) + val pairPb = PbSerDeUtils.toPbPackedPartitionLocationsPair( + List(partitionLocation5)) + val rePb = PbSerDeUtils.fromPbPackedPartitionLocationsPair(pairPb) + + val loc1 = rePb._1.get(0) + + assert(partitionLocation5 == loc1) + assert(loc1.getStorageInfo.getFileSize == partitionLocation5.getStorageInfo.getFileSize) + assert(loc1.getStorageInfo.getChunkOffsets.isEmpty) + } + + test("testPackedPartitionLocationPairCase5") { + partitionLocation5.setStorageInfo(new StorageInfo( + StorageInfo.Type.HDD, + "", + false, + null, + StorageInfo.LOCAL_DISK_MASK)) + val pairPb = PbSerDeUtils.toPbPackedPartitionLocationsPair( + List(partitionLocation5)) + val rePb = PbSerDeUtils.fromPbPackedPartitionLocationsPair(pairPb) + + val loc1 = rePb._1.get(0) + + assert(partitionLocation5 == loc1) + assert(loc1.getStorageInfo.getFileSize == partitionLocation5.getStorageInfo.getFileSize) + assert(loc1.getStorageInfo.getChunkOffsets.isEmpty) + } + test("testPackedPartitionLocationPairIPv6") { val pairPb = PbSerDeUtils.toPbPackedPartitionLocationsPair( List(partitionLocationIPv6)) @@ -565,4 +662,21 @@ class PbSerDeUtilsTest extends CelebornFunSuite { locations.asScala.foreach(p => uniqueIds.remove(p.getUniqueId)) assert(uniqueIds.isEmpty) } + + test("fromAndToPushFailedBatch") { + val failedBatch = new PushFailedBatch(1, 1, 2) + val pbPushFailedBatch = PbSerDeUtils.toPbPushFailedBatch(failedBatch) + val restoredFailedBatch = PbSerDeUtils.fromPbPushFailedBatch(pbPushFailedBatch) + + assert(restoredFailedBatch.equals(failedBatch)) + } + + test("fromAndToPushFailedBatchSet") { + val failedBatchSet = Sets.newHashSet(new PushFailedBatch(1, 1, 2), new PushFailedBatch(2, 2, 3)) + val pbPushFailedBatchSet = PbSerDeUtils.toPbPushFailedBatchSet(failedBatchSet) + val restoredFailedBatchSet = PbSerDeUtils.fromPbPushFailedBatchSet(pbPushFailedBatchSet) + + assert(restoredFailedBatchSet.equals(failedBatchSet)) + } + } diff --git a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala index afc74707f62..03c6176edd8 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala @@ -18,6 +18,7 @@ package org.apache.celeborn.common.util import java.util +import java.util.Collections import org.apache.celeborn.CelebornFunSuite import org.apache.celeborn.common.CelebornConf @@ -144,7 +145,7 @@ class UtilsSuite extends CelebornFunSuite { } test("MapperEnd class convert with pb") { - val mapperEnd = MapperEnd(1, 1, 1, 2, 1) + val mapperEnd = MapperEnd(1, 1, 1, 2, 1, Collections.emptyMap()) val mapperEndTrans = Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd] assert(mapperEnd == mapperEndTrans) diff --git a/docs/configuration/client.md b/docs/configuration/client.md index c576e0c4d27..d9e1a700c77 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -19,6 +19,7 @@ license: | | Key | Default | isDynamic | Description | Since | Deprecated | | --- | ------- | --------- | ----------- | ----- | ---------- | +| celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled | false | false | If this is true, Celeborn will adaptively split skewed partitions instead of reading them by Spark map range. Please note that this feature requires the `Celeborn-Optimize-Skew-Partitions-spark3_3.patch`. | 0.6.0 | | | celeborn.client.application.heartbeatInterval | 10s | false | Interval for client to send heartbeat message to master. | 0.3.0 | celeborn.application.heartbeatInterval | | celeborn.client.application.unregister.enabled | true | false | When true, Celeborn client will inform celeborn master the application is already shutdown during client exit, this allows the cluster to release resources immediately, resulting in resource savings. | 0.3.2 | | | celeborn.client.application.uuidSuffix.enabled | false | false | Whether to add UUID suffix for application id for unique. When `true`, add UUID suffix for unique application id. Currently, this only applies to Spark and MR. | 0.6.0 | | diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala index b9b084b7b46..39b54b5f638 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala @@ -17,12 +17,13 @@ package org.apache.celeborn.tests.client +import java.nio.charset.StandardCharsets import java.util import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.celeborn.client.{LifecycleManager, WithShuffleClientSuite} +import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl, WithShuffleClientSuite} import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers import org.apache.celeborn.client.commit.CommitFilesParam import org.apache.celeborn.common.CelebornConf @@ -228,6 +229,70 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC lifecycleManager.stop() } + test("CELEBORN-1319: test commit files and check commit info") { + val shuffleId = nextShuffleId + val conf = celebornConf.clone + conf.set(CelebornConf.TEST_MOCK_COMMIT_FILES_FAILURE.key, "false") + val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) + val shuffleClient = new ShuffleClientImpl(APP, conf, userIdentifier) + shuffleClient.setupLifecycleManagerRef(lifecycleManager.self) + + val ids = new util.ArrayList[Integer](3) + 0 until 3 foreach { + ids.add(_) + } + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + assert(res.status == StatusCode.SUCCESS) + assert(res.workerResource.keySet().size() == 3) + + lifecycleManager.setupEndpoints( + res.workerResource.keySet, + shuffleId, + new ShuffleFailedWorkers()) + + lifecycleManager.reserveSlotsWithRetry( + shuffleId, + new util.HashSet(res.workerResource.keySet()), + res.workerResource, + updateEpoch = false) + + lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false) + + val buffer = "hello world".getBytes(StandardCharsets.UTF_8) + + var bufferLength = -1 + + 0 until 3 foreach { partitionId => + bufferLength = + shuffleClient.pushData(shuffleId, 0, 0, partitionId, buffer, 0, buffer.length, 1, 3) + lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId) + } + + val commitHandler = lifecycleManager.commitManager.getCommitHandler(shuffleId) + val params = new ArrayBuffer[CommitFilesParam](res.workerResource.size()) + res.workerResource.asScala.foreach { case (workerInfo, (primaryIds, replicaIds)) => + params += CommitFilesParam( + workerInfo, + primaryIds.asScala.map(_.getUniqueId).toList.asJava, + replicaIds.asScala.map(_.getUniqueId).toList.asJava) + } + + val shuffleCommittedInfo = lifecycleManager.commitManager.committedPartitionInfo.get(shuffleId) + commitHandler.doParallelCommitFiles( + shuffleId, + shuffleCommittedInfo, + params, + new ShuffleFailedWorkers) + + shuffleCommittedInfo.committedReplicaStorageInfos.values().asScala.foreach { storageInfo => + assert(storageInfo.fileSize == bufferLength) + // chunkOffsets contains 0 by default, and bufferFlushOffset + assert(storageInfo.chunkOffsets.size() == 2) + } + + lifecycleManager.stop() + } + override def afterAll(): Unit = { logInfo("all test complete , stop celeborn mini cluster") shutdownMiniCluster() diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala new file mode 100644 index 00000000000..716ad4e3693 --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushFailedBatchSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.tests.spark + +import com.google.common.collect.Sets +import org.apache.spark.{SparkConf, SparkContext, SparkContextHelper} +import org.apache.spark.shuffle.celeborn.SparkShuffleManager +import org.junit.Assert +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.protocol.ShuffleMode +import org.apache.celeborn.common.write.PushFailedBatch +import org.apache.celeborn.service.deploy.worker.PushDataHandler + +class PushFailedBatchSuite extends AnyFunSuite + with SparkTestBase + with BeforeAndAfterEach { + + override def beforeAll(): Unit = { + val workerConf = Map( + CelebornConf.TEST_WORKER_PUSH_REPLICA_DATA_TIMEOUT.key -> "true") + + setupMiniClusterWithRandomPorts(workerConf = workerConf, workerNum = 4) + } + + override def beforeEach(): Unit = { + ShuffleClient.reset() + PushDataHandler.pushPrimaryDataTimeoutTested.set(false) + PushDataHandler.pushReplicaDataTimeoutTested.set(false) + PushDataHandler.pushPrimaryMergeDataTimeoutTested.set(false) + PushDataHandler.pushReplicaMergeDataTimeoutTested.set(false) + } + + override protected def afterEach() { + System.gc() + } + + test("CELEBORN-1319: check failed batch info by making push timeout") { + val sparkConf = new SparkConf() + .set(s"spark.${CelebornConf.TEST_CLIENT_RETRY_REVIVE.key}", "false") + .set(s"spark.${CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key}", "true") + .set(s"spark.${CelebornConf.CLIENT_PUSH_DATA_TIMEOUT.key}", "3s") + .set( + s"spark.${CelebornConf.CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED.key}", + "true") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.celeborn.SparkShuffleManager") + .setAppName("celeborn-1319") + .setMaster("local[2]") + updateSparkConf(sparkConf, ShuffleMode.HASH) + val sc = new SparkContext(sparkConf) + + sc.parallelize(1 to 1, 1).repartition(1).map(i => i + 1).collect() + + val manager = SparkContextHelper.env + .shuffleManager + .asInstanceOf[SparkShuffleManager] + .getLifecycleManager + + // only one batch failed due to push timeout, so shuffle id will be 0, + // and PartitionLocation uniqueId will be 0-0 + val pushFailedBatch = manager.commitManager.getCommitHandler(0).getShuffleFailedBatches() + assert(!pushFailedBatch.isEmpty) + Assert.assertEquals( + pushFailedBatch.get(0).get("0-0"), + Sets.newHashSet(new PushFailedBatch(0, 0, 1))) + + sc.stop() + } + +} diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala index 6cfbd227a94..f1a5d5e7769 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala @@ -48,63 +48,69 @@ class SkewJoinSuite extends AnyFunSuite } CompressionCodec.values.foreach { codec => - test(s"celeborn spark integration test - skew join - $codec") { - val sparkConf = new SparkConf().setAppName("celeborn-demo") - .setMaster("local[2]") - .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") - .set("spark.sql.adaptive.skewJoin.enabled", "true") - .set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "16MB") - .set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "12MB") - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.sql.adaptive.autoBroadcastJoinThreshold", "-1") - .set(SQLConf.PARQUET_COMPRESSION.key, "gzip") - .set(s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", codec.name) - .set(s"spark.${CelebornConf.SHUFFLE_RANGE_READ_FILTER_ENABLED.key}", "true") + Seq(false, true).foreach { enabled => + test( + s"celeborn spark integration test - skew join - with $codec - with client skew $enabled") { + val sparkConf = new SparkConf().setAppName("celeborn-demo") + .setMaster("local[2]") + .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set("spark.sql.adaptive.skewJoin.enabled", "true") + .set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "16MB") + .set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "12MB") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.sql.adaptive.autoBroadcastJoinThreshold", "-1") + .set(SQLConf.PARQUET_COMPRESSION.key, "gzip") + .set(s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", codec.name) + .set(s"spark.${CelebornConf.SHUFFLE_RANGE_READ_FILTER_ENABLED.key}", "true") + .set( + s"spark.${CelebornConf.CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED.key}", + s"$enabled") - enableCeleborn(sparkConf) + enableCeleborn(sparkConf) - val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() - if (sparkSession.version.startsWith("3")) { - import sparkSession.implicits._ - val df = sparkSession.sparkContext.parallelize(1 to 120000, 8) - .map(i => { - val random = new Random() - val oriKey = random.nextInt(64) - val key = if (oriKey < 32) 1 else oriKey - val fas = random.nextInt(1200000) - val fa = Range(fas, fas + 100).mkString(",") - val fbs = random.nextInt(1200000) - val fb = Range(fbs, fbs + 100).mkString(",") - val fcs = random.nextInt(1200000) - val fc = Range(fcs, fcs + 100).mkString(",") - val fds = random.nextInt(1200000) - val fd = Range(fds, fds + 100).mkString(",") + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + if (sparkSession.version.startsWith("3")) { + import sparkSession.implicits._ + val df = sparkSession.sparkContext.parallelize(1 to 120000, 8) + .map(i => { + val random = new Random() + val oriKey = random.nextInt(64) + val key = if (oriKey < 32) 1 else oriKey + val fas = random.nextInt(1200000) + val fa = Range(fas, fas + 100).mkString(",") + val fbs = random.nextInt(1200000) + val fb = Range(fbs, fbs + 100).mkString(",") + val fcs = random.nextInt(1200000) + val fc = Range(fcs, fcs + 100).mkString(",") + val fds = random.nextInt(1200000) + val fd = Range(fds, fds + 100).mkString(",") - (key, fa, fb, fc, fd) - }) - .toDF("fa", "f1", "f2", "f3", "f4") - df.createOrReplaceTempView("view1") - val df2 = sparkSession.sparkContext.parallelize(1 to 8, 8) - .map(i => { - val random = new Random() - val oriKey = random.nextInt(64) - val key = if (oriKey < 32) 1 else oriKey - val fas = random.nextInt(1200000) - val fa = Range(fas, fas + 100).mkString(",") - val fbs = random.nextInt(1200000) - val fb = Range(fbs, fbs + 100).mkString(",") - val fcs = random.nextInt(1200000) - val fc = Range(fcs, fcs + 100).mkString(",") - val fds = random.nextInt(1200000) - val fd = Range(fds, fds + 100).mkString(",") - (key, fa, fb, fc, fd) - }) - .toDF("fb", "f6", "f7", "f8", "f9") - df2.createOrReplaceTempView("view2") - sparkSession.sql("drop table if exists fres") - sparkSession.sql("create table fres using parquet as select * from view1 a inner join view2 b on a.fa=b.fb where a.fa=1 ") - sparkSession.sql("drop table fres") - sparkSession.stop() + (key, fa, fb, fc, fd) + }) + .toDF("fa", "f1", "f2", "f3", "f4") + df.createOrReplaceTempView("view1") + val df2 = sparkSession.sparkContext.parallelize(1 to 8, 8) + .map(i => { + val random = new Random() + val oriKey = random.nextInt(64) + val key = if (oriKey < 32) 1 else oriKey + val fas = random.nextInt(1200000) + val fa = Range(fas, fas + 100).mkString(",") + val fbs = random.nextInt(1200000) + val fb = Range(fbs, fbs + 100).mkString(",") + val fcs = random.nextInt(1200000) + val fc = Range(fcs, fcs + 100).mkString(",") + val fds = random.nextInt(1200000) + val fd = Range(fds, fds + 100).mkString(",") + (key, fa, fb, fc, fd) + }) + .toDF("fb", "f6", "f7", "f8", "f9") + df2.createOrReplaceTempView("view2") + sparkSession.sql("drop table if exists fres") + sparkSession.sql("create table fres using parquet as select * from view1 a inner join view2 b on a.fa=b.fb where a.fa=1 ") + sparkSession.sql("drop table fres") + sparkSession.stop() + } } } } diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala index 65285f64729..b2e7c4e844c 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala @@ -32,7 +32,7 @@ import org.roaringbitmap.RoaringBitmap import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.internal.Logging -import org.apache.celeborn.common.meta.{WorkerInfo, WorkerPartitionLocationInfo} +import org.apache.celeborn.common.meta.{ReduceFileMeta, WorkerInfo, WorkerPartitionLocationInfo} import org.apache.celeborn.common.metrics.MetricsSystem import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType, StorageInfo} import org.apache.celeborn.common.protocol.message.ControlMessages._ @@ -337,7 +337,21 @@ private[deploy] class Controller( // Only HDFS can be null, means that this partition location is deleted. logDebug(s"Location $uniqueId is deleted.") } else { - committedStorageInfos.put(uniqueId, fileWriter.getStorageInfo) + val storageInfo = fileWriter.getStorageInfo + val fileInfo = + if (null != fileWriter.getDiskFileInfo) { + fileWriter.getDiskFileInfo + } else { + fileWriter.getMemoryFileInfo + } + val fileMeta = fileInfo.getFileMeta + fileMeta match { + case meta: ReduceFileMeta => + storageInfo.setFileSize(bytes) + storageInfo.setChunkOffsets(meta.getChunkOffsets) + case _ => + } + committedStorageInfos.put(uniqueId, storageInfo) if (fileWriter.getMapIdBitMap != null) { committedMapIdBitMap.put(uniqueId, fileWriter.getMapIdBitMap) } diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index ff32b940173..3c99639eb84 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -249,8 +249,8 @@ class FetchHandler( // 1. when the current request is a non-range openStream, but the original unsorted file // has been deleted by another range's openStream request. // 2. when the current request is a range openStream request. - if ((endIndex != Int.MaxValue) || (endIndex == Int.MaxValue - && !fileInfo.addStream(streamId))) { + if ((endIndex != Int.MaxValue && endIndex != -1 && endIndex >= startIndex) || (endIndex == Int.MaxValue && !fileInfo.addStream( + streamId))) { fileInfo = partitionsSorter.getSortedFileInfo( shuffleKey, fileName, diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala index 0ff646b8f8b..e81593e1d99 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala @@ -116,6 +116,8 @@ trait ReadWriteTestBase extends AnyFunSuite null, null, null, + null, + null, metricsCallback) val outputStream = new ByteArrayOutputStream()