Skip to content

Commit

Permalink
prioritise thunderbolt over WIFI
Browse files Browse the repository at this point in the history
add netifaces dependency

use psutil instead of platform

use psutil instead of platform

fix typo

prioritise thunderbolt over WIFI

optimize network discovery and dynamic interface handling

optimize network discovery and dynamic interface handling

implement network interface detection and prioritization

implement network interface detection and prioritization
  • Loading branch information
itsknk committed Jul 21, 2024
1 parent 4d962ff commit c31e124
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 16 deletions.
83 changes: 67 additions & 16 deletions exo/networking/grpc/grpc_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import json
import socket
import time
import netifaces
import psutil
import subprocess
from typing import List, Dict, Callable, Tuple, Coroutine
from ..discovery import Discovery
from ..peer_handle import PeerHandle
Expand All @@ -21,7 +24,6 @@ def connection_made(self, transport):
def datagram_received(self, data, addr):
asyncio.create_task(self.on_message(data, addr))


class GRPCDiscovery(Discovery):
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES):
self.node_id = node_id
Expand All @@ -35,6 +37,42 @@ def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_por
self.listen_task = None
self.cleanup_task = None

def get_network_interfaces(self):
if psutil.MACOS:
# Use system_profiler to get detailed network information
network_info = subprocess.check_output(['system_profiler', 'SPNetworkDataType', '-json']).decode('utf-8')
network_data = json.loads(network_info)

thunderbolt_interfaces = []
wifi_interfaces = []
other_interfaces = []

for interface in network_data.get('SPNetworkDataType', []):
interface_name = interface.get('interface')
interface_type = interface.get('type', '').lower()
hardware = interface.get('hardware', '').lower()

if not interface_name:
continue

# Check for Thunderbolt-related keywords
if 'thunderbolt' in hardware or 'thunderbolt' in interface_type:
thunderbolt_interfaces.append(interface_name)
elif interface_type == 'wi-fi':
wifi_interfaces.append(interface_name)
else:
other_interfaces.append(interface_name)

if self.DEBUG >= 2:
print(f"Thunderbolt interfaces: {thunderbolt_interfaces}")
print(f"WiFi interfaces: {wifi_interfaces}")
print(f"Other interfaces: {other_interfaces}")

# Prioritize Thunderbolt, then WiFi, then others
return thunderbolt_interfaces + wifi_interfaces + other_interfaces
else:
return netifaces.interfaces()

async def start(self):
self.device_capabilities = device_capabilities()
self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
Expand Down Expand Up @@ -76,13 +114,6 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
return [peer_handle for peer_handle, _, _ in self.known_peers.values()]

async def task_broadcast_presence(self):
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(),
local_addr=('0.0.0.0', 0),
family=socket.AF_INET)
sock = transport.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)

message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
Expand All @@ -92,13 +123,26 @@ async def task_broadcast_presence(self):

while True:
try:
if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
transport.sendto(message, ('<broadcast>', self.broadcast_port))
await asyncio.sleep(self.broadcast_interval)
# Update interfaces periodically
interfaces = self.get_network_interfaces()

for interface in interfaces:
try:
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting on interface: {interface}")
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(),
local_addr=(netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr'], 0),
family=socket.AF_INET)
sock = transport.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
transport.sendto(message, ('<broadcast>', self.broadcast_port))
transport.close()
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error broadcasting on interface {interface}: {e}")
except Exception as e:
print(f"Error in broadcast presence: {e}")
import traceback
print(traceback.format_exc())
if DEBUG_DISCOVERY >= 2: print(f"Error updating or accessing interfaces: {e}")

await asyncio.sleep(self.broadcast_interval)

async def on_listen_message(self, data, addr):
message = json.loads(data.decode('utf-8'))
Expand All @@ -114,8 +158,15 @@ async def on_listen_message(self, data, addr):
self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())

async def task_listen_for_peers(self):
await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
if DEBUG_DISCOVERY >= 2: print("Started listen task")
try:
if DEBUG_DISCOVERY >= 2: print("Starting to listen on all interfaces")
await asyncio.get_event_loop().create_datagram_endpoint(
lambda: ListenProtocol(self.on_listen_message),
local_addr=('0.0.0.0', self.listen_port)
)
if DEBUG_DISCOVERY >= 2: print("Started listen task on all interfaces")
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error setting up listening: {e}")

async def task_cleanup_peers(self):
while True:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"tqdm==4.66.4",
"transformers==4.41.2",
"uuid==1.30",
"netifaces==0.11.0",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@a9f5a764dc640a5e5cbaaeeee21df7c8ca37da38",
]

Expand Down

0 comments on commit c31e124

Please sign in to comment.