Skip to content

Commit

Permalink
Merge branch 'main' into optimize-skew-partition
Browse files Browse the repository at this point in the history
  • Loading branch information
turboFei authored Feb 17, 2025
2 parents cbad6f8 + fc459c0 commit 400c9be
Show file tree
Hide file tree
Showing 34 changed files with 1,466 additions and 675 deletions.
145 changes: 76 additions & 69 deletions client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.apache.celeborn.client.read.MetricsCallback;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.exception.CelebornRuntimeException;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.metrics.source.Role;
import org.apache.celeborn.common.network.TransportContext;
Expand Down Expand Up @@ -84,6 +85,8 @@ public class ShuffleClientImpl extends ShuffleClient {

private final int registerShuffleMaxRetries;
private final long registerShuffleRetryWaitMs;
private final int rpcMaxRetries;
private final long rpcRetryWait;
private final int maxReviveTimes;
private final boolean testRetryRevive;
private final int pushBufferMaxSize;
Expand Down Expand Up @@ -189,6 +192,8 @@ public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier u
this.userIdentifier = userIdentifier;
registerShuffleMaxRetries = conf.clientRegisterShuffleMaxRetry();
registerShuffleRetryWaitMs = conf.clientRegisterShuffleRetryWaitMs();
rpcMaxRetries = conf.clientRpcMaxRetries();
rpcRetryWait = conf.clientRpcRetryWait();
maxReviveTimes = conf.clientPushMaxReviveTimes();
testRetryRevive = conf.testRetryRevive();
pushBufferMaxSize = conf.clientPushBufferMaxSize();
Expand Down Expand Up @@ -546,6 +551,8 @@ private ConcurrentHashMap<Integer, PartitionLocation> registerShuffle(
lifecycleManagerRef.askSync(
RegisterShuffle$.MODULE$.apply(shuffleId, numMappers, numPartitions),
conf.clientRpcRegisterShuffleAskTimeout(),
rpcMaxRetries,
rpcRetryWait,
ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
}

Expand Down Expand Up @@ -1747,7 +1754,9 @@ private void mapEndInternal(
numMappers,
partitionId,
pushState.getFailedBatches()),
ClassTag$.MODULE$.apply(MapperEndResponse.class));
rpcMaxRetries,
rpcRetryWait,
ClassTag$.MODULE$.apply(MapperEndResponse.class));
if (response.status() != StatusCode.SUCCESS) {
throw new CelebornIOException("MapperEnd failed! StatusCode: " + response.status());
}
Expand Down Expand Up @@ -1781,75 +1790,64 @@ public boolean cleanupShuffle(int shuffleId) {

protected Tuple3<ReduceFileGroups, String, Exception> loadFileGroupInternal(
int shuffleId, boolean isSegmentGranularityVisible) {
{
long getReducerFileGroupStartTime = System.nanoTime();
String exceptionMsg = null;
Exception exception = null;
try {
if (lifecycleManagerRef == null) {
exceptionMsg = "Driver endpoint is null!";
long getReducerFileGroupStartTime = System.nanoTime();
String exceptionMsg = null;
Exception exception = null;
if (lifecycleManagerRef == null) {
exceptionMsg = "Driver endpoint is null!";
logger.warn(exceptionMsg);
return Tuple3.apply(null, exceptionMsg, exception);
}
try {
GetReducerFileGroup getReducerFileGroup =
new GetReducerFileGroup(shuffleId, isSegmentGranularityVisible);
GetReducerFileGroupResponse response =
lifecycleManagerRef.askSync(
getReducerFileGroup,
conf.clientRpcGetReducerFileGroupAskTimeout(),
rpcMaxRetries,
rpcRetryWait,
ClassTag$.MODULE$.apply(GetReducerFileGroupResponse.class));
switch (response.status()) {
case SUCCESS:
logger.info(
"Shuffle {} request reducer file group success using {} ms, result partition size {}.",
shuffleId,
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - getReducerFileGroupStartTime),
response.fileGroup().size());
return Tuple3.apply(
new ReduceFileGroups(
response.fileGroup(), response.attempts(), response.partitionIds(),response.pushFailedBatches()),
null,
null);
case SHUFFLE_NOT_REGISTERED:
logger.warn(
"Request {} return {} for {}.", getReducerFileGroup, response.status(), shuffleId);
// return empty result
return Tuple3.apply(
new ReduceFileGroups(
response.fileGroup(), response.attempts(), response.partitionIds(),response.pushFailedBatches()),
null,
null);
case STAGE_END_TIME_OUT:
case SHUFFLE_DATA_LOST:
exceptionMsg =
String.format(
"Request %s return %s for %s.",
getReducerFileGroup, response.status(), shuffleId);
logger.warn(exceptionMsg);
} else {
GetReducerFileGroup getReducerFileGroup =
new GetReducerFileGroup(shuffleId, isSegmentGranularityVisible);

GetReducerFileGroupResponse response =
lifecycleManagerRef.askSync(
getReducerFileGroup,
conf.clientRpcGetReducerFileGroupAskTimeout(),
ClassTag$.MODULE$.apply(GetReducerFileGroupResponse.class));

switch (response.status()) {
case SUCCESS:
logger.info(
"Shuffle {} request reducer file group success using {} ms, result partition size {}.",
shuffleId,
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - getReducerFileGroupStartTime),
response.fileGroup().size());
return Tuple3.apply(
new ReduceFileGroups(
response.fileGroup(),
response.attempts(),
response.partitionIds(),
response.pushFailedBatches()),
null,
null);
case SHUFFLE_NOT_REGISTERED:
logger.warn(
"Request {} return {} for {}.",
getReducerFileGroup,
response.status(),
shuffleId);
// return empty result
return Tuple3.apply(
new ReduceFileGroups(
response.fileGroup(),
response.attempts(),
response.partitionIds(),
response.pushFailedBatches()),
null,
null);
case STAGE_END_TIME_OUT:
case SHUFFLE_DATA_LOST:
exceptionMsg =
String.format(
"Request %s return %s for %s.",
getReducerFileGroup, response.status(), shuffleId);
logger.warn(exceptionMsg);
break;
default: // fall out
}
}
} catch (Exception e) {
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
logger.error("Exception raised while call GetReducerFileGroup for {}.", shuffleId, e);
exceptionMsg = e.getMessage();
exception = e;
break;
default: // fall out
}
return Tuple3.apply(null, exceptionMsg, exception);
} catch (Exception e) {
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
logger.error("Exception raised while call GetReducerFileGroup for {}.", shuffleId, e);
exceptionMsg = e.getMessage();
exception = e;
}
return Tuple3.apply(null, exceptionMsg, exception);
}

@Override
Expand Down Expand Up @@ -1985,8 +1983,17 @@ public void shutdown() {
@Override
public void setupLifecycleManagerRef(String host, int port) {
logger.info("setupLifecycleManagerRef: host = {}, port = {}", host, port);
lifecycleManagerRef =
rpcEnv.setupEndpointRef(new RpcAddress(host, port), RpcNameConstants.LIFECYCLE_MANAGER_EP);
try {
lifecycleManagerRef =
rpcEnv.setupEndpointRef(
new RpcAddress(host, port),
RpcNameConstants.LIFECYCLE_MANAGER_EP,
rpcMaxRetries,
rpcRetryWait);
} catch (Exception e) {
throw new CelebornRuntimeException("setupLifecycleManagerRef failed!", e);
}

initDataClientFactoryIfNeeded();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,14 @@ private PartitionReader createReader(
logger.debug("Read local shuffle file {}", localHostAddress);
containLocalRead = true;
return new LocalPartitionReader(
conf, shuffleKey, location, clientFactory, startMapIndex, endMapIndex, callback);
conf,
shuffleKey,
location,
pbStreamHandler,
clientFactory,
startMapIndex,
endMapIndex,
callback);
} else {
return new WorkerPartitionReader(
conf,
Expand All @@ -575,7 +582,14 @@ private PartitionReader createReader(
case S3:
case HDFS:
return new DfsPartitionReader(
conf, shuffleKey, location, clientFactory, startMapIndex, endMapIndex, callback);
conf,
shuffleKey,
location,
pbStreamHandler,
clientFactory,
startMapIndex,
endMapIndex,
callback);
default:
throw new CelebornIOException(
String.format("Unknown storage info %s to read location %s", storageInfo, location));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ public DfsPartitionReader(
CelebornConf conf,
String shuffleKey,
PartitionLocation location,
PbStreamHandler pbStreamHandler,
TransportClientFactory clientFactory,
int startMapIndex,
int endMapIndex,
Expand All @@ -95,10 +96,10 @@ public DfsPartitionReader(
this.hadoopFs = ShuffleClient.getHadoopFs(conf).get(StorageInfo.Type.HDFS);
}

if (endMapIndex != Integer.MAX_VALUE) {
long fetchTimeoutMs = conf.clientFetchTimeoutMs();
try {
client = clientFactory.createClient(location.getHost(), location.getFetchPort());
long fetchTimeoutMs = conf.clientFetchTimeoutMs();
try {
client = clientFactory.createClient(location.getHost(), location.getFetchPort());
if (pbStreamHandler == null) {
TransportMessage openStream =
new TransportMessage(
MessageType.OPEN_STREAM,
Expand All @@ -112,13 +113,16 @@ public DfsPartitionReader(
ByteBuffer response = client.sendRpcSync(openStream.toByteBuffer(), fetchTimeoutMs);
streamHandler = TransportMessage.fromByteBuffer(response).getParsedPayload();
// Parse this message to ensure sort is done.
} catch (IOException | InterruptedException e) {
throw new IOException(
"read shuffle file from DFS failed, filePath: "
+ location.getStorageInfo().getFilePath(),
e);
} else {
streamHandler = pbStreamHandler;
}
} catch (IOException | InterruptedException e) {
throw new IOException(
"read shuffle file from DFS failed, filePath: " + location.getStorageInfo().getFilePath(),
e);
}

if (endMapIndex != Integer.MAX_VALUE) {
dfsInputStream =
hadoopFs.open(new Path(Utils.getSortedFilePath(location.getStorageInfo().getFilePath())));
chunkOffsets.addAll(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ public LocalPartitionReader(
CelebornConf conf,
String shuffleKey,
PartitionLocation location,
PbStreamHandler pbStreamHandler,
TransportClientFactory clientFactory,
int startMapIndex,
int endMapIndex,
Expand All @@ -95,19 +96,23 @@ public LocalPartitionReader(
long fetchTimeoutMs = conf.clientFetchTimeoutMs();
try {
client = clientFactory.createClient(location.getHost(), location.getFetchPort(), 0);
TransportMessage openStreamMsg =
new TransportMessage(
MessageType.OPEN_STREAM,
PbOpenStream.newBuilder()
.setShuffleKey(shuffleKey)
.setFileName(location.getFileName())
.setStartIndex(startMapIndex)
.setEndIndex(endMapIndex)
.setReadLocalShuffle(true)
.build()
.toByteArray());
ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(), fetchTimeoutMs);
streamHandler = TransportMessage.fromByteBuffer(response).getParsedPayload();
if (pbStreamHandler == null) {
TransportMessage openStreamMsg =
new TransportMessage(
MessageType.OPEN_STREAM,
PbOpenStream.newBuilder()
.setShuffleKey(shuffleKey)
.setFileName(location.getFileName())
.setStartIndex(startMapIndex)
.setEndIndex(endMapIndex)
.setReadLocalShuffle(true)
.build()
.toByteArray());
ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(), fetchTimeoutMs);
streamHandler = TransportMessage.fromByteBuffer(response).getParsedPayload();
} else {
this.streamHandler = pbStreamHandler;
}
} catch (IOException | InterruptedException e) {
throw new IOException(
"Read shuffle file from local file failed, partition location: "
Expand Down
Loading

0 comments on commit 400c9be

Please sign in to comment.