Skip to content

Commit 739ea49

Browse files
Allow async exception handlers to type-check (#2949)
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent 78da9b9 commit 739ea49

File tree

8 files changed

+32
-16
lines changed

8 files changed

+32
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ classifiers = [
2727
]
2828
dependencies = [
2929
"anyio>=3.6.2,<5",
30-
"typing_extensions>=3.10.0; python_version < '3.10'",
30+
"typing_extensions>=4.10.0; python_version < '3.13'",
3131
]
3232

3333
[project.optional-dependencies]

starlette/_exception_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ async def sender(message: Message) -> None:
5858
if is_async_callable(handler):
5959
response = await handler(conn, exc)
6060
else:
61-
response = await run_in_threadpool(handler, conn, exc) # type: ignore
61+
response = await run_in_threadpool(handler, conn, exc)
6262
if response is not None:
6363
await response(scope, receive, sender)
6464

starlette/_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
from starlette.types import Scope
1111

12-
if sys.version_info >= (3, 10): # pragma: no cover
13-
from typing import TypeGuard
12+
if sys.version_info >= (3, 13): # pragma: no cover
13+
from typing import TypeIs
1414
else: # pragma: no cover
15-
from typing_extensions import TypeGuard
15+
from typing_extensions import TypeIs
1616

1717
has_exceptiongroups = True
1818
if sys.version_info < (3, 11): # pragma: no cover
@@ -26,11 +26,11 @@
2626

2727

2828
@overload
29-
def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ...
29+
def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
3030

3131

3232
@overload
33-
def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ...
33+
def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
3434

3535

3636
def is_async_callable(obj: Any) -> Any:

starlette/applications.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(
8080
def build_middleware_stack(self) -> ASGIApp:
8181
debug = self.debug
8282
error_handler = None
83-
exception_handlers: dict[Any, Callable[[Request, Exception], Response]] = {}
83+
exception_handlers: dict[Any, ExceptionHandler] = {}
8484

8585
for key, value in self.exception_handlers.items():
8686
if key in (500, Exception):

starlette/middleware/errors.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
import inspect
55
import sys
66
import traceback
7-
from typing import Any, Callable
87

98
from starlette._utils import is_async_callable
109
from starlette.concurrency import run_in_threadpool
1110
from starlette.requests import Request
1211
from starlette.responses import HTMLResponse, PlainTextResponse, Response
13-
from starlette.types import ASGIApp, Message, Receive, Scope, Send
12+
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
1413

1514
STYLES = """
1615
p {
@@ -140,7 +139,7 @@ class ServerErrorMiddleware:
140139
def __init__(
141140
self,
142141
app: ASGIApp,
143-
handler: Callable[[Request, Exception], Any] | None = None,
142+
handler: ExceptionHandler | None = None,
144143
debug: bool = False,
145144
) -> None:
146145
self.app = app

starlette/middleware/exceptions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Mapping
4-
from typing import Any, Callable
4+
from typing import Any
55

66
from starlette._exception_handler import (
77
ExceptionHandlers,
@@ -11,15 +11,15 @@
1111
from starlette.exceptions import HTTPException, WebSocketException
1212
from starlette.requests import Request
1313
from starlette.responses import PlainTextResponse, Response
14-
from starlette.types import ASGIApp, Receive, Scope, Send
14+
from starlette.types import ASGIApp, ExceptionHandler, Receive, Scope, Send
1515
from starlette.websockets import WebSocket
1616

1717

1818
class ExceptionMiddleware:
1919
def __init__(
2020
self,
2121
app: ASGIApp,
22-
handlers: Mapping[Any, Callable[[Request, Exception], Response]] | None = None,
22+
handlers: Mapping[Any, ExceptionHandler] | None = None,
2323
debug: bool = False,
2424
) -> None:
2525
self.app = app
@@ -36,7 +36,7 @@ def __init__(
3636
def add_exception_handler(
3737
self,
3838
exc_class_or_status_code: int | type[Exception],
39-
handler: Callable[[Request, Exception], Response],
39+
handler: ExceptionHandler,
4040
) -> None:
4141
if isinstance(exc_class_or_status_code, int):
4242
self._status_handlers[exc_class_or_status_code] = handler

starlette/routing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def request_response(
6565
and returns an ASGI application.
6666
"""
6767
f: Callable[[Request], Awaitable[Response]] = (
68-
func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore
68+
func if is_async_callable(func) else functools.partial(run_in_threadpool, func)
6969
)
7070

7171
async def app(scope: Scope, receive: Receive, send: Send) -> None:

tests/test_exceptions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,20 @@ def mock_run_in_threadpool(*args: Any, **kwargs: Any) -> None:
205205
# This should succeed because http_exception is async and won't use run_in_threadpool
206206
response = client.get("/not_acceptable")
207207
assert response.status_code == 406
208+
209+
210+
def test_handlers_annotations() -> None:
211+
"""Check that async exception handlers are accepted by type checkers.
212+
213+
We annotate the handlers' exceptions with plain `Exception` to avoid variance issues
214+
when using other exception types.
215+
"""
216+
217+
async def async_catch_all_handler(request: Request, exc: Exception) -> JSONResponse:
218+
raise NotImplementedError
219+
220+
def sync_catch_all_handler(request: Request, exc: Exception) -> JSONResponse:
221+
raise NotImplementedError
222+
223+
ExceptionMiddleware(router, handlers={Exception: sync_catch_all_handler})
224+
ExceptionMiddleware(router, handlers={Exception: async_catch_all_handler})

0 commit comments

Comments
 (0)