diff --git a/newsfragments/279.feature.rst b/newsfragments/279.feature.rst new file mode 100644 index 000000000..7c5820e11 --- /dev/null +++ b/newsfragments/279.feature.rst @@ -0,0 +1 @@ +Add ``trio.open_unix_listener``, ``trio.serve_unix``, and ``trio.UnixSocketListener`` to support ``SOCK_STREAM`` `Unix domain sockets `__ diff --git a/src/trio/__init__.py b/src/trio/__init__.py index 34fda8452..fc6fd90ab 100644 --- a/src/trio/__init__.py +++ b/src/trio/__init__.py @@ -63,6 +63,10 @@ serve_tcp as serve_tcp, ) from ._highlevel_open_tcp_stream import open_tcp_stream as open_tcp_stream +from ._highlevel_open_unix_listeners import ( + open_unix_listener as open_unix_listener, + serve_unix as serve_unix, +) from ._highlevel_open_unix_stream import open_unix_socket as open_unix_socket from ._highlevel_serve_listeners import serve_listeners as serve_listeners from ._highlevel_socket import ( diff --git a/src/trio/_highlevel_open_unix_listeners.py b/src/trio/_highlevel_open_unix_listeners.py new file mode 100644 index 000000000..6d2f6bb11 --- /dev/null +++ b/src/trio/_highlevel_open_unix_listeners.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +import trio +import trio.socket as tsocket +from trio import TaskStatus + +from ._highlevel_open_tcp_listeners import _compute_backlog + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + +try: + from trio.socket import AF_UNIX + + HAS_UNIX = True +except ImportError: + HAS_UNIX = False + + +async def open_unix_listener( + path: str | bytes | os.PathLike[str] | os.PathLike[bytes], + *, + mode: int | None = None, + backlog: int | None = None, +) -> trio.SocketListener: + """Create :class:`SocketListener` objects to listen for connections. + Opens a connection to the specified + `Unix domain socket `__. + + You must have read/write permission on the specified file to connect. + + Args: + + path (str): Filename of UNIX socket to create and listen on. + Absolute or relative paths may be used. + + mode (int or None): The socket file permissions. + UNIX permissions are usually specified in octal numbers. If + you leave this as ``None``, Trio will not change the mode from + the operating system's default. + + backlog (int or None): The listen backlog to use. If you leave this as + ``None`` then Trio will pick a good default. (Currently: + whatever your system has configured as the maximum backlog.) + + Returns: + :class:`UnixSocketListener` + + Raises: + :class:`ValueError` If invalid arguments. + :class:`RuntimeError`: If AF_UNIX sockets are not supported. + :class:`FileNotFoundError`: If folder socket file is to be created in does not exist. + """ + if not HAS_UNIX: + raise RuntimeError("Unix sockets are not supported on this platform") + + computed_backlog = _compute_backlog(backlog) + + fspath = await trio.Path(os.fsdecode(path)).absolute() + + folder = fspath.parent + if not await folder.exists(): + raise FileNotFoundError(f"Socket folder does not exist: {folder!r}") + + str_path = str(fspath) + + # much more simplified logic vs tcp sockets - one socket family and only one + # possible location to connect to + sock = tsocket.socket(AF_UNIX, tsocket.SOCK_STREAM) + try: + await sock.bind(str_path) + + if mode is not None: + await fspath.chmod(mode) + + sock.listen(computed_backlog) + + return trio.SocketListener(sock) + except BaseException: + sock.close() + if os.path.exists(str_path): + os.unlink(str_path) + raise + + +async def serve_unix( + handler: Callable[[trio.SocketStream], Awaitable[object]], + path: str | bytes | os.PathLike[str] | os.PathLike[bytes], + *, + backlog: int | None = None, + handler_nursery: trio.Nursery | None = None, + task_status: TaskStatus[list[trio.SocketListener]] = trio.TASK_STATUS_IGNORED, +) -> None: + """Listen for incoming UNIX connections, and for each one start a task + running ``handler(stream)``. + This is a thin convenience wrapper around :func:`open_unix_listener` and + :func:`serve_listeners` – see them for full details. + .. warning:: + If ``handler`` raises an exception, then this function doesn't do + anything special to catch it – so by default the exception will + propagate out and crash your server. If you don't want this, then catch + exceptions inside your ``handler``, or use a ``handler_nursery`` object + that responds to exceptions in some other way. + When used with ``nursery.start`` you get back the newly opened listeners. + Args: + handler: The handler to start for each incoming connection. Passed to + :func:`serve_listeners`. + path: The socket file name. + Passed to :func:`open_unix_listener`. + backlog: The listen backlog, or None to have a good default picked. + Passed to :func:`open_tcp_listener`. + handler_nursery: The nursery to start handlers in, or None to use an + internal nursery. Passed to :func:`serve_listeners`. + task_status: This function can be used with ``nursery.start``. + Returns: + This function only returns when cancelled. + Raises: + RuntimeError: If AF_UNIX sockets are not supported. + """ + if not HAS_UNIX: + raise RuntimeError("Unix sockets are not supported on this platform") + + listener = await open_unix_listener(path, backlog=backlog) + await trio.serve_listeners( + handler, + [listener], + handler_nursery=handler_nursery, + task_status=task_status, + ) diff --git a/src/trio/_highlevel_socket.py b/src/trio/_highlevel_socket.py index c04e66e1b..f7a23364d 100644 --- a/src/trio/_highlevel_socket.py +++ b/src/trio/_highlevel_socket.py @@ -3,7 +3,10 @@ import errno from contextlib import contextmanager, suppress -from typing import TYPE_CHECKING, overload +from os import stat, unlink +from os.path import exists +from stat import S_ISSOCK +from typing import TYPE_CHECKING, Final, overload import trio @@ -16,7 +19,7 @@ from typing_extensions import Buffer - from ._socket import SocketType + from ._socket import AddressFormat, SocketType # XX TODO: this number was picked arbitrarily. We should do experiments to # tune it. (Or make it dynamic -- one idea is to start small and increase it @@ -31,6 +34,8 @@ errno.ENOTSOCK, } +HAS_UNIX: Final = hasattr(tsocket, "AF_UNIX") + @contextmanager def _translate_socket_errors_to_stream_errors() -> Generator[None, None, None]: @@ -68,13 +73,15 @@ class SocketStream(HalfCloseableStream): """ + __slots__ = ("_send_conflict_detector", "socket") + def __init__(self, socket: SocketType) -> None: if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketStream requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: raise ValueError("SocketStream requires a SOCK_STREAM socket") - self.socket = socket + self.socket: SocketType = socket self._send_conflict_detector = ConflictDetector( "another task is currently sending data on this SocketStream", ) @@ -356,7 +363,9 @@ class SocketListener(Listener[SocketStream]): and be listening. Note that the :class:`SocketListener` "takes ownership" of the given - socket; closing the :class:`SocketListener` will also close the socket. + socket; closing the :class:`SocketListener` will also close the + socket, and if it's a Unix socket, it will also unlink the leftover + socket file that the Unix socket is bound to. .. attribute:: socket @@ -364,7 +373,12 @@ class SocketListener(Listener[SocketStream]): """ - def __init__(self, socket: SocketType) -> None: + __slots__ = ("socket",) + + def __init__( + self, + socket: SocketType, + ) -> None: if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketListener requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -378,7 +392,7 @@ def __init__(self, socket: SocketType) -> None: if not listening: raise ValueError("SocketListener requires a listening socket") - self.socket = socket + self.socket: SocketType = socket async def accept(self) -> SocketStream: """Accept an incoming connection. @@ -409,6 +423,21 @@ async def accept(self) -> SocketStream: return SocketStream(sock) async def aclose(self) -> None: - """Close this listener and its underlying socket.""" + """Close this listener, its underlying socket, and for Unix sockets unlink the socket file.""" + is_unix_socket = self.socket.family == getattr(tsocket, "AF_UNIX", None) + + path: AddressFormat | None = None + if is_unix_socket: + # If unix socket, need to get path before we close socket + # or OS errors + path = self.socket.getsockname() self.socket.close() + # If unix socket, clean up socket file that gets left behind. + if ( + is_unix_socket + and path is not None + and exists(path) + and S_ISSOCK(stat(path).st_mode) + ): + unlink(path) await trio.lowlevel.checkpoint() diff --git a/src/trio/_tests/test_exports.py b/src/trio/_tests/test_exports.py index bad9ebec3..5cc7d3a35 100644 --- a/src/trio/_tests/test_exports.py +++ b/src/trio/_tests/test_exports.py @@ -452,8 +452,6 @@ def lookup_symbol(symbol: str) -> dict[str, str]: trio.Process: {"args", "pid", "stderr", "stdin", "stdio", "stdout"}, trio.SSLListener: {"transport_listener"}, trio.SSLStream: {"transport_stream"}, - trio.SocketListener: {"socket"}, - trio.SocketStream: {"socket"}, trio.testing.MemoryReceiveStream: {"close_hook", "receive_some_hook"}, trio.testing.MemorySendStream: { "close_hook", @@ -527,6 +525,12 @@ def lookup_symbol(symbol: str) -> dict[str, str]: print(f"\n{tool} can't see the following symbols in {module_name}:") pprint(errors) + print( + f""" +If there are extra attributes listed, try checking to make sure this test +isn't ignoring them. If there are missing attributes, try looking for why +{tool} isn't seeing them compared to `inspect.getmembers`.""" + ) assert not errors diff --git a/src/trio/_tests/test_highlevel_open_unix_listeners.py b/src/trio/_tests/test_highlevel_open_unix_listeners.py new file mode 100644 index 000000000..8c8721e45 --- /dev/null +++ b/src/trio/_tests/test_highlevel_open_unix_listeners.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import socket as stdlib_socket +import sys +import tempfile +from os import unlink +from os.path import exists +from typing import TYPE_CHECKING, cast + +import pytest + +import trio +import trio.socket as tsocket +from trio import ( + SocketListener, + open_unix_listener, + serve_unix, +) +from trio.testing import open_stream_to_socket_listener + +if TYPE_CHECKING: + from collections.abc import Generator + from pathlib import Path + + from trio.abc import SendStream + +assert not TYPE_CHECKING or sys.platform != "win32" + + +skip_if_not_unix = pytest.mark.skipif( + not hasattr(tsocket, "AF_UNIX"), + reason="Needs unix socket support", +) + + +@pytest.fixture +def temp_unix_socket_path(tmp_path: Path) -> Generator[str, None, None]: + """Fixture to create a temporary Unix socket path.""" + if sys.platform == "darwin": + # On macos, opening unix socket will fail if name is too long + temp_socket_path = tempfile.mkstemp(suffix=".sock")[1] + # mkstemp makes a file, we just wanted a unique name + unlink(temp_socket_path) + else: + temp_socket_path = str(tmp_path / "socket.sock") + yield temp_socket_path + # If test failed to delete file at the end, do it for them. + if exists(temp_socket_path): + unlink(temp_socket_path) + + +@skip_if_not_unix +async def test_open_unix_listener_basic(temp_unix_socket_path: str) -> None: + listener = await open_unix_listener(temp_unix_socket_path) + + assert isinstance(listener, SocketListener) + # Check that the listener is using the Unix socket family + assert listener.socket.family == tsocket.AF_UNIX + assert listener.socket.getsockname() == temp_unix_socket_path + + # Make sure the backlog is at least 2 + c1 = await open_stream_to_socket_listener(listener) + c2 = await open_stream_to_socket_listener(listener) + + s1 = await listener.accept() + s2 = await listener.accept() + + # Note that we don't know which client stream is connected to which server + # stream + await s1.send_all(b"x") + await s2.send_all(b"x") + assert await c1.receive_some(1) == b"x" + assert await c2.receive_some(1) == b"x" + + for resource in [c1, c2, s1, s2, listener]: + await resource.aclose() + + +@skip_if_not_unix +async def test_open_unix_listener_specific_path(temp_unix_socket_path: str) -> None: + listener = await open_unix_listener(temp_unix_socket_path) + async with listener: + assert listener.socket.getsockname() == temp_unix_socket_path + + +@skip_if_not_unix +async def test_open_unix_listener_rebind(temp_unix_socket_path: str) -> None: + listener = await open_unix_listener(temp_unix_socket_path) + sockaddr1 = listener.socket.getsockname() + + # Attempt to bind again to the same socket should fail + with stdlib_socket.socket(tsocket.AF_UNIX) as probe: + with pytest.raises( + OSError, + match=r"(Address (already )?in use|An attempt was made to access a socket in a way forbidden by its access permissions)$", + ): + probe.bind(temp_unix_socket_path) + + # Now use the listener to set up some connections + c_established = await open_stream_to_socket_listener(listener) + s_established = await listener.accept() + await listener.aclose() + + # Attempt to bind again should succeed after closing the listener + listener2 = await open_unix_listener(temp_unix_socket_path) + sockaddr2 = listener2.socket.getsockname() + + assert sockaddr1 == sockaddr2 + assert s_established.socket.getsockname() == sockaddr2 + + for resource in [listener2, c_established, s_established]: + await resource.aclose() + + +@skip_if_not_unix +async def test_serve_unix(temp_unix_socket_path: str) -> None: + async def handler(stream: SendStream) -> None: + await stream.send_all(b"x") + + async with trio.open_nursery() as nursery: + # nursery.start is incorrectly typed, awaiting #2773 + value = await nursery.start(serve_unix, handler, temp_unix_socket_path) + assert isinstance(value, list) + listeners = cast("list[SocketListener]", value) + stream = await open_stream_to_socket_listener(listeners[0]) + async with stream: + assert await stream.receive_some(1) == b"x" + nursery.cancel_scope.cancel() + for listener in listeners: + await listener.aclose() + + +@pytest.mark.skipif(hasattr(tsocket, "AF_UNIX"), reason="Test for non-unix platforms") +async def test_error_on_no_unix(temp_unix_socket_path: str) -> None: + with pytest.raises( + RuntimeError, + match=r"^Unix sockets are not supported on this platform$", + ): + async with await open_unix_listener(temp_unix_socket_path): + pass