diff --git a/app/api/router.py b/app/api/router.py index e7c42258..c34005e7 100644 --- a/app/api/router.py +++ b/app/api/router.py @@ -16,7 +16,10 @@ CreateConversationRequest, CreateConversationResponse, ) -from app.modules.conversations.message.message_schema import MessageRequest +from app.modules.conversations.message.message_schema import ( + MessageRequest, + DirectMessageRequest, +) from app.modules.parsing.graph_construction.parsing_controller import ParsingController from app.modules.parsing.graph_construction.parsing_schema import ParsingRequest from app.modules.utils.APIRouter import APIRouter @@ -103,4 +106,42 @@ async def post_message( # Note: email is no longer available with API key auth controller = ConversationController(db, user_id, None) message_stream = controller.post_message(conversation_id, message, stream=False) - return StreamingResponse(message_stream, media_type="text/event-stream") + async for chunk in message_stream: + return chunk + + +@router.post("/project/{project_id}/message/") +async def create_conversation_and_message( + project_id: str, + message: DirectMessageRequest, + db: Session = Depends(get_db), + user=Depends(get_api_key_user), +): + if message.content == "" or message.content is None or message.content.isspace(): + raise HTTPException(status_code=400, detail="Message content cannot be empty") + + user_id = user["user_id"] + + # default agent_id to codebase_qna_agent + if message.agent_id is None: + message.agent_id = "codebase_qna_agent" + + controller = ConversationController(db, user_id, None) + res = await controller.create_conversation( + CreateConversationRequest( + user_id=user_id, + title=message.content, + project_ids=[project_id], + agent_ids=[message.agent_id], + status=ConversationStatus.ACTIVE, + ) + ) + + message_stream = controller.post_message( + conversation_id=res.conversation_id, + message=MessageRequest(content=message.content, node_ids=message.node_ids), + stream=False, + ) + + async for chunk in message_stream: + return chunk diff --git a/app/modules/conversations/conversation/conversation_controller.py b/app/modules/conversations/conversation/conversation_controller.py index 3e7e8b6a..3337fe31 100644 --- a/app/modules/conversations/conversation/conversation_controller.py +++ b/app/modules/conversations/conversation/conversation_controller.py @@ -4,6 +4,7 @@ from sqlalchemy.orm import Session from app.modules.conversations.conversation.conversation_schema import ( + ChatMessageResponse, ConversationInfoResponse, CreateConversationRequest, CreateConversationResponse, @@ -82,7 +83,7 @@ async def get_conversation_messages( async def post_message( self, conversation_id: str, message: MessageRequest, stream: bool = True - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[ChatMessageResponse, None]: try: async for chunk in self.service.store_message( conversation_id, message, MessageType.HUMAN, self.user_id, stream @@ -100,7 +101,7 @@ async def regenerate_last_message( conversation_id: str, node_ids: List[NodeContext] = [], stream: bool = True, - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[ChatMessageResponse, None]: try: async for chunk in self.service.regenerate_last_message( conversation_id, self.user_id, node_ids, stream diff --git a/app/modules/conversations/conversation/conversation_schema.py b/app/modules/conversations/conversation/conversation_schema.py index 535c5ec4..6cca8862 100644 --- a/app/modules/conversations/conversation/conversation_schema.py +++ b/app/modules/conversations/conversation/conversation_schema.py @@ -47,6 +47,11 @@ class Config: from_attributes = True +class ChatMessageResponse(BaseModel): + message: str + citations: List[str] + + # Resolve forward references ConversationInfoResponse.update_forward_refs() diff --git a/app/modules/conversations/conversation/conversation_service.py b/app/modules/conversations/conversation/conversation_service.py index 51e2599d..0fb4ac46 100644 --- a/app/modules/conversations/conversation/conversation_service.py +++ b/app/modules/conversations/conversation/conversation_service.py @@ -1,4 +1,5 @@ import asyncio +import json import logging from datetime import datetime, timezone from typing import Any, AsyncGenerator, Dict, List, Optional, TypedDict @@ -18,6 +19,7 @@ Visibility, ) from app.modules.conversations.conversation.conversation_schema import ( + ChatMessageResponse, ConversationAccessType, ConversationInfoResponse, CreateConversationRequest, @@ -75,7 +77,7 @@ def __init__(self, db, provider_service): self.db = db self.provider_service = provider_service self.agent = None - self.current_agent_id = None + self.current_agent_id = None self.classifier = None self.agents_service = AgentsService(db) self.agent_factory = AgentFactory(db, provider_service) @@ -148,17 +150,24 @@ async def classifier_node(self, state: State) -> Command: return Command(update={"response": "No query provided"}, goto=END) agent_list = {agent.id: agent.status for agent in self.available_agents} - + # First check - if this is a custom agent (non-SYSTEM), route directly - if state["agent_id"] in agent_list and agent_list[state["agent_id"]] != "SYSTEM": + if ( + state["agent_id"] in agent_list + and agent_list[state["agent_id"]] != "SYSTEM" + ): # Initialize the agent if needed if not self.agent or self.current_agent_id != state["agent_id"]: try: - self.agent = self.agent_factory.get_agent(state["agent_id"], state["user_id"]) + self.agent = self.agent_factory.get_agent( + state["agent_id"], state["user_id"] + ) self.current_agent_id = state["agent_id"] except Exception as e: logger.error(f"Failed to create agent {state['agent_id']}: {e}") - return Command(update={"response": "Failed to initialize agent"}, goto=END) + return Command( + update={"response": "Failed to initialize agent"}, goto=END + ) return Command(update={"agent_id": state["agent_id"]}, goto="agent_node") # For system agents, perform classification @@ -167,13 +176,17 @@ async def classifier_node(self, state: State) -> Command: agent_id=state["agent_id"], agent_descriptions=self.agent_descriptions, ) - + response = await self.llm.ainvoke(prompt) response = response.content.strip("`") try: agent_id, confidence = response.split("|") confidence = float(confidence) - selected_agent_id = agent_id if confidence >= 0.5 and agent_id in agent_list else state["agent_id"] + selected_agent_id = ( + agent_id + if confidence >= 0.5 and agent_id in agent_list + else state["agent_id"] + ) except (ValueError, TypeError): logger.error("Classification format error, falling back to current agent") selected_agent_id = state["agent_id"] @@ -181,11 +194,15 @@ async def classifier_node(self, state: State) -> Command: # Initialize the selected system agent if not self.agent or self.current_agent_id != selected_agent_id: try: - self.agent = self.agent_factory.get_agent(selected_agent_id, state["user_id"]) + self.agent = self.agent_factory.get_agent( + selected_agent_id, state["user_id"] + ) self.current_agent_id = selected_agent_id except Exception as e: logger.error(f"Failed to create agent {selected_agent_id}: {e}") - return Command(update={"response": "Failed to initialize agent"}, goto=END) + return Command( + update={"response": "Failed to initialize agent"}, goto=END + ) logger.info( f"Streaming AI response for conversation {state['conversation_id']} " @@ -198,7 +215,7 @@ async def agent_node(self, state: State, writer: StreamWriter): if not self.agent: logger.error("Agent not initialized before agent_node execution") return Command(update={"response": "Agent not initialized"}, goto=END) - + try: async for chunk in self.agent.run( query=state["query"], @@ -436,7 +453,7 @@ async def store_message( message_type: MessageType, user_id: str, stream: bool = True, - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[ChatMessageResponse, None]: try: access_level = await self.check_conversation_access( conversation_id, self.user_email @@ -487,12 +504,16 @@ async def store_message( ): yield chunk else: - # For non-streaming, collect all chunks and store as a single message - full_response = "" + full_message = "" + all_citations = [] async for chunk in self._generate_and_stream_ai_response( message.content, conversation_id, user_id, message.node_ids ): - full_response += chunk + + full_message += chunk.message + all_citations = all_citations + chunk.citations + + # TODO: what is this below comment for? # # Store the complete response as a single message # self.history_manager.add_message_chunk( # conversation_id, full_response, MessageType.AI, user_id @@ -500,7 +521,9 @@ async def store_message( # self.history_manager.flush_message_buffer( # conversation_id, MessageType.AI, user_id # ) - yield full_response + yield ChatMessageResponse( + message=full_message, citations=all_citations + ) except AccessTypeReadError: raise @@ -569,7 +592,7 @@ async def regenerate_last_message( user_id: str, node_ids: List[NodeContext] = [], stream: bool = True, - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[ChatMessageResponse, None]: try: access_level = await self.check_conversation_access( conversation_id, self.user_email @@ -595,12 +618,14 @@ async def regenerate_last_message( ): yield chunk else: - # For non-streaming, collect all chunks and store as a single message - full_response = "" + full_message = "" + all_citations = [] + async for chunk in self._generate_and_stream_ai_response( last_human_message.content, conversation_id, user_id, node_ids ): - full_response += chunk + full_message += chunk.message + all_citations = all_citations + chunk.citations # # Store the complete response as a single message # self.history_manager.add_message_chunk( # conversation_id, full_response, MessageType.AI, user_id @@ -608,7 +633,7 @@ async def regenerate_last_message( # self.history_manager.flush_message_buffer( # conversation_id, MessageType.AI, user_id # ) - yield full_response + yield ChatMessageResponse(message=full_message, citations=all_citations) except AccessTypeReadError: raise @@ -659,13 +684,26 @@ async def _archive_subsequent_messages( "Failed to archive subsequent messages." ) from e + def parse_str_to_message(self, chunk: str) -> ChatMessageResponse: + try: + data = json.loads(chunk) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse chunk as JSON: {e}") + raise ConversationServiceError("Failed to parse AI response") from e + + # Extract the 'message' and 'citations' + message: str = data.get("message", "") + citations: List[str] = data.get("citations", []) + + return ChatMessageResponse(message=message, citations=citations) + async def _generate_and_stream_ai_response( self, query: str, conversation_id: str, user_id: str, node_ids: List[NodeContext], - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[ChatMessageResponse, None]: conversation = ( self.sql_db.query(Conversation).filter_by(id=conversation_id).first() ) @@ -690,13 +728,13 @@ async def _generate_and_stream_ai_response( response = await agent.run( agent_id, query, project_id, user_id, conversation.id, node_ids ) - yield response + yield self.parse_str_to_message(response) else: # For other agents that support streaming async for chunk in supervisor.process_query( query, project_id, conversation.id, user_id, node_ids, agent_id ): - yield chunk + yield self.parse_str_to_message(chunk) logger.info( f"Generated and streamed AI response for conversation {conversation.id} for user {user_id} using agent {agent_id}" diff --git a/app/modules/conversations/conversations_router.py b/app/modules/conversations/conversations_router.py index 0a6bdefd..496e1aff 100644 --- a/app/modules/conversations/conversations_router.py +++ b/app/modules/conversations/conversations_router.py @@ -1,5 +1,4 @@ -from typing import List - +from typing import List, AsyncGenerator, Any from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session @@ -26,10 +25,16 @@ RenameConversationRequest, ) from .message.message_schema import MessageRequest, MessageResponse, RegenerateRequest +import json router = APIRouter() +async def get_stream(data_stream: AsyncGenerator[Any, None]): + async for chunk in data_stream: + yield json.dumps(chunk.dict()) + + class ConversationAPI: @staticmethod @router.post("/conversations/", response_model=CreateConversationResponse) @@ -98,13 +103,13 @@ async def post_message( controller = ConversationController(db, user_id, user_email) message_stream = controller.post_message(conversation_id, message, stream) if stream: - return StreamingResponse(message_stream, media_type="text/event-stream") + return StreamingResponse( + get_stream(message_stream), media_type="text/event-stream" + ) else: - # Collect all chunks into a complete response - full_response = "" + # TODO: fix this, add types. In below stream we have only one output. async for chunk in message_stream: - full_response += chunk - return {"content": full_response} + return chunk @staticmethod @router.post( @@ -124,13 +129,12 @@ async def regenerate_last_message( conversation_id, request.node_ids, stream ) if stream: - return StreamingResponse(message_stream, media_type="text/event-stream") + return StreamingResponse( + get_stream(message_stream), media_type="text/event-stream" + ) else: - # Collect all chunks into a complete response - full_response = "" async for chunk in message_stream: - full_response += chunk - return {"content": full_response} + return chunk @staticmethod @router.delete("/conversations/{conversation_id}/", response_model=dict) diff --git a/app/modules/conversations/message/message_schema.py b/app/modules/conversations/message/message_schema.py index 0535438c..2df8ad82 100644 --- a/app/modules/conversations/message/message_schema.py +++ b/app/modules/conversations/message/message_schema.py @@ -16,6 +16,12 @@ class MessageRequest(BaseModel): node_ids: Optional[List[NodeContext]] = None +class DirectMessageRequest(BaseModel): + content: str + node_ids: Optional[List[NodeContext]] = None + agent_id: str | None = None + + class RegenerateRequest(BaseModel): node_ids: Optional[List[NodeContext]] = None