diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 1e3eef6cc..fae8c813c 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -268,9 +268,10 @@ async def handle_post_chat_completions(self, request): callback = self.node.on_token.register(callback_id) if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}") - asyncio.create_task(self.node.process_prompt(shard, prompt, image_str, request_id=request_id)) try: + await asyncio.wait_for(self.node.process_prompt(shard, prompt, image_str, request_id=request_id), timeout=self.response_timeout) + if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s") if stream: @@ -284,9 +285,9 @@ async def handle_post_chat_completions(self, request): ) await response.prepare(request) - async def stream_result(request_id: str, tokens: List[int], is_finished: bool): - prev_last_tokens_len = self.prev_token_lens.get(request_id, 0) - self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens)) + async def stream_result(_request_id: str, tokens: List[int], is_finished: bool): + prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0) + self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens)) new_tokens = tokens[prev_last_tokens_len:] finish_reason = None eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, @@ -316,7 +317,7 @@ async def stream_result(request_id: str, tokens: List[int], is_finished: bool): if DEBUG >= 2: traceback.print_exc() def on_result(_request_id: str, tokens: List[int], is_finished: bool): - self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished)) + if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished)) return _request_id == request_id and is_finished @@ -345,6 +346,9 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool): return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion")) except asyncio.TimeoutError: return web.json_response({"detail": "Response generation timed out"}, status=408) + except Exception as e: + if DEBUG >= 2: traceback.print_exc() + return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500) finally: deregistered_callback = self.node.on_token.deregister(callback_id) if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")