Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#2380] Improvement: Eagerly cancel rpc request #2381

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ private boolean sendShuffleDataAsync(
stageAttemptNumber,
retryMax,
retryIntervalMax,
shuffleIdToBlocks);
shuffleIdToBlocks,
needCancelRequest);
long s = System.currentTimeMillis();
RssSendShuffleDataResponse response =
getShuffleServerClient(ssi).sendShuffleData(request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@
import java.io.File;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import java.util.stream.Stream;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
Expand All @@ -44,21 +50,16 @@
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.rpc.ServerType;
import org.apache.uniffle.coordinator.CoordinatorConf;
import org.apache.uniffle.coordinator.CoordinatorServer;
import org.apache.uniffle.server.MockedGrpcServer;
import org.apache.uniffle.server.MockedShuffleServer;
import org.apache.uniffle.server.ShuffleServer;
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.storage.util.StorageType;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;

public class RpcClientRetryTest extends ShuffleReadWriteBase {

private static ShuffleServerInfo shuffleServerInfo0;
private static ShuffleServerInfo shuffleServerInfo1;
private static ShuffleServerInfo shuffleServerInfo2;
private static List<ShuffleServerInfo> grpcShuffleServerInfoList;
private static MockedShuffleWriteClientImpl shuffleWriteClientImpl;

private ShuffleClientFactory.ReadClientBuilder baseReadBuilder(StorageType storageType) {
Expand All @@ -73,9 +74,9 @@ private ShuffleClientFactory.ReadClientBuilder baseReadBuilder(StorageType stora
.readBufferSize(1000);
}

public static MockedShuffleServer createMockedShuffleServer(int id, File tmpDir)
throws Exception {
ShuffleServerConf shuffleServerConf = getShuffleServerConf(ServerType.GRPC);
public static MockedShuffleServer createMockedShuffleServer(
int id, File tmpDir, ServerType serverType) throws Exception {
ShuffleServerConf shuffleServerConf = getShuffleServerConf(serverType);
File dataDir1 = new File(tmpDir, id + "_1");
File dataDir2 = new File(tmpDir, id + "_2");
String basePath = dataDir1.getAbsolutePath() + "," + dataDir2.getAbsolutePath();
Expand All @@ -85,46 +86,69 @@ public static MockedShuffleServer createMockedShuffleServer(int id, File tmpDir)
shuffleServerConf.set(ShuffleServerConf.SERVER_MEMORY_SHUFFLE_HIGHWATERMARK_PERCENTAGE, 15.0);
shuffleServerConf.set(ShuffleServerConf.SERVER_BUFFER_CAPACITY, 600L);
shuffleServerConf.set(ShuffleServerConf.SINGLE_BUFFER_FLUSH_BLOCKS_NUM_THRESHOLD, 1);
shuffleServerConf.set(ShuffleServerConf.RPC_SERVER_PORT, 0);
return new MockedShuffleServer(shuffleServerConf);
}

@BeforeAll
public static void initCluster(@TempDir File tmpDir) throws Exception {
@BeforeEach
public void initCluster(@TempDir File tmpDir) throws Exception {
CoordinatorConf coordinatorConf = getCoordinatorConf();
createCoordinatorServer(coordinatorConf);

grpcShuffleServers.add(createMockedShuffleServer(0, tmpDir));
grpcShuffleServers.add(createMockedShuffleServer(1, tmpDir));
grpcShuffleServers.add(createMockedShuffleServer(2, tmpDir));

shuffleServerInfo0 =
new ShuffleServerInfo(
String.format("127.0.0.1-%s", grpcShuffleServers.get(0).getGrpcPort()),
grpcShuffleServers.get(0).getIp(),
grpcShuffleServers.get(0).getGrpcPort());
shuffleServerInfo1 =
new ShuffleServerInfo(
String.format("127.0.0.1-%s", grpcShuffleServers.get(1).getGrpcPort()),
grpcShuffleServers.get(1).getIp(),
grpcShuffleServers.get(1).getGrpcPort());
shuffleServerInfo2 =
new ShuffleServerInfo(
String.format("127.0.0.1-%s", grpcShuffleServers.get(2).getGrpcPort()),
grpcShuffleServers.get(2).getIp(),
grpcShuffleServers.get(2).getGrpcPort());
for (CoordinatorServer coordinator : coordinators) {
coordinator.start();
for (int i = 0; i < 3; i++) {
grpcShuffleServers.add(createMockedShuffleServer(i, tmpDir, ServerType.GRPC));
}
for (ShuffleServer shuffleServer : grpcShuffleServers) {
shuffleServer.start();

startServers();
grpcShuffleServerInfoList = Lists.newArrayList();
for (int i = 0; i < 3; i++) {
grpcShuffleServerInfoList.add(
new ShuffleServerInfo(
String.format("127.0.0.1-%s", grpcShuffleServers.get(i).getGrpcPort()),
grpcShuffleServers.get(i).getIp(),
grpcShuffleServers.get(i).getGrpcPort()));
}
}

@AfterAll
public static void cleanEnv() throws Exception {
@Test
public void testCancelGrpc() throws InterruptedException {
String testAppId = "testCancelGrpc";
registerShuffleServer(testAppId, 1, 1, 1, false, 3000);
Map<Long, byte[]> expectedData = Maps.newHashMap();
Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
List<ShuffleBlockInfo> blocks =
createShuffleBlockList(
0,
0,
0,
2,
25,
blockIdBitmap,
expectedData,
Lists.newArrayList(grpcShuffleServerInfoList.get(0)));
AtomicBoolean isCancel = new AtomicBoolean(false);
Supplier<Boolean> needCancelRequest = () -> isCancel.get();
SendShuffleDataResult result =
shuffleWriteClientImpl.sendShuffleData(testAppId, blocks, needCancelRequest);
assertEquals(2, result.getSuccessBlockIds().size());

enableFirstNSendDataRequestsToFail(2);
CompletableFuture<SendShuffleDataResult> future =
CompletableFuture.supplyAsync(
() -> shuffleWriteClientImpl.sendShuffleData(testAppId, blocks, needCancelRequest));
// this ensure isCancel takes effect in rpc retry
TimeUnit.SECONDS.sleep(1);
isCancel.set(true);
Awaitility.await()
.atMost(5, TimeUnit.SECONDS)
.until(() -> future.isDone() && future.get().getSuccessBlockIds().size() == 0);
}

@AfterEach
public void cleanEnv() throws Exception {
if (shuffleWriteClientImpl != null) {
shuffleWriteClientImpl.close();
}
grpcShuffleServerInfoList.clear();
shutdownServers();
}

Expand All @@ -140,22 +164,16 @@ private static Stream<Arguments> testRpcRetryLogicProvider() {
@MethodSource("testRpcRetryLogicProvider")
public void testRpcRetryLogic(StorageType storageType) {
String testAppId = "testRpcRetryLogic";
registerShuffleServer(testAppId, 3, 2, 2, true);
registerShuffleServer(testAppId, 3, 2, 2, true, 1000);
Map<Long, byte[]> expectedData = Maps.newHashMap();
Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();

List<ShuffleBlockInfo> blocks =
createShuffleBlockList(
0,
0,
0,
3,
25,
blockIdBitmap,
expectedData,
Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2));
0, 0, 0, 3, 25, blockIdBitmap, expectedData, grpcShuffleServerInfoList);

SendShuffleDataResult result = shuffleWriteClientImpl.sendShuffleData(testAppId, blocks);
SendShuffleDataResult result =
shuffleWriteClientImpl.sendShuffleData(testAppId, blocks, () -> false);
Roaring64NavigableMap failedBlockIdBitmap = Roaring64NavigableMap.bitmapOf();
Roaring64NavigableMap successfulBlockIdBitmap = Roaring64NavigableMap.bitmapOf();
for (Long blockId : result.getSuccessBlockIds()) {
Expand All @@ -174,8 +192,7 @@ public void testRpcRetryLogic(StorageType storageType) {
.appId(testAppId)
.blockIdBitmap(blockIdBitmap)
.taskIdBitmap(taskIdBitmap)
.shuffleServerInfoList(
Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2))
.shuffleServerInfoList(grpcShuffleServerInfoList)
.retryMax(3)
.retryIntervalMax(1)
.build();
Expand All @@ -195,8 +212,7 @@ public void testRpcRetryLogic(StorageType storageType) {
.appId(testAppId)
.blockIdBitmap(blockIdBitmap)
.taskIdBitmap(taskIdBitmap)
.shuffleServerInfoList(
Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2))
.shuffleServerInfoList(grpcShuffleServerInfoList)
.retryMax(3)
.retryIntervalMax(1)
.build();
Expand All @@ -208,39 +224,52 @@ public void testRpcRetryLogic(StorageType storageType) {
}

private static void enableFirstNReadRequestsToFail(int failedCount) {
for (ShuffleServer server : grpcShuffleServers) {
((MockedGrpcServer) server.getServer())
.getService()
.enableFirstNReadRequestToFail(failedCount);
}
grpcShuffleServers.stream()
.forEach(
server ->
((MockedGrpcServer) server.getServer())
.getService()
.enableFirstNReadRequestToFail(failedCount));
}

private static void enableFirstNSendDataRequestsToFail(int failedCount) {
grpcShuffleServers.stream()
.forEach(
server ->
((MockedGrpcServer) server.getServer())
.getService()
.enableFirstNSendDataRequestToFail(failedCount));
}

private static void disableFirstNReadRequestsToFail() {
for (ShuffleServer server : grpcShuffleServers) {
((MockedGrpcServer) server.getServer()).getService().resetFirstNReadRequestToFail();
}
grpcShuffleServers.stream()
.forEach(
server ->
((MockedGrpcServer) server.getServer())
.getService()
.resetFirstNReadRequestToFail());
}

static class MockedShuffleWriteClientImpl extends ShuffleWriteClientImpl {
MockedShuffleWriteClientImpl(ShuffleClientFactory.WriteClientBuilder builder) {
super(builder);
}

public SendShuffleDataResult sendShuffleData(
String appId, List<ShuffleBlockInfo> shuffleBlockInfoList) {
return super.sendShuffleData(appId, shuffleBlockInfoList, () -> false);
}
}

private void registerShuffleServer(
String testAppId, int replica, int replicaWrite, int replicaRead, boolean replicaSkip) {
String testAppId,
int replica,
int replicaWrite,
int replicaRead,
boolean replicaSkip,
long retryIntervalMs) {

shuffleWriteClientImpl =
new MockedShuffleWriteClientImpl(
ShuffleClientFactory.newWriteBuilder()
.clientType(ClientType.GRPC.name())
.retryMax(3)
.retryIntervalMax(1000)
.retryIntervalMax(retryIntervalMs)
.heartBeatThreadNum(1)
.replica(replica)
.replicaWrite(replicaWrite)
Expand All @@ -252,12 +281,9 @@ private void registerShuffleServer(
.unregisterTimeSec(10)
.unregisterRequestTimeSec(10));

List<ShuffleServerInfo> allServers =
Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2);

for (int i = 0; i < replica; i++) {
shuffleWriteClientImpl.registerShuffle(
allServers.get(i),
grpcShuffleServerInfoList.get(i),
testAppId,
0,
Lists.newArrayList(new PartitionRange(0, 0)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,10 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ
null,
request.getRetryIntervalMax(),
maxRetryAttempts,
t -> !(t instanceof OutOfMemoryError) && !(t instanceof NotRetryException));
t ->
!request.needCancel()
&& !(t instanceof OutOfMemoryError)
&& !(t instanceof NotRetryException));
} catch (Throwable throwable) {
LOG.warn("Failed to send shuffle data due to ", throwable);
isSuccessful = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,10 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ
null,
request.getRetryIntervalMax(),
maxRetryAttempts,
t -> !(t instanceof OutOfMemoryError) && !(t instanceof NotRetryException));
t ->
!request.needCancel()
&& !(t instanceof OutOfMemoryError)
&& !(t instanceof NotRetryException));
} catch (Throwable throwable) {
LOG.warn("Failed to send shuffle data due to ", throwable);
isSuccessful = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

import org.apache.uniffle.common.ShuffleBlockInfo;

Expand All @@ -29,26 +30,29 @@ public class RssSendShuffleDataRequest {
private int retryMax;
private long retryIntervalMax;
private Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks;
private Supplier<Boolean> needCancel;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is kind of leaking details or making RssSendShuffleDataRequest holding references to the sending class, for spark, it's DataPusher. I'm not sure this is the elegant way to do that.

Is it possible for
boolean result = ClientUtils.waitUntilDoneOrFail(futures, allowFastFail); in ShuffleWriteClientImpl to be aware of interruption/spark cancellation, and cancels all the sending futures?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, current Datapusher leake details to sending class, this pr does not make it worse, but achive a eagerly cancel in rpc retry level.
Aware of interruption/spark cancellation is a good idea, i'll follow this way


public RssSendShuffleDataRequest(
String appId,
int retryMax,
long retryIntervalMax,
Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks) {
this(appId, 0, retryMax, retryIntervalMax, shuffleIdToBlocks);
this(appId, 0, retryMax, retryIntervalMax, shuffleIdToBlocks, () -> false);
}

public RssSendShuffleDataRequest(
String appId,
int stageAttemptNumber,
int retryMax,
long retryIntervalMax,
Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks) {
Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks,
Supplier<Boolean> needCancel) {
this.appId = appId;
this.retryMax = retryMax;
this.retryIntervalMax = retryIntervalMax;
this.shuffleIdToBlocks = shuffleIdToBlocks;
this.stageAttemptNumber = stageAttemptNumber;
this.needCancel = needCancel;
}

public String getAppId() {
Expand All @@ -70,4 +74,8 @@ public int getStageAttemptNumber() {
public Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> getShuffleIdToBlocks() {
return shuffleIdToBlocks;
}

public Boolean needCancel() {
return needCancel.get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class MockedShuffleServerGrpcService extends ShuffleServerGrpcService {

private boolean mockSendDataFailed = false;
private int mockSendDataFailedStageNumber = -1;
private AtomicInteger failedSendDataRequest = new AtomicInteger(0);

private boolean mockRequireBufferFailedWithNoBuffer = false;
private boolean isMockRequireBufferFailedWithNoBufferForHugePartition = false;
Expand Down Expand Up @@ -86,6 +87,10 @@ public void enableFirstNReadRequestToFail(int n) {
numOfFailedReadRequest = n;
}

public void enableFirstNSendDataRequestToFail(int n) {
failedSendDataRequest.set(n);
}

public void resetFirstNReadRequestToFail() {
numOfFailedReadRequest = 0;
failedGetShuffleResultRequest.set(0);
Expand Down Expand Up @@ -146,6 +151,10 @@ public void sendShuffleData(
mockSendDataFailedStageNumber);
throw new RuntimeException("This write request is failed as mocked failure!");
}
if (failedSendDataRequest.getAndDecrement() > 0) {
LOG.info("This request is failed as mocked failure");
throw new RuntimeException("This write request is failed as mocked failure!");
}
if (mockedTimeout > 0) {
LOG.info("Add a mocked timeout on sendShuffleData");
Uninterruptibles.sleepUninterruptibly(mockedTimeout, TimeUnit.MILLISECONDS);
Expand Down
Loading