Skip to content

Prompt management #206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions backend/app/alembic/versions/8c3a36b508f1_add_prompts_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""add prompts table

Revision ID: 8c3a36b508f1
Revises: 904ed70e7dab
Create Date: 2025-06-24 10:20:21.933351

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = '8c3a36b508f1'
down_revision = '904ed70e7dab'
branch_labels = None
depends_on = None


def upgrade():
op.create_table('prompt',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('project_id', sa.Integer(), nullable=False),
sa.Column('organization_id', sa.Integer(), nullable=False),
sa.Column('inserted_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['organization_id'], ['organization.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['project_id'], ['project.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_prompt_name'), 'prompt', ['name'], unique=True)
# ### end Alembic commands ###


def downgrade():
op.drop_index(op.f('ix_prompt_name'), table_name='prompt')
op.drop_table('prompt')
# ### end Alembic commands ###
2 changes: 2 additions & 0 deletions backend/app/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
utils,
onboarding,
credentials,
prompts,
)
from app.core.config import settings

Expand All @@ -30,6 +31,7 @@
api_router.include_router(threads.router)
api_router.include_router(users.router)
api_router.include_router(utils.router)
api_router.include_router(prompts.router)

if settings.ENVIRONMENT == "local":
api_router.include_router(private.router)
144 changes: 144 additions & 0 deletions backend/app/api/routes/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select
from typing import List, Optional
from langfuse.client import Langfuse
from pydantic import BaseModel

from app.api.deps import get_current_user_org, get_db
from app.crud.credentials import get_provider_credential
from app.models import UserOrganization
from app.utils import APIResponse
from app.crud.prompt import add_prompt, get_prompt_by_name, list_prompts

router = APIRouter(prefix="/prompts", tags=["prompts"])

class PromptCreateRequest(BaseModel):
project_id: int
name: str
type: str
prompt: str
version: int
labels: Optional[List[str]] = None
tags: Optional[List[str]] = None
model: Optional[str] = None
temperature: Optional[float] = None
supported_languages: Optional[List[str]] = None

class PromptGetRequest(BaseModel):
project_id: int
name: str
type: Optional[str] = None
version: Optional[int] = None
labels: Optional[List[str]] = None
tags: Optional[List[str]] = None

class PromptListRequest(BaseModel):
project_id: int

class PromptUpdateRequest(BaseModel):
project_id: int
name: str
type: Optional[str] = None
prompt: Optional[str] = None
labels: Optional[List[str]] = None
tags: Optional[List[str]] = None
model: Optional[str] = None
temperature: Optional[float] = None
supported_languages: Optional[List[str]] = None

def initialize_langfuse(
project_id: int,
_session: Session,
_current_user: UserOrganization,
):
langfuse_credentials = get_provider_credential(
session=_session,
org_id=_current_user.organization_id,
provider="langfuse",
project_id=project_id,
)
if not langfuse_credentials or "api_key" not in langfuse_credentials:
raise HTTPException(status_code=400, detail="Langfuse API key not configured for this organization.")
return Langfuse(api_key=langfuse_credentials["api_key"])

@router.post("/", response_model=APIResponse[dict])
def create_new_prompt(
request: PromptCreateRequest,
_session: Session = Depends(get_db),
_current_user: UserOrganization = Depends(get_current_user_org)
):
langfuse_client = initialize_langfuse(request.project_id, _session, _current_user)
langfuse_client.create_prompt(
name=request.name,
type="text",
prompt=request.prompt,
labels=request.labels,
tags=request.tags,
config={
"model": request.model,
"temperature": request.temperature,
"supported_languages": request.supported_languages,
},
)
# Add prompt name to Prompt table if not exists using CRUD
if not get_prompt_by_name(_session, request.name, request.project_id, _current_user.organization_id):
add_prompt(_session, request.name, request.project_id, _current_user.organization_id)
return APIResponse.success_response(
message="Prompt created successfully",
data=request.dict(),
)

@router.get("/")
def get_prompt(
request: PromptGetRequest,
_session: Session = Depends(get_db),
_current_user: UserOrganization = Depends(get_current_user_org)
):
# Fetch prompt name from Prompt table using CRUD
prompt_row = get_prompt_by_name(_session, request.name, request.project_id, _current_user.organization_id)
if not prompt_row:
raise HTTPException(status_code=404, detail="Prompt not found in DB")
langfuse_client = initialize_langfuse(request.project_id, _session, _current_user)
prompt = langfuse_client.get_prompt(request.name, request.type, request.version)
return APIResponse.success_response(
message="Prompt fetched successfully",
data=prompt,
)

