Skip to content

Commit 3ae7ed8

Browse files
committed
[CELEBORN-1912] ReadClientHandler should support reconnection for ChannelInboundHandler#channelInactive
1 parent c1fb94d commit 3ae7ed8

File tree

8 files changed

+218
-9
lines changed

8 files changed

+218
-9
lines changed

common/src/main/java/org/apache/celeborn/common/network/TransportContext.java

+19-7
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import javax.annotation.Nullable;
2525

26+
import io.netty.bootstrap.Bootstrap;
2627
import io.netty.channel.Channel;
2728
import io.netty.channel.ChannelDuplexHandler;
2829
import io.netty.channel.ChannelInboundHandlerAdapter;
@@ -35,6 +36,7 @@
3536
import org.slf4j.LoggerFactory;
3637

3738
import org.apache.celeborn.common.metrics.source.AbstractSource;
39+
import org.apache.celeborn.common.network.client.ReconnectHandler;
3840
import org.apache.celeborn.common.network.client.TransportClient;
3941
import org.apache.celeborn.common.network.client.TransportClientBootstrap;
4042
import org.apache.celeborn.common.network.client.TransportClientFactory;
@@ -73,6 +75,7 @@ public class TransportContext implements Closeable {
7375
@Nullable private final SSLFactory sslFactory;
7476
private final boolean enableHeartbeat;
7577
private final AbstractSource source;
78+
private ReconnectHandler reconnectHandler;
7679

7780
private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
7881
private static final SslMessageEncoder SSL_ENCODER = SslMessageEncoder.INSTANCE;
@@ -173,20 +176,22 @@ public boolean sslEncryptionEnabled() {
173176
}
174177

175178
public TransportChannelHandler initializePipeline(
176-
SocketChannel channel, ChannelInboundHandlerAdapter decoder, boolean isClient) {
177-
return initializePipeline(channel, decoder, msgHandler, isClient);
179+
SocketChannel channel, ChannelInboundHandlerAdapter decoder, Bootstrap bootstrap) {
180+
return initializePipeline(channel, decoder, msgHandler, true, bootstrap);
178181
}
179182

180183
public TransportChannelHandler initializePipeline(
181-
SocketChannel channel, BaseMessageHandler resolvedMsgHandler, boolean isClient) {
182-
return initializePipeline(channel, new TransportFrameDecoder(), resolvedMsgHandler, isClient);
184+
SocketChannel channel, BaseMessageHandler resolvedMsgHandler) {
185+
return initializePipeline(
186+
channel, new TransportFrameDecoder(), resolvedMsgHandler, false, null);
183187
}
184188

185189
public TransportChannelHandler initializePipeline(
186190
SocketChannel channel,
187191
ChannelInboundHandlerAdapter decoder,
188192
BaseMessageHandler resolvedMsgHandler,
189-
boolean isClient) {
193+
boolean isClient,
194+
Bootstrap bootstrap) {
190195
try {
191196
ChannelPipeline pipeline = channel.pipeline();
192197
if (nettyLogger.getLoggingHandler() != null) {
@@ -221,8 +226,12 @@ public TransportChannelHandler initializePipeline(
221226
"idleStateHandler",
222227
enableHeartbeat
223228
? new IdleStateHandler(conf.connectionTimeoutMs() / 1000, 0, 0)
224-
: new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
225-
.addLast("handler", channelHandler);
229+
: new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000));
230+
if (isClient) {
231+
reconnectHandler = new ReconnectHandler(conf, bootstrap);
232+
pipeline.addLast(reconnectHandler);
233+
}
234+
pipeline.addLast("handler", channelHandler);
226235
return channelHandler;
227236
} catch (RuntimeException e) {
228237
logger.error("Error while initializing Netty pipeline", e);
@@ -264,5 +273,8 @@ public void close() {
264273
if (sslFactory != null) {
265274
sslFactory.destroy();
266275
}
276+
if (reconnectHandler != null) {
277+
reconnectHandler.stopReconnect();
278+
}
267279
}
268280
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.celeborn.common.network.client;
19+
20+
import java.util.concurrent.TimeUnit;
21+
import java.util.concurrent.atomic.AtomicBoolean;
22+
import java.util.concurrent.atomic.AtomicInteger;
23+
24+
import io.netty.bootstrap.Bootstrap;
25+
import io.netty.channel.ChannelFuture;
26+
import io.netty.channel.ChannelHandler;
27+
import io.netty.channel.ChannelHandlerContext;
28+
import io.netty.channel.ChannelInboundHandlerAdapter;
29+
import org.slf4j.Logger;
30+
import org.slf4j.LoggerFactory;
31+
32+
import org.apache.celeborn.common.network.util.TransportConf;
33+
34+
@ChannelHandler.Sharable
35+
public class ReconnectHandler extends ChannelInboundHandlerAdapter {
36+
37+
private static final Logger LOG = LoggerFactory.getLogger(ReconnectHandler.class);
38+
39+
private final int maxReconnectRetries;
40+
private final int reconnectRetryWaitTimeMs;
41+
private final Bootstrap bootstrap;
42+
43+
private final AtomicBoolean stopped = new AtomicBoolean(false);
44+
private final AtomicInteger reconnectRetries = new AtomicInteger(0);
45+
46+
public ReconnectHandler(TransportConf conf, Bootstrap bootstrap) {
47+
this.maxReconnectRetries = conf.maxReconnectRetries();
48+
this.reconnectRetryWaitTimeMs = conf.reconnectRetryWaitTimeMs();
49+
this.bootstrap = bootstrap;
50+
}
51+
52+
@Override
53+
public void channelActive(ChannelHandlerContext context) throws Exception {
54+
reconnectRetries.set(0);
55+
super.channelActive(context);
56+
}
57+
58+
@Override
59+
public void channelInactive(ChannelHandlerContext context) throws Exception {
60+
if (stopped.get()) {
61+
super.channelInactive(context);
62+
} else {
63+
scheduleReconnect(context);
64+
}
65+
}
66+
67+
private void scheduleReconnect(ChannelHandlerContext context) throws Exception {
68+
if (reconnectRetries.incrementAndGet() <= maxReconnectRetries) {
69+
LOG.warn(
70+
"Reconnect to {} {}/{} times.",
71+
context.channel().remoteAddress(),
72+
reconnectRetries,
73+
maxReconnectRetries);
74+
context
75+
.channel()
76+
.eventLoop()
77+
.schedule(
78+
() -> {
79+
bootstrap
80+
.connect()
81+
.addListener(
82+
(ChannelFuture future) -> {
83+
if (future.isSuccess()) {
84+
reconnectRetries.set(0);
85+
} else {
86+
scheduleReconnect(context);
87+
}
88+
});
89+
},
90+
reconnectRetryWaitTimeMs,
91+
TimeUnit.MILLISECONDS);
92+
} else {
93+
super.channelInactive(context);
94+
}
95+
}
96+
97+
public void stopReconnect() {
98+
stopped.set(true);
99+
}
100+
}

common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,8 @@ private TransportClient internalCreateClient(
294294
new ChannelInitializer<SocketChannel>() {
295295
@Override
296296
public void initChannel(SocketChannel ch) {
297-
TransportChannelHandler clientHandler = context.initializePipeline(ch, decoder, true);
297+
TransportChannelHandler clientHandler =
298+
context.initializePipeline(ch, decoder, bootstrap);
298299
clientRef.set(clientHandler.getClient());
299300
channelRef.set(ch);
300301
}

common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ protected void initChannel(SocketChannel ch) {
141141
"Adding bootstrap to TransportServer {}.", bootstrap.getClass().getName());
142142
baseHandler = bootstrap.doBootstrap(ch, baseHandler);
143143
}
144-
context.initializePipeline(ch, baseHandler, false);
144+
context.initializePipeline(ch, baseHandler);
145145
}
146146
});
147147
}

