Skip to content

Commit ef9ddec

Browse files
committed
Implement COPY … FROM STDIN queries
This implements support for COPY operations using `COPY … FROM STDIN` queries for fast data transfer from the client to the backend.
1 parent 5ae3ab0 commit ef9ddec

File tree

9 files changed

+1220
-86
lines changed

9 files changed

+1220
-86
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
/// Handle to send data for a `COPY ... FROM STDIN` query to the backend.
2+
public struct PostgresCopyFromWriter: Sendable {
3+
/// The backend failed the copy data transfer, which means that no more data sent by the frontend would be processed.
4+
///
5+
/// The `PostgresCopyFromWriter` should cancel the data transfer.
6+
public struct CopyCancellationError: Error {
7+
/// The error that the backend sent us which cancelled the data transfer.
8+
///
9+
/// Note that this error is related to previous `write` calls since a `CopyCancellationError` is thrown before
10+
/// new data is written by `write`.
11+
public let underlyingError: PSQLError
12+
}
13+
14+
private let channelHandler: NIOLoopBound<PostgresChannelHandler>
15+
private let eventLoop: any EventLoop
16+
17+
init(handler: PostgresChannelHandler, eventLoop: any EventLoop) {
18+
self.channelHandler = NIOLoopBound(handler, eventLoop: eventLoop)
19+
self.eventLoop = eventLoop
20+
}
21+
22+
private func writeAssumingInEventLoop(_ byteBuffer: ByteBuffer, _ continuation: CheckedContinuation<Void, any Error>) {
23+
precondition(eventLoop.inEventLoop)
24+
let promise = eventLoop.makePromise(of: Void.self)
25+
self.channelHandler.value.checkBackendCanReceiveCopyData(promise: promise)
26+
promise.futureResult.map {
27+
if eventLoop.inEventLoop {
28+
self.channelHandler.value.sendCopyData(byteBuffer)
29+
} else {
30+
eventLoop.execute {
31+
self.channelHandler.value.sendCopyData(byteBuffer)
32+
}
33+
}
34+
}.whenComplete { result in
35+
continuation.resume(with: result)
36+
}
37+
}
38+
39+
/// Send data for a `COPY ... FROM STDIN` operation to the backend.
40+
///
41+
/// If the backend encountered an error during the data transfer and thus cannot process any more data, this throws
42+
/// a `CopyCancellationError`.
43+
public func write(_ byteBuffer: ByteBuffer) async throws {
44+
// Check for cancellation. This is cheap and makes sure that we regularly check for cancellation in the
45+
// `writeData` closure. It is likely that the user would forget to do so.
46+
try Task.checkCancellation()
47+
48+
// TODO: Buffer data in here and only send a `CopyData` message to the backend once we have accumulated a
49+
// predefined amount of data.
50+
// TODO: Listen for task cancellation while we are waiting for backpressure to clear.
51+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
52+
if eventLoop.inEventLoop {
53+
writeAssumingInEventLoop(byteBuffer, continuation)
54+
} else {
55+
eventLoop.execute {
56+
writeAssumingInEventLoop(byteBuffer, continuation)
57+
}
58+
}
59+
}
60+
}
61+
62+
/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyDone` message to
63+
/// the backend.
64+
func done() async throws {
65+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
66+
if eventLoop.inEventLoop {
67+
self.channelHandler.value.sendCopyDone(continuation: continuation)
68+
} else {
69+
eventLoop.execute {
70+
self.channelHandler.value.sendCopyDone(continuation: continuation)
71+
}
72+
}
73+
}
74+
}
75+
76+
/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyFail` message to
77+
/// the backend.
78+
func failed(error: any Error) async throws {
79+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
80+
// TODO: Is it OK to use string interpolation to construct an error description to be sent to the backend
81+
// here? We could also use a generic description, it doesn't really matter since we throw the user's error
82+
// in `copyFrom`.
83+
if eventLoop.inEventLoop {
84+
self.channelHandler.value.sendCopyFail(message: "\(error)", continuation: continuation)
85+
} else {
86+
eventLoop.execute {
87+
self.channelHandler.value.sendCopyFail(message: "\(error)", continuation: continuation)
88+
}
89+
}
90+
}
91+
}
92+
}
93+
94+
/// Specifies the format in which data is transferred to the backend in a COPY operation.
95+
public enum PostgresCopyFromFormat: Sendable {
96+
/// Options that can be used to modify the `text` format of a COPY operation.
97+
public struct TextOptions: Sendable {
98+
/// The delimiter that separates columns in the data.
99+
///
100+
/// See the `DELIMITER` option in Postgres's `COPY` command.
101+
///
102+
/// Uses the default delimiter of the format
103+
public var delimiter: UnicodeScalar? = nil
104+
105+
public init() {}
106+
}
107+
108+
case text(TextOptions)
109+
}
110+
111+
/// Create a `COPY ... FROM STDIN` query based on the given parameters.
112+
///
113+
/// An empty `columns` array signifies that no columns should be specified in the query and that all columns will be
114+
/// copied by the caller.
115+
private func buildCopyFromQuery(
116+
table: StaticString,
117+
columns: [StaticString] = [],
118+
format: PostgresCopyFromFormat
119+
) -> PostgresQuery {
120+
// TODO: Should we put the table and column names in quotes to make them case-sensitive?
121+
var query = "COPY \(table)"
122+
if !columns.isEmpty {
123+
query += "(" + columns.map(\.description).joined(separator: ",") + ")"
124+
}
125+
query += " FROM STDIN"
126+
var queryOptions: [String] = []
127+
switch format {
128+
case .text(let options):
129+
queryOptions.append("FORMAT text")
130+
if let delimiter = options.delimiter {
131+
// Set the delimiter as a Unicode code point. This avoids the possibility of SQL injection.
132+
queryOptions.append("DELIMITER U&'\\\(String(format: "%04x", delimiter.value))'")
133+
}
134+
}
135+
precondition(!queryOptions.isEmpty)
136+
query += " WITH ("
137+
query += queryOptions.map { "\($0)" }.joined(separator: ",")
138+
query += ")"
139+
return "\(unescaped: query)"
140+
}
141+
142+
extension PostgresConnection {
143+
/// Copy data into a table using a `COPY <table name> FROM STDIN` query.
144+
///
145+
/// - Parameters:
146+
/// - table: The name of the table into which to copy the data.
147+
/// - columns: The name of the columns to copy. If an empty array is passed, all columns are assumed to be copied.
148+
/// - format: Options that specify the format of the data that is produced by `writeData`.
149+
/// - writeData: Closure that produces the data for the table, to be streamed to the backend. Call `write` on the
150+
/// writer provided by the closure to send data to the backend and return from the closure once all data is sent.
151+
/// Throw an error from the closure to fail the data transfer. The error thrown by the closure will be rethrown
152+
/// by the `copyFrom` function.
153+
///
154+
/// - Note: The table and column names are inserted into the SQL query verbatim. They are forced to be compile-time
155+
/// specified to avoid runtime SQL injection attacks.
156+
public func copyFrom(
157+
table: StaticString,
158+
columns: [StaticString] = [],
159+
format: PostgresCopyFromFormat = .text(.init()),
160+
logger: Logger,
161+
file: String = #fileID,
162+
line: Int = #line,
163+
writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void
164+
) async throws {
165+
var logger = logger
166+
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
167+
let writer: PostgresCopyFromWriter = try await withCheckedThrowingContinuation { continuation in
168+
let context = ExtendedQueryContext(
169+
copyFromQuery: buildCopyFromQuery(table: table, columns: columns, format: format),
170+
triggerCopy: continuation,
171+
logger: logger
172+
)
173+
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
174+
}
175+
176+
do {
177+
try await writeData(writer)
178+
} catch {
179+
// We need to send a `CopyFail` to the backend to put it out of copy mode. This will most likely throw, most
180+
// notably for the following two reasons. In both of them, it's better to ignore the error thrown by
181+
// `writer.failed` and instead throw the error from `writeData`:
182+
// - We send `CopyFail` and the backend replies with an `ErrorResponse` that relays the `CopyFail` message.
183+
// This took the backend out of copy mode but it's more informative to the user to see the error they
184+
// threw instead of the one that got relayed back, so it's better to ignore the error here.
185+
// - The backend sent us an `ErrorResponse` during the copy, eg. because of an invalid format. This puts
186+
// the `ExtendedQueryStateMachine` in the error state. Trying to send a `CopyFail` will throw but trigger
187+
// a `Sync` that takes the backend out of copy mode. If `writeData` threw the `CopyCancellationError`
188+
// from the `PostgresCopyFromWriter.write` call, `writer.failed` will throw with the same error, so it
189+
// doesn't matter that we ignore the error here. If the user threw some other error, it's better to honor
190+
// the user's error.
191+
try? await writer.failed(error: error)
192+
193+
if let error = error as? PostgresCopyFromWriter.CopyCancellationError {
194+
// If we receive a `CopyCancellationError` that is with almost certain likelihood because
195+
// `PostgresCopyFromWriter.write` threw it - otherwise the user must have saved a previous
196+
// `PostgresCopyFromWriter` error, which is very unlikely.
197+
// Throw the underlying error because that contains the error message that was sent by the backend and
198+
// is most actionable by the user.
199+
throw error.underlyingError
200+
} else {
201+
throw error
202+
}
203+
}
204+
205+
// `writer.done` may fail, eg. because the backend sends an error response after receiving `CopyDone` or during
206+
// the transfer of the last bit of data so that the user didn't call `PostgresCopyFromWriter.write` again, which
207+
// would have checked the error state. In either of these cases, calling `writer.done` puts the backend out of
208+
// copy mode, so we don't need to send another `CopyFail`. Thus, this must not be handled in the `do` block
209+
// above.
210+
try await writer.done()
211+
}
212+
213+
}

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,27 @@ struct ConnectionStateMachine {
8888
case sendParseDescribeBindExecuteSync(PostgresQuery)
8989
case sendBindExecuteSync(PSQLExecuteStatement)
9090
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError, cleanupContext: CleanUpContext?)
91+
/// Fail a query's execution by resuming the continuation with the given error. When `sync` is `true`, send a
92+
/// `Sync` message to the backend.
93+
case failQueryContinuation(AnyErrorContinuation, with: PSQLError, sync: Bool, cleanupContext: CleanUpContext?)
94+
/// Fail a query's execution by resuming the continuation with the given error and send a `Sync` message to the
95+
/// backend.
9196
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
97+
/// Succeed the continuation with a void result. When `sync` is `true`, send a `Sync` message to the backend.
98+
case succeedQueryContinuation(CheckedContinuation<Void, any Error>, sync: Bool)
99+
100+
/// Trigger a data transfer returning a `PostgresCopyFromWriter` to the given continuation.
101+
///
102+
/// Once the data transfer is triggered, it will send `CopyData` messages to the backend. After that the state
103+
/// machine needs to be prodded again to send a `CopyDone` or `CopyFail` by calling
104+
/// `PostgresChannelHandler.sendCopyDone` or `PostgresChannelHandler.sendCopyFail`.
105+
case triggerCopyData(CheckedContinuation<PostgresCopyFromWriter, any Error>)
106+
107+
/// Send a `CopyDone` and `Sync` message to the backend.
108+
case sendCopyDoneAndSync
109+
110+
/// Send a `CopyFail` message to the backend with the given error message.
111+
case sendCopyFail(message: String)
92112

