Skip to content

Commit e04e1c6

Browse files
committed
Add message types to support COPY operations
This adds the infrastrucutre to decode messages needed for COPY operations. It does not implement the handling support itself yet. That will be added in a follow-up PR.
1 parent c17db2f commit e04e1c6

File tree

11 files changed

+299
-5
lines changed

11 files changed

+299
-5
lines changed

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,12 @@ struct ConnectionStateMachine {
752752
return self.modify(with: action)
753753
}
754754

755+
mutating func copyInResponseReceived(
756+
_ copyInResponse: PostgresBackendMessage.CopyInResponseMessage
757+
) -> ConnectionAction {
758+
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
759+
}
760+
755761
mutating func emptyQueryResponseReceived() -> ConnectionAction {
756762
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
757763
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse))

Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ struct ExtendedQueryStateMachine {
9191
mutating func cancel() -> Action {
9292
switch self.state {
9393
case .initialized:
94-
preconditionFailure("Start must be called immediatly after the query was created")
94+
preconditionFailure("Start must be called immediately after the query was created")
9595

9696
case .messagesSent(let queryContext),
9797
.parseCompleteReceived(let queryContext),
@@ -322,6 +322,12 @@ struct ExtendedQueryStateMachine {
322322
}
323323
}
324324

325+
mutating func copyInResponseReceived(
326+
_ copyInResponse: PostgresBackendMessage.CopyInResponseMessage
327+
) -> Action {
328+
return self.setAndFireError(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
329+
}
330+
325331
mutating func emptyQueryResponseReceived() -> Action {
326332
guard case .bindCompleteReceived(let queryContext) = self.state else {
327333
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
extension PostgresBackendMessage {
2+
struct CopyInResponseMessage: Hashable {
3+
enum Format: Int {
4+
case textual = 0
5+
case binary = 1
6+
}
7+
8+
let format: Format
9+
let columnFormats: [Format]
10+
11+
static func decode(from buffer: inout ByteBuffer) throws -> Self {
12+
guard let rawFormat = buffer.readInteger(endianness: .big, as: Int8.self) else {
13+
throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(1, actual: buffer.readableBytes)
14+
}
15+
guard let format = Format(rawValue: Int(rawFormat)) else {
16+
throw PSQLPartialDecodingError.unexpectedValue(value: rawFormat)
17+
}
18+
19+
guard let numColumns = buffer.readInteger(endianness: .big, as: Int16.self) else {
20+
throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes)
21+
}
22+
var columnFormatCodes: [Format] = []
23+
columnFormatCodes.reserveCapacity(Int(numColumns))
24+
25+
for _ in 0..<numColumns {
26+
guard let rawColumnFormat = buffer.readInteger(endianness: .big, as: Int16.self) else {
27+
throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes)
28+
}
29+
guard let columnFormat = Format(rawValue: Int(rawColumnFormat)) else {
30+
throw PSQLPartialDecodingError.unexpectedValue(value: rawColumnFormat)
31+
}
32+
columnFormatCodes.append(columnFormat)
33+
}
34+
35+
return CopyInResponseMessage(format: format, columnFormats: columnFormatCodes)
36+
}
37+
}
38+
}
39+
40+
extension PostgresBackendMessage.CopyInResponseMessage: CustomDebugStringConvertible {
41+
var debugDescription: String {
42+
"format: \(format), columnFormats: \(columnFormats)"
43+
}
44+
}

Sources/PostgresNIO/New/PostgresBackendMessage.swift

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ enum PostgresBackendMessage: Hashable {
2929
case bindComplete
3030
case closeComplete
3131
case commandComplete(String)
32+
case copyInResponse(CopyInResponseMessage)
3233
case dataRow(DataRow)
3334
case emptyQueryResponse
3435
case error(ErrorResponse)
@@ -96,6 +97,9 @@ extension PostgresBackendMessage {
9697
}
9798
return .commandComplete(commandTag)
9899

100+
case .copyInResponse:
101+
return try .copyInResponse(.decode(from: &buffer))
102+
99103
case .dataRow:
100104
return try .dataRow(.decode(from: &buffer))
101105

@@ -131,9 +135,9 @@ extension PostgresBackendMessage {
131135

132136
case .rowDescription:
133137
return try .rowDescription(.decode(from: &buffer))
134-
135-
case .copyData, .copyDone, .copyInResponse, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion:
136-
preconditionFailure()
138+
139+
case .copyData, .copyDone, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion:
140+
preconditionFailure("Unknown message kind: \(messageID)")
137141
}
138142
}
139143
}
@@ -151,6 +155,8 @@ extension PostgresBackendMessage: CustomDebugStringConvertible {
151155
return ".closeComplete"
152156
case .commandComplete(let commandTag):
153157
return ".commandComplete(\(String(reflecting: commandTag)))"
158+
case .copyInResponse(let copyInResponse):
159+
return ".copyInResponse(\(String(reflecting: copyInResponse)))"
154160
case .dataRow(let dataRow):
155161
return ".dataRow(\(String(reflecting: dataRow)))"
156162
case .emptyQueryResponse:

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
136136
action = self.state.closeCompletedReceived()
137137
case .commandComplete(let commandTag):
138138
action = self.state.commandCompletedReceived(commandTag)
139+
case .copyInResponse(let copyInResponse):
140+
action = self.state.copyInResponseReceived(copyInResponse)
139141
case .dataRow(let dataRow):
140142
action = self.state.dataRowReceived(dataRow)
141143
case .emptyQueryResponse:

Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,28 @@ struct PostgresFrontendMessageEncoder {
167167
self.buffer.writeMultipleIntegers(UInt32(8), Self.sslRequestCode)
168168
}
169169

170+
/// Adds the `CopyData` message ID and `dataLength` to the message buffer but not the actual data.
171+
///
172+
/// The caller of this function is expected to write the encoder's message buffer to the backend after calling this
173+
/// function, followed by sending the actual data to the backend.
174+
mutating func copyDataHeader(dataLength: UInt32) {
175+
self.clearIfNeeded()
176+
self.buffer.psqlWriteMultipleIntegers(id: .copyData, length: dataLength)
177+
}
178+
179+
mutating func copyDone() {
180+
self.clearIfNeeded()
181+
self.buffer.psqlWriteMultipleIntegers(id: .copyDone, length: 0)
182+
}
183+
184+
mutating func copyFail(message: String) {
185+
self.clearIfNeeded()
186+
var messageBuffer = ByteBuffer()
187+
messageBuffer.writeNullTerminatedString(message)
188+
self.buffer.psqlWriteMultipleIntegers(id: .copyFail, length: UInt32(messageBuffer.readableBytes))
189+
self.buffer.writeImmutableBuffer(messageBuffer)
190+
}
191+
170192
mutating func sync() {
171193
self.clearIfNeeded()
172194
self.buffer.psqlWriteMultipleIntegers(id: .sync, length: 0)
@@ -197,6 +219,9 @@ struct PostgresFrontendMessageEncoder {
197219
private enum FrontendMessageID: UInt8, Hashable, Sendable {
198220
case bind = 66 // B
199221
case close = 67 // C
222+
case copyData = 100 // d
223+
case copyDone = 99 // c
224+
case copyFail = 102 // f
200225
case describe = 68 // D
201226
case execute = 69 // E
202227
case flush = 72 // H

Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class ExtendedQueryStateMachineTests: XCTestCase {
114114
.failQuery(promise, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil)))
115115
}
116116

117-
func testExtendedQueryIsCancelledImmediatly() {
117+
func testExtendedQueryIsCancelledImmediately() {
118118
var state = ConnectionStateMachine.readyForQuery()
119119

120120
let logger = Logger.psqlTest

Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder {
2828
case .commandComplete(let string):
2929
self.encode(messageID: message.id, payload: StringPayload(string), into: &buffer)
3030

31+
case .copyInResponse(let copyInResponse):
32+
self.encode(messageID: message.id, payload: copyInResponse, into: &buffer)
3133
case .dataRow(let row):
3234
self.encode(messageID: message.id, payload: row, into: &buffer)
3335

@@ -99,6 +101,8 @@ extension PostgresBackendMessage {
99101
return .closeComplete
100102
case .commandComplete:
101103
return .commandComplete
104+
case .copyInResponse:
105+
return .copyInResponse
102106
case .dataRow:
103107
return .dataRow
104108
case .emptyQueryResponse:
@@ -184,6 +188,16 @@ extension PostgresBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable {
184188
}
185189
}
186190

191+
extension PostgresBackendMessage.CopyInResponseMessage: PSQLMessagePayloadEncodable {
192+
public func encode(into buffer: inout ByteBuffer) {
193+
buffer.writeInteger(Int8(self.format.rawValue))
194+
buffer.writeInteger(Int16(self.columnFormats.count))
195+
for columnFormat in columnFormats {
196+
buffer.writeInteger(Int16(columnFormat.rawValue))
197+
}
198+
}
199+
}
200+
187201
extension DataRow: PSQLMessagePayloadEncodable {
188202
public func encode(into buffer: inout ByteBuffer) {
189203
buffer.writeInteger(self.columnCount, as: Int16.self)

Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ extension PostgresFrontendMessage {
168168
)
169169
)
170170

171+
case .copyData:
172+
return .copyData(CopyData(data: buffer))
173+
174+
case .copyDone:
175+
return .copyDone
176+
177+
case .copyFail:
178+
guard let message = buffer.readNullTerminatedString() else {
179+
throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self)
180+
}
181+
return .copyFail(CopyFail(message: message))
182+
171183
case .close:
172184
preconditionFailure("TODO: Unimplemented")
173185

Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ enum PostgresFrontendMessage: Equatable {
3636
let secretKey: Int32
3737
}
3838

39+
struct CopyData: Equatable {
40+
let data: ByteBuffer
41+
}
42+
43+
struct CopyFail: Equatable {
44+
let message: String
45+
}
46+
3947
enum Close: Hashable {
4048
case preparedStatement(String)
4149
case portal(String)
@@ -170,6 +178,9 @@ enum PostgresFrontendMessage: Equatable {
170178

171179
case bind(Bind)
172180
case cancel(Cancel)
181+
case copyData(CopyData)
182+
case copyDone
183+
case copyFail(CopyFail)
173184
case close(Close)
174185
case describe(Describe)
175186
case execute(Execute)
@@ -186,6 +197,9 @@ enum PostgresFrontendMessage: Equatable {
186197
enum ID: UInt8, Equatable {
187198

188199
case bind
200+
case copyData
201+
case copyDone
202+
case copyFail
189203
case close
190204
case describe
191205
case execute
@@ -201,12 +215,18 @@ enum PostgresFrontendMessage: Equatable {
201215
switch rawValue {
202216
case UInt8(ascii: "B"):
203217
self = .bind
218+
case UInt8(ascii: "c"):
219+
self = .copyDone
204220
case UInt8(ascii: "C"):
205221
self = .close
222+
case UInt8(ascii: "d"):
223+
self = .copyData
206224
case UInt8(ascii: "D"):
207225
self = .describe
208226
case UInt8(ascii: "E"):
209227
self = .execute
228+
case UInt8(ascii: "f"):
229+
self = .copyFail
210230
case UInt8(ascii: "H"):
211231
self = .flush
212232
case UInt8(ascii: "P"):
@@ -230,6 +250,12 @@ enum PostgresFrontendMessage: Equatable {
230250
switch self {
231251
case .bind:
232252
return UInt8(ascii: "B")
253+
case .copyData:
254+
return UInt8(ascii: "d")
255+
case .copyDone:
256+
return UInt8(ascii: "c")
257+
case .copyFail:
258+
return UInt8(ascii: "f")
233259
case .close:
234260
return UInt8(ascii: "C")
235261
case .describe:
@@ -263,6 +289,12 @@ extension PostgresFrontendMessage {
263289
return .bind
264290
case .cancel:
265291
preconditionFailure("Cancel messages don't have an identifier")
292+
case .copyData:
293+
return .copyData
294+
case .copyDone:
295+
return .copyDone
296+
case .copyFail:
297+
return .copyFail
266298
case .close:
267299
return .close
268300
case .describe:

0 commit comments

Comments
 (0)