From 9d1ef8a6a1394e38d41af252a24b7008eb876acc Mon Sep 17 00:00:00 2001 From: itsknk Date: Sun, 21 Jul 2024 16:03:33 -0700 Subject: [PATCH] prioritize interfaces from device preference --- exo/networking/grpc/grpc_discovery.py | 85 +++++++++++++++------------ 1 file changed, 47 insertions(+), 38 deletions(-) diff --git a/exo/networking/grpc/grpc_discovery.py b/exo/networking/grpc/grpc_discovery.py index 9d80cdc36..fce45cd27 100644 --- a/exo/networking/grpc/grpc_discovery.py +++ b/exo/networking/grpc/grpc_discovery.py @@ -32,21 +32,17 @@ def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_por self.listen_port = listen_port self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port self.broadcast_interval = broadcast_interval - self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {} + self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float, int]] = {} # Added priority self.broadcast_task = None self.listen_task = None self.cleanup_task = None def get_network_interfaces(self): if psutil.MACOS: - # Use system_profiler to get detailed network information network_info = subprocess.check_output(['system_profiler', 'SPNetworkDataType', '-json']).decode('utf-8') network_data = json.loads(network_info) - thunderbolt_interfaces = [] - wifi_interfaces = [] - other_interfaces = [] - + interfaces = [] for interface in network_data.get('SPNetworkDataType', []): interface_name = interface.get('interface') interface_type = interface.get('type', '').lower() @@ -55,23 +51,22 @@ def get_network_interfaces(self): if not interface_name: continue - # Check for Thunderbolt-related keywords - if 'thunderbolt' in hardware or 'thunderbolt' in interface_type: - thunderbolt_interfaces.append(interface_name) - elif interface_type == 'wi-fi': - wifi_interfaces.append(interface_name) - else: - other_interfaces.append(interface_name) - - if self.DEBUG >= 2: - print(f"Thunderbolt interfaces: {thunderbolt_interfaces}") - print(f"WiFi interfaces: {wifi_interfaces}") - print(f"Other interfaces: {other_interfaces}") + priority = self.get_interface_priority(interface_name, interface_type, hardware) + interfaces.append((interface_name, priority)) - # Prioritize Thunderbolt, then WiFi, then others - return thunderbolt_interfaces + wifi_interfaces + other_interfaces + return sorted(interfaces, key=lambda x: x[1], reverse=True) else: - return netifaces.interfaces() + return [(iface, self.get_interface_priority(iface)) for iface in netifaces.interfaces()] + + def get_interface_priority(self, interface_name, interface_type='', hardware=''): + if 'thunderbolt' in hardware or 'thunderbolt' in interface_type: + return 3 + elif interface_type == 'wi-fi' or 'airport' in hardware: + return 2 + elif interface_name.startswith('en'): + return 1 + else: + return 0 async def start(self): self.device_capabilities = device_capabilities() @@ -111,24 +106,24 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.") break # No new peers found in the grace period, we are done - return [peer_handle for peer_handle, _, _ in self.known_peers.values()] + return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()] async def task_broadcast_presence(self): - message = json.dumps({ - "type": "discovery", - "node_id": self.node_id, - "grpc_port": self.node_port, - "device_capabilities": self.device_capabilities.to_dict() - }).encode('utf-8') - while True: try: - # Update interfaces periodically interfaces = self.get_network_interfaces() - for interface in interfaces: + for interface, priority in interfaces: try: - if DEBUG_DISCOVERY >= 3: print(f"Broadcasting on interface: {interface}") + if DEBUG_DISCOVERY >= 3: print(f"Broadcasting on interface: {interface} with priority: {priority}") + message = json.dumps({ + "type": "discovery", + "node_id": self.node_id, + "grpc_port": self.node_port, + "device_capabilities": self.device_capabilities.to_dict(), + "priority": priority + }).encode('utf-8') + transport, _ = await asyncio.get_event_loop().create_datagram_endpoint( lambda: asyncio.DatagramProtocol(), local_addr=(netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr'], 0), @@ -151,11 +146,25 @@ async def on_listen_message(self, data, addr): peer_id = message['node_id'] peer_host = addr[0] peer_port = message['grpc_port'] + new_priority = message['priority'] device_capabilities = DeviceCapabilities(**message['device_capabilities']) - if peer_id not in self.known_peers: - self.known_peers[peer_id] = (GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time(), time.time()) - if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}") - self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time()) + + if peer_id in self.known_peers: + _, _, _, current_priority = self.known_peers[peer_id] + if new_priority > current_priority: + await self.reconnect_peer(peer_id, peer_host, peer_port, new_priority, device_capabilities) + else: + peer_handle = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities) + self.known_peers[peer_id] = (peer_handle, time.time(), time.time(), new_priority) + if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port} with priority {new_priority}") + + async def reconnect_peer(self, peer_id, new_host, new_port, new_priority, device_capabilities): + old_handle, connected_at, _, _ = self.known_peers[peer_id] + await old_handle.disconnect() + new_handle = GRPCPeerHandle(peer_id, f"{new_host}:{new_port}", device_capabilities) + await new_handle.connect() + self.known_peers[peer_id] = (new_handle, connected_at, time.time(), new_priority) + if DEBUG_DISCOVERY >= 2: print(f"Reconnected to peer {peer_id} at {new_host}:{new_port} with new priority {new_priority}") async def task_listen_for_peers(self): try: @@ -174,10 +183,10 @@ async def task_cleanup_peers(self): current_time = time.time() timeout = 15 * self.broadcast_interval peers_to_remove = [ - peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values() if + peer_handle.id() for peer_handle, connected_at, last_seen, _ in self.known_peers.values() if (not await peer_handle.is_connected() and current_time - connected_at > timeout) or current_time - last_seen > timeout ] - if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()}) + if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}, priority={priority}" for peer_handle, connected_at, last_seen, priority in self.known_peers.values()}) if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: print(f"Cleaning up peers: {peers_to_remove}") for peer_id in peers_to_remove: if peer_id in self.known_peers: del self.known_peers[peer_id]