diff --git a/postgres_da_ai_agent/agents/agents.py b/postgres_da_ai_agent/agents/agents.py index 90585d8..fe3dd30 100644 --- a/postgres_da_ai_agent/agents/agents.py +++ b/postgres_da_ai_agent/agents/agents.py @@ -4,15 +4,22 @@ from postgres_da_ai_agent.agents import agent_config import autogen import guidance +import agentops +from dotenv import load_dotenv +import os -# ------------------------ PROMPTS ------------------------ +# Load environment variables +load_dotenv() + +# Initialize AgentOps +agentops.init(os.getenv('AGENTOPS_API_KEY')) +# ------------------------ PROMPTS ------------------------ USER_PROXY_PROMPT = "A human admin. Interact with the Product Manager to discuss the plan. Plan execution needs to be approved by this admin." DATA_ENGINEER_PROMPT = "A Data Engineer. Generate the initial SQL based on the requirements provided. Send it to the Sr Data Analyst to be executed. " SR_DATA_ANALYST_PROMPT = "Sr Data Analyst. You run the SQL query using the run_sql function, send the raw response to the data viz team. You use the run_sql function exclusively." - GUIDANCE_SCRUM_MASTER_SQL_NLQ_PROMPT = """ Is the following block of text a SQL Natural Language Query (NLQ)? Please rank from 1 to 5, where: 1: Definitely not NLQ @@ -40,10 +47,8 @@ {{/geneach}}] ```""" - INSIGHTS_FILE_REPORTER_PROMPT = "You're a data reporter. You write json data you receive directly into a file using the write_innovation_file function." - # unused prompts COMPLETION_PROMPT = "If everything looks good, respond with APPROVED" PRODUCT_MANAGER_PROMPT = ( @@ -54,17 +59,13 @@ JSON_REPORT_ANALYST_PROMPT = "Json Report Analyst. You exclusively use the write_json_file function on the report." YML_REPORT_ANALYST_PROMPT = "Yaml Report Analyst. You exclusively use the write_yml_file function on the report." - # ------------------------ BUILD AGENT TEAMS ------------------------ - +@agentops.record_function('build_data_eng_team') def build_data_eng_team(instruments: PostgresAgentInstruments): """ Build a team of agents that can generate, execute, and report an SQL query """ - - # create a set of agents with specific roles - # admin user proxy agent - takes in the prompt and manages the group chat user_proxy = autogen.UserProxyAgent( name="Admin", system_message=USER_PROXY_PROMPT, @@ -72,7 +73,6 @@ def build_data_eng_team(instruments: PostgresAgentInstruments): human_input_mode="NEVER", ) - # data engineer agent - generates the sql query data_engineer = autogen.AssistantAgent( name="Engineer", llm_config=agent_config.base_config, @@ -98,9 +98,8 @@ def build_data_eng_team(instruments: PostgresAgentInstruments): sr_data_analyst, ] - +@agentops.record_function('build_data_viz_team') def build_data_viz_team(instruments: PostgresAgentInstruments): - # admin user proxy agent - takes in the prompt and manages the group chat user_proxy = autogen.UserProxyAgent( name="Admin", system_message=USER_PROXY_PROMPT, @@ -108,7 +107,6 @@ def build_data_viz_team(instruments: PostgresAgentInstruments): human_input_mode="NEVER", ) - # text report analyst - writes a summary report of the results and saves them to a local text file text_report_analyst = autogen.AssistantAgent( name="Text_Report_Analyst", llm_config=agent_config.write_file_config, @@ -119,7 +117,6 @@ def build_data_viz_team(instruments: PostgresAgentInstruments): }, ) - # json report analyst - writes a summary report of the results and saves them to a local json file json_report_analyst = autogen.AssistantAgent( name="Json_Report_Analyst", llm_config=agent_config.write_json_file_config, @@ -147,7 +144,7 @@ def build_data_viz_team(instruments: PostgresAgentInstruments): yaml_report_analyst, ] - +@agentops.record_function('build_scrum_master_team') def build_scrum_master_team(instruments: PostgresAgentInstruments): user_proxy = autogen.UserProxyAgent( name="Admin", @@ -165,7 +162,7 @@ def build_scrum_master_team(instruments: PostgresAgentInstruments): return [user_proxy, scrum_agent] - +@agentops.record_function('build_insights_team') def build_insights_team(instruments: PostgresAgentInstruments): user_proxy = autogen.UserProxyAgent( name="Admin", @@ -193,10 +190,9 @@ def build_insights_team(instruments: PostgresAgentInstruments): return [user_proxy, insights_agent, insights_data_reporter] - # ------------------------ ORCHESTRATION ------------------------ - +@agentops.record_function('build_team_orchestrator') def build_team_orchestrator( team: str, agent_instruments: PostgresAgentInstruments, @@ -235,20 +231,19 @@ def build_team_orchestrator( raise Exception("Unknown team: " + team) - # ------------------------ CUSTOM AGENTS ------------------------ - class DefensiveScrumMasterAgent(autogen.ConversableAgent): """ Custom agent that uses the guidance function to determine if a message is a SQL NLQ """ - + @agentops.record_function('DefensiveScrumMasterAgent_init') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Register the new reply function for this specific agent self.register_reply(self, self.check_sql_nlq, position=0) + @agentops.record_function('DefensiveScrumMasterAgent_check_sql_nlq') def check_sql_nlq( self, messages: Optional[List[Dict]] = None, @@ -269,16 +264,16 @@ def check_sql_nlq( return True, rank - class InsightsAgent(autogen.ConversableAgent): """ Custom agent that uses the guidance function to generate insights in JSON format """ - + @agentops.record_function('InsightsAgent_init') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.register_reply(self, self.generate_insights, position=0) + @agentops.record_function('InsightsAgent_generate_insights') def generate_insights( self, messages: Optional[List[Dict]] = None, @@ -287,3 +282,6 @@ def generate_insights( ): insights = guidance(DATA_INSIGHTS_GUIDANCE_PROMPT) return True, insights + +# End of program +agentops.end_session('Success') \ No newline at end of file diff --git a/postgres_da_ai_agent/main.py b/postgres_da_ai_agent/main.py index 9c84ab1..1ee8c1b 100644 --- a/postgres_da_ai_agent/main.py +++ b/postgres_da_ai_agent/main.py @@ -14,29 +14,31 @@ import dotenv import argparse import autogen +import agentops +from dotenv import load_dotenv from postgres_da_ai_agent.types import ConversationResult # ---------------- Your Environment Variables ---------------- -dotenv.load_dotenv() +load_dotenv() assert os.environ.get("DATABASE_URL"), "POSTGRES_CONNECTION_URL not found in .env file" -assert os.environ.get( - "OPENAI_API_KEY" -), "POSTGRES_CONNECTION_URL not found in .env file" +assert os.environ.get("OPENAI_API_KEY"), "OPENAI_API_KEY not found in .env file" +assert os.environ.get("AGENTOPS_API_KEY"), "AGENTOPS_API_KEY not found in .env file" +# Initialize AgentOps +agentops.init(os.getenv('AGENTOPS_API_KEY')) # ---------------- Constants ---------------- - DB_URL = os.environ.get("DATABASE_URL") OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") POSTGRES_TABLE_DEFINITIONS_CAP_REF = "TABLE_DEFINITIONS" - +@agentops.record_function('main') def main(): # ---------------- Parse '--prompt' CLI Parameter ---------------- @@ -59,52 +61,64 @@ def main(): with PostgresAgentInstruments(DB_URL, session_id) as (agent_instruments, db): # ----------- Gate Team: Prevent bad prompts from running and burning your $$$ ------------- - gate_orchestrator = agents.build_team_orchestrator( - "scrum_master", - agent_instruments, - validate_results=lambda: (True, ""), - ) + @agentops.record_function('gate_team') + def run_gate_team(): + gate_orchestrator = agents.build_team_orchestrator( + "scrum_master", + agent_instruments, + validate_results=lambda: (True, ""), + ) - gate_orchestrator: ConversationResult = ( - gate_orchestrator.sequential_conversation(prompt) - ) + gate_orchestrator: ConversationResult = ( + gate_orchestrator.sequential_conversation(prompt) + ) - print("gate_orchestrator.last_message_str", gate_orchestrator.last_message_str) + print("gate_orchestrator.last_message_str", gate_orchestrator.last_message_str) - nlq_confidence = int(gate_orchestrator.last_message_str) + nlq_confidence = int(gate_orchestrator.last_message_str) - match nlq_confidence: - case (1 | 2): - print(f"❌ Gate Team Rejected - Confidence too low: {nlq_confidence}") - return - case (3 | 4 | 5): - print(f"✅ Gate Team Approved - Valid confidence: {nlq_confidence}") - case _: - print("❌ Gate Team Rejected - Invalid response") - return + match nlq_confidence: + case (1 | 2): + print(f"❌ Gate Team Rejected - Confidence too low: {nlq_confidence}") + return False + case (3 | 4 | 5): + print(f"✅ Gate Team Approved - Valid confidence: {nlq_confidence}") + return True + case _: + print("❌ Gate Team Rejected - Invalid response") + return False + + if not run_gate_team(): + return # -------- BUILD TABLE DEFINITIONS ----------- - map_table_name_to_table_def = db.get_table_definition_map_for_embeddings() + @agentops.record_function('build_table_definitions') + def build_table_definitions(): + map_table_name_to_table_def = db.get_table_definition_map_for_embeddings() - database_embedder = embeddings.DatabaseEmbedder() + database_embedder = embeddings.DatabaseEmbedder() - for name, table_def in map_table_name_to_table_def.items(): - database_embedder.add_table(name, table_def) + for name, table_def in map_table_name_to_table_def.items(): + database_embedder.add_table(name, table_def) - similar_tables = database_embedder.get_similar_tables(raw_prompt, n=5) + similar_tables = database_embedder.get_similar_tables(raw_prompt, n=5) - table_definitions = database_embedder.get_table_definitions_from_names( - similar_tables - ) + table_definitions = database_embedder.get_table_definitions_from_names( + similar_tables + ) - related_table_names = db.get_related_tables(similar_tables, n=3) + related_table_names = db.get_related_tables(similar_tables, n=3) - core_and_related_table_definitions = ( - database_embedder.get_table_definitions_from_names( - related_table_names + similar_tables + core_and_related_table_definitions = ( + database_embedder.get_table_definitions_from_names( + related_table_names + similar_tables + ) ) - ) + + return table_definitions, core_and_related_table_definitions + + table_definitions, core_and_related_table_definitions = build_table_definitions() prompt = llm.add_cap_ref( prompt, @@ -115,69 +129,82 @@ def main(): # ----------- Data Eng Team: Based on a sql table definitions and a prompt create an sql statement and execute it ------------- - data_eng_orchestrator = agents.build_team_orchestrator( - "data_eng", - agent_instruments, - validate_results=agent_instruments.validate_run_sql, - ) + @agentops.record_function('data_eng_team') + def run_data_eng_team(): + data_eng_orchestrator = agents.build_team_orchestrator( + "data_eng", + agent_instruments, + validate_results=agent_instruments.validate_run_sql, + ) - data_eng_conversation_result: ConversationResult = ( - data_eng_orchestrator.sequential_conversation(prompt) - ) + data_eng_conversation_result: ConversationResult = ( + data_eng_orchestrator.sequential_conversation(prompt) + ) - match data_eng_conversation_result: - case ConversationResult( - success=True, cost=data_eng_cost, tokens=data_eng_tokens - ): - print( - f"✅ Orchestrator was successful. Team: {data_eng_orchestrator.name}" - ) - print( - f"💰📊🤖 {data_eng_orchestrator.name} Cost: {data_eng_cost}, tokens: {data_eng_tokens}" - ) - case _: - print( - f"❌ Orchestrator failed. Team: {data_eng_orchestrator.name} Failed" - ) + match data_eng_conversation_result: + case ConversationResult( + success=True, cost=data_eng_cost, tokens=data_eng_tokens + ): + print( + f"✅ Orchestrator was successful. Team: {data_eng_orchestrator.name}" + ) + print( + f"💰📊🤖 {data_eng_orchestrator.name} Cost: {data_eng_cost}, tokens: {data_eng_tokens}" + ) + return True + case _: + print( + f"❌ Orchestrator failed. Team: {data_eng_orchestrator.name} Failed" + ) + return False + + run_data_eng_team() # ----------- Data Insights Team: Based on sql table definitions and a prompt generate novel insights ------------- - innovation_prompt = f"Given this database query: '{raw_prompt}'. Generate novel insights and new database queries to give business insights." + @agentops.record_function('data_insights_team') + def run_data_insights_team(): + innovation_prompt = f"Given this database query: '{raw_prompt}'. Generate novel insights and new database queries to give business insights." - insights_prompt = llm.add_cap_ref( - innovation_prompt, - f"Use these {POSTGRES_TABLE_DEFINITIONS_CAP_REF} to satisfy the database query.", - POSTGRES_TABLE_DEFINITIONS_CAP_REF, - core_and_related_table_definitions, - ) - - data_insights_orchestrator = agents.build_team_orchestrator( - "data_insights", - agent_instruments, - validate_results=agent_instruments.validate_innovation_files, - ) + insights_prompt = llm.add_cap_ref( + innovation_prompt, + f"Use these {POSTGRES_TABLE_DEFINITIONS_CAP_REF} to satisfy the database query.", + POSTGRES_TABLE_DEFINITIONS_CAP_REF, + core_and_related_table_definitions, + ) - data_insights_conversation_result: ConversationResult = ( - data_insights_orchestrator.round_robin_conversation( - insights_prompt, loops=1 + data_insights_orchestrator = agents.build_team_orchestrator( + "data_insights", + agent_instruments, + validate_results=agent_instruments.validate_innovation_files, ) - ) - match data_insights_conversation_result: - case ConversationResult( - success=True, cost=data_insights_cost, tokens=data_insights_tokens - ): - print( - f"✅ Orchestrator was successful. Team: {data_insights_orchestrator.name}" - ) - print( - f"💰📊🤖 {data_insights_orchestrator.name} Cost: {data_insights_cost}, tokens: {data_insights_tokens}" - ) - case _: - print( - f"❌ Orchestrator failed. Team: {data_insights_orchestrator.name} Failed" + data_insights_conversation_result: ConversationResult = ( + data_insights_orchestrator.round_robin_conversation( + insights_prompt, loops=1 ) + ) + match data_insights_conversation_result: + case ConversationResult( + success=True, cost=data_insights_cost, tokens=data_insights_tokens + ): + print( + f"✅ Orchestrator was successful. Team: {data_insights_orchestrator.name}" + ) + print( + f"💰📊🤖 {data_insights_orchestrator.name} Cost: {data_insights_cost}, tokens: {data_insights_tokens}" + ) + return True + case _: + print( + f"❌ Orchestrator failed. Team: {data_insights_orchestrator.name} Failed" + ) + return False + + run_data_insights_team() if __name__ == "__main__": main() + # End of program + agentops.end_session('Success') \ No newline at end of file