Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions prompt_forge/api/proposals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Refinement proposals endpoints."""

from __future__ import annotations


from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel

from prompt_forge.core.registry import PromptRegistry, get_registry
from prompt_forge.core.vcs import VersionControl, get_vcs
from prompt_forge.db.client import SupabaseClient, get_supabase_client

router = APIRouter()


class ProposalResponse(BaseModel):
"""Refinement proposal response."""

branch_name: str
created_at: str
target_section: str | None = None
source_patterns: list[str] = []


@router.get("/{slug}/proposals", response_model=list[ProposalResponse])
async def list_refinement_proposals(
slug: str,
registry: PromptRegistry = Depends(get_registry),
vcs: VersionControl = Depends(get_vcs),
db: SupabaseClient = Depends(get_supabase_client),
) -> list[ProposalResponse]:
"""List pending refinement branches for a prompt.

Returns branches with name prefix 'refinement/' and status 'active'.
"""
prompt = registry.get_prompt(slug)
if not prompt:
raise HTTPException(status_code=404, detail=f"Prompt '{slug}' not found")

prompt_id = str(prompt["id"])

# Get all active branches for this prompt
all_branches = vcs.list_branches(prompt_id)

# Filter for refinement branches with active status
refinement_branches = [
branch for branch in all_branches
if branch["name"].startswith("refinement/") and branch["status"] == "active"
]

proposals = []
for branch in refinement_branches:
# Extract target section from branch name (refinement/{section}/{timestamp})
name_parts = branch["name"].split("/")
target_section = name_parts[1] if len(name_parts) >= 2 else None

# Try to get source patterns from metadata (optional)
source_patterns = []
try:
metadata_rows = db.select(
"refinement_proposals",
filters={"branch_id": branch["id"]}
)
if metadata_rows:
source_patterns = metadata_rows[0].get("source_patterns", [])
except Exception:
# Table might not exist or have different structure
pass

proposals.append(ProposalResponse(
branch_name=branch["name"],
created_at=branch["created_at"],
target_section=target_section,
source_patterns=source_patterns,
))

# Sort by creation time (newest first)
proposals.sort(key=lambda p: p.created_at, reverse=True)

return proposals
2 changes: 2 additions & 0 deletions prompt_forge/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from prompt_forge.api.effectiveness import router as effectiveness_router
from prompt_forge.api.persona_prompts import router as persona_prompts_router
from prompt_forge.api.prompts import router as prompts_router
from prompt_forge.api.proposals import router as proposals_router
from prompt_forge.api.scan import router as scan_router
from prompt_forge.api.subscriptions import router as subscriptions_router
from prompt_forge.api.usage import router as usage_router
Expand All @@ -20,6 +21,7 @@
api_router.include_router(prompts_router, prefix="/prompts", tags=["prompts"])
api_router.include_router(versions_router, prefix="/prompts", tags=["versions"])
api_router.include_router(branches_router, prefix="/prompts", tags=["branches"])
api_router.include_router(proposals_router, prefix="/prompts", tags=["proposals"])
api_router.include_router(subscriptions_router, prefix="/prompts", tags=["subscriptions"])
api_router.include_router(
persona_prompts_router, prefix="/persona-prompts", tags=["persona-prompts"]
Expand Down
1 change: 1 addition & 0 deletions prompt_forge/core/refinement/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Refinement processing module."""
260 changes: 260 additions & 0 deletions prompt_forge/core/refinement/consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
"""NATS subscriber for refinement proposals.

Listens to:
- pattern.refinement.proposed — creates refinement branches with proposed changes
"""

from __future__ import annotations

import json
import os
from datetime import datetime, timezone

import structlog

logger = structlog.get_logger()

_nats_available = False
try:
import nats as nats_lib

_nats_available = True
except ImportError:
pass


class RefinementConsumer:
"""Subscribes to NATS refinement events and creates refinement branches."""

def __init__(self, nats_url: str = "nats://localhost:4222") -> None:
self.nats_url = nats_url
self._nc = None
self._subs = []
self._connected = False

async def connect(self) -> bool:
"""Connect to NATS. Returns True if successful."""
if not _nats_available:
logger.info("refinement_consumer.nats_not_installed")
return False
try:
self._nc = await nats_lib.connect(self.nats_url)
self._connected = True
logger.info("refinement_consumer.connected", url=self.nats_url)
return True
except Exception as e:
logger.warning("refinement_consumer.connect_failed", error=str(e))
return False

async def start(self) -> None:
"""Start consuming refinement events."""
if not self._connected:
return

sub = await self._nc.subscribe("pattern.refinement.proposed", cb=self._handle_refinement_proposed)
self._subs = [sub]
logger.info(
"refinement_consumer.started",
subjects=["pattern.refinement.proposed"],
)

async def stop(self) -> None:
"""Stop consuming and disconnect."""
for sub in self._subs:
try:
await sub.unsubscribe()
except Exception:
pass
if self._nc and self._connected:
try:
await self._nc.close()
except Exception:
pass
self._connected = False
logger.info("refinement_consumer.stopped")

async def _handle_refinement_proposed(self, msg) -> None:
"""Handle pattern.refinement.proposed event.

Expected event data:
{
"target_slug": "kai-soul",
"section": "reasoning",
"proposed_change": "Updated reasoning instructions...",
"source_patterns": ["dredd-pattern-1", "dredd-pattern-2"]
}
"""
try:
payload = json.loads(msg.data.decode())
data = payload.get("data", payload)

