diff --git a/lmdeploy/pytorch/engine/executor/base_worker.py b/lmdeploy/pytorch/engine/executor/base_worker.py index 93870c6bd..2cfa3dbd3 100644 --- a/lmdeploy/pytorch/engine/executor/base_worker.py +++ b/lmdeploy/pytorch/engine/executor/base_worker.py @@ -50,6 +50,7 @@ def __init__( logger.setLevel(log_level) self.out_que: asyncio.Queue = None self._output_loop: asyncio.Task = None + self._forward_event: asyncio.Event = None def init_process_group(self, rank: int, master_addr: str = None, master_port: str = None): """Initialize process group.""" @@ -136,7 +137,9 @@ def get_input_processor(self): def start(self): """Start engine loop.""" - self.model_agent.start() + self._forward_event = asyncio.Event() + self._forward_event.set() # Set the event to allow forward calls + self.model_agent.start(self._forward_event) event_loop = asyncio.get_event_loop() self.out_que = asyncio.Queue() self._output_loop = event_loop.create_task(self._get_outputs_loop(), name='GetOutputsLoop') @@ -148,6 +151,7 @@ def stop(self): self._output_loop.cancel() async def stop_async(self): + await self._forward_event.wait() # Ensure forward event is set before stopping await self.model_agent.stop_async() if self._output_loop is not None: self._output_loop.cancel()