Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CELEBORN-1894] Allow skipping already read chunks during unreplicated shuffle read retried #3132

Closed
wants to merge 14 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.celeborn.client.ClientUtils;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.compress.Decompressor;
import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.network.client.TransportClient;
Expand Down Expand Up @@ -417,14 +418,29 @@ private boolean isExcluded(PartitionLocation location) {

private PartitionReader createReaderWithRetry(
PartitionLocation location, PbStreamHandler pbStreamHandler) throws IOException {
return createReaderWithRetry(location, pbStreamHandler, Optional.empty());
}

private PartitionReader createReaderWithRetry(
PartitionLocation location,
PbStreamHandler pbStreamHandler,
Optional<PartitionReaderCheckpointMetadata> checkpointMetadata)
throws IOException {
Exception lastException = null;
while (fetchChunkRetryCnt < fetchChunkMaxRetry) {
try {
logger.debug("Create reader for location {}", location);
if (isExcluded(location)) {
throw new CelebornIOException("Fetch data from excluded worker! " + location);
}
return createReader(location, pbStreamHandler, fetchChunkRetryCnt, fetchChunkMaxRetry);
PartitionReader reader =
createReader(
location,
pbStreamHandler,
fetchChunkRetryCnt,
fetchChunkMaxRetry,
checkpointMetadata);
return reader;
} catch (Exception e) {
lastException = e;
shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort(), e);
Expand Down Expand Up @@ -512,6 +528,8 @@ private ByteBuf getNextChunk() throws IOException {
if (fetchChunkRetryCnt % 2 == 0) {
Uninterruptibles.sleepUninterruptibly(retryWaitMs, TimeUnit.MILLISECONDS);
}
// We must not use checkpoint for peer location since chunkIds don't always match
// across peers
currentReader = createReaderWithRetry(currentReader.getLocation().getPeer(), null);
} else {
logger.warn(
Expand All @@ -521,7 +539,14 @@ private ByteBuf getNextChunk() throws IOException {
currentReader.getLocation(),
e);
Uninterruptibles.sleepUninterruptibly(retryWaitMs, TimeUnit.MILLISECONDS);
currentReader = createReaderWithRetry(currentReader.getLocation(), null);
// When reading from the same host again, it is possible to skip already read data
// chunks,
// improving read performance during retries.
currentReader =
createReaderWithRetry(
currentReader.getLocation(),
null,
currentReader.getPartitionReaderCheckpointMetadata());
}
}
}
Expand All @@ -533,7 +558,8 @@ private PartitionReader createReader(
PartitionLocation location,
PbStreamHandler pbStreamHandler,
int fetchChunkRetryCnt,
int fetchChunkMaxRetry)
int fetchChunkMaxRetry,
Optional<PartitionReaderCheckpointMetadata> checkpointMetadata)
throws IOException, InterruptedException {

StorageInfo storageInfo = location.getStorageInfo();
Expand Down Expand Up @@ -579,7 +605,8 @@ private PartitionReader createReader(
fetchChunkMaxRetry,
callback,
startChunkIndex,
endChunkIndex);
endChunkIndex,
checkpointMetadata);
}
case S3:
case HDFS:
Expand All @@ -593,7 +620,8 @@ private PartitionReader createReader(
endMapIndex,
callback,
startChunkIndex,
endChunkIndex);
endChunkIndex,
checkpointMetadata);
default:
throw new CelebornIOException(
String.format("Unknown storage info %s to read location %s", storageInfo, location));
Expand Down Expand Up @@ -789,6 +817,7 @@ private boolean fillBuffer() throws IOException {
hasData = true;
break;
} else {
callback.incDuplicateBytesRead(BATCH_HEADER_SIZE + size);
logger.debug(
"Skip duplicated batch: mapId {}, attemptId {}, batchId {}.",
mapId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,23 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.util.ReferenceCounted;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
Expand All @@ -57,13 +59,14 @@ public class DfsPartitionReader implements PartitionReader {
PartitionLocation location;
private final long shuffleChunkSize;
private final int fetchMaxReqsInFlight;
private final LinkedBlockingQueue<ByteBuf> results;
private final LinkedBlockingQueue<Pair<Integer, ByteBuf>> results;
private final AtomicReference<IOException> exception = new AtomicReference<>();
private volatile boolean closed = false;
private ExecutorService fetchThread;
private boolean fetchThreadStarted;
private FSDataInputStream dfsInputStream;
private int numChunks = 0;
private int lastReturnedChunkId = -1;
private int returnedChunks = 0;
private int currentChunkIndex = 0;
private int startChunkIndex;
Expand All @@ -76,6 +79,8 @@ public class DfsPartitionReader implements PartitionReader {

private Path dataFilePath;

private Optional<PartitionReaderCheckpointMetadata> partitionReaderCheckpointMetadata;

public DfsPartitionReader(
CelebornConf conf,
String shuffleKey,
Expand All @@ -86,7 +91,8 @@ public DfsPartitionReader(
int endMapIndex,
MetricsCallback metricsCallback,
int startChunkIndex,
int endChunkIndex)
int endChunkIndex,
Optional<PartitionReaderCheckpointMetadata> checkpointMetadata)
throws IOException {
this.conf = conf;
shuffleChunkSize = conf.dfsReadChunkSize();
Expand Down Expand Up @@ -144,6 +150,16 @@ public DfsPartitionReader(
: Math.min(chunkOffsets.size() - 2, endChunkIndex);
this.currentChunkIndex = this.startChunkIndex;
this.numChunks = this.endChunkIndex - this.startChunkIndex + 1;

if (checkpointMetadata.isPresent()) {
this.partitionReaderCheckpointMetadata = checkpointMetadata;
this.returnedChunks = checkpointMetadata.get().getReturnedChunks().size();
} else {
this.partitionReaderCheckpointMetadata =
conf.isPartitionReaderCheckpointEnabled()
? Optional.of(new PartitionReaderCheckpointMetadata())
: Optional.empty();
}
logger.debug(
"DFS {} total offset count:{} chunk count: {} "
+ "start chunk index:{} end chunk index:{} offsets:{}",
Expand Down Expand Up @@ -205,15 +221,32 @@ public boolean hasNext() {
return returnedChunks < numChunks;
}

private void checkpoint() {
if (lastReturnedChunkId != -1) {
partitionReaderCheckpointMetadata.ifPresent(
readerCheckpointMetadata -> readerCheckpointMetadata.checkpoint(lastReturnedChunkId));
}
}

@Override
public ByteBuf next() throws IOException, InterruptedException {
ByteBuf chunk = null;
Pair<Integer, ByteBuf> chunk = null;
checkpoint();
if (!fetchThreadStarted) {
fetchThreadStarted = true;
fetchThread.submit(
() -> {
try {
while (!closed && currentChunkIndex <= endChunkIndex) {
if (partitionReaderCheckpointMetadata.isPresent()
&& partitionReaderCheckpointMetadata.get().isCheckpointed(currentChunkIndex)) {
logger.info(
"Skipping chunk {} as it has already been returned,"
+ " likely by a previous reader for the same partition.",
currentChunkIndex);
currentChunkIndex++;
continue;
}
while (results.size() >= fetchMaxReqsInFlight) {
Thread.sleep(50);
}
Expand All @@ -238,7 +271,7 @@ public ByteBuf next() throws IOException, InterruptedException {
break;
}
}
results.put(Unpooled.wrappedBuffer(buffer));
results.put(Pair.of(currentChunkIndex, Unpooled.wrappedBuffer(buffer)));
logger.debug("add index {} to results", currentChunkIndex++);
}
} catch (Exception e) {
Expand All @@ -262,7 +295,8 @@ public ByteBuf next() throws IOException, InterruptedException {
throw e;
}
returnedChunks++;
return chunk;
lastReturnedChunkId = chunk.getLeft();
return chunk.getRight();
}

private void checkException() throws IOException {
Expand All @@ -284,7 +318,7 @@ public void close() {
logger.warn("close DFS input stream failed.", e);
}
if (results.size() > 0) {
results.forEach(ReferenceCounted::release);
results.forEach(chunk -> chunk.getRight().release());
}
results.clear();
closeStream();
Expand All @@ -308,4 +342,9 @@ private void closeStream() {
public PartitionLocation getLocation() {
return location;
}

@Override
public Optional<PartitionReaderCheckpointMetadata> getPartitionReaderCheckpointMetadata() {
return partitionReaderCheckpointMetadata;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -34,6 +35,7 @@
import org.slf4j.LoggerFactory;

import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.network.client.TransportClient;
Expand Down Expand Up @@ -277,4 +279,10 @@ private void closeStream() {
public PartitionLocation getLocation() {
return location;
}

@Override
public Optional<PartitionReaderCheckpointMetadata> getPartitionReaderCheckpointMetadata() {
// TODO implement similar to {@link WorkerPartitionReader}
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ public interface MetricsCallback {
void incBytesRead(long bytesRead);

void incReadTime(long time);

default void incDuplicateBytesRead(long bytesRead) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
package org.apache.celeborn.client.read;

import java.io.IOException;
import java.util.Optional;

import io.netty.buffer.ByteBuf;

import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata;
import org.apache.celeborn.common.protocol.PartitionLocation;

public interface PartitionReader {
Expand All @@ -31,4 +33,6 @@ public interface PartitionReader {
void close();

PartitionLocation getLocation();

Optional<PartitionReaderCheckpointMetadata> getPartitionReaderCheckpointMetadata();
}
Loading
Loading