Skip to content

Commit 36b4895

Browse files
authored
Fix incorrect callback caching (#541)
Motivation: The certificate override callback incorrectly cached results on the SSLContext, instead of the SSLConnection. The result is that it would only fire once for a given SSLContext, instead of on every connection, incorrectly caching the result indefinitely. Modifications: Move the callback state to SSLConnection Write tests to validate that the callback is called the right number of times. Result: Better behaviour.
1 parent 9173d85 commit 36b4895

File tree

5 files changed

+135
-6
lines changed

5 files changed

+135
-6
lines changed

Sources/NIOSSL/NIOSSLHandler.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ public class NIOSSLHandler: ChannelInboundHandler, ChannelOutboundHandler, Remov
411411
}
412412

413413
// If there's a failed custom context operation, we fire both errors.
414-
if let customContextError = self.connection.parentContext.customContextManager?.loadContextError {
414+
if let customContextError = self.connection.customContextManager?.loadContextError {
415415
context.fireErrorCaught(customContextError)
416416
}
417417

Sources/NIOSSL/SSLCallbacks.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ extension CustomContextManager {
342342
// Ensure we execute any completion on the next event loop tick
343343
// This ensures that we suspend before calling resume
344344
eventLoop.assumeIsolated().execute {
345-
connection.parentContext.customContextManager?.state = .complete(result)
345+
connection.customContextManager?.state = .complete(result)
346346
connection.parentHandler?.resumeHandshake()
347347
}
348348
}

Sources/NIOSSL/SSLConnection.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ internal final class SSLConnection {
5858
private var verificationCallback: NIOSSLVerificationCallback?
5959
internal var customVerificationManager: CustomVerifyManager?
6060
internal var customPrivateKeyResult: Result<ByteBuffer, Error>?
61+
internal var customContextManager: CustomContextManager?
6162

6263
/// Whether certificate hostnames should be validated.
6364
var validateHostnames: Bool {
@@ -71,6 +72,10 @@ internal final class SSLConnection {
7172
self.ssl = ownedSSL
7273
self.parentContext = parentContext
7374

75+
if let customContextCallback = parentContext.configuration.sslContextCallback {
76+
self.customContextManager = CustomContextManager(callback: customContextCallback)
77+
}
78+
7479
// We pass the SSL object an unowned reference to this object.
7580
let pointerToSelf = Unmanaged.passUnretained(self).toOpaque()
7681
CNIOBoringSSL_SSL_set_ex_data(self.ssl, sslConnectionExDataIndex, pointerToSelf)

Sources/NIOSSL/SSLContext.swift

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ private func sslContextCallback(ssl: OpaquePointer?, arg: UnsafeMutableRawPointe
250250
)
251251
}
252252

253-
let parentSwiftContext = NIOSSLContext.lookupFromRawContext(ssl: ssl)
253+
let parentSwiftContext = SSLConnection.loadConnectionFromSSL(ssl)
254254

255255
// This is a safe force unwrap as this callback is only register directly after setting the manager
256256
var contextManager = parentSwiftContext.customContextManager!
@@ -295,7 +295,6 @@ public final class NIOSSLContext {
295295
fileprivate let sslContext: OpaquePointer
296296
private let callbackManager: CallbackManagerProtocol?
297297
private var keyLogManager: KeyLogCallbackManager?
298-
internal var customContextManager: CustomContextManager?
299298
internal var pskClientConfigurationCallback: _NIOPSKClientIdentityProvider?
300299
internal var pskServerConfigurationCallback: _NIOPSKServerIdentityProvider?
301300
internal let configuration: TLSConfiguration
@@ -372,8 +371,8 @@ public final class NIOSSLContext {
372371
}
373372

374373
// Set the SSL Context Configuration callback.
375-
if let sslContextConfigurationCallback = configuration.sslContextCallback {
376-
self.customContextManager = CustomContextManager(callback: sslContextConfigurationCallback)
374+
// The state is managed on the connection.
375+
if configuration.sslContextCallback != nil {
377376
CNIOBoringSSL_SSL_CTX_set_cert_cb(context, sslContextCallback, nil)
378377
}
379378

Tests/NIOSSLTests/TLSConfigurationTest.swift

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,24 @@ class TLSConfigurationTest: XCTestCase {
242242
) throws {
243243
let clientContext = try assertNoThrowWithValue(NIOSSLContext(configuration: clientConfig))
244244
let serverContext = try assertNoThrowWithValue(NIOSSLContext(configuration: serverConfig))
245+
try self.assertHandshakeSucceededInMemory(
246+
withClientContext: clientContext,
247+
andServerContext: serverContext,
248+
file: file,
249+
line: line
250+
)
251+
}
245252

253+
/// Performs a connection in memory and validates that the handshake was successful.
254+
///
255+
/// - NOTE: This function should only be used when you know that there is no custom verification
256+
/// callback in use, otherwise it will not be thread-safe.
257+
func assertHandshakeSucceededInMemory(
258+
withClientContext clientContext: NIOSSLContext,
259+
andServerContext serverContext: NIOSSLContext,
260+
file: StaticString = #filePath,
261+
line: UInt = #line
262+
) throws {
246263
let serverChannel = EmbeddedChannel()
247264
let clientChannel = EmbeddedChannel()
248265

@@ -294,7 +311,23 @@ class TLSConfigurationTest: XCTestCase {
294311
file: file,
295312
line: line
296313
)
314+
try self.assertHandshakeSucceededEventLoop(
315+
withClientContext: clientContext,
316+
andServerContext: serverContext,
317+
file: file,
318+
line: line
319+
)
320+
}
297321

322+
/// Performs a connection using a real event loop and validates that the handshake was successful.
323+
///
324+
/// This function is thread-safe in the presence of custom verification callbacks.
325+
func assertHandshakeSucceededEventLoop(
326+
withClientContext clientContext: NIOSSLContext,
327+
andServerContext serverContext: NIOSSLContext,
328+
file: StaticString = #filePath,
329+
line: UInt = #line
330+
) throws {
298331
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
299332
defer {
300333
XCTAssertNoThrow(try group.syncShutdownGracefully())
@@ -358,6 +391,31 @@ class TLSConfigurationTest: XCTestCase {
358391
#endif
359392
}
360393

394+
func assertHandshakeSucceeded(
395+
withClientContext clientContext: NIOSSLContext,
396+
andServerContext serverContext: NIOSSLContext,
397+
file: StaticString = #filePath,
398+
line: UInt = #line
399+
) throws {
400+
// The only use of a custom callback is on Darwin...
401+
#if os(Linux)
402+
return try self.assertHandshakeSucceededInMemory(
403+
withClientContext: clientContext,
404+
andServerContext: serverContext,
405+
file: file,
406+
line: line
407+
)
408+
409+
#else
410+
return try self.assertHandshakeSucceededEventLoop(
411+
withClientContext: clientContext,
412+
andServerContext: serverContext,
413+
file: file,
414+
line: line
415+
)
416+
#endif
417+
}
418+
361419
func setupTLSLeafandClientIdentitiesFromCustomCARoot() throws -> (
362420
leafCert: NIOSSLCertificate, leafKey: NIOSSLPrivateKey,
363421
clientCert: NIOSSLCertificate, clientKey: NIOSSLPrivateKey
@@ -2062,6 +2120,73 @@ class TLSConfigurationTest: XCTestCase {
20622120
errorTextContains: "TLSV1_ALERT_INTERNAL_ERROR"
20632121
)
20642122
}
2123+
2124+
func testClientSideCertSelection_eachConnectionSelectsAgain() throws {
2125+
let callbackCount = NIOLockedValueBox(0)
2126+
var clientConfig = TLSConfiguration.makeClientConfiguration()
2127+
clientConfig.certificateVerification = .noHostnameVerification
2128+
clientConfig.trustRoots = .certificates([TLSConfigurationTest.cert1])
2129+
clientConfig.sslContextCallback = { _, promise in
2130+
callbackCount.withLockedValue { $0 += 1 }
2131+
2132+
var `override` = NIOSSLContextConfigurationOverride()
2133+
override.certificateChain = [.certificate(TLSConfigurationTest.cert2)]
2134+
override.privateKey = .privateKey(TLSConfigurationTest.key2)
2135+
promise.succeed(override)
2136+
}
2137+
2138+
var serverConfig = TLSConfiguration.makeServerConfiguration(
2139+
certificateChain: [.certificate(TLSConfigurationTest.cert1)],
2140+
privateKey: .privateKey(TLSConfigurationTest.key1)
2141+
)
2142+
serverConfig.certificateVerification = .noHostnameVerification
2143+
serverConfig.trustRoots = .certificates([TLSConfigurationTest.cert2])
2144+
2145+
let clientContext = try assertNoThrowWithValue(
2146+
NIOSSLContext(configuration: clientConfig)
2147+
)
2148+
let serverContext = try assertNoThrowWithValue(
2149+
NIOSSLContext(configuration: serverConfig)
2150+
)
2151+
2152+
for _ in 0..<5 {
2153+
try assertHandshakeSucceeded(withClientContext: clientContext, andServerContext: serverContext)
2154+
}
2155+
2156+
XCTAssertEqual(callbackCount.withLockedValue { $0 }, 5)
2157+
}
2158+
2159+
func testServerSideCertSelection_eachConnectionSelectsAgain() throws {
2160+
let callbackCount = NIOLockedValueBox(0)
2161+
var clientConfig = TLSConfiguration.makeClientConfiguration()
2162+
clientConfig.certificateVerification = .noHostnameVerification
2163+
clientConfig.trustRoots = .certificates([TLSConfigurationTest.cert1])
2164+
2165+
var serverConfig = TLSConfiguration.makeServerConfiguration(
2166+
certificateChain: [.certificate(TLSConfigurationTest.cert2)],
2167+
privateKey: .privateKey(TLSConfigurationTest.key2)
2168+
)
2169+
serverConfig.sslContextCallback = { _, promise in
2170+
var `override` = NIOSSLContextConfigurationOverride()
2171+
override.certificateChain = [.certificate(TLSConfigurationTest.cert1)]
2172+
override.privateKey = .privateKey(TLSConfigurationTest.key1)
2173+
callbackCount.withLockedValue { $0 += 1 }
2174+
promise.succeed(override)
2175+
}
2176+
2177+
let clientContext = try assertNoThrowWithValue(
2178+
NIOSSLContext(configuration: clientConfig)
2179+
)
2180+
let serverContext = try assertNoThrowWithValue(
2181+
NIOSSLContext(configuration: serverConfig)
2182+
)
2183+
2184+
for _ in 0..<5 {
2185+
try assertHandshakeSucceeded(withClientContext: clientContext, andServerContext: serverContext)
2186+
}
2187+
2188+
XCTAssertEqual(callbackCount.withLockedValue { $0 }, 5)
2189+
}
20652190
}
20662191

20672192
extension EmbeddedChannel {

0 commit comments

Comments
 (0)