Skip to content

Commit

Permalink
[CELEBORN-1838] Interrupt spark task should not report fetch failure
Browse files Browse the repository at this point in the history
Backport #3070 to main branch.
## What changes were proposed in this pull request?
Do not trigger fetch failure if a spark task attempt is interrupted(speculation enabled). Do not trigger fetch failure if the RPC of getReducerFileGroup is timeout. This PR is intended for celeborn-0.5 branch.

## Why are the changes needed?
Avoid unnecessary fetch failures and stage re-runs.

## Does this PR introduce any user-facing change?
NO.

## How was this patch tested?
1. GA.
2. Manually tested on cluster with spark speculation tasks.

Here is the test case
```scala
sc.parallelize(1 to 100, 100).flatMap(i => {
        (1 to 150000).iterator.map(num => num)
      }).groupBy(i => i, 100)
      .map(i => {
        if (i._1 < 5) {
          Thread.sleep(15000)
        }
        i
      })
      .repartition(400).count
```

<img width="1384" alt="截屏2025-01-18 16 16 16" src="https://github.com/user-attachments/assets/adf64857-5773-4081-a7d0-fa3439e751eb" /> <img width="1393" alt="截屏2025-01-18 16 16 22" src="https://github.com/user-attachments/assets/ac9bf172-1ab4-4669-a930-872d009f2530" /> <img width="1258" alt="截屏2025-01-18 16 19 15" src="https://github.com/user-attachments/assets/6a8ff3e1-c1fb-4ef2-84d8-b1fc6eb56fa6" /> <img width="892" alt="截屏2025-01-18 16 17 27" src="https://github.com/user-attachments/assets/f9de3841-f7d4-4445-99a3-873235d4abd0" />

Closes #3070 from FMX/branch-0.5-b1838.

Authored-by: mingji <fengmingxiao.fmxalibaba-inc.com>

Closes #3080 from turboFei/b1838.

Lead-authored-by: mingji <[email protected]>
Co-authored-by: Wang, Fei <[email protected]>
Signed-off-by: Wang, Fei <[email protected]>
FMX and turboFei committed Jan 23, 2025
1 parent a77a64b commit 75b697d
Showing 9 changed files with 293 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import scala.Tuple2;
import scala.Tuple3;
import scala.reflect.ClassTag$;

