Skip to content

Commit 76034a9

Browse files
committed
Add message types to support COPY operations
This adds the infrastrucutre to decode messages needed for COPY operations. It does not implement the handling support itself yet. That will be added in a follow-up PR.
1 parent d50aade commit 76034a9

12 files changed

+305
-5
lines changed

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,12 @@ struct ConnectionStateMachine {
752752
return self.modify(with: action)
753753
}
754754

755+
mutating func copyInResponseReceived(
756+
_ copyInResponse: PostgresBackendMessage.CopyInResponse
757+
) -> ConnectionAction {
758+
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
759+
}
760+
755761
mutating func emptyQueryResponseReceived() -> ConnectionAction {
756762
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
757763
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse))

Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ struct ExtendedQueryStateMachine {
9191
mutating func cancel() -> Action {
9292
switch self.state {
9393
case .initialized:
94-
preconditionFailure("Start must be called immediatly after the query was created")
94+
preconditionFailure("Start must be called immediately after the query was created")
9595

9696
case .messagesSent(let queryContext),
9797
.parseCompleteReceived(let queryContext),
@@ -322,6 +322,12 @@ struct ExtendedQueryStateMachine {
322322
}
323323
}
324324

325+
mutating func copyInResponseReceived(
326+
_ copyInResponse: PostgresBackendMessage.CopyInResponse
327+
) -> Action {
328+
return self.setAndFireError(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
329+
}
330+
325331
mutating func emptyQueryResponseReceived() -> Action {
326332
guard case .bindCompleteReceived(let queryContext) = self.state else {
327333
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
extension PostgresBackendMessage {
2+
struct CopyInResponse: Hashable {
3+
enum Format: Int8 {
4+
case textual = 0
5+
case binary = 1
6+
}
7+
8+
let format: Format
9+
let columnFormats: [Format]
10+
11+
static func decode(from buffer: inout ByteBuffer) throws -> Self {
12+
guard let rawFormat = buffer.readInteger(endianness: .big, as: Int8.self) else {
13+
throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(1, actual: buffer.readableBytes)
14+
}
15+
guard let format = Format(rawValue: rawFormat) else {
16+
throw PSQLPartialDecodingError.unexpectedValue(value: rawFormat)
17+
}
18+
19+
guard let numColumns = buffer.readInteger(endianness: .big, as: Int16.self) else {
20+
throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes)
21+
}
22+
var columnFormatCodes: [Format] = []
23+
columnFormatCodes.reserveCapacity(Int(numColumns))
24+
25+
for _ in 0..<numColumns {
26+
guard let rawColumnFormat = buffer.readInteger(endianness: .big, as: Int16.self) else {
27+
throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes)
28+
}
29+
guard Int8.min <= rawColumnFormat, rawColumnFormat <= Int8.max, let columnFormat = Format(rawValue: Int8(rawColumnFormat)) else {
30+
throw PSQLPartialDecodingError.unexpectedValue(value: rawColumnFormat)
31+
}
32+
columnFormatCodes.append(columnFormat)
33+
}
34+
35+
return CopyInResponse(format: format, columnFormats: columnFormatCodes)
36+
}
37+
}
38+
}
39+
40+
extension PostgresBackendMessage.CopyInResponse: CustomDebugStringConvertible {
41+
var debugDescription: String {
42+
"format: \(format), columnFormats: \(columnFormats)"
43+
}
44+
}

Sources/PostgresNIO/New/PostgresBackendMessage.swift

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ enum PostgresBackendMessage: Hashable {
2929
case bindComplete
3030
case closeComplete
3131
case commandComplete(String)
32+
case copyInResponse(CopyInResponse)
3233
case dataRow(DataRow)
3334
case emptyQueryResponse
3435
case error(ErrorResponse)
@@ -96,6 +97,9 @@ extension PostgresBackendMessage {
9697
}
9798
return .commandComplete(commandTag)
9899

100+
case .copyInResponse:
101+
return try .copyInResponse(.decode(from: &buffer))
102+
99103
case .dataRow:
100104
return try .dataRow(.decode(from: &buffer))
101105

@@ -131,9 +135,9 @@ extension PostgresBackendMessage {
131135

132136
case .rowDescription:
133137
return try .rowDescription(.decode(from: &buffer))
134-
135-
case .copyData, .copyDone, .copyInResponse, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion:
136-
preconditionFailure()
138+
139+
case .copyData, .copyDone, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion:
140+
throw PSQLPartialDecodingError.unknownMessageKind(messageID)
137141
}
138142
}
139143
}
@@ -151,6 +155,8 @@ extension PostgresBackendMessage: CustomDebugStringConvertible {
151155
return ".closeComplete"
152156
case .commandComplete(let commandTag):
153157
return ".commandComplete(\(String(reflecting: commandTag)))"
158+
case .copyInResponse(let copyInResponse):
159+
return ".copyInResponse(\(String(reflecting: copyInResponse)))"
154160
case .dataRow(let dataRow):
155161
return ".dataRow(\(String(reflecting: dataRow)))"
156162
case .emptyQueryResponse:

Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ struct PSQLPartialDecodingError: Error {
189189
description: "Expected the integer to be positive or null, but got \(actual).",
190190
file: file, line: line)
191191
}
192+
193+
static func unknownMessageKind(_ messageID: PostgresBackendMessage.ID, file: String = #fileID, line: Int = #line) -> Self {
194+
return PSQLPartialDecodingError(
195+
description: "Unknown message kind: \(messageID)",
196+
file: file, line: line)
197+
}
192198
}
193199

194200
extension ByteBuffer {

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
136136
action = self.state.closeCompletedReceived()
137137
case .commandComplete(let commandTag):
138138
action = self.state.commandCompletedReceived(commandTag)
139+
case .copyInResponse(let copyInResponse):
140+
action = self.state.copyInResponseReceived(copyInResponse)
139141
case .dataRow(let dataRow):
140142
action = self.state.dataRowReceived(dataRow)
141143
case .emptyQueryResponse:

Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,28 @@ struct PostgresFrontendMessageEncoder {
167167
self.buffer.writeMultipleIntegers(UInt32(8), Self.sslRequestCode)
168168
}
169169

170+
/// Adds the `CopyData` message ID and `dataLength` to the message buffer but not the actual data.
171+
///
172+
/// The caller of this function is expected to write the encoder's message buffer to the backend after calling this
173+
/// function, followed by sending the actual data to the backend.
174+
mutating func copyDataHeader(dataLength: UInt32) {
175+
self.clearIfNeeded()
176+
self.buffer.psqlWriteMultipleIntegers(id: .copyData, length: dataLength)
177+
}
178+
179+
mutating func copyDone() {
180+
self.clearIfNeeded()
181+
self.buffer.psqlWriteMultipleIntegers(id: .copyDone, length: 0)
182+
}
183+
184+
mutating func copyFail(message: String) {
185+
self.clearIfNeeded()
186+
var messageBuffer = ByteBuffer()
187+
messageBuffer.writeNullTerminatedString(message)
188+
self.buffer.psqlWriteMultipleIntegers(id: .copyFail, length: UInt32(messageBuffer.readableBytes))
189+
self.buffer.writeImmutableBuffer(messageBuffer)
190+
}
191+
170192
mutating func sync() {
171193
self.clearIfNeeded()
172194
self.buffer.psqlWriteMultipleIntegers(id: .sync, length: 0)
@@ -197,6 +219,9 @@ struct PostgresFrontendMessageEncoder {
197219
private enum FrontendMessageID: UInt8, Hashable, Sendable {
198220
case bind = 66 // B
199221
case close = 67 // C
222+
case copyData = 100 // d
223+
case copyDone = 99 // c
224+
case copyFail = 102 // f
200225
case describe = 68 // D
201226
case execute = 69 // E
202227
case flush = 72 // H

Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class ExtendedQueryStateMachineTests: XCTestCase {
114114
.failQuery(promise, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil)))
115115
}
116116

117-
func testExtendedQueryIsCancelledImmediatly() {
117+
func testExtendedQueryIsCancelledImmediately() {
118118
var state = ConnectionStateMachine.readyForQuery()
119119

120120
let logger = Logger.psqlTest

Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder {
2828
case .commandComplete(let string):
2929
self.encode(messageID: message.id, payload: StringPayload(string), into: &buffer)
3030

31+
case .copyInResponse(let copyInResponse):
32+
self.encode(messageID: message.id, payload: copyInResponse, into: &buffer)
3133
case .dataRow(let row):
3234
self.encode(messageID: message.id, payload: row, into: &buffer)
3335

@@ -99,6 +101,8 @@ extension PostgresBackendMessage {
99101
return .closeComplete
100102
case .commandComplete:
101103
return .commandComplete
104+
case .copyInResponse:
105+
return .copyInResponse
102106
case .dataRow:
103107
return .dataRow
104108
case .emptyQueryResponse:
@@ -184,6 +188,16 @@ extension PostgresBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable {
184188
}
185189
}
186190

191+
extension PostgresBackendMessage.CopyInResponse: PSQLMessagePayloadEncodable {
192+
public func encode(into buffer: inout ByteBuffer) {
193+
buffer.writeInteger(Int8(self.format.rawValue))
194+
buffer.writeInteger(Int16(self.columnFormats.count))
195+
for columnFormat in columnFormats {
196+
buffer.writeInteger(Int16(columnFormat.rawValue))
197+
}
198+
}
199+
}
200+
187201
extension DataRow: PSQLMessagePayloadEncodable {
188202
public func encode(into buffer: inout ByteBuffer) {
189203
buffer.writeInteger(self.columnCount, as: Int16.self)

Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ extension PostgresFrontendMessage {
168168
)
169169
)
170170

171+
case .copyData:
172+
return .copyData(CopyData(data: buffer))
173+
174+
case .copyDone:
175+
return .copyDone
176+
177+
case .copyFail:
178+
guard let message = buffer.readNullTerminatedString() else {
179+
throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self)
180+
}
181+
return .copyFail(CopyFail(message: message))
182+
171183
case .close:
172184
preconditionFailure("TODO: Unimplemented")
173185

0 commit comments

Comments
 (0)