Skip to content

11 apis to interact with vanna rag #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
9 changes: 6 additions & 3 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,16 @@ services:
flower:
container_name: flower
build: .
command: celery -A main.celery flower --port=5555
command: celery -A main.celery flower
environment:
- CELERY_BROKER_URL=redis://redis:6379/0
- FLOWER_PORT=5555
ports:
- 5556:5555
depends_on:
- fastapi
- redis
- celery_worker
networks:
- llm-network

volumes:
redis_data:
Expand Down
71 changes: 63 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,107 @@ amqp==5.2.0
annotated-types==0.7.0
anyio==4.4.0
async-timeout==4.0.3
attrs==25.1.0
backoff==2.2.1
billiard==4.2.0
blinker==1.8.2
blinker==1.9.0
cachetools==5.5.2
celery==5.4.0
certifi==2024.7.4
certifi==2024.6.2
charset-normalizer==3.3.2
click==8.1.7
click-didyoumean==0.3.1
click-plugins==1.1.1
click-repl==0.3.0
db-dtypes==1.4.1
distro==1.9.0
dnspython==2.6.1
email_validator==2.1.2
exceptiongroup==1.2.1
fastapi==0.115.8
fastapi==0.111.0
fastapi-cli==0.0.4
flasgger==0.9.7.1
Flask==3.1.0
flask-sock==0.7.0
flower==2.0.1
google-api-core==2.24.1
google-auth==2.38.0
google-cloud-bigquery==3.29.0
google-cloud-core==2.4.2
google-crc32c==1.6.0
google-resumable-media==2.7.2
googleapis-common-protos==1.68.0
grpcio==1.70.0
grpcio-status==1.70.0
h11==0.14.0
httpcore==1.0.5
httptools==0.6.1
httpx==0.27.0
humanize==4.9.0
idna==3.7
itsdangerous==2.2.0
Jinja2==3.1.5
Jinja2==3.1.4
jiter==0.5.0
jsonpatch==1.33
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
kaleido==0.2.1
kombu==5.3.7
langchain-core==0.3.37
langchain-openai==0.3.6
langchain-postgres==0.0.13
langfuse==2.38.0
langsmith==0.3.10
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
mistune==3.1.2
narwhals==1.27.1
numpy==2.0.0
openai==1.34.0
openai==1.64.0
orjson==3.10.5
packaging==23.2
pandas==2.2.2
pgvector==0.3.6
plotly==6.0.0
prometheus_client==0.20.0
prompt_toolkit==3.0.47
proto-plus==1.26.0
protobuf==5.29.3
psycopg==3.2.5
psycopg-pool==3.2.5
psycopg2-binary==2.9.10
pyarrow==19.0.1
pyasn1==0.6.1
pyasn1_modules==0.4.1
pydantic==2.7.4
pydantic_core==2.18.4
Pygments==2.18.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-multipart==0.0.20
python-multipart==0.0.9
pytz==2024.1
PyYAML==6.0.1
redis==5.0.6
referencing==0.36.2
regex==2024.11.6
requests==2.32.3
requests-toolbelt==1.0.0
rich==13.7.1
rpds-py==0.23.1
rsa==4.9
shellingham==1.5.4
simple-websocket==1.1.0
six==1.16.0
sniffio==1.3.1
starlette==0.45.3
tornado==6.4.2
SQLAlchemy==2.0.38
sqlparse==0.5.3
starlette==0.37.2
tabulate==0.9.0
tenacity==9.0.0
tiktoken==0.9.0
tornado==6.4.1
tqdm==4.66.4
typer==0.12.3
typing_extensions==4.12.2
Expand All @@ -60,8 +111,12 @@ ujson==5.10.0
urllib3==2.2.2
uvicorn==0.30.1
uvloop==0.19.0
vanna==0.7.6
vine==5.1.0
watchfiles==0.22.0
wcwidth==0.2.13
websockets==12.0
Werkzeug==3.1.3
wrapt==1.16.0
wsproto==1.2.0
zstandard==0.23.0
144 changes: 63 additions & 81 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,28 @@
import os
import requests
import uuid
import logging
import time
from typing import Optional
from pathlib import Path
from pydantic import BaseModel
from fastapi import APIRouter, HTTPException, UploadFile, Form
from celery import shared_task
from celery.result import AsyncResult
from celery.result import AsyncResult, states
from config.constants import TMP_UPLOAD_DIR_NAME


from src.file_search.openai_assistant import OpenAIFileAssistant, SessionStatusEnum
from src.celerytasks.file_search_tasks import close_file_search_session, query_file
from src.celerytasks.vanna_rag_tasks import train_vanna_on_warehouse, ask_vanna_rag
from src.file_search.openai_assistant import SessionStatusEnum
from src.file_search.session import FileSearchSession, OpenAISessionState
from src.custom_webhook import CustomWebhook, WebhookConfig

from src.file_search.schemas import FileQueryRequest
from src.vanna.schemas import (
TrainVannaRequest,
AskVannaRequest,
BaseVannaWarehouseConfig,
)
from src.vanna.sql_generation import SqlGeneration

router = APIRouter()

logger = logging.getLogger()


@shared_task(
bind=True,
autoretry_for=(Exception,),
retry_backoff=5, # tasks will retry after 5, 10, 15... seconds
retry_kwargs={"max_retries": 3},
name="query_file",
logger=logging.getLogger(),
)
def query_file(
self,
openai_key: str,
assistant_prompt: str,
queries: list[str],
session_id: str,
webhook_config: Optional[dict] = None,
):
fa = None
try:
results = []

