diff --git a/exo/networking/grpc/grpc_discovery.py b/exo/networking/grpc/grpc_discovery.py index 3064dff48..40d59ef92 100644 --- a/exo/networking/grpc/grpc_discovery.py +++ b/exo/networking/grpc/grpc_discovery.py @@ -2,6 +2,9 @@ import json import socket import time +import netifaces +import psutil +import subprocess from typing import List, Dict, Callable, Tuple, Coroutine from ..discovery import Discovery from ..peer_handle import PeerHandle @@ -9,6 +12,37 @@ from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES from exo import DEBUG_DISCOVERY +def get_interface_priority(interface_name, interface_type='', hardware=''): + if 'thunderbolt' in hardware or 'thunderbolt' in interface_type: + return 3 + elif interface_type == 'wi-fi' or 'airport' in hardware: + return 2 + elif interface_name.startswith('en'): + return 1 + else: + return 0 + +def get_network_interfaces(): + if psutil.MACOS: + network_info = subprocess.check_output(['system_profiler', 'SPNetworkDataType', '-json']).decode('utf-8') + network_data = json.loads(network_info) + + interfaces = [] + for interface in network_data.get('SPNetworkDataType', []): + interface_name = interface.get('interface') + interface_type = interface.get('type', '').lower() + hardware = interface.get('hardware', '').lower() + + if not interface_name: + continue + + priority = get_interface_priority(interface_name, interface_type, hardware) + interfaces.append((interface_name, priority)) + + return sorted(interfaces, key=lambda x: x[1], reverse=True) + else: + return [(iface, get_interface_priority(iface)) for iface in netifaces.interfaces()] + class ListenProtocol(asyncio.DatagramProtocol): def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]): super().__init__() @@ -21,7 +55,6 @@ def connection_made(self, transport): def datagram_received(self, data, addr): asyncio.create_task(self.on_message(data, addr)) - class GRPCDiscovery(Discovery): def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES): self.node_id = node_id @@ -30,7 +63,7 @@ def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_por self.listen_port = listen_port self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port self.broadcast_interval = broadcast_interval - self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {} + self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float, int]] = {} self.broadcast_task = None self.listen_task = None self.cleanup_task = None @@ -73,32 +106,38 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.") break # No new peers found in the grace period, we are done - return [peer_handle for peer_handle, _, _ in self.known_peers.values()] + return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()] async def task_broadcast_presence(self): - transport, _ = await asyncio.get_event_loop().create_datagram_endpoint( - lambda: asyncio.DatagramProtocol(), - local_addr=('0.0.0.0', 0), - family=socket.AF_INET) - sock = transport.get_extra_info('socket') - sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) - - message = json.dumps({ - "type": "discovery", - "node_id": self.node_id, - "grpc_port": self.node_port, - "device_capabilities": self.device_capabilities.to_dict() - }).encode('utf-8') - while True: try: - if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}") - transport.sendto(message, ('', self.broadcast_port)) - await asyncio.sleep(self.broadcast_interval) + interfaces = get_network_interfaces() + + for interface, priority in interfaces: + try: + if DEBUG_DISCOVERY >= 3: print(f"Broadcasting on interface: {interface} with priority: {priority}") + message = json.dumps({ + "type": "discovery", + "node_id": self.node_id, + "grpc_port": self.node_port, + "device_capabilities": self.device_capabilities.to_dict(), + "priority": priority + }).encode('utf-8') + + transport, _ = await asyncio.get_event_loop().create_datagram_endpoint( + lambda: asyncio.DatagramProtocol(), + local_addr=(netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr'], 0), + family=socket.AF_INET) + sock = transport.get_extra_info('socket') + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + transport.sendto(message, ('', self.broadcast_port)) + transport.close() + except Exception as e: + if DEBUG_DISCOVERY >= 2: print(f"Error broadcasting on interface {interface}: {e}") except Exception as e: - print(f"Error in broadcast presence: {e}") - import traceback - print(traceback.format_exc()) + if DEBUG_DISCOVERY >= 2: print(f"Error updating or accessing interfaces: {e}") + + await asyncio.sleep(self.broadcast_interval) async def on_listen_message(self, data, addr): message = json.loads(data.decode('utf-8')) @@ -107,15 +146,36 @@ async def on_listen_message(self, data, addr): peer_id = message['node_id'] peer_host = addr[0] peer_port = message['grpc_port'] + new_priority = message['priority'] device_capabilities = DeviceCapabilities(**message['device_capabilities']) - if peer_id not in self.known_peers: - self.known_peers[peer_id] = (GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time(), time.time()) - if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}") - self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time()) + + if peer_id in self.known_peers: + _, _, _, current_priority = self.known_peers[peer_id] + if new_priority > current_priority: + await self.reconnect_peer(peer_id, peer_host, peer_port, new_priority, device_capabilities) + else: + peer_handle = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities) + self.known_peers[peer_id] = (peer_handle, time.time(), time.time(), new_priority) + if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port} with priority {new_priority}") + + async def reconnect_peer(self, peer_id, new_host, new_port, new_priority, device_capabilities): + old_handle, connected_at, _, _ = self.known_peers[peer_id] + await old_handle.disconnect() + new_handle = GRPCPeerHandle(peer_id, f"{new_host}:{new_port}", device_capabilities) + await new_handle.connect() + self.known_peers[peer_id] = (new_handle, connected_at, time.time(), new_priority) + if DEBUG_DISCOVERY >= 2: print(f"Reconnected to peer {peer_id} at {new_host}:{new_port} with new priority {new_priority}") async def task_listen_for_peers(self): - await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port)) - if DEBUG_DISCOVERY >= 2: print("Started listen task") + try: + if DEBUG_DISCOVERY >= 2: print("Starting to listen on all interfaces") + await asyncio.get_event_loop().create_datagram_endpoint( + lambda: ListenProtocol(self.on_listen_message), + local_addr=('0.0.0.0', self.listen_port) + ) + if DEBUG_DISCOVERY >= 2: print("Started listen task on all interfaces") + except Exception as e: + if DEBUG_DISCOVERY >= 2: print(f"Error setting up listening: {e}") async def task_cleanup_peers(self): while True: @@ -123,10 +183,10 @@ async def task_cleanup_peers(self): current_time = time.time() timeout = 15 * self.broadcast_interval peers_to_remove = [ - peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values() if + peer_handle.id() for peer_handle, connected_at, last_seen, _ in self.known_peers.values() if (not await peer_handle.is_connected() and current_time - connected_at > timeout) or current_time - last_seen > timeout ] - if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()}) + if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}, priority={priority}" for peer_handle, connected_at, last_seen, priority in self.known_peers.values()}) if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: print(f"Cleaning up peers: {peers_to_remove}") for peer_id in peers_to_remove: if peer_id in self.known_peers: del self.known_peers[peer_id] diff --git a/exo/networking/grpc/test_network_interfaces.py b/exo/networking/grpc/test_network_interfaces.py new file mode 100644 index 000000000..83b9724b8 --- /dev/null +++ b/exo/networking/grpc/test_network_interfaces.py @@ -0,0 +1,58 @@ +import os +import sys +import time + +current_dir = os.path.dirname(os.path.abspath(__file__)) +parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_dir))) +sys.path.insert(0, parent_dir) + +from exo.networking.grpc.grpc_discovery import get_network_interfaces, get_interface_priority + +def simulate_broadcast_receive(peer_id, interface, priority, known_peers): + print(f"\nReceived broadcast from peer {peer_id} on interface {interface} with priority {priority}") + + if peer_id in known_peers: + current_interface, current_priority = known_peers[peer_id] + if priority > current_priority: + print(f"Higher priority detected. Disconnecting from {current_interface} and reconnecting on {interface}") + known_peers[peer_id] = (interface, priority) + else: + print(f"Keeping existing connection on {current_interface} with priority {current_priority}") + else: + print(f"New peer discovered. Connecting on {interface}") + known_peers[peer_id] = (interface, priority) + +def run_tests(): + print("Testing network interface detection, prioritization, and reconnection logic") + interfaces = get_network_interfaces() + + # Test 1: Check if interfaces are detected and sorted + assert len(interfaces) > 0, "No interfaces detected" + assert all(interfaces[i][1] >= interfaces[i+1][1] for i in range(len(interfaces)-1)), "Interfaces are not correctly sorted by priority" + + print("\nPrioritized list of interfaces:") + for i, (interface, priority) in enumerate(interfaces, 1): + print(f"{i}. {interface} (Priority: {priority})") + + # Test 2: Simulate broadcast receives and reconnection + known_peers = {} + + # Simulate receiving broadcasts from the same peer on different interfaces + peer_id = "test_peer_1" + for interface, priority in interfaces: + simulate_broadcast_receive(peer_id, interface, priority, known_peers) + time.sleep(0.1) # Add a small delay + + # Test 3: Check if the peer is connected to the highest priority interface + assert peer_id in known_peers, "Peer was not added to known_peers" + final_interface, final_priority = known_peers[peer_id] + assert final_priority == max(priority for _, priority in interfaces), "Peer is not connected to the highest priority interface" + + print("\nFinal known peers:") + for peer_id, (interface, priority) in known_peers.items(): + print(f"Peer {peer_id}: Connected on {interface} with priority {priority}") + + print("\nAll tests passed successfully!") + +if __name__ == "__main__": + run_tests() diff --git a/setup.py b/setup.py index 4b1db878c..ac653a717 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ "tqdm==4.66.4", "transformers==4.41.2", "uuid==1.30", + "netifaces==0.11.0", "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@a9f5a764dc640a5e5cbaaeeee21df7c8ca37da38", ]