93113
// --- streaming actions
94114
// actions if query has requested next row but we are waiting for backend
@@ -107,6 +127,14 @@ struct ConnectionStateMachine {
107127
case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?)
108128
}
109129

130+
enum ChannelWritabilityChangedAction {
131+
/// No action needs to be taken based on the writability change.
132+
case none
133+
134+
/// Resume the given continuation successfully.
135+
case succeedPromise(EventLoopPromise<Void>)
136+
}
137+
110138
private var state: State
111139
private let requireBackendKeyData: Bool
112140
private var taskQueue = CircularBuffer<PSQLTask>()
@@ -587,6 +615,8 @@ struct ConnectionStateMachine {
587615
switch queryContext.query {
588616
case .executeStatement(_, let promise), .unnamed(_, let promise):
589617
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
618+
case .copyFrom(_, let triggerCopy):
619+
return .failQueryContinuation(.copyFromWriter(triggerCopy), with: psqlErrror, sync: false, cleanupContext: nil)
590620
case .prepareStatement(_, _, _, let promise):
591621
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
592622
}
@@ -660,6 +690,16 @@ struct ConnectionStateMachine {
660690
preconditionFailure("Invalid state: \(self.state)")
661691
}
662692
}
693+
694+
mutating func channelWritabilityChanged(isWritable: Bool) -> ChannelWritabilityChangedAction {
695+
guard case .extendedQuery(var queryState, let connectionContext) = state else {
696+
return .none
697+
}
698+
self.state = .modifying // avoid CoW
699+
let action = queryState.channelWritabilityChanged(isWritable: isWritable)
700+
self.state = .extendedQuery(queryState, connectionContext)
701+
return action
702+
}
663703

