Skip to content

Commit 838227e

Browse files
committed
mypy
1 parent 89d2064 commit 838227e

File tree

4 files changed

+89
-23
lines changed

4 files changed

+89
-23
lines changed

asgiref/_asyncio.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,33 @@
66
]
77

88
import asyncio
9+
import concurrent.futures
910
import contextvars
1011
import functools
1112
import sys
13+
import types
1214
from asyncio import get_running_loop
13-
from collections.abc import Callable
14-
from typing import Any, TypeVar
15+
from collections.abc import Callable, Coroutine
16+
from typing import Any, Protocol, TypeVar, Union, Generic
1517

1618
from ._context import restore_context as _restore_context
1719

1820
_R = TypeVar("_R")
1921

22+
Coro = Coroutine[Any, Any, _R]
2023

21-
def create_task_threadsafe(loop, awaitable) -> None:
24+
25+
def create_task_threadsafe(
26+
loop: asyncio.AbstractEventLoop, awaitable: Coro[object]
27+
) -> None:
2228
loop.call_soon_threadsafe(loop.create_task, awaitable)
2329

2430

25-
async def wrap_task_context(loop, task_context, awaitable):
31+
async def wrap_task_context(
32+
loop: asyncio.AbstractEventLoop,
33+
task_context: list[asyncio.Task[Any]],
34+
awaitable: Coro[_R],
35+
) -> _R:
2636
if task_context is None:
2737
return await awaitable
2838

@@ -37,8 +47,30 @@ async def wrap_task_context(loop, task_context, awaitable):
3747
task_context.remove(current_task)
3848

3949

50+
ExcInfo = Union[
51+
tuple[type[BaseException], BaseException, types.TracebackType],
52+
tuple[None, None, None],
53+
]
54+
55+
56+
class ThreadHandlerType(Protocol, Generic[_R]):
57+
def __call__(
58+
self,
59+
loop: asyncio.AbstractEventLoop,
60+
exc_info: ExcInfo,
61+
task_context: list[asyncio.Task[Any]],
62+
func: Callable[[Callable[[], _R]], _R],
63+
child: Callable[[], _R],
64+
) -> _R:
65+
...
66+
67+
4068
async def run_in_executor(
41-
*, loop, executor, thread_handler, child: Callable[[], _R]
69+
*,
70+
loop: asyncio.AbstractEventLoop,
71+
executor: concurrent.futures.ThreadPoolExecutor,
72+
thread_handler: ThreadHandlerType[_R],
73+
child: Callable[[], _R],
4274
) -> _R:
4375
context = contextvars.copy_context()
4476
func = context.run

asgiref/_trio.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
import contextvars
33
import functools
44
import sys
5-
from typing import Any
5+
import types
6+
from collections.abc import Callable, Coroutine, Awaitable
7+
from typing import Any, Protocol, TypeVar, Union, Generic
8+
import concurrent.futures
69

710
import sniffio
811
import trio.lowlevel
@@ -11,12 +14,20 @@
1114
from . import _asyncio
1215
from ._context import restore_context as _restore_context
1316

17+
_R = TypeVar("_R")
18+
19+
Coro = Coroutine[Any, Any, _R]
20+
21+
Loop = Union[asyncio.AbstractEventLoop, trio.lowlevel.TrioToken]
22+
TaskContext = list[Any]
23+
1424

1525
class TrioThreadCancelled(BaseException):
1626
pass
1727

1828

19-
def get_running_loop():
29+
def get_running_loop() -> Loop:
30+
2031
try:
2132
asynclib = sniffio.current_async_library()
2233
except sniffio.AsyncLibraryNotFoundError:
@@ -25,16 +36,16 @@ def get_running_loop():
2536
if asynclib == "asyncio":
2637
return asyncio.get_running_loop()
2738
if asynclib == "trio":
28-
return trio.lowlevel.current_token()
39+
return trio.lowlevel.current_trio_token()
2940
raise RuntimeError(f"unsupported library {asynclib}")
3041

3142

3243
@trio.lowlevel.disable_ki_protection
33-
async def wrap_awaitable(awaitable):
44+
async def wrap_awaitable(awaitable: Awaitable[_R]) -> _R:
3445
return await awaitable
3546

3647

37-
def create_task_threadsafe(loop, awaitable):
48+
def create_task_threadsafe(loop: Loop, awaitable: Coro[_R]) -> None:
3849
if isinstance(loop, trio.lowlevel.TrioToken):
3950
try:
4051
loop.run_sync_soon(
@@ -44,15 +55,34 @@ def create_task_threadsafe(loop, awaitable):
4455
)
4556
except trio.RunFinishedError:
4657
raise RuntimeError("trio loop no-longer running")
58+
return
59+
60+
_asyncio.create_task_threadsafe(loop, awaitable)
61+
62+
63+
ExcInfo = Union[
64+
tuple[type[BaseException], BaseException, types.TracebackType],
65+
tuple[None, None, None],
66+
]
67+
4768