target_slug = data.get("target_slug")
section = data.get("section")
proposed_change = data.get("proposed_change")
source_patterns = data.get("source_patterns", [])

if not target_slug or not section or not proposed_change:
logger.warning("refinement_consumer.incomplete_event", data=data)
return

await self._create_refinement_branch(
target_slug, section, proposed_change, source_patterns, payload
)

except Exception as e:
logger.warning("refinement_consumer.handle_error", error=str(e))

async def _create_refinement_branch(
self,
target_slug: str,
section: str,
proposed_change: str,
source_patterns: list[str],
event_payload: dict
) -> None:
"""Create a refinement branch with the proposed change."""
try:
from prompt_forge.core.registry import get_registry
from prompt_forge.core.vcs import get_vcs

registry = get_registry()
vcs = get_vcs()

# 1. Look up the target prompt by slug
prompt = registry.get_prompt(target_slug)
if not prompt:
logger.warning("refinement_consumer.prompt_not_found", slug=target_slug)
return

prompt_id = str(prompt["id"])

# 2. Get latest version to base the change on
latest_versions = vcs.history(prompt_id, "main", limit=1)
latest_version = latest_versions[0] if latest_versions else None
if not latest_version:
logger.warning("refinement_consumer.no_versions", slug=target_slug)
return

# 3. Create branch name with timestamp
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
branch_name = f"refinement/{section}/{timestamp}"

# 4. Create the branch from latest version
branch = vcs.create_branch(
prompt_id=prompt_id,
branch_name=branch_name,
from_branch="main"
)

# 5. Apply the proposed change to the target section
current_content = latest_version["content"]
updated_content = self._apply_section_change(
current_content, section, proposed_change
)

# 6. Create new version on the branch with the change
commit_message = f"Refinement proposal for {section} section"
if source_patterns:
commit_message += f" (from: {', '.join(source_patterns)})"

vcs.commit(
prompt_id=prompt_id,
content=updated_content,
message=commit_message,
author="refinement-system",
branch=branch_name,
)

# 7. Store metadata about the proposal
await self._store_proposal_metadata(
branch["id"], source_patterns, event_payload
)

logger.info(
"refinement_consumer.branch_created",
slug=target_slug,
branch=branch_name,
section=section,
patterns=source_patterns,
)

except Exception as e:
logger.warning(
"refinement_consumer.create_branch_error",
slug=target_slug,
section=section,
error=str(e)
)

def _apply_section_change(
self, content: dict, section_name: str, proposed_change: str
) -> dict:
"""Apply the proposed change to the specified section."""
updated_content = dict(content)
sections = updated_content.get("sections", [])

# Find and update the target section
updated_sections = []
section_found = False

for section in sections:
if section.get("id") == section_name or section.get("name") == section_name:
# Update this section with the proposed change
updated_section = dict(section)
updated_section["content"] = proposed_change
updated_sections.append(updated_section)
section_found = True
else:
updated_sections.append(section)

# If section doesn't exist, create it
if not section_found:
new_section = {
"id": section_name,
"name": section_name.replace("_", " ").title(),
"content": proposed_change
}
updated_sections.append(new_section)

updated_content["sections"] = updated_sections
return updated_content

async def _store_proposal_metadata(
self, branch_id: str, source_patterns: list[str], event_payload: dict
) -> None:
"""Store metadata about the refinement proposal."""
try:
from prompt_forge.db.client import get_supabase_client

db = get_supabase_client()

# Store in a refinement_proposals table (if it exists)
# This is optional - the briefing doesn't require a specific table
metadata = {
"branch_id": branch_id,
"source_patterns": source_patterns,
"event_payload": event_payload,
"created_at": datetime.now(timezone.utc).isoformat(),
}

# Try to store, but don't fail if table doesn't exist
try:
db.insert("refinement_proposals", metadata)
except Exception:
# Table might not exist, which is fine
pass

except Exception as e:
logger.debug("refinement_consumer.metadata_store_error", error=str(e))


_consumer: RefinementConsumer | None = None


def get_refinement_consumer() -> RefinementConsumer:
"""Get the global refinement consumer (lazy init)."""
global _consumer
if _consumer is None:
nats_url = os.getenv("NATS_URL", "nats://localhost:4222")
_consumer = RefinementConsumer(nats_url)
return _consumer
19 changes: 19 additions & 0 deletions prompt_forge/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ async def lifespan(app: FastAPI):
except Exception as e:
logger.info("promptforge.subscribers_skipped", reason=str(e))

# Initialize NATS refinement consumer (optional)
try:
from prompt_forge.core.refinement.consumer import get_refinement_consumer

consumer = get_refinement_consumer()
if await consumer.connect():
await consumer.start()
except Exception as e:
logger.info("promptforge.refinement_consumer_skipped", reason=str(e))

# Start TTL cleanup background task
_cleanup_task = asyncio.create_task(subscription_ttl_cleanup())

Expand Down Expand Up @@ -111,6 +121,15 @@ async def lifespan(app: FastAPI):
except Exception:
pass

# Disconnect NATS refinement consumer
try:
from prompt_forge.core.refinement.consumer import get_refinement_consumer

consumer = get_refinement_consumer()
await consumer.stop()
except Exception:
pass

# Disconnect NATS
try:
from prompt_forge.core.events import get_event_publisher
Expand Down
Loading