Skip to content

Commit

Permalink
prioritise thunderbolt over WIFI
Browse files Browse the repository at this point in the history
add netifaces dependency

use psutil instead of platform

use psutil instead of platform

fix typo

prioritise thunderbolt over WIFI

optimize network discovery and dynamic interface handling

optimize network discovery and dynamic interface handling
  • Loading branch information
itsknk committed Jul 20, 2024
1 parent 4d962ff commit 95f22ed
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 16 deletions.
55 changes: 39 additions & 16 deletions exo/networking/grpc/grpc_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import json
import socket
import time
import netifaces
import psutil
from typing import List, Dict, Callable, Tuple, Coroutine
from ..discovery import Discovery
from ..peer_handle import PeerHandle
Expand All @@ -21,7 +23,6 @@ def connection_made(self, transport):
def datagram_received(self, data, addr):
asyncio.create_task(self.on_message(data, addr))


class GRPCDiscovery(Discovery):
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES):
self.node_id = node_id
Expand All @@ -35,6 +36,15 @@ def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_por
self.listen_task = None
self.cleanup_task = None

def get_network_interfaces(self):
interfaces = netifaces.interfaces()
if psutil.MACOS:
thunderbolt_interfaces = [iface for iface in interfaces if iface.startswith('thunderbolt')]
wifi_interfaces = [iface for iface in interfaces if iface.startswith('en')]
return thunderbolt_interfaces + wifi_interfaces # Prioritize Thunderbolt
else:
return interfaces

async def start(self):
self.device_capabilities = device_capabilities()
self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
Expand Down Expand Up @@ -76,13 +86,6 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
return [peer_handle for peer_handle, _, _ in self.known_peers.values()]

async def task_broadcast_presence(self):
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(),
local_addr=('0.0.0.0', 0),
family=socket.AF_INET)
sock = transport.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)

message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
Expand All @@ -92,13 +95,26 @@ async def task_broadcast_presence(self):

while True:
try:
if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
transport.sendto(message, ('<broadcast>', self.broadcast_port))
await asyncio.sleep(self.broadcast_interval)
# Update interfaces periodically
interfaces = self.get_network_interfaces()

for interface in interfaces:
try:
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting on interface: {interface}")
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(),
local_addr=(netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr'], 0),
family=socket.AF_INET)
sock = transport.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
transport.sendto(message, ('<broadcast>', self.broadcast_port))
transport.close()
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error broadcasting on interface {interface}: {e}")
except Exception as e:
print(f"Error in broadcast presence: {e}")
import traceback
print(traceback.format_exc())
if DEBUG_DISCOVERY >= 2: print(f"Error updating or accessing interfaces: {e}")

await asyncio.sleep(self.broadcast_interval)

async def on_listen_message(self, data, addr):
message = json.loads(data.decode('utf-8'))
Expand All @@ -114,8 +130,15 @@ async def on_listen_message(self, data, addr):
self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())

async def task_listen_for_peers(self):
await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
if DEBUG_DISCOVERY >= 2: print("Started listen task")
try:
if DEBUG_DISCOVERY >= 2: print("Starting to listen on all interfaces")
await asyncio.get_event_loop().create_datagram_endpoint(
lambda: ListenProtocol(self.on_listen_message),
local_addr=('0.0.0.0', self.listen_port)
)
if DEBUG_DISCOVERY >= 2: print("Started listen task on all interfaces")
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error setting up listening: {e}")

async def task_cleanup_peers(self):
while True:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"tqdm==4.66.4",
"transformers==4.41.2",
"uuid==1.30",
"netifaces==0.11.0",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@a9f5a764dc640a5e5cbaaeeee21df7c8ca37da38",
]

Expand Down

0 comments on commit 95f22ed

Please sign in to comment.