Skip to content

Commit

Permalink
prioritise thunderbolt over WIFI
Browse files Browse the repository at this point in the history
add netifaces dependency

use psutil instead of platform

use psutil instead of platform

fix typo

prioritise thunderbolt over WIFI
  • Loading branch information
itsknk committed Jul 20, 2024
1 parent 4d962ff commit 0222b2e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 17 deletions.
52 changes: 35 additions & 17 deletions exo/networking/grpc/grpc_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import json
import socket
import time
import netifaces
import psutil
from typing import List, Dict, Callable, Tuple, Coroutine
from ..discovery import Discovery
from ..peer_handle import PeerHandle
Expand All @@ -21,7 +23,6 @@ def connection_made(self, transport):
def datagram_received(self, data, addr):
asyncio.create_task(self.on_message(data, addr))


class GRPCDiscovery(Discovery):
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES):
self.node_id = node_id
Expand All @@ -34,6 +35,16 @@ def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_por
self.broadcast_task = None
self.listen_task = None
self.cleanup_task = None
self.interfaces = self.get_network_interfaces()

def get_network_interfaces(self):
interfaces = netifaces.interfaces()
if psutil.MACOS:
thunderbolt_interfaces = [iface for iface in interfaces if iface.startswith('thunderbolt')]
wifi_interfaces = [iface for iface in interfaces if iface.startswith('en')]
return thunderbolt_interfaces + wifi_interfaces # Prioritize Thunderbolt
else:
return interfaces

async def start(self):
self.device_capabilities = device_capabilities()
Expand Down Expand Up @@ -76,13 +87,6 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
return [peer_handle for peer_handle, _, _ in self.known_peers.values()]

async def task_broadcast_presence(self):
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(),
local_addr=('0.0.0.0', 0),
family=socket.AF_INET)
sock = transport.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)

message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
Expand All @@ -91,14 +95,20 @@ async def task_broadcast_presence(self):
}).encode('utf-8')

while True:
try:
if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
transport.sendto(message, ('<broadcast>', self.broadcast_port))
await asyncio.sleep(self.broadcast_interval)
except Exception as e:
print(f"Error in broadcast presence: {e}")
import traceback
print(traceback.format_exc())
for interface in self.interfaces:
try:
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting on interface: {interface}")
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(),
local_addr=(netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr'], 0),
family=socket.AF_INET)
sock = transport.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
transport.sendto(message, ('<broadcast>', self.broadcast_port))
transport.close()
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error broadcasting on interface {interface}: {e}")
await asyncio.sleep(self.broadcast_interval)

async def on_listen_message(self, data, addr):
message = json.loads(data.decode('utf-8'))
Expand All @@ -114,7 +124,15 @@ async def on_listen_message(self, data, addr):
self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())

async def task_listen_for_peers(self):
await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
for interface in self.interfaces:
try:
if DEBUG_DISCOVERY >= 2: print(f"Listening on interface: {interface}")
await asyncio.get_event_loop().create_datagram_endpoint(
lambda: ListenProtocol(self.on_listen_message),
local_addr=(netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr'], self.listen_port)
)
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error listening on interface {interface}: {e}")
if DEBUG_DISCOVERY >= 2: print("Started listen task")

async def task_cleanup_peers(self):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"tqdm==4.66.4",
"transformers==4.41.2",
"uuid==1.30",
"netifaces==0.11.0",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@a9f5a764dc640a5e5cbaaeeee21df7c8ca37da38",
]

Expand Down

0 comments on commit 0222b2e

Please sign in to comment.