66from dataclasses import dataclass
77from queue import Empty as QueueEmpty
88from queue import Queue
9+ from threading import Event
910from typing import (
1011 Any ,
1112 Generic ,
@@ -109,9 +110,11 @@ async def resolve(
109110 ...
110111
111112 async def get_request_time (
112- self , times_queue : Queue [WorkerProcessRequestTime ]
113+ self ,
114+ times_queue : Queue [WorkerProcessRequestTime ],
115+ timeout : Optional [int ] = None ,
113116 ) -> WorkerProcessRequestTime :
114- return await asyncio .to_thread (times_queue .get ) # type: ignore[attr-defined]
117+ return await asyncio .to_thread (times_queue .get , timeout = timeout ) # type: ignore[attr-defined]
115118
116119 async def send_result (
117120 self ,
@@ -181,6 +184,7 @@ async def resolve_scheduler_request(
181184 def process_loop_asynchronous (
182185 self ,
183186 queues : MPQueues [RequestT , ResponseT ],
187+ stop_event : Event ,
184188 prioritize_sessions : bool ,
185189 max_concurrency : int ,
186190 process_id : int ,
@@ -189,7 +193,7 @@ async def _process_runner():
189193 lock = asyncio .Semaphore (max_concurrency )
190194 pending_sessions : list [RequestSession [RequestT , ResponseT ]] = []
191195
192- while True : # TODO: Exit condition
196+ while True :
193197 await asyncio .sleep (0 ) # Yield control to the event loop
194198 await lock .acquire ()
195199
@@ -201,13 +205,16 @@ async def _process_runner():
201205 else queues .requests .get_nowait ()
202206 )
203207 dequeued_time = time .time ()
204- request_times = await self .get_request_time (queues .times )
208+ request_times = await self .get_request_time (queues .times , 5 )
205209 except (QueueEmpty , IndexError ):
206210 # Requeue the session if we don't have a next time yet
207211 if request_session is not None :
208212 pending_sessions .append (request_session )
209213 lock .release ()
210- continue
214+ if stop_event .is_set ():
215+ return # Exit if stop event is set
216+ else :
217+ continue
211218
212219 async def wait_then_requeue (
213220 session : RequestSession [RequestT , ResponseT ],
@@ -309,13 +316,15 @@ async def prepare_multiprocessing(self):
309316 def process_loop_asynchronous (
310317 self ,
311318 queues : MPQueues [GenerationRequest , ResponseSummary ],
319+ stop_event : Event ,
312320 prioritize_sessions : bool ,
313321 max_concurrency : int ,
314322 process_id : int ,
315323 ):
316324 asyncio .run (self .backend .validate ())
317325 super ().process_loop_asynchronous (
318326 queues = queues ,
327+ stop_event = stop_event ,
319328 prioritize_sessions = prioritize_sessions ,
320329 max_concurrency = max_concurrency ,
321330 process_id = process_id ,
0 commit comments