Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/direct message api #252

Merged
merged 6 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions app/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
CreateConversationRequest,
CreateConversationResponse,
)
from app.modules.conversations.message.message_schema import MessageRequest
from app.modules.conversations.message.message_schema import (
MessageRequest,
DirectMessageRequest,
)
from app.modules.parsing.graph_construction.parsing_controller import ParsingController
from app.modules.parsing.graph_construction.parsing_schema import ParsingRequest
from app.modules.utils.APIRouter import APIRouter
Expand Down Expand Up @@ -103,4 +106,42 @@ async def post_message(
# Note: email is no longer available with API key auth
controller = ConversationController(db, user_id, None)
message_stream = controller.post_message(conversation_id, message, stream=False)
return StreamingResponse(message_stream, media_type="text/event-stream")
async for chunk in message_stream:
return chunk


@router.post("/project/{project_id}/message/")
async def create_conversation_and_message(
project_id: str,
message: DirectMessageRequest,
db: Session = Depends(get_db),
user=Depends(get_api_key_user),
):
if message.content == "" or message.content is None or message.content.isspace():
raise HTTPException(status_code=400, detail="Message content cannot be empty")

user_id = user["user_id"]

# default agent_id to codebase_qna_agent
if message.agent_id is None:
message.agent_id = "codebase_qna_agent"

controller = ConversationController(db, user_id, None)
res = await controller.create_conversation(
CreateConversationRequest(
user_id=user_id,
title=message.content,
project_ids=[project_id],
agent_ids=[message.agent_id],
status=ConversationStatus.ACTIVE,
)
)

message_stream = controller.post_message(
conversation_id=res.conversation_id,
message=MessageRequest(content=message.content, node_ids=message.node_ids),
stream=False,
)

async for chunk in message_stream:
return chunk
Comment on lines +140 to +147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for message posting.

The message posting call lacks error handling. Consider wrapping it in a try-except block to handle potential errors.

+    try:
         message_stream = controller.post_message(
             conversation_id=res.conversation_id,
             message=MessageRequest(content=message.content, node_ids=message.node_ids),
             stream=False,
         )

         async for chunk in message_stream:
             return chunk
+    except Exception as e:
+        # Clean up the created conversation on failure
+        await controller.delete_conversation(res.conversation_id)
+        raise HTTPException(status_code=500, detail=str(e))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
message_stream = controller.post_message(
conversation_id=res.conversation_id,
message=MessageRequest(content=message.content, node_ids=message.node_ids),
stream=False,
)
async for chunk in message_stream:
return chunk
try:
message_stream = controller.post_message(
conversation_id=res.conversation_id,
message=MessageRequest(content=message.content, node_ids=message.node_ids),
stream=False,
)
async for chunk in message_stream:
return chunk
except Exception as e:
# Clean up the created conversation on failure
await controller.delete_conversation(res.conversation_id)
raise HTTPException(status_code=500, detail=str(e))

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sqlalchemy.orm import Session

from app.modules.conversations.conversation.conversation_schema import (
ChatMessageResponse,
ConversationInfoResponse,
CreateConversationRequest,
CreateConversationResponse,
Expand Down Expand Up @@ -82,7 +83,7 @@ async def get_conversation_messages(

async def post_message(
self, conversation_id: str, message: MessageRequest, stream: bool = True
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[ChatMessageResponse, None]:
try:
async for chunk in self.service.store_message(
conversation_id, message, MessageType.HUMAN, self.user_id, stream
Expand All @@ -100,7 +101,7 @@ async def regenerate_last_message(
conversation_id: str,
node_ids: List[NodeContext] = [],
stream: bool = True,
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[ChatMessageResponse, None]:
try:
async for chunk in self.service.regenerate_last_message(
conversation_id, self.user_id, node_ids, stream
Expand Down
5 changes: 5 additions & 0 deletions app/modules/conversations/conversation/conversation_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class Config:
from_attributes = True


class ChatMessageResponse(BaseModel):
message: str
citations: List[str]


# Resolve forward references
ConversationInfoResponse.update_forward_refs()

Expand Down
84 changes: 61 additions & 23 deletions app/modules/conversations/conversation/conversation_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
import logging
from datetime import datetime, timezone
from typing import Any, AsyncGenerator, Dict, List, Optional, TypedDict
Expand All @@ -18,6 +19,7 @@
Visibility,
)
from app.modules.conversations.conversation.conversation_schema import (
ChatMessageResponse,
ConversationAccessType,
ConversationInfoResponse,
CreateConversationRequest,
Expand Down Expand Up @@ -75,7 +77,7 @@ def __init__(self, db, provider_service):
self.db = db
self.provider_service = provider_service
self.agent = None
self.current_agent_id = None
self.current_agent_id = None
self.classifier = None
self.agents_service = AgentsService(db)
self.agent_factory = AgentFactory(db, provider_service)
Expand Down Expand Up @@ -148,17 +150,24 @@ async def classifier_node(self, state: State) -> Command:
return Command(update={"response": "No query provided"}, goto=END)

agent_list = {agent.id: agent.status for agent in self.available_agents}

# First check - if this is a custom agent (non-SYSTEM), route directly
if state["agent_id"] in agent_list and agent_list[state["agent_id"]] != "SYSTEM":
if (
state["agent_id"] in agent_list
and agent_list[state["agent_id"]] != "SYSTEM"
):
# Initialize the agent if needed
if not self.agent or self.current_agent_id != state["agent_id"]:
try:
self.agent = self.agent_factory.get_agent(state["agent_id"], state["user_id"])
self.agent = self.agent_factory.get_agent(
state["agent_id"], state["user_id"]
)
self.current_agent_id = state["agent_id"]
except Exception as e:
logger.error(f"Failed to create agent {state['agent_id']}: {e}")
return Command(update={"response": "Failed to initialize agent"}, goto=END)
return Command(
update={"response": "Failed to initialize agent"}, goto=END
)
return Command(update={"agent_id": state["agent_id"]}, goto="agent_node")

# For system agents, perform classification
Expand All @@ -167,25 +176,33 @@ async def classifier_node(self, state: State) -> Command:
agent_id=state["agent_id"],
agent_descriptions=self.agent_descriptions,
)

response = await self.llm.ainvoke(prompt)
response = response.content.strip("`")
try:
agent_id, confidence = response.split("|")
confidence = float(confidence)
selected_agent_id = agent_id if confidence >= 0.5 and agent_id in agent_list else state["agent_id"]
selected_agent_id = (
agent_id
if confidence >= 0.5 and agent_id in agent_list
else state["agent_id"]
)
except (ValueError, TypeError):
logger.error("Classification format error, falling back to current agent")
selected_agent_id = state["agent_id"]

# Initialize the selected system agent
if not self.agent or self.current_agent_id != selected_agent_id:
try:
self.agent = self.agent_factory.get_agent(selected_agent_id, state["user_id"])
self.agent = self.agent_factory.get_agent(
selected_agent_id, state["user_id"]
)
self.current_agent_id = selected_agent_id
except Exception as e:
logger.error(f"Failed to create agent {selected_agent_id}: {e}")
return Command(update={"response": "Failed to initialize agent"}, goto=END)
return Command(
update={"response": "Failed to initialize agent"}, goto=END
)

logger.info(
f"Streaming AI response for conversation {state['conversation_id']} "
Expand All @@ -198,7 +215,7 @@ async def agent_node(self, state: State, writer: StreamWriter):
if not self.agent:
logger.error("Agent not initialized before agent_node execution")
return Command(update={"response": "Agent not initialized"}, goto=END)

try:
async for chunk in self.agent.run(
query=state["query"],
Expand Down Expand Up @@ -436,7 +453,7 @@ async def store_message(
message_type: MessageType,
user_id: str,
stream: bool = True,
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[ChatMessageResponse, None]:
try:
access_level = await self.check_conversation_access(
conversation_id, self.user_email
Expand Down Expand Up @@ -487,20 +504,26 @@ async def store_message(
):
yield chunk
else:
# For non-streaming, collect all chunks and store as a single message
full_response = ""
full_message = ""
all_citations = []
async for chunk in self._generate_and_stream_ai_response(
message.content, conversation_id, user_id, message.node_ids
):
full_response += chunk

full_message += chunk.message
all_citations = all_citations + chunk.citations

# TODO: what is this below comment for?
# # Store the complete response as a single message
# self.history_manager.add_message_chunk(
# conversation_id, full_response, MessageType.AI, user_id
# )
# self.history_manager.flush_message_buffer(
# conversation_id, MessageType.AI, user_id
# )
yield full_response
yield ChatMessageResponse(
message=full_message, citations=all_citations
)

except AccessTypeReadError:
raise
Expand Down Expand Up @@ -569,7 +592,7 @@ async def regenerate_last_message(
user_id: str,
node_ids: List[NodeContext] = [],
stream: bool = True,
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[ChatMessageResponse, None]:
try:
access_level = await self.check_conversation_access(
conversation_id, self.user_email
Expand All @@ -595,20 +618,22 @@ async def regenerate_last_message(
):
yield chunk
else:
# For non-streaming, collect all chunks and store as a single message
full_response = ""
full_message = ""
all_citations = []

async for chunk in self._generate_and_stream_ai_response(
last_human_message.content, conversation_id, user_id, node_ids
):
full_response += chunk
full_message += chunk.message
all_citations = all_citations + chunk.citations
# # Store the complete response as a single message
# self.history_manager.add_message_chunk(
# conversation_id, full_response, MessageType.AI, user_id
# )
# self.history_manager.flush_message_buffer(
# conversation_id, MessageType.AI, user_id
# )
yield full_response
yield ChatMessageResponse(message=full_message, citations=all_citations)

except AccessTypeReadError:
raise
Expand Down Expand Up @@ -659,13 +684,26 @@ async def _archive_subsequent_messages(
"Failed to archive subsequent messages."
) from e

def parse_str_to_message(self, chunk: str) -> ChatMessageResponse:
try:
data = json.loads(chunk)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse chunk as JSON: {e}")
raise ConversationServiceError("Failed to parse AI response") from e

# Extract the 'message' and 'citations'
message: str = data.get("message", "")
citations: List[str] = data.get("citations", [])

return ChatMessageResponse(message=message, citations=citations)

Comment on lines +687 to +699
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add type validation for extracted fields.

The function doesn't validate the types of the extracted 'message' and 'citations' fields, which could lead to runtime errors.

Apply this diff to add type validation:

     def parse_str_to_message(self, chunk: str) -> ChatMessageResponse:
         try:
             data = json.loads(chunk)
         except json.JSONDecodeError as e:
             logger.error(f"Failed to parse chunk as JSON: {e}")
             raise ConversationServiceError("Failed to parse AI response") from e

-        # Extract the 'message' and 'citations'
-        message: str = data.get("message", "")
-        citations: List[str] = data.get("citations", [])
+        # Extract and validate 'message' and 'citations'
+        message = data.get("message", "")
+        citations = data.get("citations", [])
+
+        if not isinstance(message, str):
+            raise ConversationServiceError("Message must be a string")
+        if not isinstance(citations, list) or not all(isinstance(c, str) for c in citations):
+            raise ConversationServiceError("Citations must be a list of strings")

         return ChatMessageResponse(message=message, citations=citations)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def parse_str_to_message(self, chunk: str) -> ChatMessageResponse:
try:
data = json.loads(chunk)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse chunk as JSON: {e}")
raise ConversationServiceError("Failed to parse AI response") from e
# Extract the 'message' and 'citations'
message: str = data.get("message", "")
citations: List[str] = data.get("citations", [])
return ChatMessageResponse(message=message, citations=citations)
def parse_str_to_message(self, chunk: str) -> ChatMessageResponse:
try:
data = json.loads(chunk)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse chunk as JSON: {e}")
raise ConversationServiceError("Failed to parse AI response") from e
# Extract and validate 'message' and 'citations'
message = data.get("message", "")
citations = data.get("citations", [])
if not isinstance(message, str):
raise ConversationServiceError("Message must be a string")
if not isinstance(citations, list) or not all(isinstance(c, str) for c in citations):
raise ConversationServiceError("Citations must be a list of strings")
return ChatMessageResponse(message=message, citations=citations)

async def _generate_and_stream_ai_response(
self,
query: str,
conversation_id: str,
user_id: str,
node_ids: List[NodeContext],
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[ChatMessageResponse, None]:
conversation = (
self.sql_db.query(Conversation).filter_by(id=conversation_id).first()
)
Expand All @@ -690,13 +728,13 @@ async def _generate_and_stream_ai_response(
response = await agent.run(
agent_id, query, project_id, user_id, conversation.id, node_ids
)
yield response
yield self.parse_str_to_message(response)
else:
# For other agents that support streaming
async for chunk in supervisor.process_query(
query, project_id, conversation.id, user_id, node_ids, agent_id
):
yield chunk
yield self.parse_str_to_message(chunk)
Comment on lines +731 to +737
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for parse_str_to_message calls.

The function doesn't handle potential errors from parse_str_to_message, which could lead to unhandled exceptions.

Apply this diff to add error handling:

             if isinstance(agent, CustomAgentsService):
                 # Custom agent doesn't support streaming, so we'll yield the entire response at once
                 response = await agent.run(
                     agent_id, query, project_id, user_id, conversation.id, node_ids
                 )
-                yield self.parse_str_to_message(response)
+                try:
+                    yield self.parse_str_to_message(response)
+                except ConversationServiceError as e:
+                    logger.error(f"Failed to parse custom agent response: {e}")
+                    raise
             else:
                 # For other agents that support streaming
                 async for chunk in supervisor.process_query(
                     query, project_id, conversation.id, user_id, node_ids, agent_id
                 ):
-                    yield self.parse_str_to_message(chunk)
+                    try:
+                        yield self.parse_str_to_message(chunk)
+                    except ConversationServiceError as e:
+                        logger.error(f"Failed to parse agent response chunk: {e}")
+                        raise
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
yield self.parse_str_to_message(response)
else:
# For other agents that support streaming
async for chunk in supervisor.process_query(
query, project_id, conversation.id, user_id, node_ids, agent_id
):
yield chunk
yield self.parse_str_to_message(chunk)
try:
yield self.parse_str_to_message(response)
except ConversationServiceError as e:
logger.error(f"Failed to parse custom agent response: {e}")
raise
else:
# For other agents that support streaming
async for chunk in supervisor.process_query(
query, project_id, conversation.id, user_id, node_ids, agent_id
):
try:
yield self.parse_str_to_message(chunk)
except ConversationServiceError as e:
logger.error(f"Failed to parse agent response chunk: {e}")
raise


logger.info(
f"Generated and streamed AI response for conversation {conversation.id} for user {user_id} using agent {agent_id}"
Expand Down
28 changes: 16 additions & 12 deletions app/modules/conversations/conversations_router.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import List

from typing import List, AsyncGenerator, Any
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
Expand All @@ -26,10 +25,16 @@
RenameConversationRequest,
)
from .message.message_schema import MessageRequest, MessageResponse, RegenerateRequest
import json

router = APIRouter()


async def get_stream(data_stream: AsyncGenerator[Any, None]):
async for chunk in data_stream:
yield json.dumps(chunk.dict())

Comment on lines +33 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for dict() method call.

The function assumes that each chunk has a dict() method. Add error handling to gracefully handle cases where the chunk doesn't support the dict() method.

Apply this diff to add error handling:

 async def get_stream(data_stream: AsyncGenerator[Any, None]):
     async for chunk in data_stream:
-        yield json.dumps(chunk.dict())
+        try:
+            yield json.dumps(chunk.dict())
+        except (AttributeError, TypeError) as e:
+            logger.error(f"Failed to serialize chunk: {e}")
+            raise ConversationServiceError("Failed to serialize response") from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def get_stream(data_stream: AsyncGenerator[Any, None]):
async for chunk in data_stream:
yield json.dumps(chunk.dict())
async def get_stream(data_stream: AsyncGenerator[Any, None]):
async for chunk in data_stream:
try:
yield json.dumps(chunk.dict())
except (AttributeError, TypeError) as e:
logger.error(f"Failed to serialize chunk: {e}")
raise ConversationServiceError("Failed to serialize response") from e


class ConversationAPI:
@staticmethod
@router.post("/conversations/", response_model=CreateConversationResponse)
Expand Down Expand Up @@ -98,13 +103,13 @@ async def post_message(
controller = ConversationController(db, user_id, user_email)
message_stream = controller.post_message(conversation_id, message, stream)
if stream:
return StreamingResponse(message_stream, media_type="text/event-stream")
return StreamingResponse(
get_stream(message_stream), media_type="text/event-stream"
)
else:
# Collect all chunks into a complete response
full_response = ""
# TODO: fix this, add types. In below stream we have only one output.
async for chunk in message_stream:
full_response += chunk
return {"content": full_response}
return chunk

Comment on lines +106 to +112
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix non-streaming response handling.

The non-streaming case only returns the first chunk from the stream, which might miss important data. Also, the TODO comment indicates that types need to be fixed.

Apply this diff to handle the non-streaming case properly:

-            # TODO: fix this, add types. In below stream we have only one output.
-            async for chunk in message_stream:
-                return chunk
+            # For non-streaming case, collect all chunks into a single response
+            chunks = []
+            async for chunk in message_stream:
+                chunks.append(chunk)
+            return chunks[-1] if chunks else None  # Return the last chunk or None if empty
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return StreamingResponse(
get_stream(message_stream), media_type="text/event-stream"
)
else:
# Collect all chunks into a complete response
full_response = ""
# TODO: fix this, add types. In below stream we have only one output.
async for chunk in message_stream:
full_response += chunk
return {"content": full_response}
return chunk
return StreamingResponse(
get_stream(message_stream), media_type="text/event-stream"
)
else:
# For non-streaming case, collect all chunks into a single response
chunks = []
async for chunk in message_stream:
chunks.append(chunk)
return chunks[-1] if chunks else None # Return the last chunk or None if empty

@staticmethod
@router.post(
Expand All @@ -124,13 +129,12 @@ async def regenerate_last_message(
conversation_id, request.node_ids, stream
)
if stream:
return StreamingResponse(message_stream, media_type="text/event-stream")
return StreamingResponse(
get_stream(message_stream), media_type="text/event-stream"
)
else:
# Collect all chunks into a complete response
full_response = ""
async for chunk in message_stream:
full_response += chunk
return {"content": full_response}
return chunk

Comment on lines +132 to +137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix non-streaming response handling.

Similar to post_message, the non-streaming case only returns the first chunk from the stream, which might miss important data.

Apply this diff to handle the non-streaming case properly:

-            async for chunk in message_stream:
-                return chunk
+            # For non-streaming case, collect all chunks into a single response
+            chunks = []
+            async for chunk in message_stream:
+                chunks.append(chunk)
+            return chunks[-1] if chunks else None  # Return the last chunk or None if empty
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return StreamingResponse(
get_stream(message_stream), media_type="text/event-stream"
)
else:
# Collect all chunks into a complete response
full_response = ""
async for chunk in message_stream:
full_response += chunk
return {"content": full_response}
return chunk
return StreamingResponse(
get_stream(message_stream), media_type="text/event-stream"
)
else:
# For non-streaming case, collect all chunks into a single response
chunks = []
async for chunk in message_stream:
chunks.append(chunk)
return chunks[-1] if chunks else None # Return the last chunk or None if empty

@staticmethod
@router.delete("/conversations/{conversation_id}/", response_model=dict)
Expand Down
6 changes: 6 additions & 0 deletions app/modules/conversations/message/message_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ class MessageRequest(BaseModel):
node_ids: Optional[List[NodeContext]] = None


class DirectMessageRequest(BaseModel):
content: str
node_ids: Optional[List[NodeContext]] = None
agent_id: str | None = None


class RegenerateRequest(BaseModel):
node_ids: Optional[List[NodeContext]] = None

Expand Down
Loading