Skip to content

Commit

Permalink
prioritise thunderbolt over WIFI
Browse files Browse the repository at this point in the history
add netifaces dependency

use psutil instead of platform

use psutil instead of platform

fix typo

prioritise thunderbolt over WIFI

optimize network discovery and dynamic interface handling

optimize network discovery and dynamic interface handling

implement network interface detection and prioritization

implement network interface detection and prioritization

prioritize interfaces from device preference

prioritize interfaces from device preference

implemented assertions

implemented assertions

Implement pure functions outside the class. Write better test

prioritize interfaces from device preference
  • Loading branch information
itsknk committed Jul 22, 2024
1 parent 4d962ff commit 3355378
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 31 deletions.
122 changes: 91 additions & 31 deletions exo/networking/grpc/grpc_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,47 @@
import json
import socket
import time
import netifaces
import psutil
import subprocess
from typing import List, Dict, Callable, Tuple, Coroutine
from ..discovery import Discovery
from ..peer_handle import PeerHandle
from .grpc_peer_handle import GRPCPeerHandle
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
from exo import DEBUG_DISCOVERY

def get_interface_priority(interface_name, interface_type='', hardware=''):
if 'thunderbolt' in hardware or 'thunderbolt' in interface_type:
return 3
elif interface_type == 'wi-fi' or 'airport' in hardware:
return 2
elif interface_name.startswith('en'):
return 1
else:
return 0

def get_network_interfaces():
if psutil.MACOS:
network_info = subprocess.check_output(['system_profiler', 'SPNetworkDataType', '-json']).decode('utf-8')
network_data = json.loads(network_info)

interfaces = []
for interface in network_data.get('SPNetworkDataType', []):
interface_name = interface.get('interface')
interface_type = interface.get('type', '').lower()
hardware = interface.get('hardware', '').lower()

if not interface_name:
continue

priority = get_interface_priority(interface_name, interface_type, hardware)
interfaces.append((interface_name, priority))

return sorted(interfaces, key=lambda x: x[1], reverse=True)
else:
return [(iface, get_interface_priority(iface)) for iface in netifaces.interfaces()]

class ListenProtocol(asyncio.DatagramProtocol):
def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
super().__init__()
Expand All @@ -21,7 +55,6 @@ def connection_made(self, transport):
def datagram_received(self, data, addr):
asyncio.create_task(self.on_message(data, addr))


class GRPCDiscovery(Discovery):
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES):
self.node_id = node_id
Expand All @@ -30,7 +63,7 @@ def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_por
self.listen_port = listen_port
self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
self.broadcast_interval = broadcast_interval
self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {}
self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float, int]] = {}
self.broadcast_task = None
self.listen_task = None
self.cleanup_task = None
Expand Down Expand Up @@ -73,32 +106,38 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.")
break # No new peers found in the grace period, we are done

return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]

async def task_broadcast_presence(self):
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(),
local_addr=('0.0.0.0', 0),
family=socket.AF_INET)
sock = transport.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)

message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict()
}).encode('utf-8')

while True:
try:
if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
transport.sendto(message, ('<broadcast>', self.broadcast_port))
await asyncio.sleep(self.broadcast_interval)
interfaces = get_network_interfaces()

for interface, priority in interfaces:
try:
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting on interface: {interface} with priority: {priority}")
message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict(),
"priority": priority
}).encode('utf-8')

transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(),
local_addr=(netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr'], 0),
family=socket.AF_INET)
sock = transport.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
transport.sendto(message, ('<broadcast>', self.broadcast_port))
transport.close()
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error broadcasting on interface {interface}: {e}")
except Exception as e:
print(f"Error in broadcast presence: {e}")
import traceback
print(traceback.format_exc())
if DEBUG_DISCOVERY >= 2: print(f"Error updating or accessing interfaces: {e}")

await asyncio.sleep(self.broadcast_interval)

async def on_listen_message(self, data, addr):
message = json.loads(data.decode('utf-8'))
Expand All @@ -107,26 +146,47 @@ async def on_listen_message(self, data, addr):
peer_id = message['node_id']
peer_host = addr[0]
peer_port = message['grpc_port']
new_priority = message['priority']
device_capabilities = DeviceCapabilities(**message['device_capabilities'])
if peer_id not in self.known_peers:
self.known_peers[peer_id] = (GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time(), time.time())
if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())

