Skip to content

Commit 6c91234

Browse files
committed
Implement cancellation during copy operations
1 parent ef9ddec commit 6c91234

File tree

6 files changed

+162
-36
lines changed

6 files changed

+162
-36
lines changed

Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,17 @@ public struct PostgresCopyFromWriter: Sendable {
2323
precondition(eventLoop.inEventLoop)
2424
let promise = eventLoop.makePromise(of: Void.self)
2525
self.channelHandler.value.checkBackendCanReceiveCopyData(promise: promise)
26-
promise.futureResult.map {
26+
promise.futureResult.flatMap {
2727
if eventLoop.inEventLoop {
28-
self.channelHandler.value.sendCopyData(byteBuffer)
28+
return eventLoop.makeCompletedFuture(withResultOf: {
29+
try self.channelHandler.value.sendCopyData(byteBuffer)
30+
})
2931
} else {
32+
let promise = eventLoop.makePromise(of: Void.self)
3033
eventLoop.execute {
31-
self.channelHandler.value.sendCopyData(byteBuffer)
34+
promise.completeWith(Result(catching: { try self.channelHandler.value.sendCopyData(byteBuffer) }))
3235
}
36+
return promise.futureResult
3337
}
3438
}.whenComplete { result in
3539
continuation.resume(with: result)
@@ -47,13 +51,32 @@ public struct PostgresCopyFromWriter: Sendable {
4751

4852
// TODO: Buffer data in here and only send a `CopyData` message to the backend once we have accumulated a
4953
// predefined amount of data.
50-
// TODO: Listen for task cancellation while we are waiting for backpressure to clear.
51-
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
54+
try await withTaskCancellationHandler {
55+
do {
56+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
57+
if eventLoop.inEventLoop {
58+
writeAssumingInEventLoop(byteBuffer, continuation)
59+
} else {
60+
eventLoop.execute {
61+
writeAssumingInEventLoop(byteBuffer, continuation)
62+
}
63+
}
64+
}
65+
} catch {
66+
if Task.isCancelled {
67+
// If the task was cancelled, we might receive a postgres error which is an artifact about how we
68+
// communicate the cancellation to the state machine. Throw a `CancellationError` to the user
69+
// instead, which looks more like native Swift Concurrency code.
70+
throw CancellationError()
71+
}
72+
throw error
73+
}
74+
} onCancel: {
5275
if eventLoop.inEventLoop {
53-
writeAssumingInEventLoop(byteBuffer, continuation)
76+
self.channelHandler.value.cancel()
5477
} else {
5578
eventLoop.execute {
56-
writeAssumingInEventLoop(byteBuffer, continuation)
79+
self.channelHandler.value.cancel()
5780
}
5881
}
5982
}

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ struct ConnectionStateMachine {
110110
/// Send a `CopyFail` message to the backend with the given error message.
111111
case sendCopyFail(message: String)
112112

113+
/// Fail the promise with the given error and close the connection.
114+
///
115+
/// This is used when we want to cancel a COPY operation while waiting for backpressure relieve. In that case we
116+
/// can't recover the connection because we can't send any messages to the backend, so we need to close it.
117+
case failPromiseAndCloseConnection(EventLoopPromise<Void>, error: PSQLError, cleanupContext: CleanUpContext)
118+
113119
// --- streaming actions
114120
// actions if query has requested next row but we are waiting for backend
115121
case forwardRows([DataRow])
@@ -867,9 +873,10 @@ struct ConnectionStateMachine {
867873

868874
// MARK: Consumer
869875

870-
mutating func cancelQueryStream() -> ConnectionAction {
876+
mutating func cancel() -> ConnectionAction {
871877
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
872-
preconditionFailure("Tried to cancel stream without active query")
878+
// We are not in a state in which we can cancel. Do nothing.
879+
return .wait
873880
}
874881

875882
self.state = .modifying // avoid CoW
@@ -955,7 +962,8 @@ struct ConnectionStateMachine {
955962
.triggerCopyData,
956963
.sendCopyDoneAndSync,
957964
.sendCopyFail,
958-
.succeedQueryContinuation:
965+
.succeedQueryContinuation,
966+
.failPromiseAndCloseConnection:
959967
preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)")
960968

961969
case .evaluateErrorAtConnectionLevel:
@@ -1149,6 +1157,9 @@ extension ConnectionStateMachine {
11491157
return .sendCopyDoneAndSync
11501158
case .sendCopyFail(message: let message):
11511159
return .sendCopyFail(message: message)
1160+
case .failPromiseAndCloseConnection(let promise, error: let error):
1161+
let cleanupContext = self.setErrorAndCreateCleanupContext(error)
1162+
return .failPromiseAndCloseConnection(promise, error: error, cleanupContext: cleanupContext)
11521163
case .forwardRows(let buffer):
11531164
return .forwardRows(buffer)
11541165
case .forwardStreamComplete(let buffer, let commandTag):

Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ struct ExtendedQueryStateMachine {
7979
/// Send a `CopyFail` message to the backend with the given error message.
8080
case sendCopyFail(message: String)
8181

82+
/// Fail the promise with the given error and close the connection.
83+
///
84+
/// This is used when we want to cancel a COPY operation while waiting for backpressure relieve. In that case we
85+
/// can't recover the connection because we can't send any messages to the backend, so we need to close it.
86+
case failPromiseAndCloseConnection(EventLoopPromise<Void>, error: PSQLError)
87+
8288
// --- streaming actions
8389
// actions if query has requested next row but we are waiting for backend
8490
case forwardRows([DataRow])
@@ -155,8 +161,16 @@ struct ExtendedQueryStateMachine {
155161
return .failPreparedStatementCreation(eventLoopPromise, with: .queryCancelled)
156162
}
157163

158-
case .copyingData:
159-
return .sendCopyFail(message: "Copy cancelled")
164+
case .copyingData(.readyToSend):
165+
// We can't initiate an exit from the copy state here because `copyingFinished`, which is the state that is
166+
// reached after sending a `CopyFail` requires a continuation that waits for the `CommandComplete` or
167+
// `ErrorResponse`. Instead, we assume that the next call to `CopyFromWriter.write` checks cancellation and
168+
// initiates the `CopyFail` with the cancellation.
169+
return .wait
170+
171+
case .copyingData(.pendingBackpressureRelieve(let promise)):
172+
self.state = .error(.queryCancelled)
173+
return .failPromiseAndCloseConnection(promise, error: .queryCancelled)
160174

161175
case .copyingFinished:
162176
// We already finished the copy and are awaiting the `CommandComplete` or `ErrorResponse` from it. There's

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,26 +177,50 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
177177
/// `ErrorResponse`. This is mostly the case when malformed data is sent to it. In that case, the data transfer
178178
/// should be aborted to avoid unnecessary work.
179179
func checkBackendCanReceiveCopyData(promise: EventLoopPromise<Void>) {
180-
self.state.checkBackendCanReceiveCopyData(channelIsWritable: handlerContext!.channel.isWritable, promise: promise)
180+
guard let handlerContext else {
181+
promise.fail(PostgresError.connectionClosed)
182+
return
183+
}
184+
self.state.checkBackendCanReceiveCopyData(channelIsWritable: handlerContext.channel.isWritable, promise: promise)
185+
}
186+
187+
/// Cancel the currently executing operation, if it is cancellable.
188+
func cancel() {
189+
guard let handlerContext else {
190+
return
191+
}
192+
let action = self.state.cancel()
193+
self.run(action, with: handlerContext)
181194
}
182195

183196
/// Send a `CopyData` message to the backend using the given data.
184-
func sendCopyData(_ data: ByteBuffer) {
197+
func sendCopyData(_ data: ByteBuffer) throws {
198+
guard let handlerContext else {
199+
throw PostgresError.connectionClosed
200+
}
185201
self.encoder.copyDataHeader(dataLength: UInt32(data.readableBytes))
186-
self.handlerContext!.write(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
187-
self.handlerContext!.writeAndFlush(self.wrapOutboundOut(data), promise: nil)
202+
handlerContext.write(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
203+
handlerContext.writeAndFlush(self.wrapOutboundOut(data), promise: nil)
188204
}
189205

190206
/// Put the state machine out of the copying mode and send a `CopyDone` message to the backend.
191207
func sendCopyDone(continuation: CheckedContinuation<Void, any Error>) {
208+
guard let handlerContext else {
209+
continuation.resume(throwing: PostgresError.connectionClosed)
210+
return
211+
}
192212
let action = self.state.sendCopyDone(continuation: continuation)
193-
self.run(action, with: self.handlerContext!)
213+
self.run(action, with: handlerContext)
194214
}
195215

196216
/// Put the state machine out of the copying mode and send a `CopyFail` message to the backend.
197217
func sendCopyFail(message: String, continuation: CheckedContinuation<Void, any Error>) {
218+
guard let handlerContext else {
219+
continuation.resume(throwing: PostgresError.connectionClosed)
220+
return
221+
}
198222
let action = self.state.sendCopyFail(message: message, continuation: continuation)
199-
self.run(action, with: self.handlerContext!)
223+
self.run(action, with: handlerContext)
200224
}
201225

202226
func channelReadComplete(context: ChannelHandlerContext) {
@@ -489,6 +513,9 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
489513
}
490514
case .forwardNotificationToListeners(let notification):
491515
self.forwardNotificationToListeners(notification, context: context)
516+
case .failPromiseAndCloseConnection(let promise, let error, let cleanupContext):
517+
promise.fail(error)
518+
self.closeConnectionAndCleanup(cleanupContext, context: context)
492519
}
493520
}
494521

@@ -860,11 +887,10 @@ extension PostgresChannelHandler: PSQLRowsDataSource {
860887
}
861888

862889
func cancel(for stream: PSQLRowStream) {
863-
guard self.rowStream === stream, let handlerContext = self.handlerContext else {
890+
guard self.rowStream === stream else {
864891
return
865892
}
866-
let action = self.state.cancelQueryStream()
867-
self.run(action, with: handlerContext)
893+
self.cancel()
868894
}
869895
}
870896

Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class ExtendedQueryStateMachineTests: XCTestCase {
140140

141141
XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait)
142142
XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger)))
143-
XCTAssertEqual(state.cancelQueryStream(), .forwardStreamError(.queryCancelled, read: false, cleanupContext: nil))
143+
XCTAssertEqual(state.cancel(), .forwardStreamError(.queryCancelled, read: false, cleanupContext: nil))
144144
XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test1")]), .wait)
145145
XCTAssertEqual(state.channelReadComplete(), .wait)
146146
XCTAssertEqual(state.readEventCaught(), .read)
@@ -188,7 +188,7 @@ class ExtendedQueryStateMachineTests: XCTestCase {
188188
XCTAssertEqual(state.dataRowReceived(row1), .wait)
189189
XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1]))
190190
XCTAssertEqual(state.readEventCaught(), .wait)
191-
XCTAssertEqual(state.cancelQueryStream(), .forwardStreamError(.queryCancelled, read: true, cleanupContext: nil))
191+
XCTAssertEqual(state.cancel(), .forwardStreamError(.queryCancelled, read: true, cleanupContext: nil))
192192

193193
XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test2")]), .wait)
194194
XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test3")]), .wait)
@@ -287,7 +287,7 @@ class ExtendedQueryStateMachineTests: XCTestCase {
287287
XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query))
288288
XCTAssertEqual(state.parseCompleteReceived(), .wait)
289289
XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait)
290-
XCTAssertEqual(state.cancelQueryStream(), .failQuery(promise, with: .queryCancelled, cleanupContext: .none))
290+
XCTAssertEqual(state.cancel(), .failQuery(promise, with: .queryCancelled, cleanupContext: .none))
291291

292292
let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"])
293293
XCTAssertEqual(state.errorReceived(serverError), .wait)

0 commit comments

Comments
 (0)