Skip to content

Commit

Permalink
Avoid creating tasks if the name is already in the cache (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Feb 14, 2025
1 parent 1bc7763 commit 878438d
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 20 deletions.
53 changes: 33 additions & 20 deletions src/aiohttp_asyncmdnsresolver/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import socket
import sys
from ipaddress import IPv4Address, IPv6Address
from typing import Any
from typing import TYPE_CHECKING, Any, Union

from aiohttp.resolver import AsyncResolver, ResolveResult
from zeroconf import (
Expand All @@ -19,6 +19,8 @@

DEFAULT_TIMEOUT = 5.0

ResolverType = Union[AddressResolver, AddressResolverIPv4, AddressResolverIPv6]

_FAMILY_TO_RESOLVER_CLASS: dict[
socket.AddressFamily,
type[AddressResolver] | type[AddressResolverIPv4] | type[AddressResolverIPv6],
Expand Down Expand Up @@ -69,26 +71,31 @@ def __init__(
self._aiozc_owner = async_zeroconf is None
self._aiozc = async_zeroconf or AsyncZeroconf()

def _make_resolver(self, host: str, family: socket.AddressFamily) -> ResolverType:
"""Create an mDNS resolver."""
resolver_class = _FAMILY_TO_RESOLVER_CLASS[family]
return resolver_class(host if host[-1] == "." else f"{host}.")

def _addresses_from_info_or_raise(
self, info: ResolverType, port: int, family: socket.AddressFamily
) -> list[ResolveResult]:
"""Get addresses from info or raise OSError."""
ip_version = _FAMILY_TO_IP_VERSION[family]
if addresses := info.ip_addresses_by_version(ip_version):
if TYPE_CHECKING:
assert info.server is not None
return [
_to_resolve_result(info.server, port, address) for address in addresses
]
raise OSError(None, "MDNS lookup failed")

async def _resolve_mdns(
self, host: str, port: int, family: socket.AddressFamily
self, info: ResolverType, port: int, family: socket.AddressFamily
) -> list[ResolveResult]:
"""Resolve a host name to an IP address using mDNS."""
resolver_class = _FAMILY_TO_RESOLVER_CLASS[family]
ip_version: IPVersion = _FAMILY_TO_IP_VERSION[family]
if host[-1] != ".":
host += "."
info = resolver_class(host)
if (
info.load_from_cache(self._aiozc.zeroconf)
or (
self._mdns_timeout
and await info.async_request(
self._aiozc.zeroconf, self._mdns_timeout * 1000
)
)
) and (addresses := info.ip_addresses_by_version(ip_version)):
return [_to_resolve_result(host, port, address) for address in addresses]
raise OSError(None, "MDNS lookup failed")
if self._mdns_timeout:
await info.async_request(self._aiozc.zeroconf, self._mdns_timeout * 1000)
return self._addresses_from_info_or_raise(info, port, family)

async def close(self) -> None:
"""Close the resolver."""
Expand All @@ -107,7 +114,10 @@ async def resolve(
"""Resolve a host name to an IP address."""
if not host.endswith(".local") and not host.endswith(".local."):
return await super().resolve(host, port, family)
return await self._resolve_mdns(host, port, family)
info = self._make_resolver(host, family)
if info.load_from_cache(self._aiozc.zeroconf):
return self._addresses_from_info_or_raise(info, port, family)
return await self._resolve_mdns(info, port, family)


class AsyncDualMDNSResolver(_AsyncMDNSResolverBase):
Expand All @@ -128,7 +138,10 @@ async def resolve(
"""Resolve a host name to an IP address."""
if not host.endswith(".local") and not host.endswith(".local."):
return await super().resolve(host, port, family)
resolve_via_mdns = self._resolve_mdns(host, port, family)
info = self._make_resolver(host, family)
if info.load_from_cache(self._aiozc.zeroconf):
return self._addresses_from_info_or_raise(info, port, family)
resolve_via_mdns = self._resolve_mdns(info, port, family)
resolve_via_dns = super().resolve(host, port, family)
loop = asyncio.get_running_loop()
if sys.version_info >= (3, 12):
Expand Down
58 changes: 58 additions & 0 deletions tests/test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,31 @@ async def _take_a_while_to_resolve(*args: Any, **kwargs: Any) -> bool:
assert result["host"] == "127.0.0.2"


@pytest.mark.asyncio
async def test_async_dual_mdns_resolver_from_cache(
dual_resolver: AsyncMDNSResolver,
) -> None:
"""Test AsyncDualMDNSResolver can resolve from cache."""
with (
patch(
"aiohttp_asyncmdnsresolver._impl.AsyncResolver.resolve",
side_effect=OSError,
),
patch.object(IPv4HostResolver, "load_from_cache", return_value=True),
patch.object(
IPv4HostResolver,
"ip_addresses_by_version",
return_value=[IPv4Address("127.0.0.2")],
),
):
results = await dual_resolver.resolve("localhost.local.")
assert results is not None
assert len(results) == 1
result = results[0]
assert result["hostname"] == "localhost.local."
assert result["host"] == "127.0.0.2"


@pytest.mark.asyncio
async def test_different_results_async_dual_mdns_resolver(
dual_resolver: AsyncMDNSResolver,
Expand Down Expand Up @@ -443,6 +468,39 @@ async def test_different_results_async_dual_mdns_resolver(
assert result["host"] == "127.0.0.1"


@pytest.mark.asyncio
async def test_different_results_async_dual_mdns_resolver_zero_timeout(
dual_resolver: AsyncMDNSResolver,
) -> None:
"""Test AsyncDualMDNSResolver resolves using mDNS and DNS.
Test when both resolvers return different results with zero timeout
for mDNS.
"""
dual_resolver._mdns_timeout = 0
with (
patch(
"aiohttp_asyncmdnsresolver._impl.AsyncResolver.resolve",
return_value=[
ResolveResult(hostname="localhost.local.", host="127.0.0.1", port=0) # type: ignore[typeddict-item]
],
),
patch.object(IPv4HostResolver, "load_from_cache", return_value=False),
patch.object(IPv4HostResolver, "async_request", return_value=True),
patch.object(
IPv4HostResolver,
"ip_addresses_by_version",
return_value=[],
),
):
results = await dual_resolver.resolve("localhost.local.")
assert results is not None
assert len(results) == 1
result = results[0]
assert result["hostname"] == "localhost.local."
assert result["host"] == "127.0.0.1"


@pytest.mark.asyncio
async def test_failed_mdns_async_dual_mdns_resolver(
dual_resolver: AsyncMDNSResolver,
Expand Down

0 comments on commit 878438d

Please sign in to comment.