diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift new file mode 100644 index 00000000..51319f41 --- /dev/null +++ b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift @@ -0,0 +1,227 @@ +/// Handle to send data for a `COPY ... FROM STDIN` query to the backend. +public struct PostgresCopyFromWriter: Sendable { + private let channelHandler: NIOLoopBound + private let eventLoop: any EventLoop + + init(handler: PostgresChannelHandler, eventLoop: any EventLoop) { + self.channelHandler = NIOLoopBound(handler, eventLoop: eventLoop) + self.eventLoop = eventLoop + } + + private func writeAssumingInEventLoop(_ byteBuffer: ByteBuffer, _ continuation: CheckedContinuation) { + precondition(eventLoop.inEventLoop) + let promise = eventLoop.makePromise(of: Void.self) + self.channelHandler.value.checkBackendCanReceiveCopyData(promise: promise) + promise.futureResult.flatMap { + if eventLoop.inEventLoop { + return eventLoop.makeCompletedFuture(withResultOf: { + try self.channelHandler.value.sendCopyData(byteBuffer) + }) + } else { + let promise = eventLoop.makePromise(of: Void.self) + eventLoop.execute { + promise.completeWith(Result(catching: { try self.channelHandler.value.sendCopyData(byteBuffer) })) + } + return promise.futureResult + } + }.whenComplete { result in + continuation.resume(with: result) + } + } + + #if compiler(>=6.0) + /// Send data for a `COPY ... FROM STDIN` operation to the backend. + /// + /// - Throws: If an error occurs during the write of if the backend sent an `ErrorResponse` during the copy + /// operation, eg. to indicate that a **previous** `write` call had an invalid format. + public func write(_ byteBuffer: ByteBuffer, isolation: isolated (any Actor)? = #isolation) async throws { + // Check for cancellation. This is cheap and makes sure that we regularly check for cancellation in the + // `writeData` closure. It is likely that the user would forget to do so. + try Task.checkCancellation() + + try await withTaskCancellationHandler { + do { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + if eventLoop.inEventLoop { + writeAssumingInEventLoop(byteBuffer, continuation) + } else { + eventLoop.execute { + writeAssumingInEventLoop(byteBuffer, continuation) + } + } + } + } catch { + if Task.isCancelled { + // If the task was cancelled, we might receive a postgres error which is an artifact about how we + // communicate the cancellation to the state machine. Throw a `CancellationError` to the user + // instead, which looks more like native Swift Concurrency code. + throw CancellationError() + } + throw error + } + } onCancel: { + if eventLoop.inEventLoop { + self.channelHandler.value.cancel() + } else { + eventLoop.execute { + self.channelHandler.value.cancel() + } + } + } + } + + /// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyDone` message to + /// the backend. + func done(isolation: isolated (any Actor)? = #isolation) async throws { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + if eventLoop.inEventLoop { + self.channelHandler.value.sendCopyDone(continuation: continuation) + } else { + eventLoop.execute { + self.channelHandler.value.sendCopyDone(continuation: continuation) + } + } + } + } + + /// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyFail` message to + /// the backend. + func failed(error: any Error, isolation: isolated (any Actor)? = #isolation) async throws { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + if eventLoop.inEventLoop { + self.channelHandler.value.sendCopyFail(message: "Client failed copy", continuation: continuation) + } else { + eventLoop.execute { + self.channelHandler.value.sendCopyFail(message: "Client failed copy", continuation: continuation) + } + } + } + } + #endif +} + +/// Specifies the format in which data is transferred to the backend in a COPY operation. +/// +/// See the Postgres documentation at https://www.postgresql.org/docs/current/sql-copy.html for the option's meanings +/// and their default values. +public struct PostgresCopyFromFormat: Sendable { + /// Options that can be used to modify the `text` format of a COPY operation. + public struct TextOptions: Sendable { + /// The delimiter that separates columns in the data. + /// + /// See the `DELIMITER` option in Postgres's `COPY` command. + public var delimiter: UnicodeScalar? = nil + + public init() {} + } + + enum Format { + case text(TextOptions) + } + + var format: Format + + public static func text(_ options: TextOptions) -> PostgresCopyFromFormat { + return PostgresCopyFromFormat(format: .text(options)) + } +} + +#if compiler(>=6.0) +/// Create a `COPY ... FROM STDIN` query based on the given parameters. +/// +/// An empty `columns` array signifies that no columns should be specified in the query and that all columns will be +/// copied by the caller. +private func buildCopyFromQuery( + table: StaticString, + columns: [StaticString] = [], + format: PostgresCopyFromFormat +) -> PostgresQuery { + var query = """ + COPY "\(table)" + """ + if !columns.isEmpty { + query += "(" + query += columns.map { #"""# + $0.description + #"""# }.joined(separator: ",") + query += ")" + } + query += " FROM STDIN" + var queryOptions: [String] = [] + switch format.format { + case .text(let options): + queryOptions.append("FORMAT text") + if let delimiter = options.delimiter { + // Set the delimiter as a Unicode code point. This avoids the possibility of SQL injection. + queryOptions.append("DELIMITER U&'\\\(String(format: "%04x", delimiter.value))'") + } + } + precondition(!queryOptions.isEmpty) + query += " WITH (" + query += queryOptions.map { "\($0)" }.joined(separator: ",") + query += ")" + return "\(unescaped: query)" +} + +extension PostgresConnection { + /// Copy data into a table using a `COPY FROM STDIN` query. + /// + /// - Parameters: + /// - table: The name of the table into which to copy the data. + /// - columns: The name of the columns to copy. If an empty array is passed, all columns are assumed to be copied. + /// - format: Options that specify the format of the data that is produced by `writeData`. + /// - writeData: Closure that produces the data for the table, to be streamed to the backend. Call `write` on the + /// writer provided by the closure to send data to the backend and return from the closure once all data is sent. + /// Throw an error from the closure to fail the data transfer. The error thrown by the closure will be rethrown + /// by the `copyFrom` function. + /// + /// - Note: The table and column names are inserted into the SQL query verbatim. They are forced to be compile-time + /// specified to avoid runtime SQL injection attacks. + public func copyFrom( + table: StaticString, + columns: [StaticString] = [], + format: PostgresCopyFromFormat = .text(.init()), + logger: Logger, + isolation: isolated (any Actor)? = #isolation, + file: String = #fileID, + line: Int = #line, + writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void + ) async throws { + var logger = logger + logger[postgresMetadataKey: .connectionID] = "\(self.id)" + let writer: PostgresCopyFromWriter = try await withCheckedThrowingContinuation { continuation in + let context = ExtendedQueryContext( + copyFromQuery: buildCopyFromQuery(table: table, columns: columns, format: format), + triggerCopy: continuation, + logger: logger + ) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + } + + do { + try await writeData(writer) + } catch { + // We need to send a `CopyFail` to the backend to put it out of copy mode. This will most likely throw, most + // notably for the following two reasons. In both of them, it's better to ignore the error thrown by + // `writer.failed` and instead throw the error from `writeData`: + // - We send `CopyFail` and the backend replies with an `ErrorResponse` that relays the `CopyFail` message. + // This took the backend out of copy mode but it's more informative to the user to see the error they + // threw instead of the one that got relayed back, so it's better to ignore the error here. + // - The backend sent us an `ErrorResponse` during the copy, eg. because of an invalid format. This puts + // the `ExtendedQueryStateMachine` in the error state. Trying to send a `CopyFail` will throw but trigger + // a `Sync` that takes the backend out of copy mode. If `writeData` threw the error from from the + // `PostgresCopyFromWriter.write` call, `writer.failed` will throw with the same error, so it doesn't + // matter that we ignore the error here. If the user threw some other error, it's better to honor the + // user's error. + try? await writer.failed(error: error) + + throw error + } + + // `writer.done` may fail, eg. because the backend sends an error response after receiving `CopyDone` or during + // the transfer of the last bit of data so that the user didn't call `PostgresCopyFromWriter.write` again, which + // would have checked the error state. In either of these cases, calling `writer.done` puts the backend out of + // copy mode, so we don't need to send another `CopyFail`. Thus, this must not be handled in the `do` block + // above. + try await writer.done() + } +} +#endif diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 8560b948..6b0a8059 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -88,7 +88,33 @@ struct ConnectionStateMachine { case sendParseDescribeBindExecuteSync(PostgresQuery) case sendBindExecuteSync(PSQLExecuteStatement) case failQuery(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) + /// Fail a query's execution by resuming the continuation with the given error. When `sync` is `true`, send a + /// `Sync` message to the backend. + case failQueryContinuation(AnyErrorContinuation, with: PSQLError, sync: Bool, cleanupContext: CleanUpContext?) + /// Fail a query's execution by resuming the continuation with the given error and send a `Sync` message to the + /// backend. case succeedQuery(EventLoopPromise, with: QueryResult) + /// Succeed the continuation with a void result. When `sync` is `true`, send a `Sync` message to the backend. + case succeedQueryContinuation(CheckedContinuation, sync: Bool) + + /// Trigger a data transfer returning a `PostgresCopyFromWriter` to the given continuation. + /// + /// Once the data transfer is triggered, it will send `CopyData` messages to the backend. After that the state + /// machine needs to be prodded again to send a `CopyDone` or `CopyFail` by calling + /// `PostgresChannelHandler.sendCopyDone` or `PostgresChannelHandler.sendCopyFail`. + case triggerCopyData(CheckedContinuation) + + /// Send a `CopyDone` and `Sync` message to the backend. + case sendCopyDoneAndSync + + /// Send a `CopyFail` message to the backend with the given error message. + case sendCopyFail(message: String) + + /// Fail the promise with the given error and close the connection. + /// + /// This is used when we want to cancel a COPY operation while waiting for backpressure relieve. In that case we + /// can't recover the connection because we can't send any messages to the backend, so we need to close it. + case failPromiseAndCloseConnection(EventLoopPromise, error: PSQLError, cleanupContext: CleanUpContext) // --- streaming actions // actions if query has requested next row but we are waiting for backend @@ -107,6 +133,25 @@ struct ConnectionStateMachine { case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?) } + enum ChannelWritabilityChangedAction { + /// No action needs to be taken based on the writability change. + case none + + /// Resume the given continuation successfully. + case succeedPromise(EventLoopPromise) + } + + enum CheckBackendCanReceiveCopyDataAction { + /// Don't perform any action. + case none + + /// Succeed the promise with a Void result. + case succeedPromise(EventLoopPromise) + + /// Fail the promise with the given error. + case failPromise(EventLoopPromise, error: any Error) + } + private var state: State private let requireBackendKeyData: Bool private var taskQueue = CircularBuffer() @@ -587,6 +632,8 @@ struct ConnectionStateMachine { switch queryContext.query { case .executeStatement(_, let promise), .unnamed(_, let promise): return .failQuery(promise, with: psqlErrror, cleanupContext: nil) + case .copyFrom(_, let triggerCopy): + return .failQueryContinuation(.copyFromWriter(triggerCopy), with: psqlErrror, sync: false, cleanupContext: nil) case .prepareStatement(_, _, _, let promise): return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil) } @@ -660,6 +707,16 @@ struct ConnectionStateMachine { preconditionFailure("Invalid state: \(self.state)") } } + + mutating func channelWritabilityChanged(isWritable: Bool) -> ChannelWritabilityChangedAction { + guard case .extendedQuery(var queryState, let connectionContext) = state else { + return .none + } + self.state = .modifying // avoid CoW + let action = queryState.channelWritabilityChanged(isWritable: isWritable) + self.state = .extendedQuery(queryState, connectionContext) + return action + } // MARK: - Running Queries - @@ -752,10 +809,56 @@ struct ConnectionStateMachine { return self.modify(with: action) } - mutating func copyInResponseReceived( - _ copyInResponse: PostgresBackendMessage.CopyInResponse - ) -> ConnectionAction { - return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + mutating func copyInResponseReceived(_ copyInResponse: PostgresBackendMessage.CopyInResponse) -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + + self.state = .modifying // avoid CoW + let action = queryState.copyInResponseReceived(copyInResponse) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + } + + + /// Succeed the promise when the channel to the backend is writable and the backend is ready to receive more data. + /// + /// The promise may be failed if the backend indicated that it can't handle any more data by sending an + /// `ErrorResponse`. This is mostly the case when malformed data is sent to it. In that case, the data transfer + /// should be aborted to avoid unnecessary work. + mutating func checkBackendCanReceiveCopyData(channelIsWritable: Bool, promise: EventLoopPromise) -> CheckBackendCanReceiveCopyDataAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state else { + preconditionFailure("Copy mode is only supported for extended queries") + } + + self.state = .modifying // avoid CoW + let action = queryState.checkBackendCanReceiveCopyData(channelIsWritable: channelIsWritable, promise: promise) + self.state = .extendedQuery(queryState, connectionContext) + return action + } + + /// Put the state machine out of the copying mode and send a `CopyDone` message to the backend. + mutating func sendCopyDone(continuation: CheckedContinuation) -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state else { + preconditionFailure("Copy mode is only supported for extended queries") + } + + self.state = .modifying // avoid CoW + let action = queryState.sendCopyDone(continuation: continuation) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + } + + /// Put the state machine out of the copying mode and send a `CopyFail` message to the backend. + mutating func sendCopyFail(message: String, continuation: CheckedContinuation) -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state else { + preconditionFailure("Copy mode is only supported for extended queries") + } + + self.state = .modifying // avoid CoW + let action = queryState.sendCopyFail(message: message, continuation: continuation) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) } mutating func emptyQueryResponseReceived() -> ConnectionAction { @@ -782,9 +885,10 @@ struct ConnectionStateMachine { // MARK: Consumer - mutating func cancelQueryStream() -> ConnectionAction { + mutating func cancel() -> ConnectionAction { guard case .extendedQuery(var queryState, let connectionContext) = self.state else { - preconditionFailure("Tried to cancel stream without active query") + // We are not in a state in which we can cancel. Do nothing. + return .wait } self.state = .modifying // avoid CoW @@ -866,14 +970,22 @@ struct ConnectionStateMachine { .forwardRows, .forwardStreamComplete, .wait, - .read: + .read, + .triggerCopyData, + .sendCopyDoneAndSync, + .sendCopyFail, + .succeedQueryContinuation, + .failPromiseAndCloseConnection: preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)") case .evaluateErrorAtConnectionLevel: return .closeConnectionAndCleanup(cleanupContext) - case .failQuery(let queryContext, with: let error): - return .failQuery(queryContext, with: error, cleanupContext: cleanupContext) + case .failQuery(let promise, with: let error): + return .failQuery(promise, with: error, cleanupContext: cleanupContext) + + case .failQueryContinuation(let continuation, with: let error, let sync): + return .failQueryContinuation(continuation, with: error, sync: sync, cleanupContext: cleanupContext) case .forwardStreamError(let error, let read): return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) @@ -1044,8 +1156,22 @@ extension ConnectionStateMachine { case .failQuery(let requestContext, with: let error): let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) return .failQuery(requestContext, with: error, cleanupContext: cleanupContext) + case .failQueryContinuation(let continuation, with: let error, let sync): + let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) + return .failQueryContinuation(continuation, with: error, sync: sync, cleanupContext: cleanupContext) case .succeedQuery(let requestContext, with: let result): return .succeedQuery(requestContext, with: result) + case .succeedQueryContinuation(let continuation, let sync): + return .succeedQueryContinuation(continuation, sync: sync) + case .triggerCopyData(let triggerCopy): + return .triggerCopyData(triggerCopy) + case .sendCopyDoneAndSync: + return .sendCopyDoneAndSync + case .sendCopyFail(message: let message): + return .sendCopyFail(message: message) + case .failPromiseAndCloseConnection(let promise, error: let error): + let cleanupContext = self.setErrorAndCreateCleanupContext(error) + return .failPromiseAndCloseConnection(promise, error: error, cleanupContext: cleanupContext) case .forwardRows(let buffer): return .forwardRows(buffer) case .forwardStreamComplete(let buffer, let commandTag): diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 5708b6b9..bba44f0f 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -2,6 +2,15 @@ import NIOCore struct ExtendedQueryStateMachine { + private enum CopyingDataState { + /// The write channel is ready to handle more data. + case readyToSend + + /// The write channel has backpressure. Once that is relieved, we should resume the attached continuation to + /// allow more data to be sent by the client. + case pendingBackpressureRelieve(EventLoopPromise) + } + private enum State { case initialized(ExtendedQueryContext) case messagesSent(ExtendedQueryContext) @@ -12,6 +21,19 @@ struct ExtendedQueryStateMachine { case noDataMessageReceived(ExtendedQueryContext) case emptyQueryResponseReceived + /// We are currently copying data to the backend using `CopyData` messages. + case copyingData(CopyingDataState) + + /// We copied data to the backend and are done with that, either by sending a `CopyDone` or `CopyFail` message. + /// We are now expecting a `CommandComplete` or `ErrorResponse`. + /// + /// Once that is received the continuation is resumed. + /// + /// `successful` identifies whether copying was finished with a `CopyDone` or a `CopyFail` message. This is + /// necessary because we send a `Sync` after `CopyDone` but only send the `Sync` for `CopyFail` once we receive + /// the `ErrorResponse` from the backend. + case copyingFinished(CheckedContinuation, successful: Bool) + /// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is /// used after receiving a `bindComplete` message case bindCompleteReceived(ExtendedQueryContext) @@ -32,13 +54,37 @@ struct ExtendedQueryStateMachine { // --- general actions case failQuery(EventLoopPromise, with: PSQLError) + /// Fail a query's execution by resuming the continuation with the given error. When `sync` is `true`, send a + /// `Sync` message to the backend. + case failQueryContinuation(AnyErrorContinuation, with: PSQLError, sync: Bool) case succeedQuery(EventLoopPromise, with: QueryResult) + /// Succeed the continuation with a void result. When `sync` is `true`, send a `Sync` message to the backend. + case succeedQueryContinuation(CheckedContinuation, sync: Bool) case evaluateErrorAtConnectionLevel(PSQLError) case succeedPreparedStatementCreation(EventLoopPromise, with: RowDescription?) case failPreparedStatementCreation(EventLoopPromise, with: PSQLError) + /// Trigger a data transfer returning a `PostgresCopyFromWriter` to the given continuation. + /// + /// Once the data transfer is triggered, it will send `CopyData` messages to the backend. After that the state + /// machine needs to be prodded again to send a `CopyDone` or `CopyFail` by calling + /// `PostgresChannelHandler.sendCopyDone` or `PostgresChannelHandler.sendCopyFail`. + case triggerCopyData(CheckedContinuation) + + /// Send a `CopyDone` and `Sync` message to the backend. + case sendCopyDoneAndSync + + /// Send a `CopyFail` message to the backend with the given error message. + case sendCopyFail(message: String) + + /// Fail the promise with the given error and close the connection. + /// + /// This is used when we want to cancel a COPY operation while waiting for backpressure relieve. In that case we + /// can't recover the connection because we can't send any messages to the backend, so we need to close it. + case failPromiseAndCloseConnection(EventLoopPromise, error: PSQLError) + // --- streaming actions // actions if query has requested next row but we are waiting for backend case forwardRows([DataRow]) @@ -63,7 +109,7 @@ struct ExtendedQueryStateMachine { } switch queryContext.query { - case .unnamed(let query, _): + case .unnamed(let query, _), .copyFrom(let query, _): return self.avoidingStateMachineCoW { state -> Action in state = .messagesSent(queryContext) return .sendParseDescribeBindExecuteSync(query) @@ -108,10 +154,29 @@ struct ExtendedQueryStateMachine { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return .failQuery(eventLoopPromise, with: .queryCancelled) + case .copyFrom(_, let triggerCopy): + return .failQueryContinuation(.copyFromWriter(triggerCopy), with: .queryCancelled, sync: false) + case .prepareStatement(_, _, _, let eventLoopPromise): return .failPreparedStatementCreation(eventLoopPromise, with: .queryCancelled) } + case .copyingData(.readyToSend): + // We can't initiate an exit from the copy state here because `copyingFinished`, which is the state that is + // reached after sending a `CopyFail` requires a continuation that waits for the `CommandComplete` or + // `ErrorResponse`. Instead, we assume that the next call to `CopyFromWriter.write` checks cancellation and + // initiates the `CopyFail` with the cancellation. + return .wait + + case .copyingData(.pendingBackpressureRelieve(let promise)): + self.state = .error(.queryCancelled) + return .failPromiseAndCloseConnection(promise, error: .queryCancelled) + + case .copyingFinished: + // We already finished the copy and are awaiting the `CommandComplete` or `ErrorResponse` from it. There's + // nothing we can do to cancel that. + return .wait + case .streaming(let columns, var streamStateMachine): precondition(!self.isCancelled) self.isCancelled = true @@ -160,7 +225,7 @@ struct ExtendedQueryStateMachine { } switch queryContext.query { - case .unnamed, .executeStatement: + case .unnamed, .copyFrom, .executeStatement: return self.avoidingStateMachineCoW { state -> Action in state = .noDataMessageReceived(queryContext) return .wait @@ -198,7 +263,7 @@ struct ExtendedQueryStateMachine { } switch queryContext.query { - case .unnamed, .executeStatement: + case .unnamed, .copyFrom, .executeStatement: return .wait case .prepareStatement(_, _, _, let eventLoopPromise): @@ -219,6 +284,10 @@ struct ExtendedQueryStateMachine { case .prepareStatement: return .evaluateErrorAtConnectionLevel(.unexpectedBackendMessage(.bindComplete)) + case .copyFrom: + // The COPY commands don't return row descriptions, so we should never be in the + // `rowDescriptionReceived` state. + return .evaluateErrorAtConnectionLevel(.unexpectedBackendMessage(.bindComplete)) } case .noDataMessageReceived(let queryContext): @@ -235,7 +304,9 @@ struct ExtendedQueryStateMachine { .streaming, .drain, .commandComplete, - .error: + .error, + .copyingData, + .copyingFinished: return self.setAndFireError(.unexpectedBackendMessage(.bindComplete)) case .modifying: @@ -274,7 +345,9 @@ struct ExtendedQueryStateMachine { .rowDescriptionReceived, .bindCompleteReceived, .commandComplete, - .error: + .error, + .copyingData, + .copyingFinished: return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) case .modifying: preconditionFailure("Invalid state") @@ -291,10 +364,19 @@ struct ExtendedQueryStateMachine { let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) return .succeedQuery(eventLoopPromise, with: result) } - + case .copyFrom: + // We expect to transition through `copyingData` to `copyingFinished` before receiving a + // `CommandCompleted` message for copy queries. + return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag))) case .prepareStatement: preconditionFailure("Invalid state: \(self.state)") } + + case .copyingFinished(let continuation, let successful): + return self.avoidingStateMachineCoW { state -> Action in + state = .commandComplete(commandTag: commandTag) + return .succeedQueryContinuation(continuation, sync: !successful) + } case .streaming(_, var demandStateMachine): return self.avoidingStateMachineCoW { state -> Action in @@ -315,17 +397,82 @@ struct ExtendedQueryStateMachine { .emptyQueryResponseReceived, .rowDescriptionReceived, .commandComplete, - .error: + .error, + .copyingData: return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag))) case .modifying: preconditionFailure("Invalid state") } } - mutating func copyInResponseReceived( - _ copyInResponse: PostgresBackendMessage.CopyInResponse - ) -> Action { - return self.setAndFireError(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + mutating func copyInResponseReceived(_ copyInResponse: PostgresBackendMessage.CopyInResponse) -> Action { + guard case .bindCompleteReceived(let queryContext) = self.state, + case .copyFrom(_, let triggerCopy) = queryContext.query else { + return self.setAndFireError(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + return avoidingStateMachineCoW { state in + // We can assume that we have no backpressure here. Before sending data, `checkBackendCanReceiveCopyData` + // will be called, which checks if the channel to the backend is indeed writable. + state = .copyingData(.readyToSend) + return .triggerCopyData(triggerCopy) + } + } + + /// Succeed the promise when the channel to the backend is writable and the backend is ready to receive more data. + /// + /// The promise may be failed if the backend indicated that it can't handle any more data by sending an + /// `ErrorResponse`. This is mostly the case when malformed data is sent to it. In that case, the data transfer + /// should be aborted to avoid unnecessary work. + mutating func checkBackendCanReceiveCopyData(channelIsWritable: Bool, promise: EventLoopPromise) -> ConnectionStateMachine.CheckBackendCanReceiveCopyDataAction { + if case .error(let error) = self.state { + // The backend sent us an ErrorResponse during the copy operation. Indicate to the client that it should + // abort the data transfer. + promise.fail(error) + return . failPromise(promise, error: error) + } + guard case .copyingData(.readyToSend) = self.state else { + preconditionFailure("Not ready to send data") + } + if channelIsWritable { + return .succeedPromise(promise) + } + return avoidingStateMachineCoW { state in + state = .copyingData(.pendingBackpressureRelieve(promise)) + return .none + } + } + + /// Put the state machine out of the copying mode and send a `CopyDone` message to the backend. + mutating func sendCopyDone(continuation: CheckedContinuation) -> Action { + if case .error(let error) = self.state { + // The backend sent us an ErrorResponse during the copy operation. We need to send a `Sync` to get out of + // copy mode and communicate the error to the user. There's no need for `CopyDone` anymore. + return .failQueryContinuation(.void(continuation), with: error, sync: true) + } + guard case .copyingData = self.state else { + preconditionFailure("Must be in copy mode to send CopyDone") + } + return avoidingStateMachineCoW { state in + state = .copyingFinished(continuation, successful: true) + return .sendCopyDoneAndSync + } + } + + /// Put the state machine out of the copying mode and send a `CopyFail` message to the backend. + mutating func sendCopyFail(message: String, continuation: CheckedContinuation) -> Action { + if case .error(let error) = self.state { + // The backend sent us an ErrorResponse during the copy operation. We need to send a `Sync` to get out of + // copy mode and communicate the error to the user. There's no need for `CopyFail` anymore. + return .failQueryContinuation(.void(continuation), with: error, sync: true) + } + guard case .copyingData = self.state else { + preconditionFailure("Must be in copy mode to send CopyFail") + } + return avoidingStateMachineCoW { state in + state = .copyingFinished(continuation, successful: false) + return .sendCopyFail(message: message) + } + } mutating func emptyQueryResponseReceived() -> Action { @@ -342,7 +489,7 @@ struct ExtendedQueryStateMachine { return .succeedQuery(eventLoopPromise, with: result) } - case .prepareStatement(_, _, _, _): + case .prepareStatement, .copyFrom: return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) } } @@ -359,6 +506,8 @@ struct ExtendedQueryStateMachine { return self.setAndFireError(error) case .rowDescriptionReceived, .noDataMessageReceived: return self.setAndFireError(error) + case .copyingData, .copyingFinished: + return self.setAndFireError(error) case .streaming, .drain: return self.setAndFireError(error) case .commandComplete, .emptyQueryResponseReceived: @@ -409,7 +558,9 @@ struct ExtendedQueryStateMachine { .noDataMessageReceived, .emptyQueryResponseReceived, .rowDescriptionReceived, - .bindCompleteReceived: + .bindCompleteReceived, + .copyingData, + .copyingFinished: preconditionFailure("Requested to consume next row without anything going on.") case .commandComplete, .error: @@ -433,7 +584,9 @@ struct ExtendedQueryStateMachine { .noDataMessageReceived, .emptyQueryResponseReceived, .rowDescriptionReceived, - .bindCompleteReceived: + .bindCompleteReceived, + .copyingData, + .copyingFinished: return .wait case .streaming(let columns, var demandStateMachine): @@ -460,7 +613,9 @@ struct ExtendedQueryStateMachine { .parameterDescriptionReceived, .noDataMessageReceived, .rowDescriptionReceived, - .bindCompleteReceived: + .bindCompleteReceived, + .copyingData, + .copyingFinished: return .read case .streaming(let columns, var demandStateMachine): precondition(!self.isCancelled) @@ -486,6 +641,16 @@ struct ExtendedQueryStateMachine { preconditionFailure("Invalid state") } } + + mutating func channelWritabilityChanged(isWritable: Bool) -> ConnectionStateMachine.ChannelWritabilityChangedAction { + guard case .copyingData(.pendingBackpressureRelieve(let promise)) = state else { + return .none + } + return self.avoidingStateMachineCoW { state in + state = .copyingData(.readyToSend) + return .succeedPromise(promise) + } + } // MARK: Private Methods @@ -505,11 +670,22 @@ struct ExtendedQueryStateMachine { switch context.query { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return .failQuery(eventLoopPromise, with: error) + case .copyFrom(_, let triggerCopy): + return .failQueryContinuation(.copyFromWriter(triggerCopy), with: error, sync: false) case .prepareStatement(_, _, _, let eventLoopPromise): return .failPreparedStatementCreation(eventLoopPromise, with: error) } } - + case .copyingData: + self.state = .error(error) + // Store the error. We expect the next chunk of data to be written almost immediately, which will call + // `checkBackendCanReceiveCopyData`, which handles the error. If the user is done writing data, we expect a + // `CopyDone` or `CopyFail` message soon, which also checks for the error case, so there's nothing that we + // need to actively do here. + return .wait + case .copyingFinished(let continuation, let successful): + self.state = .error(error) + return .failQueryContinuation(.void(continuation), with: error, sync: !successful) case .drain: self.state = .error(error) return .evaluateErrorAtConnectionLevel(error) @@ -542,11 +718,19 @@ struct ExtendedQueryStateMachine { switch context.query { case .prepareStatement: return true - case .unnamed, .executeStatement: + case .unnamed, .copyFrom, .executeStatement: return false } - case .initialized, .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .bindCompleteReceived, .streaming, .drain: + case .initialized, + .messagesSent, + .parseCompleteReceived, + .parameterDescriptionReceived, + .bindCompleteReceived, + .streaming, + .drain, + .copyingData, + .copyingFinished: return false case .modifying: diff --git a/Sources/PostgresNIO/New/Extensions/AnyErrorContinuation.swift b/Sources/PostgresNIO/New/Extensions/AnyErrorContinuation.swift new file mode 100644 index 00000000..b141c928 --- /dev/null +++ b/Sources/PostgresNIO/New/Extensions/AnyErrorContinuation.swift @@ -0,0 +1,13 @@ +/// Enum that abstracts over continuations that have `any Error` as the failure type. Cases are expected to get added +/// for the success types that we care about. +enum AnyErrorContinuation { + case void(CheckedContinuation) + case copyFromWriter(CheckedContinuation) + + func resume(throwing error: any Error) { + switch self { + case .void(let continuation): continuation.resume(throwing: error) + case .copyFromWriter(let continuation): continuation.resume(throwing: error) + } + } +} diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 6106fd21..820e8a20 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -19,6 +19,8 @@ enum PSQLTask { switch extendedQueryContext.query { case .unnamed(_, let eventLoopPromise): eventLoopPromise.fail(error) + case .copyFrom(_, let triggerCopy): + triggerCopy.resume(throwing: error) case .executeStatement(_, let eventLoopPromise): eventLoopPromise.fail(error) case .prepareStatement(_, _, _, let eventLoopPromise): @@ -34,6 +36,12 @@ enum PSQLTask { final class ExtendedQueryContext: Sendable { enum Query { case unnamed(PostgresQuery, EventLoopPromise) + /// A `COPY ... FROM STDIN` query that copies data from the frontend into a table. + /// + /// When `triggerCopy` is resumed, the `PostgresConnection` that created this query should send data to the + /// backend via `CopyData` messages and finalize the data transfer by calling `sendCopyDone` or `sendCopyFail` + /// on the `PostgresChannelHandler`. + case copyFrom(PostgresQuery, triggerCopy: CheckedContinuation) case executeStatement(PSQLExecuteStatement, EventLoopPromise) case prepareStatement(name: String, query: String, bindingDataTypes: [PostgresDataType], EventLoopPromise) } @@ -50,6 +58,15 @@ final class ExtendedQueryContext: Sendable { self.logger = logger } + init( + copyFromQuery query: PostgresQuery, + triggerCopy: CheckedContinuation, + logger: Logger + ) { + self.query = .copyFrom(query, triggerCopy: triggerCopy) + self.logger = logger + } + init( executeStatement: PSQLExecuteStatement, logger: Logger, diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index baf801e5..d1470dee 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -171,10 +171,80 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.run(action, with: context) } + /// Succeed the promise when the channel to the backend is writable and the backend is ready to receive more data. + /// + /// The promise may be failed if the backend indicated that it can't handle any more data by sending an + /// `ErrorResponse`. This is mostly the case when malformed data is sent to it. In that case, the data transfer + /// should be aborted to avoid unnecessary work. + func checkBackendCanReceiveCopyData(promise: EventLoopPromise) { + guard let handlerContext else { + promise.fail(PostgresError.connectionClosed) + return + } + let action = self.state.checkBackendCanReceiveCopyData(channelIsWritable: handlerContext.channel.isWritable, promise: promise) + switch action { + case .none: + break + case .succeedPromise(let promise): + promise.succeed() + case .failPromise(let promise, error: let error): + promise.fail(error) +} + } + + /// Cancel the currently executing operation, if it is cancellable. + func cancel() { + guard let handlerContext else { + return + } + let action = self.state.cancel() + self.run(action, with: handlerContext) + } + + /// Send a `CopyData` message to the backend using the given data. + func sendCopyData(_ data: ByteBuffer) throws { + guard let handlerContext else { + throw PostgresError.connectionClosed + } + self.encoder.copyDataHeader(dataLength: UInt32(data.readableBytes)) + handlerContext.write(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + handlerContext.writeAndFlush(self.wrapOutboundOut(data), promise: nil) + } + + /// Put the state machine out of the copying mode and send a `CopyDone` message to the backend. + func sendCopyDone(continuation: CheckedContinuation) { + guard let handlerContext else { + continuation.resume(throwing: PostgresError.connectionClosed) + return + } + let action = self.state.sendCopyDone(continuation: continuation) + self.run(action, with: handlerContext) + } + + /// Put the state machine out of the copying mode and send a `CopyFail` message to the backend. + func sendCopyFail(message: String, continuation: CheckedContinuation) { + guard let handlerContext else { + continuation.resume(throwing: PostgresError.connectionClosed) + return + } + let action = self.state.sendCopyFail(message: message, continuation: continuation) + self.run(action, with: handlerContext) + } + func channelReadComplete(context: ChannelHandlerContext) { let action = self.state.channelReadComplete() self.run(action, with: context) } + + func channelWritabilityChanged(context: ChannelHandlerContext) { + let action = self.state.channelWritabilityChanged(isWritable: context.channel.isWritable) + switch action { + case .none: + break + case .succeedPromise(let promise): + promise.succeed() + } + } func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { self.logger.trace("User inbound event received", metadata: [ @@ -355,12 +425,36 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.sendParseDescribeBindExecuteAndSyncMessage(query: query, context: context) case .succeedQuery(let promise, with: let result): self.succeedQuery(promise, result: result, context: context) + case .succeedQueryContinuation(let continuation, let sync): + if sync { + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + } + continuation.resume() case .failQuery(let promise, with: let error, let cleanupContext): promise.fail(error) if let cleanupContext = cleanupContext { self.closeConnectionAndCleanup(cleanupContext, context: context) } - + case .failQueryContinuation(let continuation, with: let error, let sync, let cleanupContext): + if sync { + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + } + if let cleanupContext = cleanupContext { + self.closeConnectionAndCleanup(cleanupContext, context: context) + } + continuation.resume(throwing: error) + case .triggerCopyData(let triggerCopy): + let writer = PostgresCopyFromWriter(handler: self, eventLoop: eventLoop) + triggerCopy.resume(returning: writer) + case .sendCopyDoneAndSync: + self.encoder.copyDone() + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + case .sendCopyFail(message: let message): + self.encoder.copyFail(message: message) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .forwardRows(let rows): self.rowStream!.receive(rows) @@ -427,6 +521,9 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } case .forwardNotificationToListeners(let notification): self.forwardNotificationToListeners(notification, context: context) + case .failPromiseAndCloseConnection(let promise, let error, let cleanupContext): + promise.fail(error) + self.closeConnectionAndCleanup(cleanupContext, context: context) } } @@ -798,11 +895,10 @@ extension PostgresChannelHandler: PSQLRowsDataSource { } func cancel(for stream: PSQLRowStream) { - guard self.rowStream === stream, let handlerContext = self.handlerContext else { + guard self.rowStream === stream else { return } - let action = self.state.cancelQueryStream() - self.run(action, with: handlerContext) + self.cancel() } } diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index d541899b..267e70e9 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -379,4 +379,114 @@ final class IntegrationTests: XCTestCase { } } + #if compiler(>=6.0) + func testCopyIntoFrom() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let conn = try await PostgresConnection.test(on: eventLoop).get() + defer { XCTAssertNoThrow(try conn.close().wait()) } + + _ = try? await conn.query("DROP TABLE copy_table", logger: .psqlTest).get() + _ = try await conn.query("CREATE TABLE copy_table (id INT, name VARCHAR(100))", logger: .psqlTest).get() + + var options = PostgresCopyFromFormat.TextOptions() + options.delimiter = "," + try await conn.copyFrom(table: "copy_table", columns: ["id", "name"], format: .text(options), logger: .psqlTest) { writer in + let records: [(id: Int, name: String)] = [ + (1, "Alice"), + (42, "Bob") + ] + for record in records { + var buffer = ByteBuffer() + buffer.writeString("\(record.id),\(record.name)\n") + try await writer.write(buffer) + } + } + let rows = try await conn.query("SELECT id, name FROM copy_table").get().rows.map { try $0.decode((Int, String).self) } + guard rows.count == 2 else { + XCTFail("Expected 2 columns, received \(rows.count)") + return + } + XCTAssertEqual(rows[0].0, 1) + XCTAssertEqual(rows[0].1, "Alice") + XCTAssertEqual(rows[1].0, 42) + XCTAssertEqual(rows[1].1, "Bob") + } + + func testCopyIntoFromIsTerminatedByThrowingErrorFromClosure() async throws { + struct MyError: Error, CustomStringConvertible { + var description: String { "My error" } + } + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let conn = try await PostgresConnection.test(on: eventLoop).get() + defer { XCTAssertNoThrow(try conn.close().wait()) } + + _ = try? await conn.query("DROP TABLE copy_table", logger: .psqlTest).get() + _ = try await conn.query("CREATE TABLE copy_table (id INT, name VARCHAR(100))", logger: .psqlTest).get() + + do { + try await conn.copyFrom(table: "copy_table", columns: ["id", "name"], logger: .psqlTest) { writer in + throw MyError() + } + XCTFail("Expected error to be thrown") + } catch { + XCTAssert(error is MyError, "Expected error of type MyError, got \(String(reflecting: error))") + } + } + + + func testCopyIntoFromHasBadFormat() async throws { + struct MyError: Error, CustomStringConvertible { + var description: String { "My error" } + } + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let conn = try await PostgresConnection.test(on: eventLoop).get() + defer { XCTAssertNoThrow(try conn.close().wait()) } + + _ = try? await conn.query("DROP TABLE copy_table", logger: .psqlTest).get() + _ = try await conn.query("CREATE TABLE copy_table (id INT, name VARCHAR(100))", logger: .psqlTest).get() + + do { + try await conn.copyFrom(table: "copy_table", columns: ["id", "name"], logger: .psqlTest) { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + } + XCTFail("Expected error to be thrown") + } catch { + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") // invalid_text_representation + } + } + + func testSyntaxErrorInGeneratedQuery() async throws { + struct MyError: Error, CustomStringConvertible { + var description: String { "My error" } + } + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let conn = try await PostgresConnection.test(on: eventLoop).get() + defer { XCTAssertNoThrow(try conn.close().wait()) } + + do { + // Use some form of input that generates an invalid query, the exact manner of its invalidness doesn't matter + try await conn.copyFrom(table: "", logger: .psqlTest) { writer in + XCTFail("Did not expect to call writeData") + } + XCTFail("Expected error to be thrown") + } catch { + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "42601") // scanner_yyerror + } + } + #endif } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index 872664af..758f83e2 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -140,7 +140,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger))) - XCTAssertEqual(state.cancelQueryStream(), .forwardStreamError(.queryCancelled, read: false, cleanupContext: nil)) + XCTAssertEqual(state.cancel(), .forwardStreamError(.queryCancelled, read: false, cleanupContext: nil)) XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test1")]), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.readEventCaught(), .read) @@ -188,7 +188,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.dataRowReceived(row1), .wait) XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) XCTAssertEqual(state.readEventCaught(), .wait) - XCTAssertEqual(state.cancelQueryStream(), .forwardStreamError(.queryCancelled, read: true, cleanupContext: nil)) + XCTAssertEqual(state.cancel(), .forwardStreamError(.queryCancelled, read: true, cleanupContext: nil)) XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test2")]), .wait) XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test3")]), .wait) @@ -287,7 +287,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) - XCTAssertEqual(state.cancelQueryStream(), .failQuery(promise, with: .queryCancelled, cleanupContext: .none)) + XCTAssertEqual(state.cancel(), .failQuery(promise, with: .queryCancelled, cleanupContext: .none)) let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) XCTAssertEqual(state.errorReceived(serverError), .wait) diff --git a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift index 5fc8144b..d2dd5681 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -323,3 +323,125 @@ extension PostgresFrontendMessage { } } } + +/// Convenience accessors to get a specific case or `nil` if the enum is of a different case. +extension PostgresFrontendMessage { + var bind: Bind? { + guard case .bind(let bind) = self else { + return nil + } + return bind + } + + var cancel: Cancel? { + guard case .cancel(let cancel) = self else { + return nil + } + return cancel + } + + var copyData: CopyData? { + guard case .copyData(let copyData) = self else { + return nil + } + return copyData + } + + var copyDone: Void? { + guard case .copyDone = self else { + return nil + } + return () + } + + var copyFail: CopyFail? { + guard case .copyFail(let copyFail) = self else { + return nil + } + return copyFail + } + + var close: Close? { + guard case .close(let close) = self else { + return nil + } + return close + } + + var describe: Describe? { + guard case .describe(let describe) = self else { + return nil + } + return describe + } + + var execute: Execute? { + guard case .execute(let execute) = self else { + return nil + } + return execute + } + + var flush: Void? { + guard case .flush = self else { + return nil + } + return () + } + + var parse: Parse? { + guard case .parse(let parse) = self else { + return nil + } + return parse + } + + var password: Password? { + guard case .password(let password) = self else { + return nil + } + return password + } + + var saslInitialResponse: SASLInitialResponse? { + guard case .saslInitialResponse(let saslInitialResponse) = self else { + return nil + } + return saslInitialResponse + } + + var saslResponse: SASLResponse? { + guard case .saslResponse(let saslResponse) = self else { + return nil + } + return saslResponse + } + + var sslRequest: Void? { + guard case .sslRequest = self else { + return nil + } + return () + } + + var sync: Void? { + guard case .sync = self else { + return nil + } + return () + } + + var startup: Startup? { + guard case .startup(let startup) = self else { + return nil + } + return startup + } + + var terminate: Void? { + guard case .terminate = self else { + return nil + } + return () + } +} diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index d0f8e2b0..4e43ed38 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -4,6 +4,9 @@ import NIOEmbedded import XCTest import Logging @testable import PostgresNIO +#if canImport(Synchronization) +import Synchronization +#endif class PostgresConnectionTests: XCTestCase { @@ -70,8 +73,8 @@ class PostgresConnectionTests: XCTestCase { }() async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: .psqlTest) - let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false)))) + let message = try await channel.waitForPostgresFrontendMessage(\.startup) + XCTAssertEqual(message, .versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -95,10 +98,7 @@ class PostgresConnectionTests: XCTestCase { let listenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -107,10 +107,7 @@ class PostgresConnectionTests: XCTestCase { let unlistenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -155,10 +152,7 @@ class PostgresConnectionTests: XCTestCase { let listenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -168,10 +162,7 @@ class PostgresConnectionTests: XCTestCase { let unlistenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -204,10 +195,7 @@ class PostgresConnectionTests: XCTestCase { let listenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -267,8 +255,7 @@ class PostgresConnectionTests: XCTestCase { try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) } - let terminate = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(terminate, .terminate) + try await channel.waitForPostgresFrontendMessage(\.terminate) try await channel.closeFuture.get() XCTAssertEqual(channel.isActive, false) @@ -283,7 +270,7 @@ class PostgresConnectionTests: XCTestCase { } } - func testCloseClosesImmediatly() async throws { + func testCloseClosesImmediately() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in @@ -638,6 +625,295 @@ class PostgresConnectionTests: XCTestCase { } } + #if compiler(>=6.0) + func testCopyFromSucceeds() async throws { + try await assertCopyFrom { writer in + try await writer.write(ByteBuffer(staticString: "1\tAlice\n")) + } validateCopyRequest: { copyRequest in + XCTAssertEqual(copyRequest.parse.query, #"COPY "copy_table"("id","name") FROM STDIN WITH (FORMAT text)"#) + XCTAssertEqual(copyRequest.bind.parameters, []) + } mockBackend: { channel, _ in + let data = try await channel.waitForCopyData() + XCTAssertEqual(String(buffer: data.data), "1\tAlice\n") + XCTAssertEqual(data.result, .done) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 1")) + } + } + + + func testCopyFromWithOptions() async throws { + var options = PostgresCopyFromFormat.TextOptions() + options.delimiter = "," + try await assertCopyFrom(format: .text(options)) { writer in + try await writer.write(ByteBuffer(staticString: "1,Alice\n")) + } validateCopyRequest: { copyRequest in + XCTAssertEqual(copyRequest.parse.query, #"COPY "copy_table"("id","name") FROM STDIN WITH (FORMAT text,DELIMITER U&'\002c')"#) + XCTAssertEqual(copyRequest.bind.parameters, []) + } mockBackend: { channel, _ in + let data = try await channel.waitForCopyData() + XCTAssertEqual(String(buffer: data.data), "1,Alice\n") + XCTAssertEqual(data.result, .done) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 1")) + } + } + + func testCopyFromWriterFails() async throws { + struct MyError: Error {} + + try await assertCopyFrom { writer in + throw MyError() + } validateCopyFromError: { error in + XCTAssert(error is MyError, "Expected error of type MyError, got \(error)") + } mockBackend: { channel, _ in + let data = try await channel.waitForCopyData() + XCTAssertEqual(data.result, .failed(message: "Client failed copy")) + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: "COPY from stdin failed: Client failed copy", + .sqlState : "57014" // query_canceled + ]))) + } + } + + func testCopyFromBackendSendsErrorBeforeCopyDone() async throws { + // Stream to indicate that the backend did send an error + let (signalStream, signalContinuation) = AsyncStream.makeStream(of: Void.self) + + try await assertCopyFrom { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + var iterator = signalStream.makeAsyncIterator() + await iterator.next() + } validateCopyFromError: { error in + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") + } mockBackend: { channel, _ in + let copyDataMessage = try await channel.waitForPostgresFrontendMessage(\.copyData) + XCTAssertEqual(copyDataMessage, PostgresFrontendMessage.CopyData(data: ByteBuffer(staticString: "1Alice\n"))) + + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"invalid input syntax for type integer: "1Alice""#, + .sqlState : "22P02" // invalid_text_representation + ]))) + signalContinuation.yield() + } + } + + func testCopyFromBackendSendsErrorAfterCopyDone() async throws { + try await assertCopyFrom { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + } validateCopyFromError: { error in + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") + } mockBackend: { channel, _ in + _ = try await channel.waitForCopyData() + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"invalid input syntax for type integer: "1Alice""#, + .sqlState : "22P02" // invalid_text_representation + ]))) + } + } + + func testCopyFromBackendSendsErrorBeforeUserThrowsUnrelatedErrorFromClosure() async throws { + struct MyError: Error, CustomStringConvertible { + var description: String { "My error" } + } + + // Stream to indicate that the backend did send an error + let (signalStream, signalContinuation) = AsyncStream.makeStream(of: Void.self) + + // If the user throws an error and we receive an error from the server, we should prefer throwing the user error + // from `copyFrom` since it's likely the more actionable for the user. + try await assertCopyFrom { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + var iterator = signalStream.makeAsyncIterator() + await iterator.next() + throw MyError() + } validateCopyFromError: { error in + XCTAssert(error is MyError, "Expected MyError, got \(error)") + } mockBackend: { channel, _ in + let copyDataMessage = try await channel.waitForPostgresFrontendMessage(\.copyData) + XCTAssertEqual(copyDataMessage, PostgresFrontendMessage.CopyData(data: ByteBuffer(staticString: "1Alice\n"))) + + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"invalid input syntax for type integer: "1Alice""#, + .sqlState : "22P02" // invalid_text_representation + ]))) + signalContinuation.yield() + } + } + + func testCopyFromWriterThrowsErrorAfterBackendSentError() async throws { + // Stream to indicate that the backend did send an error + let (signalStream, signalContinuation) = AsyncStream.makeStream(of: Void.self) + + try await assertCopyFrom { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + var iterator = signalStream.makeAsyncIterator() + await iterator.next() + do { + try await writer.write(ByteBuffer(staticString: "2\tBob\n")) + XCTFail("Expected error to be thrown") + } catch { + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") + throw error + } + } validateCopyFromError: { error in + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") + } mockBackend: { channel, _ in + let dataMessage = try await channel.waitForPostgresFrontendMessage(\.copyData) + XCTAssertEqual(dataMessage, PostgresFrontendMessage.CopyData(data: ByteBuffer(staticString: "1Alice\n"))) + + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"invalid input syntax for type integer: "1Alice""#, + .sqlState : "22P02" // invalid_text_representation + ]))) + signalContinuation.yield() + } + } + + func testCopyFromCallerDoesNotRethrowFromWriteCall() async throws { + struct MyError: Error, CustomStringConvertible { + var description: String { "My error" } + } + + // Stream to indicate that the backend did send an error + let (signalStream, signalContinuation) = AsyncStream.makeStream(of: Void.self) + + try await assertCopyFrom { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + var iterator = signalStream.makeAsyncIterator() + await iterator.next() + do { + try await writer.write(ByteBuffer(staticString: "2\tBob\n")) + XCTFail("Expected error to be thrown") + } catch { + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") + throw MyError() + } + } validateCopyFromError: { error in + XCTAssert(error is MyError, "Expected MyError, got \(error)") + } mockBackend: { channel, _ in + let dataMessage = try await channel.waitForPostgresFrontendMessage(\.copyData) + XCTAssertEqual(dataMessage, PostgresFrontendMessage.CopyData(data: ByteBuffer(staticString: "1Alice\n"))) + + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"invalid input syntax for type integer: "1Alice""#, + .sqlState : "22P02" // invalid_text_representation + ]))) + signalContinuation.yield() + } + } + + func testCopyFromQueryHasSyntaxError() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + do { + try await connection.copyFrom(table: "", logger: .psqlTest) { _ in + XCTFail("Did not expect to call writeData") + } + + } catch { + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "42601") + } + // Send another query to ensure that the state machine is back in the idle state afterwards and can + // handle new queries. We don't wait for this to finish, just to receive the initiation on the other + // side of the + _ = connection.simpleQuery("DUMMY") + } + + _ = try await channel.waitForUnpreparedRequest() + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"syntax error"#, + .sqlState : "42601" // scanner_yyerror + ]))) + + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + _ = try await channel.waitForUnpreparedRequest() // Await the dummy query messages + } + } + + func testCopyFromHasWriteBackpressure() async throws { + #if !canImport(Synchronization) + throw XCTSkip("Test uses Synchronization which is not available") + #else + guard #available(macOS 15, *) else { + throw XCTSkip("Test uses Atomic which is not available") + } + // `true` while the `writeData` closure is executing the `PostgresCopyFromWriter.write` function, ie. while it + // is blocked for backpressure to be relieved. + let isWriting = Atomic(false) + + try await assertCopyFrom { writer in + isWriting.store(true, ordering: .sequentiallyConsistent) + try await writer.write(ByteBuffer(staticString: "1\tAlice\n")) + isWriting.store(false, ordering: .sequentiallyConsistent) + } preCopyInResponse: { channel in + channel.isWritable = false + } mockBackend: { channel, _ in + XCTAssert(isWriting.load(ordering: .sequentiallyConsistent)) + + channel.isWritable = true + channel.pipeline.fireChannelWritabilityChanged() + + let data = try await channel.waitForCopyData() + XCTAssertEqual(data.data, ByteBuffer(staticString: "1\tAlice\n")) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 1")) + } + #endif + } + + func testCopyFromCancelled() async throws { + try await assertCopyFrom { writer in + while true { + try await writer.write(ByteBuffer(staticString: "1\tAlice\n")) + try await Task.sleep(for: .milliseconds(10)) + } + } validateCopyFromError: { error in + XCTAssert(error is CancellationError, "Expected CancellationError, got \(error)") + } mockBackend: { channel, cancelCopy in + cancelCopy() + + let data = try await channel.waitForCopyData() + XCTAssertEqual(data.result, .failed(message: "Client failed copy")) + + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: "COPY from stdin failed: Client failed copy", + .sqlState : "57014" // query_canceled + ]))) + } + } + + func testCopyFromCancelledWhileWaitingForBackpressureRelieve() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + do { + try await connection.copyFrom(table: "test", logger: .psqlTest) { writer in + try await writer.write(ByteBuffer(staticString: "1\tAlice\n")) + } + XCTFail("Expected `copyFrom` to throw but it did not") + } catch { + XCTAssert(error is CancellationError, "Expected CancellationError, got \(error)") + } + } + + _ = try await channel.waitForUnpreparedRequest() + + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() + channel.isWritable = false + try await channel.writeInbound(PostgresBackendMessage.copyInResponse(.init(format: .textual, columnFormats: Array(repeating: .textual, count: 2)))) + + // Wait for the `PostgresCopyFromWriter.write` call to execute and hit the write backpressure before we cancel the task. + try await Task.sleep(for: .milliseconds(200)) + taskGroup.cancelAll() + + // Check that the connection got closed because of the cancellation. + try await connection.closeFuture.get() + } + } + #endif + func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in @@ -655,8 +931,8 @@ class PostgresConnectionTests: XCTestCase { let logger = self.logger async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: logger) - let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) + let message = try await channel.waitForPostgresFrontendMessage(\.startup) + XCTAssertEqual(message, .versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -669,40 +945,143 @@ class PostgresConnectionTests: XCTestCase { return (connection, channel) } + + #if compiler(>=6.0) + /// Validate the behavior of a `COPY FROM` query. + /// + /// Also checks that the connection returns to an idle state after performing the copy and is capable + /// of handling another query. + /// + /// - Parameters: + /// - table: Forwarded to the `copyFrom` call in `PostgresConnection`. + /// - columns: Forwarded to the `copyFrom` call in `PostgresConnection`. + /// - format: Forwarded to the `copyFrom` call in `PostgresConnection`. + /// - writeData: Forwarded to the `copyFrom` call in `PostgresConnection`. + /// - validateCopyFromError: When not `nil`, we expect the `copyFrom` call to throw. This closure can be used to + /// inspect the thrown error and assert that it has the correct shape. + /// - preCopyInResponse: Called before the `CopyInResponse` is sent to the frontend. + /// - validateCopyRequest: Can be used to verify the shape of the `COPY` query that is received by the backend. + /// - mockBackend: determines how the backend behaves, starting after the point where the backend has sent the + /// `CopyInResponse` and ending in the state where the backend has sent a `CommandComplete` or `ErrorResponse` + /// and is now expecting a `Sync` to return back to the idle state. The closure may call the `cancelCopyFrom` + /// closure that is passed to it to cancel the COPY operation. + private func assertCopyFrom( + table: StaticString = "copy_table", + columns: [StaticString] = ["id", "name"], + format: PostgresCopyFromFormat = .text(.init()), + writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void, + validateCopyFromError: (@Sendable (any Error) -> Void)? = nil, + preCopyInResponse: (_ channel: NIOAsyncTestingChannel) -> Void = { _ in }, + validateCopyRequest: (UnpreparedRequest) -> Void = { _ in }, + mockBackend: (_ channel: NIOAsyncTestingChannel, _ cancelCopy: () -> Void) async throws -> Void, + file: StaticString = #file, + line: UInt = #line + ) async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + do { + try await connection.copyFrom(table: table, columns: columns, format: format, logger: .psqlTest, writeData: writeData) + if validateCopyFromError != nil { + XCTFail("Expected `copyFrom` to throw but it did not") + } + } catch { + if let validateCopyFromError { + validateCopyFromError(error) + } else { + throw error + } + } + // Send another query to ensure that the state machine is back in the idle state afterwards and can + // handle new queries. We don't wait for this to finish, just to receive the initiation on the other + // side of the + _ = connection.simpleQuery("DUMMY") + } + + let copyRequest = try await channel.waitForUnpreparedRequest() + validateCopyRequest(copyRequest) + + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() + preCopyInResponse(channel) + try await channel.writeInbound(PostgresBackendMessage.copyInResponse(.init(format: .textual, columnFormats: Array(repeating: .textual, count: columns.count)))) + + try await mockBackend(channel, { taskGroup.cancelAll() }) + + try await channel.waitForPostgresFrontendMessage(\.sync) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + _ = try await channel.waitForUnpreparedRequest() // Await the dummy query messages + } + } + #endif } extension NIOAsyncTestingChannel { + /// Wait for a `PostgresFrontendMessage` such that `transform` returns a non-nil value. + /// + /// The intention of this is to be used with the convenience accessors on `PostgresFrontendMessage` for the + /// different cases, eg. to wait for a `parse` message + /// + /// ```swift + /// try await self.waitForPostgresFrontendMessage(\.parse) + /// ``` + func waitForPostgresFrontendMessage( + _ transform: (PostgresFrontendMessage) -> T?, + file: StaticString = #file, + line: UInt = #line + ) async throws -> T { + let message = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let payload = try XCTUnwrap(transform(message), "Received unexpected payload: \(message)", file: file, line: line) + return payload + } func waitForUnpreparedRequest() async throws -> UnpreparedRequest { - let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let bind = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let execute = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - - guard case .parse(let parse) = parse, - case .describe(let describe) = describe, - case .bind(let bind) = bind, - case .execute(let execute) = execute, - case .sync = sync - else { - fatalError() - } + let parse = try await self.waitForPostgresFrontendMessage(\.parse) + let describe = try await self.waitForPostgresFrontendMessage(\.describe) + let bind = try await self.waitForPostgresFrontendMessage(\.bind) + let execute = try await self.waitForPostgresFrontendMessage(\.execute) + try await self.waitForPostgresFrontendMessage(\.sync) return UnpreparedRequest(parse: parse, describe: describe, bind: bind, execute: execute) } - func waitForPrepareRequest() async throws -> PrepareRequest { - let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + struct CopyDataRequest { + enum Result: Equatable { + /// The data copy finished successfully with a `CopyDone` message. + case done + /// The data copy finished with a `CopyFail` message containing the following error message. + case failed(message: String) + } + + /// The data that was transferred. + var data: ByteBuffer + + /// The `CopyDone` or `CopyFail` message that finalized the data transfer. + var result: Result + } - guard case .parse(let parse) = parse, - case .describe(let describe) = describe, - case .sync = sync - else { - fatalError("Unexpected message") + func waitForCopyData() async throws -> CopyDataRequest { + var copiedData = ByteBuffer() + while true { + let message = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + switch message { + case .copyData(let data): + copiedData.writeImmutableBuffer(data.data) + case .copyDone: + return CopyDataRequest(data: copiedData, result: .done) + case .copyFail(let message): + return CopyDataRequest(data: copiedData, result: .failed(message: message.message)) + default: + fatalError("Unexpected message") + } } + } + + func waitForPrepareRequest() async throws -> PrepareRequest { + let parse = try await self.waitForPostgresFrontendMessage(\.parse) + let describe = try await self.waitForPostgresFrontendMessage(\.describe) + try await self.waitForPostgresFrontendMessage(\.sync) return PrepareRequest(parse: parse, describe: describe) } @@ -722,16 +1101,9 @@ extension NIOAsyncTestingChannel { } func waitForPreparedRequest() async throws -> PreparedRequest { - let bind = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let execute = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - - guard case .bind(let bind) = bind, - case .execute(let execute) = execute, - case .sync = sync - else { - fatalError() - } + let bind = try await self.waitForPostgresFrontendMessage(\.bind) + let execute = try await self.waitForPostgresFrontendMessage(\.execute) + try await self.waitForPostgresFrontendMessage(\.sync) return PreparedRequest(bind: bind, execute: execute) } @@ -751,6 +1123,14 @@ extension NIOAsyncTestingChannel { try await self.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) try await self.testingEventLoop.executeInContext { self.read() } } + + /// Send the messages up to `BindComplete` for an unnamed query that does not bind any parameters. + func sendUnpreparedRequestWithNoParametersBindResponse() async throws { + try await writeInbound(PostgresBackendMessage.parseComplete) + try await writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await writeInbound(PostgresBackendMessage.noData) + try await writeInbound(PostgresBackendMessage.bindComplete) + } } struct UnpreparedRequest {