import com.google.common.annotations.VisibleForTesting;
@@ -265,9 +265,9 @@ public ReduceFileGroups updateFileGroup(
int shuffleId, int partitionId, boolean isSegmentGranularityVisible)
throws CelebornIOException {
ReduceFileGroups reduceFileGroups =
reduceFileGroupsMap.computeIfAbsent(
shuffleId, (id) -> Tuple2.apply(new ReduceFileGroups(), null))
._1;
reduceFileGroupsMap
.computeIfAbsent(shuffleId, (id) -> Tuple3.apply(new ReduceFileGroups(), null, null))
._1();
if (reduceFileGroups.partitionIds != null
&& reduceFileGroups.partitionIds.contains(partitionId)) {
logger.debug(
@@ -281,12 +281,12 @@ public ReduceFileGroups updateFileGroup(
Utils.makeReducerKey(shuffleId, partitionId));
} else {
// refresh file groups
Tuple2<ReduceFileGroups, String> fileGroups =
Tuple3<ReduceFileGroups, String, Exception> fileGroups =
loadFileGroupInternal(shuffleId, isSegmentGranularityVisible);
ReduceFileGroups newGroups = fileGroups._1;
ReduceFileGroups newGroups = fileGroups._1();
if (newGroups == null) {
throw new CelebornIOException(
loadFileGroupException(shuffleId, partitionId, fileGroups._2));
loadFileGroupException(shuffleId, partitionId, fileGroups._2()));
} else if (!newGroups.partitionIds.contains(partitionId)) {
throw new CelebornIOException(
String.format(
5 changes: 5 additions & 0 deletions client-spark/spark-3-4/pom.xml
Original file line number Diff line number Diff line change
@@ -91,5 +91,10 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -18,12 +18,14 @@
package org.apache.spark.shuffle.celeborn

import java.io.IOException
import java.nio.file.Files
import java.util
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit}
import java.util.concurrent.atomic.AtomicReference

import scala.collection.JavaConverters._

import com.google.common.annotations.VisibleForTesting
import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext}
import org.apache.spark.celeborn.ExceptionMakerHelper
import org.apache.spark.internal.Logging
@@ -33,14 +35,14 @@ import org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.{DummyShuffleClient, ShuffleClient}
import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups
import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException}
import org.apache.celeborn.common.network.client.TransportClient
import org.apache.celeborn.common.network.protocol.TransportMessage
import org.apache.celeborn.common.protocol.{MessageType, PartitionLocation, PbOpenStreamList, PbOpenStreamListResponse, PbStreamHandler}
import org.apache.celeborn.common.protocol._
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils, Utils}

@@ -57,7 +59,9 @@ class CelebornShuffleReader[K, C](
extends ShuffleReader[K, C] with Logging {

private val dep = handle.dependency
private val shuffleClient = ShuffleClient.get(

@VisibleForTesting
val shuffleClient = ShuffleClient.get(
handle.appUniqueId,
handle.lifecycleManagerHost,
handle.lifecycleManagerPort,
@@ -111,7 +115,9 @@ class CelebornShuffleReader[K, C](
fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
} catch {
case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
handleFetchExceptions(handle.shuffleId, shuffleId, 0, ce)
// if a task is interrupted, should not report fetch failure
// if a task update file group timeout, should not report fetch failure
checkAndReportFetchFailureForUpdateFileGroupFailure(shuffleId, ce)
case e: Throwable => throw e
}

@@ -370,7 +376,22 @@ class CelebornShuffleReader[K, C](
}
}

private def handleFetchExceptions(
@VisibleForTesting
def checkAndReportFetchFailureForUpdateFileGroupFailure(
celebornShuffleId: Int,
ce: Throwable): Unit = {
if (ce.getCause != null &&
(ce.getCause.isInstanceOf[InterruptedException] || ce.getCause.isInstanceOf[
TimeoutException])) {
logWarning(s"fetch shuffle ${celebornShuffleId} timeout or interrupt", ce)
throw ce
} else {
handleFetchExceptions(handle.shuffleId, celebornShuffleId, 0, ce)
}
}

@VisibleForTesting
def handleFetchExceptions(
appShuffleId: Int,
shuffleId: Int,
partitionId: Int,
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.celeborn

import java.nio.file.Files
import java.util.concurrent.TimeoutException

import org.apache.spark.{Dependency, ShuffleDependency, TaskContext}
import org.apache.spark.shuffle.ShuffleReadMetricsReporter
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito
import org.mockito.Mockito._
import org.scalatest.funsuite.AnyFunSuite

import org.apache.celeborn.client.{DummyShuffleClient, ShuffleClient}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.CelebornIOException
import org.apache.celeborn.common.identity.UserIdentifier

class CelebornShuffleReaderSuite extends AnyFunSuite {

/**
* Due to spark limitations, spark local mode can not test speculation tasks ,
* test the method `checkAndReportFetchFailureForUpdateFileGroupFailure`
*/
test("CELEBORN-1838 test check report fetch failure exceptions ") {
val dependency = Mockito.mock(classOf[ShuffleDependency[Int, Int, Int]])
val handler = new CelebornShuffleHandle[Int, Int, Int](
"APP",
"HOST1",
1,
UserIdentifier.apply("a", "b"),
0,
true,
1,
dependency)
val context = Mockito.mock(classOf[TaskContext])
val metricReporter = Mockito.mock(classOf[ShuffleReadMetricsReporter])
val conf = new CelebornConf()

val tmpFile = Files.createTempFile("test", ".tmp").toFile
mockStatic(classOf[ShuffleClient]).when(() =>
ShuffleClient.get(any(), any(), any(), any(), any(), any())).thenReturn(
new DummyShuffleClient(conf, tmpFile))

val shuffleReader =
new CelebornShuffleReader[Int, Int](handler, 0, 0, 0, 0, context, conf, metricReporter, null)

val exception1: Throwable = new CelebornIOException("test1", new InterruptedException("test1"))
val exception2: Throwable = new CelebornIOException("test2", new TimeoutException("test2"))
val exception3: Throwable = new CelebornIOException("test3")
val exception4: Throwable = new CelebornIOException("test4")

try {
shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception1)
} catch {
case _: Throwable =>
}
try {
shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception2)
} catch {
case _: Throwable =>
}
try {
shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception3)
} catch {
case _: Throwable =>
}
assert(
shuffleReader.shuffleClient.asInstanceOf[DummyShuffleClient].fetchFailureCount.get() === 1)
try {
shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception4)
} catch {
case _: Throwable =>
}
assert(
shuffleReader.shuffleClient.asInstanceOf[DummyShuffleClient].fetchFailureCount.get() === 2)

}
}
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -55,6 +56,8 @@ public class DummyShuffleClient extends ShuffleClient {
private final Map<Integer, ConcurrentHashMap<Integer, PartitionLocation>> reducePartitionMap =
new HashMap<>();

public AtomicInteger fetchFailureCount = new AtomicInteger();

public DummyShuffleClient(CelebornConf conf, File file) throws Exception {
this.os = new BufferedOutputStream(new FileOutputStream(file));
this.conf = conf;
@@ -181,6 +184,7 @@ public int getShuffleId(

@Override
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) {
fetchFailureCount.incrementAndGet();
return true;
}

Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
import java.util.concurrent.TimeUnit;

import scala.Tuple2;
import scala.Tuple3;
import scala.reflect.ClassTag$;

import com.google.common.annotations.VisibleForTesting;
@@ -170,7 +171,7 @@ public void update(ReduceFileGroups fileGroups) {
}

// key: shuffleId
protected final Map<Integer, Tuple2<ReduceFileGroups, String>> reduceFileGroupsMap =
protected final Map<Integer, Tuple3<ReduceFileGroups, String, Exception>> reduceFileGroupsMap =
JavaUtils.newConcurrentHashMap();

public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier userIdentifier) {
@@ -1742,11 +1743,12 @@ public boolean cleanupShuffle(int shuffleId) {
return true;
}

protected Tuple2<ReduceFileGroups, String> loadFileGroupInternal(
protected Tuple3<ReduceFileGroups, String, Exception> loadFileGroupInternal(
int shuffleId, boolean isSegmentGranularityVisible) {
{
long getReducerFileGroupStartTime = System.nanoTime();
String exceptionMsg = null;
Exception exception = null;
try {
if (lifecycleManagerRef == null) {
exceptionMsg = "Driver endpoint is null!";
@@ -1768,9 +1770,10 @@ protected Tuple2<ReduceFileGroups, String> loadFileGroupInternal(
shuffleId,
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - getReducerFileGroupStartTime),
response.fileGroup().size());
return Tuple2.apply(
return Tuple3.apply(
new ReduceFileGroups(
response.fileGroup(), response.attempts(), response.partitionIds()),
null,
null);
case SHUFFLE_NOT_REGISTERED:
logger.warn(
@@ -1779,9 +1782,10 @@ protected Tuple2<ReduceFileGroups, String> loadFileGroupInternal(
response.status(),
shuffleId);
// return empty result
return Tuple2.apply(
return Tuple3.apply(
new ReduceFileGroups(
response.fileGroup(), response.attempts(), response.partitionIds()),
null,
null);
case STAGE_END_TIME_OUT:
case SHUFFLE_DATA_LOST:
@@ -1800,8 +1804,9 @@ protected Tuple2<ReduceFileGroups, String> loadFileGroupInternal(
}
logger.error("Exception raised while call GetReducerFileGroup for {}.", shuffleId, e);
exceptionMsg = e.getMessage();
exception = e;
}
return Tuple2.apply(null, exceptionMsg);
return Tuple3.apply(null, exceptionMsg, exception);
}
}

@@ -1814,21 +1819,22 @@ public ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
public ReduceFileGroups updateFileGroup(
int shuffleId, int partitionId, boolean isSegmentGranularityVisible)
throws CelebornIOException {
Tuple2<ReduceFileGroups, String> fileGroupTuple =
Tuple3<ReduceFileGroups, String, Exception> fileGroupTuple =
reduceFileGroupsMap.compute(
shuffleId,
(id, existsTuple) -> {
if (existsTuple == null || existsTuple._1 == null) {
if (existsTuple == null || existsTuple._1() == null) {
return loadFileGroupInternal(shuffleId, isSegmentGranularityVisible);
} else {
return existsTuple;
}
});
if (fileGroupTuple._1 == null) {
if (fileGroupTuple._1() == null) {
throw new CelebornIOException(
loadFileGroupException(shuffleId, partitionId, (fileGroupTuple._2)));
loadFileGroupException(shuffleId, partitionId, (fileGroupTuple._2())),
fileGroupTuple._3());
} else {
return fileGroupTuple._1;
return fileGroupTuple._1();
}
}

@@ -1899,7 +1905,7 @@ public CelebornInputStream readPartition(
}

@VisibleForTesting
public Map<Integer, Tuple2<ReduceFileGroups, String>> getReduceFileGroupsMap() {
public Map<Integer, Tuple3<ReduceFileGroups, String, Exception>> getReduceFileGroupsMap() {
return reduceFileGroupsMap;
}

Loading

0 comments on commit 75b697d

Please sign in to comment.