Skip to content

Commit

Permalink
prioritise thunderbolt over WIFI
Browse files Browse the repository at this point in the history
add netifaces dependency

use psutil instead of platform

use psutil instead of platform

fix typo

prioritise thunderbolt over WIFI

optimize network discovery and dynamic interface handling

optimize network discovery and dynamic interface handling

implement network interface detection and prioritization

implement network interface detection and prioritization

prioritize interfaces from device preference

prioritize interfaces from device preference
  • Loading branch information
itsknk committed Jul 21, 2024
1 parent 4d962ff commit c17cf8d
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 31 deletions.
122 changes: 91 additions & 31 deletions exo/networking/grpc/grpc_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import json
import socket
import time
import netifaces
import psutil
import subprocess
from typing import List, Dict, Callable, Tuple, Coroutine
from ..discovery import Discovery
from ..peer_handle import PeerHandle
Expand All @@ -21,7 +24,6 @@ def connection_made(self, transport):
def datagram_received(self, data, addr):
asyncio.create_task(self.on_message(data, addr))


class GRPCDiscovery(Discovery):
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES):
self.node_id = node_id
Expand All @@ -30,11 +32,42 @@ def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_por
self.listen_port = listen_port
self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
self.broadcast_interval = broadcast_interval
self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {}
self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float, int]] = {} # Added priority
self.broadcast_task = None
self.listen_task = None
self.cleanup_task = None

def get_network_interfaces(self):
if psutil.MACOS:
network_info = subprocess.check_output(['system_profiler', 'SPNetworkDataType', '-json']).decode('utf-8')
network_data = json.loads(network_info)

interfaces = []
for interface in network_data.get('SPNetworkDataType', []):
interface_name = interface.get('interface')
interface_type = interface.get('type', '').lower()
hardware = interface.get('hardware', '').lower()

if not interface_name:
continue

priority = self.get_interface_priority(interface_name, interface_type, hardware)
interfaces.append((interface_name, priority))

return sorted(interfaces, key=lambda x: x[1], reverse=True)
else:
return [(iface, self.get_interface_priority(iface)) for iface in netifaces.interfaces()]

def get_interface_priority(self, interface_name, interface_type='', hardware=''):
if 'thunderbolt' in hardware or 'thunderbolt' in interface_type:
return 3
elif interface_type == 'wi-fi' or 'airport' in hardware:
return 2
elif interface_name.startswith('en'):
return 1
else:
return 0

async def start(self):
self.device_capabilities = device_capabilities()
self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
Expand Down Expand Up @@ -73,32 +106,38 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.")
break # No new peers found in the grace period, we are done

return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]

async def task_broadcast_presence(self):
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(),
local_addr=('0.0.0.0', 0),
family=socket.AF_INET)
sock = transport.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)

message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict()
}).encode('utf-8')

while True:
try:
if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
transport.sendto(message, ('<broadcast>', self.broadcast_port))
await asyncio.sleep(self.broadcast_interval)
interfaces = self.get_network_interfaces()

for interface, priority in interfaces:
try:
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting on interface: {interface} with priority: {priority}")
message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict(),
"priority": priority
}).encode('utf-8')

transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(),
local_addr=(netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr'], 0),
family=socket.AF_INET)
sock = transport.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
transport.sendto(message, ('<broadcast>', self.broadcast_port))
transport.close()
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error broadcasting on interface {interface}: {e}")
except Exception as e:
print(f"Error in broadcast presence: {e}")
import traceback
print(traceback.format_exc())
if DEBUG_DISCOVERY >= 2: print(f"Error updating or accessing interfaces: {e}")

await asyncio.sleep(self.broadcast_interval)

async def on_listen_message(self, data, addr):
message = json.loads(data.decode('utf-8'))
Expand All @@ -107,26 +146,47 @@ async def on_listen_message(self, data, addr):
peer_id = message['node_id']
peer_host = addr[0]
peer_port = message['grpc_port']
new_priority = message['priority']
device_capabilities = DeviceCapabilities(**message['device_capabilities'])
if peer_id not in self.known_peers:
self.known_peers[peer_id] = (GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time(), time.time())
if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())

if peer_id in self.known_peers:
_, _, _, current_priority = self.known_peers[peer_id]
if new_priority > current_priority:
await self.reconnect_peer(peer_id, peer_host, peer_port, new_priority, device_capabilities)
else:
peer_handle = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
self.known_peers[peer_id] = (peer_handle, time.time(), time.time(), new_priority)
if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port} with priority {new_priority}")

async def reconnect_peer(self, peer_id, new_host, new_port, new_priority, device_capabilities):
old_handle, connected_at, _, _ = self.known_peers[peer_id]
await old_handle.disconnect()
new_handle = GRPCPeerHandle(peer_id, f"{new_host}:{new_port}", device_capabilities)
await new_handle.connect()
self.known_peers[peer_id] = (new_handle, connected_at, time.time(), new_priority)
if DEBUG_DISCOVERY >= 2: print(f"Reconnected to peer {peer_id} at {new_host}:{new_port} with new priority {new_priority}")

async def task_listen_for_peers(self):
await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
if DEBUG_DISCOVERY >= 2: print("Started listen task")
try:
if DEBUG_DISCOVERY >= 2: print("Starting to listen on all interfaces")
await asyncio.get_event_loop().create_datagram_endpoint(
lambda: ListenProtocol(self.on_listen_message),
local_addr=('0.0.0.0', self.listen_port)
)
if DEBUG_DISCOVERY >= 2: print("Started listen task on all interfaces")
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error setting up listening: {e}")

async def task_cleanup_peers(self):
while True:
try:
current_time = time.time()
timeout = 15 * self.broadcast_interval
peers_to_remove = [
peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values() if
peer_handle.id() for peer_handle, connected_at, last_seen, _ in self.known_peers.values() if
(not await peer_handle.is_connected() and current_time - connected_at > timeout) or current_time - last_seen > timeout
]
if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}, priority={priority}" for peer_handle, connected_at, last_seen, priority in self.known_peers.values()})
if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: print(f"Cleaning up peers: {peers_to_remove}")
for peer_id in peers_to_remove:
if peer_id in self.known_peers: del self.known_peers[peer_id]
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"tqdm==4.66.4",
"transformers==4.41.2",
"uuid==1.30",
"netifaces==0.11.0",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@a9f5a764dc640a5e5cbaaeeee21df7c8ca37da38",
]

Expand Down

0 comments on commit c17cf8d

Please sign in to comment.