22import contextvars
33import functools
44import sys
5- from typing import Any
5+ import types
6+ from collections .abc import Callable , Coroutine , Awaitable
7+ from typing import Any , Protocol , TypeVar , Union , Generic
8+ import concurrent .futures
69
710import sniffio
811import trio .lowlevel
1114from . import _asyncio
1215from ._context import restore_context as _restore_context
1316
17+ _R = TypeVar ("_R" )
18+
19+ Coro = Coroutine [Any , Any , _R ]
20+
21+ Loop = Union [asyncio .AbstractEventLoop , trio .lowlevel .TrioToken ]
22+ TaskContext = list [Any ]
23+
1424
1525class TrioThreadCancelled (BaseException ):
1626 pass
1727
1828
19- def get_running_loop ():
29+ def get_running_loop () -> Loop :
30+
2031 try :
2132 asynclib = sniffio .current_async_library ()
2233 except sniffio .AsyncLibraryNotFoundError :
@@ -25,16 +36,16 @@ def get_running_loop():
2536 if asynclib == "asyncio" :
2637 return asyncio .get_running_loop ()
2738 if asynclib == "trio" :
28- return trio .lowlevel .current_token ()
39+ return trio .lowlevel .current_trio_token ()
2940 raise RuntimeError (f"unsupported library { asynclib } " )
3041
3142
3243@trio .lowlevel .disable_ki_protection
33- async def wrap_awaitable (awaitable ) :
44+ async def wrap_awaitable (awaitable : Awaitable [ _R ]) -> _R :
3445 return await awaitable
3546
3647
37- def create_task_threadsafe (loop , awaitable ) :
48+ def create_task_threadsafe (loop : Loop , awaitable : Coro [ _R ]) -> None :
3849 if isinstance (loop , trio .lowlevel .TrioToken ):
3950 try :
4051 loop .run_sync_soon (
@@ -44,15 +55,34 @@ def create_task_threadsafe(loop, awaitable):
4455 )
4556 except trio .RunFinishedError :
4657 raise RuntimeError ("trio loop no-longer running" )
58+ return
59+
60+ _asyncio .create_task_threadsafe (loop , awaitable )
61+
62+
63+ ExcInfo = Union [
64+ tuple [type [BaseException ], BaseException , types .TracebackType ],
65+ tuple [None , None , None ],
66+ ]
67+
4768
48- return _asyncio .create_task_threadsafe (loop , awaitable )
69+ class ThreadHandlerType (Protocol , Generic [_R ]):
70+ def __call__ (
71+ self ,
72+ loop : Loop ,
73+ exc_info : ExcInfo ,
74+ task_context : TaskContext ,
75+ func : Callable [[Callable [[], _R ]], _R ],
76+ child : Callable [[], _R ],
77+ ) -> _R :
78+ ...
4979
5080
51- async def run_in_executor (* , loop , executor , thread_handler , child ) :
81+ async def run_in_executor (* , loop : Loop , executor : concurrent . futures . ThreadPoolExecutor , thread_handler : ThreadHandlerType [ _R ] , child : Callable [[], _R ]) -> _R :
5282 if isinstance (loop , trio .lowlevel .TrioToken ):
5383 context = contextvars .copy_context ()
5484 func = context .run
55- task_context : list [ asyncio . Task [ Any ]] = []
85+ task_context : TaskContext = []
5686
5787 # Run the code in the right thread
5888 full_func = functools .partial (
@@ -66,7 +96,7 @@ async def run_in_executor(*, loop, executor, thread_handler, child):
6696 try :
6797 if executor is None :
6898
69- async def handle_cancel ():
99+ async def handle_cancel () -> None :
70100 try :
71101 await trio .sleep_forever ()
72102 except trio .Cancelled :
@@ -84,16 +114,17 @@ async def handle_cancel():
84114 pass
85115 finally :
86116 nursery .cancel_scope .cancel ()
117+ assert False
87118 else :
88119 event = trio .Event ()
89120
90- def callback (fut ) :
121+ def callback (fut : object ) -> None :
91122 loop .run_sync_soon (event .set )
92123
93124 fut = executor .submit (full_func )
94125 fut .add_done_callback (callback )
95126
96- async def handle_cancel_fut ():
127+ async def handle_cancel_fut () -> None :
97128 try :
98129 await trio .sleep_forever ()
99130 except trio .Cancelled :
@@ -111,15 +142,17 @@ async def handle_cancel_fut():
111142 return fut .result ()
112143 except TrioThreadCancelled :
113144 pass
145+ assert False
114146 finally :
115147 _restore_context (context )
116148
117- return await _asyncio .run_in_executor (
118- loop = loop , executor = executor , thread_handler = thread_handler , func = func
149+ else :
150+ return await _asyncio .run_in_executor (
151+ loop = loop , executor = executor , thread_handler = thread_handler , child = child
119152 )
120153
121154
122- async def wrap_task_context (loop , task_context , awaitable ) :
155+ async def wrap_task_context (loop : Loop , task_context : Union [ TaskContext , None ], awaitable : Coro [ _R ]) -> _R :
123156 if task_context is None :
124157 return await awaitable
125158
@@ -130,7 +163,6 @@ async def wrap_task_context(loop, task_context, awaitable):
130163 return await awaitable
131164 finally :
132165 task_context .remove (scope )
133- if scope .cancelled_caught :
134- raise TrioThreadCancelled
166+ raise TrioThreadCancelled
135167
136168 return await _asyncio .wrap_task_context (loop , task_context , awaitable )
0 commit comments