Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clone search indexes for demo repos #136

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 1 addition & 20 deletions app/modules/github/github_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(self, db: Session):
self.redis = Redis.from_url(config_provider.get_redis_url())

def get_github_repo_details(self, repo_name: str) -> Tuple[Github, Dict, str]:
logger.info(f"Getting GitHub repo details for: {repo_name}")
private_key = (
"-----BEGIN RSA PRIVATE KEY-----\n"
+ config_provider.get_github_key()
Expand All @@ -60,14 +59,10 @@ def get_github_repo_details(self, repo_name: str) -> Tuple[Github, Dict, str]:
}
response = requests.get(url, headers=headers)
if response.status_code != 200:
logger.info(
f"Failed to get installation ID for {repo_name}. Status code: {response.status_code}, Response: {response.text}"
)
raise HTTPException(
status_code=400, detail=f"Failed to get installation ID for {repo_name}"
)

logger.info(f"Successfully got installation ID for {repo_name}")
app_auth = auth.get_installation_auth(response.json()["id"])
github = Github(auth=app_auth)

Expand Down Expand Up @@ -157,11 +152,9 @@ def get_github_oauth_token(self, uid: str) -> str:

async def get_repos_for_user(self, user_id: str):
try:
logger.info(f"Getting repositories for user: {user_id}")
user = self.db.query(User).filter(User.uid == user_id).first()
if user is None:
raise HTTPException(status_code=404, detail="User not found")
logger.info(f"User found: {user}")

firebase_uid = user.uid # Assuming `uid` is the Firebase UID
github_username = user.provider_username
Expand Down Expand Up @@ -223,10 +216,6 @@ async def get_repos_for_user(self, user_id: str):
elif account_type == "Organization" and account_login in org_logins:
user_installations.append(installation)

logger.info(
f"Filtered installations: {[inst['id'] for inst in user_installations]}"
)

repos = []
for installation in user_installations:
app_auth = auth.get_installation_auth(installation["id"])
Expand Down Expand Up @@ -293,28 +282,20 @@ def get_public_github_instance(cls):
return Github(token)

def get_repo(self, repo_name: str) -> Tuple[Github, Any]:
logger.info(f"Attempting to access repo: {repo_name}")
try:
# Try authenticated access first
logger.info(f"Trying authenticated access for repo: {repo_name}")
github, _, _ = self.get_github_repo_details(repo_name)
repo = github.get_repo(repo_name)
logger.info(
f"Successfully accessed repo {repo_name} with authenticated access"
)

