diff --git a/Packages/ChainServices/NodeService/NodeService.swift b/Packages/ChainServices/NodeService/NodeService.swift index 43110fa74..95218e77a 100644 --- a/Packages/ChainServices/NodeService/NodeService.swift +++ b/Packages/ChainServices/NodeService/NodeService.swift @@ -17,24 +17,22 @@ public final class NodeService: Sendable { } public func getNodeSelected(chain: Chain) -> ChainNode { - guard let node = try? nodeStore.selectedNode(chain: chain.rawValue) else { + guard + let url = try? nodeStore.selectedNodeUrl(chain: chain), + let node = try? nodes(for: chain).first(where: { $0.node.url == url }) + else { return chain.defaultChainNode } return node } - + public func setNodeSelected(chain: Chain, node: Primitives.Node) throws { - if node.url.contains("gemnodes.com") { - return try nodeStore.deleteNodeSelected(chain: chain.rawValue) - } - guard let recordNode = try nodeStore.nodeRecord(chain: chain, url: node.url) else { - throw AnyError("Node node record") - } - try nodeStore.setNodeSelected(node: recordNode) + try nodeStore.setNodeSelected(chain: chain, url: node.url) } public func delete(chain: Chain, node: Primitives.Node) throws { - try nodeStore.deleteNode(chain: chain.rawValue, url: node.url) + try nodeStore.deleteNode(chain: chain, url: node.url) + try nodeStore.deleteNodeSelected(chain: chain) } public func nodes(for chain: Chain) throws -> [ChainNode] { @@ -45,7 +43,7 @@ public final class NodeService: Sendable { chain.europeChainNode, ] + nodes).unique() } - + public func update(chain: Chain, force: Bool = false) throws { // TODO: - implement later /* @@ -62,12 +60,10 @@ public final class NodeService: Sendable { // MARK: - NodeURLFetchable -extension NodeService: NodeURLFetchable { +extension NodeService: NodeURLFetchable { public func node(for chain: Chain) -> URL { - guard - let node = try? nodeStore.selectedNode(chain: chain.rawValue), - let url = URL(string: node.node.url) else { - return chain.defaultBaseUrl + guard let url = try? nodeStore.selectedNodeUrl(chain: chain)?.asURL else { + return chain.defaultBaseUrl } return url } diff --git a/Packages/ChainServices/NodeService/Tests/NodeServiceTests.swift b/Packages/ChainServices/NodeService/Tests/NodeServiceTests.swift new file mode 100644 index 000000000..79a867ee8 --- /dev/null +++ b/Packages/ChainServices/NodeService/Tests/NodeServiceTests.swift @@ -0,0 +1,53 @@ +// Copyright (c). Gem Wallet. All rights reserved. + +import Testing +import Primitives +import Store +import StoreTestKit +import NodeServiceTestKit + +@testable import NodeService + +struct NodeServiceTests { + + @Test + func getNodeSelectedReturnsDefaultWhenNotSet() { + #expect(NodeService.mock().getNodeSelected(chain: .ethereum).node.url == Chain.ethereum.defaultChainNode.node.url) + } + + @Test + func setNodeSelectedPersistsNode() throws { + let service = NodeService.mock() + + try service.setNodeSelected(chain: .ethereum, node: Chain.ethereum.asiaChainNode.node) + + #expect(service.getNodeSelected(chain: .ethereum).node.url == Chain.ethereum.asiaChainNode.node.url) + } + + @Test + func switchNode() throws { + let service = NodeService.mock() + + try service.setNodeSelected(chain: .ethereum, node: Chain.ethereum.asiaChainNode.node) + #expect(service.getNodeSelected(chain: .ethereum).node.url == Chain.ethereum.asiaChainNode.node.url) + + try service.setNodeSelected(chain: .ethereum, node: Chain.ethereum.europeChainNode.node) + #expect(service.getNodeSelected(chain: .ethereum).node.url == Chain.ethereum.europeChainNode.node.url) + } + + @Test + func nodeURLFetchableReturnsSelectedUrl() throws { + let service = NodeService.mock() + + try service.setNodeSelected(chain: .ethereum, node: Chain.ethereum.asiaChainNode.node) + + #expect(service.node(for: .ethereum) == Chain.ethereum.asiaChainNode.node.url.asURL) + } + + @Test + func nodeURLFetchableReturnsDefaultWhenNotSet() { + let service = NodeService.mock() + + #expect(service.node(for: .ethereum) == Chain.ethereum.defaultBaseUrl) + } +} diff --git a/Packages/ChainServices/Package.swift b/Packages/ChainServices/Package.swift index 55cfdf7d1..742185269 100644 --- a/Packages/ChainServices/Package.swift +++ b/Packages/ChainServices/Package.swift @@ -86,7 +86,7 @@ let package = Package( "ChainService", ], path: "NodeService", - exclude: ["TestKit"] + exclude: ["TestKit", "Tests"] ), .target( name: "NodeServiceTestKit", @@ -96,6 +96,16 @@ let package = Package( ], path: "NodeService/TestKit" ), + .testTarget( + name: "NodeServiceTests", + dependencies: [ + "NodeService", + "NodeServiceTestKit", + "Primitives", + .product(name: "StoreTestKit", package: "Store"), + ], + path: "NodeService/Tests" + ), .target( name: "WalletConnectorService", dependencies: [ diff --git a/Packages/Store/Sources/Migrations.swift b/Packages/Store/Sources/Migrations.swift index 1928fac44..dfe7b8258 100644 --- a/Packages/Store/Sources/Migrations.swift +++ b/Packages/Store/Sources/Migrations.swift @@ -346,6 +346,18 @@ public struct Migrations { } } + migrator.registerMigration("Migrate nodes_selected_v1 to \(NodeSelectedRecord.databaseTableName)") { db in + try? db.drop(table: NodeSelectedRecord.databaseTableName) + try? NodeSelectedRecord.create(db: db) + try? db.execute(sql: """ + INSERT INTO \(NodeSelectedRecord.databaseTableName) (chain, nodeUrl) + SELECT ns.chain, n.url + FROM nodes_selected_v1 ns + INNER JOIN \(NodeRecord.databaseTableName) n ON ns.nodeId = n.id + """) + try? db.drop(table: "nodes_selected_v1") + } + try migrator.migrate(dbQueue) } } diff --git a/Packages/Store/Sources/Models/NodeSelectedRecord.swift b/Packages/Store/Sources/Models/NodeSelectedRecord.swift index 2aea7ea54..c77f4be4f 100644 --- a/Packages/Store/Sources/Models/NodeSelectedRecord.swift +++ b/Packages/Store/Sources/Models/NodeSelectedRecord.swift @@ -4,43 +4,24 @@ import Foundation import Primitives import GRDB -struct NodeSelectedRecordInfo: FetchableRecord, Codable { - var node: NodeRecord - var nodeSelected: NodeSelectedRecord -} +public struct NodeSelectedRecord: Codable, FetchableRecord, PersistableRecord, TableRecord { + public static let databaseTableName: String = "nodes_selected" -public struct NodeSelectedRecord: Codable, FetchableRecord, PersistableRecord, TableRecord { - public static let databaseTableName: String = "nodes_selected_v1" - public enum Columns { - static let nodeId = Column("nodeId") static let chain = Column("chain") - static let auto = Column("auto") + static let nodeUrl = Column("nodeUrl") } - public var nodeId: Int public var chain: Chain - public var auto: Bool - - static let node = belongsTo(NodeRecord.self, key: "node") + public var nodeUrl: String + } extension NodeSelectedRecord: CreateTable { static func create(db: Database) throws { try db.create(table: Self.databaseTableName, ifNotExists: true) { - $0.column(Columns.nodeId.name, .text) - .notNull() - .indexed() - .references(NodeRecord.databaseTableName, onDelete: .cascade) - $0.column(Columns.chain.name, .text) - .primaryKey() - $0.column(Columns.auto.name, .boolean) + $0.column(Columns.chain.name, .text).primaryKey() + $0.column(Columns.nodeUrl.name, .text).notNull() } } } - -extension NodeSelectedRecordInfo { - func mapToChainNode() -> ChainNode { - return node.mapToChainNode() - } -} diff --git a/Packages/Store/Sources/Stores/NodeStore.swift b/Packages/Store/Sources/Stores/NodeStore.swift index 3f294914f..3d181ee09 100644 --- a/Packages/Store/Sources/Stores/NodeStore.swift +++ b/Packages/Store/Sources/Stores/NodeStore.swift @@ -26,88 +26,43 @@ public struct NodeStore: Sendable { } } - public func nodes() throws -> [ChainNode] { - try db.read { db in - try NodeRecord - .fetchAll(db) - } - .map { $0.mapToChainNode() } - } - public func nodes(chain: Chain) throws -> [ChainNode] { - return try nodeRecords(chain: chain) - .map { $0.mapToChainNode() } - } - - public func nodeRecord(chain: Chain, url: String) throws -> NodeRecord? { - try db.read { db in - try NodeRecord - .filter(NodeRecord.Columns.chain == chain.rawValue) - .filter(NodeRecord.Columns.url == url) - .fetchOne(db) - } - } - - public func nodeRecords(chain: Chain) throws -> [NodeRecord] { try db.read { db in try NodeRecord .filter(NodeRecord.Columns.chain == chain.rawValue) .fetchAll(db) + .map { $0.mapToChainNode() } } } - public func setNodeSelected(node: NodeRecord) throws { + public func setNodeSelected(chain: Chain, url: String) throws { try db.write { (db: Database) in - guard let nodeId = node.id else { - throw AnyError("no node id") - } - try NodeSelectedRecord(nodeId: nodeId, chain: node.chain, auto: true) - .upsert(db) - } - } - - public func deleteNodeSelected(chain: String) throws { - return try db.write { (db: Database) in - try NodeSelectedRecord - .filter(NodeRecord.Columns.chain == chain) - .deleteAll(db) + try NodeSelectedRecord(chain: chain, nodeUrl: url).upsert(db) } } - public func deleteNode(chain: String, url: String) throws { - return try db.write { (db: Database) in + public func deleteNode(chain: Chain, url: String) throws { + _ = try db.write { db in try NodeRecord - .filter(NodeRecord.Columns.chain == chain && NodeRecord.Columns.url == url) + .filter(NodeRecord.Columns.chain == chain.rawValue && NodeRecord.Columns.url == url) .deleteAll(db) } } - public func selectedNodes() throws -> [ChainNode] { + public func selectedNodeUrl(chain: Chain) throws -> String? { try db.read { db in try NodeSelectedRecord - .including(required: NodeSelectedRecord.node) - .asRequest(of: NodeSelectedRecordInfo.self) - .fetchAll(db) - .map { $0.mapToChainNode() } + .filter(NodeSelectedRecord.Columns.chain == chain.rawValue) + .fetchOne(db)? + .nodeUrl } } - public func selectedNode(chain: String) throws -> ChainNode? { - try db.read { db in + public func deleteNodeSelected(chain: Chain) throws { + _ = try db.write { (db: Database) in try NodeSelectedRecord - .including(required: NodeSelectedRecord.node) - .filter(NodeRecord.Columns.chain == chain) - .asRequest(of: NodeSelectedRecordInfo.self) - .fetchOne(db) - .map { $0.mapToChainNode() } - } - } - - public func allNodes() throws -> [Node] { - try db.read { db in - try NodeRecord - .fetchAll(db) - .map { $0.mapToNode() } + .filter(NodeRecord.Columns.chain == chain.rawValue) + .deleteAll(db) } } }