Skip to content

Commit

Permalink
Clone search indexes for demo repos (#136)
Browse files Browse the repository at this point in the history
* Clone search indexes for demo repos too

* linter

* remove unnecessary logs , handle demo edge cases, comment out flow understanding temporarily

* linter
  • Loading branch information
dhirenmathur authored Oct 29, 2024
1 parent 7d5c651 commit 0c7c7a0
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 74 deletions.
21 changes: 1 addition & 20 deletions app/modules/github/github_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(self, db: Session):
self.redis = Redis.from_url(config_provider.get_redis_url())

def get_github_repo_details(self, repo_name: str) -> Tuple[Github, Dict, str]:
logger.info(f"Getting GitHub repo details for: {repo_name}")
private_key = (
"-----BEGIN RSA PRIVATE KEY-----\n"
+ config_provider.get_github_key()
Expand All @@ -60,14 +59,10 @@ def get_github_repo_details(self, repo_name: str) -> Tuple[Github, Dict, str]:
}
response = requests.get(url, headers=headers)
if response.status_code != 200:
logger.info(
f"Failed to get installation ID for {repo_name}. Status code: {response.status_code}, Response: {response.text}"
)
raise HTTPException(
status_code=400, detail=f"Failed to get installation ID for {repo_name}"
)

logger.info(f"Successfully got installation ID for {repo_name}")
app_auth = auth.get_installation_auth(response.json()["id"])
github = Github(auth=app_auth)

Expand Down Expand Up @@ -157,11 +152,9 @@ def get_github_oauth_token(self, uid: str) -> str:

async def get_repos_for_user(self, user_id: str):
try:
logger.info(f"Getting repositories for user: {user_id}")
user = self.db.query(User).filter(User.uid == user_id).first()
if user is None:
raise HTTPException(status_code=404, detail="User not found")
logger.info(f"User found: {user}")

firebase_uid = user.uid # Assuming `uid` is the Firebase UID
github_username = user.provider_username
Expand Down Expand Up @@ -223,10 +216,6 @@ async def get_repos_for_user(self, user_id: str):
elif account_type == "Organization" and account_login in org_logins:
user_installations.append(installation)

logger.info(
f"Filtered installations: {[inst['id'] for inst in user_installations]}"
)

repos = []
for installation in user_installations:
app_auth = auth.get_installation_auth(installation["id"])
Expand Down Expand Up @@ -293,28 +282,20 @@ def get_public_github_instance(cls):
return Github(token)

def get_repo(self, repo_name: str) -> Tuple[Github, Any]:
logger.info(f"Attempting to access repo: {repo_name}")
try:
# Try authenticated access first
logger.info(f"Trying authenticated access for repo: {repo_name}")
github, _, _ = self.get_github_repo_details(repo_name)
repo = github.get_repo(repo_name)
logger.info(
f"Successfully accessed repo {repo_name} with authenticated access"
)

return github, repo
except Exception as private_error:
logger.info(
f"Failed to access private repo {repo_name}: {str(private_error)}"
)
# If authenticated access fails, try public access
try:
logger.info(f"Trying public access for repo: {repo_name}")
github = self.get_public_github_instance()
repo = github.get_repo(repo_name)
logger.info(
f"Successfully accessed repo {repo_name} with public access"
)
return github, repo
except Exception as public_error:
logger.error(
Expand Down
3 changes: 0 additions & 3 deletions app/modules/intelligence/prompts/prompt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,6 @@ async def create_or_update_system_prompt(
)
else:
prompt_to_return = existing_prompt
logger.info(
"Existing prompt is kept as it is. No changes detected."
)
else:
# Create new prompt
new_prompt = Prompt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run_tool(self, project_id: str, node_ids: List[str]) -> Dict[str, Any]:
except Exception as e:
logging.exception(f"An unexpected error occurred: {str(e)}")
return {"error": f"An unexpected error occurred: {str(e)}"}

async def run(self, project_id: str, node_ids: List[str]) -> Dict[str, Any]:
"""
Run the tool to retrieve neighbors of the specified nodes.
Expand Down
55 changes: 31 additions & 24 deletions app/modules/parsing/graph_construction/parsing_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ async def parse_directory(
project = await project_manager.get_project_from_db(
repo_name, repo_details.branch_name, user_id
)
duplicate_project = True
if project and project.repo_name in demo_repos:
if project.status == ProjectStatusEnum.READY.value:
duplicate_project = False
project = None # Reset project to None if it's a demo repo so that we don't parse it again

if project:
project_id = project.id
Expand Down Expand Up @@ -85,30 +90,32 @@ async def parse_directory(
new_project_id = str(uuid7())

if existing_project:
# Register the new project with status SUBMITTED
await project_manager.duplicate_project(
repo_name,
repo_details.branch_name,
user_id,
new_project_id,
existing_project.properties,
existing_project.commit_id,
)
await project_manager.update_project_status(
new_project_id, ProjectStatusEnum.SUBMITTED
)

old_repo_id = await project_manager.get_demo_repo_id(repo_name)

# Duplicate the graph under the new repo ID
await parsing_service.duplicate_graph(
old_repo_id, new_project_id
)

# Update the project status to READY after copying
await project_manager.update_project_status(
new_project_id, ProjectStatusEnum.READY
)
if duplicate_project:
await project_manager.duplicate_project(
repo_name,
repo_details.branch_name,
user_id,
new_project_id,
existing_project.properties,
existing_project.commit_id,
)
await project_manager.update_project_status(
new_project_id, ProjectStatusEnum.SUBMITTED
)

old_repo_id = await project_manager.get_demo_repo_id(
repo_name
)

# Duplicate the graph under the new repo ID
await parsing_service.duplicate_graph(
old_repo_id, new_project_id
)

# Update the project status to READY after copying
await project_manager.update_project_status(
new_project_id, ProjectStatusEnum.READY
)

return {
"project_id": new_project_id,
Expand Down
1 change: 1 addition & 0 deletions app/modules/parsing/graph_construction/parsing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ async def analyze_directory(
)

async def duplicate_graph(self, old_repo_id: str, new_repo_id: str):
await self.search_service.clone_search_indices(old_repo_id, new_repo_id)
node_batch_size = 3000 # Fixed batch size for nodes
relationship_batch_size = 3000 # Fixed batch size for relationships
try:
Expand Down
29 changes: 15 additions & 14 deletions app/modules/parsing/knowledge_graph/inference_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def num_tokens_from_string(self, string: str, model: str = "gpt-4") -> int:
return len(encoding.encode(string, disallowed_special=set()))

def fetch_graph(self, repo_id: str) -> List[Dict]:
batch_size = 100 # Define the batch size
batch_size = 500
all_nodes = []
with self.driver.session() as session:
offset = 0
Expand All @@ -99,6 +99,7 @@ def fetch_graph(self, repo_id: str) -> List[Dict]:
break
all_nodes.extend(batch)
offset += batch_size
logger.info(f"DEBUGNEO4J: Fetched {len(all_nodes)} nodes for repo {repo_id}")
return all_nodes

def get_entry_points(self, repo_id: str) -> List[str]:
Expand Down Expand Up @@ -432,15 +433,15 @@ async def generate_docstrings(self, repo_id: str) -> Dict[str, DocstringResponse
f"DEBUGNEO4J: After get entry points, Repo ID: {repo_id}, Entry points: {len(entry_points)}"
)
self.log_graph_stats(repo_id)
entry_points_neighbors = {}
for entry_point in entry_points:
neighbors = self.get_neighbours(entry_point, repo_id)
entry_points_neighbors[entry_point] = neighbors

logger.info(
f"DEBUGNEO4J: After get neighbours, Repo ID: {repo_id}, Entry points neighbors: {len(entry_points_neighbors)}"
)
self.log_graph_stats(repo_id)
# entry_points_neighbors = {}
# for entry_point in entry_points:
# neighbors = self.get_neighbours(entry_point, repo_id)
# entry_points_neighbors[entry_point] = neighbors

# logger.info(
# f"DEBUGNEO4J: After get neighbours, Repo ID: {repo_id}, Entry points neighbors: {len(entry_points_neighbors)}"
# )
# self.log_graph_stats(repo_id)
batches = self.batch_nodes(nodes)
all_docstrings = {"docstrings": []}

Expand All @@ -464,10 +465,10 @@ async def process_batch(batch):
all_docstrings["docstrings"] + result.docstrings
)

updated_docstrings = await self.generate_docstrings_for_entry_points(
all_docstrings, entry_points_neighbors
)

# updated_docstrings = await self.generate_docstrings_for_entry_points(
# all_docstrings, entry_points_neighbors
# )
updated_docstrings = all_docstrings
return updated_docstrings

async def generate_response(
Expand Down
38 changes: 26 additions & 12 deletions app/modules/search/search_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,6 @@ class SearchService:
def __init__(self, db: Session):
self.db = db

async def create_search_index(self, project_id: str, node: Dict):
# Create index entry for the node
self.db.add(
SearchIndex(
project_id=project_id,
node_id=node["node_id"],
name=node.get("name", ""),
file_path=node.get("file_path", ""),
content=f"{node.get('name', '')} {node.get('file_path', '')}",
)
)

async def commit_indices(self):
self.db.commit()

Expand Down Expand Up @@ -125,3 +113,29 @@ def delete_project_index(self, project_id: str):
async def bulk_create_search_indices(self, nodes: List[Dict]):
# Create index entries for all nodes in bulk
self.db.bulk_insert_mappings(SearchIndex, nodes)

async def clone_search_indices(self, input_project_id: str, output_project_id: str):
"""Clone all search indices from input project to output project."""
# Query all search indices for the input project
source_indices = (
self.db.query(SearchIndex)
.filter(SearchIndex.project_id == input_project_id)
.all()
)

# Prepare bulk insert data
cloned_indices = [
{
"project_id": output_project_id,
"node_id": index.node_id,
"name": index.name,
"file_path": index.file_path,
"content": index.content,
}
for index in source_indices
]

# Bulk insert the cloned indices if there are any
if cloned_indices:
self.db.bulk_insert_mappings(SearchIndex, cloned_indices)
await self.commit_indices()

0 comments on commit 0c7c7a0

Please sign in to comment.