diff --git a/prompt_forge/api/branches.py b/prompt_forge/api/branches.py index 66f3fa8..1ee3160 100644 --- a/prompt_forge/api/branches.py +++ b/prompt_forge/api/branches.py @@ -27,6 +27,22 @@ class BranchMerge(BaseModel): author: str = "system" +class BranchReject(BaseModel): + """Reject a branch.""" + + reason: str | None = None + + +class BranchDiffResponse(BaseModel): + """Branch diff response.""" + + branch_name: str + target_section: str = "all" + current_content: dict[str, Any] + proposed_content: dict[str, Any] + diff_summary: str + + class BranchResponse(BaseModel): """Branch info.""" @@ -77,6 +93,90 @@ async def list_branches( return vcs.list_branches(str(prompt["id"])) +@router.get("/{slug}/branches/{branch_name}/diff") +async def branch_diff( + slug: str, + branch_name: str, + registry: PromptRegistry = Depends(get_registry), + vcs: VersionControl = Depends(get_vcs), +) -> dict[str, Any]: + """Compare branch head version content vs main branch latest version content.""" + prompt = registry.get_prompt(slug) + if not prompt: + raise HTTPException(status_code=404, detail=f"Prompt '{slug}' not found") + + prompt_id = str(prompt["id"]) + + # Check if branch exists + branches = vcs.list_branches(prompt_id) + branch = next((b for b in branches if b["name"] == branch_name), None) + if not branch: + raise HTTPException(status_code=404, detail=f"Branch '{branch_name}' not found") + + try: + # Get latest version from the branch + branch_history = vcs.history(prompt_id, branch=branch_name, limit=1) + if not branch_history: + raise HTTPException( + status_code=404, detail=f"No versions found on branch '{branch_name}'" + ) + + branch_version = branch_history[0] + proposed_content = branch_version["content"] + + # Get latest version from main branch + main_history = vcs.history(prompt_id, branch="main", limit=1) + if not main_history: + raise HTTPException(status_code=404, detail="No versions found on main branch") + + main_version = main_history[0] + current_content = main_version["content"] + + # Generate diff summary + diff_summary = _generate_diff_summary(current_content, proposed_content) + + return { + "branch_name": branch_name, + "target_section": "all", + "current_content": current_content, + "proposed_content": proposed_content, + "diff_summary": diff_summary, + } + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error generating diff: {str(e)}") + + +@router.post("/{slug}/branches/{branch_name}/reject") +async def reject_branch( + slug: str, + branch_name: str, + data: BranchReject, + registry: PromptRegistry = Depends(get_registry), + vcs: VersionControl = Depends(get_vcs), +) -> dict[str, Any]: + """Reject a branch by updating its status.""" + prompt = registry.get_prompt(slug) + if not prompt: + raise HTTPException(status_code=404, detail=f"Prompt '{slug}' not found") + + prompt_id = str(prompt["id"]) + + # Check if branch exists + branches = vcs.list_branches(prompt_id) + branch = next((b for b in branches if b["name"] == branch_name), None) + if not branch: + raise HTTPException(status_code=404, detail=f"Branch '{branch_name}' not found") + + try: + # Update branch status to rejected + updated_branch = vcs.db.update("prompt_branches", branch["id"], {"status": "rejected"}) + return updated_branch + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error rejecting branch: {str(e)}") + + @router.post("/{slug}/branches/{branch_name}/merge") async def merge_branch( slug: str, @@ -101,3 +201,66 @@ async def merge_branch( return version except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) + + +def _generate_diff_summary(current: dict[str, Any], proposed: dict[str, Any]) -> str: + """Generate a human-readable summary of differences.""" + changes = [] + + # Compare sections + current_sections = {s.get("id"): s for s in current.get("sections", [])} + proposed_sections = {s.get("id"): s for s in proposed.get("sections", [])} + + # New sections + new_sections = set(proposed_sections.keys()) - set(current_sections.keys()) + if new_sections: + changes.append(f"Added {len(new_sections)} new section(s): {', '.join(new_sections)}") + + # Removed sections + removed_sections = set(current_sections.keys()) - set(proposed_sections.keys()) + if removed_sections: + changes.append(f"Removed {len(removed_sections)} section(s): {', '.join(removed_sections)}") + + # Modified sections + modified_sections = [] + for section_id in set(current_sections.keys()) & set(proposed_sections.keys()): + if current_sections[section_id].get("content") != proposed_sections[section_id].get( + "content" + ): + modified_sections.append(section_id) + + if modified_sections: + changes.append( + f"Modified {len(modified_sections)} section(s): {', '.join(modified_sections)}" + ) + + # Compare variables + current_vars = current.get("variables", {}) + proposed_vars = proposed.get("variables", {}) + + new_vars = set(proposed_vars.keys()) - set(current_vars.keys()) + removed_vars = set(current_vars.keys()) - set(proposed_vars.keys()) + modified_vars = [ + k + for k in set(current_vars.keys()) & set(proposed_vars.keys()) + if current_vars[k] != proposed_vars[k] + ] + + if new_vars: + changes.append(f"Added {len(new_vars)} variable(s)") + if removed_vars: + changes.append(f"Removed {len(removed_vars)} variable(s)") + if modified_vars: + changes.append(f"Modified {len(modified_vars)} variable(s)") + + # Compare metadata + current_meta = current.get("metadata", {}) + proposed_meta = proposed.get("metadata", {}) + + if current_meta != proposed_meta: + changes.append("Modified metadata") + + if not changes: + return "No changes detected" + + return "; ".join(changes) diff --git a/prompt_forge/api/branches.py.bak b/prompt_forge/api/branches.py.bak new file mode 100644 index 0000000..66f3fa8 --- /dev/null +++ b/prompt_forge/api/branches.py.bak @@ -0,0 +1,103 @@ +"""Branch management endpoints.""" + +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field + +from prompt_forge.core.registry import PromptRegistry, get_registry +from prompt_forge.core.vcs import VersionControl, get_vcs + +router = APIRouter() + + +class BranchCreate(BaseModel): + """Create a branch.""" + + name: str = Field(..., min_length=1, max_length=100) + from_branch: str = "main" + + +class BranchMerge(BaseModel): + """Merge a branch.""" + + strategy: str = Field(default="theirs", pattern=r"^(ours|theirs|section_merge)$") + author: str = "system" + + +class BranchResponse(BaseModel): + """Branch info.""" + + id: str + prompt_id: str + name: str + head_version_id: str | None + base_version_id: str | None + status: str + created_at: str + updated_at: str + + +@router.post("/{slug}/branches", status_code=201) +async def create_branch( + slug: str, + data: BranchCreate, + registry: PromptRegistry = Depends(get_registry), + vcs: VersionControl = Depends(get_vcs), +) -> dict[str, Any]: + """Create a new branch for a prompt.""" + prompt = registry.get_prompt(slug) + if not prompt: + raise HTTPException(status_code=404, detail=f"Prompt '{slug}' not found") + + try: + branch = vcs.create_branch( + prompt_id=str(prompt["id"]), + branch_name=data.name, + from_branch=data.from_branch, + ) + return branch + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.get("/{slug}/branches") +async def list_branches( + slug: str, + registry: PromptRegistry = Depends(get_registry), + vcs: VersionControl = Depends(get_vcs), +) -> list[dict[str, Any]]: + """List all branches for a prompt.""" + prompt = registry.get_prompt(slug) + if not prompt: + raise HTTPException(status_code=404, detail=f"Prompt '{slug}' not found") + + return vcs.list_branches(str(prompt["id"])) + + +@router.post("/{slug}/branches/{branch_name}/merge") +async def merge_branch( + slug: str, + branch_name: str, + data: BranchMerge, + registry: PromptRegistry = Depends(get_registry), + vcs: VersionControl = Depends(get_vcs), +) -> dict[str, Any]: + """Merge a branch into main (or target).""" + prompt = registry.get_prompt(slug) + if not prompt: + raise HTTPException(status_code=404, detail=f"Prompt '{slug}' not found") + + try: + version = vcs.merge_branch( + prompt_id=str(prompt["id"]), + source_branch=branch_name, + target_branch="main", + strategy=data.strategy, + author=data.author, + ) + return version + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) diff --git a/prompt_forge/api/proposals.py b/prompt_forge/api/proposals.py index 56c2f90..a7998a2 100644 --- a/prompt_forge/api/proposals.py +++ b/prompt_forge/api/proposals.py @@ -15,7 +15,7 @@ class ProposalResponse(BaseModel): """Refinement proposal response.""" - + branch_name: str created_at: str target_section: str | None = None @@ -30,51 +30,51 @@ async def list_refinement_proposals( db: SupabaseClient = Depends(get_supabase_client), ) -> list[ProposalResponse]: """List pending refinement branches for a prompt. - + Returns branches with name prefix 'refinement/' and status 'active'. """ prompt = registry.get_prompt(slug) if not prompt: raise HTTPException(status_code=404, detail=f"Prompt '{slug}' not found") - + prompt_id = str(prompt["id"]) - + # Get all active branches for this prompt all_branches = vcs.list_branches(prompt_id) - + # Filter for refinement branches with active status refinement_branches = [ - branch for branch in all_branches + branch + for branch in all_branches if branch["name"].startswith("refinement/") and branch["status"] == "active" ] - + proposals = [] for branch in refinement_branches: # Extract target section from branch name (refinement/{section}/{timestamp}) name_parts = branch["name"].split("/") target_section = name_parts[1] if len(name_parts) >= 2 else None - + # Try to get source patterns from metadata (optional) source_patterns = [] try: - metadata_rows = db.select( - "refinement_proposals", - filters={"branch_id": branch["id"]} - ) + metadata_rows = db.select("refinement_proposals", filters={"branch_id": branch["id"]}) if metadata_rows: source_patterns = metadata_rows[0].get("source_patterns", []) except Exception: # Table might not exist or have different structure pass - - proposals.append(ProposalResponse( - branch_name=branch["name"], - created_at=branch["created_at"], - target_section=target_section, - source_patterns=source_patterns, - )) - + + proposals.append( + ProposalResponse( + branch_name=branch["name"], + created_at=branch["created_at"], + target_section=target_section, + source_patterns=source_patterns, + ) + ) + # Sort by creation time (newest first) proposals.sort(key=lambda p: p.created_at, reverse=True) - + return proposals diff --git a/prompt_forge/core/refinement/consumer.py b/prompt_forge/core/refinement/consumer.py index 78abb43..72bbe33 100644 --- a/prompt_forge/core/refinement/consumer.py +++ b/prompt_forge/core/refinement/consumer.py @@ -51,7 +51,9 @@ async def start(self) -> None: if not self._connected: return - sub = await self._nc.subscribe("pattern.refinement.proposed", cb=self._handle_refinement_proposed) + sub = await self._nc.subscribe( + "pattern.refinement.proposed", cb=self._handle_refinement_proposed + ) self._subs = [sub] logger.info( "refinement_consumer.started", @@ -75,7 +77,7 @@ async def stop(self) -> None: async def _handle_refinement_proposed(self, msg) -> None: """Handle pattern.refinement.proposed event. - + Expected event data: { "target_slug": "kai-soul", @@ -87,12 +89,12 @@ async def _handle_refinement_proposed(self, msg) -> None: try: payload = json.loads(msg.data.decode()) data = payload.get("data", payload) - + target_slug = data.get("target_slug") section = data.get("section") proposed_change = data.get("proposed_change") source_patterns = data.get("source_patterns", []) - + if not target_slug or not section or not proposed_change: logger.warning("refinement_consumer.incomplete_event", data=data) return @@ -100,63 +102,59 @@ async def _handle_refinement_proposed(self, msg) -> None: await self._create_refinement_branch( target_slug, section, proposed_change, source_patterns, payload ) - + except Exception as e: logger.warning("refinement_consumer.handle_error", error=str(e)) async def _create_refinement_branch( - self, - target_slug: str, - section: str, - proposed_change: str, + self, + target_slug: str, + section: str, + proposed_change: str, source_patterns: list[str], - event_payload: dict + event_payload: dict, ) -> None: """Create a refinement branch with the proposed change.""" try: from prompt_forge.core.registry import get_registry from prompt_forge.core.vcs import get_vcs - + registry = get_registry() vcs = get_vcs() - + # 1. Look up the target prompt by slug prompt = registry.get_prompt(target_slug) if not prompt: logger.warning("refinement_consumer.prompt_not_found", slug=target_slug) return - + prompt_id = str(prompt["id"]) - + # 2. Get latest version to base the change on latest_versions = vcs.history(prompt_id, "main", limit=1) latest_version = latest_versions[0] if latest_versions else None if not latest_version: logger.warning("refinement_consumer.no_versions", slug=target_slug) return - + # 3. Create branch name with timestamp timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") branch_name = f"refinement/{section}/{timestamp}" - + # 4. Create the branch from latest version branch = vcs.create_branch( - prompt_id=prompt_id, - branch_name=branch_name, - from_branch="main" + prompt_id=prompt_id, branch_name=branch_name, from_branch="main" ) - + # 5. Apply the proposed change to the target section current_content = latest_version["content"] - updated_content = self._apply_section_change( - current_content, section, proposed_change - ) - + updated_content = self._apply_section_change(current_content, section, proposed_change) + # 6. Create new version on the branch with the change commit_message = f"Refinement proposal for {section} section" if source_patterns: commit_message += f" (from: {', '.join(source_patterns)})" - + vcs.commit( prompt_id=prompt_id, content=updated_content, @@ -164,12 +162,10 @@ async def _create_refinement_branch( author="refinement-system", branch=branch_name, ) - + # 7. Store metadata about the proposal - await self._store_proposal_metadata( - branch["id"], source_patterns, event_payload - ) - + await self._store_proposal_metadata(branch["id"], source_patterns, event_payload) + logger.info( "refinement_consumer.branch_created", slug=target_slug, @@ -177,26 +173,24 @@ async def _create_refinement_branch( section=section, patterns=source_patterns, ) - + except Exception as e: logger.warning( "refinement_consumer.create_branch_error", slug=target_slug, section=section, - error=str(e) + error=str(e), ) - def _apply_section_change( - self, content: dict, section_name: str, proposed_change: str - ) -> dict: + def _apply_section_change(self, content: dict, section_name: str, proposed_change: str) -> dict: """Apply the proposed change to the specified section.""" updated_content = dict(content) sections = updated_content.get("sections", []) - + # Find and update the target section updated_sections = [] section_found = False - + for section in sections: if section.get("id") == section_name or section.get("name") == section_name: # Update this section with the proposed change @@ -206,16 +200,16 @@ def _apply_section_change( section_found = True else: updated_sections.append(section) - + # If section doesn't exist, create it if not section_found: new_section = { "id": section_name, "name": section_name.replace("_", " ").title(), - "content": proposed_change + "content": proposed_change, } updated_sections.append(new_section) - + updated_content["sections"] = updated_sections return updated_content @@ -225,9 +219,9 @@ async def _store_proposal_metadata( """Store metadata about the refinement proposal.""" try: from prompt_forge.db.client import get_supabase_client - + db = get_supabase_client() - + # Store in a refinement_proposals table (if it exists) # This is optional - the briefing doesn't require a specific table metadata = { @@ -236,14 +230,14 @@ async def _store_proposal_metadata( "event_payload": event_payload, "created_at": datetime.now(timezone.utc).isoformat(), } - + # Try to store, but don't fail if table doesn't exist try: db.insert("refinement_proposals", metadata) except Exception: # Table might not exist, which is fine pass - + except Exception as e: logger.debug("refinement_consumer.metadata_store_error", error=str(e)) diff --git a/scripts/seed_persona_prompts.py b/scripts/seed_persona_prompts.py index 8190f9a..b6fd5c3 100755 --- a/scripts/seed_persona_prompts.py +++ b/scripts/seed_persona_prompts.py @@ -13,13 +13,13 @@ def main(): """Seed initial persona prompts.""" print("Seeding initial persona prompts...") - + store = get_persona_store() store.seed_initial_personas() - + print("✅ Initial persona prompts seeded successfully!") print("Available personas: researcher, developer, reviewer, tester, architect") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/test_branches.py b/tests/test_branches.py index d3f2032..40a8874 100644 --- a/tests/test_branches.py +++ b/tests/test_branches.py @@ -167,3 +167,113 @@ def test_merge_branch_api(self, client): def test_branch_not_found(self, client): resp = client.post("/api/v1/prompts/nonexistent/branches", json={"name": "x"}) assert resp.status_code == 404 + + +class TestNewBranchEndpoints: + def test_branch_diff_api(self, client): + # Create a prompt + client.post( + "/api/v1/prompts", + json={ + "slug": "branch-diff-test", + "name": "Branch Diff Test", + "type": "persona", + "content": { + "sections": [ + {"id": "identity", "label": "Identity", "content": "Original content"} + ], + "variables": {"var1": "value1"}, + "metadata": {}, + }, + }, + ) + + # Create a branch + client.post("/api/v1/prompts/branch-diff-test/branches", json={"name": "feature"}) + + # Update the branch with new content + client.post( + "/api/v1/prompts/branch-diff-test/versions", + json={ + "content": { + "sections": [{"id": "identity", "label": "Identity", "content": "New content"}], + "variables": {"var1": "value1", "var2": "value2"}, + "metadata": {}, + }, + "message": "Update feature branch", + "branch": "feature", + }, + ) + + # Test the diff endpoint + resp = client.get("/api/v1/prompts/branch-diff-test/branches/feature/diff") + assert resp.status_code == 200 + data = resp.json() + assert data["branch_name"] == "feature" + assert "diff_summary" in data + assert "current_content" in data + assert "proposed_content" in data + assert data["current_content"]["sections"][0]["content"] == "Original content" + assert data["proposed_content"]["sections"][0]["content"] == "New content" + + def test_branch_diff_branch_not_found(self, client): + client.post( + "/api/v1/prompts", + json={ + "slug": "diff-404-test", + "name": "Test", + "type": "persona", + "content": { + "sections": [{"id": "identity", "label": "Identity", "content": "Test"}], + "variables": {}, + "metadata": {}, + }, + }, + ) + resp = client.get("/api/v1/prompts/diff-404-test/branches/nonexistent/diff") + assert resp.status_code == 404 + + def test_branch_reject_api(self, client): + # Create a prompt + client.post( + "/api/v1/prompts", + json={ + "slug": "branch-reject-test", + "name": "Branch Reject Test", + "type": "persona", + "content": { + "sections": [{"id": "identity", "label": "Identity", "content": "Test"}], + "variables": {}, + "metadata": {}, + }, + }, + ) + + # Create a branch + client.post("/api/v1/prompts/branch-reject-test/branches", json={"name": "unwanted"}) + + # Reject the branch + resp = client.post( + "/api/v1/prompts/branch-reject-test/branches/unwanted/reject", + json={"reason": "Not needed anymore"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "rejected" + + def test_branch_reject_branch_not_found(self, client): + client.post( + "/api/v1/prompts", + json={ + "slug": "reject-404-test", + "name": "Test", + "type": "persona", + "content": { + "sections": [{"id": "identity", "label": "Identity", "content": "Test"}], + "variables": {}, + "metadata": {}, + }, + }, + ) + resp = client.post("/api/v1/prompts/reject-404-test/branches/nonexistent/reject", json={}) + assert resp.status_code == 404