664704
// MARK: - Running Queries -
665705

@@ -752,10 +792,55 @@ struct ConnectionStateMachine {
752792
return self.modify(with: action)
753793
}
754794

755-
mutating func copyInResponseReceived(
756-
_ copyInResponse: PostgresBackendMessage.CopyInResponse
757-
) -> ConnectionAction {
758-
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
795+
mutating func copyInResponseReceived(_ copyInResponse: PostgresBackendMessage.CopyInResponse) -> ConnectionAction {
796+
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
797+
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
798+
}
799+
800+
self.state = .modifying // avoid CoW
801+
let action = queryState.copyInResponseReceived(copyInResponse)
802+
self.state = .extendedQuery(queryState, connectionContext)
803+
return self.modify(with: action)
804+
}
805+
806+
807+
/// Succeed the promise when the channel to the backend is writable and the backend is ready to receive more data.
808+
///
809+
/// The promise may be failed if the backend indicated that it can't handle any more data by sending an
810+
/// `ErrorResponse`. This is mostly the case when malformed data is sent to it. In that case, the data transfer
811+
/// should be aborted to avoid unnecessary work.
812+
mutating func checkBackendCanReceiveCopyData(channelIsWritable: Bool, promise: EventLoopPromise<Void>) {
813+
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
814+
preconditionFailure("Copy mode is only supported for extended queries")
815+
}
816+
817+
self.state = .modifying // avoid CoW
818+
queryState.checkBackendCanReceiveCopyData(channelIsWritable: channelIsWritable, promise: promise)
819+
self.state = .extendedQuery(queryState, connectionContext)
820+
}
821+
822+
/// Put the state machine out of the copying mode and send a `CopyDone` message to the backend.
823+
mutating func sendCopyDone(continuation: CheckedContinuation<Void, any Error>) -> ConnectionAction {
824+
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
825+
preconditionFailure("Copy mode is only supported for extended queries")
826+
}
827+
828+
self.state = .modifying // avoid CoW
829+
let action = queryState.sendCopyDone(continuation: continuation)
830+
self.state = .extendedQuery(queryState, connectionContext)
831+
return self.modify(with: action)
832+
}
833+
834+
/// Put the state machine out of the copying mode and send a `CopyFail` message to the backend.
835+
mutating func sendCopyFail(message: String, continuation: CheckedContinuation<Void, any Error>) -> ConnectionAction {
836+
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
837+
preconditionFailure("Copy mode is only supported for extended queries")
838+
}
839+
840+
self.state = .modifying // avoid CoW
841+
let action = queryState.sendCopyFail(message: message, continuation: continuation)
842+
self.state = .extendedQuery(queryState, connectionContext)
843+
return self.modify(with: action)
759844
}
760845

