Skip to content

Allow async exception handlers to type-check #2949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
]
dependencies = [
"anyio>=3.6.2,<5",
"typing_extensions>=3.10.0; python_version < '3.10'",
"typing_extensions>=4.10.0; python_version < '3.13'",
]

[project.optional-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion starlette/_exception_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def sender(message: Message) -> None:
if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc) # type: ignore
response = await run_in_threadpool(handler, conn, exc)
if response is not None:
await response(scope, receive, sender)

Expand Down
10 changes: 5 additions & 5 deletions starlette/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

from starlette.types import Scope

if sys.version_info >= (3, 10): # pragma: no cover
from typing import TypeGuard
if sys.version_info >= (3, 13): # pragma: no cover
from typing import TypeIs
else: # pragma: no cover
from typing_extensions import TypeGuard
from typing_extensions import TypeIs

has_exceptiongroups = True
if sys.version_info < (3, 11): # pragma: no cover
Expand All @@ -26,11 +26,11 @@


@overload
def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ...
def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...


@overload
def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ...
def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...


def is_async_callable(obj: Any) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
error_handler = None
exception_handlers: dict[Any, Callable[[Request, Exception], Response]] = {}
exception_handlers: dict[Any, ExceptionHandler] = {}

for key, value in self.exception_handlers.items():
if key in (500, Exception):
Expand Down
5 changes: 2 additions & 3 deletions starlette/middleware/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import inspect
import sys
import traceback
from typing import Any, Callable

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import HTMLResponse, PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send

STYLES = """
p {
Expand Down Expand Up @@ -140,7 +139,7 @@ class ServerErrorMiddleware:
def __init__(
self,
app: ASGIApp,
handler: Callable[[Request, Exception], Any] | None = None,
handler: ExceptionHandler | None = None,
debug: bool = False,
) -> None:
self.app = app
Expand Down
8 changes: 4 additions & 4 deletions starlette/middleware/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import Any, Callable
from typing import Any

from starlette._exception_handler import (
ExceptionHandlers,
Expand All @@ -11,15 +11,15 @@
from starlette.exceptions import HTTPException, WebSocketException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, ExceptionHandler, Receive, Scope, Send
from starlette.websockets import WebSocket


class ExceptionMiddleware:
def __init__(
self,
app: ASGIApp,
handlers: Mapping[Any, Callable[[Request, Exception], Response]] | None = None,
handlers: Mapping[Any, ExceptionHandler] | None = None,
debug: bool = False,
) -> None:
self.app = app
Expand All @@ -36,7 +36,7 @@ def __init__(
def add_exception_handler(
self,
exc_class_or_status_code: int | type[Exception],
handler: Callable[[Request, Exception], Response],
handler: ExceptionHandler,
) -> None:
if isinstance(exc_class_or_status_code, int):
self._status_handlers[exc_class_or_status_code] = handler
Expand Down
2 changes: 1 addition & 1 deletion starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def request_response(
and returns an ASGI application.
"""
f: Callable[[Request], Awaitable[Response]] = (
func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore
func if is_async_callable(func) else functools.partial(run_in_threadpool, func)
)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
Expand Down
17 changes: 17 additions & 0 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,20 @@ def mock_run_in_threadpool(*args: Any, **kwargs: Any) -> None:
# This should succeed because http_exception is async and won't use run_in_threadpool
response = client.get("/not_acceptable")
assert response.status_code == 406


def test_handlers_annotations() -> None:
"""Check that async exception handlers are accepted by type checkers.

We annotate the handlers' exceptions with plain `Exception` to avoid variance issues
when using other exception types.
"""

async def async_catch_all_handler(request: Request, exc: Exception) -> JSONResponse:
raise NotImplementedError

def sync_catch_all_handler(request: Request, exc: Exception) -> JSONResponse:
raise NotImplementedError

ExceptionMiddleware(router, handlers={Exception: sync_catch_all_handler})
ExceptionMiddleware(router, handlers={Exception: async_catch_all_handler})