Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize skew partition #3105

Closed
wants to merge 47 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
2570b50
[CELEBORN-1319] Optimize skew partition logic for Reduce Mode to avoi…
wangshengjie123 Dec 13, 2024
43416c3
fix unit test
wangshengjie123 Mar 11, 2024
355f6b3
fix ut and refactor code
wangshengjie123 Mar 16, 2024
2b3ad58
refactor code and imports
wangshengjie123 Mar 16, 2024
1a0462d
fix NPE and remove unused code and fix scala 2.11 compile error
wangshengjie123 Mar 24, 2024
5d1bcb6
format code
wangshengjie123 Mar 24, 2024
2c4b03f
add spark 3.3 patch
wangshengjie123 Mar 24, 2024
ef2dacb
remove unused log and refactor code
wangshengjie123 Mar 26, 2024
b4dae1d
add unit tests and refactor code
wangshengjie123 Mar 28, 2024
52fdb96
split the skewed partition based on the chunk range
cfmcgrady Apr 5, 2024
deb2913
remove CelebornInputStreamSuiteJ
cfmcgrady Apr 7, 2024
f7abfc2
refactor code for batch open stream
wangshengjie123 Apr 8, 2024
2e3be7d
Revert "refactor code for batch open stream"
cfmcgrady Apr 8, 2024
cf8dcb2
`celeborn.client.dataPushFailure.tracking.enabled` -> `celeborn.clien…
cfmcgrady Apr 8, 2024
6ce0e11
refactor code for batch open stream
wangshengjie123 Apr 9, 2024
9b7cc7f
add ut and refactor split logic
wangshengjie123 Apr 12, 2024
8ac7172
refactor code according to review suggestions
wangshengjie123 Apr 12, 2024
c169aa2
add sbt dependency
wangshengjie123 Apr 12, 2024
4b9fab6
update spark patch
cfmcgrady Apr 15, 2024
2e3a1c1
address comment
cfmcgrady Apr 15, 2024
1794fd4
fix
cfmcgrady Apr 15, 2024
e7d8d73
add ut
wangshengjie123 Apr 16, 2024
14d5f10
update
cfmcgrady Apr 18, 2024
85f9445
refactor
cfmcgrady Apr 24, 2024
697b468
add CelebornPartitionUtil.java
cfmcgrady Apr 24, 2024
2b6349d
fix ut for jdk 8
cfmcgrady Apr 25, 2024
8155619
rebase main and fix npe
wangshengjie123 Jun 2, 2024
aa30ff6
fix ut
wangshengjie123 Jun 3, 2024
6b4bb40
fix npe when memory storage enabled
wangshengjie123 Jun 3, 2024
59bef45
fix code style check error
wangshengjie123 Jun 3, 2024
031fdb3
fix comile error
wangshengjie123 Oct 14, 2024
3966f49
add spark patch
wangshengjie123 Nov 26, 2024
8c17522
fix ut compile error and add license to spark patch
wangshengjie123 Nov 26, 2024
21f14d9
update spark patch to abort stage when rerun skew join stage
wangshengjie123 Nov 27, 2024
e9efddc
address comment, avoid reading replicate peer when read skew partitio…
wangshengjie123 Dec 22, 2024
23ca472
update ClientUtils
wangshengjie123 Dec 23, 2024
94cb56d
address review comment
wangshengjie123 Jan 19, 2025
60422d1
remove duplicate uts
wangshengjie123 Jan 19, 2025
4130227
add spark patch
wangshengjie123 Jan 23, 2025
e9dcfd6
Merge branch 'main' into optimize-skew-partition
wangshengjie123 Jan 24, 2025
1c58362
fix codestyle
wangshengjie123 Jan 24, 2025
cbceadf
fix uts
wangshengjie123 Jan 24, 2025
21a0635
add method javadoc comment
wangshengjie123 Feb 13, 2025
4307f0c
Merge branch 'main' into optimize-skew-partition
wangshengjie123 Feb 13, 2025
a9a4af1
Update common/src/main/scala/org/apache/celeborn/common/CelebornConf.…
RexXiong Feb 17, 2025
cbad6f8
Update docs/configuration/client.md
RexXiong Feb 17, 2025
400c9be
Merge branch 'main' into optimize-skew-partition
turboFei Feb 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 315 additions & 0 deletions assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch

Large diffs are not rendered by default.

312 changes: 312 additions & 0 deletions assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch

Large diffs are not rendered by default.

312 changes: 312 additions & 0 deletions assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch

Large diffs are not rendered by default.

312 changes: 312 additions & 0 deletions assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.celeborn;

import java.util.*;

import org.apache.commons.lang3.tuple.Pair;

import org.apache.celeborn.common.protocol.PartitionLocation;

public class CelebornPartitionUtil {
/**
* The general idea is to divide each skew partition into smaller partitions:
*
* <p>- Spark driver will calculate the number of sub-partitions: {@code subPartitionSize =
* skewPartitionTotalSize / subPartitionTargetSize}
*
* <p>- In Celeborn, we divide the skew partition into {@code subPartitionSize} small partitions
* by PartitionLocation chunk offsets. This allows them to run in parallel Spark tasks.
*
* <p>For example, one skewed partition has 2 PartitionLocation:
*
* <ul>
* <li>PartitionLocation 0 with chunk offset [0L, 100L, 200L, 300L, 500L, 1000L]
* <li>PartitionLocation 1 with chunk offset [0L, 200L, 500L, 800L, 900L, 1000L]
* </ul>
*
* If we want to divide it into 3 sub-partitions (each sub-partition target size is 2000/3), the
* result will be:
*
* <ul>
* <li>sub-partition 0: uniqueId0 -> (0, 3)
* <li>sub-partition 1: uniqueId0 -> (4, 4), uniqueId1 -> (0, 0)
* <li>sub-partition 2: uniqueId1 -> (1, 4)
* </ul>
*
* Note: (0, 3) means chunks with chunkIndex 0-1-2-3, four chunks.
*
* @param locations PartitionLocation information belonging to the reduce partition
* @param subPartitionSize the number of sub-partitions separated from the reduce partition
* @param subPartitionIndex current sub-partition index
* @return a map of partitionUniqueId to chunkRange pairs for one subtask of skew partitions
*/
public static Map<String, Pair<Integer, Integer>> splitSkewedPartitionLocations(
ArrayList<PartitionLocation> locations, int subPartitionSize, int subPartitionIndex) {
locations.sort(Comparator.comparing((PartitionLocation p) -> p.getUniqueId()));
long totalPartitionSize =
locations.stream().mapToLong((PartitionLocation p) -> p.getStorageInfo().fileSize).sum();
long step = totalPartitionSize / subPartitionSize;
long startOffset = step * subPartitionIndex;
long endOffset =
subPartitionIndex < subPartitionSize - 1
? step * (subPartitionIndex + 1)
: totalPartitionSize + 1; // last subPartition should include all remaining data

long partitionLocationOffset = 0;
Map<String, Pair<Integer, Integer>> chunkRange = new HashMap<>();
for (PartitionLocation p : locations) {
int left = -1;
int right = -1;
Iterator<Long> chunkOffsets = p.getStorageInfo().getChunkOffsets().iterator();
// Start from index 1 since the first chunk offset is always 0.
chunkOffsets.next();
int j = 1;
while (chunkOffsets.hasNext()) {
long currentOffset = partitionLocationOffset + chunkOffsets.next();
if (currentOffset > startOffset && left < 0) {
left = j - 1;
}
if (currentOffset <= endOffset) {
right = j - 1;
}
if (left >= 0 && right >= 0) {
chunkRange.put(p.getUniqueId(), Pair.of(left, right));
}
j++;
}
partitionLocationOffset += p.getStorageInfo().getFileSize();
}
return chunkRange;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
package org.apache.spark.shuffle.celeborn

import java.io.IOException
import java.nio.file.Files
import java.util
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Set => JSet}
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit}
import java.util.concurrent.atomic.AtomicReference

import scala.collection.JavaConverters._

import com.google.common.annotations.VisibleForTesting
import org.apache.commons.lang3.tuple.Pair
import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext}
import org.apache.spark.celeborn.ExceptionMakerHelper
import org.apache.spark.internal.Logging
Expand All @@ -35,7 +35,7 @@ import org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

import org.apache.celeborn.client.{DummyShuffleClient, ShuffleClient}
import org.apache.celeborn.client.{ClientUtils, ShuffleClient}
import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups
import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
import org.apache.celeborn.common.CelebornConf
Expand Down Expand Up @@ -122,15 +122,41 @@ class CelebornShuffleReader[K, C](
}

// host-port -> (TransportClient, PartitionLocation Array, PbOpenStreamList)
val workerRequestMap = new util.HashMap[
val workerRequestMap = new JHashMap[
String,
(TransportClient, util.ArrayList[PartitionLocation], PbOpenStreamList.Builder)]()
(TransportClient, JArrayList[PartitionLocation], PbOpenStreamList.Builder)]()
// partitionId -> (partition uniqueId -> chunkRange pair)
val partitionId2ChunkRange = new JHashMap[Int, JMap[String, Pair[Integer, Integer]]]()

val partitionId2PartitionLocations = new JHashMap[Int, JSet[PartitionLocation]]()

var partCnt = 0

// if startMapIndex > endMapIndex, means partition is skew partition and read by Celeborn implementation.
// locations will split to sub-partitions with startMapIndex size.
val splitSkewPartitionWithoutMapRange =
ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex)

(startPartition until endPartition).foreach { partitionId =>
if (fileGroups.partitionGroups.containsKey(partitionId)) {
fileGroups.partitionGroups.get(partitionId).asScala.foreach { location =>
var locations = fileGroups.partitionGroups.get(partitionId)
if (splitSkewPartitionWithoutMapRange) {
val partitionLocation2ChunkRange = CelebornPartitionUtil.splitSkewedPartitionLocations(
new JArrayList(locations),
startMapIndex,
endMapIndex)
partitionId2ChunkRange.put(partitionId, partitionLocation2ChunkRange)
// filter locations avoid OPEN_STREAM when split skew partition without map range
val filterLocations = locations.asScala
.filter { location =>
null != partitionLocation2ChunkRange &&
partitionLocation2ChunkRange.containsKey(location.getUniqueId)
}
locations = filterLocations.asJava
partitionId2PartitionLocations.put(partitionId, locations)
}

locations.asScala.foreach { location =>
partCnt += 1
val hostPort = location.hostAndFetchPort
if (!workerRequestMap.containsKey(hostPort)) {
Expand All @@ -142,7 +168,7 @@ class CelebornShuffleReader[K, C](
pbOpenStreamList.setShuffleKey(shuffleKey)
workerRequestMap.put(
hostPort,
(client, new util.ArrayList[PartitionLocation], pbOpenStreamList))
(client, new JArrayList[PartitionLocation], pbOpenStreamList))
} catch {
case ex: Exception =>
shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort, ex)
Expand Down Expand Up @@ -203,13 +229,22 @@ class CelebornShuffleReader[K, C](

def createInputStream(partitionId: Int): Unit = {
val locations =
if (fileGroups.partitionGroups.containsKey(partitionId)) {
new util.ArrayList(fileGroups.partitionGroups.get(partitionId))
} else new util.ArrayList[PartitionLocation]()
if (splitSkewPartitionWithoutMapRange) {
partitionId2PartitionLocations.get(partitionId)
} else {
fileGroups.partitionGroups.get(partitionId)
}

val locationList =
if (null == locations) {
new JArrayList[PartitionLocation]()
} else {
new JArrayList[PartitionLocation](locations)
}
val streamHandlers =
if (locations != null) {
val streamHandlerArr = new util.ArrayList[PbStreamHandler](locations.size())
locations.asScala.foreach { loc =>
val streamHandlerArr = new JArrayList[PbStreamHandler](locationList.size)
locationList.asScala.foreach { loc =>
streamHandlerArr.add(locationStreamHandlerMap.get(loc))
}
streamHandlerArr
Expand All @@ -226,8 +261,10 @@ class CelebornShuffleReader[K, C](
endMapIndex,
if (throwsFetchFailure) ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER
else null,
locations,
locationList,
streamHandlers,
fileGroups.pushFailedBatches,
partitionId2ChunkRange.get(partitionId),
fileGroups.mapAttempts,
metricsCallback)
streams.put(partitionId, inputStream)
Expand Down Expand Up @@ -414,7 +451,6 @@ class CelebornShuffleReader[K, C](
def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
dep.serializer.newInstance()
}

}

object CelebornShuffleReader {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.celeborn;

import java.util.*;

import org.apache.commons.lang3.tuple.Pair;
import org.junit.Assert;
import org.junit.Test;

import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.StorageInfo;

public class CelebornPartitionUtilSuiteJ {
@Test
public void testSkewPartitionSplit() {

ArrayList<PartitionLocation> locations = new ArrayList<>();
for (int i = 0; i < 13; i++) {
locations.add(genPartitionLocation(i, new Long[] {0L, 100L, 200L, 300L, 500L, 1000L}));
}
locations.add(genPartitionLocation(91, new Long[] {0L, 1L}));

int subPartitionSize = 3;

Map<String, Pair<Integer, Integer>> result1 =
CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 0);
Map<String, Pair<Integer, Integer>> expectResult1 =
genRanges(
new Object[][] {
{"0-0", 0, 4},
{"0-1", 0, 4},
{"0-10", 0, 4},
{"0-11", 0, 4},
{"0-12", 0, 2}
});
Assert.assertEquals(expectResult1, result1);

Map<String, Pair<Integer, Integer>> result2 =
CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 1);
Map<String, Pair<Integer, Integer>> expectResult2 =
genRanges(
new Object[][] {
{"0-12", 3, 4},
{"0-2", 0, 4},
{"0-3", 0, 4},
{"0-4", 0, 4},
{"0-5", 0, 3}
});
Assert.assertEquals(expectResult2, result2);

Map<String, Pair<Integer, Integer>> result3 =
CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 2);
Map<String, Pair<Integer, Integer>> expectResult3 =
genRanges(
new Object[][] {
{"0-5", 4, 4},
{"0-6", 0, 4},
{"0-7", 0, 4},
{"0-8", 0, 4},
{"0-9", 0, 4},
{"0-91", 0, 0}
});
Assert.assertEquals(expectResult3, result3);
}

@Test
public void testBoundary() {
ArrayList<PartitionLocation> locations = new ArrayList<>();
locations.add(genPartitionLocation(0, new Long[] {0L, 100L, 200L, 300L, 400L, 500L}));

for (int i = 0; i < 5; i++) {
Map<String, Pair<Integer, Integer>> result =
CelebornPartitionUtil.splitSkewedPartitionLocations(locations, 5, i);
Map<String, Pair<Integer, Integer>> expectResult = genRanges(new Object[][] {{"0-0", i, i}});
Assert.assertEquals(expectResult, result);
}
}

@Test
public void testSplitStable() {
ArrayList<PartitionLocation> locations = new ArrayList<>();
for (int i = 0; i < 13; i++) {
locations.add(genPartitionLocation(i, new Long[] {0L, 100L, 200L, 300L, 500L, 1000L}));
}
locations.add(genPartitionLocation(91, new Long[] {0L, 1L}));

Collections.shuffle(locations);

Map<String, Pair<Integer, Integer>> result =
CelebornPartitionUtil.splitSkewedPartitionLocations(locations, 3, 0);
Map<String, Pair<Integer, Integer>> expectResult =
genRanges(
new Object[][] {
{"0-0", 0, 4},
{"0-1", 0, 4},
{"0-10", 0, 4},
{"0-11", 0, 4},
{"0-12", 0, 2}
});
Assert.assertEquals(expectResult, result);
}

private ArrayList<PartitionLocation> genPartitionLocations(Map<Integer, Long[]> epochToOffsets) {
ArrayList<PartitionLocation> locations = new ArrayList<>();
epochToOffsets.forEach(
(epoch, offsets) -> {
PartitionLocation location =
new PartitionLocation(
0, epoch, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY);
StorageInfo storageInfo =
new StorageInfo(
StorageInfo.Type.HDD,
"mountPoint",
false,
"filePath",
StorageInfo.LOCAL_DISK_MASK,
1,
Arrays.asList(offsets));
location.setStorageInfo(storageInfo);
locations.add(location);
});
return locations;
}

private PartitionLocation genPartitionLocation(int epoch, Long[] offsets) {
PartitionLocation location =
new PartitionLocation(0, epoch, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY);
StorageInfo storageInfo =
new StorageInfo(
StorageInfo.Type.HDD,
"mountPoint",
false,
"filePath",
StorageInfo.LOCAL_DISK_MASK,
offsets[offsets.length - 1],
Arrays.asList(offsets));
location.setStorageInfo(storageInfo);
return location;
}

private Map<String, Pair<Integer, Integer>> genRanges(Object[][] inputs) {
Map<String, Pair<Integer, Integer>> ranges = new HashMap<>();
for (Object[] idToChunkRange : inputs) {
String uid = (String) idToChunkRange[0];
Pair<Integer, Integer> range = Pair.of((int) idToChunkRange[1], (int) idToChunkRange[2]);
ranges.put(uid, range);
}
return ranges;
}
}
Loading
Loading