@router.put("/")
def update_prompt(
request: PromptUpdateRequest,
_session: Session = Depends(get_db),
_current_user: UserOrganization = Depends(get_current_user_org)
):
# Ensure prompt name exists in Prompt table using CRUD
prompt_row = get_prompt_by_name(_session, request.name, request.project_id, _current_user.organization_id)
if not prompt_row:
raise HTTPException(status_code=404, detail="Prompt not found in DB for update")
langfuse_client = initialize_langfuse(request.project_id, _session, _current_user)
langfuse_client.update_prompt(
request.name,
request.type,
request.version,
request.prompt,
request.labels,
)
prompt_row = update_prompt(_session, request.name, request.project_id, _current_user.organization_id)
return APIResponse.success_response(
message="Prompt updated successfully",
data=request.dict(),
)

# Optionally, add a list endpoint for all prompt names
@router.get("/list")
def list_prompt_names(
request: PromptListRequest,
_session: Session = Depends(get_db),
_current_user: UserOrganization = Depends(get_current_user_org)
):
prompts = list_prompts(_session, project_id=request.project_id, organization_id=_current_user.organization_id)
names = [p.name for p in prompts]
return APIResponse.success_response(
message="Prompt names fetched successfully",
data=names,
)
55 changes: 55 additions & 0 deletions backend/app/crud/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from sqlmodel import Session, select
from app.models import Prompt
from typing import Optional, List
from app.core.util import now

def add_prompt(session: Session, name: str, project_id: int, organization_id: int) -> Prompt:
if not name:
raise ValueError("Name cannot be empty")
if project_id <= 0:
raise ValueError("Project ID must be a positive integer")
if organization_id <= 0:
raise ValueError("Organization ID must be a positive integer")

prompt = Prompt(name=name, project_id=project_id, organization_id=organization_id, inserted_at=now(), updated_at=now())
session.add(prompt)
try:
session.commit()
except Exception as e:
session.rollback()
raise e
session.refresh(prompt)
return prompt

def get_prompt_by_name(session: Session, name: str, project_id: int, organization_id: int) -> Optional[Prompt]:
if not name:
raise ValueError("Name cannot be empty")
if project_id <= 0:
raise ValueError("Project ID must be a positive integer")
if organization_id <= 0:
raise ValueError("Organization ID must be a positive integer")

statement = select(Prompt).where(Prompt.project_id == project_id & Prompt.organization_id == organization_id & Prompt.name == name)
return session.exec(statement).first()

def update_prompt(session: Session, name: str, project_id: int, organization_id: int) -> Optional[Prompt]:
try:
statement = select(Prompt).where(Prompt.project_id == project_id & Prompt.organization_id == organization_id & Prompt.name == name)
prompt = session.exec(statement).first()
if not prompt:
raise ValueError(f"No prompt found with name '{name}' for project_id '{project_id}' and organization_id '{organization_id}'")
prompt.updated_at = now()
session.commit()
session.refresh(prompt)
except Exception as e:
session.rollback()
raise e
return prompt

def list_prompts(session: Session, project_id: int, organization_id: int) -> List[Prompt]:
if project_id <= 0:
raise ValueError("Project ID must be a positive integer")
if organization_id <= 0:
raise ValueError("Organization ID must be a positive integer")
statement = select(Prompt).where(Prompt.project_id == project_id & Prompt.organization_id == organization_id)
return list(session.exec(statement))
2 changes: 2 additions & 0 deletions backend/app/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,5 @@
)

from .threads import OpenAI_Thread, OpenAIThreadBase, OpenAIThreadCreate

from .prompt import Prompt
14 changes: 14 additions & 0 deletions backend/app/models/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from sqlmodel import SQLModel, Field
from typing import Optional
from datetime import datetime
from app.core.util import now


class Prompt(SQLModel, table=True):
id: int = Field(default=None, primary_key=True)
name: str = Field(index=True, unique=True)
project_id: int = Field(foreign_key="project.id", nullable=False, ondelete="CASCADE")
organization_id: int = Field(foreign_key="organization.id", nullable=False, ondelete="CASCADE")
inserted_at: datetime = Field(default_factory=now, nullable=False)
updated_at: datetime = Field(default_factory=now, nullable=False)