diff --git a/tools/url_safety.py b/tools/url_safety.py index ae610d0f78..6746ba964c 100644 --- a/tools/url_safety.py +++ b/tools/url_safety.py @@ -15,6 +15,7 @@ SDKs (Firecrawl/Tavily) where redirect handling is on their servers. """ +import asyncio import ipaddress import logging import socket @@ -47,50 +48,89 @@ def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: return False +def _check_hostname(hostname: str, url: str) -> bool: + """Resolve hostname and verify it is not a private/internal address. + + This is the synchronous core used by both is_safe_url() and + async_is_safe_url(). Callers in async contexts should use + async_is_safe_url() to avoid blocking the event loop during DNS lookup. + """ + # Block known internal hostnames + if hostname in _BLOCKED_HOSTNAMES: + logger.warning("Blocked request to internal hostname: %s", hostname) + return False + + # Try to resolve and check IP + try: + addr_info = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM) + except socket.gaierror: + # DNS resolution failed — fail closed. If DNS can't resolve it, + # the HTTP client will also fail, so blocking loses nothing. + logger.warning("Blocked request — DNS resolution failed for: %s", hostname) + return False + + for family, _, _, _, sockaddr in addr_info: + ip_str = sockaddr[0] + try: + ip = ipaddress.ip_address(ip_str) + except ValueError: + continue + + if _is_blocked_ip(ip): + logger.warning( + "Blocked request to private/internal address: %s -> %s", + hostname, ip_str, + ) + return False + + return True + + def is_safe_url(url: str) -> bool: """Return True if the URL target is not a private/internal address. Resolves the hostname to an IP and checks against private ranges. Fails closed: DNS errors and unexpected exceptions block the request. + + WARNING: This function calls socket.getaddrinfo() which is synchronous + and will block the calling thread during DNS resolution. In async + contexts (e.g. inside async def functions), use async_is_safe_url() + instead to avoid blocking the event loop. """ try: parsed = urlparse(url) hostname = (parsed.hostname or "").strip().lower() if not hostname: return False + return _check_hostname(hostname, url) + except Exception as exc: + logger.warning("Blocked request — URL safety check error for %s: %s", url, exc) + return False - # Block known internal hostnames - if hostname in _BLOCKED_HOSTNAMES: - logger.warning("Blocked request to internal hostname: %s", hostname) - return False - # Try to resolve and check IP - try: - addr_info = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM) - except socket.gaierror: - # DNS resolution failed — fail closed. If DNS can't resolve it, - # the HTTP client will also fail, so blocking loses nothing. - logger.warning("Blocked request — DNS resolution failed for: %s", hostname) - return False +async def async_is_safe_url(url: str) -> bool: + """Async-safe version of is_safe_url() for use in async contexts. - for family, _, _, _, sockaddr in addr_info: - ip_str = sockaddr[0] - try: - ip = ipaddress.ip_address(ip_str) - except ValueError: - continue + Runs the blocking socket.getaddrinfo() DNS lookup in a thread pool + executor so the event loop is not blocked during DNS resolution. + This is important in the gateway where a slow DNS response would + otherwise freeze all message handling for all users. + """ + try: + parsed = urlparse(url) + hostname = (parsed.hostname or "").strip().lower() + if not hostname: + return False - if _is_blocked_ip(ip): - logger.warning( - "Blocked request to private/internal address: %s -> %s", - hostname, ip_str, - ) - return False + # Block known internal hostnames without DNS (no I/O needed) + if hostname in _BLOCKED_HOSTNAMES: + logger.warning("Blocked request to internal hostname: %s", hostname) + return False - return True + # Run blocking DNS resolution in thread pool to avoid blocking event loop + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, _check_hostname, hostname, url) except Exception as exc: - # Fail closed on unexpected errors — don't let parsing edge cases - # become SSRF bypass vectors logger.warning("Blocked request — URL safety check error for %s: %s", url, exc) return False