From e742113a430cb7536b94877d1fe96aa6ec5e1bcf Mon Sep 17 00:00:00 2001 From: Nandan Date: Wed, 5 Feb 2025 18:28:09 +0530 Subject: [PATCH 1/6] fix: collect message from stream --- .../conversation/conversation_schema.py | 3 ++ .../conversation/conversation_service.py | 34 ++++++++++++++----- .../conversations/conversations_router.py | 15 ++++---- 3 files changed, 38 insertions(+), 14 deletions(-) diff --git a/app/modules/conversations/conversation/conversation_schema.py b/app/modules/conversations/conversation/conversation_schema.py index 535c5ec4..b8f685f8 100644 --- a/app/modules/conversations/conversation/conversation_schema.py +++ b/app/modules/conversations/conversation/conversation_schema.py @@ -46,6 +46,9 @@ class ConversationInfoResponse(BaseModel): class Config: from_attributes = True +class ChatMessageResponse(BaseModel): + message: str + citations: List[str] # Resolve forward references ConversationInfoResponse.update_forward_refs() diff --git a/app/modules/conversations/conversation/conversation_service.py b/app/modules/conversations/conversation/conversation_service.py index 51e2599d..d8ce690a 100644 --- a/app/modules/conversations/conversation/conversation_service.py +++ b/app/modules/conversations/conversation/conversation_service.py @@ -21,6 +21,7 @@ ConversationAccessType, ConversationInfoResponse, CreateConversationRequest, + ChatMessageResponse ) from app.modules.conversations.message.message_model import ( Message, @@ -46,6 +47,7 @@ from app.modules.projects.projects_service import ProjectService from app.modules.users.user_service import UserService from app.modules.utils.posthog_helper import PostHogClient +import json logger = logging.getLogger(__name__) @@ -487,12 +489,20 @@ async def store_message( ): yield chunk else: - # For non-streaming, collect all chunks and store as a single message - full_response = "" + full_message = "" + all_citations = [] async for chunk in self._generate_and_stream_ai_response( message.content, conversation_id, user_id, message.node_ids ): - full_response += chunk + data = json.loads(chunk) + + # Extract the 'message' and 'citations' + message: str = data.get('message', '') + citations: List[str] = data.get('citations', []) + + full_message += message + all_citations = all_citations + citations + # # Store the complete response as a single message # self.history_manager.add_message_chunk( # conversation_id, full_response, MessageType.AI, user_id @@ -500,7 +510,7 @@ async def store_message( # self.history_manager.flush_message_buffer( # conversation_id, MessageType.AI, user_id # ) - yield full_response + yield ChatMessageResponse(message=full_message,citations=all_citations).json() except AccessTypeReadError: raise @@ -595,12 +605,20 @@ async def regenerate_last_message( ): yield chunk else: - # For non-streaming, collect all chunks and store as a single message - full_response = "" + full_message = "" + all_citations = [] + async for chunk in self._generate_and_stream_ai_response( last_human_message.content, conversation_id, user_id, node_ids ): - full_response += chunk + data = json.loads(chunk) + + # Extract the 'message' and 'citations' + message: str = data.get('message', '') + citations: List[str] = data.get('citations', []) + + full_message += message + all_citations = all_citations + citations # # Store the complete response as a single message # self.history_manager.add_message_chunk( # conversation_id, full_response, MessageType.AI, user_id @@ -608,7 +626,7 @@ async def regenerate_last_message( # self.history_manager.flush_message_buffer( # conversation_id, MessageType.AI, user_id # ) - yield full_response + yield ChatMessageResponse(message=full_message,citations=all_citations).json() except AccessTypeReadError: raise diff --git a/app/modules/conversations/conversations_router.py b/app/modules/conversations/conversations_router.py index 0a6bdefd..df5e0915 100644 --- a/app/modules/conversations/conversations_router.py +++ b/app/modules/conversations/conversations_router.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends, HTTPException, Query -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse from sqlalchemy.orm import Session from app.core.database import get_db @@ -26,6 +26,7 @@ RenameConversationRequest, ) from .message.message_schema import MessageRequest, MessageResponse, RegenerateRequest +import json router = APIRouter() @@ -100,11 +101,12 @@ async def post_message( if stream: return StreamingResponse(message_stream, media_type="text/event-stream") else: - # Collect all chunks into a complete response + # TODO: fix this, add types. In below stream we have only one output. + # no need of stream here full_response = "" async for chunk in message_stream: full_response += chunk - return {"content": full_response} + return json.loads(full_response) @staticmethod @router.post( @@ -126,11 +128,12 @@ async def regenerate_last_message( if stream: return StreamingResponse(message_stream, media_type="text/event-stream") else: - # Collect all chunks into a complete response + # TODO: fix this, add types. In below stream we have only one output. + # no need of stream here full_response = "" async for chunk in message_stream: full_response += chunk - return {"content": full_response} + return json.loads(full_response) @staticmethod @router.delete("/conversations/{conversation_id}/", response_model=dict) @@ -222,4 +225,4 @@ async def remove_access( ) return {"message": "Access removed successfully"} except ShareChatServiceError as e: - raise HTTPException(status_code=400, detail=str(e)) + raise HTTPException(status_code=400, detail=str(e)) \ No newline at end of file From e5c2c220d52b27bde91edce56d60e12398b066c7 Mon Sep 17 00:00:00 2001 From: Nandan Date: Thu, 6 Feb 2025 10:16:03 +0530 Subject: [PATCH 2/6] chore: add type for chat streaming --- .../conversation/conversation_controller.py | 26 +++++--------- .../conversation/conversation_schema.py | 5 ++- .../conversations/conversations_router.py | 35 ++++++++----------- 3 files changed, 27 insertions(+), 39 deletions(-) diff --git a/app/modules/conversations/conversation/conversation_controller.py b/app/modules/conversations/conversation/conversation_controller.py index 3e7e8b6a..8fac703f 100644 --- a/app/modules/conversations/conversation/conversation_controller.py +++ b/app/modules/conversations/conversation/conversation_controller.py @@ -4,23 +4,15 @@ from sqlalchemy.orm import Session from app.modules.conversations.conversation.conversation_schema import ( - ConversationInfoResponse, - CreateConversationRequest, - CreateConversationResponse, -) + ChatMessageResponse, ConversationInfoResponse, CreateConversationRequest, + CreateConversationResponse) from app.modules.conversations.conversation.conversation_service import ( - AccessTypeNotFoundError, - AccessTypeReadError, - ConversationNotFoundError, - ConversationService, - ConversationServiceError, -) + AccessTypeNotFoundError, AccessTypeReadError, ConversationNotFoundError, + ConversationService, ConversationServiceError) from app.modules.conversations.message.message_model import MessageType -from app.modules.conversations.message.message_schema import ( - MessageRequest, - MessageResponse, - NodeContext, -) +from app.modules.conversations.message.message_schema import (MessageRequest, + MessageResponse, + NodeContext) class ConversationController: @@ -82,7 +74,7 @@ async def get_conversation_messages( async def post_message( self, conversation_id: str, message: MessageRequest, stream: bool = True - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[ChatMessageResponse, None]: try: async for chunk in self.service.store_message( conversation_id, message, MessageType.HUMAN, self.user_id, stream @@ -100,7 +92,7 @@ async def regenerate_last_message( conversation_id: str, node_ids: List[NodeContext] = [], stream: bool = True, - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[ChatMessageResponse, None]: try: async for chunk in self.service.regenerate_last_message( conversation_id, self.user_id, node_ids, stream diff --git a/app/modules/conversations/conversation/conversation_schema.py b/app/modules/conversations/conversation/conversation_schema.py index b8f685f8..93f67198 100644 --- a/app/modules/conversations/conversation/conversation_schema.py +++ b/app/modules/conversations/conversation/conversation_schema.py @@ -4,7 +4,8 @@ from pydantic import BaseModel -from app.modules.conversations.conversation.conversation_model import ConversationStatus +from app.modules.conversations.conversation.conversation_model import \ + ConversationStatus class CreateConversationRequest(BaseModel): @@ -46,10 +47,12 @@ class ConversationInfoResponse(BaseModel): class Config: from_attributes = True + class ChatMessageResponse(BaseModel): message: str citations: List[str] + # Resolve forward references ConversationInfoResponse.update_forward_refs() diff --git a/app/modules/conversations/conversations_router.py b/app/modules/conversations/conversations_router.py index df5e0915..b0acbf32 100644 --- a/app/modules/conversations/conversations_router.py +++ b/app/modules/conversations/conversations_router.py @@ -1,32 +1,25 @@ +import json from typing import List from fastapi import APIRouter, Depends, HTTPException, Query -from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from app.core.database import get_db from app.modules.auth.auth_service import AuthService from app.modules.conversations.access.access_schema import ( - RemoveAccessRequest, - ShareChatRequest, - ShareChatResponse, -) + RemoveAccessRequest, ShareChatRequest, ShareChatResponse) from app.modules.conversations.access.access_service import ( - ShareChatService, - ShareChatServiceError, -) -from app.modules.conversations.conversation.conversation_controller import ( - ConversationController, -) - -from .conversation.conversation_schema import ( - ConversationInfoResponse, - CreateConversationRequest, - CreateConversationResponse, - RenameConversationRequest, -) -from .message.message_schema import MessageRequest, MessageResponse, RegenerateRequest -import json + ShareChatService, ShareChatServiceError) +from app.modules.conversations.conversation.conversation_controller import \ + ConversationController + +from .conversation.conversation_schema import (ConversationInfoResponse, + CreateConversationRequest, + CreateConversationResponse, + RenameConversationRequest) +from .message.message_schema import (MessageRequest, MessageResponse, + RegenerateRequest) router = APIRouter() @@ -225,4 +218,4 @@ async def remove_access( ) return {"message": "Access removed successfully"} except ShareChatServiceError as e: - raise HTTPException(status_code=400, detail=str(e)) \ No newline at end of file + raise HTTPException(status_code=400, detail=str(e)) From 291fc9f4c219e1f0871629175d3eb2edbaf4ad72 Mon Sep 17 00:00:00 2001 From: Nandan Date: Thu, 6 Feb 2025 10:17:22 +0530 Subject: [PATCH 3/6] chore: add type for chat streaming --- .../conversation/conversation_service.py | 123 ++++++++++-------- 1 file changed, 68 insertions(+), 55 deletions(-) diff --git a/app/modules/conversations/conversation/conversation_service.py b/app/modules/conversations/conversation/conversation_service.py index d8ce690a..0fa553b3 100644 --- a/app/modules/conversations/conversation/conversation_service.py +++ b/app/modules/conversations/conversation/conversation_service.py @@ -1,4 +1,5 @@ import asyncio +import json import logging from datetime import datetime, timezone from typing import Any, AsyncGenerator, Dict, List, Optional, TypedDict @@ -13,41 +14,29 @@ from app.modules.code_provider.code_provider_service import CodeProviderService from app.modules.conversations.conversation.conversation_model import ( - Conversation, - ConversationStatus, - Visibility, -) + Conversation, ConversationStatus, Visibility) from app.modules.conversations.conversation.conversation_schema import ( - ConversationAccessType, - ConversationInfoResponse, - CreateConversationRequest, - ChatMessageResponse -) -from app.modules.conversations.message.message_model import ( - Message, - MessageStatus, - MessageType, -) -from app.modules.conversations.message.message_schema import ( - MessageRequest, - MessageResponse, - NodeContext, -) + ChatMessageResponse, ConversationAccessType, ConversationInfoResponse, + CreateConversationRequest) +from app.modules.conversations.message.message_model import (Message, + MessageStatus, + MessageType) +from app.modules.conversations.message.message_schema import (MessageRequest, + MessageResponse, + NodeContext) from app.modules.intelligence.agents.agent_factory import AgentFactory -from app.modules.intelligence.agents.agent_injector_service import AgentInjectorService +from app.modules.intelligence.agents.agent_injector_service import \ + AgentInjectorService from app.modules.intelligence.agents.agents_service import AgentsService -from app.modules.intelligence.agents.custom_agents.custom_agents_service import ( - CustomAgentsService, -) -from app.modules.intelligence.memory.chat_history_service import ChatHistoryService +from app.modules.intelligence.agents.custom_agents.custom_agents_service import \ + CustomAgentsService +from app.modules.intelligence.memory.chat_history_service import \ + ChatHistoryService from app.modules.intelligence.provider.provider_service import ( - AgentType, - ProviderService, -) + AgentType, ProviderService) from app.modules.projects.projects_service import ProjectService from app.modules.users.user_service import UserService from app.modules.utils.posthog_helper import PostHogClient -import json logger = logging.getLogger(__name__) @@ -77,7 +66,7 @@ def __init__(self, db, provider_service): self.db = db self.provider_service = provider_service self.agent = None - self.current_agent_id = None + self.current_agent_id = None self.classifier = None self.agents_service = AgentsService(db) self.agent_factory = AgentFactory(db, provider_service) @@ -150,17 +139,24 @@ async def classifier_node(self, state: State) -> Command: return Command(update={"response": "No query provided"}, goto=END) agent_list = {agent.id: agent.status for agent in self.available_agents} - + # First check - if this is a custom agent (non-SYSTEM), route directly - if state["agent_id"] in agent_list and agent_list[state["agent_id"]] != "SYSTEM": + if ( + state["agent_id"] in agent_list + and agent_list[state["agent_id"]] != "SYSTEM" + ): # Initialize the agent if needed if not self.agent or self.current_agent_id != state["agent_id"]: try: - self.agent = self.agent_factory.get_agent(state["agent_id"], state["user_id"]) + self.agent = self.agent_factory.get_agent( + state["agent_id"], state["user_id"] + ) self.current_agent_id = state["agent_id"] except Exception as e: logger.error(f"Failed to create agent {state['agent_id']}: {e}") - return Command(update={"response": "Failed to initialize agent"}, goto=END) + return Command( + update={"response": "Failed to initialize agent"}, goto=END + ) return Command(update={"agent_id": state["agent_id"]}, goto="agent_node") # For system agents, perform classification @@ -169,13 +165,17 @@ async def classifier_node(self, state: State) -> Command: agent_id=state["agent_id"], agent_descriptions=self.agent_descriptions, ) - + response = await self.llm.ainvoke(prompt) response = response.content.strip("`") try: agent_id, confidence = response.split("|") confidence = float(confidence) - selected_agent_id = agent_id if confidence >= 0.5 and agent_id in agent_list else state["agent_id"] + selected_agent_id = ( + agent_id + if confidence >= 0.5 and agent_id in agent_list + else state["agent_id"] + ) except (ValueError, TypeError): logger.error("Classification format error, falling back to current agent") selected_agent_id = state["agent_id"] @@ -183,11 +183,15 @@ async def classifier_node(self, state: State) -> Command: # Initialize the selected system agent if not self.agent or self.current_agent_id != selected_agent_id: try: - self.agent = self.agent_factory.get_agent(selected_agent_id, state["user_id"]) + self.agent = self.agent_factory.get_agent( + selected_agent_id, state["user_id"] + ) self.current_agent_id = selected_agent_id except Exception as e: logger.error(f"Failed to create agent {selected_agent_id}: {e}") - return Command(update={"response": "Failed to initialize agent"}, goto=END) + return Command( + update={"response": "Failed to initialize agent"}, goto=END + ) logger.info( f"Streaming AI response for conversation {state['conversation_id']} " @@ -200,7 +204,7 @@ async def agent_node(self, state: State, writer: StreamWriter): if not self.agent: logger.error("Agent not initialized before agent_node execution") return Command(update={"response": "Agent not initialized"}, goto=END) - + try: async for chunk in self.agent.run( query=state["query"], @@ -438,7 +442,7 @@ async def store_message( message_type: MessageType, user_id: str, stream: bool = True, - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[ChatMessageResponse, None]: try: access_level = await self.check_conversation_access( conversation_id, self.user_email @@ -494,15 +498,11 @@ async def store_message( async for chunk in self._generate_and_stream_ai_response( message.content, conversation_id, user_id, message.node_ids ): - data = json.loads(chunk) - # Extract the 'message' and 'citations' - message: str = data.get('message', '') - citations: List[str] = data.get('citations', []) + full_message += chunk.message + all_citations = all_citations + chunk.citations - full_message += message - all_citations = all_citations + citations - + # TODO: what is this below comment for? # # Store the complete response as a single message # self.history_manager.add_message_chunk( # conversation_id, full_response, MessageType.AI, user_id @@ -510,7 +510,9 @@ async def store_message( # self.history_manager.flush_message_buffer( # conversation_id, MessageType.AI, user_id # ) - yield ChatMessageResponse(message=full_message,citations=all_citations).json() + yield ChatMessageResponse( + message=full_message, citations=all_citations + ) except AccessTypeReadError: raise @@ -579,7 +581,7 @@ async def regenerate_last_message( user_id: str, node_ids: List[NodeContext] = [], stream: bool = True, - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[ChatMessageResponse, None]: try: access_level = await self.check_conversation_access( conversation_id, self.user_email @@ -607,15 +609,15 @@ async def regenerate_last_message( else: full_message = "" all_citations = [] - + async for chunk in self._generate_and_stream_ai_response( last_human_message.content, conversation_id, user_id, node_ids ): data = json.loads(chunk) # Extract the 'message' and 'citations' - message: str = data.get('message', '') - citations: List[str] = data.get('citations', []) + message: str = data.get("message", "") + citations: List[str] = data.get("citations", []) full_message += message all_citations = all_citations + citations @@ -626,7 +628,9 @@ async def regenerate_last_message( # self.history_manager.flush_message_buffer( # conversation_id, MessageType.AI, user_id # ) - yield ChatMessageResponse(message=full_message,citations=all_citations).json() + yield ChatMessageResponse( + message=full_message, citations=all_citations + ).json() except AccessTypeReadError: raise @@ -677,13 +681,22 @@ async def _archive_subsequent_messages( "Failed to archive subsequent messages." ) from e + def parse_str_to_message(self, chunk: str) -> ChatMessageResponse: + data = json.loads(chunk) + + # Extract the 'message' and 'citations' + message: str = data.get("message", "") + citations: List[str] = data.get("citations", []) + + return ChatMessageResponse(message=message, citations=citations) + async def _generate_and_stream_ai_response( self, query: str, conversation_id: str, user_id: str, node_ids: List[NodeContext], - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[ChatMessageResponse, None]: conversation = ( self.sql_db.query(Conversation).filter_by(id=conversation_id).first() ) @@ -708,13 +721,13 @@ async def _generate_and_stream_ai_response( response = await agent.run( agent_id, query, project_id, user_id, conversation.id, node_ids ) - yield response + yield self.parse_str_to_message(response) else: # For other agents that support streaming async for chunk in supervisor.process_query( query, project_id, conversation.id, user_id, node_ids, agent_id ): - yield chunk + yield self.parse_str_to_message(chunk) logger.info( f"Generated and streamed AI response for conversation {conversation.id} for user {user_id} using agent {agent_id}" From 30305e557356fb5ab540003933fc94c0132be1d8 Mon Sep 17 00:00:00 2001 From: Nandan Date: Thu, 6 Feb 2025 10:47:37 +0530 Subject: [PATCH 4/6] fix: generate byte stream --- .../conversation/conversation_controller.py | 23 +++++--- .../conversation/conversation_schema.py | 3 +- .../conversation/conversation_service.py | 57 +++++++++--------- .../conversations/conversations_router.py | 58 +++++++++++-------- 4 files changed, 80 insertions(+), 61 deletions(-) diff --git a/app/modules/conversations/conversation/conversation_controller.py b/app/modules/conversations/conversation/conversation_controller.py index 8fac703f..3337fe31 100644 --- a/app/modules/conversations/conversation/conversation_controller.py +++ b/app/modules/conversations/conversation/conversation_controller.py @@ -4,15 +4,24 @@ from sqlalchemy.orm import Session from app.modules.conversations.conversation.conversation_schema import ( - ChatMessageResponse, ConversationInfoResponse, CreateConversationRequest, - CreateConversationResponse) + ChatMessageResponse, + ConversationInfoResponse, + CreateConversationRequest, + CreateConversationResponse, +) from app.modules.conversations.conversation.conversation_service import ( - AccessTypeNotFoundError, AccessTypeReadError, ConversationNotFoundError, - ConversationService, ConversationServiceError) + AccessTypeNotFoundError, + AccessTypeReadError, + ConversationNotFoundError, + ConversationService, + ConversationServiceError, +) from app.modules.conversations.message.message_model import MessageType -from app.modules.conversations.message.message_schema import (MessageRequest, - MessageResponse, - NodeContext) +from app.modules.conversations.message.message_schema import ( + MessageRequest, + MessageResponse, + NodeContext, +) class ConversationController: diff --git a/app/modules/conversations/conversation/conversation_schema.py b/app/modules/conversations/conversation/conversation_schema.py index 93f67198..6cca8862 100644 --- a/app/modules/conversations/conversation/conversation_schema.py +++ b/app/modules/conversations/conversation/conversation_schema.py @@ -4,8 +4,7 @@ from pydantic import BaseModel -from app.modules.conversations.conversation.conversation_model import \ - ConversationStatus +from app.modules.conversations.conversation.conversation_model import ConversationStatus class CreateConversationRequest(BaseModel): diff --git a/app/modules/conversations/conversation/conversation_service.py b/app/modules/conversations/conversation/conversation_service.py index 0fa553b3..44f55595 100644 --- a/app/modules/conversations/conversation/conversation_service.py +++ b/app/modules/conversations/conversation/conversation_service.py @@ -14,26 +14,37 @@ from app.modules.code_provider.code_provider_service import CodeProviderService from app.modules.conversations.conversation.conversation_model import ( - Conversation, ConversationStatus, Visibility) + Conversation, + ConversationStatus, + Visibility, +) from app.modules.conversations.conversation.conversation_schema import ( - ChatMessageResponse, ConversationAccessType, ConversationInfoResponse, - CreateConversationRequest) -from app.modules.conversations.message.message_model import (Message, - MessageStatus, - MessageType) -from app.modules.conversations.message.message_schema import (MessageRequest, - MessageResponse, - NodeContext) + ChatMessageResponse, + ConversationAccessType, + ConversationInfoResponse, + CreateConversationRequest, +) +from app.modules.conversations.message.message_model import ( + Message, + MessageStatus, + MessageType, +) +from app.modules.conversations.message.message_schema import ( + MessageRequest, + MessageResponse, + NodeContext, +) from app.modules.intelligence.agents.agent_factory import AgentFactory -from app.modules.intelligence.agents.agent_injector_service import \ - AgentInjectorService +from app.modules.intelligence.agents.agent_injector_service import AgentInjectorService from app.modules.intelligence.agents.agents_service import AgentsService -from app.modules.intelligence.agents.custom_agents.custom_agents_service import \ - CustomAgentsService -from app.modules.intelligence.memory.chat_history_service import \ - ChatHistoryService +from app.modules.intelligence.agents.custom_agents.custom_agents_service import ( + CustomAgentsService, +) +from app.modules.intelligence.memory.chat_history_service import ChatHistoryService from app.modules.intelligence.provider.provider_service import ( - AgentType, ProviderService) + AgentType, + ProviderService, +) from app.modules.projects.projects_service import ProjectService from app.modules.users.user_service import UserService from app.modules.utils.posthog_helper import PostHogClient @@ -613,14 +624,8 @@ async def regenerate_last_message( async for chunk in self._generate_and_stream_ai_response( last_human_message.content, conversation_id, user_id, node_ids ): - data = json.loads(chunk) - - # Extract the 'message' and 'citations' - message: str = data.get("message", "") - citations: List[str] = data.get("citations", []) - - full_message += message - all_citations = all_citations + citations + full_message += chunk.message + all_citations = all_citations + chunk.citations # # Store the complete response as a single message # self.history_manager.add_message_chunk( # conversation_id, full_response, MessageType.AI, user_id @@ -628,9 +633,7 @@ async def regenerate_last_message( # self.history_manager.flush_message_buffer( # conversation_id, MessageType.AI, user_id # ) - yield ChatMessageResponse( - message=full_message, citations=all_citations - ).json() + yield ChatMessageResponse(message=full_message, citations=all_citations) except AccessTypeReadError: raise diff --git a/app/modules/conversations/conversations_router.py b/app/modules/conversations/conversations_router.py index b0acbf32..496e1aff 100644 --- a/app/modules/conversations/conversations_router.py +++ b/app/modules/conversations/conversations_router.py @@ -1,6 +1,4 @@ -import json -from typing import List - +from typing import List, AsyncGenerator, Any from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session @@ -8,22 +6,35 @@ from app.core.database import get_db from app.modules.auth.auth_service import AuthService from app.modules.conversations.access.access_schema import ( - RemoveAccessRequest, ShareChatRequest, ShareChatResponse) + RemoveAccessRequest, + ShareChatRequest, + ShareChatResponse, +) from app.modules.conversations.access.access_service import ( - ShareChatService, ShareChatServiceError) -from app.modules.conversations.conversation.conversation_controller import \ - ConversationController - -from .conversation.conversation_schema import (ConversationInfoResponse, - CreateConversationRequest, - CreateConversationResponse, - RenameConversationRequest) -from .message.message_schema import (MessageRequest, MessageResponse, - RegenerateRequest) + ShareChatService, + ShareChatServiceError, +) +from app.modules.conversations.conversation.conversation_controller import ( + ConversationController, +) + +from .conversation.conversation_schema import ( + ConversationInfoResponse, + CreateConversationRequest, + CreateConversationResponse, + RenameConversationRequest, +) +from .message.message_schema import MessageRequest, MessageResponse, RegenerateRequest +import json router = APIRouter() +async def get_stream(data_stream: AsyncGenerator[Any, None]): + async for chunk in data_stream: + yield json.dumps(chunk.dict()) + + class ConversationAPI: @staticmethod @router.post("/conversations/", response_model=CreateConversationResponse) @@ -92,14 +103,13 @@ async def post_message( controller = ConversationController(db, user_id, user_email) message_stream = controller.post_message(conversation_id, message, stream) if stream: - return StreamingResponse(message_stream, media_type="text/event-stream") + return StreamingResponse( + get_stream(message_stream), media_type="text/event-stream" + ) else: # TODO: fix this, add types. In below stream we have only one output. - # no need of stream here - full_response = "" async for chunk in message_stream: - full_response += chunk - return json.loads(full_response) + return chunk @staticmethod @router.post( @@ -119,14 +129,12 @@ async def regenerate_last_message( conversation_id, request.node_ids, stream ) if stream: - return StreamingResponse(message_stream, media_type="text/event-stream") + return StreamingResponse( + get_stream(message_stream), media_type="text/event-stream" + ) else: - # TODO: fix this, add types. In below stream we have only one output. - # no need of stream here - full_response = "" async for chunk in message_stream: - full_response += chunk - return json.loads(full_response) + return chunk @staticmethod @router.delete("/conversations/{conversation_id}/", response_model=dict) From 45c3988524ac3ab27875d1ec4f7c07255145bf63 Mon Sep 17 00:00:00 2001 From: Nandan Date: Thu, 6 Feb 2025 16:28:54 +0530 Subject: [PATCH 5/6] chore: handle json parsing error case --- .../conversations/conversation/conversation_service.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/app/modules/conversations/conversation/conversation_service.py b/app/modules/conversations/conversation/conversation_service.py index 44f55595..0fb4ac46 100644 --- a/app/modules/conversations/conversation/conversation_service.py +++ b/app/modules/conversations/conversation/conversation_service.py @@ -685,7 +685,11 @@ async def _archive_subsequent_messages( ) from e def parse_str_to_message(self, chunk: str) -> ChatMessageResponse: - data = json.loads(chunk) + try: + data = json.loads(chunk) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse chunk as JSON: {e}") + raise ConversationServiceError("Failed to parse AI response") from e # Extract the 'message' and 'citations' message: str = data.get("message", "") From 2854195ebccccc64ff40677bab865edf7b020290 Mon Sep 17 00:00:00 2001 From: Nandan Date: Tue, 11 Feb 2025 17:20:15 +0530 Subject: [PATCH 6/6] feat: add direct messaging api --- app/api/router.py | 45 ++++++++++++++++++- .../conversations/message/message_schema.py | 6 +++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/app/api/router.py b/app/api/router.py index e7c42258..c34005e7 100644 --- a/app/api/router.py +++ b/app/api/router.py @@ -16,7 +16,10 @@ CreateConversationRequest, CreateConversationResponse, ) -from app.modules.conversations.message.message_schema import MessageRequest +from app.modules.conversations.message.message_schema import ( + MessageRequest, + DirectMessageRequest, +) from app.modules.parsing.graph_construction.parsing_controller import ParsingController from app.modules.parsing.graph_construction.parsing_schema import ParsingRequest from app.modules.utils.APIRouter import APIRouter @@ -103,4 +106,42 @@ async def post_message( # Note: email is no longer available with API key auth controller = ConversationController(db, user_id, None) message_stream = controller.post_message(conversation_id, message, stream=False) - return StreamingResponse(message_stream, media_type="text/event-stream") + async for chunk in message_stream: + return chunk + + +@router.post("/project/{project_id}/message/") +async def create_conversation_and_message( + project_id: str, + message: DirectMessageRequest, + db: Session = Depends(get_db), + user=Depends(get_api_key_user), +): + if message.content == "" or message.content is None or message.content.isspace(): + raise HTTPException(status_code=400, detail="Message content cannot be empty") + + user_id = user["user_id"] + + # default agent_id to codebase_qna_agent + if message.agent_id is None: + message.agent_id = "codebase_qna_agent" + + controller = ConversationController(db, user_id, None) + res = await controller.create_conversation( + CreateConversationRequest( + user_id=user_id, + title=message.content, + project_ids=[project_id], + agent_ids=[message.agent_id], + status=ConversationStatus.ACTIVE, + ) + ) + + message_stream = controller.post_message( + conversation_id=res.conversation_id, + message=MessageRequest(content=message.content, node_ids=message.node_ids), + stream=False, + ) + + async for chunk in message_stream: + return chunk diff --git a/app/modules/conversations/message/message_schema.py b/app/modules/conversations/message/message_schema.py index 0535438c..2df8ad82 100644 --- a/app/modules/conversations/message/message_schema.py +++ b/app/modules/conversations/message/message_schema.py @@ -16,6 +16,12 @@ class MessageRequest(BaseModel): node_ids: Optional[List[NodeContext]] = None +class DirectMessageRequest(BaseModel): + content: str + node_ids: Optional[List[NodeContext]] = None + agent_id: str | None = None + + class RegenerateRequest(BaseModel): node_ids: Optional[List[NodeContext]] = None