diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java index ac761457f5..728edd0442 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java @@ -190,9 +190,18 @@ public static final class Builder { private int maxInboundMessageSize = MAX_GRPC_MESSAGE_SIZE; private int maxHeaderListSize = MAX_GRPC_MESSAGE_SIZE; private int backpressureThreshold = DEFAULT_BACKPRESSURE_THRESHOLD; + + /* + TODO: Remove certChain/key/mTlsCACert as instance vars and make them local vars in build() after they are + no longer being set in the builder methods. + */ + private File certChainFile; private InputStream certChain; + private File keyFile; private InputStream key; + private File mTlsCACertFile; private InputStream mTlsCACert; + private SslContext sslContext; private final List> interceptors; // Keep track of inserted interceptors @@ -213,6 +222,16 @@ public static final class Builder { /** Create the server for this builder. */ public FlightServer build() { + // Get TLS info in order if the server is being configured to use mTLS. + try { + prepareTlsSettings(); + } catch (IOException e) { + closeMTlsCACert(); + closeCertChain(); + closeKey(); + throw new RuntimeException("Could not create FlightServer with mTLS", e); + } + // Add the auth middleware if applicable. if (headerAuthenticator != CallHeaderAuthenticator.NO_OP) { this.middleware( @@ -442,11 +461,8 @@ private void closeMTlsCACert() { * @param key The private key to use. */ public Builder useTls(final File certChain, final File key) throws IOException { - closeCertChain(); - this.certChain = new FileInputStream(certChain); - - closeKey(); - this.key = new FileInputStream(key); + this.certChainFile = certChain; + this.keyFile = key; return this; } @@ -457,8 +473,8 @@ public Builder useTls(final File certChain, final File key) throws IOException { * @param mTlsCACert The CA certificate to use for verifying clients. */ public Builder useMTlsClientVerification(final File mTlsCACert) throws IOException { - closeMTlsCACert(); - this.mTlsCACert = new FileInputStream(mTlsCACert); + this.mTlsCACertFile = mTlsCACert; + return this; } @@ -468,6 +484,7 @@ public Builder useMTlsClientVerification(final File mTlsCACert) throws IOExcepti * @param certChain The certificate chain to use. * @param key The private key to use. */ + @Deprecated(forRemoval = true, since = "18.4.0") public Builder useTls(final InputStream certChain, final InputStream key) throws IOException { closeCertChain(); this.certChain = certChain; @@ -483,6 +500,7 @@ public Builder useTls(final InputStream certChain, final InputStream key) throws * * @param mTlsCACert The CA certificate to use for verifying clients. */ + @Deprecated(forRemoval = true, since = "18.4.0") public Builder useMTlsClientVerification(final InputStream mTlsCACert) throws IOException { closeMTlsCACert(); this.mTlsCACert = mTlsCACert; @@ -552,5 +570,20 @@ public Builder producer(FlightProducer producer) { this.producer = Preconditions.checkNotNull(producer); return this; } + + private void prepareTlsSettings() throws IOException { + if (this.mTlsCACertFile != null) { + closeMTlsCACert(); + this.mTlsCACert = new FileInputStream(this.mTlsCACertFile); + } + if (this.certChainFile != null) { + closeCertChain(); + this.certChain = new FileInputStream(this.certChainFile); + } + if (this.keyFile != null) { + closeKey(); + this.key = new FileInputStream(this.keyFile); + } + } } }