fa = OpenAIFileAssistant(
openai_key,
session_id=session_id,
instructions=assistant_prompt,
)
for i, prompt in enumerate(queries):
logger.info("%s: %s", i, prompt)
response = fa.query(prompt)
results.append(response)

logger.info(f"Results generated in the session {fa.session.id}")

if webhook_config:
webhook = CustomWebhook(WebhookConfig(**webhook_config))
logger.info(
f"Posting results to the webhook configured at {webhook.config.endpoint}"
)
res = webhook.post_result({"results": results, "session_id": fa.session.id})
logger.info(f"Results posted to the webhook with res: {str(res)}")

return {"result": results, "session_id": fa.session.id}
except Exception as err:
logger.error(err)
raise Exception(str(err))


@shared_task(
bind=True,
autoretry_for=(Exception,),
retry_backoff=5, # tasks will retry after 5, 10, 15... seconds
retry_kwargs={"max_retries": 3},
name="close_file_search_session",
logger=logging.getLogger(),
)
def close_file_search_session(self, openai_key, session_id: str):
try:
fa = OpenAIFileAssistant(openai_key, session_id=session_id)
fa.close()
except Exception as err:
logger.error(err)
raise Exception(str(err))


class FileQueryRequest(BaseModel):
queries: list[str]
assistant_prompt: str = None
session_id: str
webhook_config: Optional[WebhookConfig] = None


@router.delete("/file/search/session/{session_id}")
async def delete_file_search_session(session_id: str):
"""
Expand Down Expand Up @@ -200,6 +131,57 @@ def get_summarize_job(task_id):
"id": task_id,
"status": task_result.status,
"result": task_result.result,
"error": str(task_result.info) if task_result.info else None,
"error": (
str(task_result.info)
if task_result.info and task_result.status != states.SUCCESS
else None
),
}
return result


########################### vanna rag related ###########################


@router.post("/vanna/train")
async def post_train_vanna(payload: TrainVannaRequest):
"""Train the vanna RAG against a warehouse for a defined training plan"""
task = train_vanna_on_warehouse.apply_async(
kwargs={
"openai_api_key": os.getenv("OPENAI_API_KEY"),
"pg_vector_creds": payload.pg_vector_creds.model_dump(),
"warehouse_creds": payload.warehouse_creds,
"training_sql": payload.training_sql,
"reset": payload.reset,
"warehouse_type": payload.warehouse_type.value,
}
)
return {"task_id": task.id}


@router.post("/vanna/train/check")
def post_train_vanna_health_check(payload: BaseVannaWarehouseConfig):
"""Checks if the embeddings are generated or not for the warehouse"""
sql_generation_client = SqlGeneration(
openai_api_key=os.getenv("OPENAI_API_KEY"),
pg_vector_creds=payload.pg_vector_creds,
warehouse_creds=payload.warehouse_creds,
warehouse_type=payload.warehouse_type,
)

return sql_generation_client.is_trained()


@router.post("/vanna/ask")
async def post_generate_sql(payload: AskVannaRequest):
"""Run the question against the trained vanna RAG to generate a sql query"""
task = ask_vanna_rag.apply_async(
kwargs={
"openai_api_key": os.getenv("OPENAI_API_KEY"),
"pg_vector_creds": payload.pg_vector_creds.model_dump(),
"warehouse_creds": payload.warehouse_creds,
"warehouse_type": payload.warehouse_type.value,
"user_prompt": payload.user_prompt,
}
)
return {"task_id": task.id}
Empty file added src/celerytasks/__init__.py
Empty file.
72 changes: 72 additions & 0 deletions src/celerytasks/file_search_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Optional
import logging

from celery import shared_task

from src.file_search.openai_assistant import OpenAIFileAssistant
from src.utils.custom_webhook import CustomWebhook, WebhookConfig


logger = logging.getLogger()


@shared_task(
bind=True,
autoretry_for=(Exception,),
retry_backoff=5, # tasks will retry after 5, 10, 15... seconds
retry_kwargs={"max_retries": 3},
name="query_file",
logger=logging.getLogger(),
)
def query_file(
self,
openai_key: str,
assistant_prompt: str,
queries: list[str],
session_id: str,
webhook_config: Optional[dict] = None,
):
fa = None
try:
results = []

fa = OpenAIFileAssistant(
openai_key,
session_id=session_id,
instructions=assistant_prompt,
)
for i, prompt in enumerate(queries):
logger.info("%s: %s", i, prompt)
response = fa.query(prompt)
results.append(response)

logger.info(f"Results generated in the session {fa.session.id}")

if webhook_config:
webhook = CustomWebhook(WebhookConfig(**webhook_config))
logger.info(
f"Posting results to the webhook configured at {webhook.config.endpoint}"
)
res = webhook.post_result({"results": results, "session_id": fa.session.id})
logger.info(f"Results posted to the webhook with res: {str(res)}")

return {"result": results, "session_id": fa.session.id}
except Exception as err:
logger.error(err)
raise Exception(str(err))


@shared_task(
bind=True,
autoretry_for=(Exception,),
retry_backoff=5, # tasks will retry after 5, 10, 15... seconds
retry_kwargs={"max_retries": 3},
name="close_file_search_session",
logger=logging.getLogger(),
)
def close_file_search_session(self, openai_key, session_id: str):
try:
fa = OpenAIFileAssistant(openai_key, session_id=session_id)
fa.close()
except Exception as err:
logger.error(err)
Loading