return github, repo
except Exception as private_error:
logger.info(
f"Failed to access private repo {repo_name}: {str(private_error)}"
)
# If authenticated access fails, try public access
try:
logger.info(f"Trying public access for repo: {repo_name}")
github = self.get_public_github_instance()
repo = github.get_repo(repo_name)
logger.info(
f"Successfully accessed repo {repo_name} with public access"
)
return github, repo
except Exception as public_error:
logger.error(
Expand Down
3 changes: 0 additions & 3 deletions app/modules/intelligence/prompts/prompt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,6 @@ async def create_or_update_system_prompt(
)
else:
prompt_to_return = existing_prompt
logger.info(
"Existing prompt is kept as it is. No changes detected."
)
else:
# Create new prompt
new_prompt = Prompt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run_tool(self, project_id: str, node_ids: List[str]) -> Dict[str, Any]:
except Exception as e:
logging.exception(f"An unexpected error occurred: {str(e)}")
return {"error": f"An unexpected error occurred: {str(e)}"}

async def run(self, project_id: str, node_ids: List[str]) -> Dict[str, Any]:
"""
Run the tool to retrieve neighbors of the specified nodes.
Expand Down
55 changes: 31 additions & 24 deletions app/modules/parsing/graph_construction/parsing_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ async def parse_directory(
project = await project_manager.get_project_from_db(
repo_name, repo_details.branch_name, user_id
)
duplicate_project = True
if project and project.repo_name in demo_repos:
if project.status == ProjectStatusEnum.READY.value:
duplicate_project = False
project = None # Reset project to None if it's a demo repo so that we don't parse it again

if project:
project_id = project.id
Expand Down Expand Up @@ -85,30 +90,32 @@ async def parse_directory(
new_project_id = str(uuid7())

if existing_project:
# Register the new project with status SUBMITTED
await project_manager.duplicate_project(
repo_name,
repo_details.branch_name,
user_id,
new_project_id,
existing_project.properties,
existing_project.commit_id,
)
await project_manager.update_project_status(
new_project_id, ProjectStatusEnum.SUBMITTED
)

old_repo_id = await project_manager.get_demo_repo_id(repo_name)

# Duplicate the graph under the new repo ID
await parsing_service.duplicate_graph(
old_repo_id, new_project_id
)

# Update the project status to READY after copying
await project_manager.update_project_status(
new_project_id, ProjectStatusEnum.READY
)
if duplicate_project:
await project_manager.duplicate_project(
repo_name,
repo_details.branch_name,
user_id,
new_project_id,
existing_project.properties,
existing_project.commit_id,
)
await project_manager.update_project_status(
new_project_id, ProjectStatusEnum.SUBMITTED
)

old_repo_id = await project_manager.get_demo_repo_id(
repo_name
)

# Duplicate the graph under the new repo ID
await parsing_service.duplicate_graph(
old_repo_id, new_project_id
)

# Update the project status to READY after copying
await project_manager.update_project_status(
new_project_id, ProjectStatusEnum.READY
)

return {
"project_id": new_project_id,
Expand Down
1 change: 1 addition & 0 deletions app/modules/parsing/graph_construction/parsing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ async def analyze_directory(
)

async def duplicate_graph(self, old_repo_id: str, new_repo_id: str):
await self.search_service.clone_search_indices(old_repo_id, new_repo_id)
node_batch_size = 3000 # Fixed batch size for nodes
relationship_batch_size = 3000 # Fixed batch size for relationships
try:
Expand Down
29 changes: 15 additions & 14 deletions app/modules/parsing/knowledge_graph/inference_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def num_tokens_from_string(self, string: str, model: str = "gpt-4") -> int:
return len(encoding.encode(string, disallowed_special=set()))

def fetch_graph(self, repo_id: str) -> List[Dict]:
batch_size = 100 # Define the batch size
batch_size = 500
all_nodes = []
with self.driver.session() as session:
offset = 0
Expand All @@ -99,6 +99,7 @@ def fetch_graph(self, repo_id: str) -> List[Dict]:
break
all_nodes.extend(batch)
offset += batch_size
logger.info(f"DEBUGNEO4J: Fetched {len(all_nodes)} nodes for repo {repo_id}")
return all_nodes

def get_entry_points(self, repo_id: str) -> List[str]:
Expand Down Expand Up @@ -432,15 +433,15 @@ async def generate_docstrings(self, repo_id: str) -> Dict[str, DocstringResponse
f"DEBUGNEO4J: After get entry points, Repo ID: {repo_id}, Entry points: {len(entry_points)}"
)
self.log_graph_stats(repo_id)
entry_points_neighbors = {}
for entry_point in entry_points:
neighbors = self.get_neighbours(entry_point, repo_id)
entry_points_neighbors[entry_point] = neighbors

logger.info(
f"DEBUGNEO4J: After get neighbours, Repo ID: {repo_id}, Entry points neighbors: {len(entry_points_neighbors)}"
)
self.log_graph_stats(repo_id)
# entry_points_neighbors = {}
# for entry_point in entry_points:
# neighbors = self.get_neighbours(entry_point, repo_id)
# entry_points_neighbors[entry_point] = neighbors

# logger.info(
# f"DEBUGNEO4J: After get neighbours, Repo ID: {repo_id}, Entry points neighbors: {len(entry_points_neighbors)}"
# )
# self.log_graph_stats(repo_id)
batches = self.batch_nodes(nodes)
all_docstrings = {"docstrings": []}

Expand All @@ -464,10 +465,10 @@ async def process_batch(batch):
all_docstrings["docstrings"] + result.docstrings
)

updated_docstrings = await self.generate_docstrings_for_entry_points(
all_docstrings, entry_points_neighbors
)

# updated_docstrings = await self.generate_docstrings_for_entry_points(
# all_docstrings, entry_points_neighbors
# )
updated_docstrings = all_docstrings
return updated_docstrings

async def generate_response(
Expand Down
38 changes: 26 additions & 12 deletions app/modules/search/search_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,6 @@ class SearchService:
def __init__(self, db: Session):
self.db = db

async def create_search_index(self, project_id: str, node: Dict):
# Create index entry for the node
self.db.add(
SearchIndex(
project_id=project_id,
node_id=node["node_id"],
name=node.get("name", ""),
file_path=node.get("file_path", ""),
content=f"{node.get('name', '')} {node.get('file_path', '')}",
)
)

async def commit_indices(self):
self.db.commit()

Expand Down Expand Up @@ -125,3 +113,29 @@ def delete_project_index(self, project_id: str):
async def bulk_create_search_indices(self, nodes: List[Dict]):
# Create index entries for all nodes in bulk
self.db.bulk_insert_mappings(SearchIndex, nodes)

async def clone_search_indices(self, input_project_id: str, output_project_id: str):
"""Clone all search indices from input project to output project."""
# Query all search indices for the input project
source_indices = (
self.db.query(SearchIndex)
.filter(SearchIndex.project_id == input_project_id)
.all()
)

# Prepare bulk insert data
cloned_indices = [
{
"project_id": output_project_id,
"node_id": index.node_id,
"name": index.name,
"file_path": index.file_path,
"content": index.content,
}
for index in source_indices
]

# Bulk insert the cloned indices if there are any
if cloned_indices:
self.db.bulk_insert_mappings(SearchIndex, cloned_indices)
await self.commit_indices()
Loading