diff --git a/docker-compose.yml b/docker-compose.yml index d201d2f..4a443fd 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -51,13 +51,16 @@ services: flower: container_name: flower build: . - command: celery -A main.celery flower --port=5555 + command: celery -A main.celery flower + environment: + - CELERY_BROKER_URL=redis://redis:6379/0 + - FLOWER_PORT=5555 ports: - 5556:5555 depends_on: - - fastapi - redis - - celery_worker + networks: + - llm-network volumes: redis_data: diff --git a/requirements.txt b/requirements.txt index 177cdd8..e714541 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,22 +2,38 @@ amqp==5.2.0 annotated-types==0.7.0 anyio==4.4.0 async-timeout==4.0.3 +attrs==25.1.0 +backoff==2.2.1 billiard==4.2.0 -blinker==1.8.2 +blinker==1.9.0 +cachetools==5.5.2 celery==5.4.0 -certifi==2024.7.4 +certifi==2024.6.2 charset-normalizer==3.3.2 click==8.1.7 click-didyoumean==0.3.1 click-plugins==1.1.1 click-repl==0.3.0 +db-dtypes==1.4.1 distro==1.9.0 dnspython==2.6.1 email_validator==2.1.2 exceptiongroup==1.2.1 -fastapi==0.115.8 +fastapi==0.111.0 fastapi-cli==0.0.4 +flasgger==0.9.7.1 +Flask==3.1.0 +flask-sock==0.7.0 flower==2.0.1 +google-api-core==2.24.1 +google-auth==2.38.0 +google-cloud-bigquery==3.29.0 +google-cloud-core==2.4.2 +google-crc32c==1.6.0 +google-resumable-media==2.7.2 +googleapis-common-protos==1.68.0 +grpcio==1.70.0 +grpcio-status==1.70.0 h11==0.14.0 httpcore==1.0.5 httptools==0.6.1 @@ -25,33 +41,68 @@ httpx==0.27.0 humanize==4.9.0 idna==3.7 itsdangerous==2.2.0 -Jinja2==3.1.5 +Jinja2==3.1.4 +jiter==0.5.0 +jsonpatch==1.33 +jsonpointer==3.0.0 +jsonschema==4.23.0 +jsonschema-specifications==2024.10.1 +kaleido==0.2.1 kombu==5.3.7 +langchain-core==0.3.37 +langchain-openai==0.3.6 +langchain-postgres==0.0.13 +langfuse==2.38.0 +langsmith==0.3.10 markdown-it-py==3.0.0 MarkupSafe==2.1.5 mdurl==0.1.2 +mistune==3.1.2 +narwhals==1.27.1 numpy==2.0.0 -openai==1.34.0 +openai==1.64.0 orjson==3.10.5 +packaging==23.2 pandas==2.2.2 +pgvector==0.3.6 +plotly==6.0.0 prometheus_client==0.20.0 prompt_toolkit==3.0.47 +proto-plus==1.26.0 +protobuf==5.29.3 +psycopg==3.2.5 +psycopg-pool==3.2.5 +psycopg2-binary==2.9.10 +pyarrow==19.0.1 +pyasn1==0.6.1 +pyasn1_modules==0.4.1 pydantic==2.7.4 pydantic_core==2.18.4 Pygments==2.18.0 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 -python-multipart==0.0.20 +python-multipart==0.0.9 pytz==2024.1 PyYAML==6.0.1 redis==5.0.6 +referencing==0.36.2 +regex==2024.11.6 requests==2.32.3 +requests-toolbelt==1.0.0 rich==13.7.1 +rpds-py==0.23.1 +rsa==4.9 shellingham==1.5.4 +simple-websocket==1.1.0 six==1.16.0 sniffio==1.3.1 -starlette==0.45.3 -tornado==6.4.2 +SQLAlchemy==2.0.38 +sqlparse==0.5.3 +starlette==0.37.2 +tabulate==0.9.0 +tenacity==9.0.0 +tiktoken==0.9.0 +tornado==6.4.1 tqdm==4.66.4 typer==0.12.3 typing_extensions==4.12.2 @@ -60,8 +111,12 @@ ujson==5.10.0 urllib3==2.2.2 uvicorn==0.30.1 uvloop==0.19.0 +vanna==0.7.6 vine==5.1.0 watchfiles==0.22.0 wcwidth==0.2.13 websockets==12.0 Werkzeug==3.1.3 +wrapt==1.16.0 +wsproto==1.2.0 +zstandard==0.23.0 diff --git a/src/api.py b/src/api.py index 297b5a1..22d2949 100644 --- a/src/api.py +++ b/src/api.py @@ -1,97 +1,28 @@ import os -import requests import uuid import logging -import time -from typing import Optional from pathlib import Path -from pydantic import BaseModel from fastapi import APIRouter, HTTPException, UploadFile, Form -from celery import shared_task -from celery.result import AsyncResult +from celery.result import AsyncResult, states from config.constants import TMP_UPLOAD_DIR_NAME - -from src.file_search.openai_assistant import OpenAIFileAssistant, SessionStatusEnum +from src.celerytasks.file_search_tasks import close_file_search_session, query_file +from src.celerytasks.vanna_rag_tasks import train_vanna_on_warehouse, ask_vanna_rag +from src.file_search.openai_assistant import SessionStatusEnum from src.file_search.session import FileSearchSession, OpenAISessionState -from src.custom_webhook import CustomWebhook, WebhookConfig - +from src.file_search.schemas import FileQueryRequest +from src.vanna.schemas import ( + TrainVannaRequest, + AskVannaRequest, + BaseVannaWarehouseConfig, +) +from src.vanna.sql_generation import SqlGeneration router = APIRouter() logger = logging.getLogger() -@shared_task( - bind=True, - autoretry_for=(Exception,), - retry_backoff=5, # tasks will retry after 5, 10, 15... seconds - retry_kwargs={"max_retries": 3}, - name="query_file", - logger=logging.getLogger(), -) -def query_file( - self, - openai_key: str, - assistant_prompt: str, - queries: list[str], - session_id: str, - webhook_config: Optional[dict] = None, -): - fa = None - try: - results = [] - - fa = OpenAIFileAssistant( - openai_key, - session_id=session_id, - instructions=assistant_prompt, - ) - for i, prompt in enumerate(queries): - logger.info("%s: %s", i, prompt) - response = fa.query(prompt) - results.append(response) - - logger.info(f"Results generated in the session {fa.session.id}") - - if webhook_config: - webhook = CustomWebhook(WebhookConfig(**webhook_config)) - logger.info( - f"Posting results to the webhook configured at {webhook.config.endpoint}" - ) - res = webhook.post_result({"results": results, "session_id": fa.session.id}) - logger.info(f"Results posted to the webhook with res: {str(res)}") - - return {"result": results, "session_id": fa.session.id} - except Exception as err: - logger.error(err) - raise Exception(str(err)) - - -@shared_task( - bind=True, - autoretry_for=(Exception,), - retry_backoff=5, # tasks will retry after 5, 10, 15... seconds - retry_kwargs={"max_retries": 3}, - name="close_file_search_session", - logger=logging.getLogger(), -) -def close_file_search_session(self, openai_key, session_id: str): - try: - fa = OpenAIFileAssistant(openai_key, session_id=session_id) - fa.close() - except Exception as err: - logger.error(err) - raise Exception(str(err)) - - -class FileQueryRequest(BaseModel): - queries: list[str] - assistant_prompt: str = None - session_id: str - webhook_config: Optional[WebhookConfig] = None - - @router.delete("/file/search/session/{session_id}") async def delete_file_search_session(session_id: str): """ @@ -200,6 +131,57 @@ def get_summarize_job(task_id): "id": task_id, "status": task_result.status, "result": task_result.result, - "error": str(task_result.info) if task_result.info else None, + "error": ( + str(task_result.info) + if task_result.info and task_result.status != states.SUCCESS + else None + ), } return result + + +########################### vanna rag related ########################### + + +@router.post("/vanna/train") +async def post_train_vanna(payload: TrainVannaRequest): + """Train the vanna RAG against a warehouse for a defined training plan""" + task = train_vanna_on_warehouse.apply_async( + kwargs={ + "openai_api_key": os.getenv("OPENAI_API_KEY"), + "pg_vector_creds": payload.pg_vector_creds.model_dump(), + "warehouse_creds": payload.warehouse_creds, + "training_sql": payload.training_sql, + "reset": payload.reset, + "warehouse_type": payload.warehouse_type.value, + } + ) + return {"task_id": task.id} + + +@router.post("/vanna/train/check") +def post_train_vanna_health_check(payload: BaseVannaWarehouseConfig): + """Checks if the embeddings are generated or not for the warehouse""" + sql_generation_client = SqlGeneration( + openai_api_key=os.getenv("OPENAI_API_KEY"), + pg_vector_creds=payload.pg_vector_creds, + warehouse_creds=payload.warehouse_creds, + warehouse_type=payload.warehouse_type, + ) + + return sql_generation_client.is_trained() + + +@router.post("/vanna/ask") +async def post_generate_sql(payload: AskVannaRequest): + """Run the question against the trained vanna RAG to generate a sql query""" + task = ask_vanna_rag.apply_async( + kwargs={ + "openai_api_key": os.getenv("OPENAI_API_KEY"), + "pg_vector_creds": payload.pg_vector_creds.model_dump(), + "warehouse_creds": payload.warehouse_creds, + "warehouse_type": payload.warehouse_type.value, + "user_prompt": payload.user_prompt, + } + ) + return {"task_id": task.id} diff --git a/src/celerytasks/__init__.py b/src/celerytasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/celerytasks/file_search_tasks.py b/src/celerytasks/file_search_tasks.py new file mode 100644 index 0000000..7825940 --- /dev/null +++ b/src/celerytasks/file_search_tasks.py @@ -0,0 +1,72 @@ +from typing import Optional +import logging + +from celery import shared_task + +from src.file_search.openai_assistant import OpenAIFileAssistant +from src.utils.custom_webhook import CustomWebhook, WebhookConfig + + +logger = logging.getLogger() + + +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=5, # tasks will retry after 5, 10, 15... seconds + retry_kwargs={"max_retries": 3}, + name="query_file", + logger=logging.getLogger(), +) +def query_file( + self, + openai_key: str, + assistant_prompt: str, + queries: list[str], + session_id: str, + webhook_config: Optional[dict] = None, +): + fa = None + try: + results = [] + + fa = OpenAIFileAssistant( + openai_key, + session_id=session_id, + instructions=assistant_prompt, + ) + for i, prompt in enumerate(queries): + logger.info("%s: %s", i, prompt) + response = fa.query(prompt) + results.append(response) + + logger.info(f"Results generated in the session {fa.session.id}") + + if webhook_config: + webhook = CustomWebhook(WebhookConfig(**webhook_config)) + logger.info( + f"Posting results to the webhook configured at {webhook.config.endpoint}" + ) + res = webhook.post_result({"results": results, "session_id": fa.session.id}) + logger.info(f"Results posted to the webhook with res: {str(res)}") + + return {"result": results, "session_id": fa.session.id} + except Exception as err: + logger.error(err) + raise Exception(str(err)) + + +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=5, # tasks will retry after 5, 10, 15... seconds + retry_kwargs={"max_retries": 3}, + name="close_file_search_session", + logger=logging.getLogger(), +) +def close_file_search_session(self, openai_key, session_id: str): + try: + fa = OpenAIFileAssistant(openai_key, session_id=session_id) + fa.close() + except Exception as err: + logger.error(err) diff --git a/src/celerytasks/vanna_rag_tasks.py b/src/celerytasks/vanna_rag_tasks.py new file mode 100644 index 0000000..23a1610 --- /dev/null +++ b/src/celerytasks/vanna_rag_tasks.py @@ -0,0 +1,83 @@ +from typing import Optional +from celery import shared_task +import logging + +from src.vanna.schemas import PgVectorCreds +from src.vanna.sql_generation import SqlGeneration + + +logger = logging.getLogger() + + +@shared_task( + bind=True, + retry_backoff=5, # tasks will retry after 5, 10, 15... seconds + retry_kwargs={"max_retries": 1}, + name="train_vanna_on_warehouse", + logger=logging.getLogger(), +) +def train_vanna_on_warehouse( + self, + openai_api_key: str, + pg_vector_creds: dict, + warehouse_creds: dict, + training_sql: str, + reset: bool, + warehouse_type: str, +): + + sql_generation_client = SqlGeneration( + openai_api_key=openai_api_key, + pg_vector_creds=PgVectorCreds(**pg_vector_creds), + warehouse_creds=warehouse_creds, + warehouse_type=warehouse_type, + ) + + if reset: + sql_generation_client.remove_training_data() + logger.info("Deleted training data successfully") + + sql_generation_client.setup_training_plan_and_execute(training_sql) + + logger.info( + f"Completed training successfully with the following plan {training_sql}" + ) + + return True + + +@shared_task( + bind=True, + retry_backoff=5, # tasks will retry after 5, 10, 15... seconds + retry_kwargs={"max_retries": 1}, + name="ask_vanna_rag", + logger=logging.getLogger(), +) +def ask_vanna_rag( + self, + openai_api_key: str, + pg_vector_creds: dict, + warehouse_creds: dict, + warehouse_type: str, + user_prompt: str, +): + + sql_generation_client = SqlGeneration( + openai_api_key=openai_api_key, + pg_vector_creds=PgVectorCreds(**pg_vector_creds), + warehouse_creds=warehouse_creds, + warehouse_type=warehouse_type, + ) + + logger.info("Starting sql generation") + + sql = sql_generation_client.generate_sql( + question=user_prompt, allow_llm_to_see_data=False + ) + + logger.info(f"Finished sql generation with result: {sql}") + + if not sql_generation_client.is_sql_valid(sql): + raise Exception(f"Failed to get a valid sql from llm : {sql}") + + return sql diff --git a/src/file_search/openai_assistant.py b/src/file_search/openai_assistant.py index a74b1ac..92f72e4 100644 --- a/src/file_search/openai_assistant.py +++ b/src/file_search/openai_assistant.py @@ -193,8 +193,10 @@ def close(self): logger.info("Closing the session %s", self.session.id) for doc in self.documents: self.client.files.delete(doc.id) - self.client.beta.threads.delete(self.thread.id) - self.client.beta.assistants.delete(self.assistant.id) + if self.thread.id: + self.client.beta.threads.delete(self.thread.id) + if self.assistant.id: + self.client.beta.assistants.delete(self.assistant.id) for local_fpath in self.session.local_fpaths: Path(local_fpath).unlink() # remove from redis diff --git a/src/file_search/schemas.py b/src/file_search/schemas.py new file mode 100644 index 0000000..93985bb --- /dev/null +++ b/src/file_search/schemas.py @@ -0,0 +1,10 @@ +from src.utils.custom_webhook import WebhookConfig +from typing import Optional +from pydantic import BaseModel + + +class FileQueryRequest(BaseModel): + queries: list[str] + assistant_prompt: str = None + session_id: str + webhook_config: Optional[WebhookConfig] = None diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/custom_webhook.py b/src/utils/custom_webhook.py similarity index 100% rename from src/custom_webhook.py rename to src/utils/custom_webhook.py diff --git a/src/vanna/__init__.py b/src/vanna/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/vanna/schemas.py b/src/vanna/schemas.py new file mode 100644 index 0000000..9b121b7 --- /dev/null +++ b/src/vanna/schemas.py @@ -0,0 +1,43 @@ +from enum import Enum +from typing import Optional +from pydantic import BaseModel + + +class WarehouseType(str, Enum): + """ + warehouse types available that vanna model can work with + """ + + POSTGRES = "postgres" + BIGQUERY = "bigquery" + + +class PgVectorCreds(BaseModel): + """Pg Vector Creds where the embeddings for the RAG will be stored""" + + username: str + password: str + host: str + port: int + database: str + + +class BaseVannaWarehouseConfig(BaseModel): + """Base model for vanna related stuff""" + + pg_vector_creds: PgVectorCreds + warehouse_creds: dict + warehouse_type: WarehouseType + + +class TrainVannaRequest(BaseVannaWarehouseConfig): + """Payload to train vanna model against a warehouse""" + + training_sql: str + reset: bool = True + + +class AskVannaRequest(BaseVannaWarehouseConfig): + """Payload to ask vanna for sql corresponding to a user prompt""" + + user_prompt: str diff --git a/src/vanna/sql_generation.py b/src/vanna/sql_generation.py new file mode 100644 index 0000000..8e52d8f --- /dev/null +++ b/src/vanna/sql_generation.py @@ -0,0 +1,164 @@ +import os +import tempfile +import json +from enum import Enum +import logging +from sqlalchemy import create_engine, text + +from openai import OpenAI +from vanna.openai import OpenAI_Chat +from vanna.pgvector import PG_VectorStore +from langchain_openai import OpenAIEmbeddings + +from src.vanna.schemas import PgVectorCreds, WarehouseType + + +logger = logging.getLogger() + + +class CustomPG_VectorStore(PG_VectorStore): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def remove_all_training_data(self, **kwargs): + engine = create_engine(self.connection_string) + + delete_statement = text( + """ + DELETE FROM langchain_pg_embedding + """ + ) + + with engine.connect() as connection: + with connection.begin() as transaction: + try: + result = connection.execute(delete_statement) + transaction.commit() + return result.rowcount > 0 + except Exception as e: + logging.error(f"An error occurred: {e}") + transaction.rollback() + return False + + def cnt_of_embeddings(self): + engine = create_engine(self.connection_string) + + stmt = text( + """ + SELECT count(*) FROM langchain_pg_embedding + """ + ) + + with engine.connect() as connection: + with connection.begin() as transaction: + try: + result = connection.execute(stmt) + transaction.commit() + return result.fetchone()[0] + except Exception as e: + logging.error(f"An error occurred: {e}") + transaction.rollback() + return False + + +class CustomVannaClient(CustomPG_VectorStore, OpenAI_Chat): + """ + Vanna client with pgvector as its backend and openai as the service provider + All RAG related calls to talk vanna model will be made via this client + """ + + def __init__( + self, + openai_api_key: str, + pg_vector_creds: PgVectorCreds, + openai_model: str = "gpt-4o-mini", + initial_prompt: str = None, + ): + CustomPG_VectorStore.__init__( + self, + config={ + "connection_string": "postgresql+psycopg://{username}:{password}@{host}:{port}/{database}".format( + **pg_vector_creds.model_dump() + ), + "embedding_function": OpenAIEmbeddings(), + }, + ) + + OpenAI_Chat.__init__( + self, + config={ + "api_key": openai_api_key, + "model": openai_model, + "initial_prompt": initial_prompt, + }, + ) + + +class SqlGeneration: + def __init__( + self, + openai_api_key: str, + pg_vector_creds: PgVectorCreds, + warehouse_creds: dict, + warehouse_type: str, + ): + os.environ["OPENAI_API_KEY"] = openai_api_key + + if warehouse_type == WarehouseType.POSTGRES: + required_creds = { + "host": warehouse_creds["host"], + "port": warehouse_creds["port"], + "dbname": warehouse_creds["database"], + "user": warehouse_creds["username"], + "password": warehouse_creds["password"], + } + + self.vanna = CustomVannaClient( + openai_api_key=openai_api_key, + pg_vector_creds=pg_vector_creds, + initial_prompt="Please qualify all table names with their schema names in the generated SQL", + ) + self.vanna.connect_to_postgres(**required_creds) + elif warehouse_type == WarehouseType.BIGQUERY: + cred_file_path = None + with tempfile.NamedTemporaryFile( + delete=False, mode="w", suffix=".json" + ) as temp_file: + json.dump(warehouse_creds, temp_file, indent=4) + cred_file_path = temp_file.name + + self.vanna = CustomVannaClient( + openai_api_key=openai_api_key, + pg_vector_creds=pg_vector_creds, + initial_prompt="please include backticks for project names and table names if appropriate", + ) + self.vanna.connect_to_bigquery( + project_id=warehouse_creds["project_id"], + cred_file_path=cred_file_path, + ) + else: + raise ValueError("Invalid warehouse type") + + def generate_sql(self, question: str, allow_llm_to_see_data=False): + return self.vanna.generate_sql( + question=question, allow_llm_to_see_data=allow_llm_to_see_data + ) + + def is_sql_valid(self, sql: str): + return self.vanna.is_sql_valid(sql=sql) + + def run_sql(self, sql: str): + return self.vanna.run_sql(sql=sql) + + def setup_training_plan_and_execute(self, training_sql: str): + df_information_schema = self.vanna.run_sql(training_sql) + plan = self.vanna.get_training_plan_generic(df_information_schema) + self.vanna.train(plan=plan) + return True + + def remove_training_data(self): + self.vanna.remove_all_training_data() + return True + + def is_trained(self) -> bool: + return self.vanna.cnt_of_embeddings() > 0