Skip to content

Add message types to support COPY operations #569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,12 @@ struct ConnectionStateMachine {
return self.modify(with: action)
}

mutating func copyInResponseReceived(
_ copyInResponse: PostgresBackendMessage.CopyInResponse
) -> ConnectionAction {
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
}

mutating func emptyQueryResponseReceived() -> ConnectionAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct ExtendedQueryStateMachine {
mutating func cancel() -> Action {
switch self.state {
case .initialized:
preconditionFailure("Start must be called immediatly after the query was created")
preconditionFailure("Start must be called immediately after the query was created")

case .messagesSent(let queryContext),
.parseCompleteReceived(let queryContext),
Expand Down Expand Up @@ -322,6 +322,12 @@ struct ExtendedQueryStateMachine {
}
}

mutating func copyInResponseReceived(
_ copyInResponse: PostgresBackendMessage.CopyInResponse
) -> Action {
return self.setAndFireError(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
}

mutating func emptyQueryResponseReceived() -> Action {
guard case .bindCompleteReceived(let queryContext) = self.state else {
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
Expand Down
44 changes: 44 additions & 0 deletions Sources/PostgresNIO/New/Messages/CopyInMessage.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
extension PostgresBackendMessage {
struct CopyInResponse: Hashable {
enum Format: Int8 {
case textual = 0
case binary = 1
}

let format: Format
let columnFormats: [Format]

static func decode(from buffer: inout ByteBuffer) throws -> Self {
guard let rawFormat = buffer.readInteger(endianness: .big, as: Int8.self) else {
throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(1, actual: buffer.readableBytes)
}
guard let format = Format(rawValue: rawFormat) else {
throw PSQLPartialDecodingError.unexpectedValue(value: rawFormat)
}

guard let numColumns = buffer.readInteger(endianness: .big, as: Int16.self) else {
throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes)
}
var columnFormatCodes: [Format] = []
columnFormatCodes.reserveCapacity(Int(numColumns))

for _ in 0..<numColumns {
guard let rawColumnFormat = buffer.readInteger(endianness: .big, as: Int16.self) else {
throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes)
}
guard Int8.min <= rawColumnFormat, rawColumnFormat <= Int8.max, let columnFormat = Format(rawValue: Int8(rawColumnFormat)) else {
throw PSQLPartialDecodingError.unexpectedValue(value: rawColumnFormat)
}
columnFormatCodes.append(columnFormat)
}

return CopyInResponse(format: format, columnFormats: columnFormatCodes)
}
}
}

extension PostgresBackendMessage.CopyInResponse: CustomDebugStringConvertible {
var debugDescription: String {
"format: \(format), columnFormats: \(columnFormats)"
}
}
12 changes: 9 additions & 3 deletions Sources/PostgresNIO/New/PostgresBackendMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ enum PostgresBackendMessage: Hashable {
case bindComplete
case closeComplete
case commandComplete(String)
case copyInResponse(CopyInResponse)
case dataRow(DataRow)
case emptyQueryResponse
case error(ErrorResponse)
Expand Down Expand Up @@ -96,6 +97,9 @@ extension PostgresBackendMessage {
}
return .commandComplete(commandTag)

case .copyInResponse:
return try .copyInResponse(.decode(from: &buffer))

case .dataRow:
return try .dataRow(.decode(from: &buffer))

Expand Down Expand Up @@ -131,9 +135,9 @@ extension PostgresBackendMessage {

case .rowDescription:
return try .rowDescription(.decode(from: &buffer))
case .copyData, .copyDone, .copyInResponse, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion:
preconditionFailure()

case .copyData, .copyDone, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion:
throw PSQLPartialDecodingError.unknownMessageKind(messageID)
}
}
}
Expand All @@ -151,6 +155,8 @@ extension PostgresBackendMessage: CustomDebugStringConvertible {
return ".closeComplete"
case .commandComplete(let commandTag):
return ".commandComplete(\(String(reflecting: commandTag)))"
case .copyInResponse(let copyInResponse):
return ".copyInResponse(\(String(reflecting: copyInResponse)))"
case .dataRow(let dataRow):
return ".dataRow(\(String(reflecting: dataRow)))"
case .emptyQueryResponse:
Expand Down
6 changes: 6 additions & 0 deletions Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ struct PSQLPartialDecodingError: Error {
description: "Expected the integer to be positive or null, but got \(actual).",
file: file, line: line)
}

static func unknownMessageKind(_ messageID: PostgresBackendMessage.ID, file: String = #fileID, line: Int = #line) -> Self {
return PSQLPartialDecodingError(
description: "Unknown message kind: \(messageID)",
file: file, line: line)
}
}

extension ByteBuffer {
Expand Down
2 changes: 2 additions & 0 deletions Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
action = self.state.closeCompletedReceived()
case .commandComplete(let commandTag):
action = self.state.commandCompletedReceived(commandTag)
case .copyInResponse(let copyInResponse):
action = self.state.copyInResponseReceived(copyInResponse)
case .dataRow(let dataRow):
action = self.state.dataRowReceived(dataRow)
case .emptyQueryResponse:
Expand Down
25 changes: 25 additions & 0 deletions Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,28 @@ struct PostgresFrontendMessageEncoder {
self.buffer.writeMultipleIntegers(UInt32(8), Self.sslRequestCode)
}

/// Adds the `CopyData` message ID and `dataLength` to the message buffer but not the actual data.
///
/// The caller of this function is expected to write the encoder's message buffer to the backend after calling this
/// function, followed by sending the actual data to the backend.
mutating func copyDataHeader(dataLength: UInt32) {
self.clearIfNeeded()
self.buffer.psqlWriteMultipleIntegers(id: .copyData, length: dataLength)
}

mutating func copyDone() {
self.clearIfNeeded()
self.buffer.psqlWriteMultipleIntegers(id: .copyDone, length: 0)
}

mutating func copyFail(message: String) {
self.clearIfNeeded()
var messageBuffer = ByteBuffer()
messageBuffer.writeNullTerminatedString(message)
self.buffer.psqlWriteMultipleIntegers(id: .copyFail, length: UInt32(messageBuffer.readableBytes))
self.buffer.writeImmutableBuffer(messageBuffer)
}

mutating func sync() {
self.clearIfNeeded()
self.buffer.psqlWriteMultipleIntegers(id: .sync, length: 0)
Expand Down Expand Up @@ -197,6 +219,9 @@ struct PostgresFrontendMessageEncoder {
private enum FrontendMessageID: UInt8, Hashable, Sendable {
case bind = 66 // B
case close = 67 // C
case copyData = 100 // d
case copyDone = 99 // c
case copyFail = 102 // f
case describe = 68 // D
case execute = 69 // E
case flush = 72 // H
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class ExtendedQueryStateMachineTests: XCTestCase {
.failQuery(promise, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil)))
}

func testExtendedQueryIsCancelledImmediatly() {
func testExtendedQueryIsCancelledImmediately() {
var state = ConnectionStateMachine.readyForQuery()

let logger = Logger.psqlTest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder {
case .commandComplete(let string):
self.encode(messageID: message.id, payload: StringPayload(string), into: &buffer)

case .copyInResponse(let copyInResponse):
self.encode(messageID: message.id, payload: copyInResponse, into: &buffer)
case .dataRow(let row):
self.encode(messageID: message.id, payload: row, into: &buffer)

Expand Down Expand Up @@ -99,6 +101,8 @@ extension PostgresBackendMessage {
return .closeComplete
case .commandComplete:
return .commandComplete
case .copyInResponse:
return .copyInResponse
case .dataRow:
return .dataRow
case .emptyQueryResponse:
Expand Down Expand Up @@ -184,6 +188,16 @@ extension PostgresBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable {
}
}

extension PostgresBackendMessage.CopyInResponse: PSQLMessagePayloadEncodable {
public func encode(into buffer: inout ByteBuffer) {
buffer.writeInteger(Int8(self.format.rawValue))
buffer.writeInteger(Int16(self.columnFormats.count))
for columnFormat in columnFormats {
buffer.writeInteger(Int16(columnFormat.rawValue))
}
}
}

extension DataRow: PSQLMessagePayloadEncodable {
public func encode(into buffer: inout ByteBuffer) {
buffer.writeInteger(self.columnCount, as: Int16.self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ extension PostgresFrontendMessage {
)
)

case .copyData:
return .copyData(CopyData(data: buffer))

case .copyDone:
return .copyDone

case .copyFail:
guard let message = buffer.readNullTerminatedString() else {
throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self)
}
return .copyFail(CopyFail(message: message))

case .close:
preconditionFailure("TODO: Unimplemented")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ enum PostgresFrontendMessage: Equatable {
let secretKey: Int32
}

struct CopyData: Hashable {
var data: ByteBuffer
}

struct CopyFail: Hashable {
var message: String
}

enum Close: Hashable {
case preparedStatement(String)
case portal(String)
Expand Down Expand Up @@ -170,6 +178,9 @@ enum PostgresFrontendMessage: Equatable {

case bind(Bind)
case cancel(Cancel)
case copyData(CopyData)
case copyDone
case copyFail(CopyFail)
case close(Close)
case describe(Describe)
case execute(Execute)
Expand All @@ -186,6 +197,9 @@ enum PostgresFrontendMessage: Equatable {
enum ID: UInt8, Equatable {

case bind
case copyData
case copyDone
case copyFail
case close
case describe
case execute
Expand All @@ -201,12 +215,18 @@ enum PostgresFrontendMessage: Equatable {
switch rawValue {
case UInt8(ascii: "B"):
self = .bind
case UInt8(ascii: "c"):
self = .copyDone
case UInt8(ascii: "C"):
self = .close
case UInt8(ascii: "d"):
self = .copyData
case UInt8(ascii: "D"):
self = .describe
case UInt8(ascii: "E"):
self = .execute
case UInt8(ascii: "f"):
self = .copyFail
case UInt8(ascii: "H"):
self = .flush
case UInt8(ascii: "P"):
Expand All @@ -230,6 +250,12 @@ enum PostgresFrontendMessage: Equatable {
switch self {
case .bind:
return UInt8(ascii: "B")
case .copyData:
return UInt8(ascii: "d")
case .copyDone:
return UInt8(ascii: "c")
case .copyFail:
return UInt8(ascii: "f")
case .close:
return UInt8(ascii: "C")
case .describe:
Expand Down Expand Up @@ -263,6 +289,12 @@ extension PostgresFrontendMessage {
return .bind
case .cancel:
preconditionFailure("Cancel messages don't have an identifier")
case .copyData:
return .copyData
case .copyDone:
return .copyDone
case .copyFail:
return .copyFail
case .close:
return .close
case .describe:
Expand Down
Loading
Loading