Skip to content

Commit

Permalink
feat: Support for custom agent in conv (#173)
Browse files Browse the repository at this point in the history
* Serve tools through API  (#110)

* server tools through api

* update kg tool

* wip tools

* Async/sync behaviour fixed

* fix bug in code changes agent + linting + sonar fixes

* rename variable

* remove unused fn

* Support for custom agents (#114)

* Basics for integration

* Creation of custom agent

* Updates and smell fixes

* remove

* Update auth_service.py

* Fixes

* fixes

* Naming fixe

* Timestamp addition

* Refactoring tools

* Update agents_schema.py

* pre-commit fixes

* Update agents_schema.py

* Update agents_schema.py

* Update list tools API  (#118)

* server tools through api

* update kg tool

* wip tools

* Async/sync behaviour fixed

* fix bug in code changes agent + linting + sonar fixes

* rename variable

* remove unused fn

* update list tool api

* update to include user_id check in other tools

* updates

* Refactors to seperate custom agent service

* Cleanup round 2

* Memory Implementation

* payload fixes

* Update conversation_service.py

* Tool parameter integration

* Listing of custom agents

* Circular import fixe

* Update agent_injector_service.py

* Agentic Tool Refactors

* custom agent restructuring

* Tools improvements

* Update get_code_from_probable_node_name_tool.py

* Pre-commit

* Crew Refactors

* Update unit_test_agent.py

* Update integration_test_crew.py

* Update rag_crew.py

* Support-for-custom-agent-deployment (#133)

* Fixing base url

* fetching system prompt dynamically

* lint

* Update custom_agent.py

* Update change_detection.py

* change_detection_tool

* get_code_From_node_name_tool

* Update get_code_graph_from_node_id_tool.py

* Update get_code_graph_from_node_name_tool.py

* Update ask_knowledge_graph_queries_tool.py

* Update get_code_from_multiple_node_ids_tool.py

* Update get_nodes_from_tags_tool.py

* Update get_code_from_node_id_tool.py

* Update get_code_from_probable_node_name_tool.py

* Update get_code_from_probable_node_name_tool.py

* Pre-commit fixes

* Update inference_service.py

* Update inference_service.py

* resolving conflicts

* rename fixes

* pre-commit fixes

* Update projects_service.py

* Sonar fixes

* Update tool_service.py

* fixes

* Tool Standardisation and related refactors

* Update tool_service.py

* Moving all the tools to align with keyword project_id instead of repo_id

* Exposing new tools + adding back sync execution + renaming change_detection as tool

* desc fixes

* fixes from lint

* Improved description

* Cleanup

* setting custom flag as true by default for now (to be adjusted)

* Update tool_router.py

* Update debug_rag_crew.py

* adding status

* Enabling HMAC

* validatiom for agent id

* Update auth_service.py

* Update auth_service.py

* hmac prints to debug

* commenting out hmac validation (temp)

* Update auth_service.py

* Update custom_agents_service.py

* Update .env.template

* print removal

* print removal

* Agent Naming Convention Fixes

* Sonnet  Upgrade

* Added exception handling

* Fixes

* Adding Hmac

---------

Co-authored-by: Dhiren Mathur <[email protected]>
  • Loading branch information
vineetshar and dhirenmathur authored Nov 11, 2024
1 parent d472467 commit 0c36e39
Show file tree
Hide file tree
Showing 47 changed files with 1,074 additions and 289 deletions.
6 changes: 4 additions & 2 deletions .env.template
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ENV= development
OPENAI_API_KEY=
OPENAI_MODEL_REASONING=
# POSTGRES_SERVER=postgresql://postgres:[email protected]:5432/momentum Use this when using WSL
# POSTGRES_SERVER=postgresql://postgres:[email protected]:5432/momentum #for use with wsgl
POSTGRES_SERVER=postgresql://postgres:mysecretpassword@localhost:5432/momentum
MONGO_URI= mongodb://127.0.0.1:27017
MONGODB_DB_NAME= momentum
Expand All @@ -26,4 +26,6 @@ EMAIL_FROM_ADDRESS=
RESEND_API_KEY=
ANTHROPIC_API_KEY=
POSTHOG_API_KEY=
POSTHOG_HOST=
POSTHOG_HOST=
POTPIE_PLUS_BASE_URL=http://localhost:8080
POTPIE_PLUS_HMAC_KEY=
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("conversations", "visibility")
op.execute("DROP TYPE visibility")
# ### end Alembic commands ###
# ### end Alembic commands ###
2 changes: 2 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from app.modules.intelligence.prompts.prompt_router import router as prompt_router
from app.modules.intelligence.prompts.system_prompt_setup import SystemPromptSetup
from app.modules.intelligence.provider.provider_router import router as provider_router
from app.modules.intelligence.tools.tool_router import router as tool_router
from app.modules.key_management.secret_manager import router as secret_manager_router
from app.modules.parsing.graph_construction.parsing_router import (
router as parsing_router,
Expand Down Expand Up @@ -96,6 +97,7 @@ def include_routers(self):
self.app.include_router(agent_router, prefix="/api/v1", tags=["Agents"])

self.app.include_router(provider_router, prefix="/api/v1", tags=["Providers"])
self.app.include_router(tool_router, prefix="/api/v1", tags=["Tools"])

def add_health_check(self):
@self.app.get("/health", tags=["Health"])
Expand Down
42 changes: 41 additions & 1 deletion app/modules/auth/auth_service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import hashlib
import hmac
import json
import logging
import os
from typing import Union

from dotenv import load_dotenv
import requests
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from firebase_admin import auth


load_dotenv(override=True)
class AuthService:
def login(self, email, password):
log_prefix = "AuthService::login:"
Expand Down Expand Up @@ -64,6 +69,41 @@ async def check_auth(
)
res.headers["WWW-Authenticate"] = 'Bearer realm="auth_required"'
return decoded_token

@staticmethod
def generate_hmac_signature(message: str) -> str:
"""Generate HMAC signature for a message string"""
hmac_key = AuthService.get_hmac_secret_key()
if not hmac_key:
raise ValueError("HMAC secret key not configured")
hmac_obj = hmac.new(
key=hmac_key,
msg=message.encode("utf-8"),
digestmod=hashlib.sha256
)
return hmac_obj.hexdigest()

@staticmethod
def verify_hmac_signature(payload_body: Union[str, dict], hmac_signature: str) -> bool:
"""Verify HMAC signature matches the payload"""
hmac_key = AuthService.get_hmac_secret_key()
if not hmac_key:
raise ValueError("HMAC secret key not configured")
payload_str = payload_body if isinstance(payload_body, str) else json.dumps(payload_body, sort_keys=True)
expected_signature = hmac.new(
key=hmac_key,
msg=payload_str.encode("utf-8"),
digestmod=hashlib.sha256
).hexdigest()
return hmac.compare_digest(hmac_signature, expected_signature)

@staticmethod
def get_hmac_secret_key() -> bytes:
"""Get HMAC secret key from environment"""
key = os.getenv("POTPIE_PLUS_HMAC_KEY", "")
if not key:
return b""
return key.encode("utf-8")


auth_handler = AuthService()
58 changes: 33 additions & 25 deletions app/modules/conversations/conversation/conversation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
)
from app.modules.github.github_service import GithubService
from app.modules.intelligence.agents.agent_injector_service import AgentInjectorService
from app.modules.intelligence.agents.custom_agents.custom_agents_service import (
CustomAgentsService,
)
from app.modules.intelligence.memory.chat_history_service import ChatHistoryService
from app.modules.intelligence.provider.provider_service import ProviderService
from app.modules.projects.projects_service import ProjectService
Expand Down Expand Up @@ -70,6 +73,7 @@ def __init__(
history_manager: ChatHistoryService,
provider_service: ProviderService,
agent_injector_service: AgentInjectorService,
custom_agent_service: CustomAgentsService,
):
self.sql_db = db
self.user_id = user_id
Expand All @@ -78,13 +82,15 @@ def __init__(
self.history_manager = history_manager
self.provider_service = provider_service
self.agent_injector_service = agent_injector_service
self.custom_agent_service = custom_agent_service

@classmethod
def create(cls, db: Session, user_id: str, user_email: str):
project_service = ProjectService(db)
history_manager = ChatHistoryService(db)
provider_service = ProviderService(db, user_id)
agent_injector_service = AgentInjectorService(db, provider_service)
agent_injector_service = AgentInjectorService(db, provider_service, user_id)
custom_agent_service = CustomAgentsService()
return cls(
db,
user_id,
Expand All @@ -93,6 +99,7 @@ def create(cls, db: Session, user_id: str, user_email: str):
history_manager,
provider_service,
agent_injector_service,
custom_agent_service,
)

async def check_conversation_access(
Expand Down Expand Up @@ -135,6 +142,7 @@ async def create_conversation(
) -> tuple[str, str]:
try:
if not self.agent_injector_service.validate_agent_id(
user_id,
conversation.agent_ids[0]
):
raise ConversationServiceError(
Expand Down Expand Up @@ -274,25 +282,16 @@ async def store_message(
)
await self._update_conversation_title(conversation_id, new_title)

repo_id = (
project_id = (
conversation.project_ids[0] if conversation.project_ids else None
)
if not repo_id:
if not project_id:
raise ConversationServiceError(
"No project associated with this conversation"
)

agent = self.agent_injector_service.get_agent(conversation.agent_ids[0])
if not agent:
raise ConversationServiceError(
f"Invalid agent_id: {conversation.agent_ids[0]}"
)

logger.info(
f"Running agent for repo_id: {repo_id} conversation_id: {conversation_id}"
)
async for chunk in agent.run(
message.content, repo_id, user_id, conversation.id, message.node_ids
async for chunk in self._generate_and_stream_ai_response(
message.content, conversation_id, user_id, message.node_ids
):
yield chunk

Expand Down Expand Up @@ -446,23 +445,32 @@ async def _generate_and_stream_ai_response(
raise ConversationNotFoundError(
f"Conversation with id {conversation_id} not found"
)
agent = self.agent_injector_service.get_agent(conversation.agent_ids[0])
if not agent:
raise ConversationServiceError(
f"Invalid agent_id: {conversation.agent_ids[0]}"
)

agent_id = conversation.agent_ids[0]
project_id = conversation.project_ids[0] if conversation.project_ids else None

try:
agent = self.agent_injector_service.get_agent(agent_id)

logger.info(
f"conversation_id: {conversation_id}Running agent {conversation.agent_ids[0]} with query: {query} "
f"conversation_id: {conversation_id} Running agent {agent_id} with query: {query}"
)
async for chunk in agent.run(
query, conversation.project_ids[0], user_id, conversation.id, node_ids
):
if chunk:

if isinstance(agent, CustomAgentsService):
# Custom agent doesn't support streaming, so we'll yield the entire response at once
response = await agent.run(
agent_id, query, project_id, user_id, conversation.id, node_ids
)
yield response
else:
# For other agents that support streaming
async for chunk in agent.run(
query, project_id, user_id, conversation.id, node_ids
):
yield chunk

logger.info(
f"Generated and streamed AI response for conversation {conversation.id} for user {user_id} using agent {conversation.agent_ids[0]}"
f"Generated and streamed AI response for conversation {conversation.id} for user {user_id} using agent {agent_id}"
)
except Exception as e:
logger.error(
Expand Down
58 changes: 35 additions & 23 deletions app/modules/intelligence/agents/agent_injector_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,50 +3,62 @@

from sqlalchemy.orm import Session

from app.modules.intelligence.agents.chat_agents.code_changes_agent import (
CodeChangesAgent,
from app.modules.intelligence.agents.chat_agents.code_changes_chat_agent import (
CodeChangesChatAgent,
)
from app.modules.intelligence.agents.chat_agents.debugging_agent import DebuggingAgent
from app.modules.intelligence.agents.chat_agents.integration_test_agent import (
IntegrationTestAgent,
from app.modules.intelligence.agents.chat_agents.debugging_chat_agent import DebuggingChatAgent
from app.modules.intelligence.agents.chat_agents.integration_test_chat_agent import (
IntegrationTestChatAgent,
)
from app.modules.intelligence.agents.chat_agents.lld_chat_agent import LLDChatAgent
from app.modules.intelligence.agents.chat_agents.qna_chat_agent import QNAChatAgent
from app.modules.intelligence.agents.chat_agents.unit_test_chat_agent import UnitTestAgent
from app.modules.intelligence.agents.custom_agents.custom_agent import CustomAgent
from app.modules.intelligence.agents.custom_agents.custom_agents_service import (
CustomAgentsService,
)
from app.modules.intelligence.agents.chat_agents.lld_agent import LLDAgent
from app.modules.intelligence.agents.chat_agents.qna_agent import QNAAgent
from app.modules.intelligence.agents.chat_agents.unit_test_agent import UnitTestAgent
from app.modules.intelligence.provider.provider_service import ProviderService

logger = logging.getLogger(__name__)


class AgentInjectorService:
def __init__(self, db: Session, provider_service: ProviderService):
def __init__(self, db: Session, provider_service: ProviderService, user_id: str):
self.sql_db = db
self.provider_service = provider_service
self.custom_agent_service = CustomAgentsService()
self.agents = self._initialize_agents()
self.user_id = user_id

def _initialize_agents(self) -> Dict[str, Any]:
mini_llm = self.provider_service.get_small_llm()
reasoning_llm = self.provider_service.get_large_llm()
return {
"debugging_agent": DebuggingAgent(mini_llm, reasoning_llm, self.sql_db),
"codebase_qna_agent": QNAAgent(mini_llm, reasoning_llm, self.sql_db),
"debugging_agent": DebuggingChatAgent(mini_llm, reasoning_llm, self.sql_db),
"codebase_qna_agent": QNAChatAgent(mini_llm, reasoning_llm, self.sql_db),
"unit_test_agent": UnitTestAgent(mini_llm, reasoning_llm, self.sql_db),
"integration_test_agent": IntegrationTestAgent(
"integration_test_agent": IntegrationTestChatAgent(
mini_llm, reasoning_llm, self.sql_db
),
"code_changes_agent": CodeChangesAgent(
"code_changes_agent": CodeChangesChatAgent(
mini_llm, reasoning_llm, self.sql_db
),
"LLD_agent": LLDAgent(mini_llm, reasoning_llm, self.sql_db),
"LLD_agent": LLDChatAgent(mini_llm, reasoning_llm, self.sql_db),
}

def get_agent(self, agent_id: str) -> Any:
agent = self.agents.get(agent_id)
if not agent:
logger.error(f"Invalid agent_id: {agent_id}")
raise ValueError(f"Invalid agent_id: {agent_id}")
return agent

def validate_agent_id(self, agent_id: str) -> bool:
logger.info(f"Validating agent_id: {agent_id}")
return agent_id in self.agents
if agent_id in self.agents:
return self.agents[agent_id]
else:
reasoning_llm = self.provider_service.get_large_llm()
return CustomAgent(
llm=reasoning_llm,
db=self.sql_db,
agent_id=agent_id,
user_id=self.user_id,
)

def validate_agent_id(self, user_id: str, agent_id: str) -> bool:
return agent_id in self.agents or self.custom_agent_service.validate_agent(
self.sql_db, user_id, agent_id
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from pydantic import BaseModel, Field

from app.modules.conversations.message.message_schema import NodeContext
from app.modules.intelligence.tools.change_detection.change_detection import (
from app.modules.intelligence.tools.change_detection.change_detection_tool import (
ChangeDetectionResponse,
get_blast_radius_tool,
get_change_detection_tool,
)
from app.modules.intelligence.tools.kg_based_tools.ask_knowledge_graph_queries_tool import (
get_ask_knowledge_graph_queries_tool,
Expand Down Expand Up @@ -103,11 +103,12 @@ async def create_tasks(
expected_output=f"Comprehensive impact analysis of the code changes on the codebase and answers to the users query about them. Ensure that your output ALWAYS follows the structure outlined in the following pydantic model : {self.BlastRadiusAgentResponse.model_json_schema()}",
agent=blast_radius_agent,
tools=[
get_blast_radius_tool(self.user_id),
get_change_detection_tool(self.user_id),
self.get_nodes_from_tags,
self.ask_knowledge_graph_queries,
],
output_pydantic=self.BlastRadiusAgentResponse,
async_execution=True,
)

return analyze_changes_task
Expand All @@ -134,7 +135,7 @@ async def run(
return result


async def kickoff_blast_radius_crew(
async def kickoff_blast_radius_agent(
query: str, project_id: str, node_ids: List[NodeContext], sql_db, user_id, llm
) -> Dict[str, str]:
blast_radius_agent = BlastRadiusAgent(sql_db, user_id, llm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class RAGResponse(BaseModel):
response: List[NodeResponse]


class DebugAgent:
class DebugRAGAgent:
def __init__(self, sql_db, llm, mini_llm, user_id):
self.openai_api_key = os.getenv("OPENAI_API_KEY")
self.max_iter = os.getenv("MAX_ITER", 5)
Expand Down Expand Up @@ -225,7 +225,7 @@ async def run(
return result


async def kickoff_debug_crew(
async def kickoff_debug_rag_agent(
query: str,
project_id: str,
chat_history: List,
Expand All @@ -235,7 +235,7 @@ async def kickoff_debug_crew(
mini_llm,
user_id: str,
) -> str:
debug_agent = DebugAgent(sql_db, llm, mini_llm, user_id)
debug_agent = DebugRAGAgent(sql_db, llm, mini_llm, user_id)
file_structure = await GithubService(sql_db).get_project_structure_async(project_id)
result = await debug_agent.run(
query, project_id, chat_history, node_ids, file_structure
Expand Down
Loading

0 comments on commit 0c36e39

Please sign in to comment.