Skip to content

Commit

Permalink
custom agent restructuring
Browse files Browse the repository at this point in the history
  • Loading branch information
vineetshar committed Oct 27, 2024
1 parent 7194199 commit d2efd88
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 68 deletions.
33 changes: 8 additions & 25 deletions app/modules/intelligence/agents/custom_agents/custom_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
from functools import lru_cache
from typing import Any, AsyncGenerator, Dict, List
from typing import AsyncGenerator, Dict, List, Any

import httpx
from langchain.schema import HumanMessage, SystemMessage
Expand All @@ -19,17 +19,18 @@
from app.modules.intelligence.memory.chat_history_service import ChatHistoryService
from app.modules.intelligence.prompts.prompt_schema import PromptResponse, PromptType
from app.modules.intelligence.prompts.prompt_service import PromptService
from app.modules.intelligence.agents.custom_agents.custom_agents_service import CustomAgentsService

logger = logging.getLogger(__name__)


class CustomAgent:
def __init__(self, llm, db: Session, agent_id: str):
self.llm = llm
self.db = db
self.agent_id = agent_id
self.history_manager = ChatHistoryService(db)
self.prompt_service = PromptService(db)
self.custom_agents_service = CustomAgentsService()
self.chain = None

@lru_cache(maxsize=2)
Expand Down Expand Up @@ -71,22 +72,16 @@ async def run(

history = self.history_manager.get_session_history(user_id, conversation_id)
validated_history = [
(
HumanMessage(content=str(msg))
if isinstance(msg, (str, int, float))
else msg
)
HumanMessage(content=str(msg)) if isinstance(msg, (str, int, float)) else msg
for msg in history
]

custom_agent_result = await self.custom_agent_service.run(
self.agent_id, query, node_ids
custom_agent_result = await self.custom_agents_service.run_agent(
self.agent_id, query, project_id, user_id, node_ids
)

tool_results = [
SystemMessage(
content=f"Custom Agent result: {json.dumps(custom_agent_result)}"
)
SystemMessage(content=f"Custom Agent result: {json.dumps(custom_agent_result)}")
]

inputs = {
Expand Down Expand Up @@ -118,16 +113,4 @@ async def run(
yield f"An error occurred: {str(e)}"

async def is_valid(self) -> bool:
validate_url = f"{self.base_url}/deployment/{self.agent_id}/validate"

async with httpx.AsyncClient() as client:
response = await client.get(validate_url)
return response.status_code == 200

async def run(self, payload: Dict[str, Any]) -> str:
run_url = f"{self.base_url}/deployment/{self.agent_id}/run"

async with httpx.AsyncClient() as client:
response = await client.post(run_url, json=payload)
response.raise_for_status()
return response.json()["response"]
return await self.custom_agents_service.validate_agent(self.agent_id)
Original file line number Diff line number Diff line change
@@ -1,59 +1,51 @@
from typing import List, Union

from sqlalchemy.orm import Session
import httpx
import logging
from typing import Dict, Any, List

from app.modules.conversations.message.message_schema import NodeContext

logger = logging.getLogger(__name__)

class CustomAgentService:
def __init__(self, db: Session):
self.db = db
self.base_url = "http://localhost:8080"
class CustomAgentsService:
def __init__(self):
self.base_url = "https://your-custom-agent-service-url.com" # Replace with actual URL

async def run(
async def run_agent(
self,
agent_id: str,
query: str,
project_id: str,
user_id: str,
conversation_id: str,
node_ids: Union[List[NodeContext], List[str]],
) -> str:
# Import CustomAgent here to avoid circular import
from app.modules.intelligence.agents.custom_agents.custom_agent import (
CustomAgent,
)

custom_agent = CustomAgent(agent_id)

# Convert node_ids to a list of dictionaries or strings
node_ids_payload = [
node.dict() if isinstance(node, NodeContext) else node for node in node_ids
]

node_ids: List[NodeContext],
) -> Dict[str, Any]:
run_url = f"{self.base_url}/api/v1/agents/{agent_id}/run"
payload = {
"query": query,
"project_id": project_id,
"user_id": user_id,
"conversation_id": conversation_id,
"node_ids": node_ids_payload,
"node_ids": [node.dict() for node in node_ids],
}

return await custom_agent.run(payload)

# async def get_system_prompt(self, agent_id: str) -> str:
# system_prompt_url = f"{self.base_url}/deployment/{agent_id}/system_prompt"

# async with httpx.AsyncClient() as client:
# response = await client.get(system_prompt_url)
# response.raise_for_status()
# return response.text

async def is_valid_agent(self, agent_id: str) -> bool:
# Import CustomAgent here to avoid circular import
from app.modules.intelligence.agents.custom_agents.custom_agent import (
CustomAgent,
)

custom_agent = CustomAgent(agent_id)
return await custom_agent.is_valid()
async with httpx.AsyncClient() as client:
try:
response = await client.post(run_url, json=payload)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred while running agent {agent_id}: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error occurred while running agent {agent_id}: {e}")
raise

async def validate_agent(self, agent_id: str) -> bool:
return True
# validate_url = f"{self.base_url}/api/v1/agents/{agent_id}/validate"

# async with httpx.AsyncClient() as client:
# try:
# response = await client.get(validate_url)
# return response.status_code == 200
# except Exception as e:
# logger.error(f"Error validating agent {agent_id}: {e}")
# return False

0 comments on commit d2efd88

Please sign in to comment.