Skip to content

Commit

Permalink
prioritize interfaces from device preference
Browse files Browse the repository at this point in the history
  • Loading branch information
itsknk committed Jul 21, 2024
1 parent c31e124 commit 9d1ef8a
Showing 1 changed file with 47 additions and 38 deletions.
85 changes: 47 additions & 38 deletions exo/networking/grpc/grpc_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,17 @@ def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_por
self.listen_port = listen_port
self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
self.broadcast_interval = broadcast_interval
self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {}
self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float, int]] = {} # Added priority
self.broadcast_task = None
self.listen_task = None
self.cleanup_task = None

def get_network_interfaces(self):
if psutil.MACOS:
# Use system_profiler to get detailed network information
network_info = subprocess.check_output(['system_profiler', 'SPNetworkDataType', '-json']).decode('utf-8')
network_data = json.loads(network_info)

thunderbolt_interfaces = []
wifi_interfaces = []
other_interfaces = []

interfaces = []
for interface in network_data.get('SPNetworkDataType', []):
interface_name = interface.get('interface')
interface_type = interface.get('type', '').lower()
Expand All @@ -55,23 +51,22 @@ def get_network_interfaces(self):
if not interface_name:
continue

# Check for Thunderbolt-related keywords
if 'thunderbolt' in hardware or 'thunderbolt' in interface_type:
thunderbolt_interfaces.append(interface_name)
elif interface_type == 'wi-fi':
wifi_interfaces.append(interface_name)
else:
other_interfaces.append(interface_name)

if self.DEBUG >= 2:
print(f"Thunderbolt interfaces: {thunderbolt_interfaces}")
print(f"WiFi interfaces: {wifi_interfaces}")
print(f"Other interfaces: {other_interfaces}")
priority = self.get_interface_priority(interface_name, interface_type, hardware)
interfaces.append((interface_name, priority))

# Prioritize Thunderbolt, then WiFi, then others
return thunderbolt_interfaces + wifi_interfaces + other_interfaces
return sorted(interfaces, key=lambda x: x[1], reverse=True)
else:
return netifaces.interfaces()
return [(iface, self.get_interface_priority(iface)) for iface in netifaces.interfaces()]

def get_interface_priority(self, interface_name, interface_type='', hardware=''):
if 'thunderbolt' in hardware or 'thunderbolt' in interface_type:
return 3
elif interface_type == 'wi-fi' or 'airport' in hardware:
return 2
elif interface_name.startswith('en'):
return 1
else:
return 0

async def start(self):
self.device_capabilities = device_capabilities()
Expand Down Expand Up @@ -111,24 +106,24 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.")
break # No new peers found in the grace period, we are done

return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]

async def task_broadcast_presence(self):
message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict()
}).encode('utf-8')

while True:
try:
# Update interfaces periodically
interfaces = self.get_network_interfaces()

for interface in interfaces:
for interface, priority in interfaces:
try:
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting on interface: {interface}")
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting on interface: {interface} with priority: {priority}")
message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict(),
"priority": priority
}).encode('utf-8')

transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(),
local_addr=(netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr'], 0),
Expand All @@ -151,11 +146,25 @@ async def on_listen_message(self, data, addr):
peer_id = message['node_id']
peer_host = addr[0]
peer_port = message['grpc_port']
new_priority = message['priority']
device_capabilities = DeviceCapabilities(**message['device_capabilities'])
if peer_id not in self.known_peers:
self.known_peers[peer_id] = (GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time(), time.time())
if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())

if peer_id in self.known_peers:
_, _, _, current_priority = self.known_peers[peer_id]
if new_priority > current_priority:
await self.reconnect_peer(peer_id, peer_host, peer_port, new_priority, device_capabilities)
else:
peer_handle = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
self.known_peers[peer_id] = (peer_handle, time.time(), time.time(), new_priority)
if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port} with priority {new_priority}")

async def reconnect_peer(self, peer_id, new_host, new_port, new_priority, device_capabilities):
old_handle, connected_at, _, _ = self.known_peers[peer_id]
await old_handle.disconnect()
new_handle = GRPCPeerHandle(peer_id, f"{new_host}:{new_port}", device_capabilities)
await new_handle.connect()
self.known_peers[peer_id] = (new_handle, connected_at, time.time(), new_priority)
if DEBUG_DISCOVERY >= 2: print(f"Reconnected to peer {peer_id} at {new_host}:{new_port} with new priority {new_priority}")

async def task_listen_for_peers(self):
try:
Expand All @@ -174,10 +183,10 @@ async def task_cleanup_peers(self):
current_time = time.time()
timeout = 15 * self.broadcast_interval
peers_to_remove = [
peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values() if
peer_handle.id() for peer_handle, connected_at, last_seen, _ in self.known_peers.values() if
(not await peer_handle.is_connected() and current_time - connected_at > timeout) or current_time - last_seen > timeout
]
if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}, priority={priority}" for peer_handle, connected_at, last_seen, priority in self.known_peers.values()})
if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: print(f"Cleaning up peers: {peers_to_remove}")
for peer_id in peers_to_remove:
if peer_id in self.known_peers: del self.known_peers[peer_id]
Expand Down

0 comments on commit 9d1ef8a

Please sign in to comment.