From 35a704b46af6b8cdb6c36bcf053623f8f8fde35d Mon Sep 17 00:00:00 2001 From: Marius Seufzer Date: Mon, 19 Jun 2023 23:54:57 +0200 Subject: [PATCH 1/2] add typed queries --- .../Connection/PostgresConnection.swift | 19 ++++++++ Sources/PostgresNIO/Data/PostgresRow.swift | 4 ++ Sources/PostgresNIO/New/PostgresQuery.swift | 6 +++ .../PostgresNIO/New/PostgresRowSequence.swift | 29 +++++++++++- .../IntegrationTests/TypedQueriesTests.swift | 47 +++++++++++++++++++ 5 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 Tests/IntegrationTests/TypedQueriesTests.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index c24041c9..ef324fbd 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -428,6 +428,25 @@ extension PostgresConnection { throw error // rethrow with more metadata } } + + /// Run a query on the Postgres server the connection is connected to. + /// + /// - Parameters: + /// - query: A ``PostgresTypedQuery`` to run + /// - logger: The `Logger` to log into for the query + /// - file: The file, the query was started in. Used for better error reporting. + /// - line: The line, the query was started in. Used for better error reporting. + /// - Returns: A ``PostgresTypedSequence`` containing typed rows the server sent as the query result. + @discardableResult + public func query( + _ query: T, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> PostgresTypedSequence { + let rowSequence = try await self.query(query.sql, logger: logger, file: file, line: line) + return PostgresTypedSequence(rowSequence: rowSequence) + } } // MARK: EventLoopFuture interface diff --git a/Sources/PostgresNIO/Data/PostgresRow.swift b/Sources/PostgresNIO/Data/PostgresRow.swift index e3aea692..e89ff566 100644 --- a/Sources/PostgresNIO/Data/PostgresRow.swift +++ b/Sources/PostgresNIO/Data/PostgresRow.swift @@ -22,6 +22,10 @@ public struct PostgresRow: Sendable { } } +public protocol PostgresTypedRow { + init(from row: PostgresRow) throws +} + extension PostgresRow: Equatable { public static func ==(lhs: Self, rhs: Self) -> Bool { // we don't need to compare the lookup table here, as the looup table is only derived diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 381370e9..dc39a419 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -13,6 +13,12 @@ public struct PostgresQuery: Sendable, Hashable { } } +public protocol PostgresTypedQuery { + associatedtype Row: PostgresTypedRow + + var sql: PostgresQuery { get } +} + extension PostgresQuery: ExpressibleByStringInterpolation { public init(stringInterpolation: StringInterpolation) { self.sql = stringInterpolation.sql diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index ccf4f69c..7c2b4bb6 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -43,7 +43,7 @@ extension PostgresRowSequence { self.columns = columns } - public mutating func next() async throws -> PostgresRow? { + public func next() async throws -> PostgresRow? { if let dataRow = try await self.backing.next() { return PostgresRow( data: dataRow, @@ -56,6 +56,33 @@ extension PostgresRowSequence { } } +public struct PostgresTypedSequence: AsyncSequence { + public typealias Element = T + + let rowSequence: PostgresRowSequence + + init(rowSequence: PostgresRowSequence) { + self.rowSequence = rowSequence + } + + public func makeAsyncIterator() -> AsyncIterator { + AsyncIterator(rowSequence: rowSequence.makeAsyncIterator()) + } +} + +extension PostgresTypedSequence { + public struct AsyncIterator: AsyncIteratorProtocol { + let rowSequence: PostgresRowSequence.AsyncIterator + + public func next() async throws -> T? { + guard let row = try await self.rowSequence.next() else { + return nil + } + return try T.init(from: row) + } + } +} + extension PostgresRowSequence { public func collect() async throws -> [PostgresRow] { var result = [PostgresRow]() diff --git a/Tests/IntegrationTests/TypedQueriesTests.swift b/Tests/IntegrationTests/TypedQueriesTests.swift new file mode 100644 index 00000000..78ef0326 --- /dev/null +++ b/Tests/IntegrationTests/TypedQueriesTests.swift @@ -0,0 +1,47 @@ +import Logging +import XCTest +import PostgresNIO + +final class TypedQueriesTests: XCTestCase { + func testTypedPostgresQuery() async throws { + struct MyQuery: PostgresTypedQuery { + struct Row: PostgresTypedRow { + let id: Int + let name: String + + init(from row: PostgresRow) throws { + (id, name) = try row.decode((Int, String).self, context: .default) + } + } + + var sql: PostgresQuery { + "SELECT id, name FROM users" + } + } + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + try await withTestConnection(on: eventLoop) { connection in + let createTableQuery = PostgresQuery(unsafeSQL: """ + CREATE TABLE users ( + id SERIAL PRIMARY KEY, + name character varying(255) NOT NULL + ); + """) + let name = "foobar" + + try await connection.query(createTableQuery, logger: .psqlTest) + try await connection.query("INSERT INTO users (name) VALUES (\(name));", logger: .psqlTest) + + let rows = try await connection.query(MyQuery(), logger: .psqlTest) + for try await row in rows { + XCTAssertEqual(row.name, name) + } + + let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE users") + try await connection.query(dropQuery, logger: .psqlTest) + } + } +} From 8a4f11a83bd87ab28dbd1b61613bbdb102c75245 Mon Sep 17 00:00:00 2001 From: Marius Seufzer Date: Wed, 21 Jun 2023 12:24:08 +0200 Subject: [PATCH 2/2] wip macros --- Package.swift | 14 +++- Sources/PostgresNIO/New/PostgresQuery.swift | 31 ++++++++ Sources/PostgresNIO/Utilities/Macros.swift | 5 ++ .../PostgresNIODiagnostic.swift | 21 +++++ .../PostgresNIOMacros/PostgresNIOMacro.swift | 76 +++++++++++++++++++ .../PostgresNIOTests/Macros/MacroTests.swift | 45 +++++++++++ 6 files changed, 191 insertions(+), 1 deletion(-) create mode 100644 Sources/PostgresNIO/Utilities/Macros.swift create mode 100644 Sources/PostgresNIOMacros/PostgresNIODiagnostic.swift create mode 100644 Sources/PostgresNIOMacros/PostgresNIOMacro.swift create mode 100644 Tests/PostgresNIOTests/Macros/MacroTests.swift diff --git a/Package.swift b/Package.swift index c1cb4bda..87e12046 100644 --- a/Package.swift +++ b/Package.swift @@ -1,5 +1,7 @@ -// swift-tools-version:5.6 +// swift-tools-version:5.9 + import PackageDescription +import CompilerPluginSupport let package = Package( name: "postgres-nio", @@ -20,6 +22,7 @@ let package = Package( .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "3.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.0.0"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.2"), + .package(url: "https://github.com/apple/swift-syntax.git", branch: "main") ], targets: [ .target( @@ -36,6 +39,14 @@ let package = Package( .product(name: "NIOTLS", package: "swift-nio"), .product(name: "NIOSSL", package: "swift-nio-ssl"), .product(name: "NIOFoundationCompat", package: "swift-nio"), + .target(name: "PostgresNIOMacros") + ] + ), + .macro( + name: "PostgresNIOMacros", + dependencies: [ + .product(name: "SwiftSyntaxMacros", package: "swift-syntax"), + .product(name: "SwiftCompilerPlugin", package: "swift-syntax") ] ), .testTarget( @@ -44,6 +55,7 @@ let package = Package( .target(name: "PostgresNIO"), .product(name: "NIOEmbedded", package: "swift-nio"), .product(name: "NIOTestUtils", package: "swift-nio"), + .product(name: "SwiftSyntaxMacrosTestSupport", package: "swift-syntax"), ] ), .testTarget( diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index dc39a419..10f5bf58 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -13,6 +13,37 @@ public struct PostgresQuery: Sendable, Hashable { } } +struct PostgresMacroQuery: ExpressibleByStringInterpolation { + var sql: String + + public init(stringInterpolation: StringInterpolation) { + sql = stringInterpolation.sql + } + + public init(stringLiteral value: String) { + sql = value + } + + struct StringInterpolation: StringInterpolationProtocol { + typealias StringLiteralType = String + + var sql: String + + init(literalCapacity: Int, interpolationCount: Int) { + sql = "" + } + + mutating func appendLiteral(_ literal: String) { + sql.append(contentsOf: literal) + } + + mutating func appendInterpolation(_ sql: String, type: T.Type) {} + } +} + +@Query("SELECT \("id", type: Int.self) FROM users") +struct GetAllUsersQuery {} + public protocol PostgresTypedQuery { associatedtype Row: PostgresTypedRow diff --git a/Sources/PostgresNIO/Utilities/Macros.swift b/Sources/PostgresNIO/Utilities/Macros.swift new file mode 100644 index 00000000..dd6109ee --- /dev/null +++ b/Sources/PostgresNIO/Utilities/Macros.swift @@ -0,0 +1,5 @@ +@attached(member) +macro Query(_ query: PostgresMacroQuery) = #externalMacro( + module: "PostgresNIOMacros", + type: "PostgresTypedQueryMacro" +) diff --git a/Sources/PostgresNIOMacros/PostgresNIODiagnostic.swift b/Sources/PostgresNIOMacros/PostgresNIODiagnostic.swift new file mode 100644 index 00000000..866fa482 --- /dev/null +++ b/Sources/PostgresNIOMacros/PostgresNIODiagnostic.swift @@ -0,0 +1,21 @@ +import SwiftDiagnostics + +enum PostgresNIODiagnostic: String, DiagnosticMessage { + case wrongArgument + case notAStruct + + var message: String { + switch self { + case .wrongArgument: + return "Invalid argument" + case .notAStruct: + return "Macro only works with structs" + } + } + + var diagnosticID: SwiftDiagnostics.MessageID { + MessageID(domain: "PostgresNIOMacros", id: rawValue) + } + + var severity: SwiftDiagnostics.DiagnosticSeverity { .error } +} diff --git a/Sources/PostgresNIOMacros/PostgresNIOMacro.swift b/Sources/PostgresNIOMacros/PostgresNIOMacro.swift new file mode 100644 index 00000000..2fdb53a0 --- /dev/null +++ b/Sources/PostgresNIOMacros/PostgresNIOMacro.swift @@ -0,0 +1,76 @@ +import SwiftCompilerPlugin +import SwiftSyntax +import SwiftSyntaxBuilder +import SwiftSyntaxMacros +import SwiftDiagnostics + +public struct PostgresTypedQueryMacro: MemberMacro { + public static func expansion( + of node: AttributeSyntax, + providingMembersOf declaration: some DeclGroupSyntax, + in context: some MacroExpansionContext + ) throws -> [DeclSyntax] { + guard declaration.is(StructDeclSyntax.self) else { + context.diagnose(Diagnostic(node: Syntax(node), message: PostgresNIODiagnostic.notAStruct)) + return [] + } + + guard let elements = node.argument?.as(TupleExprElementListSyntax.self)? + .first?.as(TupleExprElementSyntax.self)? + .expression.as(StringLiteralExprSyntax.self)?.segments else { + // TODO: Be more specific about this error + context.diagnose(Diagnostic(node: Syntax(node), message: PostgresNIODiagnostic.wrongArgument)) + return [] + } + + + + var outputTypes: [(String, String)] = [] + for tup in elements { + if let expression = tup.as(ExpressionSegmentSyntax.self) { + outputTypes.append(extractColumnTypes(expression)) + } + } + + let rowStruct = try StructDeclSyntax("struct Row") { + for outputType in outputTypes { + MemberDeclListItemSyntax.init(decl: DeclSyntax(stringLiteral: "let \(outputType.0): \(outputType.1)")) + } + try InitializerDeclSyntax("init(from: PostgresRow) throws") { + // TODO: (id, name) = try row.decode((Int, String).self, context: .default) + } + } + + return [ +// DeclSyntax(rowStruct) + ] + } + + /// Returns ("name", "String") + private static func extractColumnTypes(_ node: ExpressionSegmentSyntax) -> (String, String) { + let tupleElements = node.expressions + guard tupleElements.count == 2 else { + fatalError("Expected tuple with exactly two elements") + } + + // First element needs to be the column name + var iterator = tupleElements.makeIterator() + guard let columnName = iterator.next()?.expression.as(StringLiteralExprSyntax.self)? + .segments.first?.as(StringSegmentSyntax.self)?.content + .text else { + fatalError("Expected column name") + } + + guard let columnType = iterator.next()?.expression.as(MemberAccessExprSyntax.self)?.base?.as(IdentifierExprSyntax.self)?.identifier.text else { + fatalError("Expected column type") + } + return (columnName, columnType) + } +} + +@main +struct PostgresNIOMacros: CompilerPlugin { + let providingMacros: [Macro.Type] = [ + PostgresTypedQueryMacro.self + ] +} diff --git a/Tests/PostgresNIOTests/Macros/MacroTests.swift b/Tests/PostgresNIOTests/Macros/MacroTests.swift new file mode 100644 index 00000000..1ff68c76 --- /dev/null +++ b/Tests/PostgresNIOTests/Macros/MacroTests.swift @@ -0,0 +1,45 @@ +import SwiftSyntaxMacros +import SwiftSyntaxMacrosTestSupport +import XCTest +import PostgresNIOMacros + +let testMacros: [String: Macro.Type] = [ + "Query": PostgresTypedQueryMacro.self, +] + +final class MacrosTests: XCTestCase { + func testMacro() { + assertMacroExpansion( + #""" + @Query("SELECT \("id", Int.self), \("name", String.self) FROM users") + struct MyQuery {} + """#, + expandedSource: #""" + struct MyQuery { + struct Row: PostgresTypedRow { + let id: Int + let name: String + } + } + """#, +// expandedSource: #""" +// struct MyQuery: PostgresTypedQuery { +// struct Row: PostgresTypedRow { +// let id: Int +// let name: String +// +// init(from row: PostgresRow) throws { +// (id, name) = try row.decode((Int, String).self, context: .default) +// } +// } +// +// var sql: PostgresQuery { +// "SELECT id, name FROM users" +// } +// } +// """#, + macros: testMacros + ) + } +} +