From 76034a9ee2576ee9227b180e9a3d475b186a3259 Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Mon, 7 Jul 2025 12:36:32 +0200 Subject: [PATCH] Add message types to support COPY operations This adds the infrastrucutre to decode messages needed for COPY operations. It does not implement the handling support itself yet. That will be added in a follow-up PR. --- .../ConnectionStateMachine.swift | 6 + .../ExtendedQueryStateMachine.swift | 8 +- .../New/Messages/CopyInMessage.swift | 44 ++++++ .../New/PostgresBackendMessage.swift | 12 +- .../New/PostgresBackendMessageDecoder.swift | 6 + .../New/PostgresChannelHandler.swift | 2 + .../New/PostgresFrontendMessageEncoder.swift | 25 +++ .../ExtendedQueryStateMachineTests.swift | 2 +- .../PSQLBackendMessageEncoder.swift | 14 ++ .../PSQLFrontendMessageDecoder.swift | 12 ++ .../Extensions/PostgresFrontendMessage.swift | 32 ++++ .../New/Messages/CopyTests.swift | 147 ++++++++++++++++++ 12 files changed, 305 insertions(+), 5 deletions(-) create mode 100644 Sources/PostgresNIO/New/Messages/CopyInMessage.swift create mode 100644 Tests/PostgresNIOTests/New/Messages/CopyTests.swift diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 9d264bcc..8560b948 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -752,6 +752,12 @@ struct ConnectionStateMachine { return self.modify(with: action) } + mutating func copyInResponseReceived( + _ copyInResponse: PostgresBackendMessage.CopyInResponse + ) -> ConnectionAction { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + mutating func emptyQueryResponseReceived() -> ConnectionAction { guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse)) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 087a6c24..5708b6b9 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -91,7 +91,7 @@ struct ExtendedQueryStateMachine { mutating func cancel() -> Action { switch self.state { case .initialized: - preconditionFailure("Start must be called immediatly after the query was created") + preconditionFailure("Start must be called immediately after the query was created") case .messagesSent(let queryContext), .parseCompleteReceived(let queryContext), @@ -322,6 +322,12 @@ struct ExtendedQueryStateMachine { } } + mutating func copyInResponseReceived( + _ copyInResponse: PostgresBackendMessage.CopyInResponse + ) -> Action { + return self.setAndFireError(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + mutating func emptyQueryResponseReceived() -> Action { guard case .bindCompleteReceived(let queryContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) diff --git a/Sources/PostgresNIO/New/Messages/CopyInMessage.swift b/Sources/PostgresNIO/New/Messages/CopyInMessage.swift new file mode 100644 index 00000000..46dec648 --- /dev/null +++ b/Sources/PostgresNIO/New/Messages/CopyInMessage.swift @@ -0,0 +1,44 @@ +extension PostgresBackendMessage { + struct CopyInResponse: Hashable { + enum Format: Int8 { + case textual = 0 + case binary = 1 + } + + let format: Format + let columnFormats: [Format] + + static func decode(from buffer: inout ByteBuffer) throws -> Self { + guard let rawFormat = buffer.readInteger(endianness: .big, as: Int8.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(1, actual: buffer.readableBytes) + } + guard let format = Format(rawValue: rawFormat) else { + throw PSQLPartialDecodingError.unexpectedValue(value: rawFormat) + } + + guard let numColumns = buffer.readInteger(endianness: .big, as: Int16.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes) + } + var columnFormatCodes: [Format] = [] + columnFormatCodes.reserveCapacity(Int(numColumns)) + + for _ in 0.. Self { + return PSQLPartialDecodingError( + description: "Unknown message kind: \(messageID)", + file: file, line: line) + } } extension ByteBuffer { diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 0a14849a..baf801e5 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -136,6 +136,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { action = self.state.closeCompletedReceived() case .commandComplete(let commandTag): action = self.state.commandCompletedReceived(commandTag) + case .copyInResponse(let copyInResponse): + action = self.state.copyInResponseReceived(copyInResponse) case .dataRow(let dataRow): action = self.state.dataRowReceived(dataRow) case .emptyQueryResponse: diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift index 97805418..6ca4cc27 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -167,6 +167,28 @@ struct PostgresFrontendMessageEncoder { self.buffer.writeMultipleIntegers(UInt32(8), Self.sslRequestCode) } + /// Adds the `CopyData` message ID and `dataLength` to the message buffer but not the actual data. + /// + /// The caller of this function is expected to write the encoder's message buffer to the backend after calling this + /// function, followed by sending the actual data to the backend. + mutating func copyDataHeader(dataLength: UInt32) { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .copyData, length: dataLength) + } + + mutating func copyDone() { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .copyDone, length: 0) + } + + mutating func copyFail(message: String) { + self.clearIfNeeded() + var messageBuffer = ByteBuffer() + messageBuffer.writeNullTerminatedString(message) + self.buffer.psqlWriteMultipleIntegers(id: .copyFail, length: UInt32(messageBuffer.readableBytes)) + self.buffer.writeImmutableBuffer(messageBuffer) + } + mutating func sync() { self.clearIfNeeded() self.buffer.psqlWriteMultipleIntegers(id: .sync, length: 0) @@ -197,6 +219,9 @@ struct PostgresFrontendMessageEncoder { private enum FrontendMessageID: UInt8, Hashable, Sendable { case bind = 66 // B case close = 67 // C + case copyData = 100 // d + case copyDone = 99 // c + case copyFail = 102 // f case describe = 68 // D case execute = 69 // E case flush = 72 // H diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index ae484acc..872664af 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -114,7 +114,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { .failQuery(promise, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil))) } - func testExtendedQueryIsCancelledImmediatly() { + func testExtendedQueryIsCancelledImmediately() { var state = ConnectionStateMachine.readyForQuery() let logger = Logger.psqlTest diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index 9614bf1e..0c6b37ef 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -28,6 +28,8 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder { case .commandComplete(let string): self.encode(messageID: message.id, payload: StringPayload(string), into: &buffer) + case .copyInResponse(let copyInResponse): + self.encode(messageID: message.id, payload: copyInResponse, into: &buffer) case .dataRow(let row): self.encode(messageID: message.id, payload: row, into: &buffer) @@ -99,6 +101,8 @@ extension PostgresBackendMessage { return .closeComplete case .commandComplete: return .commandComplete + case .copyInResponse: + return .copyInResponse case .dataRow: return .dataRow case .emptyQueryResponse: @@ -184,6 +188,16 @@ extension PostgresBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable { } } +extension PostgresBackendMessage.CopyInResponse: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(Int8(self.format.rawValue)) + buffer.writeInteger(Int16(self.columnFormats.count)) + for columnFormat in columnFormats { + buffer.writeInteger(Int16(columnFormat.rawValue)) + } + } +} + extension DataRow: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(self.columnCount, as: Int16.self) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 55ccd0a9..d913da22 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -168,6 +168,18 @@ extension PostgresFrontendMessage { ) ) + case .copyData: + return .copyData(CopyData(data: buffer)) + + case .copyDone: + return .copyDone + + case .copyFail: + guard let message = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + return .copyFail(CopyFail(message: message)) + case .close: preconditionFailure("TODO: Unimplemented") diff --git a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift index 2532959a..5fc8144b 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -36,6 +36,14 @@ enum PostgresFrontendMessage: Equatable { let secretKey: Int32 } + struct CopyData: Hashable { + var data: ByteBuffer + } + + struct CopyFail: Hashable { + var message: String + } + enum Close: Hashable { case preparedStatement(String) case portal(String) @@ -170,6 +178,9 @@ enum PostgresFrontendMessage: Equatable { case bind(Bind) case cancel(Cancel) + case copyData(CopyData) + case copyDone + case copyFail(CopyFail) case close(Close) case describe(Describe) case execute(Execute) @@ -186,6 +197,9 @@ enum PostgresFrontendMessage: Equatable { enum ID: UInt8, Equatable { case bind + case copyData + case copyDone + case copyFail case close case describe case execute @@ -201,12 +215,18 @@ enum PostgresFrontendMessage: Equatable { switch rawValue { case UInt8(ascii: "B"): self = .bind + case UInt8(ascii: "c"): + self = .copyDone case UInt8(ascii: "C"): self = .close + case UInt8(ascii: "d"): + self = .copyData case UInt8(ascii: "D"): self = .describe case UInt8(ascii: "E"): self = .execute + case UInt8(ascii: "f"): + self = .copyFail case UInt8(ascii: "H"): self = .flush case UInt8(ascii: "P"): @@ -230,6 +250,12 @@ enum PostgresFrontendMessage: Equatable { switch self { case .bind: return UInt8(ascii: "B") + case .copyData: + return UInt8(ascii: "d") + case .copyDone: + return UInt8(ascii: "c") + case .copyFail: + return UInt8(ascii: "f") case .close: return UInt8(ascii: "C") case .describe: @@ -263,6 +289,12 @@ extension PostgresFrontendMessage { return .bind case .cancel: preconditionFailure("Cancel messages don't have an identifier") + case .copyData: + return .copyData + case .copyDone: + return .copyDone + case .copyFail: + return .copyFail case .close: return .close case .describe: diff --git a/Tests/PostgresNIOTests/New/Messages/CopyTests.swift b/Tests/PostgresNIOTests/New/Messages/CopyTests.swift new file mode 100644 index 00000000..de686ae5 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/CopyTests.swift @@ -0,0 +1,147 @@ +import XCTest +import NIOCore +import NIOTestUtils +@testable import PostgresNIO + +class CopyTests: XCTestCase { + func testDecodeCopyInResponseMessage() throws { + let expected: [PostgresBackendMessage] = [ + .copyInResponse(.init(format: .textual, columnFormats: [.textual, .textual])), + .copyInResponse(.init(format: .binary, columnFormats: [.binary, .binary])), + .copyInResponse(.init(format: .binary, columnFormats: [.textual, .binary])) + ] + + var buffer = ByteBuffer() + + for message in expected { + guard case .copyInResponse(let message) = message else { + return XCTFail("Expected only to get copyInResponse here!") + } + buffer.writeBackendMessage(id: .copyInResponse ) { buffer in + buffer.writeInteger(Int8(message.format.rawValue)) + buffer.writeInteger(Int16(message.columnFormats.count)) + for columnFormat in message.columnFormats { + buffer.writeInteger(UInt16(columnFormat.rawValue)) + } + } + } + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + } + + func testDecodeFailureBecauseOfEmptyMessage() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { _ in} + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + + func testDecodeFailureBecauseOfInvalidFormat() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(20)) // Only 0 and 1 are valid formats + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeFailureBecauseOfMissingColumnNumber() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + + func testDecodeFailureBecauseOfMissingColumns() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + buffer.writeInteger(Int16(20)) // 20 columns promised, none given + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeFailureBecauseOfInvalidColumnFormat() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + buffer.writeInteger(Int16(1)) + buffer.writeInteger(Int8(20)) // Only 0 and 1 are valid formats + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testEncodeCopyDataHeader() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyDataHeader(dataLength: 3) + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 5) + XCTAssertEqual(PostgresFrontendMessage.ID.copyData.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 7) + } + + func testEncodeCopyDone() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyDone() + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 5) + XCTAssertEqual(PostgresFrontendMessage.ID.copyDone.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 4) + } + + func testEncodeCopyFail() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyFail(message: "Oh, no :(") + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 15) + XCTAssertEqual(PostgresFrontendMessage.ID.copyFail.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 14) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "Oh, no :(") + } +}