Skip to content

Commit

Permalink
Tool Standardisation and related refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
vineetshar committed Oct 29, 2024
1 parent 4041d9e commit 0ea6ced
Show file tree
Hide file tree
Showing 12 changed files with 39 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -302,20 +302,15 @@ def get_parameters() -> List[ToolParameter]:
),
]

async def arun(self, project_id):
return asyncio.run(self.get_code_changes(project_id))

def run(self, project_id: str) -> str:
return self.get_code_changes(project_id)

async def arun(self, project_id: str) -> str:
return await self.get_code_changes(project_id)

def get_change_detection_tool(user_id: str) -> Tool:
"""
Get a list of LangChain Tool objects for use in agents.
"""
change_detection_tool = ChangeDetectionTool(next(get_db()), user_id)
return StructuredTool.from_function(
func=change_detection_tool.run,
coroutine=change_detection_tool.arun,
name="Get code changes",
description="""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from langchain_core.tools import StructuredTool
from pydantic import BaseModel
from sqlalchemy.orm import Session
import asyncio

from app.modules.github.github_service import GithubService

Expand All @@ -15,19 +16,15 @@ def __init__(self, db: Session):

def fetch_repo_structure(self, repo_id: str) -> str:
return self.github_service.get_project_structure(repo_id)

async def run(self, repo_id: str) -> str:
return self.fetch_repo_structure(repo_id)

def run_tool(self, repo_id: str) -> str:
return self.fetch_repo_structure(repo_id)

async def arun(self, repo_id: str) -> str:
return await self.fetch_repo_structure(repo_id)


def get_code_file_structure_tool(db: Session) -> StructuredTool:
return StructuredTool(
name="get_code_file_structure",
description="Retrieve the hierarchical file structure of a specified repository.",
coroutine=RepoStructureService(db).run,
func=RepoStructureService(db).run_tool,
coroutine=RepoStructureService(db).arun,
args_schema=RepoStructureRequest,
)
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@ def _create_neo4j_driver(self) -> GraphDatabase.driver:
auth=(neo4j_config["username"], neo4j_config["password"]),
)

def run(self, repo_id: str, node_name: str) -> Dict[str, Any]:
project = asyncio.run(
ProjectService(self.sql_db).get_project_repo_details_from_db(
repo_id, self.user_id
)
)
async def arun(self, repo_id: str, node_name: str) -> Dict[str, Any]:
project = await ProjectService(self.sql_db).get_project_repo_details_from_db(
repo_id, self.user_id)

if not project:
raise ValueError(
f"Project with ID '{repo_id}' not found in database for user '{self.user_id}'"
Expand Down Expand Up @@ -126,9 +124,6 @@ def __del__(self):
if hasattr(self, "neo4j_driver"):
self.neo4j_driver.close()

async def arun(self, repo_id: str, node_name: str) -> Dict[str, Any]:
return self.run(repo_id, node_name)

@staticmethod
def get_parameters() -> List[ToolParameter]:
return [
Expand All @@ -151,7 +146,6 @@ def get_code_from_node_name_tool(sql_db: Session, user_id: str) -> Tool:
tool_instance = GetCodeFromNodeNameTool(sql_db, user_id)
return StructuredTool.from_function(
coroutine=tool_instance.arun,
func=tool_instance.run,
name="Get Code From Node Name",
description="Retrieves code for a specific node in a repository given its node name",
)
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -36,7 +37,7 @@ def _create_neo4j_driver(self) -> GraphDatabase.driver:
auth=(neo4j_config["username"], neo4j_config["password"]),
)

def run(self, repo_id: str, node_id: str) -> Dict[str, Any]:
async def arun(self, repo_id: str, node_id: str) -> Dict[str, Any]:
"""
Run the tool to retrieve the code graph.
Expand Down Expand Up @@ -185,10 +186,6 @@ def __del__(self):
if hasattr(self, "neo4j_driver"):
self.neo4j_driver.close()

async def arun(self, repo_id: str, node_id: str) -> Dict[str, Any]:
"""Asynchronous version of the run method."""
return self.run(repo_id, node_id)

@staticmethod
def get_parameters() -> List[ToolParameter]:
return [
Expand All @@ -210,7 +207,6 @@ def get_parameters() -> List[ToolParameter]:
def get_code_graph_from_node_id_tool(sql_db: Session) -> Tool:
tool_instance = GetCodeGraphFromNodeIdTool(sql_db)
return StructuredTool.from_function(
func=tool_instance.run,
coroutine=tool_instance.arun,
name="Get Code Graph From Node ID",
description="Retrieves a code graph for a specific node in a repository given its node ID",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -36,7 +37,7 @@ def _create_neo4j_driver(self) -> GraphDatabase.driver:
auth=(neo4j_config["username"], neo4j_config["password"]),
)

def run(self, repo_id: str, node_name: str) -> Dict[str, Any]:
async def arun(self, repo_id: str, node_name: str) -> Dict[str, Any]:
"""
Run the tool to retrieve the code graph.
Expand Down Expand Up @@ -187,10 +188,6 @@ def __del__(self):
if hasattr(self, "neo4j_driver"):
self.neo4j_driver.close()

async def arun(self, repo_id: str, node_name: str) -> Dict[str, Any]:
"""Asynchronous version of the run method."""
return self.run(repo_id, node_name)

@staticmethod
def get_parameters() -> List[ToolParameter]:
return [
Expand All @@ -213,7 +210,6 @@ def get_code_graph_from_node_name_tool(sql_db: Session) -> Tool:
tool_instance = GetCodeGraphFromNodeNameTool(sql_db)
return StructuredTool.from_function(
coroutine=tool_instance.arun,
func=tool_instance.run,
name="Get Code Graph From Node Name",
description="Retrieves a code graph for a specific node in a repository given its node name",
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain_core.tools import StructuredTool, Tool
from neo4j import GraphDatabase
from sqlalchemy.orm import Session
import asyncio

from app.core.config_provider import config_provider

Expand Down Expand Up @@ -34,7 +35,7 @@ def _create_neo4j_driver(self) -> GraphDatabase.driver:
auth=(neo4j_config["username"], neo4j_config["password"]),
)

def run_tool(self, project_id: str, node_ids: List[str]) -> Dict[str, Any]:
async def arun(self, project_id: str, node_ids: List[str]) -> Dict[str, Any]:
"""
Run the tool to retrieve neighbors of the specified nodes.
Expand All @@ -56,20 +57,7 @@ def run_tool(self, project_id: str, node_ids: List[str]) -> Dict[str, Any]:
except Exception as e:
logging.exception(f"An unexpected error occurred: {str(e)}")
return {"error": f"An unexpected error occurred: {str(e)}"}

async def run(self, project_id: str, node_ids: List[str]) -> Dict[str, Any]:
"""
Run the tool to retrieve neighbors of the specified nodes.
Args:
project_id (str): Project ID.
node_ids (List[str]): List of node IDs to retrieve neighbors for. Should contain atleast one node ID.
Returns:
Dict[str, Any]: Neighbor data or error message.
"""
return self.run_tool(project_id, node_ids)


def _get_neighbors(
self, project_id: str, node_ids: List[str]
) -> Optional[List[Dict[str, Any]]]:
Expand Down Expand Up @@ -112,8 +100,7 @@ def __del__(self):
def get_node_neighbours_from_node_id_tool(sql_db: Session) -> Tool:
tool_instance = GetNodeNeighboursFromNodeIdTool(sql_db)
return StructuredTool.from_function(
coroutine=tool_instance.run,
func=tool_instance.run_tool,
coroutine=tool_instance.arun,
name="Get Node Neighbours From Node ID",
description="Retrieves inbound and outbound neighbors of a specific node in a repository given its node ID. This is helpful to find which functions are called by a specific function and which functions are calling the specific function. Works best with Pythoon, JS and TS code.",
)
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,9 @@ async def process_query(query_request: QueryRequest) -> List[QueryResponse]:
results = await asyncio.gather(*tasks)

return results

async def arun(
self, queries: List[str], project_id: str, node_ids: List[str] = []
) -> Dict[str, str]:
"""Asynchronous version of the run method."""
return self.run(queries=queries, project_id=project_id, node_ids=node_ids)

def run(
self, queries: List[str], project_id: str, node_ids: List[str] = []
) -> Dict[str, str]:
"""
Query the code knowledge graph using multiple natural language questions.
Expand All @@ -114,11 +108,8 @@ def run(
Returns:
- Dict[str, str]: A dictionary where keys are the original queries and values are the corresponding responses.
"""
project = asyncio.run(
ProjectService(self.sql_db).get_project_repo_details_from_db(
project_id, self.user_id
)
)
project = await ProjectService(self.sql_db).get_project_repo_details_from_db(
project_id, self.user_id )
if not project:
raise ValueError(
f"Project with ID '{project_id}' not found in database for user '{self.user_id}'"
Expand All @@ -128,12 +119,11 @@ def run(
QueryRequest(query=query, project_id=project_id, node_ids=node_ids)
for query in queries
]
return asyncio.run(self.ask_multiple_knowledge_graph_queries(query_list))
return await self.ask_multiple_knowledge_graph_queries(query_list)


def get_ask_knowledge_graph_queries_tool(sql_db, user_id) -> StructuredTool:
return StructuredTool.from_function(
func=KnowledgeGraphQueryTool(sql_db, user_id).run,
coroutine=KnowledgeGraphQueryTool(sql_db, user_id).arun,
name="Ask Knowledge Graph Queries",
description="""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,8 @@ def _create_neo4j_driver(self) -> GraphDatabase.driver:
auth=(neo4j_config["username"], neo4j_config["password"]),
)

def arun(self, repo_id: str, node_ids: List[str]) -> Dict[str, Any]:
"""Asynchronous version of the run method."""
return self.run(repo_id, node_ids)

def run(self, repo_id: str, node_ids: List[str]) -> Dict[str, Any]:
return asyncio.run(self.run_multiple(repo_id, node_ids))
async def arun(self, repo_id: str, node_ids: List[str]) -> Dict[str, Any]:
return await self.run_multiple(repo_id, node_ids)

async def run_multiple(self, repo_id: str, node_ids: List[str]) -> Dict[str, Any]:
try:
Expand Down Expand Up @@ -159,7 +155,6 @@ def get_code_from_multiple_node_ids_tool(
) -> StructuredTool:
tool_instance = GetCodeFromMultipleNodeIdsTool(sql_db, user_id)
return StructuredTool.from_function(
func=tool_instance.run,
coroutine=tool_instance.arun,
name="Get Code and docstring From Multiple Node IDs",
description="""Retrieves code and docstring for multiple node ids in a repository given their node IDs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import Any, Dict, List

Expand Down Expand Up @@ -36,10 +37,7 @@ def _create_neo4j_driver(self) -> GraphDatabase.driver:
)

async def arun(self, repo_id: str, node_id: str) -> Dict[str, Any]:
"""Asynchronous version of the run method."""
return self.run(repo_id, node_id)

def run(self, repo_id: str, node_id: str) -> Dict[str, Any]:
"""Synchronous version that handles the core logic"""
try:
node_data = self._get_node_data(repo_id, node_id)
if not node_data:
Expand Down Expand Up @@ -139,7 +137,6 @@ def get_parameters() -> List[ToolParameter]:
def get_code_from_node_id_tool(sql_db: Session, user_id: str) -> StructuredTool:
tool_instance = GetCodeFromNodeIdTool(sql_db, user_id)
return StructuredTool.from_function(
func=tool_instance.run,
coroutine=tool_instance.arun,
name="Get Code and docstring From Node ID",
description="""Retrieves code and docstring for a specific node id in a repository given its node ID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,27 +75,18 @@ async def find_node_from_probable_name(
for name in probable_node_names
]
return await asyncio.gather(*tasks)


async def arun(
self, project_id: str, probable_node_names: List[str]
) -> List[Dict[str, Any]]:
return await self.run(project_id, probable_node_names)

def run(
self, project_id: str, probable_node_names: List[str]
) -> List[Dict[str, Any]]:
project = asyncio.run(
ProjectService(self.sql_db).get_project_repo_details_from_db(
project_id, self.user_id
)
)
project = await ProjectService(self.sql_db).get_project_repo_details_from_db(
project_id, self.user_id)
if not project:
raise ValueError(
f"Project with ID '{project_id}' not found in database for user '{self.user_id}'"
)
return asyncio.run(
self.find_node_from_probable_name(project_id, probable_node_names)
)
return await self.find_node_from_probable_name(project_id, probable_node_names)

async def async_code_from_node(self, repo_id: str, node_id: str) -> Dict[str, Any]:
return self.code_from_node(repo_id, node_id)
Expand Down Expand Up @@ -200,7 +191,6 @@ def get_code_from_probable_node_name_tool(
) -> StructuredTool:
tool_instance = GetCodeFromProbableNodeNameTool(sql_db, user_id)
return StructuredTool.from_function(
func=tool_instance.run,
coroutine=tool_instance.arun,
name="Get Code and docstring From Probable Node Name",
description="""Retrieves code and docstring for the closest node name in a repository. Node names are in the format of 'file_path:function_name' or 'file_path:class_name' or 'file_path',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ def __init__(self, sql_db, user_id):
self.user_id = user_id

async def arun(self, tags: List[str], project_id: str) -> str:
"""Asynchronous version of the run method."""
return self.run(tags, project_id)

def run(self, tags: List[str], project_id: str) -> str:
"""
Get nodes from the knowledge graph based on the provided tags.
Inputs for the fetch_nodes method:
Expand All @@ -43,11 +39,8 @@ def run(self, tags: List[str], project_id: str) -> str:
* EXTERNAL_SERVICE: Does the code make HTTP requests to external services? Check for HTTP client usage or request handling.
- project_id (str): The ID of the project being evaluated, this is a UUID.
"""
project = asyncio.run(
ProjectService(self.sql_db).get_project_repo_details_from_db(
project_id, self.user_id
)
)
project = await ProjectService(self.sql_db).get_project_repo_details_from_db(
project_id, self.user_id)
if not project:
raise ValueError(
f"Project with ID '{project_id}' not found in database for user '{self.user_id}'"
Expand Down Expand Up @@ -86,7 +79,6 @@ def get_parameters() -> List[ToolParameter]:

def get_nodes_from_tags_tool(sql_db, user_id) -> StructuredTool:
return StructuredTool.from_function(
func=GetNodesFromTags(sql_db, user_id).run,
coroutine=GetNodesFromTags(sql_db, user_id).arun,
name="Get Nodes from Tags",
description="""
Expand Down
7 changes: 6 additions & 1 deletion app/modules/intelligence/tools/tool_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ async def run_tool(self, tool_id: str, params: Dict[str, Any]) -> Dict[str, Any]
tool = self.tools.get(tool_id)
if not tool:
raise ValueError(f"Invalid tool_id: {tool_id}")
return await tool.run(**params)

# If the tool has an arun method, use it
if hasattr(tool, 'arun'):
return await tool.arun(**params)
else:
raise ValueError(f"Tool {tool.__class__.__name__} has no arun method")

def list_tools(self) -> List[ToolInfo]:
return [
Expand Down

0 comments on commit 0ea6ced

Please sign in to comment.