Skip to content

Commit a3be786

Browse files
authored
Dalgo Migration: Add threads endpoints (thread creation+polling) (#171)
* endpoints and test cases * test cases * test cases * initial test cases * added test cases * added test cases * added test cases * added test cases * added test cases * added test cases * added test cases * added test cases * added test cases * alembic fix * changes * test cases failure * clean db after test * clean db after test * test cases * removing sqlite session
1 parent e8eb23c commit a3be786

File tree

10 files changed

+491
-16
lines changed

10 files changed

+491
-16
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""add threads table
2+
3+
Revision ID: 79e47bc3aac6
4+
Revises: f23675767ed2
5+
Create Date: 2025-05-12 15:49:39.142806
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
12+
13+
# revision identifiers, used by Alembic.
14+
revision = "79e47bc3aac6"
15+
down_revision = "f23675767ed2"
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.create_table(
23+
"openai_thread",
24+
sa.Column("id", sa.Integer(), nullable=False),
25+
sa.Column("thread_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
26+
sa.Column("prompt", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
27+
sa.Column("response", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
28+
sa.Column("status", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
29+
sa.Column("error", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
30+
sa.Column("inserted_at", sa.DateTime(), nullable=False),
31+
sa.Column("updated_at", sa.DateTime(), nullable=False),
32+
sa.PrimaryKeyConstraint("id"),
33+
)
34+
op.create_index(
35+
op.f("ix_openai_thread_thread_id"), "openai_thread", ["thread_id"], unique=True
36+
)
37+
op.drop_constraint(
38+
"credential_organization_id_fkey", "credential", type_="foreignkey"
39+
)
40+
op.create_foreign_key(
41+
None, "credential", "organization", ["organization_id"], ["id"]
42+
)
43+
op.drop_constraint("project_organization_id_fkey", "project", type_="foreignkey")
44+
op.create_foreign_key(None, "project", "organization", ["organization_id"], ["id"])
45+
# ### end Alembic commands ###
46+
47+
48+
def downgrade():
49+
# ### commands auto generated by Alembic - please adjust! ###
50+
op.drop_constraint(None, "project", type_="foreignkey")
51+
op.create_foreign_key(
52+
"project_organization_id_fkey",
53+
"project",
54+
"organization",
55+
["organization_id"],
56+
["id"],
57+
ondelete="CASCADE",
58+
)
59+
op.drop_constraint(None, "credential", type_="foreignkey")
60+
op.create_foreign_key(
61+
"credential_organization_id_fkey",
62+
"credential",
63+
"organization",
64+
["organization_id"],
65+
["id"],
66+
ondelete="CASCADE",
67+
)
68+
op.drop_index(op.f("ix_openai_thread_thread_id"), table_name="openai_thread")
69+
op.drop_table("openai_thread")
70+
# ### end Alembic commands ###

backend/app/api/routes/threads.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
from app.api.deps import get_current_user_org, get_db
1111
from app.core import logging, settings
12-
from app.models import UserOrganization
12+
from app.models import UserOrganization, OpenAIThreadCreate
13+
from app.crud import upsert_thread_result, get_thread_result
1314
from app.utils import APIResponse
1415

1516
logger = logging.getLogger(__name__)
@@ -113,6 +114,24 @@ def create_success_response(request: dict, message: str) -> APIResponse:
113114
)
114115

115116

117+
def run_and_poll_thread(client: OpenAI, thread_id: str, assistant_id: str):
118+
"""Runs and polls a thread with the specified assistant using the OpenAI client."""
119+
return client.beta.threads.runs.create_and_poll(
120+
thread_id=thread_id,
121+
assistant_id=assistant_id,
122+
)
123+
124+
125+
def extract_response_from_thread(
126+
client: OpenAI, thread_id: str, remove_citation: bool = False
127+
) -> str:
128+
"""Fetches and processes the latest message from a thread."""
129+
messages = client.beta.threads.messages.list(thread_id=thread_id)
130+
latest_message = messages.data[0]
131+
message_content = latest_message.content[0].text.value
132+
return process_message_content(message_content, remove_citation)
133+
134+
116135
@observe(as_type="generation")
117136
def process_run(request: dict, client: OpenAI):
118137
"""Process a run and send callback with results."""
@@ -159,6 +178,40 @@ def process_run(request: dict, client: OpenAI):
159178
send_callback(request["callback_url"], callback_response.model_dump())
160179

161180

181+
def poll_run_and_prepare_response(request: dict, client: OpenAI, db: Session):
182+
"""Handles a thread run, processes the response, and upserts the result to the database."""
183+
thread_id = request["thread_id"]
184+
prompt = request["question"]
185+
186+
try:
187+
run = run_and_poll_thread(client, thread_id, request["assistant_id"])
188+
189+
status = run.status or "unknown"
190+
response = None
191+
error = None
192+
193+
if status == "completed":
194+
response = extract_response_from_thread(
195+
client, thread_id, request.get("remove_citation", False)
196+
)
197+
198+
except openai.OpenAIError as e:
199+
status = "failed"
200+
error = str(e)
201+
response = None
202+
203+
upsert_thread_result(
204+
db,
205+
OpenAIThreadCreate(
206+
thread_id=thread_id,
207+
prompt=prompt,
208+
response=response,
209+
status=status,
210+
error=error,
211+
),
212+
)
213+
214+
162215
@router.post("/threads")
163216
async def threads(
164217
request: dict,
@@ -240,3 +293,72 @@ async def threads_sync(
240293

241294
except openai.OpenAIError as e:
242295
return APIResponse.failure_response(error=handle_openai_error(e))
296+
297+
298+
@router.post("/threads/start")
299+
async def start_thread(
300+
request: OpenAIThreadCreate,
301+
background_tasks: BackgroundTasks,
302+
db: Session = Depends(get_db),
303+
_current_user: UserOrganization = Depends(get_current_user_org),
304+
):
305+
"""
306+
Create a new OpenAI thread for the given question and start polling in the background.
307+
"""
308+
prompt = request["question"]
309+
client = OpenAI(api_key=settings.OPENAI_API_KEY)
310+
311+
is_success, error = setup_thread(client, request)
312+
if not is_success:
313+
return APIResponse.failure_response(error=error)
314+
315+
thread_id = request["thread_id"]
316+
317+
upsert_thread_result(
318+
db,
319+
OpenAIThreadCreate(
320+
thread_id=thread_id,
321+
prompt=prompt,
322+
response=None,
323+
status="processing",
324+
error=None,
325+
),
326+
)
327+
328+
background_tasks.add_task(poll_run_and_prepare_response, request, client, db)
329+
330+
return APIResponse.success_response(
331+
data={
332+
"thread_id": thread_id,
333+
"prompt": prompt,
334+
"status": "processing",
335+
"message": "Thread created and polling started in background.",
336+
}
337+
)
338+
339+
340+
@router.get("/threads/result/{thread_id}")
341+
async def get_thread(
342+
thread_id: str,
343+
db: Session = Depends(get_db),
344+
_current_user: UserOrganization = Depends(get_current_user_org),
345+
):
346+
"""
347+
Retrieve the result of a previously started OpenAI thread using its thread ID.
348+
"""
349+
result = get_thread_result(db, thread_id)
350+
351+
if not result:
352+
return APIResponse.failure_response(error="Thread not found.")
353+
354+
status = result.status or ("success" if result.response else "processing")
355+
356+
return APIResponse.success_response(
357+
data={
358+
"thread_id": result.thread_id,
359+
"prompt": result.prompt,
360+
"status": status,
361+
"response": result.response,
362+
"error": result.error,
363+
}
364+
)

backend/app/crud/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,5 @@
2929
get_api_keys_by_organization,
3030
delete_api_key,
3131
)
32+
33+
from .thread_results import upsert_thread_result, get_thread_result

backend/app/crud/thread_results.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from sqlmodel import Session, select
2+
from datetime import datetime
3+
from app.models import OpenAIThreadCreate, OpenAI_Thread
4+
5+
6+
def upsert_thread_result(session: Session, data: OpenAIThreadCreate):
7+
statement = select(OpenAI_Thread).where(OpenAI_Thread.thread_id == data.thread_id)
8+
existing = session.exec(statement).first()
9+
10+
if existing:
11+
existing.prompt = data.prompt
12+
existing.response = data.response
13+
existing.status = data.status
14+
existing.error = data.error
15+
existing.updated_at = datetime.utcnow()
16+
else:
17+
new_thread = OpenAI_Thread(**data.dict())
18+
session.add(new_thread)
19+
20+
session.commit()
21+
22+
23+
def get_thread_result(session: Session, thread_id: str) -> OpenAI_Thread | None:
24+
statement = select(OpenAI_Thread).where(OpenAI_Thread.thread_id == thread_id)
25+
return session.exec(statement).first()

backend/app/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,5 @@
5151
CredsPublic,
5252
CredsUpdate,
5353
)
54+
55+
from .threads import OpenAI_Thread, OpenAIThreadBase, OpenAIThreadCreate

backend/app/models/threads.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from sqlmodel import SQLModel, Field
2+
from typing import Optional
3+
from datetime import datetime
4+
5+
6+
class OpenAIThreadBase(SQLModel):
7+
thread_id: str = Field(index=True, unique=True)
8+
prompt: str
9+
response: Optional[str] = None
10+
status: Optional[str] = None
11+
error: Optional[str] = None
12+
13+
14+
class OpenAIThreadCreate(OpenAIThreadBase):
15+
pass # Used for requests, no `id` or timestamps
16+
17+
18+
class OpenAI_Thread(OpenAIThreadBase, table=True):
19+
id: int = Field(default=None, primary_key=True)
20+
inserted_at: datetime = Field(default_factory=datetime.utcnow)
21+
updated_at: datetime = Field(default_factory=datetime.utcnow)

backend/app/tests/api/routes/test_creds.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,16 @@ def create_organization_and_creds(db: Session, superuser_token_headers: dict[str
4343

4444

4545
def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]):
46-
unique_org_id = 2
47-
existing_org = (
48-
db.query(Organization).filter(Organization.id == unique_org_id).first()
49-
)
46+
unique_name = "Test Organization " + generate_random_string(5)
5047

51-
if not existing_org:
52-
new_org = Organization(
53-
id=unique_org_id, name="Test Organization", is_active=True
54-
)
55-
db.add(new_org)
56-
db.commit()
48+
new_org = Organization(name=unique_name, is_active=True)
49+
db.add(new_org)
50+
db.commit()
51+
db.refresh(new_org)
5752

5853
api_key = "sk-" + generate_random_string(10)
5954
creds_data = {
60-
"organization_id": unique_org_id,
55+
"organization_id": new_org.id,
6156
"is_active": True,
6257
"credential": {"openai": {"api_key": api_key}},
6358
}
@@ -69,10 +64,9 @@ def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str])
6964
)
7065

7166
assert response.status_code == 200
72-
7367
created_creds = response.json()
7468
assert "data" in created_creds
75-
assert created_creds["data"]["organization_id"] == unique_org_id
69+
assert created_creds["data"]["organization_id"] == new_org.id
7670
assert created_creds["data"]["credential"]["openai"]["api_key"] == api_key
7771

7872

0 commit comments

Comments
 (0)