48-
return _asyncio.create_task_threadsafe(loop, awaitable)
69+
class ThreadHandlerType(Protocol, Generic[_R]):
70+
def __call__(
71+
self,
72+
loop: Loop,
73+
exc_info: ExcInfo,
74+
task_context: TaskContext,
75+
func: Callable[[Callable[[], _R]], _R],
76+
child: Callable[[], _R],
77+
) -> _R:
78+
...
4979

5080

51-
async def run_in_executor(*, loop, executor, thread_handler, child):
81+
async def run_in_executor(*, loop: Loop, executor: concurrent.futures.ThreadPoolExecutor, thread_handler: ThreadHandlerType[_R], child: Callable[[], _R]) -> _R:
5282
if isinstance(loop, trio.lowlevel.TrioToken):
5383
context = contextvars.copy_context()
5484
func = context.run
55-
task_context: list[asyncio.Task[Any]] = []
85+
task_context: TaskContext = []
5686

5787
# Run the code in the right thread
5888
full_func = functools.partial(
@@ -66,7 +96,7 @@ async def run_in_executor(*, loop, executor, thread_handler, child):
6696
try:
6797
if executor is None:
6898

69-
async def handle_cancel():
99+
async def handle_cancel() -> None:
70100
try:
71101
await trio.sleep_forever()
72102
except trio.Cancelled:
@@ -84,16 +114,17 @@ async def handle_cancel():
84114
pass
85115
finally:
86116
nursery.cancel_scope.cancel()
117+
assert False
87118
else:
88119
event = trio.Event()
89120

90-
def callback(fut):
121+
def callback(fut: object) -> None:
91122
loop.run_sync_soon(event.set)
92123

93124
fut = executor.submit(full_func)
94125
fut.add_done_callback(callback)
95126

96-
async def handle_cancel_fut():
127+
async def handle_cancel_fut() -> None:
97128
try:
98129
await trio.sleep_forever()
99130
except trio.Cancelled:
@@ -111,15 +142,17 @@ async def handle_cancel_fut():
111142
return fut.result()
112143
except TrioThreadCancelled:
113144
pass
145+
assert False
114146
finally:
115147
_restore_context(context)
116148

117-
return await _asyncio.run_in_executor(
118-
loop=loop, executor=executor, thread_handler=thread_handler, func=func
149+
else:
150+
return await _asyncio.run_in_executor(
151+
loop=loop, executor=executor, thread_handler=thread_handler, child=child
119152
)
120153

121154

122-
async def wrap_task_context(loop, task_context, awaitable):
155+
async def wrap_task_context(loop: Loop, task_context: Union[TaskContext, None], awaitable: Coro[_R]) -> _R:
123156
if task_context is None:
124157
return await awaitable
125158

@@ -130,7 +163,6 @@ async def wrap_task_context(loop, task_context, awaitable):
130163
return await awaitable
131164
finally:
132165
task_context.remove(scope)
133-
if scope.cancelled_caught:
134-
raise TrioThreadCancelled
166+
raise TrioThreadCancelled
135167

136168
return await _asyncio.wrap_task_context(loop, task_context, awaitable)

asgiref/sync.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def __init__(
165165
],
166166
force_new_loop: Union[LoopType, bool] = False,
167167
):
168-
if force_new_loop and not isinstance(LoopType):
168+
if force_new_loop and not isinstance(force_new_loop, LoopType):
169169
force_new_loop = LoopType.ASYNCIO
170170

171171
if not callable(awaitable) or (
@@ -319,14 +319,15 @@ async def main_wrap(
319319
if context is not None:
320320
_restore_context(context[0])
321321

322+
result: _R
322323
try:
323324
# If we have an exception, run the function inside the except block
324325
# after raising it so exc_info is correctly populated.
325326
if exc_info[1]:
326327
try:
327328
raise exc_info[1]
328329
except BaseException:
329-
result = await wrap_task_context(loop, task_context, awaitable)
330+
result = await wrap_task_context(loop, task_context, awaitable)
330331
else:
331332
result = await wrap_task_context(loop, task_context, awaitable)
332333
except BaseException as e:

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ commands =
1111
mypy: mypy . {posargs}
1212
deps =
1313
setuptools
14+
mypy: trio
1415

1516
[testenv:qa]
1617
skip_install = true

0 commit comments

Comments
 (0)