761846
mutating func emptyQueryResponseReceived() -> ConnectionAction {
@@ -866,14 +951,21 @@ struct ConnectionStateMachine {
866951
.forwardRows,
867952
.forwardStreamComplete,
868953
.wait,
869-
.read:
954+
.read,
955+
.triggerCopyData,
956+
.sendCopyDoneAndSync,
957+
.sendCopyFail,
958+
.succeedQueryContinuation:
870959
preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)")
871960

872961
case .evaluateErrorAtConnectionLevel:
873962
return .closeConnectionAndCleanup(cleanupContext)
874963

875-
case .failQuery(let queryContext, with: let error):
876-
return .failQuery(queryContext, with: error, cleanupContext: cleanupContext)
964+
case .failQuery(let promise, with: let error):
965+
return .failQuery(promise, with: error, cleanupContext: cleanupContext)
966+
967+
case .failQueryContinuation(let continuation, with: let error, let sync):
968+
return .failQueryContinuation(continuation, with: error, sync: sync, cleanupContext: cleanupContext)
877969

878970
case .forwardStreamError(let error, let read):
879971
return .forwardStreamError(error, read: read, cleanupContext: cleanupContext)
@@ -1044,8 +1136,19 @@ extension ConnectionStateMachine {
10441136
case .failQuery(let requestContext, with: let error):
10451137
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
10461138
return .failQuery(requestContext, with: error, cleanupContext: cleanupContext)
1139+
case .failQueryContinuation(let continuation, with: let error, let sync):
1140+
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
1141+
return .failQueryContinuation(continuation, with: error, sync: sync, cleanupContext: cleanupContext)
10471142
case .succeedQuery(let requestContext, with: let result):
10481143
return .succeedQuery(requestContext, with: result)
1144+
case .succeedQueryContinuation(let continuation, let sync):
1145+
return .succeedQueryContinuation(continuation, sync: sync)
1146+
case .triggerCopyData(let triggerCopy):
1147+
return .triggerCopyData(triggerCopy)
1148+
case .sendCopyDoneAndSync:
1149+
return .sendCopyDoneAndSync
1150+
case .sendCopyFail(message: let message):
1151+
return .sendCopyFail(message: message)
10491152
case .forwardRows(let buffer):
10501153
return .forwardRows(buffer)
10511154
case .forwardStreamComplete(let buffer, let commandTag):

0 commit comments

Comments
 (0)