if peer_id in self.known_peers:
_, _, _, current_priority = self.known_peers[peer_id]
if new_priority > current_priority:
await self.reconnect_peer(peer_id, peer_host, peer_port, new_priority, device_capabilities)
else:
peer_handle = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
self.known_peers[peer_id] = (peer_handle, time.time(), time.time(), new_priority)
if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port} with priority {new_priority}")

async def reconnect_peer(self, peer_id, new_host, new_port, new_priority, device_capabilities):
old_handle, connected_at, _, _ = self.known_peers[peer_id]
await old_handle.disconnect()
new_handle = GRPCPeerHandle(peer_id, f"{new_host}:{new_port}", device_capabilities)
await new_handle.connect()
self.known_peers[peer_id] = (new_handle, connected_at, time.time(), new_priority)
if DEBUG_DISCOVERY >= 2: print(f"Reconnected to peer {peer_id} at {new_host}:{new_port} with new priority {new_priority}")

async def task_listen_for_peers(self):
await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
if DEBUG_DISCOVERY >= 2: print("Started listen task")
try:
if DEBUG_DISCOVERY >= 2: print("Starting to listen on all interfaces")
await asyncio.get_event_loop().create_datagram_endpoint(
lambda: ListenProtocol(self.on_listen_message),
local_addr=('0.0.0.0', self.listen_port)
)
if DEBUG_DISCOVERY >= 2: print("Started listen task on all interfaces")
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error setting up listening: {e}")

async def task_cleanup_peers(self):
while True:
try:
current_time = time.time()
timeout = 15 * self.broadcast_interval
peers_to_remove = [
peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values() if
peer_handle.id() for peer_handle, connected_at, last_seen, _ in self.known_peers.values() if
(not await peer_handle.is_connected() and current_time - connected_at > timeout) or current_time - last_seen > timeout
]
if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}, priority={priority}" for peer_handle, connected_at, last_seen, priority in self.known_peers.values()})
if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: print(f"Cleaning up peers: {peers_to_remove}")
for peer_id in peers_to_remove:
if peer_id in self.known_peers: del self.known_peers[peer_id]
Expand Down
58 changes: 58 additions & 0 deletions exo/networking/grpc/test_network_interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import sys
import time

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
sys.path.insert(0, parent_dir)

from exo.networking.grpc.grpc_discovery import get_network_interfaces, get_interface_priority

def simulate_broadcast_receive(peer_id, interface, priority, known_peers):
print(f"\nReceived broadcast from peer {peer_id} on interface {interface} with priority {priority}")

if peer_id in known_peers:
current_interface, current_priority = known_peers[peer_id]
if priority > current_priority:
print(f"Higher priority detected. Disconnecting from {current_interface} and reconnecting on {interface}")
known_peers[peer_id] = (interface, priority)
else:
print(f"Keeping existing connection on {current_interface} with priority {current_priority}")
else:
print(f"New peer discovered. Connecting on {interface}")
known_peers[peer_id] = (interface, priority)

def run_tests():
print("Testing network interface detection, prioritization, and reconnection logic")
interfaces = get_network_interfaces()

# Test 1: Check if interfaces are detected and sorted
assert len(interfaces) > 0, "No interfaces detected"
assert all(interfaces[i][1] >= interfaces[i+1][1] for i in range(len(interfaces)-1)), "Interfaces are not correctly sorted by priority"

print("\nPrioritized list of interfaces:")
for i, (interface, priority) in enumerate(interfaces, 1):
print(f"{i}. {interface} (Priority: {priority})")

# Test 2: Simulate broadcast receives and reconnection
known_peers = {}

# Simulate receiving broadcasts from the same peer on different interfaces
peer_id = "test_peer_1"
for interface, priority in interfaces:
simulate_broadcast_receive(peer_id, interface, priority, known_peers)
time.sleep(0.1) # Add a small delay

# Test 3: Check if the peer is connected to the highest priority interface
assert peer_id in known_peers, "Peer was not added to known_peers"
final_interface, final_priority = known_peers[peer_id]
assert final_priority == max(priority for _, priority in interfaces), "Peer is not connected to the highest priority interface"

print("\nFinal known peers:")
for peer_id, (interface, priority) in known_peers.items():
print(f"Peer {peer_id}: Connected on {interface} with priority {priority}")

print("\nAll tests passed successfully!")

if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"tqdm==4.66.4",
"transformers==4.41.2",
"uuid==1.30",
"netifaces==0.11.0",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@a9f5a764dc640a5e5cbaaeeee21df7c8ca37da38",
]

Expand Down

0 comments on commit 3355378

Please sign in to comment.