From c31e124170f9d84f4ed6234c8d88b61da2693393 Mon Sep 17 00:00:00 2001 From: itsknk Date: Fri, 19 Jul 2024 23:42:26 -0700 Subject: [PATCH] prioritise thunderbolt over WIFI add netifaces dependency use psutil instead of platform use psutil instead of platform fix typo prioritise thunderbolt over WIFI optimize network discovery and dynamic interface handling optimize network discovery and dynamic interface handling implement network interface detection and prioritization implement network interface detection and prioritization --- exo/networking/grpc/grpc_discovery.py | 83 +++++++++++++++++++++------ setup.py | 1 + 2 files changed, 68 insertions(+), 16 deletions(-) diff --git a/exo/networking/grpc/grpc_discovery.py b/exo/networking/grpc/grpc_discovery.py index 3064dff48..9d80cdc36 100644 --- a/exo/networking/grpc/grpc_discovery.py +++ b/exo/networking/grpc/grpc_discovery.py @@ -2,6 +2,9 @@ import json import socket import time +import netifaces +import psutil +import subprocess from typing import List, Dict, Callable, Tuple, Coroutine from ..discovery import Discovery from ..peer_handle import PeerHandle @@ -21,7 +24,6 @@ def connection_made(self, transport): def datagram_received(self, data, addr): asyncio.create_task(self.on_message(data, addr)) - class GRPCDiscovery(Discovery): def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES): self.node_id = node_id @@ -35,6 +37,42 @@ def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_por self.listen_task = None self.cleanup_task = None + def get_network_interfaces(self): + if psutil.MACOS: + # Use system_profiler to get detailed network information + network_info = subprocess.check_output(['system_profiler', 'SPNetworkDataType', '-json']).decode('utf-8') + network_data = json.loads(network_info) + + thunderbolt_interfaces = [] + wifi_interfaces = [] + other_interfaces = [] + + for interface in network_data.get('SPNetworkDataType', []): + interface_name = interface.get('interface') + interface_type = interface.get('type', '').lower() + hardware = interface.get('hardware', '').lower() + + if not interface_name: + continue + + # Check for Thunderbolt-related keywords + if 'thunderbolt' in hardware or 'thunderbolt' in interface_type: + thunderbolt_interfaces.append(interface_name) + elif interface_type == 'wi-fi': + wifi_interfaces.append(interface_name) + else: + other_interfaces.append(interface_name) + + if self.DEBUG >= 2: + print(f"Thunderbolt interfaces: {thunderbolt_interfaces}") + print(f"WiFi interfaces: {wifi_interfaces}") + print(f"Other interfaces: {other_interfaces}") + + # Prioritize Thunderbolt, then WiFi, then others + return thunderbolt_interfaces + wifi_interfaces + other_interfaces + else: + return netifaces.interfaces() + async def start(self): self.device_capabilities = device_capabilities() self.broadcast_task = asyncio.create_task(self.task_broadcast_presence()) @@ -76,13 +114,6 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: return [peer_handle for peer_handle, _, _ in self.known_peers.values()] async def task_broadcast_presence(self): - transport, _ = await asyncio.get_event_loop().create_datagram_endpoint( - lambda: asyncio.DatagramProtocol(), - local_addr=('0.0.0.0', 0), - family=socket.AF_INET) - sock = transport.get_extra_info('socket') - sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) - message = json.dumps({ "type": "discovery", "node_id": self.node_id, @@ -92,13 +123,26 @@ async def task_broadcast_presence(self): while True: try: - if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}") - transport.sendto(message, ('', self.broadcast_port)) - await asyncio.sleep(self.broadcast_interval) + # Update interfaces periodically + interfaces = self.get_network_interfaces() + + for interface in interfaces: + try: + if DEBUG_DISCOVERY >= 3: print(f"Broadcasting on interface: {interface}") + transport, _ = await asyncio.get_event_loop().create_datagram_endpoint( + lambda: asyncio.DatagramProtocol(), + local_addr=(netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr'], 0), + family=socket.AF_INET) + sock = transport.get_extra_info('socket') + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + transport.sendto(message, ('', self.broadcast_port)) + transport.close() + except Exception as e: + if DEBUG_DISCOVERY >= 2: print(f"Error broadcasting on interface {interface}: {e}") except Exception as e: - print(f"Error in broadcast presence: {e}") - import traceback - print(traceback.format_exc()) + if DEBUG_DISCOVERY >= 2: print(f"Error updating or accessing interfaces: {e}") + + await asyncio.sleep(self.broadcast_interval) async def on_listen_message(self, data, addr): message = json.loads(data.decode('utf-8')) @@ -114,8 +158,15 @@ async def on_listen_message(self, data, addr): self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time()) async def task_listen_for_peers(self): - await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port)) - if DEBUG_DISCOVERY >= 2: print("Started listen task") + try: + if DEBUG_DISCOVERY >= 2: print("Starting to listen on all interfaces") + await asyncio.get_event_loop().create_datagram_endpoint( + lambda: ListenProtocol(self.on_listen_message), + local_addr=('0.0.0.0', self.listen_port) + ) + if DEBUG_DISCOVERY >= 2: print("Started listen task on all interfaces") + except Exception as e: + if DEBUG_DISCOVERY >= 2: print(f"Error setting up listening: {e}") async def task_cleanup_peers(self): while True: diff --git a/setup.py b/setup.py index 4b1db878c..ac653a717 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ "tqdm==4.66.4", "transformers==4.41.2", "uuid==1.30", + "netifaces==0.11.0", "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@a9f5a764dc640a5e5cbaaeeee21df7c8ca37da38", ]