From 46726552890d61174dcbacb367d5ccc6dfea4327 Mon Sep 17 00:00:00 2001 From: Saurabh Dubey Date: Sun, 2 Mar 2025 12:43:46 +0530 Subject: [PATCH 01/14] Allow skipping already read chunks to improve read performance in unreplicated scenario --- .../celeborn/CelebornShuffleReader.scala | 2 + .../client/read/CelebornInputStream.java | 22 +++- .../celeborn/client/read/MetricsCallback.java | 2 + .../celeborn/client/read/PartitionReader.java | 9 +- .../client/read/WorkerPartitionReader.java | 69 +++++++---- .../client/WithShuffleClientSuite.scala | 1 + .../apache/celeborn/common/CelebornConf.scala | 10 ++ .../deploy/cluster/ReadWriteTestBase.scala | 1 + .../cluster/ReadWriteTestWithFailures.scala | 112 ++++++++++++++++++ 9 files changed, 204 insertions(+), 24 deletions(-) create mode 100644 worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index b35e549ae38..d84ea52761c 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -91,6 +91,8 @@ class CelebornShuffleReader[K, C]( override def incReadTime(time: Long): Unit = metrics.incFetchWaitTime(time) + + override def incDuplicateBytesRead(bytesRead: Long): Unit = {} } if (streamCreatorPool == null) { diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 0b9434a3b1b..172a49c61bd 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -24,6 +24,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.LongAdder; +import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import scala.Tuple2; import com.github.luben.zstd.ZstdException; @@ -415,16 +416,27 @@ private boolean isExcluded(PartitionLocation location) { } } + private PartitionReader createReaderWithRetry( PartitionLocation location, PbStreamHandler pbStreamHandler) + throws IOException { + return createReaderWithRetry(location, pbStreamHandler, Optional.empty()); + } + private PartitionReader createReaderWithRetry( - PartitionLocation location, PbStreamHandler pbStreamHandler) throws IOException { + PartitionLocation location, PbStreamHandler pbStreamHandler, + Optional checkpointMetadata) throws IOException { Exception lastException = null; + PartitionReader reader = null; while (fetchChunkRetryCnt < fetchChunkMaxRetry) { try { logger.debug("Create reader for location {}", location); if (isExcluded(location)) { throw new CelebornIOException("Fetch data from excluded worker! " + location); } - return createReader(location, pbStreamHandler, fetchChunkRetryCnt, fetchChunkMaxRetry); + reader = createReader(location, pbStreamHandler, fetchChunkRetryCnt, fetchChunkMaxRetry); + if (checkpointMetadata.isPresent()) { + reader.updateCheckpointMetadata(checkpointMetadata.get()); + } + return reader; } catch (Exception e) { lastException = e; shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort(), e); @@ -512,6 +524,7 @@ private ByteBuf getNextChunk() throws IOException { if (fetchChunkRetryCnt % 2 == 0) { Uninterruptibles.sleepUninterruptibly(retryWaitMs, TimeUnit.MILLISECONDS); } + // We must not use checkpoint for peer location since chunkIds don't always match across peers currentReader = createReaderWithRetry(currentReader.getLocation().getPeer(), null); } else { logger.warn( @@ -521,7 +534,9 @@ private ByteBuf getNextChunk() throws IOException { currentReader.getLocation(), e); Uninterruptibles.sleepUninterruptibly(retryWaitMs, TimeUnit.MILLISECONDS); - currentReader = createReaderWithRetry(currentReader.getLocation(), null); + // When reading from the same host again, it is possible to skip already read data chunks, + // improving read performance during retries. + currentReader = createReaderWithRetry(currentReader.getLocation(), null, Optional.ofNullable(currentReader.getPartitionReaderCheckpointMetadata())); } } } @@ -789,6 +804,7 @@ private boolean fillBuffer() throws IOException { hasData = true; break; } else { + callback.incDuplicateBytesRead(BATCH_HEADER_SIZE + size); logger.debug( "Skip duplicated batch: mapId {}, attemptId {}, batchId {}.", mapId, diff --git a/client/src/main/java/org/apache/celeborn/client/read/MetricsCallback.java b/client/src/main/java/org/apache/celeborn/client/read/MetricsCallback.java index aae07ff2348..f79ea4038a2 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/MetricsCallback.java +++ b/client/src/main/java/org/apache/celeborn/client/read/MetricsCallback.java @@ -21,4 +21,6 @@ public interface MetricsCallback { void incBytesRead(long bytesRead); void incReadTime(long time); + + void incDuplicateBytesRead(long bytesRead); } diff --git a/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java index cf8009d9052..9b7ff752961 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java @@ -21,9 +21,10 @@ import io.netty.buffer.ByteBuf; +import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import org.apache.celeborn.common.protocol.PartitionLocation; -public interface PartitionReader { +public interface PartitionReader { boolean hasNext(); ByteBuf next() throws IOException, InterruptedException; @@ -31,4 +32,10 @@ public interface PartitionReader { void close(); PartitionLocation getLocation(); + + default T getPartitionReaderCheckpointMetadata() { + return null; + } + + default void updateCheckpointMetadata(T checkpointMetadata) {} } diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index 29236273663..041f8ee55a3 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -17,18 +17,9 @@ package org.apache.celeborn.client.read; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; - import io.netty.buffer.ByteBuf; -import io.netty.util.ReferenceCounted; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.read.checkpoint.WorkerPartitionReaderCheckpointMetadata; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.network.buffer.ManagedBuffer; @@ -44,8 +35,19 @@ import org.apache.celeborn.common.protocol.PbStreamHandler; import org.apache.celeborn.common.protocol.StreamType; import org.apache.celeborn.common.util.ExceptionUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; -public class WorkerPartitionReader implements PartitionReader { +public class WorkerPartitionReader implements PartitionReader { private final Logger logger = LoggerFactory.getLogger(WorkerPartitionReader.class); private PartitionLocation location; private final TransportClientFactory clientFactory; @@ -58,7 +60,7 @@ public class WorkerPartitionReader implements PartitionReader { private int startChunkIndex; private int endChunkIndex; - private final LinkedBlockingQueue results; + private final LinkedBlockingQueue> results; private final ChunkReceivedCallback callback; private final AtomicReference exception = new AtomicReference<>(); @@ -72,6 +74,10 @@ public class WorkerPartitionReader implements PartitionReader { private int fetchChunkMaxRetry; private final boolean testFetch; + // checkpoints + private final boolean isCheckpointEnabled; + private Set chunkIdsAlreadyReturned; + WorkerPartitionReader( CelebornConf conf, String shuffleKey, @@ -101,7 +107,7 @@ public void onSuccess(int chunkIndex, ManagedBuffer buffer) { ByteBuf buf = ((NettyManagedBuffer) buffer).getBuf(); if (!closed) { buf.retain(); - results.add(buf); + results.add(Pair.of(chunkIndex, buf)); } } } @@ -147,13 +153,15 @@ public void onFailure(int chunkIndex, Throwable e) { this.clientFactory = clientFactory; this.fetchChunkRetryCnt = fetchChunkRetryCnt; this.fetchChunkMaxRetry = fetchChunkMaxRetry; + this.chunkIdsAlreadyReturned = new HashSet<>(); + this.isCheckpointEnabled = conf.isWorkerPartitionReaderCheckpointEnabled(); testFetch = conf.testFetchFailure(); ShuffleClient.incrementTotalReadCounter(); } @Override public boolean hasNext() { - return returnedChunks < endChunkIndex - startChunkIndex + 1; + return chunkIdsAlreadyReturned.size() < endChunkIndex - startChunkIndex + 1; } @Override @@ -162,7 +170,7 @@ public ByteBuf next() throws IOException, InterruptedException { if (chunkIndex <= endChunkIndex) { fetchChunks(); } - ByteBuf chunk = null; + Pair chunk = null; try { while (chunk == null) { checkException(); @@ -176,7 +184,8 @@ public ByteBuf next() throws IOException, InterruptedException { throw e; } returnedChunks++; - return chunk; + chunkIdsAlreadyReturned.add(chunk.getLeft()); + return chunk.getRight(); } @Override @@ -185,7 +194,9 @@ public void close() { closed = true; } if (results.size() > 0) { - results.forEach(ReferenceCounted::release); + results.forEach(chunk -> { + chunk.getRight().release(); // + }); } results.clear(); closeStream(); @@ -210,14 +221,31 @@ public PartitionLocation getLocation() { return location; } + @Override + public WorkerPartitionReaderCheckpointMetadata getPartitionReaderCheckpointMetadata() { + return isCheckpointEnabled ? new WorkerPartitionReaderCheckpointMetadata(chunkIdsAlreadyReturned) : null; + } + + @Override + public void updateCheckpointMetadata(WorkerPartitionReaderCheckpointMetadata checkpointMetadata) { + chunkIdsAlreadyReturned = checkpointMetadata.getReturnedChunks(); + } + private void fetchChunks() throws IOException, InterruptedException { final int inFlight = chunkIndex - startChunkIndex - returnedChunks; if (inFlight < fetchMaxReqsInFlight) { - final int toFetch = + int toFetch = Math.min(fetchMaxReqsInFlight - inFlight + 1, endChunkIndex + 1 - chunkIndex); - for (int i = 0; i < toFetch; i++) { - if (testFetch && fetchChunkRetryCnt < fetchChunkMaxRetry - 1 && chunkIndex == 3) { + + while (toFetch > 0 && chunkIndex <= endChunkIndex) { + if (chunkIdsAlreadyReturned.contains(chunkIndex)) { + logger.info("Skipping chunk {} as it has already been returned," + + " likely by a previous reader for the same partition.", chunkIndex); + chunkIndex++; + returnedChunks++; + } else if (testFetch && fetchChunkRetryCnt < fetchChunkMaxRetry - 1 && chunkIndex == 3) { callback.onFailure(chunkIndex, new CelebornIOException("Test fetch chunk failure")); + toFetch--; } else { if (!client.isActive()) { try { @@ -237,6 +265,7 @@ private void fetchChunks() throws IOException, InterruptedException { } client.fetchChunk(streamHandler.getStreamId(), chunkIndex, fetchTimeoutMs, callback); chunkIndex++; + toFetch--; } } } diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala index 0570760ce14..935faf922d2 100644 --- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala @@ -149,6 +149,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { val metricsCallback = new MetricsCallback { override def incBytesRead(bytesWritten: Long): Unit = {} override def incReadTime(time: Long): Unit = {} + override def incDuplicateBytesRead(bytesRead: Long): Unit = {} } // reduce normal empty CelebornInputStream diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 9d6d2e445c7..159724a5ac5 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -982,6 +982,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientFetchTimeoutMs: Long = get(CLIENT_FETCH_TIMEOUT) def clientFetchBufferSize: Int = get(CLIENT_FETCH_BUFFER_SIZE).toInt def clientFetchMaxReqsInFlight: Int = get(CLIENT_FETCH_MAX_REQS_IN_FLIGHT) + def isWorkerPartitionReaderCheckpointEnabled: Boolean = get(WORKER_PARTITION_READER_CHECKPOINT_ENABLE) def clientFetchMaxRetriesForEachReplica: Int = get(CLIENT_FETCH_MAX_RETRIES_FOR_EACH_REPLICA) def clientStageRerunEnabled: Boolean = get(CLIENT_STAGE_RERUN_ENABLED) def clientFetchExcludeWorkerOnFailureEnabled: Boolean = @@ -4691,6 +4692,15 @@ object CelebornConf extends Logging { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("64k") + val WORKER_PARTITION_READER_CHECKPOINT_ENABLE: ConfigEntry[Boolean] = + buildConf("celeborn.worker.partition.reader.checkpointEnabled") + .categories("client") + .version("0.5.0") + .doc("Whether or not checkpoint reads when re-creating a worker reader. Setting to true minimizes" + + " the amount of unnecessary reads during client read retries") + .booleanConf + .createWithDefault(false) + val CLIENT_FETCH_MAX_REQS_IN_FLIGHT: ConfigEntry[Int] = buildConf("celeborn.client.fetch.maxReqsInFlight") .withAlternative("celeborn.fetch.maxReqsInFlight") diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala index e81593e1d99..61a2e895b7f 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala @@ -103,6 +103,7 @@ trait ReadWriteTestBase extends AnyFunSuite val metricsCallback = new MetricsCallback { override def incBytesRead(bytesWritten: Long): Unit = {} override def incReadTime(time: Long): Unit = {} + override def incDuplicateBytesRead(bytesRead: Long): Unit = {} } val inputStream = shuffleClient.readPartition( 1, diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala new file mode 100644 index 00000000000..32815378231 --- /dev/null +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala @@ -0,0 +1,112 @@ +package org.apache.celeborn.service.deploy.cluster + +import org.apache.celeborn.client.read.MetricsCallback +import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl} +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.identity.UserIdentifier +import org.apache.celeborn.common.internal.Logging +import org.apache.celeborn.common.protocol.CompressionCodec +import org.apache.celeborn.service.deploy.MiniClusterFeature +import org.junit.Assert +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import java.io.ByteArrayOutputStream +import java.nio.charset.StandardCharsets +import java.util +import java.util.UUID +import java.util.concurrent.atomic.AtomicLong + +class ReadWriteTestWithFailures extends AnyFunSuite + with Logging with MiniClusterFeature with BeforeAndAfterAll { + + var masterPort = 0 + + override def beforeAll(): Unit = { + logInfo("test initialized , setup Celeborn mini cluster") + val (m, _) = setupMiniClusterWithRandomPorts(workerConf = Map("celeborn.shuffle.chunk.size" -> "100B", + "celeborn.worker.flusher.buffer.size" -> "10B")) + masterPort = m.conf.masterPort + } + + override def afterAll(): Unit = { + logInfo("all test complete , stop Celeborn mini cluster") + shutdownMiniCluster() + } + + test(s"test MiniCluster with connection resets, ensure no duplicate reads") { + val APP = "app-1" + + val clientConf = new CelebornConf() + .set(CelebornConf.MASTER_ENDPOINTS.key, s"localhost:$masterPort") + .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "true") + .set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "256K") + .set("celeborn.data.io.numConnectionsPerPeer", "1") + .set("celeborn.client.fetch.maxReqsInFlight", "1") + .set("celeborn.client.shuffle.compression.codec", CompressionCodec.NONE.toString) + .set(CelebornConf.TEST_CLIENT_FETCH_FAILURE.key, "true") + .set(CelebornConf.WORKER_PARTITION_READER_CHECKPOINT_ENABLE.key, "true") + .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "false") + + val lifecycleManager = new LifecycleManager(APP, clientConf) + val shuffleClient = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) + shuffleClient.setupLifecycleManagerRef(lifecycleManager.self) + + // push 100 random strings + val numuuids = 100 + val stringSet = new util.HashSet[String]() + for (i <- 0 until numuuids) { + val str = UUID.randomUUID().toString + stringSet.add(str) + val data = ("_" + str).getBytes(StandardCharsets.UTF_8) + shuffleClient.pushData(1, 0, 0, 0, data, 0, data.length, 1, 1) + } + shuffleClient.pushMergedData(1, 0, 0) + Thread.sleep(1000) + + shuffleClient.mapperEnd(1, 0, 0, 1) + + var duplicateBytesRead = new AtomicLong(0) + val metricsCallback = new MetricsCallback { + override def incBytesRead(bytesWritten: Long): Unit = {} + override def incReadTime(time: Long): Unit = {} + override def incDuplicateBytesRead(bytesRead: Long): Unit = { + duplicateBytesRead.addAndGet(bytesRead) + } + } + val inputStream = shuffleClient.readPartition( + 1, + 1, + 0, + 0, + 0, + 0, + Integer.MAX_VALUE, + null, + null, + null, + null, + null, + null, + metricsCallback) + + val outputStream = new ByteArrayOutputStream() + var b = inputStream.read() + while (b != -1) { + outputStream.write(b) + b = inputStream.read() + } + + val readStrings = new String(outputStream.toByteArray, StandardCharsets.UTF_8).substring(1).split("_") + Assert.assertEquals(readStrings.length, numuuids) + readStrings.foreach { str => + Assert.assertTrue(stringSet.contains(str)) + } + + // Assert no duplicate chunks read despite chunk fetch retries + Assert.assertEquals(duplicateBytesRead.get(), 0) + Thread.sleep(5000L) + shuffleClient.shutdown() + lifecycleManager.rpcEnv.shutdown() + } +} From a95dbb3c6b65e7cb68d74ef35d56f9dc99d3f9ab Mon Sep 17 00:00:00 2001 From: Saurabh Dubey Date: Tue, 4 Mar 2025 15:55:18 +0530 Subject: [PATCH 02/14] Lint fix --- .../client/read/CelebornInputStream.java | 24 ++++++---- .../celeborn/client/read/PartitionReader.java | 2 +- .../client/read/WorkerPartitionReader.java | 46 +++++++++++-------- .../apache/celeborn/common/CelebornConf.scala | 3 +- .../cluster/ReadWriteTestWithFailures.scala | 43 +++++++++-------- 5 files changed, 68 insertions(+), 50 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 172a49c61bd..0573dbbba57 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -24,7 +24,6 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.LongAdder; -import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import scala.Tuple2; import com.github.luben.zstd.ZstdException; @@ -39,6 +38,7 @@ import org.apache.celeborn.client.ClientUtils; import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.client.compress.Decompressor; +import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.network.client.TransportClient; @@ -416,14 +416,16 @@ private boolean isExcluded(PartitionLocation location) { } } - private PartitionReader createReaderWithRetry( PartitionLocation location, PbStreamHandler pbStreamHandler) - throws IOException { + private PartitionReader createReaderWithRetry( + PartitionLocation location, PbStreamHandler pbStreamHandler) throws IOException { return createReaderWithRetry(location, pbStreamHandler, Optional.empty()); } private PartitionReader createReaderWithRetry( - PartitionLocation location, PbStreamHandler pbStreamHandler, - Optional checkpointMetadata) throws IOException { + PartitionLocation location, + PbStreamHandler pbStreamHandler, + Optional checkpointMetadata) + throws IOException { Exception lastException = null; PartitionReader reader = null; while (fetchChunkRetryCnt < fetchChunkMaxRetry) { @@ -524,7 +526,8 @@ private ByteBuf getNextChunk() throws IOException { if (fetchChunkRetryCnt % 2 == 0) { Uninterruptibles.sleepUninterruptibly(retryWaitMs, TimeUnit.MILLISECONDS); } - // We must not use checkpoint for peer location since chunkIds don't always match across peers + // We must not use checkpoint for peer location since chunkIds don't always match + // across peers currentReader = createReaderWithRetry(currentReader.getLocation().getPeer(), null); } else { logger.warn( @@ -534,9 +537,14 @@ private ByteBuf getNextChunk() throws IOException { currentReader.getLocation(), e); Uninterruptibles.sleepUninterruptibly(retryWaitMs, TimeUnit.MILLISECONDS); - // When reading from the same host again, it is possible to skip already read data chunks, + // When reading from the same host again, it is possible to skip already read data + // chunks, // improving read performance during retries. - currentReader = createReaderWithRetry(currentReader.getLocation(), null, Optional.ofNullable(currentReader.getPartitionReaderCheckpointMetadata())); + currentReader = + createReaderWithRetry( + currentReader.getLocation(), + null, + Optional.ofNullable(currentReader.getPartitionReaderCheckpointMetadata())); } } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java index 9b7ff752961..42d5d02459b 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java @@ -24,7 +24,7 @@ import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import org.apache.celeborn.common.protocol.PartitionLocation; -public interface PartitionReader { +public interface PartitionReader { boolean hasNext(); ByteBuf next() throws IOException, InterruptedException; diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index 041f8ee55a3..91b9095b022 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -17,7 +17,19 @@ package org.apache.celeborn.client.read; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.client.read.checkpoint.WorkerPartitionReaderCheckpointMetadata; import org.apache.celeborn.common.CelebornConf; @@ -35,19 +47,9 @@ import org.apache.celeborn.common.protocol.PbStreamHandler; import org.apache.celeborn.common.protocol.StreamType; import org.apache.celeborn.common.util.ExceptionUtils; -import org.apache.commons.lang3.tuple.Pair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.HashSet; -import java.util.Set; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -public class WorkerPartitionReader implements PartitionReader { +public class WorkerPartitionReader + implements PartitionReader { private final Logger logger = LoggerFactory.getLogger(WorkerPartitionReader.class); private PartitionLocation location; private final TransportClientFactory clientFactory; @@ -194,9 +196,10 @@ public void close() { closed = true; } if (results.size() > 0) { - results.forEach(chunk -> { - chunk.getRight().release(); // - }); + results.forEach( + chunk -> { + chunk.getRight().release(); // + }); } results.clear(); closeStream(); @@ -223,7 +226,9 @@ public PartitionLocation getLocation() { @Override public WorkerPartitionReaderCheckpointMetadata getPartitionReaderCheckpointMetadata() { - return isCheckpointEnabled ? new WorkerPartitionReaderCheckpointMetadata(chunkIdsAlreadyReturned) : null; + return isCheckpointEnabled + ? new WorkerPartitionReaderCheckpointMetadata(chunkIdsAlreadyReturned) + : null; } @Override @@ -234,13 +239,14 @@ public void updateCheckpointMetadata(WorkerPartitionReaderCheckpointMetadata che private void fetchChunks() throws IOException, InterruptedException { final int inFlight = chunkIndex - startChunkIndex - returnedChunks; if (inFlight < fetchMaxReqsInFlight) { - int toFetch = - Math.min(fetchMaxReqsInFlight - inFlight + 1, endChunkIndex + 1 - chunkIndex); + int toFetch = Math.min(fetchMaxReqsInFlight - inFlight + 1, endChunkIndex + 1 - chunkIndex); while (toFetch > 0 && chunkIndex <= endChunkIndex) { if (chunkIdsAlreadyReturned.contains(chunkIndex)) { - logger.info("Skipping chunk {} as it has already been returned," + - " likely by a previous reader for the same partition.", chunkIndex); + logger.info( + "Skipping chunk {} as it has already been returned," + + " likely by a previous reader for the same partition.", + chunkIndex); chunkIndex++; returnedChunks++; } else if (testFetch && fetchChunkRetryCnt < fetchChunkMaxRetry - 1 && chunkIndex == 3) { diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 159724a5ac5..7f88622eacb 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -982,7 +982,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientFetchTimeoutMs: Long = get(CLIENT_FETCH_TIMEOUT) def clientFetchBufferSize: Int = get(CLIENT_FETCH_BUFFER_SIZE).toInt def clientFetchMaxReqsInFlight: Int = get(CLIENT_FETCH_MAX_REQS_IN_FLIGHT) - def isWorkerPartitionReaderCheckpointEnabled: Boolean = get(WORKER_PARTITION_READER_CHECKPOINT_ENABLE) + def isWorkerPartitionReaderCheckpointEnabled: Boolean = + get(WORKER_PARTITION_READER_CHECKPOINT_ENABLE) def clientFetchMaxRetriesForEachReplica: Int = get(CLIENT_FETCH_MAX_RETRIES_FOR_EACH_REPLICA) def clientStageRerunEnabled: Boolean = get(CLIENT_STAGE_RERUN_ENABLED) def clientFetchExcludeWorkerOnFailureEnabled: Boolean = diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala index 32815378231..85a5f16bf10 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala @@ -1,21 +1,22 @@ package org.apache.celeborn.service.deploy.cluster -import org.apache.celeborn.client.read.MetricsCallback +import java.io.ByteArrayOutputStream +import java.nio.charset.StandardCharsets +import java.util +import java.util.UUID +import java.util.concurrent.atomic.AtomicLong + +import org.junit.Assert +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl} +import org.apache.celeborn.client.read.MetricsCallback import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.protocol.CompressionCodec import org.apache.celeborn.service.deploy.MiniClusterFeature -import org.junit.Assert -import org.scalatest.BeforeAndAfterAll -import org.scalatest.funsuite.AnyFunSuite - -import java.io.ByteArrayOutputStream -import java.nio.charset.StandardCharsets -import java.util -import java.util.UUID -import java.util.concurrent.atomic.AtomicLong class ReadWriteTestWithFailures extends AnyFunSuite with Logging with MiniClusterFeature with BeforeAndAfterAll { @@ -24,8 +25,8 @@ class ReadWriteTestWithFailures extends AnyFunSuite override def beforeAll(): Unit = { logInfo("test initialized , setup Celeborn mini cluster") - val (m, _) = setupMiniClusterWithRandomPorts(workerConf = Map("celeborn.shuffle.chunk.size" -> "100B", - "celeborn.worker.flusher.buffer.size" -> "10B")) + val (m, _) = setupMiniClusterWithRandomPorts(workerConf = + Map("celeborn.shuffle.chunk.size" -> "100B", "celeborn.worker.flusher.buffer.size" -> "10B")) masterPort = m.conf.masterPort } @@ -49,8 +50,9 @@ class ReadWriteTestWithFailures extends AnyFunSuite .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "false") val lifecycleManager = new LifecycleManager(APP, clientConf) - val shuffleClient = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) - shuffleClient.setupLifecycleManagerRef(lifecycleManager.self) + val shuffleClient1 = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) + val shuffleClient2 = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) + shuffleClient1.setupLifecycleManagerRef(lifecycleManager.self) // push 100 random strings val numuuids = 100 @@ -59,12 +61,12 @@ class ReadWriteTestWithFailures extends AnyFunSuite val str = UUID.randomUUID().toString stringSet.add(str) val data = ("_" + str).getBytes(StandardCharsets.UTF_8) - shuffleClient.pushData(1, 0, 0, 0, data, 0, data.length, 1, 1) + shuffleClient1.pushData(1, 0, 0, 0, data, 0, data.length, 1, 1) } - shuffleClient.pushMergedData(1, 0, 0) + shuffleClient1.pushMergedData(1, 0, 0) Thread.sleep(1000) - shuffleClient.mapperEnd(1, 0, 0, 1) + shuffleClient1.mapperEnd(1, 0, 0, 1) var duplicateBytesRead = new AtomicLong(0) val metricsCallback = new MetricsCallback { @@ -74,7 +76,7 @@ class ReadWriteTestWithFailures extends AnyFunSuite duplicateBytesRead.addAndGet(bytesRead) } } - val inputStream = shuffleClient.readPartition( + val inputStream = shuffleClient1.readPartition( 1, 1, 0, @@ -97,7 +99,8 @@ class ReadWriteTestWithFailures extends AnyFunSuite b = inputStream.read() } - val readStrings = new String(outputStream.toByteArray, StandardCharsets.UTF_8).substring(1).split("_") + val readStrings = + new String(outputStream.toByteArray, StandardCharsets.UTF_8).substring(1).split("_") Assert.assertEquals(readStrings.length, numuuids) readStrings.foreach { str => Assert.assertTrue(stringSet.contains(str)) @@ -106,7 +109,7 @@ class ReadWriteTestWithFailures extends AnyFunSuite // Assert no duplicate chunks read despite chunk fetch retries Assert.assertEquals(duplicateBytesRead.get(), 0) Thread.sleep(5000L) - shuffleClient.shutdown() + shuffleClient1.shutdown() lifecycleManager.rpcEnv.shutdown() } } From 3933c948e81028436b1b6398cee66f463711e3b8 Mon Sep 17 00:00:00 2001 From: Saurabh Dubey Date: Tue, 4 Mar 2025 16:09:14 +0530 Subject: [PATCH 03/14] Lint + negative test case --- .../apache/celeborn/common/CelebornConf.scala | 2 +- docs/configuration/client.md | 1 + .../cluster/ReadWriteTestWithFailures.scala | 29 ++++++++++++------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 7f88622eacb..173d6cb8d45 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -4696,7 +4696,7 @@ object CelebornConf extends Logging { val WORKER_PARTITION_READER_CHECKPOINT_ENABLE: ConfigEntry[Boolean] = buildConf("celeborn.worker.partition.reader.checkpointEnabled") .categories("client") - .version("0.5.0") + .version("0.6.0") .doc("Whether or not checkpoint reads when re-creating a worker reader. Setting to true minimizes" + " the amount of unnecessary reads during client read retries") .booleanConf diff --git a/docs/configuration/client.md b/docs/configuration/client.md index d9e1a700c77..76a53c29f4d 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -137,4 +137,5 @@ license: | | celeborn.storage.s3.endpoint.region | <undefined> | false | S3 endpoint for Celeborn to store shuffle data. | 0.6.0 | | | celeborn.storage.s3.secret.key | <undefined> | false | S3 secret key for Celeborn to store shuffle data. | 0.6.0 | | | celeborn.tags.tagsExpr | | true | Expression to filter workers by tags. The expression is a comma-separated list of tags. The expression is evaluated as a logical AND of all tags. For example, `prod,high-io` filters workers that have both the `prod` and `high-io` tags. | 0.6.0 | | +| celeborn.worker.partition.reader.checkpointEnabled | false | false | Whether or not checkpoint reads when re-creating a worker reader. Setting to true minimizes the amount of unnecessary reads during client read retries | 0.6.0 | | diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala index 85a5f16bf10..2920aa829cc 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala @@ -36,6 +36,14 @@ class ReadWriteTestWithFailures extends AnyFunSuite } test(s"test MiniCluster with connection resets, ensure no duplicate reads") { + Assert.assertEquals(performTest("true"), 0) + } + + test(s"test MiniCluster with connection resets, assert duplicate reads") { + Assert.assertTrue(performTest("false") > 0) + } + + def performTest(workerChunkLevelCheckpointEnabled: String): Long = { val APP = "app-1" val clientConf = new CelebornConf() @@ -46,13 +54,12 @@ class ReadWriteTestWithFailures extends AnyFunSuite .set("celeborn.client.fetch.maxReqsInFlight", "1") .set("celeborn.client.shuffle.compression.codec", CompressionCodec.NONE.toString) .set(CelebornConf.TEST_CLIENT_FETCH_FAILURE.key, "true") - .set(CelebornConf.WORKER_PARTITION_READER_CHECKPOINT_ENABLE.key, "true") + .set(CelebornConf.WORKER_PARTITION_READER_CHECKPOINT_ENABLE.key, workerChunkLevelCheckpointEnabled) .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "false") val lifecycleManager = new LifecycleManager(APP, clientConf) - val shuffleClient1 = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) - val shuffleClient2 = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) - shuffleClient1.setupLifecycleManagerRef(lifecycleManager.self) + val shuffleClient = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) + shuffleClient.setupLifecycleManagerRef(lifecycleManager.self) // push 100 random strings val numuuids = 100 @@ -61,12 +68,12 @@ class ReadWriteTestWithFailures extends AnyFunSuite val str = UUID.randomUUID().toString stringSet.add(str) val data = ("_" + str).getBytes(StandardCharsets.UTF_8) - shuffleClient1.pushData(1, 0, 0, 0, data, 0, data.length, 1, 1) + shuffleClient.pushData(1, 0, 0, 0, data, 0, data.length, 1, 1) } - shuffleClient1.pushMergedData(1, 0, 0) + shuffleClient.pushMergedData(1, 0, 0) Thread.sleep(1000) - shuffleClient1.mapperEnd(1, 0, 0, 1) + shuffleClient.mapperEnd(1, 0, 0, 1) var duplicateBytesRead = new AtomicLong(0) val metricsCallback = new MetricsCallback { @@ -76,7 +83,7 @@ class ReadWriteTestWithFailures extends AnyFunSuite duplicateBytesRead.addAndGet(bytesRead) } } - val inputStream = shuffleClient1.readPartition( + val inputStream = shuffleClient.readPartition( 1, 1, 0, @@ -106,10 +113,10 @@ class ReadWriteTestWithFailures extends AnyFunSuite Assert.assertTrue(stringSet.contains(str)) } - // Assert no duplicate chunks read despite chunk fetch retries - Assert.assertEquals(duplicateBytesRead.get(), 0) Thread.sleep(5000L) - shuffleClient1.shutdown() + shuffleClient.shutdown() lifecycleManager.rpcEnv.shutdown() + + duplicateBytesRead.get() } } From 9fea1b633630178785b4c772af3151c8cc887dd9 Mon Sep 17 00:00:00 2001 From: Saurabh Dubey Date: Tue, 4 Mar 2025 16:15:21 +0530 Subject: [PATCH 04/14] Commit missing files --- .../PartitionReaderCheckpointMetadata.java | 3 +++ .../WorkerPartitionReaderCheckpointMetadata.java | 15 +++++++++++++++ 2 files changed, 18 insertions(+) create mode 100644 client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java create mode 100644 client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java diff --git a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java new file mode 100644 index 00000000000..402b64d6d93 --- /dev/null +++ b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java @@ -0,0 +1,3 @@ +package org.apache.celeborn.client.read.checkpoint; + +public interface PartitionReaderCheckpointMetadata {} diff --git a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java new file mode 100644 index 00000000000..bdaea53ade4 --- /dev/null +++ b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java @@ -0,0 +1,15 @@ +package org.apache.celeborn.client.read.checkpoint; + +import java.util.Set; + +public class WorkerPartitionReaderCheckpointMetadata implements PartitionReaderCheckpointMetadata { + private final Set returnedChunks; + + public WorkerPartitionReaderCheckpointMetadata(Set returnedChunks) { + this.returnedChunks = returnedChunks; + } + + public Set getReturnedChunks() { + return returnedChunks; + } +} From 8ec61c5edaba08641645ea87f18fd94a263dd7ad Mon Sep 17 00:00:00 2001 From: Saurabh Dubey Date: Tue, 4 Mar 2025 16:19:23 +0530 Subject: [PATCH 05/14] Fix interface --- .../service/deploy/cluster/LocalReadByChunkOffsetsTest.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala index 33dcc12da90..75d95228879 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala @@ -121,6 +121,7 @@ class LocalReadByChunkOffsetsTest extends AnyFunSuite val metricsCallback = new MetricsCallback { override def incBytesRead(bytesWritten: Long): Unit = {} override def incReadTime(time: Long): Unit = {} + override def incDuplicateBytesRead(bytesRead: Long): Unit = ??? } // chunkOffset is [0, 9404, 25913, 35393, 49576] From b3c6eedec1c016b671988ee2ea92d71b10d63fe7 Mon Sep 17 00:00:00 2001 From: Saurabh Dubey Date: Tue, 4 Mar 2025 16:23:19 +0530 Subject: [PATCH 06/14] Move to default interface method --- .../apache/spark/shuffle/celeborn/CelebornShuffleReader.scala | 2 -- .../java/org/apache/celeborn/client/read/MetricsCallback.java | 2 +- .../org/apache/celeborn/client/WithShuffleClientSuite.scala | 1 - .../service/deploy/cluster/LocalReadByChunkOffsetsTest.scala | 1 - .../celeborn/service/deploy/cluster/ReadWriteTestBase.scala | 1 - 5 files changed, 1 insertion(+), 6 deletions(-) diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index d84ea52761c..b35e549ae38 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -91,8 +91,6 @@ class CelebornShuffleReader[K, C]( override def incReadTime(time: Long): Unit = metrics.incFetchWaitTime(time) - - override def incDuplicateBytesRead(bytesRead: Long): Unit = {} } if (streamCreatorPool == null) { diff --git a/client/src/main/java/org/apache/celeborn/client/read/MetricsCallback.java b/client/src/main/java/org/apache/celeborn/client/read/MetricsCallback.java index f79ea4038a2..ab563cb8863 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/MetricsCallback.java +++ b/client/src/main/java/org/apache/celeborn/client/read/MetricsCallback.java @@ -22,5 +22,5 @@ public interface MetricsCallback { void incReadTime(long time); - void incDuplicateBytesRead(long bytesRead); + default void incDuplicateBytesRead(long bytesRead) {} } diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala index 935faf922d2..0570760ce14 100644 --- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala @@ -149,7 +149,6 @@ trait WithShuffleClientSuite extends CelebornFunSuite { val metricsCallback = new MetricsCallback { override def incBytesRead(bytesWritten: Long): Unit = {} override def incReadTime(time: Long): Unit = {} - override def incDuplicateBytesRead(bytesRead: Long): Unit = {} } // reduce normal empty CelebornInputStream diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala index 75d95228879..33dcc12da90 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala @@ -121,7 +121,6 @@ class LocalReadByChunkOffsetsTest extends AnyFunSuite val metricsCallback = new MetricsCallback { override def incBytesRead(bytesWritten: Long): Unit = {} override def incReadTime(time: Long): Unit = {} - override def incDuplicateBytesRead(bytesRead: Long): Unit = ??? } // chunkOffset is [0, 9404, 25913, 35393, 49576] diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala index 61a2e895b7f..e81593e1d99 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala @@ -103,7 +103,6 @@ trait ReadWriteTestBase extends AnyFunSuite val metricsCallback = new MetricsCallback { override def incBytesRead(bytesWritten: Long): Unit = {} override def incReadTime(time: Long): Unit = {} - override def incDuplicateBytesRead(bytesRead: Long): Unit = {} } val inputStream = shuffleClient.readPartition( 1, From ebb165bbaf7071f2c5b1161d5662e4b78ac9bc83 Mon Sep 17 00:00:00 2001 From: Saurabh Dubey Date: Tue, 4 Mar 2025 16:27:01 +0530 Subject: [PATCH 07/14] Add licenses --- .../PartitionReaderCheckpointMetadata.java | 20 +++++++++++++++ ...rkerPartitionReaderCheckpointMetadata.java | 25 +++++++++++++++++++ .../cluster/ReadWriteTestWithFailures.scala | 17 +++++++++++++ 3 files changed, 62 insertions(+) diff --git a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java index 402b64d6d93..63e81739f0b 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java +++ b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java @@ -1,3 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.celeborn.client.read.checkpoint; +/** + * Checkpoint metadata interface for a partition reader. + */ public interface PartitionReaderCheckpointMetadata {} diff --git a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java index bdaea53ade4..057d9e8ac18 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java +++ b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java @@ -1,10 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.celeborn.client.read.checkpoint; import java.util.Set; +/** + * Checkpoint metadata for a partition reader on the worker side. + * + */ public class WorkerPartitionReaderCheckpointMetadata implements PartitionReaderCheckpointMetadata { private final Set returnedChunks; + /** + * Create an instance of the checkpoint metadata. + * @param returnedChunks The set of chunks that have already been returned to the user. + */ public WorkerPartitionReaderCheckpointMetadata(Set returnedChunks) { this.returnedChunks = returnedChunks; } diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala index 2920aa829cc..89a64d7d766 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.celeborn.service.deploy.cluster import java.io.ByteArrayOutputStream From 0a4dc460d5b9ced66c73913ee2e47ca24b03fd13 Mon Sep 17 00:00:00 2001 From: Saurabh Dubey Date: Tue, 4 Mar 2025 16:31:36 +0530 Subject: [PATCH 08/14] Lint --- .../read/checkpoint/PartitionReaderCheckpointMetadata.java | 4 +--- .../checkpoint/WorkerPartitionReaderCheckpointMetadata.java | 6 ++---- .../service/deploy/cluster/ReadWriteTestWithFailures.scala | 4 +++- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java index 63e81739f0b..d3f9bcf9ca7 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java +++ b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java @@ -17,7 +17,5 @@ package org.apache.celeborn.client.read.checkpoint; -/** - * Checkpoint metadata interface for a partition reader. - */ +/** Checkpoint metadata interface for a partition reader. */ public interface PartitionReaderCheckpointMetadata {} diff --git a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java index 057d9e8ac18..dc597df942f 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java +++ b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java @@ -19,15 +19,13 @@ import java.util.Set; -/** - * Checkpoint metadata for a partition reader on the worker side. - * - */ +/** Checkpoint metadata for a partition reader on the worker side. */ public class WorkerPartitionReaderCheckpointMetadata implements PartitionReaderCheckpointMetadata { private final Set returnedChunks; /** * Create an instance of the checkpoint metadata. + * * @param returnedChunks The set of chunks that have already been returned to the user. */ public WorkerPartitionReaderCheckpointMetadata(Set returnedChunks) { diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala index 89a64d7d766..d35ce51d0af 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala @@ -71,7 +71,9 @@ class ReadWriteTestWithFailures extends AnyFunSuite .set("celeborn.client.fetch.maxReqsInFlight", "1") .set("celeborn.client.shuffle.compression.codec", CompressionCodec.NONE.toString) .set(CelebornConf.TEST_CLIENT_FETCH_FAILURE.key, "true") - .set(CelebornConf.WORKER_PARTITION_READER_CHECKPOINT_ENABLE.key, workerChunkLevelCheckpointEnabled) + .set( + CelebornConf.WORKER_PARTITION_READER_CHECKPOINT_ENABLE.key, + workerChunkLevelCheckpointEnabled) .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "false") val lifecycleManager = new LifecycleManager(APP, clientConf) From 65bb3fcdb5bd5810e315e252e90e2e4150c8ba90 Mon Sep 17 00:00:00 2001 From: Saurabh Dubey Date: Fri, 7 Mar 2025 16:53:50 +0530 Subject: [PATCH 09/14] Review comments --- .../celeborn/client/read/WorkerPartitionReader.java | 2 +- .../scala/org/apache/celeborn/common/CelebornConf.scala | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index 91b9095b022..a19c788278b 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -198,7 +198,7 @@ public void close() { if (results.size() > 0) { results.forEach( chunk -> { - chunk.getRight().release(); // + chunk.getRight().release(); }); } results.clear(); diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 173d6cb8d45..03961458881 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -982,8 +982,13 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientFetchTimeoutMs: Long = get(CLIENT_FETCH_TIMEOUT) def clientFetchBufferSize: Int = get(CLIENT_FETCH_BUFFER_SIZE).toInt def clientFetchMaxReqsInFlight: Int = get(CLIENT_FETCH_MAX_REQS_IN_FLIGHT) - def isWorkerPartitionReaderCheckpointEnabled: Boolean = - get(WORKER_PARTITION_READER_CHECKPOINT_ENABLE) + def isWorkerPartitionReaderCheckpointEnabled: Boolean = { + if (clientPushReplicateEnabled) { + false + } else + get(WORKER_PARTITION_READER_CHECKPOINT_ENABLE) + } + def clientFetchMaxRetriesForEachReplica: Int = get(CLIENT_FETCH_MAX_RETRIES_FOR_EACH_REPLICA) def clientStageRerunEnabled: Boolean = get(CLIENT_STAGE_RERUN_ENABLED) def clientFetchExcludeWorkerOnFailureEnabled: Boolean = From ac7ec43ae207f1d91fae6046f156ad0358f2c6dc Mon Sep 17 00:00:00 2001 From: Saurabh Dubey Date: Sat, 8 Mar 2025 22:53:01 +0530 Subject: [PATCH 10/14] Add comment --- .../apache/celeborn/client/read/WorkerPartitionReader.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index a19c788278b..853f939bad6 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -249,6 +249,11 @@ private void fetchChunks() throws IOException, InterruptedException { chunkIndex); chunkIndex++; returnedChunks++; + // IMP Since we're skipping fetching this chunk, we must not decrement toFetch here + // Eg: If chunkIndex=1, toFetch=2, endChunkIndex = 4 and chunkIdsAlreadyReturned = {1,2} + // if we skip chunk 1 and 2, decrementing toFetch here would wrongly exit the loop + // without ever fetching chunk {3, 4}, and next() would end up waiting for chunk {3,4} + // infinitely. } else if (testFetch && fetchChunkRetryCnt < fetchChunkMaxRetry - 1 && chunkIndex == 3) { callback.onFailure(chunkIndex, new CelebornIOException("Test fetch chunk failure")); toFetch--; From 96fe400b44a1ee4d002f62ff9a3a24c1f9b554db Mon Sep 17 00:00:00 2001 From: Saurabh Dubey Date: Mon, 10 Mar 2025 11:50:09 +0530 Subject: [PATCH 11/14] Review comments --- .../client/read/CelebornInputStream.java | 10 ++--- .../client/read/DfsPartitionReader.java | 13 +++++++ .../client/read/LocalPartitionReader.java | 13 +++++++ .../celeborn/client/read/PartitionReader.java | 9 ++--- .../client/read/WorkerPartitionReader.java | 14 +++---- .../PartitionReaderCheckpointMetadata.java | 19 +++++++++- ...rkerPartitionReaderCheckpointMetadata.java | 38 ------------------- .../apache/celeborn/common/CelebornConf.scala | 12 ++---- docs/configuration/client.md | 2 +- .../cluster/ReadWriteTestWithFailures.scala | 2 +- 10 files changed, 65 insertions(+), 67 deletions(-) delete mode 100644 client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 0573dbbba57..09a066ea41f 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -427,17 +427,15 @@ private PartitionReader createReaderWithRetry( Optional checkpointMetadata) throws IOException { Exception lastException = null; - PartitionReader reader = null; while (fetchChunkRetryCnt < fetchChunkMaxRetry) { try { logger.debug("Create reader for location {}", location); if (isExcluded(location)) { throw new CelebornIOException("Fetch data from excluded worker! " + location); } - reader = createReader(location, pbStreamHandler, fetchChunkRetryCnt, fetchChunkMaxRetry); - if (checkpointMetadata.isPresent()) { - reader.updateCheckpointMetadata(checkpointMetadata.get()); - } + PartitionReader reader = + createReader(location, pbStreamHandler, fetchChunkRetryCnt, fetchChunkMaxRetry); + checkpointMetadata.ifPresent(reader::updateCheckpointMetadata); return reader; } catch (Exception e) { lastException = e; @@ -544,7 +542,7 @@ private ByteBuf getNextChunk() throws IOException { createReaderWithRetry( currentReader.getLocation(), null, - Optional.ofNullable(currentReader.getPartitionReaderCheckpointMetadata())); + currentReader.getPartitionReaderCheckpointMetadata()); } } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java index 6cca0d47be0..f3e50ddc958 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -36,6 +37,7 @@ import org.slf4j.LoggerFactory; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.network.client.TransportClient; import org.apache.celeborn.common.network.client.TransportClientFactory; @@ -308,4 +310,15 @@ private void closeStream() { public PartitionLocation getLocation() { return location; } + + @Override + public Optional getPartitionReaderCheckpointMetadata() { + // TODO implement similar to {@link WorkerPartitionReader} + return Optional.empty(); + } + + @Override + public void updateCheckpointMetadata(PartitionReaderCheckpointMetadata checkpointMetadata) { + // TODO implement similar to {@link WorkerPartitionReader} + } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java index 1de83eb4573..b5755beaef1 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java @@ -22,6 +22,7 @@ import java.nio.channels.FileChannel; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -34,6 +35,7 @@ import org.slf4j.LoggerFactory; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.network.client.TransportClient; @@ -277,4 +279,15 @@ private void closeStream() { public PartitionLocation getLocation() { return location; } + + @Override + public Optional getPartitionReaderCheckpointMetadata() { + // TODO implement similar to {@link WorkerPartitionReader} + return Optional.empty(); + } + + @Override + public void updateCheckpointMetadata(PartitionReaderCheckpointMetadata checkpointMetadata) { + // TODO implement similar to {@link WorkerPartitionReader} + } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java index 42d5d02459b..7c14b66ad37 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java @@ -18,13 +18,14 @@ package org.apache.celeborn.client.read; import java.io.IOException; +import java.util.Optional; import io.netty.buffer.ByteBuf; import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import org.apache.celeborn.common.protocol.PartitionLocation; -public interface PartitionReader { +public interface PartitionReader { boolean hasNext(); ByteBuf next() throws IOException, InterruptedException; @@ -33,9 +34,7 @@ public interface PartitionReader { PartitionLocation getLocation(); - default T getPartitionReaderCheckpointMetadata() { - return null; - } + Optional getPartitionReaderCheckpointMetadata(); - default void updateCheckpointMetadata(T checkpointMetadata) {} + void updateCheckpointMetadata(PartitionReaderCheckpointMetadata checkpointMetadata); } diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index 853f939bad6..b6ef4991786 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.HashSet; +import java.util.Optional; import java.util.Set; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -31,7 +32,7 @@ import org.slf4j.LoggerFactory; import org.apache.celeborn.client.ShuffleClient; -import org.apache.celeborn.client.read.checkpoint.WorkerPartitionReaderCheckpointMetadata; +import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.network.buffer.ManagedBuffer; @@ -48,8 +49,7 @@ import org.apache.celeborn.common.protocol.StreamType; import org.apache.celeborn.common.util.ExceptionUtils; -public class WorkerPartitionReader - implements PartitionReader { +public class WorkerPartitionReader implements PartitionReader { private final Logger logger = LoggerFactory.getLogger(WorkerPartitionReader.class); private PartitionLocation location; private final TransportClientFactory clientFactory; @@ -225,14 +225,14 @@ public PartitionLocation getLocation() { } @Override - public WorkerPartitionReaderCheckpointMetadata getPartitionReaderCheckpointMetadata() { + public Optional getPartitionReaderCheckpointMetadata() { return isCheckpointEnabled - ? new WorkerPartitionReaderCheckpointMetadata(chunkIdsAlreadyReturned) - : null; + ? Optional.of(new PartitionReaderCheckpointMetadata(chunkIdsAlreadyReturned)) + : Optional.empty(); } @Override - public void updateCheckpointMetadata(WorkerPartitionReaderCheckpointMetadata checkpointMetadata) { + public void updateCheckpointMetadata(PartitionReaderCheckpointMetadata checkpointMetadata) { chunkIdsAlreadyReturned = checkpointMetadata.getReturnedChunks(); } diff --git a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java index d3f9bcf9ca7..f5e9e8c035f 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java +++ b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java @@ -17,5 +17,22 @@ package org.apache.celeborn.client.read.checkpoint; +import java.util.Set; + /** Checkpoint metadata interface for a partition reader. */ -public interface PartitionReaderCheckpointMetadata {} +public class PartitionReaderCheckpointMetadata { + private final Set returnedChunks; + + /** + * Create an instance of the checkpoint metadata. + * + * @param returnedChunks The set of chunks that have already been returned to the user. + */ + public PartitionReaderCheckpointMetadata(Set returnedChunks) { + this.returnedChunks = returnedChunks; + } + + public Set getReturnedChunks() { + return returnedChunks; + } +} diff --git a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java deleted file mode 100644 index dc597df942f..00000000000 --- a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/WorkerPartitionReaderCheckpointMetadata.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.celeborn.client.read.checkpoint; - -import java.util.Set; - -/** Checkpoint metadata for a partition reader on the worker side. */ -public class WorkerPartitionReaderCheckpointMetadata implements PartitionReaderCheckpointMetadata { - private final Set returnedChunks; - - /** - * Create an instance of the checkpoint metadata. - * - * @param returnedChunks The set of chunks that have already been returned to the user. - */ - public WorkerPartitionReaderCheckpointMetadata(Set returnedChunks) { - this.returnedChunks = returnedChunks; - } - - public Set getReturnedChunks() { - return returnedChunks; - } -} diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 03961458881..45ab9e8ea5d 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -982,12 +982,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientFetchTimeoutMs: Long = get(CLIENT_FETCH_TIMEOUT) def clientFetchBufferSize: Int = get(CLIENT_FETCH_BUFFER_SIZE).toInt def clientFetchMaxReqsInFlight: Int = get(CLIENT_FETCH_MAX_REQS_IN_FLIGHT) - def isWorkerPartitionReaderCheckpointEnabled: Boolean = { - if (clientPushReplicateEnabled) { - false - } else - get(WORKER_PARTITION_READER_CHECKPOINT_ENABLE) - } + def isWorkerPartitionReaderCheckpointEnabled: Boolean = + get(WORKER_PARTITION_READER_CHECKPOINT_ENABLED) def clientFetchMaxRetriesForEachReplica: Int = get(CLIENT_FETCH_MAX_RETRIES_FOR_EACH_REPLICA) def clientStageRerunEnabled: Boolean = get(CLIENT_STAGE_RERUN_ENABLED) @@ -4698,8 +4694,8 @@ object CelebornConf extends Logging { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("64k") - val WORKER_PARTITION_READER_CHECKPOINT_ENABLE: ConfigEntry[Boolean] = - buildConf("celeborn.worker.partition.reader.checkpointEnabled") + val WORKER_PARTITION_READER_CHECKPOINT_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.worker.partition.reader.checkpoint.enabled") .categories("client") .version("0.6.0") .doc("Whether or not checkpoint reads when re-creating a worker reader. Setting to true minimizes" + diff --git a/docs/configuration/client.md b/docs/configuration/client.md index 76a53c29f4d..3cc4aa64014 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -137,5 +137,5 @@ license: | | celeborn.storage.s3.endpoint.region | <undefined> | false | S3 endpoint for Celeborn to store shuffle data. | 0.6.0 | | | celeborn.storage.s3.secret.key | <undefined> | false | S3 secret key for Celeborn to store shuffle data. | 0.6.0 | | | celeborn.tags.tagsExpr | | true | Expression to filter workers by tags. The expression is a comma-separated list of tags. The expression is evaluated as a logical AND of all tags. For example, `prod,high-io` filters workers that have both the `prod` and `high-io` tags. | 0.6.0 | | -| celeborn.worker.partition.reader.checkpointEnabled | false | false | Whether or not checkpoint reads when re-creating a worker reader. Setting to true minimizes the amount of unnecessary reads during client read retries | 0.6.0 | | +| celeborn.worker.partition.reader.checkpoint.enabled | false | false | Whether or not checkpoint reads when re-creating a worker reader. Setting to true minimizes the amount of unnecessary reads during client read retries | 0.6.0 | | diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala index d35ce51d0af..959fc8dfb44 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala @@ -72,7 +72,7 @@ class ReadWriteTestWithFailures extends AnyFunSuite .set("celeborn.client.shuffle.compression.codec", CompressionCodec.NONE.toString) .set(CelebornConf.TEST_CLIENT_FETCH_FAILURE.key, "true") .set( - CelebornConf.WORKER_PARTITION_READER_CHECKPOINT_ENABLE.key, + CelebornConf.WORKER_PARTITION_READER_CHECKPOINT_ENABLED.key, workerChunkLevelCheckpointEnabled) .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "false") From 74706e9cadd5532838a5611799a73710ab59c1ee Mon Sep 17 00:00:00 2001 From: Saurabh Dubey Date: Wed, 12 Mar 2025 23:36:43 +0530 Subject: [PATCH 12/14] Impl for DFSPartitionReader --- .../client/read/DfsPartitionReader.java | 39 ++++++++++++++----- .../client/read/WorkerPartitionReader.java | 2 +- .../apache/celeborn/common/CelebornConf.scala | 12 +++--- docs/configuration/client.md | 2 +- .../cluster/ReadWriteTestWithFailures.scala | 2 +- 5 files changed, 38 insertions(+), 19 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java index f3e50ddc958..c3925be680d 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java @@ -20,8 +20,10 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -29,7 +31,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; -import io.netty.util.ReferenceCounted; +import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; @@ -59,7 +61,7 @@ public class DfsPartitionReader implements PartitionReader { PartitionLocation location; private final long shuffleChunkSize; private final int fetchMaxReqsInFlight; - private final LinkedBlockingQueue results; + private final LinkedBlockingQueue> results; private final AtomicReference exception = new AtomicReference<>(); private volatile boolean closed = false; private ExecutorService fetchThread; @@ -78,6 +80,10 @@ public class DfsPartitionReader implements PartitionReader { private Path dataFilePath; + // checkpoints + private final boolean isCheckpointEnabled; + private Set chunkIdsAlreadyReturned; + public DfsPartitionReader( CelebornConf conf, String shuffleKey, @@ -146,6 +152,8 @@ public DfsPartitionReader( : Math.min(chunkOffsets.size() - 2, endChunkIndex); this.currentChunkIndex = this.startChunkIndex; this.numChunks = this.endChunkIndex - this.startChunkIndex + 1; + this.chunkIdsAlreadyReturned = new HashSet<>(); + this.isCheckpointEnabled = conf.isPartitionReaderCheckpointEnabled(); logger.debug( "DFS {} total offset count:{} chunk count: {} " + "start chunk index:{} end chunk index:{} offsets:{}", @@ -204,18 +212,27 @@ private List getChunkOffsetsFromSortedIndex( @Override public boolean hasNext() { logger.debug("check has next current index: {} chunks {}", returnedChunks, numChunks); - return returnedChunks < numChunks; + return chunkIdsAlreadyReturned.size() < numChunks; } @Override public ByteBuf next() throws IOException, InterruptedException { - ByteBuf chunk = null; + Pair chunk = null; if (!fetchThreadStarted) { fetchThreadStarted = true; fetchThread.submit( () -> { try { while (!closed && currentChunkIndex <= endChunkIndex) { + if (chunkIdsAlreadyReturned.contains(currentChunkIndex)) { + logger.info( + "Skipping chunk {} as it has already been returned," + + " likely by a previous reader for the same partition.", + currentChunkIndex); + currentChunkIndex++; + returnedChunks++; + continue; + } while (results.size() >= fetchMaxReqsInFlight) { Thread.sleep(50); } @@ -240,7 +257,7 @@ public ByteBuf next() throws IOException, InterruptedException { break; } } - results.put(Unpooled.wrappedBuffer(buffer)); + results.put(Pair.of(currentChunkIndex, Unpooled.wrappedBuffer(buffer))); logger.debug("add index {} to results", currentChunkIndex++); } } catch (Exception e) { @@ -264,7 +281,8 @@ public ByteBuf next() throws IOException, InterruptedException { throw e; } returnedChunks++; - return chunk; + chunkIdsAlreadyReturned.add(chunk.getLeft()); + return chunk.getRight(); } private void checkException() throws IOException { @@ -286,7 +304,7 @@ public void close() { logger.warn("close DFS input stream failed.", e); } if (results.size() > 0) { - results.forEach(ReferenceCounted::release); + results.forEach(chunk -> chunk.getRight().release()); } results.clear(); closeStream(); @@ -313,12 +331,13 @@ public PartitionLocation getLocation() { @Override public Optional getPartitionReaderCheckpointMetadata() { - // TODO implement similar to {@link WorkerPartitionReader} - return Optional.empty(); + return isCheckpointEnabled + ? Optional.of(new PartitionReaderCheckpointMetadata(chunkIdsAlreadyReturned)) + : Optional.empty(); } @Override public void updateCheckpointMetadata(PartitionReaderCheckpointMetadata checkpointMetadata) { - // TODO implement similar to {@link WorkerPartitionReader} + chunkIdsAlreadyReturned = checkpointMetadata.getReturnedChunks(); } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index b6ef4991786..5bd5fc661ea 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -156,7 +156,7 @@ public void onFailure(int chunkIndex, Throwable e) { this.fetchChunkRetryCnt = fetchChunkRetryCnt; this.fetchChunkMaxRetry = fetchChunkMaxRetry; this.chunkIdsAlreadyReturned = new HashSet<>(); - this.isCheckpointEnabled = conf.isWorkerPartitionReaderCheckpointEnabled(); + this.isCheckpointEnabled = conf.isPartitionReaderCheckpointEnabled(); testFetch = conf.testFetchFailure(); ShuffleClient.incrementTotalReadCounter(); } diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 45ab9e8ea5d..d7f23b37a75 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -982,8 +982,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientFetchTimeoutMs: Long = get(CLIENT_FETCH_TIMEOUT) def clientFetchBufferSize: Int = get(CLIENT_FETCH_BUFFER_SIZE).toInt def clientFetchMaxReqsInFlight: Int = get(CLIENT_FETCH_MAX_REQS_IN_FLIGHT) - def isWorkerPartitionReaderCheckpointEnabled: Boolean = - get(WORKER_PARTITION_READER_CHECKPOINT_ENABLED) + def isPartitionReaderCheckpointEnabled: Boolean = + get(PARTITION_READER_CHECKPOINT_ENABLED) def clientFetchMaxRetriesForEachReplica: Int = get(CLIENT_FETCH_MAX_RETRIES_FOR_EACH_REPLICA) def clientStageRerunEnabled: Boolean = get(CLIENT_STAGE_RERUN_ENABLED) @@ -4694,12 +4694,12 @@ object CelebornConf extends Logging { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("64k") - val WORKER_PARTITION_READER_CHECKPOINT_ENABLED: ConfigEntry[Boolean] = - buildConf("celeborn.worker.partition.reader.checkpoint.enabled") + val PARTITION_READER_CHECKPOINT_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.partition.reader.checkpoint.enabled") .categories("client") .version("0.6.0") - .doc("Whether or not checkpoint reads when re-creating a worker reader. Setting to true minimizes" + - " the amount of unnecessary reads during client read retries") + .doc("Whether or not checkpoint reads when re-creating a partition reader. Setting to true minimizes" + + " the amount of unnecessary reads during partition read retries") .booleanConf .createWithDefault(false) diff --git a/docs/configuration/client.md b/docs/configuration/client.md index 3cc4aa64014..77ca04023b3 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -128,6 +128,7 @@ license: | | celeborn.identity.user-specific.userName | default | false | User name if celeborn.identity.provider is org.apache.celeborn.common.identity.DefaultIdentityProvider. | 0.6.0 | celeborn.quota.identity.user-specific.userName | | celeborn.master.endpoints | <localhost>:9097 | false | Endpoints of master nodes for celeborn clients to connect. Client uses resolver provided by celeborn.master.endpoints.resolver to resolve the master endpoints. By default Celeborn uses `org.apache.celeborn.common.client.StaticMasterEndpointResolver` which take static master endpoints as input. Allowed pattern: `:[,:]*`, e.g. `clb1:9097,clb2:9098,clb3:9099`. If the port is omitted, 9097 will be used. If the master endpoints are not static then users can pass custom resolver implementation to discover master endpoints actively using celeborn.master.endpoints.resolver. | 0.2.0 | | | celeborn.master.endpoints.resolver | org.apache.celeborn.common.client.StaticMasterEndpointResolver | false | Resolver class that can be used for discovering and updating the master endpoints. This allows users to provide a custom master endpoint resolver implementation. This is useful in environments where the master nodes might change due to scaling operations or infrastructure updates. Clients need to ensure that provided resolver class should be present in the classpath. | 0.5.2 | | +| celeborn.partition.reader.checkpoint.enabled | false | false | Whether or not checkpoint reads when re-creating a partition reader. Setting to true minimizes the amount of unnecessary reads during partition read retries | 0.6.0 | | | celeborn.quota.enabled | true | false | When Master side sets to true, the master will enable to check the quota via QuotaManager. When Client side sets to true, LifecycleManager will request Master side to check whether the current user has enough quota before registration of shuffle. Fallback to the default shuffle service when Master side checks that there is no enough quota for current user. | 0.2.0 | | | celeborn.quota.interruptShuffle.enabled | false | false | Whether to enable interrupt shuffle when quota exceeds. | 0.6.0 | | | celeborn.storage.availableTypes | HDD | false | Enabled storages. Available options: MEMORY,HDD,SSD,HDFS,S3. Note: HDD and SSD would be treated as identical. | 0.3.0 | celeborn.storage.activeTypes | @@ -137,5 +138,4 @@ license: | | celeborn.storage.s3.endpoint.region | <undefined> | false | S3 endpoint for Celeborn to store shuffle data. | 0.6.0 | | | celeborn.storage.s3.secret.key | <undefined> | false | S3 secret key for Celeborn to store shuffle data. | 0.6.0 | | | celeborn.tags.tagsExpr | | true | Expression to filter workers by tags. The expression is a comma-separated list of tags. The expression is evaluated as a logical AND of all tags. For example, `prod,high-io` filters workers that have both the `prod` and `high-io` tags. | 0.6.0 | | -| celeborn.worker.partition.reader.checkpoint.enabled | false | false | Whether or not checkpoint reads when re-creating a worker reader. Setting to true minimizes the amount of unnecessary reads during client read retries | 0.6.0 | | diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala index 959fc8dfb44..17c21d0a683 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala @@ -72,7 +72,7 @@ class ReadWriteTestWithFailures extends AnyFunSuite .set("celeborn.client.shuffle.compression.codec", CompressionCodec.NONE.toString) .set(CelebornConf.TEST_CLIENT_FETCH_FAILURE.key, "true") .set( - CelebornConf.WORKER_PARTITION_READER_CHECKPOINT_ENABLED.key, + CelebornConf.PARTITION_READER_CHECKPOINT_ENABLED.key, workerChunkLevelCheckpointEnabled) .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "false") From b437dc3b684e11376cc5b3a64dbda9e3a2376707 Mon Sep 17 00:00:00 2001 From: Saurabh Dubey Date: Thu, 13 Mar 2025 22:50:01 +0530 Subject: [PATCH 13/14] Review comments --- .../client/read/CelebornInputStream.java | 17 ++++-- .../client/read/DfsPartitionReader.java | 47 ++++++++++------- .../client/read/LocalPartitionReader.java | 5 -- .../celeborn/client/read/PartitionReader.java | 2 - .../client/read/WorkerPartitionReader.java | 52 +++++++++++-------- .../PartitionReaderCheckpointMetadata.java | 19 ++++--- .../apache/celeborn/common/CelebornConf.scala | 2 +- docs/configuration/client.md | 2 +- 8 files changed, 84 insertions(+), 62 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 09a066ea41f..cbcfae93808 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -434,8 +434,12 @@ private PartitionReader createReaderWithRetry( throw new CelebornIOException("Fetch data from excluded worker! " + location); } PartitionReader reader = - createReader(location, pbStreamHandler, fetchChunkRetryCnt, fetchChunkMaxRetry); - checkpointMetadata.ifPresent(reader::updateCheckpointMetadata); + createReader( + location, + pbStreamHandler, + fetchChunkRetryCnt, + fetchChunkMaxRetry, + checkpointMetadata); return reader; } catch (Exception e) { lastException = e; @@ -554,7 +558,8 @@ private PartitionReader createReader( PartitionLocation location, PbStreamHandler pbStreamHandler, int fetchChunkRetryCnt, - int fetchChunkMaxRetry) + int fetchChunkMaxRetry, + Optional checkpointMetadata) throws IOException, InterruptedException { StorageInfo storageInfo = location.getStorageInfo(); @@ -600,7 +605,8 @@ private PartitionReader createReader( fetchChunkMaxRetry, callback, startChunkIndex, - endChunkIndex); + endChunkIndex, + checkpointMetadata); } case S3: case HDFS: @@ -614,7 +620,8 @@ private PartitionReader createReader( endMapIndex, callback, startChunkIndex, - endChunkIndex); + endChunkIndex, + checkpointMetadata); default: throw new CelebornIOException( String.format("Unknown storage info %s to read location %s", storageInfo, location)); diff --git a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java index c3925be680d..003c7cd1264 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java @@ -20,10 +20,8 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; -import java.util.HashSet; import java.util.List; import java.util.Optional; -import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -68,6 +66,7 @@ public class DfsPartitionReader implements PartitionReader { private boolean fetchThreadStarted; private FSDataInputStream dfsInputStream; private int numChunks = 0; + private int lastReturnedChunkId = -1; private int returnedChunks = 0; private int currentChunkIndex = 0; private int startChunkIndex; @@ -80,9 +79,7 @@ public class DfsPartitionReader implements PartitionReader { private Path dataFilePath; - // checkpoints - private final boolean isCheckpointEnabled; - private Set chunkIdsAlreadyReturned; + private Optional partitionReaderCheckpointMetadata; public DfsPartitionReader( CelebornConf conf, @@ -94,7 +91,8 @@ public DfsPartitionReader( int endMapIndex, MetricsCallback metricsCallback, int startChunkIndex, - int endChunkIndex) + int endChunkIndex, + Optional checkpointMetadata) throws IOException { this.conf = conf; shuffleChunkSize = conf.dfsReadChunkSize(); @@ -152,8 +150,16 @@ public DfsPartitionReader( : Math.min(chunkOffsets.size() - 2, endChunkIndex); this.currentChunkIndex = this.startChunkIndex; this.numChunks = this.endChunkIndex - this.startChunkIndex + 1; - this.chunkIdsAlreadyReturned = new HashSet<>(); - this.isCheckpointEnabled = conf.isPartitionReaderCheckpointEnabled(); + + if (checkpointMetadata.isPresent()) { + this.partitionReaderCheckpointMetadata = checkpointMetadata; + this.returnedChunks = checkpointMetadata.get().getReturnedChunks().size(); + } else { + this.partitionReaderCheckpointMetadata = + conf.isPartitionReaderCheckpointEnabled() + ? Optional.of(new PartitionReaderCheckpointMetadata()) + : Optional.empty(); + } logger.debug( "DFS {} total offset count:{} chunk count: {} " + "start chunk index:{} end chunk index:{} offsets:{}", @@ -212,25 +218,33 @@ private List getChunkOffsetsFromSortedIndex( @Override public boolean hasNext() { logger.debug("check has next current index: {} chunks {}", returnedChunks, numChunks); - return chunkIdsAlreadyReturned.size() < numChunks; + return returnedChunks < numChunks; + } + + private void checkpoint() { + if (lastReturnedChunkId != -1) { + partitionReaderCheckpointMetadata.ifPresent( + readerCheckpointMetadata -> readerCheckpointMetadata.checkpoint(lastReturnedChunkId)); + } } @Override public ByteBuf next() throws IOException, InterruptedException { Pair chunk = null; + checkpoint(); if (!fetchThreadStarted) { fetchThreadStarted = true; fetchThread.submit( () -> { try { while (!closed && currentChunkIndex <= endChunkIndex) { - if (chunkIdsAlreadyReturned.contains(currentChunkIndex)) { + if (partitionReaderCheckpointMetadata.isPresent() + && partitionReaderCheckpointMetadata.get().isCheckpointed(currentChunkIndex)) { logger.info( "Skipping chunk {} as it has already been returned," + " likely by a previous reader for the same partition.", currentChunkIndex); currentChunkIndex++; - returnedChunks++; continue; } while (results.size() >= fetchMaxReqsInFlight) { @@ -281,7 +295,7 @@ public ByteBuf next() throws IOException, InterruptedException { throw e; } returnedChunks++; - chunkIdsAlreadyReturned.add(chunk.getLeft()); + lastReturnedChunkId = chunk.getLeft(); return chunk.getRight(); } @@ -331,13 +345,6 @@ public PartitionLocation getLocation() { @Override public Optional getPartitionReaderCheckpointMetadata() { - return isCheckpointEnabled - ? Optional.of(new PartitionReaderCheckpointMetadata(chunkIdsAlreadyReturned)) - : Optional.empty(); - } - - @Override - public void updateCheckpointMetadata(PartitionReaderCheckpointMetadata checkpointMetadata) { - chunkIdsAlreadyReturned = checkpointMetadata.getReturnedChunks(); + return partitionReaderCheckpointMetadata; } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java index b5755beaef1..56982f9147b 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java @@ -285,9 +285,4 @@ public Optional getPartitionReaderCheckpointM // TODO implement similar to {@link WorkerPartitionReader} return Optional.empty(); } - - @Override - public void updateCheckpointMetadata(PartitionReaderCheckpointMetadata checkpointMetadata) { - // TODO implement similar to {@link WorkerPartitionReader} - } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java index 7c14b66ad37..0ebb94e7467 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java @@ -35,6 +35,4 @@ public interface PartitionReader { PartitionLocation getLocation(); Optional getPartitionReaderCheckpointMetadata(); - - void updateCheckpointMetadata(PartitionReaderCheckpointMetadata checkpointMetadata); } diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index 5bd5fc661ea..bd3a58170de 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -19,9 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; -import java.util.HashSet; import java.util.Optional; -import java.util.Set; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -57,11 +55,13 @@ public class WorkerPartitionReader implements PartitionReader { private TransportClient client; private MetricsCallback metricsCallback; + private int lastReturnedChunkId = -1; private int returnedChunks; private int chunkIndex; private int startChunkIndex; private int endChunkIndex; + private int inflightRequestCount; private final LinkedBlockingQueue> results; private final ChunkReceivedCallback callback; @@ -76,9 +76,7 @@ public class WorkerPartitionReader implements PartitionReader { private int fetchChunkMaxRetry; private final boolean testFetch; - // checkpoints - private final boolean isCheckpointEnabled; - private Set chunkIdsAlreadyReturned; + private Optional partitionReaderCheckpointMetadata; WorkerPartitionReader( CelebornConf conf, @@ -92,12 +90,14 @@ public class WorkerPartitionReader implements PartitionReader { int fetchChunkMaxRetry, MetricsCallback metricsCallback, int startChunkIndex, - int endChunkIndex) + int endChunkIndex, + Optional checkpointMetadata) throws IOException, InterruptedException { this.shuffleKey = shuffleKey; fetchMaxReqsInFlight = conf.clientFetchMaxReqsInFlight(); results = new LinkedBlockingQueue<>(); fetchTimeoutMs = conf.clientFetchTimeoutMs(); + inflightRequestCount = 0; this.metricsCallback = metricsCallback; // only add the buffer to results queue if this reader is not closed. callback = @@ -155,20 +155,35 @@ public void onFailure(int chunkIndex, Throwable e) { this.clientFactory = clientFactory; this.fetchChunkRetryCnt = fetchChunkRetryCnt; this.fetchChunkMaxRetry = fetchChunkMaxRetry; - this.chunkIdsAlreadyReturned = new HashSet<>(); - this.isCheckpointEnabled = conf.isPartitionReaderCheckpointEnabled(); + if (checkpointMetadata.isPresent()) { + this.partitionReaderCheckpointMetadata = checkpointMetadata; + this.returnedChunks = checkpointMetadata.get().getReturnedChunks().size(); + } else { + this.partitionReaderCheckpointMetadata = + conf.isPartitionReaderCheckpointEnabled() + ? Optional.of(new PartitionReaderCheckpointMetadata()) + : Optional.empty(); + } testFetch = conf.testFetchFailure(); ShuffleClient.incrementTotalReadCounter(); } @Override public boolean hasNext() { - return chunkIdsAlreadyReturned.size() < endChunkIndex - startChunkIndex + 1; + return returnedChunks < endChunkIndex - startChunkIndex + 1; + } + + private void checkpoint() { + if (lastReturnedChunkId != -1) { + partitionReaderCheckpointMetadata.ifPresent( + readerCheckpointMetadata -> readerCheckpointMetadata.checkpoint(lastReturnedChunkId)); + } } @Override public ByteBuf next() throws IOException, InterruptedException { checkException(); + checkpoint(); if (chunkIndex <= endChunkIndex) { fetchChunks(); } @@ -186,7 +201,8 @@ public ByteBuf next() throws IOException, InterruptedException { throw e; } returnedChunks++; - chunkIdsAlreadyReturned.add(chunk.getLeft()); + inflightRequestCount--; + lastReturnedChunkId = chunk.getLeft(); return chunk.getRight(); } @@ -226,29 +242,22 @@ public PartitionLocation getLocation() { @Override public Optional getPartitionReaderCheckpointMetadata() { - return isCheckpointEnabled - ? Optional.of(new PartitionReaderCheckpointMetadata(chunkIdsAlreadyReturned)) - : Optional.empty(); - } - - @Override - public void updateCheckpointMetadata(PartitionReaderCheckpointMetadata checkpointMetadata) { - chunkIdsAlreadyReturned = checkpointMetadata.getReturnedChunks(); + return partitionReaderCheckpointMetadata; } private void fetchChunks() throws IOException, InterruptedException { - final int inFlight = chunkIndex - startChunkIndex - returnedChunks; + final int inFlight = inflightRequestCount; if (inFlight < fetchMaxReqsInFlight) { int toFetch = Math.min(fetchMaxReqsInFlight - inFlight + 1, endChunkIndex + 1 - chunkIndex); while (toFetch > 0 && chunkIndex <= endChunkIndex) { - if (chunkIdsAlreadyReturned.contains(chunkIndex)) { + if (partitionReaderCheckpointMetadata.isPresent() + && partitionReaderCheckpointMetadata.get().isCheckpointed(chunkIndex)) { logger.info( "Skipping chunk {} as it has already been returned," + " likely by a previous reader for the same partition.", chunkIndex); chunkIndex++; - returnedChunks++; // IMP Since we're skipping fetching this chunk, we must not decrement toFetch here // Eg: If chunkIndex=1, toFetch=2, endChunkIndex = 4 and chunkIdsAlreadyReturned = {1,2} // if we skip chunk 1 and 2, decrementing toFetch here would wrongly exit the loop @@ -275,6 +284,7 @@ private void fetchChunks() throws IOException, InterruptedException { } } client.fetchChunk(streamHandler.getStreamId(), chunkIndex, fetchTimeoutMs, callback); + inflightRequestCount++; chunkIndex++; toFetch--; } diff --git a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java index f5e9e8c035f..353fa9f8d79 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java +++ b/client/src/main/java/org/apache/celeborn/client/read/checkpoint/PartitionReaderCheckpointMetadata.java @@ -17,19 +17,24 @@ package org.apache.celeborn.client.read.checkpoint; +import java.util.HashSet; import java.util.Set; /** Checkpoint metadata interface for a partition reader. */ public class PartitionReaderCheckpointMetadata { private final Set returnedChunks; - /** - * Create an instance of the checkpoint metadata. - * - * @param returnedChunks The set of chunks that have already been returned to the user. - */ - public PartitionReaderCheckpointMetadata(Set returnedChunks) { - this.returnedChunks = returnedChunks; + /** Create an instance of the checkpoint metadata. */ + public PartitionReaderCheckpointMetadata() { + this.returnedChunks = new HashSet<>(); + } + + public void checkpoint(int chunkId) { + returnedChunks.add(chunkId); + } + + public boolean isCheckpointed(int chunkId) { + return returnedChunks.contains(chunkId); } public Set getReturnedChunks() { diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index d7f23b37a75..703e0f33fa9 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -4695,7 +4695,7 @@ object CelebornConf extends Logging { .createWithDefaultString("64k") val PARTITION_READER_CHECKPOINT_ENABLED: ConfigEntry[Boolean] = - buildConf("celeborn.partition.reader.checkpoint.enabled") + buildConf("celeborn.client.partition.reader.checkpoint.enabled") .categories("client") .version("0.6.0") .doc("Whether or not checkpoint reads when re-creating a partition reader. Setting to true minimizes" + diff --git a/docs/configuration/client.md b/docs/configuration/client.md index 77ca04023b3..c572308d8d8 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -45,6 +45,7 @@ license: | | celeborn.client.flink.shuffle.fallback.policy | AUTO | false | Celeborn supports the following kind of fallback policies. 1. ALWAYS: always use flink built-in shuffle implementation; 2. AUTO: prefer to use celeborn shuffle implementation, and fallback to use flink built-in shuffle implementation based on certain factors, e.g. availability of enough workers and quota; 3. NEVER: always use celeborn shuffle implementation, and fail fast when it it is concluded that fallback is required based on factors above. | 0.6.0 | | | celeborn.client.inputStream.creation.window | 16 | false | Window size that CelebornShuffleReader pre-creates CelebornInputStreams, for coalesced scenario where multiple Partitions are read | 0.5.1 | | | celeborn.client.mr.pushData.max | 32m | false | Max size for a push data sent from mr client. | 0.4.0 | | +| celeborn.client.partition.reader.checkpoint.enabled | false | false | Whether or not checkpoint reads when re-creating a partition reader. Setting to true minimizes the amount of unnecessary reads during partition read retries | 0.6.0 | | | celeborn.client.push.buffer.initial.size | 8k | false | | 0.3.0 | celeborn.push.buffer.initial.size | | celeborn.client.push.buffer.max.size | 64k | false | Max size of reducer partition buffer memory for shuffle hash writer. The pushed data will be buffered in memory before sending to Celeborn worker. For performance consideration keep this buffer size higher than 32K. Example: If reducer amount is 2000, buffer size is 64K, then each task will consume up to `64KiB * 2000 = 125MiB` heap memory. | 0.3.0 | celeborn.push.buffer.max.size | | celeborn.client.push.excludeWorkerOnFailure.enabled | false | false | Whether to enable shuffle client-side push exclude workers on failures. | 0.3.0 | | @@ -128,7 +129,6 @@ license: | | celeborn.identity.user-specific.userName | default | false | User name if celeborn.identity.provider is org.apache.celeborn.common.identity.DefaultIdentityProvider. | 0.6.0 | celeborn.quota.identity.user-specific.userName | | celeborn.master.endpoints | <localhost>:9097 | false | Endpoints of master nodes for celeborn clients to connect. Client uses resolver provided by celeborn.master.endpoints.resolver to resolve the master endpoints. By default Celeborn uses `org.apache.celeborn.common.client.StaticMasterEndpointResolver` which take static master endpoints as input. Allowed pattern: `:[,:]*`, e.g. `clb1:9097,clb2:9098,clb3:9099`. If the port is omitted, 9097 will be used. If the master endpoints are not static then users can pass custom resolver implementation to discover master endpoints actively using celeborn.master.endpoints.resolver. | 0.2.0 | | | celeborn.master.endpoints.resolver | org.apache.celeborn.common.client.StaticMasterEndpointResolver | false | Resolver class that can be used for discovering and updating the master endpoints. This allows users to provide a custom master endpoint resolver implementation. This is useful in environments where the master nodes might change due to scaling operations or infrastructure updates. Clients need to ensure that provided resolver class should be present in the classpath. | 0.5.2 | | -| celeborn.partition.reader.checkpoint.enabled | false | false | Whether or not checkpoint reads when re-creating a partition reader. Setting to true minimizes the amount of unnecessary reads during partition read retries | 0.6.0 | | | celeborn.quota.enabled | true | false | When Master side sets to true, the master will enable to check the quota via QuotaManager. When Client side sets to true, LifecycleManager will request Master side to check whether the current user has enough quota before registration of shuffle. Fallback to the default shuffle service when Master side checks that there is no enough quota for current user. | 0.2.0 | | | celeborn.quota.interruptShuffle.enabled | false | false | Whether to enable interrupt shuffle when quota exceeds. | 0.6.0 | | | celeborn.storage.availableTypes | HDD | false | Enabled storages. Available options: MEMORY,HDD,SSD,HDFS,S3. Note: HDD and SSD would be treated as identical. | 0.3.0 | celeborn.storage.activeTypes | From 6e88b60d05075db8faf19c15fbbc775f797dad0e Mon Sep 17 00:00:00 2001 From: Saurabh Dubey Date: Fri, 14 Mar 2025 09:06:52 +0530 Subject: [PATCH 14/14] review comments --- .../org/apache/celeborn/client/read/WorkerPartitionReader.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index bd3a58170de..df6f3902ca9 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -182,8 +182,8 @@ private void checkpoint() { @Override public ByteBuf next() throws IOException, InterruptedException { - checkException(); checkpoint(); + checkException(); if (chunkIndex <= endChunkIndex) { fetchChunks(); }