diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 78f0d202..087a6c24 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -10,7 +10,8 @@ struct ExtendedQueryStateMachine { case parameterDescriptionReceived(ExtendedQueryContext) case rowDescriptionReceived(ExtendedQueryContext, [RowDescription.Column]) case noDataMessageReceived(ExtendedQueryContext) - + case emptyQueryResponseReceived + /// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is /// used after receiving a `bindComplete` message case bindCompleteReceived(ExtendedQueryContext) @@ -122,7 +123,7 @@ struct ExtendedQueryStateMachine { return .forwardStreamError(.queryCancelled, read: true) } - case .commandComplete, .error, .drain: + case .commandComplete, .emptyQueryResponseReceived, .error, .drain: // the stream has already finished. return .wait @@ -229,6 +230,7 @@ struct ExtendedQueryStateMachine { .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, + .emptyQueryResponseReceived, .bindCompleteReceived, .streaming, .drain, @@ -268,6 +270,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived, .commandComplete, @@ -285,7 +288,7 @@ struct ExtendedQueryStateMachine { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return self.avoidingStateMachineCoW { state -> Action in state = .commandComplete(commandTag: commandTag) - let result = QueryResult(value: .noRows(commandTag), logger: context.logger) + let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) return .succeedQuery(eventLoopPromise, with: result) } @@ -309,6 +312,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .commandComplete, .error: @@ -319,7 +323,22 @@ struct ExtendedQueryStateMachine { } mutating func emptyQueryResponseReceived() -> Action { - preconditionFailure("Unimplemented") + guard case .bindCompleteReceived(let queryContext) = self.state else { + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } + + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), + .executeStatement(_, let eventLoopPromise): + return self.avoidingStateMachineCoW { state -> Action in + state = .emptyQueryResponseReceived + let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger) + return .succeedQuery(eventLoopPromise, with: result) + } + + case .prepareStatement(_, _, _, _): + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } } mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action { @@ -336,7 +355,7 @@ struct ExtendedQueryStateMachine { return self.setAndFireError(error) case .streaming, .drain: return self.setAndFireError(error) - case .commandComplete: + case .commandComplete, .emptyQueryResponseReceived: return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) case .error: preconditionFailure(""" @@ -382,6 +401,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived: preconditionFailure("Requested to consume next row without anything going on.") @@ -405,6 +425,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived: return .wait @@ -449,6 +470,7 @@ struct ExtendedQueryStateMachine { } case .initialized, .commandComplete, + .emptyQueryResponseReceived, .drain, .error: // we already have the complete stream received, now we are waiting for a @@ -495,7 +517,7 @@ struct ExtendedQueryStateMachine { return .forwardStreamError(error, read: true) } - case .commandComplete, .error: + case .commandComplete, .emptyQueryResponseReceived, .error: preconditionFailure(""" This state must not be reached. If the query `.isComplete`, the ConnectionStateMachine must not send any further events to the substate machine. @@ -507,7 +529,7 @@ struct ExtendedQueryStateMachine { var isComplete: Bool { switch self.state { - case .commandComplete, .error: + case .commandComplete, .emptyQueryResponseReceived, .error: return true case .noDataMessageReceived(let context), .rowDescriptionReceived(let context, _): diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index b7f2d4fb..ee925d0e 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -3,7 +3,7 @@ import Logging struct QueryResult { enum Value: Equatable { - case noRows(String) + case noRows(PSQLRowStream.StatementSummary) case rowDescription([RowDescription.Column]) } @@ -16,25 +16,30 @@ struct QueryResult { final class PSQLRowStream: @unchecked Sendable { private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer.Source + enum StatementSummary: Equatable { + case tag(String) + case emptyResponse + } + enum Source { case stream([RowDescription.Column], PSQLRowsDataSource) - case noRows(Result) + case noRows(Result) } let eventLoop: EventLoop let logger: Logger - + private enum BufferState { case streaming(buffer: CircularBuffer, dataSource: PSQLRowsDataSource) - case finished(buffer: CircularBuffer, commandTag: String) + case finished(buffer: CircularBuffer, summary: StatementSummary) case failure(Error) } - + private enum DownstreamState { case waitingForConsumer(BufferState) case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource) - case consumed(Result) + case consumed(Result) case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource, onFinish: @Sendable () -> ()) } @@ -52,9 +57,9 @@ final class PSQLRowStream: @unchecked Sendable { case .stream(let rowDescription, let dataSource): self.rowDescription = rowDescription bufferState = .streaming(buffer: .init(), dataSource: dataSource) - case .noRows(.success(let commandTag)): + case .noRows(.success(let summary)): self.rowDescription = [] - bufferState = .finished(buffer: .init(), commandTag: commandTag) + bufferState = .finished(buffer: .init(), summary: summary) case .noRows(.failure(let error)): self.rowDescription = [] bufferState = .failure(error) @@ -98,12 +103,12 @@ final class PSQLRowStream: @unchecked Sendable { self.downstreamState = .asyncSequence(source, dataSource, onFinish: onFinish) self.executeActionBasedOnYieldResult(yieldResult, source: dataSource) - case .finished(let buffer, let commandTag): + case .finished(let buffer, let summary): _ = source.yield(contentsOf: buffer) source.finish() onFinish() - self.downstreamState = .consumed(.success(commandTag)) - + self.downstreamState = .consumed(.success(summary)) + case .failure(let error): source.finish(error) self.downstreamState = .consumed(.failure(error)) @@ -190,12 +195,12 @@ final class PSQLRowStream: @unchecked Sendable { dataSource.request(for: self) return promise.futureResult - case .finished(let buffer, let commandTag): + case .finished(let buffer, let summary): let rows = buffer.map { PostgresRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription) } - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(summary)) return self.eventLoop.makeSucceededFuture(rows) case .failure(let error): @@ -247,8 +252,8 @@ final class PSQLRowStream: @unchecked Sendable { } return promise.futureResult - - case .finished(let buffer, let commandTag): + + case .finished(let buffer, let summary): do { for data in buffer { let row = PostgresRow( @@ -259,7 +264,7 @@ final class PSQLRowStream: @unchecked Sendable { try onRow(row) } - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(summary)) return self.eventLoop.makeSucceededVoidFuture() } catch { self.downstreamState = .consumed(.failure(error)) @@ -292,7 +297,7 @@ final class PSQLRowStream: @unchecked Sendable { case .waitingForConsumer(.finished), .waitingForConsumer(.failure): preconditionFailure("How can new rows be received, if an end was already signalled?") - + case .iteratingRows(let onRow, let promise, let dataSource): do { for data in newRows { @@ -347,25 +352,25 @@ final class PSQLRowStream: @unchecked Sendable { private func receiveEnd(_ commandTag: String) { switch self.downstreamState { case .waitingForConsumer(.streaming(buffer: let buffer, _)): - self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, commandTag: commandTag)) - - case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, summary: .tag(commandTag))) + + case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)): preconditionFailure("How can we get another end, if an end was already signalled?") case .iteratingRows(_, let promise, _): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) promise.succeed(()) case .waitingForAll(let rows, let promise, _): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) promise.succeed(rows) case .asyncSequence(let source, _, let onFinish): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) source.finish() onFinish() - case .consumed: + case .consumed(.success(.tag)), .consumed(.failure): break } } @@ -375,7 +380,7 @@ final class PSQLRowStream: @unchecked Sendable { case .waitingForConsumer(.streaming): self.downstreamState = .waitingForConsumer(.failure(error)) - case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)): preconditionFailure("How can we get another end, if an end was already signalled?") case .iteratingRows(_, let promise, _): @@ -391,7 +396,7 @@ final class PSQLRowStream: @unchecked Sendable { consumer.finish(error) onFinish() - case .consumed: + case .consumed(.success(.tag)), .consumed(.failure): break } } @@ -413,10 +418,15 @@ final class PSQLRowStream: @unchecked Sendable { } var commandTag: String { - guard case .consumed(.success(let commandTag)) = self.downstreamState else { + guard case .consumed(.success(let consumed)) = self.downstreamState else { preconditionFailure("commandTag may only be called if all rows have been consumed") } - return commandTag + switch consumed { + case .tag(let tag): + return tag + case .emptyResponse: + return "" + } } } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index a3190aa7..ee2af0fe 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -550,9 +550,9 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) self.rowStream = rows - case .noRows(let commandTag): + case .noRows(let summary): rows = PSQLRowStream( - source: .noRows(.success(commandTag)), + source: .noRows(.success(summary)), eventLoop: context.channel.eventLoop, logger: result.logger ) diff --git a/Sources/PostgresNIO/PostgresDatabase+Query.swift b/Sources/PostgresNIO/PostgresDatabase+Query.swift index 01a7e61f..483d5a7b 100644 --- a/Sources/PostgresNIO/PostgresDatabase+Query.swift +++ b/Sources/PostgresNIO/PostgresDatabase+Query.swift @@ -73,10 +73,7 @@ public struct PostgresQueryMetadata: Sendable { init?(string: String) { let parts = string.split(separator: " ") - guard parts.count >= 1 else { - return nil - } - switch parts[0] { + switch parts.first { case "INSERT": // INSERT oid rows guard parts.count == 3 else { diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 57939c06..d541899b 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -123,6 +123,25 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(foo, "hello") } + func testQueryNothing() throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var _result: PostgresQueryResult? + XCTAssertNoThrow(_result = try conn?.query(""" + -- Some comments + """, logger: .psqlTest).wait()) + + let result = try XCTUnwrap(_result) + XCTAssertEqual(result.rows, []) + XCTAssertEqual(result.metadata.command, "") + } + func testDecodeIntegers() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index 40e32468..ae484acc 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -20,7 +20,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) XCTAssertEqual(state.noDataReceived(), .wait) XCTAssertEqual(state.bindCompleteReceived(), .wait) - XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows("DELETE 1"), logger: logger))) + XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows(.tag("DELETE 1")), logger: logger))) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } @@ -77,7 +77,25 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .forwardStreamComplete([row5, row6], commandTag: "SELECT 2")) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } - + + func testExtendedQueryWithNoQuery() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "-- some comments" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + XCTAssertEqual(state.noDataReceived(), .wait) + XCTAssertEqual(state.bindCompleteReceived(), .wait) + XCTAssertEqual(state.emptyQueryResponseReceived(), .succeedQuery(promise, with: .init(value: .noRows(.emptyResponse), logger: logger))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + func testReceiveTotallyUnexpectedMessageInQuery() { var state = ConnectionStateMachine.readyForQuery() diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift index f6c1ddf7..e35e93f7 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift @@ -28,7 +28,7 @@ class PreparedStatementStateMachineTests: XCTestCase { XCTAssertEqual(preparationCompleteAction.statements.count, 1) XCTAssertNil(preparationCompleteAction.rowDescription) firstPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) @@ -46,7 +46,7 @@ class PreparedStatementStateMachineTests: XCTestCase { return } secondPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) @@ -135,12 +135,12 @@ class PreparedStatementStateMachineTests: XCTestCase { XCTAssertNil(preparationCompleteAction.rowDescription) firstPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) secondPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index 9a1e9e41..65ca26c3 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -12,7 +12,7 @@ final class PSQLRowStreamTests: XCTestCase { func testEmptyStream() { let stream = PSQLRowStream( - source: .noRows(.success("INSERT 0 1")), + source: .noRows(.success(.tag("INSERT 0 1"))), eventLoop: self.eventLoop, logger: self.logger )