diff --git a/Sources/ConnectionPoolModule/ConnectionLease.swift b/Sources/ConnectionPoolModule/ConnectionLease.swift new file mode 100644 index 00000000..77591a58 --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionLease.swift @@ -0,0 +1,17 @@ +public struct ConnectionLease: Sendable { + public var connection: Connection + + @usableFromInline + let _release: @Sendable (Connection) -> () + + @inlinable + public init(connection: Connection, release: @escaping @Sendable (Connection) -> Void) { + self.connection = connection + self._release = release + } + + @inlinable + public func release() { + self._release(self.connection) + } +} diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index b460b263..ee72337d 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -88,7 +88,7 @@ public protocol ConnectionRequestProtocol: Sendable { /// A function that is called with a connection or a /// `PoolError`. - func complete(with: Result) + func complete(with: Result, ConnectionPoolError>) } @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) @@ -402,8 +402,11 @@ public final class ConnectionPool< /*private*/ func runRequestAction(_ action: StateMachine.RequestAction) { switch action { case .leaseConnection(let requests, let connection): + let lease = ConnectionLease(connection: connection) { connection in + self.releaseConnection(connection) + } for request in requests { - request.complete(with: .success(connection)) + request.complete(with: .success(lease)) } case .failRequest(let request, let error): diff --git a/Sources/ConnectionPoolModule/ConnectionRequest.swift b/Sources/ConnectionPoolModule/ConnectionRequest.swift index 1d1c55da..d6654a27 100644 --- a/Sources/ConnectionPoolModule/ConnectionRequest.swift +++ b/Sources/ConnectionPoolModule/ConnectionRequest.swift @@ -5,18 +5,18 @@ public struct ConnectionRequest: ConnectionRequest public var id: ID @usableFromInline - private(set) var continuation: CheckedContinuation + private(set) var continuation: CheckedContinuation, any Error> @inlinable init( id: Int, - continuation: CheckedContinuation + continuation: CheckedContinuation, any Error> ) { self.id = id self.continuation = continuation } - public func complete(with result: Result) { + public func complete(with result: Result, ConnectionPoolError>) { self.continuation.resume(with: result) } } @@ -46,7 +46,7 @@ extension ConnectionPool where Request == ConnectionRequest { } @inlinable - public func leaseConnection() async throws -> Connection { + public func leaseConnection() async throws -> ConnectionLease { let requestID = requestIDGenerator.next() let connection = try await withTaskCancellationHandler { @@ -54,7 +54,7 @@ extension ConnectionPool where Request == ConnectionRequest { throw CancellationError() } - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation, Error>) in let request = Request( id: requestID, continuation: continuation @@ -71,8 +71,8 @@ extension ConnectionPool where Request == ConnectionRequest { @inlinable public func withConnection(_ closure: (Connection) async throws -> Result) async throws -> Result { - let connection = try await self.leaseConnection() - defer { self.releaseConnection(connection) } - return try await closure(connection) + let lease = try await self.leaseConnection() + defer { lease.release() } + return try await closure(lease.connection) } } diff --git a/Sources/ConnectionPoolTestUtils/MockRequest.swift b/Sources/ConnectionPoolTestUtils/MockRequest.swift index 5e4e2fc0..3dd8b0fb 100644 --- a/Sources/ConnectionPoolTestUtils/MockRequest.swift +++ b/Sources/ConnectionPoolTestUtils/MockRequest.swift @@ -1,8 +1,6 @@ import _ConnectionPoolModule -public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { - public typealias Connection = MockConnection - +public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { public struct ID: Hashable, Sendable { var objectID: ObjectIdentifier @@ -11,7 +9,7 @@ public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { } } - public init() {} + public init(connectionType: Connection.Type = Connection.self) {} public var id: ID { ID(self) } @@ -23,7 +21,7 @@ public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { hasher.combine(self.id) } - public func complete(with: Result) { + public func complete(with: Result, ConnectionPoolError>) { } } diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index d54e34eb..0279be07 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -301,11 +301,11 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { /// - Returns: The closure's return value. @_disfavoredOverload public func withConnection(_ closure: (PostgresConnection) async throws -> Result) async throws -> Result { - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() - defer { self.pool.releaseConnection(connection) } + defer { lease.release() } - return try await closure(connection) + return try await closure(lease.connection) } #if compiler(>=6.0) @@ -319,11 +319,11 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED // https://github.com/swiftlang/swift/issues/79285 _ closure: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() - defer { self.pool.releaseConnection(connection) } + defer { lease.release() } - return try await closure(connection) + return try await closure(lease.connection) } /// Lease a connection, which is in an open transaction state, for the provided `closure`'s lifetime. @@ -404,7 +404,8 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line) } - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() + let connection = lease.connection var logger = logger logger[postgresMetadataKey: .connectionID] = "\(connection.id)" @@ -419,12 +420,12 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { connection.channel.write(HandlerTask.extendedQuery(context), promise: nil) promise.futureResult.whenFailure { _ in - self.pool.releaseConnection(connection) + lease.release() } return try await promise.futureResult.map { $0.asyncSequence(onFinish: { - self.pool.releaseConnection(connection) + lease.release() }) }.get() } catch var error as PSQLError { @@ -446,7 +447,8 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { let logger = logger ?? Self.loggingDisabled do { - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() + let connection = lease.connection let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( @@ -460,11 +462,11 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { connection.channel.write(task, promise: nil) promise.futureResult.whenFailure { _ in - self.pool.releaseConnection(connection) + lease.release() } return try await promise.futureResult - .map { $0.asyncSequence(onFinish: { self.pool.releaseConnection(connection) }) } + .map { $0.asyncSequence(onFinish: { lease.release() }) } .get() .map { try preparedStatement.decodeRow($0) } } catch var error as PSQLError { @@ -504,7 +506,7 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { // MARK: - Private Methods - - private func leaseConnection() async throws -> PostgresConnection { + private func leaseConnection() async throws -> ConnectionLease { if !self.runningAtomic.load(ordering: .relaxed) { self.backgroundLogger.warning("Trying to lease connection from `PostgresClient`, but `PostgresClient.run()` hasn't been called yet.") } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index c745b4a0..c1ba89cb 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -39,15 +39,13 @@ final class ConnectionPoolTests: XCTestCase { do { for _ in 0..<1000 { async let connectionFuture = try await pool.leaseConnection() - var leasedConnection: MockConnection? + var connectionLease: ConnectionLease? XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0) - leasedConnection = try await connectionFuture - XCTAssertNotNil(leasedConnection) - XCTAssert(createdConnection === leasedConnection) + connectionLease = try await connectionFuture + XCTAssertNotNil(connectionLease) + XCTAssert(createdConnection === connectionLease?.connection) - if let leasedConnection { - pool.releaseConnection(leasedConnection) - } + connectionLease?.release() } } catch { XCTFail("Unexpected error: \(error)") @@ -195,8 +193,8 @@ final class ConnectionPoolTests: XCTestCase { for _ in 0..]() for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLeases.append(connection) } // Ensure that we got 4 distinct connections - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 4) + XCTAssertEqual(Set(connectionLeases.lazy.map(\.connection.id)).count, 4) // release all 4 leased connections - for connection in connections { - pool.releaseConnection(connection) + for lease in connectionLeases { + lease.release() } // shutdown @@ -727,7 +725,7 @@ final class ConnectionPoolTests: XCTestCase { // create 4 connection requests let requests = (0..<10).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) - var connections = [MockConnection]() + var connectionLeases = [ConnectionLease]() await factory.nextConnectAttempt { connectionID in return 10 @@ -735,15 +733,15 @@ final class ConnectionPoolTests: XCTestCase { for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLeases.append(connection) } // Ensure that all requests got the same connection - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + XCTAssertEqual(Set(connectionLeases.lazy.map(\.connection.id)).count, 1) // release all 10 leased streams - for connection in connections { - pool.releaseConnection(connection) + for lease in connectionLeases { + lease.release() } for _ in 0..<9 { @@ -792,41 +790,41 @@ final class ConnectionPoolTests: XCTestCase { // create 4 connection requests var requests = (0..<21).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) - var connections = [MockConnection]() + var connectionLease = [ConnectionLease]() await factory.nextConnectAttempt { connectionID in return 1 } - let connection = try await requests.first!.future.success - connections.append(connection) + let lease = try await requests.first!.future.success + connectionLease.append(lease) requests.removeFirst() - pool.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: 21) + pool.connectionReceivedNewMaxStreamSetting(lease.connection, newMaxStreamSetting: 21) for (_, request) in requests.enumerated() { let connection = try await request.future.success - connections.append(connection) + connectionLease.append(connection) } // Ensure that all requests got the same connection - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + XCTAssertEqual(Set(connectionLease.lazy.map(\.connection.id)).count, 1) requests = (22..<42).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) // release all 21 leased streams in a single call - pool.releaseConnection(connection, streams: 21) + pool.releaseConnection(lease.connection, streams: 21) // ensure all 20 new requests got fulfilled for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLease.append(connection) } // release all 20 leased streams one by one for _ in requests { - pool.releaseConnection(connection, streams: 1) + pool.releaseConnection(lease.connection, streams: 1) } // shutdown @@ -840,14 +838,14 @@ final class ConnectionPoolTests: XCTestCase { struct ConnectionFuture: ConnectionRequestProtocol { let id: Int - let future: Future + let future: Future> init(id: Int) { self.id = id - self.future = Future(of: MockConnection.self) + self.future = Future(of: ConnectionLease.self) } - func complete(with result: Result) { + func complete(with result: Result, ConnectionPoolError>) { switch result { case .success(let success): self.future.yield(value: success) diff --git a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift index 537efbd9..2952bf8b 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift @@ -6,13 +6,14 @@ final class ConnectionRequestTests: XCTestCase { func testHappyPath() async throws { let mockConnection = MockConnection(id: 1) - let connection = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let lease = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation, any Error>) in let request = ConnectionRequest(id: 42, continuation: continuation) XCTAssertEqual(request.id, 42) - continuation.resume(with: .success(mockConnection)) + let lease = ConnectionLease(connection: mockConnection) { _ in } + continuation.resume(with: .success(lease)) } - XCTAssert(connection === mockConnection) + XCTAssert(lease.connection === mockConnection) } func testSadPath() async throws { diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift index b74b86cc..ddd6a71e 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift @@ -11,7 +11,7 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) XCTAssertEqual(queue.count, 1) XCTAssertFalse(queue.isEmpty) @@ -25,11 +25,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { @@ -49,11 +49,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { @@ -76,11 +76,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { @@ -113,11 +113,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index c0b6ddcd..08afdf8e 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -7,8 +7,8 @@ typealias TestPoolStateMachine = PoolStateMachine< MockConnection, ConnectionIDGenerator, MockConnection.ID, - MockRequest, - MockRequest.ID, + MockRequest, + MockRequest.ID, MockTimerCancellationToken > @@ -75,7 +75,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) // lease connection 1 - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) @@ -84,13 +84,13 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(stateMachine.releaseConnection(connection1, streams: 1), .none()) // lease connection 1 - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .cancelTimers([])) XCTAssertEqual(leaseRequest2.request, .leaseConnection(.init(element: request2), connection1)) // request connection while none is available - let request3 = MockRequest() + let request3 = MockRequest(connectionType: MockConnection.self) let leaseRequest3 = stateMachine.leaseConnection(request3) XCTAssertEqual(leaseRequest3.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest3.request, .none) @@ -132,7 +132,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) XCTAssertEqual(leaseRequest1.request, .none) @@ -144,7 +144,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(createdAction1.connection, .none) // request connection while none is available - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest2.request, .none) @@ -195,13 +195,13 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) // lease connection 1 - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) // request connection while none is available - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest2.request, .none) @@ -245,7 +245,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) XCTAssertEqual(leaseRequest1.request, .none) @@ -287,7 +287,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) XCTAssertEqual(leaseRequest1.request, .none) @@ -309,7 +309,7 @@ final class PoolStateMachineTests: XCTestCase { connection1.closeIfClosing() // request connection while none exists anymore - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest2.request, .none) @@ -354,7 +354,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 1) // one connection should exist - let request = MockRequest() + let request = MockRequest(connectionType: MockConnection.self) let leaseRequest = stateMachine.leaseConnection(request) XCTAssertEqual(leaseRequest.connection, .none) XCTAssertEqual(leaseRequest.request, .none) diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index eaf3663f..9ac92754 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -338,7 +338,7 @@ final class PostgresClientTests: XCTestCase { ) var count = 0 - for try await (id, label) in rows.decode((Int, String).self) { + for try await _ in rows.decode((Int, String).self) { count += 1 } XCTAssertEqual(count, 1)