Skip to content

Commit a528171

Browse files
authored
add postgres query metadata (#95)
1 parent 97f2778 commit a528171

File tree

3 files changed

+118
-11
lines changed

3 files changed

+118
-11
lines changed

Sources/PostgresNIO/Message/PostgresMessage+CommandComplete.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import NIO
22

33
extension PostgresMessage {
44
/// Identifies the message as a Close command.
5-
public struct CommandComplete {
5+
public struct CommandComplete: PostgresMessageType {
66
/// Parses an instance of this message type from a byte buffer.
77
public static func parse(from buffer: inout ByteBuffer) throws -> CommandComplete {
88
guard let string = buffer.readNullTerminatedString() else {

Sources/PostgresNIO/PostgresDatabase+Query.swift

Lines changed: 100 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,43 +2,128 @@ import NIO
22
import Logging
33

44
extension PostgresDatabase {
5-
public func query(_ string: String, _ binds: [PostgresData] = []) -> EventLoopFuture<[PostgresRow]> {
5+
public func query(
6+
_ string: String,
7+
_ binds: [PostgresData] = []
8+
) -> EventLoopFuture<PostgresQueryResult> {
69
var rows: [PostgresRow] = []
7-
return query(string, binds) { rows.append($0) }.map { rows }
10+
var metadata: PostgresQueryMetadata?
11+
return self.query(string, binds, onMetadata: {
12+
metadata = $0
13+
}) {
14+
rows.append($0)
15+
}.map {
16+
.init(metadata: metadata!, rows: rows)
17+
}
818
}
9-
10-
public func query(_ string: String, _ binds: [PostgresData] = [], _ onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture<Void> {
11-
let query = PostgresParameterizedQuery(query: string, binds: binds, onRow: onRow)
19+
20+
public func query(
21+
_ string: String,
22+
_ binds: [PostgresData] = [],
23+
onMetadata: @escaping (PostgresQueryMetadata) -> () = { _ in },
24+
onRow: @escaping (PostgresRow) throws -> ()
25+
) -> EventLoopFuture<Void> {
26+
let query = PostgresParameterizedQuery(
27+
query: string,
28+
binds: binds,
29+
onMetadata: onMetadata,
30+
onRow: onRow
31+
)
1232
return self.send(query, logger: self.logger)
1333
}
1434
}
1535

36+
public struct PostgresQueryResult {
37+
public let metadata: PostgresQueryMetadata
38+
public let rows: [PostgresRow]
39+
}
40+
41+
extension PostgresQueryResult: Collection {
42+
public typealias Index = Int
43+
public typealias Element = PostgresRow
44+
45+
public var startIndex: Int {
46+
self.rows.startIndex
47+
}
48+
49+
public var endIndex: Int {
50+
self.rows.endIndex
51+
}
52+
53+
public subscript(position: Int) -> PostgresRow {
54+
self.rows[position]
55+
}
56+
57+
public func index(after i: Int) -> Int {
58+
self.rows.index(after: i)
59+
}
60+
}
61+
62+
public struct PostgresQueryMetadata {
63+
public let command: String
64+
public var oid: Int?
65+
public var rows: Int?
66+
67+
init?(string: String) {
68+
let parts = string.split(separator: " ")
69+
guard parts.count >= 1 else {
70+
return nil
71+
}
72+
switch parts[0] {
73+
case "INSERT":
74+
// INSERT oid rows
75+
guard parts.count == 3 else {
76+
return nil
77+
}
78+
self.command = .init(parts[0])
79+
self.oid = Int(parts[1])
80+
self.rows = Int(parts[2])
81+
case "DELETE", "UPDATE", "SELECT", "MOVE", "FETCH", "COPY":
82+
// <cmd> rows
83+
guard parts.count == 2 else {
84+
return nil
85+
}
86+
self.command = .init(parts[0])
87+
self.oid = nil
88+
self.rows = Int(parts[1])
89+
default:
90+
// <cmd>
91+
self.command = string
92+
self.oid = nil
93+
self.rows = nil
94+
}
95+
}
96+
}
97+
1698
// MARK: Private
1799

18100
private final class PostgresParameterizedQuery: PostgresRequest {
19101
let query: String
20102
let binds: [PostgresData]
103+
var onMetadata: (PostgresQueryMetadata) -> ()
21104
var onRow: (PostgresRow) throws -> ()
22105
var rowLookupTable: PostgresRow.LookupTable?
23106
var resultFormatCodes: [PostgresFormatCode]
24107
var logger: Logger?
25-
108+
26109
init(
27110
query: String,
28111
binds: [PostgresData],
112+
onMetadata: @escaping (PostgresQueryMetadata) -> (),
29113
onRow: @escaping (PostgresRow) throws -> ()
30114
) {
31115
self.query = query
32116
self.binds = binds
117+
self.onMetadata = onMetadata
33118
self.onRow = onRow
34119
self.resultFormatCodes = [.binary]
35120
}
36-
121+
37122
func log(to logger: Logger) {
38123
self.logger = logger
39124
logger.debug("\(self.query) \(self.binds)")
40125
}
41-
126+
42127
func respond(to message: PostgresMessage) throws -> [PostgresMessage]? {
43128
if case .error = message.identifier {
44129
// we should continue after errors
@@ -77,6 +162,11 @@ private final class PostgresParameterizedQuery: PostgresRequest {
77162
}
78163
return []
79164
case .commandComplete:
165+
let complete = try PostgresMessage.CommandComplete(message: message)
166+
guard let metadata = PostgresQueryMetadata(string: complete.tag) else {
167+
throw PostgresError.protocol("Unexpected query metadata: \(complete.tag)")
168+
}
169+
self.onMetadata(metadata)
80170
return []
81171
case .notice:
82172
return []
@@ -87,7 +177,7 @@ private final class PostgresParameterizedQuery: PostgresRequest {
87177
default: throw PostgresError.protocol("Unexpected message during query: \(message)")
88178
}
89179
}
90-
180+
91181
func start() throws -> [PostgresMessage] {
92182
let parse = PostgresMessage.Parse(
93183
statementName: "",
@@ -109,7 +199,7 @@ private final class PostgresParameterizedQuery: PostgresRequest {
109199
portalName: "",
110200
maxRows: 0
111201
)
112-
202+
113203
let sync = PostgresMessage.Sync()
114204
return try [parse.message(), describe.message(), bind.message(), execute.message(), sync.message()]
115205
}

Tests/PostgresNIOTests/PostgresNIOTests.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,23 @@ final class PostgresNIOTests: XCTestCase {
855855
let res = try conn.query("SELECT $1::text as foo", [String?.none.postgresData!]).wait()
856856
XCTAssertEqual(res[0].column("foo")?.string, nil)
857857
}
858+
859+
func testUpdateMetadata() throws {
860+
let conn = try PostgresConnection.test(on: eventLoop).wait()
861+
defer { try! conn.close().wait() }
862+
_ = try conn.simpleQuery("DROP TABLE IF EXISTS test_table").wait()
863+
_ = try conn.simpleQuery("CREATE TABLE test_table(pk int PRIMARY KEY)").wait()
864+
_ = try conn.simpleQuery("INSERT INTO test_table VALUES(1)").wait()
865+
try conn.query("DELETE FROM test_table", onMetadata: { metadata in
866+
XCTAssertEqual(metadata.command, "DELETE")
867+
XCTAssertEqual(metadata.oid, nil)
868+
XCTAssertEqual(metadata.rows, 1)
869+
}, onRow: { _ in }).wait()
870+
let rows = try conn.query("DELETE FROM test_table").wait()
871+
XCTAssertEqual(rows.metadata.command, "DELETE")
872+
XCTAssertEqual(rows.metadata.oid, nil)
873+
XCTAssertEqual(rows.metadata.rows, 0)
874+
}
858875
}
859876

860877
let isLoggingConfigured: Bool = {

0 commit comments

Comments
 (0)