diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 0b9434a3b1b..cbcfae93808 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -38,6 +38,7 @@ import org.apache.celeborn.client.ClientUtils; import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.client.compress.Decompressor; +import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.network.client.TransportClient; @@ -417,6 +418,14 @@ private boolean isExcluded(PartitionLocation location) { private PartitionReader createReaderWithRetry( PartitionLocation location, PbStreamHandler pbStreamHandler) throws IOException { + return createReaderWithRetry(location, pbStreamHandler, Optional.empty()); + } + + private PartitionReader createReaderWithRetry( + PartitionLocation location, + PbStreamHandler pbStreamHandler, + Optional checkpointMetadata) + throws IOException { Exception lastException = null; while (fetchChunkRetryCnt < fetchChunkMaxRetry) { try { @@ -424,7 +433,14 @@ private PartitionReader createReaderWithRetry( if (isExcluded(location)) { throw new CelebornIOException("Fetch data from excluded worker! " + location); } - return createReader(location, pbStreamHandler, fetchChunkRetryCnt, fetchChunkMaxRetry); + PartitionReader reader = + createReader( + location, + pbStreamHandler, + fetchChunkRetryCnt, + fetchChunkMaxRetry, + checkpointMetadata); + return reader; } catch (Exception e) { lastException = e; shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort(), e); @@ -512,6 +528,8 @@ private ByteBuf getNextChunk() throws IOException { if (fetchChunkRetryCnt % 2 == 0) { Uninterruptibles.sleepUninterruptibly(retryWaitMs, TimeUnit.MILLISECONDS); } + // We must not use checkpoint for peer location since chunkIds don't always match + // across peers currentReader = createReaderWithRetry(currentReader.getLocation().getPeer(), null); } else { logger.warn( @@ -521,7 +539,14 @@ private ByteBuf getNextChunk() throws IOException { currentReader.getLocation(), e); Uninterruptibles.sleepUninterruptibly(retryWaitMs, TimeUnit.MILLISECONDS); - currentReader = createReaderWithRetry(currentReader.getLocation(), null); + // When reading from the same host again, it is possible to skip already read data + // chunks, + // improving read performance during retries. + currentReader = + createReaderWithRetry( + currentReader.getLocation(), + null, + currentReader.getPartitionReaderCheckpointMetadata()); } } } @@ -533,7 +558,8 @@ private PartitionReader createReader( PartitionLocation location, PbStreamHandler pbStreamHandler, int fetchChunkRetryCnt, - int fetchChunkMaxRetry) + int fetchChunkMaxRetry, + Optional checkpointMetadata) throws IOException, InterruptedException { StorageInfo storageInfo = location.getStorageInfo(); @@ -579,7 +605,8 @@ private PartitionReader createReader( fetchChunkMaxRetry, callback, startChunkIndex, - endChunkIndex); + endChunkIndex, + checkpointMetadata); } case S3: case HDFS: @@ -593,7 +620,8 @@ private PartitionReader createReader( endMapIndex, callback, startChunkIndex, - endChunkIndex); + endChunkIndex, + checkpointMetadata); default: throw new CelebornIOException( String.format("Unknown storage info %s to read location %s", storageInfo, location)); @@ -789,6 +817,7 @@ private boolean fillBuffer() throws IOException { hasData = true; break; } else { + callback.incDuplicateBytesRead(BATCH_HEADER_SIZE + size); logger.debug( "Skip duplicated batch: mapId {}, attemptId {}, batchId {}.", mapId, diff --git a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java index 6cca0d47be0..003c7cd1264 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -28,7 +29,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; -import io.netty.util.ReferenceCounted; +import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; @@ -36,6 +37,7 @@ import org.slf4j.LoggerFactory; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.network.client.TransportClient; import org.apache.celeborn.common.network.client.TransportClientFactory; @@ -57,13 +59,14 @@ public class DfsPartitionReader implements PartitionReader { PartitionLocation location; private final long shuffleChunkSize; private final int fetchMaxReqsInFlight; - private final LinkedBlockingQueue results; + private final LinkedBlockingQueue> results; private final AtomicReference exception = new AtomicReference<>(); private volatile boolean closed = false; private ExecutorService fetchThread; private boolean fetchThreadStarted; private FSDataInputStream dfsInputStream; private int numChunks = 0; + private int lastReturnedChunkId = -1; private int returnedChunks = 0; private int currentChunkIndex = 0; private int startChunkIndex; @@ -76,6 +79,8 @@ public class DfsPartitionReader implements PartitionReader { private Path dataFilePath; + private Optional partitionReaderCheckpointMetadata; + public DfsPartitionReader( CelebornConf conf, String shuffleKey, @@ -86,7 +91,8 @@ public DfsPartitionReader( int endMapIndex, MetricsCallback metricsCallback, int startChunkIndex, - int endChunkIndex) + int endChunkIndex, + Optional checkpointMetadata) throws IOException { this.conf = conf; shuffleChunkSize = conf.dfsReadChunkSize(); @@ -144,6 +150,16 @@ public DfsPartitionReader( : Math.min(chunkOffsets.size() - 2, endChunkIndex); this.currentChunkIndex = this.startChunkIndex; this.numChunks = this.endChunkIndex - this.startChunkIndex + 1; + + if (checkpointMetadata.isPresent()) { + this.partitionReaderCheckpointMetadata = checkpointMetadata; + this.returnedChunks = checkpointMetadata.get().getReturnedChunks().size(); + } else { + this.partitionReaderCheckpointMetadata = + conf.isPartitionReaderCheckpointEnabled() + ? Optional.of(new PartitionReaderCheckpointMetadata()) + : Optional.empty(); + } logger.debug( "DFS {} total offset count:{} chunk count: {} " + "start chunk index:{} end chunk index:{} offsets:{}", @@ -205,15 +221,32 @@ public boolean hasNext() { return returnedChunks < numChunks; } + private void checkpoint() { + if (lastReturnedChunkId != -1) { + partitionReaderCheckpointMetadata.ifPresent( + readerCheckpointMetadata -> readerCheckpointMetadata.checkpoint(lastReturnedChunkId)); + } + } + @Override public ByteBuf next() throws IOException, InterruptedException { - ByteBuf chunk = null; + Pair chunk = null; + checkpoint(); if (!fetchThreadStarted) { fetchThreadStarted = true; fetchThread.submit( () -> { try { while (!closed && currentChunkIndex <= endChunkIndex) { + if (partitionReaderCheckpointMetadata.isPresent() + && partitionReaderCheckpointMetadata.get().isCheckpointed(currentChunkIndex)) { + logger.info( + "Skipping chunk {} as it has already been returned," + + " likely by a previous reader for the same partition.", + currentChunkIndex); + currentChunkIndex++; + continue; + } while (results.size() >= fetchMaxReqsInFlight) { Thread.sleep(50); } @@ -238,7 +271,7 @@ public ByteBuf next() throws IOException, InterruptedException { break; } } - results.put(Unpooled.wrappedBuffer(buffer)); + results.put(Pair.of(currentChunkIndex, Unpooled.wrappedBuffer(buffer))); logger.debug("add index {} to results", currentChunkIndex++); } } catch (Exception e) { @@ -262,7 +295,8 @@ public ByteBuf next() throws IOException, InterruptedException { throw e; } returnedChunks++; - return chunk; + lastReturnedChunkId = chunk.getLeft(); + return chunk.getRight(); } private void checkException() throws IOException { @@ -284,7 +318,7 @@ public void close() { logger.warn("close DFS input stream failed.", e); } if (results.size() > 0) { - results.forEach(ReferenceCounted::release); + results.forEach(chunk -> chunk.getRight().release()); } results.clear(); closeStream(); @@ -308,4 +342,9 @@ private void closeStream() { public PartitionLocation getLocation() { return location; } + + @Override + public Optional getPartitionReaderCheckpointMetadata() { + return partitionReaderCheckpointMetadata; + } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java index 1de83eb4573..56982f9147b 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java @@ -22,6 +22,7 @@ import java.nio.channels.FileChannel; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -34,6 +35,7 @@ import org.slf4j.LoggerFactory; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.network.client.TransportClient; @@ -277,4 +279,10 @@ private void closeStream() { public PartitionLocation getLocation() { return location; } + + @Override + public Optional getPartitionReaderCheckpointMetadata() { + // TODO implement similar to {@link WorkerPartitionReader} + return Optional.empty(); + } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/MetricsCallback.java b/client/src/main/java/org/apache/celeborn/client/read/MetricsCallback.java index aae07ff2348..ab563cb8863 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/MetricsCallback.java +++ b/client/src/main/java/org/apache/celeborn/client/read/MetricsCallback.java @@ -21,4 +21,6 @@ public interface MetricsCallback { void incBytesRead(long bytesRead); void incReadTime(long time); + + default void incDuplicateBytesRead(long bytesRead) {} } diff --git a/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java index cf8009d9052..0ebb94e7467 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java @@ -18,9 +18,11 @@ package org.apache.celeborn.client.read; import java.io.IOException; +import java.util.Optional; import io.netty.buffer.ByteBuf; +import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import org.apache.celeborn.common.protocol.PartitionLocation; public interface PartitionReader { @@ -31,4 +33,6 @@ public interface PartitionReader { void close(); PartitionLocation getLocation(); + + Optional getPartitionReaderCheckpointMetadata(); } diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index 29236273663..df6f3902ca9 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -19,16 +19,18 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Optional; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import io.netty.buffer.ByteBuf; -import io.netty.util.ReferenceCounted; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.network.buffer.ManagedBuffer; @@ -53,12 +55,14 @@ public class WorkerPartitionReader implements PartitionReader { private TransportClient client; private MetricsCallback metricsCallback; + private int lastReturnedChunkId = -1; private int returnedChunks; private int chunkIndex; private int startChunkIndex; private int endChunkIndex; - private final LinkedBlockingQueue results; + private int inflightRequestCount; + private final LinkedBlockingQueue> results; private final ChunkReceivedCallback callback; private final AtomicReference exception = new AtomicReference<>(); @@ -72,6 +76,8 @@ public class WorkerPartitionReader implements PartitionReader { private int fetchChunkMaxRetry; private final boolean testFetch; + private Optional partitionReaderCheckpointMetadata; + WorkerPartitionReader( CelebornConf conf, String shuffleKey, @@ -84,12 +90,14 @@ public class WorkerPartitionReader implements PartitionReader { int fetchChunkMaxRetry, MetricsCallback metricsCallback, int startChunkIndex, - int endChunkIndex) + int endChunkIndex, + Optional checkpointMetadata) throws IOException, InterruptedException { this.shuffleKey = shuffleKey; fetchMaxReqsInFlight = conf.clientFetchMaxReqsInFlight(); results = new LinkedBlockingQueue<>(); fetchTimeoutMs = conf.clientFetchTimeoutMs(); + inflightRequestCount = 0; this.metricsCallback = metricsCallback; // only add the buffer to results queue if this reader is not closed. callback = @@ -101,7 +109,7 @@ public void onSuccess(int chunkIndex, ManagedBuffer buffer) { ByteBuf buf = ((NettyManagedBuffer) buffer).getBuf(); if (!closed) { buf.retain(); - results.add(buf); + results.add(Pair.of(chunkIndex, buf)); } } } @@ -147,6 +155,15 @@ public void onFailure(int chunkIndex, Throwable e) { this.clientFactory = clientFactory; this.fetchChunkRetryCnt = fetchChunkRetryCnt; this.fetchChunkMaxRetry = fetchChunkMaxRetry; + if (checkpointMetadata.isPresent()) { + this.partitionReaderCheckpointMetadata = checkpointMetadata; + this.returnedChunks = checkpointMetadata.get().getReturnedChunks().size(); + } else { + this.partitionReaderCheckpointMetadata = + conf.isPartitionReaderCheckpointEnabled() + ? Optional.of(new PartitionReaderCheckpointMetadata()) + : Optional.empty(); + } testFetch = conf.testFetchFailure(); ShuffleClient.incrementTotalReadCounter(); } @@ -156,13 +173,21 @@ public boolean hasNext() { return returnedChunks < endChunkIndex - startChunkIndex + 1; } + private void checkpoint() { + if (lastReturnedChunkId != -1) { + partitionReaderCheckpointMetadata.ifPresent( + readerCheckpointMetadata -> readerCheckpointMetadata.checkpoint(lastReturnedChunkId)); + } + } + @Override public ByteBuf next() throws IOException, InterruptedException { + checkpoint(); checkException(); if (chunkIndex <= endChunkIndex) { fetchChunks(); } - ByteBuf chunk = null; + Pair chunk = null; try { while (chunk == null) { checkException(); @@ -176,7 +201,9 @@ public ByteBuf next() throws IOException, InterruptedException { throw e; } returnedChunks++; - return chunk; + inflightRequestCount--; + lastReturnedChunkId = chunk.getLeft(); + return chunk.getRight(); } @Override @@ -185,7 +212,10 @@ public void close() { closed = true; } if (results.size() > 0) { - results.forEach(ReferenceCounted::release); + results.forEach( + chunk -> { + chunk.getRight().release(); + }); } results.clear(); closeStream(); @@ -210,14 +240,32 @@ public PartitionLocation getLocation() { return location; } + @Override + public Optional getPartitionReaderCheckpointMetadata() { + return partitionReaderCheckpointMetadata; + } + private void fetchChunks() throws IOException, InterruptedException { - final int inFlight = chunkIndex - startChunkIndex - returnedChunks; + final int inFlight = inflightRequestCount; if (inFlight < fetchMaxReqsInFlight) { - final int toFetch = - Math.min(fetchMaxReqsInFlight - inFlight + 1, endChunkIndex + 1 - chunkIndex); - for (int i = 0; i < toFetch; i++) { - if (testFetch && fetchChunkRetryCnt < fetchChunkMaxRetry - 1 && chunkIndex == 3) { + int toFetch = Math.min(fetchMaxReqsInFlight - inFlight + 1, endChunkIndex + 1 - chunkIndex); + + while (toFetch > 0 && chunkIndex <= endChunkIndex) { + if (partitionReaderCheckpointMetadata.isPresent() + && partitionReaderCheckpointMetadata.get().isCheckpointed(chunkIndex)) { + logger.info( + "Skipping chunk {} as it has already been returned," + + " likely by a previous reader for the same partition.", + chunkIndex); + chunkIndex++; + // IMP Since we're skipping fetching this chunk, we must not decrement toFetch here + // Eg: If chunkIndex=1, toFetch=2, endChunkIndex = 4 and chunkIdsAlreadyReturned = {1,2} + // if we skip chunk 1 and 2, decrementing toFetch here would wrongly exit the loop + // without ever fetching chunk {3, 4}, and next() would end up waiting for chunk {3,4} + // infinitely. + } else if (testFetch && fetchChunkRetryCnt < fetchChunkMaxRetry - 1 && chunkIndex == 3) { callback.onFailure(chunkIndex, new CelebornIOException("Test fetch chunk failure")); + toFetch--; } else { if (!client.isActive()) { try { @@ -236,7 +284,9 @@ private void fetchChunks() throws IOException, InterruptedException { } } client.fetchChunk(streamHandler.getStreamId(), chunkIndex, fetchTimeoutMs, callback); + inflightRequestCount++; chunkIndex++; + toFetch--; } } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java new file mode 100644 index 00000000000..353fa9f8d79 --- /dev/null +++ b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.read.checkpoint; + +import java.util.HashSet; +import java.util.Set; + +/** Checkpoint metadata interface for a partition reader. */ +public class PartitionReaderCheckpointMetadata { + private final Set returnedChunks; + + /** Create an instance of the checkpoint metadata. */ + public PartitionReaderCheckpointMetadata() { + this.returnedChunks = new HashSet<>(); + } + + public void checkpoint(int chunkId) { + returnedChunks.add(chunkId); + } + + public boolean isCheckpointed(int chunkId) { + return returnedChunks.contains(chunkId); + } + + public Set getReturnedChunks() { + return returnedChunks; + } +} diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 9d6d2e445c7..703e0f33fa9 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -982,6 +982,9 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientFetchTimeoutMs: Long = get(CLIENT_FETCH_TIMEOUT) def clientFetchBufferSize: Int = get(CLIENT_FETCH_BUFFER_SIZE).toInt def clientFetchMaxReqsInFlight: Int = get(CLIENT_FETCH_MAX_REQS_IN_FLIGHT) + def isPartitionReaderCheckpointEnabled: Boolean = + get(PARTITION_READER_CHECKPOINT_ENABLED) + def clientFetchMaxRetriesForEachReplica: Int = get(CLIENT_FETCH_MAX_RETRIES_FOR_EACH_REPLICA) def clientStageRerunEnabled: Boolean = get(CLIENT_STAGE_RERUN_ENABLED) def clientFetchExcludeWorkerOnFailureEnabled: Boolean = @@ -4691,6 +4694,15 @@ object CelebornConf extends Logging { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("64k") + val PARTITION_READER_CHECKPOINT_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.client.partition.reader.checkpoint.enabled") + .categories("client") + .version("0.6.0") + .doc("Whether or not checkpoint reads when re-creating a partition reader. Setting to true minimizes" + + " the amount of unnecessary reads during partition read retries") + .booleanConf + .createWithDefault(false) + val CLIENT_FETCH_MAX_REQS_IN_FLIGHT: ConfigEntry[Int] = buildConf("celeborn.client.fetch.maxReqsInFlight") .withAlternative("celeborn.fetch.maxReqsInFlight") diff --git a/docs/configuration/client.md b/docs/configuration/client.md index d9e1a700c77..c572308d8d8 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -45,6 +45,7 @@ license: | | celeborn.client.flink.shuffle.fallback.policy | AUTO | false | Celeborn supports the following kind of fallback policies. 1. ALWAYS: always use flink built-in shuffle implementation; 2. AUTO: prefer to use celeborn shuffle implementation, and fallback to use flink built-in shuffle implementation based on certain factors, e.g. availability of enough workers and quota; 3. NEVER: always use celeborn shuffle implementation, and fail fast when it it is concluded that fallback is required based on factors above. | 0.6.0 | | | celeborn.client.inputStream.creation.window | 16 | false | Window size that CelebornShuffleReader pre-creates CelebornInputStreams, for coalesced scenario where multiple Partitions are read | 0.5.1 | | | celeborn.client.mr.pushData.max | 32m | false | Max size for a push data sent from mr client. | 0.4.0 | | +| celeborn.client.partition.reader.checkpoint.enabled | false | false | Whether or not checkpoint reads when re-creating a partition reader. Setting to true minimizes the amount of unnecessary reads during partition read retries | 0.6.0 | | | celeborn.client.push.buffer.initial.size | 8k | false | | 0.3.0 | celeborn.push.buffer.initial.size | | celeborn.client.push.buffer.max.size | 64k | false | Max size of reducer partition buffer memory for shuffle hash writer. The pushed data will be buffered in memory before sending to Celeborn worker. For performance consideration keep this buffer size higher than 32K. Example: If reducer amount is 2000, buffer size is 64K, then each task will consume up to `64KiB * 2000 = 125MiB` heap memory. | 0.3.0 | celeborn.push.buffer.max.size | | celeborn.client.push.excludeWorkerOnFailure.enabled | false | false | Whether to enable shuffle client-side push exclude workers on failures. | 0.3.0 | | diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala new file mode 100644 index 00000000000..17c21d0a683 --- /dev/null +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.cluster + +import java.io.ByteArrayOutputStream +import java.nio.charset.StandardCharsets +import java.util +import java.util.UUID +import java.util.concurrent.atomic.AtomicLong + +import org.junit.Assert +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl} +import org.apache.celeborn.client.read.MetricsCallback +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.identity.UserIdentifier +import org.apache.celeborn.common.internal.Logging +import org.apache.celeborn.common.protocol.CompressionCodec +import org.apache.celeborn.service.deploy.MiniClusterFeature + +class ReadWriteTestWithFailures extends AnyFunSuite + with Logging with MiniClusterFeature with BeforeAndAfterAll { + + var masterPort = 0 + + override def beforeAll(): Unit = { + logInfo("test initialized , setup Celeborn mini cluster") + val (m, _) = setupMiniClusterWithRandomPorts(workerConf = + Map("celeborn.shuffle.chunk.size" -> "100B", "celeborn.worker.flusher.buffer.size" -> "10B")) + masterPort = m.conf.masterPort + } + + override def afterAll(): Unit = { + logInfo("all test complete , stop Celeborn mini cluster") + shutdownMiniCluster() + } + + test(s"test MiniCluster with connection resets, ensure no duplicate reads") { + Assert.assertEquals(performTest("true"), 0) + } + + test(s"test MiniCluster with connection resets, assert duplicate reads") { + Assert.assertTrue(performTest("false") > 0) + } + + def performTest(workerChunkLevelCheckpointEnabled: String): Long = { + val APP = "app-1" + + val clientConf = new CelebornConf() + .set(CelebornConf.MASTER_ENDPOINTS.key, s"localhost:$masterPort") + .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "true") + .set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "256K") + .set("celeborn.data.io.numConnectionsPerPeer", "1") + .set("celeborn.client.fetch.maxReqsInFlight", "1") + .set("celeborn.client.shuffle.compression.codec", CompressionCodec.NONE.toString) + .set(CelebornConf.TEST_CLIENT_FETCH_FAILURE.key, "true") + .set( + CelebornConf.PARTITION_READER_CHECKPOINT_ENABLED.key, + workerChunkLevelCheckpointEnabled) + .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "false") + + val lifecycleManager = new LifecycleManager(APP, clientConf) + val shuffleClient = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) + shuffleClient.setupLifecycleManagerRef(lifecycleManager.self) + + // push 100 random strings + val numuuids = 100 + val stringSet = new util.HashSet[String]() + for (i <- 0 until numuuids) { + val str = UUID.randomUUID().toString + stringSet.add(str) + val data = ("_" + str).getBytes(StandardCharsets.UTF_8) + shuffleClient.pushData(1, 0, 0, 0, data, 0, data.length, 1, 1) + } + shuffleClient.pushMergedData(1, 0, 0) + Thread.sleep(1000) + + shuffleClient.mapperEnd(1, 0, 0, 1) + + var duplicateBytesRead = new AtomicLong(0) + val metricsCallback = new MetricsCallback { + override def incBytesRead(bytesWritten: Long): Unit = {} + override def incReadTime(time: Long): Unit = {} + override def incDuplicateBytesRead(bytesRead: Long): Unit = { + duplicateBytesRead.addAndGet(bytesRead) + } + } + val inputStream = shuffleClient.readPartition( + 1, + 1, + 0, + 0, + 0, + 0, + Integer.MAX_VALUE, + null, + null, + null, + null, + null, + null, + metricsCallback) + + val outputStream = new ByteArrayOutputStream() + var b = inputStream.read() + while (b != -1) { + outputStream.write(b) + b = inputStream.read() + } + + val readStrings = + new String(outputStream.toByteArray, StandardCharsets.UTF_8).substring(1).split("_") + Assert.assertEquals(readStrings.length, numuuids) + readStrings.foreach { str => + Assert.assertTrue(stringSet.contains(str)) + } + + Thread.sleep(5000L) + shuffleClient.shutdown() + lifecycleManager.rpcEnv.shutdown() + + duplicateBytesRead.get() + } +}