Skip to content

Commit

Permalink
Fix TLS initialization (#1426)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoran10 authored Aug 25, 2024
1 parent 7fe8245 commit 9817118
Show file tree
Hide file tree
Showing 11 changed files with 456 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,71 +22,56 @@
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.codec.mqtt.MqttDecoder;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.timeout.IdleStateHandler;
import io.streamnative.pulsar.handlers.mqtt.adapter.CombineAdapterHandler;
import io.streamnative.pulsar.handlers.mqtt.adapter.MqttAdapterDecoder;
import io.streamnative.pulsar.handlers.mqtt.adapter.MqttAdapterEncoder;
import io.streamnative.pulsar.handlers.mqtt.codec.MqttWebSocketCodec;
import io.streamnative.pulsar.handlers.mqtt.support.psk.PSKUtils;
import org.apache.pulsar.common.util.NettyServerSslContextBuilder;
import org.apache.pulsar.common.util.SslContextAutoRefreshBuilder;
import org.apache.pulsar.common.util.keystoretls.NettySSLContextAutoRefreshBuilder;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import lombok.extern.slf4j.Slf4j;
import org.apache.pulsar.common.util.PulsarSslConfiguration;
import org.apache.pulsar.common.util.PulsarSslFactory;

/**
* A channel initializer that initialize channels for MQTT protocol.
*/
@Slf4j
public class MQTTChannelInitializer extends ChannelInitializer<SocketChannel> {

private final MQTTServerConfiguration mqttConfig;
private final MQTTService mqttService;
private final boolean enableTls;
private final boolean enableTlsPsk;
private final boolean enableWs;
private final boolean tlsEnabledWithKeyStore;
private PulsarSslFactory sslFactory;

private SslContextAutoRefreshBuilder<SslContext> sslCtxRefresher;
private NettySSLContextAutoRefreshBuilder nettySSLContextAutoRefreshBuilder;

public MQTTChannelInitializer(MQTTService mqttService, boolean enableTls, boolean enableWs) {
this(mqttService, enableTls, false, enableWs);
public MQTTChannelInitializer(MQTTService mqttService, boolean enableTls, boolean enableWs,
ScheduledExecutorService sslContextRefresher) throws Exception {
this(mqttService, enableTls, false, enableWs, sslContextRefresher);
}

public MQTTChannelInitializer(MQTTService mqttService, boolean enableTls, boolean enableTlsPsk, boolean enableWs) {
public MQTTChannelInitializer(
MQTTService mqttService, boolean enableTls, boolean enableTlsPsk, boolean enableWs,
ScheduledExecutorService sslContextRefresher) throws Exception {
super();
this.mqttService = mqttService;
this.mqttConfig = mqttService.getServerConfiguration();
this.enableTls = enableTls;
this.enableTlsPsk = enableTlsPsk;
this.enableWs = enableWs;
this.tlsEnabledWithKeyStore = mqttConfig.isMqttTlsEnabledWithKeyStore();
if (this.enableTls) {
if (tlsEnabledWithKeyStore) {
nettySSLContextAutoRefreshBuilder = new NettySSLContextAutoRefreshBuilder(
mqttConfig.getMqttTlsProvider(),
mqttConfig.getMqttTlsKeyStoreType(),
mqttConfig.getMqttTlsKeyStore(),
mqttConfig.getMqttTlsKeyStorePassword(),
mqttConfig.isMqttTlsAllowInsecureConnection(),
mqttConfig.getMqttTlsTrustStoreType(),
mqttConfig.getMqttTlsTrustStore(),
mqttConfig.getMqttTlsTrustStorePassword(),
mqttConfig.isMqttTlsRequireTrustedClientCertOnConnect(),
mqttConfig.getMqttTlsCiphers(),
mqttConfig.getMqttTlsProtocols(),
mqttConfig.getMqttTlsCertRefreshCheckDurationSec());
} else {
sslCtxRefresher = new NettyServerSslContextBuilder(
null,
mqttConfig.isMqttTlsAllowInsecureConnection(),
mqttConfig.getMqttTlsTrustCertsFilePath(),
mqttConfig.getMqttTlsCertificateFilePath(),
mqttConfig.getMqttTlsKeyFilePath(),
mqttConfig.getMqttTlsCiphers(),
mqttConfig.getMqttTlsProtocols(),
mqttConfig.isMqttTlsRequireTrustedClientCertOnConnect(),
mqttConfig.getMqttTlsCertRefreshCheckDurationSec());
PulsarSslConfiguration sslConfiguration = buildSslConfiguration(mqttConfig);
this.sslFactory = (PulsarSslFactory) Class.forName(mqttConfig.getSslFactoryPlugin())
.getConstructor().newInstance();
this.sslFactory.initialize(sslConfiguration);
this.sslFactory.createInternalSslContext();
if (mqttConfig.getTlsCertRefreshCheckDurationSec() > 0) {
sslContextRefresher.scheduleWithFixedDelay(this::refreshSslContext,
mqttConfig.getTlsCertRefreshCheckDurationSec(),
mqttConfig.getTlsCertRefreshCheckDurationSec(), TimeUnit.SECONDS);
}
}
}
Expand All @@ -95,12 +80,7 @@ public MQTTChannelInitializer(MQTTService mqttService, boolean enableTls, boolea
public void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().addFirst("idleStateHandler", new IdleStateHandler(0, 0, 120));
if (this.enableTls) {
if (this.tlsEnabledWithKeyStore) {
ch.pipeline().addLast(TLS_HANDLER,
new SslHandler(nettySSLContextAutoRefreshBuilder.get().createSSLEngine()));
} else {
ch.pipeline().addLast(TLS_HANDLER, sslCtxRefresher.get().newHandler(ch.alloc()));
}
ch.pipeline().addLast(TLS_HANDLER, new SslHandler(sslFactory.createServerSslEngine(ch.alloc())));
} else if (this.enableTlsPsk) {
ch.pipeline().addLast(TLS_HANDLER,
new SslHandler(PSKUtils.createServerEngine(ch, mqttService.getPskConfiguration())));
Expand Down Expand Up @@ -138,4 +118,36 @@ private void addWsHandler(ChannelPipeline pipeline) {
true, mqttConfig.getWebSocketMaxFrameSize()));
pipeline.addLast(Constants.HANDLER_MQTT_WEB_SOCKET_CODEC, new MqttWebSocketCodec());
}

protected PulsarSslConfiguration buildSslConfiguration(MQTTServerConfiguration config) {
return PulsarSslConfiguration.builder()
.tlsProvider(config.getMqttTlsProvider())
.tlsKeyStoreType(config.getMqttTlsKeyStoreType())
.tlsKeyStorePath(config.getMqttTlsKeyStore())
.tlsKeyStorePassword(config.getMqttTlsKeyStorePassword())
.tlsTrustStoreType(config.getMqttTlsTrustStoreType())
.tlsTrustStorePath(config.getMqttTlsTrustStore())
.tlsTrustStorePassword(config.getMqttTlsTrustStorePassword())
.tlsCiphers(config.getMqttTlsCiphers())
.tlsProtocols(config.getMqttTlsProtocols())
.tlsTrustCertsFilePath(config.getMqttTlsTrustCertsFilePath())
.tlsCertificateFilePath(config.getMqttTlsCertificateFilePath())
.tlsKeyFilePath(config.getMqttTlsKeyFilePath())
.allowInsecureConnection(config.isMqttTlsAllowInsecureConnection())
.requireTrustedClientCertOnConnect(config.isMqttTlsRequireTrustedClientCertOnConnect())
.tlsEnabledWithKeystore(config.isMqttTlsEnabledWithKeyStore())
.tlsCustomParams(config.getSslFactoryPluginParams())
.authData(null)
.serverMode(true)
.build();
}

protected void refreshSslContext() {
try {
this.sslFactory.update();
} catch (Exception e) {
log.error("Failed to refresh SSL context for mqtt channel.", e);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,21 @@
import com.google.common.collect.ImmutableMap;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import io.streamnative.pulsar.handlers.mqtt.proxy.MQTTProxyConfiguration;
import io.streamnative.pulsar.handlers.mqtt.proxy.MQTTProxyService;
import io.streamnative.pulsar.handlers.mqtt.utils.ConfigurationUtils;
import java.net.InetSocketAddress;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.pulsar.broker.ServiceConfiguration;
import org.apache.pulsar.broker.ServiceConfigurationUtils;
import org.apache.pulsar.broker.protocol.ProtocolHandler;
import org.apache.pulsar.broker.service.BrokerService;

/**
* MQTT Protocol Handler load and run by Pulsar Service.
*/
Expand All @@ -56,6 +60,8 @@ public class MQTTProtocolHandler implements ProtocolHandler {
@Getter
private MQTTService mqttService;

private ScheduledExecutorService sslContextRefresher;

@Override
public String protocolName() {
return PROTOCOL_NAME;
Expand Down Expand Up @@ -112,6 +118,9 @@ public Map<InetSocketAddress, ChannelInitializer<SocketChannel>> newChannelIniti
checkArgument(mqttConfig.getMqttListeners() != null);
checkArgument(brokerService != null);

this.sslContextRefresher = Executors.newSingleThreadScheduledExecutor(
new DefaultThreadFactory("mop-ssl-context-refresher"));

String listeners = mqttConfig.getMqttListeners();
String[] parts = listeners.split(LISTENER_DEL);
try {
Expand All @@ -122,27 +131,28 @@ public Map<InetSocketAddress, ChannelInitializer<SocketChannel>> newChannelIniti
if (listener.startsWith(PLAINTEXT_PREFIX)) {
builder.put(
new InetSocketAddress(brokerService.pulsar().getBindAddress(), getListenerPort(listener)),
new MQTTChannelInitializer(mqttService, false, false));
new MQTTChannelInitializer(mqttService, false, false, sslContextRefresher));

} else if (listener.startsWith(SSL_PREFIX)) {
builder.put(
new InetSocketAddress(brokerService.pulsar().getBindAddress(), getListenerPort(listener)),
new MQTTChannelInitializer(mqttService, true, false));
new MQTTChannelInitializer(mqttService, true, false, sslContextRefresher));

} else if (listener.startsWith(SSL_PSK_PREFIX) && mqttConfig.isMqttTlsPskEnabled()) {
builder.put(
new InetSocketAddress(brokerService.pulsar().getBindAddress(), getListenerPort(listener)),
new MQTTChannelInitializer(mqttService, false, true, false));
new MQTTChannelInitializer(
mqttService, false, true, false, sslContextRefresher));

} else if (listener.startsWith(WS_PLAINTEXT_PREFIX)) {
builder.put(
new InetSocketAddress(brokerService.pulsar().getBindAddress(), getListenerPort(listener)),
new MQTTChannelInitializer(mqttService, false, true));
new MQTTChannelInitializer(mqttService, false, true, sslContextRefresher));

} else if (listener.startsWith(WS_SSL_PREFIX)) {
builder.put(
new InetSocketAddress(brokerService.pulsar().getBindAddress(), getListenerPort(listener)),
new MQTTChannelInitializer(mqttService, true, true));
new MQTTChannelInitializer(mqttService, true, true, sslContextRefresher));

} else {
log.error("MQTT listener {} not supported. supports {}, {} or {}",
Expand All @@ -159,6 +169,9 @@ public Map<InetSocketAddress, ChannelInitializer<SocketChannel>> newChannelIniti

@Override
public void close() {
if (sslContextRefresher != null) {
sslContextRefresher.shutdownNow();
}
if (proxyService != null) {
proxyService.close();
}
Expand Down
Loading

0 comments on commit 9817118

Please sign in to comment.