Skip to content

Commit

Permalink
Add AsyncDualMDNSResolver class to resolve via DNS and mDNS (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Feb 6, 2025
1 parent cc9042f commit 657b2b7
Show file tree
Hide file tree
Showing 6 changed files with 417 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGES/23.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Created the :class:`aiohttp_asyncmdnsresolver.api.AsyncDualMDNSResolver` class to resolve ``.local`` names using both mDNS and DNS -- by :user:`bdraco`.
15 changes: 14 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ Reference

.. module:: aiohttp_asyncmdnsresolver.api

The only public *aiohttp_asyncmdnsresolver.api* class is :class:`AsyncMDNSResolver`:
The only public *aiohttp_asyncmdnsresolver.api* classes are :class:`AsyncMDNSResolver`
and :class:`AsyncDualMDNSResolver`:

.. doctest::

Expand Down Expand Up @@ -35,3 +36,15 @@ The only public *aiohttp_asyncmdnsresolver.api* class is :class:`AsyncMDNSResolv
async with aiohttp.ClientSession(connector=connector) as session:
async with session.get("http://KNKSADE41945.local.") as response:
print(response.status)


.. class:: AsyncDualMDNSResolver(*args, *, async_zeroconf=None, mdns_timeout=5.0, **kwargs)

This resolver is a variant of :class:`AsyncMDNSResolver` that resolves ``.local``
names with both mDNS and regular DNS. It takes the same arguments as
:class:`AsyncMDNSResolver`, and is used in the same way.

- The first successful result from either resolver is returned.
- If both resolvers fail, an exception is raised.
- If both resolvers return results at the same time, the results are
combined and duplicates are removed.
6 changes: 3 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ Introduction
Usage
-----

The API provides a single ``AsyncMDNSResolver`` class that can be
used to resolve mDNS queries and fallback to ``AsyncResolver`` for
non-MDNS hosts.
The API provides the :class:`aiohttp_asyncmdnsresolver.api.AsyncMDNSResolver` and
:class:`aiohttp_asyncmdnsresolver.api.AsyncDualMDNSResolver` classes that can be
used to resolve mDNS queries and fallback to ``AsyncResolver`` for non-MDNS hosts.

API documentation
-----------------
Expand Down
98 changes: 89 additions & 9 deletions src/aiohttp_asyncmdnsresolver/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from __future__ import annotations

import asyncio
import socket
import sys
from ipaddress import IPv4Address, IPv6Address
from typing import Any

Expand Down Expand Up @@ -51,7 +53,7 @@ def _to_resolve_result(
)


class AsyncMDNSResolver(AsyncResolver):
class _AsyncMDNSResolverBase(AsyncResolver):
"""Use the `aiodns`/`zeroconf` packages to make asynchronous DNS lookups."""

def __init__(
Expand All @@ -67,14 +69,6 @@ def __init__(
self._aiozc_owner = async_zeroconf is None
self._aiozc = async_zeroconf or AsyncZeroconf()

async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
) -> list[ResolveResult]:
"""Resolve a host name to an IP address."""
if host.endswith(".local") or host.endswith(".local."):
return await self._resolve_mdns(host, port, family)
return await super().resolve(host, port, family)

async def _resolve_mdns(
self, host: str, port: int, family: socket.AddressFamily
) -> list[ResolveResult]:
Expand Down Expand Up @@ -102,3 +96,89 @@ async def close(self) -> None:
await self._aiozc.async_close()
await super().close()
self._aiozc = None # type: ignore[assignment] # break ref cycles early


class AsyncMDNSResolver(_AsyncMDNSResolverBase):
"""Use the `aiodns`/`zeroconf` packages to make asynchronous DNS lookups."""

async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
) -> list[ResolveResult]:
"""Resolve a host name to an IP address."""
if not host.endswith(".local") and not host.endswith(".local."):
return await super().resolve(host, port, family)
return await self._resolve_mdns(host, port, family)


class AsyncDualMDNSResolver(_AsyncMDNSResolverBase):
"""Use the `aiodns`/`zeroconf` packages to make asynchronous DNS lookups.
This resolver is a variant of `AsyncMDNSResolver` that resolves .local
names with both mDNS and regular DNS.
- The first successful result from either resolver is returned.
- If both resolvers fail, an exception is raised.
- If both resolvers return results at the same time, the results are
combined and duplicates are removed.
"""

async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
) -> list[ResolveResult]:
"""Resolve a host name to an IP address."""
if not host.endswith(".local") and not host.endswith(".local."):
return await super().resolve(host, port, family)
resolve_via_mdns = self._resolve_mdns(host, port, family)
resolve_via_dns = super().resolve(host, port, family)
loop = asyncio.get_running_loop()
if sys.version_info >= (3, 12):
mdns_task = asyncio.Task(resolve_via_mdns, loop=loop, eager_start=True)
dns_task = asyncio.Task(resolve_via_dns, loop=loop, eager_start=True)
else:
mdns_task = loop.create_task(resolve_via_mdns)
dns_task = loop.create_task(resolve_via_dns)
await asyncio.wait((mdns_task, dns_task), return_when=asyncio.FIRST_COMPLETED)
if mdns_task.done() and mdns_task.exception():
await asyncio.wait((dns_task,), return_when=asyncio.ALL_COMPLETED)
elif dns_task.done() and dns_task.exception():
await asyncio.wait((mdns_task,), return_when=asyncio.ALL_COMPLETED)
resolve_results: list[ResolveResult] = []
exceptions: list[BaseException] = []
seen_results: set[tuple[str, int, str]] = set()
for task in (mdns_task, dns_task):
if task.done():
if exc := task.exception():
exceptions.append(exc)
else:
# If we have multiple results, we need to remove duplicates
# and combine the results. We put the mDNS results first
# to prioritize them.
for result in task.result():
result_key = (
result["hostname"],
result["port"],
result["host"],
)
if result_key not in seen_results:
seen_results.add(result_key)
resolve_results.append(result)
else:
task.cancel()
try:
await task # clear log traceback
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
and (current_task := asyncio.current_task())
and current_task.cancelling()
):
raise

if resolve_results:
return resolve_results

exception_strings = ", ".join(
exc.strerror or str(exc) if isinstance(exc, OSError) else str(exc)
for exc in exceptions
)
raise OSError(None, exception_strings)
4 changes: 2 additions & 2 deletions src/aiohttp_asyncmdnsresolver/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Public API of the property caching library."""

from ._impl import AsyncMDNSResolver
from ._impl import AsyncDualMDNSResolver, AsyncMDNSResolver

__all__ = ("AsyncMDNSResolver",)
__all__ = ("AsyncMDNSResolver", "AsyncDualMDNSResolver")
Loading

0 comments on commit 657b2b7

Please sign in to comment.