diff --git a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java index e8a203be797..afbbcbfe23f 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java +++ b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java @@ -23,6 +23,7 @@ import javax.annotation.Nullable; +import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -35,6 +36,7 @@ import org.slf4j.LoggerFactory; import org.apache.celeborn.common.metrics.source.AbstractSource; +import org.apache.celeborn.common.network.client.ReconnectHandler; import org.apache.celeborn.common.network.client.TransportClient; import org.apache.celeborn.common.network.client.TransportClientBootstrap; import org.apache.celeborn.common.network.client.TransportClientFactory; @@ -73,6 +75,7 @@ public class TransportContext implements Closeable { @Nullable private final SSLFactory sslFactory; private final boolean enableHeartbeat; private final AbstractSource source; + private ReconnectHandler reconnectHandler; private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE; private static final SslMessageEncoder SSL_ENCODER = SslMessageEncoder.INSTANCE; @@ -173,20 +176,22 @@ public boolean sslEncryptionEnabled() { } public TransportChannelHandler initializePipeline( - SocketChannel channel, ChannelInboundHandlerAdapter decoder, boolean isClient) { - return initializePipeline(channel, decoder, msgHandler, isClient); + SocketChannel channel, ChannelInboundHandlerAdapter decoder, Bootstrap bootstrap) { + return initializePipeline(channel, decoder, msgHandler, true, bootstrap); } public TransportChannelHandler initializePipeline( - SocketChannel channel, BaseMessageHandler resolvedMsgHandler, boolean isClient) { - return initializePipeline(channel, new TransportFrameDecoder(), resolvedMsgHandler, isClient); + SocketChannel channel, BaseMessageHandler resolvedMsgHandler) { + return initializePipeline( + channel, new TransportFrameDecoder(), resolvedMsgHandler, false, null); } public TransportChannelHandler initializePipeline( SocketChannel channel, ChannelInboundHandlerAdapter decoder, BaseMessageHandler resolvedMsgHandler, - boolean isClient) { + boolean isClient, + Bootstrap bootstrap) { try { ChannelPipeline pipeline = channel.pipeline(); if (nettyLogger.getLoggingHandler() != null) { @@ -221,8 +226,12 @@ public TransportChannelHandler initializePipeline( "idleStateHandler", enableHeartbeat ? new IdleStateHandler(conf.connectionTimeoutMs() / 1000, 0, 0) - : new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) - .addLast("handler", channelHandler); + : new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)); + if (isClient) { + reconnectHandler = new ReconnectHandler(conf, bootstrap); + pipeline.addLast(reconnectHandler); + } + pipeline.addLast("handler", channelHandler); return channelHandler; } catch (RuntimeException e) { logger.error("Error while initializing Netty pipeline", e); @@ -264,5 +273,8 @@ public void close() { if (sslFactory != null) { sslFactory.destroy(); } + if (reconnectHandler != null) { + reconnectHandler.stopReconnect(); + } } } diff --git a/common/src/main/java/org/apache/celeborn/common/network/client/ReconnectHandler.java b/common/src/main/java/org/apache/celeborn/common/network/client/ReconnectHandler.java new file mode 100644 index 00000000000..a31d55e3cad --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/network/client/ReconnectHandler.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.network.client; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.celeborn.common.network.util.TransportConf; + +@ChannelHandler.Sharable +public class ReconnectHandler extends ChannelInboundHandlerAdapter { + + private static final Logger LOG = LoggerFactory.getLogger(ReconnectHandler.class); + + private final int maxReconnectRetries; + private final int reconnectRetryWaitTimeMs; + private final Bootstrap bootstrap; + + private final AtomicBoolean stopped = new AtomicBoolean(false); + private final AtomicInteger reconnectRetries = new AtomicInteger(0); + + public ReconnectHandler(TransportConf conf, Bootstrap bootstrap) { + this.maxReconnectRetries = conf.maxReconnectRetries(); + this.reconnectRetryWaitTimeMs = conf.reconnectRetryWaitTimeMs(); + this.bootstrap = bootstrap; + } + + @Override + public void channelActive(ChannelHandlerContext context) throws Exception { + reconnectRetries.set(0); + super.channelActive(context); + } + + @Override + public void channelInactive(ChannelHandlerContext context) throws Exception { + if (stopped.get()) { + super.channelInactive(context); + } else { + scheduleReconnect(context); + } + } + + private void scheduleReconnect(ChannelHandlerContext context) throws Exception { + if (reconnectRetries.incrementAndGet() <= maxReconnectRetries) { + LOG.warn( + "Reconnect to {} {}/{} times.", + context.channel().remoteAddress(), + reconnectRetries, + maxReconnectRetries); + context + .channel() + .eventLoop() + .schedule( + () -> { + bootstrap + .connect() + .addListener( + (ChannelFuture future) -> { + if (future.isSuccess()) { + reconnectRetries.set(0); + } else { + scheduleReconnect(context); + } + }); + }, + reconnectRetryWaitTimeMs, + TimeUnit.MILLISECONDS); + } else { + super.channelInactive(context); + } + } + + public void stopReconnect() { + stopped.set(true); + } +} diff --git a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java index fe261f03f5e..dd00f9d0a30 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java +++ b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java @@ -294,7 +294,8 @@ private TransportClient internalCreateClient( new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { - TransportChannelHandler clientHandler = context.initializePipeline(ch, decoder, true); + TransportChannelHandler clientHandler = + context.initializePipeline(ch, decoder, bootstrap); clientRef.set(clientHandler.getClient()); channelRef.set(ch); } diff --git a/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java b/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java index 1808a000ce5..c939038f1e5 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java +++ b/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java @@ -141,7 +141,7 @@ protected void initChannel(SocketChannel ch) { "Adding bootstrap to TransportServer {}.", bootstrap.getClass().getName()); baseHandler = bootstrap.doBootstrap(ch, baseHandler); } - context.initializePipeline(ch, baseHandler, false); + context.initializePipeline(ch, baseHandler); } }); } diff --git a/common/src/main/java/org/apache/celeborn/common/network/util/TransportConf.java b/common/src/main/java/org/apache/celeborn/common/network/util/TransportConf.java index effafee4355..b5b02d5e454 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/util/TransportConf.java +++ b/common/src/main/java/org/apache/celeborn/common/network/util/TransportConf.java @@ -119,6 +119,22 @@ public int ioRetryWaitTimeMs() { return celebornConf.networkIoRetryWaitMs(module); } + /** + * Max number of times we will try to reconnect per request. If set to 0, we will not do any + * retries. + */ + public int maxReconnectRetries() { + return celebornConf.networkReconnectMaxRetries(module); + } + + /** + * Time (in milliseconds) that we will wait in order to perform a retry after reconnection fails. + * Only relevant if maxReconnectRetries > 0. + */ + public int reconnectRetryWaitTimeMs() { + return celebornConf.networkReconnectRetryWaitMs(module); + } + /** * Minimum size of a block that we should start using memory map rather than reading in through * normal IO operations. This prevents Celeborn from memory mapping very small blocks. In general, diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index c4b8ffb54a2..a73c6e8da2e 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -594,6 +594,14 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se getTransportConfTimeAsMs(module, NETWORK_IO_RETRY_WAIT).toInt } + def networkReconnectMaxRetries(module: String): Int = { + getTransportConfInt(module, NETWORK_RECONNECT_MAX_RETRIES) + } + + def networkReconnectRetryWaitMs(module: String): Int = { + getTransportConfTimeAsMs(module, NETWORK_RECONNECT_RETRY_WAIT).toInt + } + def networkIoMemoryMapBytes(module: String): Int = { getTransportConfSizeAsBytes(module, NETWORK_IO_STORAGE_MEMORY_MAP_THRESHOLD).toInt } @@ -2179,6 +2187,37 @@ object CelebornConf extends Logging { .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("5s") + val NETWORK_RECONNECT_MAX_RETRIES: ConfigEntry[Int] = + buildConf("celeborn..reconnect.maxRetries") + .categories("network") + .version("0.6.0") + .doc( + "Max number of times we will try to reconnect per request. " + + "If set to 0, we will not do any retries. " + + s"If setting to `${TransportModuleConstants.DATA_MODULE}`, " + + s"it works for shuffle client push and fetch data. " + + s"If setting to `${TransportModuleConstants.REPLICATE_MODULE}`, " + + s"it works for replicate client of worker replicating data to peer worker. " + + s"If setting to `${TransportModuleConstants.PUSH_MODULE}`, " + + s"it works for Flink shuffle client push data.") + .intConf + .createWithDefault(0) + + val NETWORK_RECONNECT_RETRY_WAIT: ConfigEntry[Long] = + buildConf("celeborn..reconnect.retryWait") + .categories("network") + .version("0.6.0") + .doc("Time that we will wait in order to perform a retry after reconnection fails. " + + "Only relevant if maxReconnectRetries > 0. " + + s"If setting to `${TransportModuleConstants.DATA_MODULE}`, " + + s"it works for shuffle client push and fetch data. " + + s"If setting to `${TransportModuleConstants.REPLICATE_MODULE}`, " + + s"it works for replicate client of worker replicating data to peer worker. " + + s"If setting to `${TransportModuleConstants.PUSH_MODULE}`, " + + s"it works for Flink shuffle client push data.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("5s") + val NETWORK_IO_LAZY_FD: ConfigEntry[Boolean] = buildConf("celeborn..io.lazyFD") .categories("network") diff --git a/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java b/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java index 7ec7190aa06..a61bebc6e02 100644 --- a/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java +++ b/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java @@ -26,8 +26,12 @@ import java.io.IOException; import java.util.*; +import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -260,4 +264,39 @@ public void testRetryCreateClient() throws IOException, InterruptedException { factory.retryCreateClient("xxx", 10, 1, TransportFrameDecoder::new); Assert.assertEquals(transportClient, client); } + + @Test + public void testChannelInactiveReconnect() + throws IOException, InterruptedException, ExecutionException { + verifyChannelInactiveReconnect(false); + verifyChannelInactiveReconnect(true); + } + + private void verifyChannelInactiveReconnect(boolean reconnectEnabled) + throws IOException, InterruptedException, ExecutionException { + CelebornConf _conf = new CelebornConf(); + if (reconnectEnabled) { + _conf.set(String.format("celeborn.%s.reconnect.maxRetries", TEST_MODULE), "3"); + } + TransportConf conf = new TransportConf(TEST_MODULE, _conf); + try (TransportContext ctx = new TransportContext(conf, new BaseMessageHandler(), true); + TransportClientFactory factory = ctx.createClientFactory()) { + Channel channel = factory.createClient(getLocalHost(), server1.getPort()).getChannel(); + TestReconnectHandler handler = new TestReconnectHandler(); + channel.pipeline().addLast(handler); + channel.disconnect().get(); + Thread.sleep(10000); + assertEquals(reconnectEnabled, handler.active); + } + } + + static class TestReconnectHandler extends ChannelInboundHandlerAdapter { + + private boolean active = true; + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + active = false; + } + } } diff --git a/docs/configuration/network.md b/docs/configuration/network.md index 199a9328fb7..b894553654c 100644 --- a/docs/configuration/network.md +++ b/docs/configuration/network.md @@ -39,6 +39,8 @@ license: | | celeborn.<module>.io.serverThreads | 0 | false | Number of threads used in the server thread pool. Default to 0, which is 2x#cores. If setting to `rpc_app`, works for shuffle client. If setting to `rpc_service`, works for master or worker. If setting to `push`, it works for worker receiving push data. If setting to `replicate`, it works for replicate server of worker replicating data to peer worker. If setting to `fetch`, it works for worker fetch server. | | | | celeborn.<module>.push.timeoutCheck.interval | 5s | false | Interval for checking push data timeout. If setting to `data`, it works for shuffle client push data. If setting to `push`, it works for Flink shuffle client push data. If setting to `replicate`, it works for replicate client of worker replicating data to peer worker. | 0.3.0 | | | celeborn.<module>.push.timeoutCheck.threads | 4 | false | Threads num for checking push data timeout. If setting to `data`, it works for shuffle client push data. If setting to `push`, it works for Flink shuffle client push data. If setting to `replicate`, it works for replicate client of worker replicating data to peer worker. | 0.3.0 | | +| celeborn.<module>.reconnect.maxRetries | 0 | false | Max number of times we will try to reconnect per request. If set to 0, we will not do any retries. If setting to `data`, it works for shuffle client push and fetch data. If setting to `replicate`, it works for replicate client of worker replicating data to peer worker. If setting to `push`, it works for Flink shuffle client push data. | 0.6.0 | | +| celeborn.<module>.reconnect.retryWait | 5s | false | Time that we will wait in order to perform a retry after reconnection fails. Only relevant if maxReconnectRetries > 0. If setting to `data`, it works for shuffle client push and fetch data. If setting to `replicate`, it works for replicate client of worker replicating data to peer worker. If setting to `push`, it works for Flink shuffle client push data. | 0.6.0 | | | celeborn.<role>.rpc.dispatcher.threads | <value of celeborn.rpc.dispatcher.threads> | false | Threads number of message dispatcher event loop for roles | | | | celeborn.io.maxDefaultNettyThreads | 64 | false | Max default netty threads | 0.3.2 | | | celeborn.network.advertise.preferIpAddress | <value of celeborn.network.bind.preferIpAddress> | false | When `true`, prefer to use IP address, otherwise FQDN for advertise address. | 0.6.0 | |