common/src/main/java/org/apache/celeborn/common/network/util/TransportConf.java

+16
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,22 @@ public int ioRetryWaitTimeMs() {
119119
return celebornConf.networkIoRetryWaitMs(module);
120120
}
121121

122+
/**
123+
* Max number of times we will try to reconnect per request. If set to 0, we will not do any
124+
* retries.
125+
*/
126+
public int maxReconnectRetries() {
127+
return celebornConf.networkReconnectMaxRetries(module);
128+
}
129+
130+
/**
131+
* Time (in milliseconds) that we will wait in order to perform a retry after reconnection fails.
132+
* Only relevant if maxReconnectRetries > 0.
133+
*/
134+
public int reconnectRetryWaitTimeMs() {
135+
return celebornConf.networkReconnectRetryWaitMs(module);
136+
}
137+
122138
/**
123139
* Minimum size of a block that we should start using memory map rather than reading in through
124140
* normal IO operations. This prevents Celeborn from memory mapping very small blocks. In general,

common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala

+39
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,14 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
594594
getTransportConfTimeAsMs(module, NETWORK_IO_RETRY_WAIT).toInt
595595
}
596596

597+
def networkReconnectMaxRetries(module: String): Int = {
598+
getTransportConfInt(module, NETWORK_RECONNECT_MAX_RETRIES)
599+
}
600+
601+
def networkReconnectRetryWaitMs(module: String): Int = {
602+
getTransportConfTimeAsMs(module, NETWORK_RECONNECT_RETRY_WAIT).toInt
603+
}
604+
597605
def networkIoMemoryMapBytes(module: String): Int = {
598606
getTransportConfSizeAsBytes(module, NETWORK_IO_STORAGE_MEMORY_MAP_THRESHOLD).toInt
599607
}
@@ -2179,6 +2187,37 @@ object CelebornConf extends Logging {
21792187
.timeConf(TimeUnit.MILLISECONDS)
21802188
.createWithDefaultString("5s")
21812189

2190+
val NETWORK_RECONNECT_MAX_RETRIES: ConfigEntry[Int] =
2191+
buildConf("celeborn.<module>.reconnect.maxRetries")
2192+
.categories("network")
2193+
.version("0.6.0")
2194+
.doc(
2195+
"Max number of times we will try to reconnect per request. " +
2196+
"If set to 0, we will not do any retries. " +
2197+
s"If setting <module> to `${TransportModuleConstants.DATA_MODULE}`, " +
2198+
s"it works for shuffle client push and fetch data. " +
2199+
s"If setting <module> to `${TransportModuleConstants.REPLICATE_MODULE}`, " +
2200+
s"it works for replicate client of worker replicating data to peer worker. " +
2201+
s"If setting <module> to `${TransportModuleConstants.PUSH_MODULE}`, " +
2202+
s"it works for Flink shuffle client push data.")
2203+
.intConf
2204+
.createWithDefault(0)
2205+
2206+
val NETWORK_RECONNECT_RETRY_WAIT: ConfigEntry[Long] =
2207+
buildConf("celeborn.<module>.reconnect.retryWait")
2208+
.categories("network")
2209+
.version("0.6.0")
2210+
.doc("Time that we will wait in order to perform a retry after reconnection fails. " +
2211+
"Only relevant if maxReconnectRetries > 0. " +
2212+
s"If setting <module> to `${TransportModuleConstants.DATA_MODULE}`, " +
2213+
s"it works for shuffle client push and fetch data. " +
2214+
s"If setting <module> to `${TransportModuleConstants.REPLICATE_MODULE}`, " +
2215+
s"it works for replicate client of worker replicating data to peer worker. " +
2216+
s"If setting <module> to `${TransportModuleConstants.PUSH_MODULE}`, " +
2217+
s"it works for Flink shuffle client push data.")
2218+
.timeConf(TimeUnit.MILLISECONDS)
2219+
.createWithDefaultString("5s")
2220+
21822221
val NETWORK_IO_LAZY_FD: ConfigEntry[Boolean] =
21832222
buildConf("celeborn.<module>.io.lazyFD")
21842223
.categories("network")

common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java

+39
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@
2626

2727
import java.io.IOException;
2828
import java.util.*;
29+
import java.util.concurrent.ExecutionException;
2930
import java.util.concurrent.atomic.AtomicInteger;
3031

32+
import io.netty.channel.Channel;
33+
import io.netty.channel.ChannelHandlerContext;
34+
import io.netty.channel.ChannelInboundHandlerAdapter;
3135
import org.junit.After;
3236
import org.junit.Assert;
3337
import org.junit.Before;
@@ -260,4 +264,39 @@ public void testRetryCreateClient() throws IOException, InterruptedException {
260264
factory.retryCreateClient("xxx", 10, 1, TransportFrameDecoder::new);
261265
Assert.assertEquals(transportClient, client);
262266
}
267+
268+
@Test
269+
public void testChannelInactiveReconnect()
270+
throws IOException, InterruptedException, ExecutionException {
271+
verifyChannelInactiveReconnect(false);
272+
verifyChannelInactiveReconnect(true);
273+
}
274+
275+
private void verifyChannelInactiveReconnect(boolean reconnectEnabled)
276+
throws IOException, InterruptedException, ExecutionException {
277+
CelebornConf _conf = new CelebornConf();
278+
if (reconnectEnabled) {
279+
_conf.set(String.format("celeborn.%s.reconnect.maxRetries", TEST_MODULE), "3");
280+
}
281+
TransportConf conf = new TransportConf(TEST_MODULE, _conf);
282+
try (TransportContext ctx = new TransportContext(conf, new BaseMessageHandler(), true);
283+
TransportClientFactory factory = ctx.createClientFactory()) {
284+
Channel channel = factory.createClient(getLocalHost(), server1.getPort()).getChannel();
285+
TestReconnectHandler handler = new TestReconnectHandler();
286+
channel.pipeline().addLast(handler);
287+
channel.disconnect().get();
288+
Thread.sleep(10000);
289+
assertEquals(reconnectEnabled, handler.active);
290+
}
291+
}
292+
293+
static class TestReconnectHandler extends ChannelInboundHandlerAdapter {
294+
295+
private boolean active = true;
296+
297+
@Override
298+
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
299+
active = false;
300+
}
301+
}
263302
}

docs/configuration/network.md

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ license: |
3939
| celeborn.&lt;module&gt;.io.serverThreads | 0 | false | Number of threads used in the server thread pool. Default to 0, which is 2x#cores. If setting <module> to `rpc_app`, works for shuffle client. If setting <module> to `rpc_service`, works for master or worker. If setting <module> to `push`, it works for worker receiving push data. If setting <module> to `replicate`, it works for replicate server of worker replicating data to peer worker. If setting <module> to `fetch`, it works for worker fetch server. | | |
4040
| celeborn.&lt;module&gt;.push.timeoutCheck.interval | 5s | false | Interval for checking push data timeout. If setting <module> to `data`, it works for shuffle client push data. If setting <module> to `push`, it works for Flink shuffle client push data. If setting <module> to `replicate`, it works for replicate client of worker replicating data to peer worker. | 0.3.0 | |
4141
| celeborn.&lt;module&gt;.push.timeoutCheck.threads | 4 | false | Threads num for checking push data timeout. If setting <module> to `data`, it works for shuffle client push data. If setting <module> to `push`, it works for Flink shuffle client push data. If setting <module> to `replicate`, it works for replicate client of worker replicating data to peer worker. | 0.3.0 | |
42+
| celeborn.&lt;module&gt;.reconnect.maxRetries | 0 | false | Max number of times we will try to reconnect per request. If set to 0, we will not do any retries. If setting <module> to `data`, it works for shuffle client push and fetch data. If setting <module> to `replicate`, it works for replicate client of worker replicating data to peer worker. If setting <module> to `push`, it works for Flink shuffle client push data. | 0.6.0 | |
43+
| celeborn.&lt;module&gt;.reconnect.retryWait | 5s | false | Time that we will wait in order to perform a retry after reconnection fails. Only relevant if maxReconnectRetries > 0. If setting <module> to `data`, it works for shuffle client push and fetch data. If setting <module> to `replicate`, it works for replicate client of worker replicating data to peer worker. If setting <module> to `push`, it works for Flink shuffle client push data. | 0.6.0 | |
4244
| celeborn.&lt;role&gt;.rpc.dispatcher.threads | &lt;value of celeborn.rpc.dispatcher.threads&gt; | false | Threads number of message dispatcher event loop for roles | | |
4345
| celeborn.io.maxDefaultNettyThreads | 64 | false | Max default netty threads | 0.3.2 | |
4446
| celeborn.network.advertise.preferIpAddress | &lt;value of celeborn.network.bind.preferIpAddress&gt; | false | When `true`, prefer to use IP address, otherwise FQDN for advertise address. | 0.6.0 | |

0 commit comments

Comments
 (0)