From 99a179c75326992758b94ff35ce01f497fc85cc4 Mon Sep 17 00:00:00 2001
From: kaarthik108 <kaarthikandavar@gmail.com>
Date: Sun, 13 Oct 2024 22:15:33 +1300
Subject: [PATCH 1/9] Use agents

---
 .github/workflows/lint.yml |  27 +++-
 Makefile                   |  57 +++++++
 agent.py                   | 102 ++++++++++++
 chain.py                   | 309 ++++++++++++++++++-------------------
 ingest.py                  |   1 +
 main.py                    |  77 +++++----
 template.py                |   1 -
 tools.py                   |  28 ++++
 utils/snow_connect.py      |   1 -
 utils/snowchat_ui.py       |  91 +++++------
 10 files changed, 452 insertions(+), 242 deletions(-)
 create mode 100644 Makefile
 create mode 100644 agent.py
 create mode 100644 tools.py

diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 97e02502..6c7256e3 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -4,26 +4,39 @@ on:
   push:
     branches:
       - main
-      - prod
   pull_request:
     branches:
       - main
-      - prod
 
 jobs:
   lint:
+    name: Lint and Format Code
     runs-on: ubuntu-latest
+
     steps:
       - name: Check out repository
         uses: actions/checkout@v3
 
       - name: Set up Python
-        uses: actions/setup-python@v2
+        uses: actions/setup-python@v4
+        with:
+          python-version: "3.9"
+
+      - name: Cache pip dependencies
+        uses: actions/cache@v3
         with:
-          python-version: 3.9
+          path: ~/.cache/pip
+          key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
+          restore-keys: |
+            ${{ runner.os }}-pip-
 
       - name: Install dependencies
-        run: pip install black
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install black ruff mypy codespell
 
-      - name: Lint with black
-        run: black --check .
+      - name: Run Formatting and Linting
+        run: |
+          make format
+          make lint
diff --git a/Makefile b/Makefile
new file mode 100644
index 00000000..a768e53a
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,57 @@
+.PHONY: all format lint lint_diff format_diff lint_package lint_tests spell_check spell_fix help lint-fix
+
+# Define a variable for Python and notebook files.
+PYTHON_FILES=src/
+MYPY_CACHE=.mypy_cache
+
+######################
+# LINTING AND FORMATTING
+######################
+
+lint format: PYTHON_FILES=.
+lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$')
+lint_package: PYTHON_FILES=src
+lint_tests: PYTHON_FILES=tests
+lint_tests: MYPY_CACHE=.mypy_cache_test
+
+lint lint_diff lint_package lint_tests:
+	python -m ruff check .
+	[ "$(PYTHON_FILES)" = "" ] || python -m ruff format $(PYTHON_FILES) --diff
+	[ "$(PYTHON_FILES)" = "" ] || python -m ruff check --select I,F401 --fix $(PYTHON_FILES)
+	[ "$(PYTHON_FILES)" = "" ] || python -m mypy --strict $(PYTHON_FILES)
+	[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && python -m mypy --strict $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
+
+format format_diff:
+	ruff format $(PYTHON_FILES)
+	ruff check --fix $(PYTHON_FILES)
+
+spell_check:
+	codespell --toml pyproject.toml
+
+spell_fix:
+	codespell --toml pyproject.toml -w
+
+######################
+# RUN ALL
+######################
+
+all: format lint spell_check
+
+######################
+# HELP
+######################
+
+help:
+	@echo '----'
+	@echo 'format                       - run code formatters'
+	@echo 'lint                         - run linters'
+	@echo 'spell_check                  - run spell check'
+	@echo 'all                          - run all tasks'
+	@echo 'lint-fix                     - run lint and fix issues'
+
+######################
+# LINT-FIX TARGET
+######################
+
+lint-fix: format lint
+	@echo "Linting and fixing completed successfully."
diff --git a/agent.py b/agent.py
new file mode 100644
index 00000000..a7a7462b
--- /dev/null
+++ b/agent.py
@@ -0,0 +1,102 @@
+import os
+from dataclasses import dataclass
+from typing import Annotated, Sequence, Optional
+
+from langchain.callbacks.base import BaseCallbackHandler
+from langchain_anthropic import ChatAnthropic
+from langchain_core.messages import SystemMessage
+from langchain_openai import ChatOpenAI
+from langgraph.checkpoint.memory import MemorySaver
+from langgraph.graph import START, StateGraph
+from langgraph.prebuilt import ToolNode, tools_condition
+from langgraph.graph.message import add_messages
+from langchain_core.messages import BaseMessage
+
+from template import TEMPLATE
+from tools import retriever_tool
+
+
+@dataclass
+class MessagesState:
+    messages: Annotated[Sequence[BaseMessage], add_messages]
+
+
+memory = MemorySaver()
+
+
+@dataclass
+class ModelConfig:
+    model_name: str
+    api_key: str
+    base_url: Optional[str] = None
+
+
+def create_agent(callback_handler: BaseCallbackHandler, model_name: str):
+    model_configurations = {
+        "gpt-4o-mini": ModelConfig(
+            model_name="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY")
+        ),
+        "gemma2-9b": ModelConfig(
+            model_name="gemma2-9b-it",
+            api_key=os.getenv("GROQ_API_KEY"),
+            base_url="https://api.groq.com/openai/v1",
+        ),
+        "claude3-haiku": ModelConfig(
+            model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY")
+        ),
+        "mixtral-8x22b": ModelConfig(
+            model_name="accounts/fireworks/models/mixtral-8x22b-instruct",
+            api_key=os.getenv("FIREWORKS_API_KEY"),
+            base_url="https://api.fireworks.ai/inference/v1",
+        ),
+        "llama-3.1-405b": ModelConfig(
+            model_name="accounts/fireworks/models/llama-v3p1-405b-instruct",
+            api_key=os.getenv("FIREWORKS_API_KEY"),
+            base_url="https://api.fireworks.ai/inference/v1",
+        ),
+    }
+    config = model_configurations.get(model_name)
+    if not config:
+        raise ValueError(f"Unsupported model name: {model_name}")
+
+    sys_msg = SystemMessage(
+        content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate.
+        Call the tool "Database_Schema" to search for database schema details when needed to generate the SQL code.
+        """
+    )
+
+    llm = (
+        ChatOpenAI(
+            model=config.model_name,
+            api_key=config.api_key,
+            callbacks=[callback_handler],
+            streaming=True,
+            base_url=config.base_url,
+        )
+        if config.model_name != "claude-3-haiku-20240307"
+        else ChatAnthropic(
+            model=config.model_name,
+            api_key=config.api_key,
+            callbacks=[callback_handler],
+            streaming=True,
+        )
+    )
+
+    tools = [retriever_tool]
+
+    llm_with_tools = llm.bind_tools(tools)
+
+    def reasoner(state: MessagesState):
+        return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]}
+
+    # Build the graph
+    builder = StateGraph(MessagesState)
+    builder.add_node("reasoner", reasoner)
+    builder.add_node("tools", ToolNode(tools))
+    builder.add_edge(START, "reasoner")
+    builder.add_conditional_edges("reasoner", tools_condition)
+    builder.add_edge("tools", "reasoner")
+
+    react_graph = builder.compile(checkpointer=memory)
+
+    return react_graph
diff --git a/chain.py b/chain.py
index a16b8627..a0c25952 100644
--- a/chain.py
+++ b/chain.py
@@ -1,155 +1,154 @@
-from typing import Any, Callable, Dict, Optional
-
-import streamlit as st
-from langchain_community.chat_models import ChatOpenAI
-from langchain.embeddings.openai import OpenAIEmbeddings
-from langchain.llms import OpenAI
-from langchain.vectorstores import SupabaseVectorStore
-from pydantic import BaseModel, validator
-from supabase.client import Client, create_client
-
-from template import CONDENSE_QUESTION_PROMPT, QA_PROMPT
-
-from operator import itemgetter
-
-from langchain.prompts.prompt import PromptTemplate
-from langchain.schema import format_document
-from langchain_core.messages import get_buffer_string
-from langchain_core.output_parsers import StrOutputParser
-from langchain_core.runnables import RunnableParallel, RunnablePassthrough
-from langchain_openai import ChatOpenAI, OpenAIEmbeddings
-from langchain_anthropic import ChatAnthropic
-
-DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
-
-supabase_url = st.secrets["SUPABASE_URL"]
-supabase_key = st.secrets["SUPABASE_SERVICE_KEY"]
-supabase: Client = create_client(supabase_url, supabase_key)
-
-
-class ModelConfig(BaseModel):
-    model_type: str
-    secrets: Dict[str, Any]
-    callback_handler: Optional[Callable] = None
-
-
-class ModelWrapper:
-    def __init__(self, config: ModelConfig):
-        self.model_type = config.model_type
-        self.secrets = config.secrets
-        self.callback_handler = config.callback_handler
-        self.llm = self._setup_llm()
-
-    def _setup_llm(self):
-        model_config = {
-            "gpt-4o-mini": {
-                "model_name": "gpt-4o-mini",
-                "api_key": self.secrets["OPENAI_API_KEY"],
-            },
-            "gemma2-9b": {
-                "model_name": "gemma2-9b-it",
-                "api_key": self.secrets["GROQ_API_KEY"],
-                "base_url": "https://api.groq.com/openai/v1",
-            },
-            "claude3-haiku": {
-                "model_name": "claude-3-haiku-20240307",
-                "api_key": self.secrets["ANTHROPIC_API_KEY"],
-            },
-            "mixtral-8x22b": {
-                "model_name": "accounts/fireworks/models/mixtral-8x22b-instruct",
-                "api_key": self.secrets["FIREWORKS_API_KEY"],
-                "base_url": "https://api.fireworks.ai/inference/v1",
-            },
-            "llama-3.1-405b": {
-                "model_name": "accounts/fireworks/models/llama-v3p1-405b-instruct",
-                "api_key": self.secrets["FIREWORKS_API_KEY"],
-                "base_url": "https://api.fireworks.ai/inference/v1",
-            },
-        }
-
-        config = model_config[self.model_type]
-
-        return (
-            ChatOpenAI(
-                model_name=config["model_name"],
-                temperature=0.1,
-                api_key=config["api_key"],
-                max_tokens=700,
-                callbacks=[self.callback_handler],
-                streaming=True,
-                base_url=config["base_url"]
-                if config["model_name"] != "gpt-4o-mini"
-                else None,
-                default_headers={
-                    "HTTP-Referer": "https://snowchat.streamlit.app/",
-                    "X-Title": "Snowchat",
-                },
-            )
-            if config["model_name"] != "claude-3-haiku-20240307"
-            else (
-                ChatAnthropic(
-                    model=config["model_name"],
-                    temperature=0.1,
-                    max_tokens=700,
-                    timeout=None,
-                    max_retries=2,
-                    callbacks=[self.callback_handler],
-                    streaming=True,
-                )
-            )
-        )
-
-    def get_chain(self, vectorstore):
-        def _combine_documents(
-            docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
-        ):
-            doc_strings = [format_document(doc, document_prompt) for doc in docs]
-            return document_separator.join(doc_strings)
-
-        _inputs = RunnableParallel(
-            standalone_question=RunnablePassthrough.assign(
-                chat_history=lambda x: get_buffer_string(x["chat_history"])
-            )
-            | CONDENSE_QUESTION_PROMPT
-            | OpenAI()
-            | StrOutputParser(),
-        )
-        _context = {
-            "context": itemgetter("standalone_question")
-            | vectorstore.as_retriever()
-            | _combine_documents,
-            "question": lambda x: x["standalone_question"],
-        }
-        conversational_qa_chain = _inputs | _context | QA_PROMPT | self.llm
-
-        return conversational_qa_chain
-
-
-def load_chain(model_name="qwen", callback_handler=None):
-    embeddings = OpenAIEmbeddings(
-        openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002"
-    )
-    vectorstore = SupabaseVectorStore(
-        embedding=embeddings,
-        client=supabase,
-        table_name="documents",
-        query_name="v_match_documents",
-    )
-
-    model_type_mapping = {
-        "gpt-4o-mini": "gpt-4o-mini",
-        "gemma2-9b": "gemma2-9b",
-        "claude3-haiku": "claude3-haiku",
-        "mixtral-8x22b": "mixtral-8x22b",
-        "llama-3.1-405b": "llama-3.1-405b",
-    }
-
-    model_type = model_type_mapping.get(model_name.lower())
-    if model_type is None:
-        raise ValueError(f"Unsupported model name: {model_name}")
-
-    config = ModelConfig(
-        model_type=model_type, secrets=st.secrets, callback_handler=callback_handler
-    )
-    model = ModelWrapper(config)
-    return model.get_chain(vectorstore)
+# from dataclasses import dataclass, field
+# from operator import itemgetter
+# from typing import Any, Callable, Dict, Optional
+
+# import streamlit as st
+# from langchain.embeddings.openai import OpenAIEmbeddings
+# from langchain.llms import OpenAI
+# from langchain.prompts.prompt import PromptTemplate
+# from langchain.schema import format_document
+# from langchain.vectorstores import SupabaseVectorStore
+# from langchain_anthropic import ChatAnthropic
+# from langchain_community.chat_models import ChatOpenAI
+# from langchain_core.messages import get_buffer_string
+# from langchain_core.output_parsers import StrOutputParser
+# from langchain_core.runnables import RunnableParallel, RunnablePassthrough
+# from langchain_openai import ChatOpenAI, OpenAIEmbeddings
+
+# from supabase.client import Client, create_client
+# from template import CONDENSE_QUESTION_PROMPT, QA_PROMPT
+
+# DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
+
+# supabase_url = st.secrets["SUPABASE_URL"]
+# supabase_key = st.secrets["SUPABASE_SERVICE_KEY"]
+# supabase: Client = create_client(supabase_url, supabase_key)
+
+
+# @dataclass
+# class ModelConfig:
+#     model_type: str
+#     secrets: Dict[str, Any]
+#     callback_handler: Optional[Callable] = field(default=None)
+
+
+# class ModelWrapper:
+#     def __init__(self, config: ModelConfig):
+#         self.model_type = config.model_type
+#         self.secrets = config.secrets
+#         self.callback_handler = config.callback_handler
+#         self.llm = self._setup_llm()
+
+#     def _setup_llm(self):
+#         model_config = {
+#             "gpt-4o-mini": {
+#                 "model_name": "gpt-4o-mini",
+#                 "api_key": self.secrets["OPENAI_API_KEY"],
+#             },
+#             "gemma2-9b": {
+#                 "model_name": "gemma2-9b-it",
+#                 "api_key": self.secrets["GROQ_API_KEY"],
+#                 "base_url": "https://api.groq.com/openai/v1",
+#             },
+#             "claude3-haiku": {
+#                 "model_name": "claude-3-haiku-20240307",
+#                 "api_key": self.secrets["ANTHROPIC_API_KEY"],
+#             },
+#             "mixtral-8x22b": {
+#                 "model_name": "accounts/fireworks/models/mixtral-8x22b-instruct",
+#                 "api_key": self.secrets["FIREWORKS_API_KEY"],
+#                 "base_url": "https://api.fireworks.ai/inference/v1",
+#             },
+#             "llama-3.1-405b": {
+#                 "model_name": "accounts/fireworks/models/llama-v3p1-405b-instruct",
+#                 "api_key": self.secrets["FIREWORKS_API_KEY"],
+#                 "base_url": "https://api.fireworks.ai/inference/v1",
+#             },
+#         }
+
+#         config = model_config[self.model_type]
+
+#         return (
+#             ChatOpenAI(
+#                 model_name=config["model_name"],
+#                 temperature=0.1,
+#                 api_key=config["api_key"],
+#                 max_tokens=700,
+#                 callbacks=[self.callback_handler],
+#                 streaming=True,
+#                 base_url=config["base_url"]
+#                 if config["model_name"] != "gpt-4o-mini"
+#                 else None,
+#                 default_headers={
+#                     "HTTP-Referer": "https://snowchat.streamlit.app/",
+#                     "X-Title": "Snowchat",
+#                 },
+#             )
+#             if config["model_name"] != "claude-3-haiku-20240307"
+#             else (
+#                 ChatAnthropic(
+#                     model=config["model_name"],
+#                     temperature=0.1,
+#                     max_tokens=700,
+#                     timeout=None,
+#                     max_retries=2,
+#                     callbacks=[self.callback_handler],
+#                     streaming=True,
+#                 )
+#             )
+#         )
+
+#     def get_chain(self, vectorstore):
+#         def _combine_documents(
+#             docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
+#         ):
+#             doc_strings = [format_document(doc, document_prompt) for doc in docs]
+#             return document_separator.join(doc_strings)
+
+#         _inputs = RunnableParallel(
+#             standalone_question=RunnablePassthrough.assign(
+#                 chat_history=lambda x: get_buffer_string(x["chat_history"])
+#             )
+#             | CONDENSE_QUESTION_PROMPT
+#             | OpenAI()
+#             | StrOutputParser(),
+#         )
+#         _context = {
+#             "context": itemgetter("standalone_question")
+#             | vectorstore.as_retriever()
+#             | _combine_documents,
+#             "question": lambda x: x["standalone_question"],
+#         }
+#         conversational_qa_chain = _inputs | _context | QA_PROMPT | self.llm
+
+#         return conversational_qa_chain
+
+
+# def load_chain(model_name="qwen", callback_handler=None):
+#     embeddings = OpenAIEmbeddings(
+#         openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002"
+#     )
+#     vectorstore = SupabaseVectorStore(
+#         embedding=embeddings,
+#         client=supabase,
+#         table_name="documents",
+#         query_name="v_match_documents",
+#     )
+
+#     model_type_mapping = {
+#         "gpt-4o-mini": "gpt-4o-mini",
+#         "gemma2-9b": "gemma2-9b",
+#         "claude3-haiku": "claude3-haiku",
+#         "mixtral-8x22b": "mixtral-8x22b",
+#         "llama-3.1-405b": "llama-3.1-405b",
+#     }
+
+#     model_type = model_type_mapping.get(model_name.lower())
+#     if model_type is None:
+#         raise ValueError(f"Unsupported model name: {model_name}")
+
+#     config = ModelConfig(
+#         model_type=model_type, secrets=st.secrets, callback_handler=callback_handler
+#     )
+#     model = ModelWrapper(config)
+#     return model.get_chain(vectorstore)
diff --git a/ingest.py b/ingest.py
index 67de0f33..c6669f35 100644
--- a/ingest.py
+++ b/ingest.py
@@ -6,6 +6,7 @@
 from langchain.text_splitter import CharacterTextSplitter
 from langchain.vectorstores import SupabaseVectorStore
 from pydantic import BaseModel
+
 from supabase.client import Client, create_client
 
 
diff --git a/main.py b/main.py
index 1a491be1..97a6ddce 100644
--- a/main.py
+++ b/main.py
@@ -2,9 +2,10 @@
 import warnings
 
 import streamlit as st
+from langchain_core.messages import HumanMessage
 from snowflake.snowpark.exceptions import SnowparkSQLException
 
-from chain import load_chain
+from agent import MessagesState, create_agent
 
 # from utils.snow_connect import SnowflakeConnection
 from utils.snowchat_ui import StreamlitUICallbackHandler, message_func
@@ -50,6 +51,9 @@
 )
 st.session_state["model"] = model
 
+if "assistant_response_processed" not in st.session_state:
+    st.session_state["assistant_response_processed"] = True  # Initialize to True
+
 if "toast_shown" not in st.session_state:
     st.session_state["toast_shown"] = False
 
@@ -76,6 +80,7 @@
         "content": "Hey there, I'm Chatty McQueryFace, your SQL-speaking sidekick, ready to chat up Snowflake and fetch answers faster than a snowball fight in summer! ❄️🔍",
     },
 ]
+config = {"configurable": {"thread_id": "42"}}
 
 with open("ui/sidebar.md", "r") as sidebar_file:
     sidebar_content = sidebar_file.read()
@@ -118,18 +123,28 @@
 # Prompt for user input and save
 if prompt := st.chat_input():
     st.session_state.messages.append({"role": "user", "content": prompt})
+    st.session_state["assistant_response_processed"] = (
+        False  # Assistant response not yet processed
+    )
 
-for message in st.session_state.messages:
+messages_to_display = st.session_state.messages.copy()
+# if not st.session_state["assistant_response_processed"]:
+#     # Exclude the last assistant message if assistant response not yet processed
+#     if messages_to_display and messages_to_display[-1]["role"] == "assistant":
+#         print("\n\nthis is messages_to_display \n\n", messages_to_display)
+#         messages_to_display = messages_to_display[:-1]
+
+for message in messages_to_display:
     message_func(
         message["content"],
-        True if message["role"] == "user" else False,
-        True if message["role"] == "data" else False,
-        model,
+        is_user=(message["role"] == "user"),
+        is_df=(message["role"] == "data"),
+        model=model,
     )
 
 callback_handler = StreamlitUICallbackHandler(model)
 
-chain = load_chain(st.session_state["model"], callback_handler)
+react_graph = create_agent(callback_handler, st.session_state["model"])
 
 
 def append_chat_history(question, answer):
@@ -148,20 +163,21 @@ def append_message(content, role="assistant"):
 
 
 def handle_sql_exception(query, conn, e, retries=2):
-    append_message("Uh oh, I made an error, let me try to fix it..")
-    error_message = (
-        "You gave me a wrong SQL. FIX The SQL query by searching the schema definition:  \n```sql\n"
-        + query
-        + "\n```\n Error message: \n "
-        + str(e)
-    )
-    new_query = chain({"question": error_message, "chat_history": ""})["answer"]
-    append_message(new_query)
-    if get_sql(new_query) and retries > 0:
-        return execute_sql(get_sql(new_query), conn, retries - 1)
-    else:
-        append_message("I'm sorry, I couldn't fix the error. Please try again.")
-        return None
+    # append_message("Uh oh, I made an error, let me try to fix it..")
+    # error_message = (
+    #     "You gave me a wrong SQL. FIX The SQL query by searching the schema definition:  \n```sql\n"
+    #     + query
+    #     + "\n```\n Error message: \n "
+    #     + str(e)
+    # )
+    # new_query = chain({"question": error_message, "chat_history": ""})["answer"]
+    # append_message(new_query)
+    # if get_sql(new_query) and retries > 0:
+    #     return execute_sql(get_sql(new_query), conn, retries - 1)
+    # else:
+    #     append_message("I'm sorry, I couldn't fix the error. Please try again.")
+    #     return None
+    pass
 
 
 def execute_sql(query, conn, retries=2):
@@ -176,20 +192,25 @@ def execute_sql(query, conn, retries=2):
 
 if (
     "messages" in st.session_state
-    and st.session_state["messages"][-1]["role"] != "assistant"
+    and st.session_state["messages"][-1]["role"] == "user"
+    and not st.session_state["assistant_response_processed"]
 ):
     user_input_content = st.session_state["messages"][-1]["content"]
 
     if isinstance(user_input_content, str):
+        # Start loading animation
         callback_handler.start_loading_message()
 
-        result = chain.invoke(
-            {
-                "question": user_input_content,
-                "chat_history": [h for h in st.session_state["history"]],
-            }
-        )
-        append_message(result.content)
+        messages = [HumanMessage(content=user_input_content)]
+
+        state = MessagesState(messages=messages)
+        result = react_graph.invoke(state, config=config)
+
+        if result["messages"]:
+            assistant_message = callback_handler.final_message
+            append_message(assistant_message)
+            st.session_state["assistant_response_processed"] = True
+
 
 if (
     st.session_state["model"] == "Mixtral 8x7B"
diff --git a/template.py b/template.py
index c8cd086c..5cc1759a 100644
--- a/template.py
+++ b/template.py
@@ -1,4 +1,3 @@
-from langchain.prompts.prompt import PromptTemplate
 from langchain_core.prompts import ChatPromptTemplate
 
 template = """You are an AI chatbot having a conversation with a human.
diff --git a/tools.py b/tools.py
new file mode 100644
index 00000000..5b5a4504
--- /dev/null
+++ b/tools.py
@@ -0,0 +1,28 @@
+import streamlit as st
+from langchain.prompts.prompt import PromptTemplate
+from supabase.client import Client, create_client
+from langchain.tools.retriever import create_retriever_tool
+from langchain_openai import OpenAIEmbeddings
+from langchain_community.vectorstores import SupabaseVectorStore
+
+supabase_url = st.secrets["SUPABASE_URL"]
+supabase_key = st.secrets["SUPABASE_SERVICE_KEY"]
+supabase: Client = create_client(supabase_url, supabase_key)
+
+
+embeddings = OpenAIEmbeddings(
+    openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002"
+)
+vectorstore = SupabaseVectorStore(
+    embedding=embeddings,
+    client=supabase,
+    table_name="documents",
+    query_name="v_match_documents",
+)
+
+
+retriever_tool = create_retriever_tool(
+    vectorstore.as_retriever(),
+    name="Database_Schema",
+    description="Search for database schema details",
+)
diff --git a/utils/snow_connect.py b/utils/snow_connect.py
index 2268c8bd..d0b396a6 100644
--- a/utils/snow_connect.py
+++ b/utils/snow_connect.py
@@ -2,7 +2,6 @@
 
 import streamlit as st
 from snowflake.snowpark.session import Session
-from snowflake.snowpark.version import VERSION
 
 
 class SnowflakeConnection:
diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py
index 03f5f58f..98a63370 100644
--- a/utils/snowchat_ui.py
+++ b/utils/snowchat_ui.py
@@ -1,10 +1,10 @@
 import html
 import re
+import textwrap
 
 import streamlit as st
 from langchain.callbacks.base import BaseCallbackHandler
 
-
 image_url = f"{st.secrets['SUPABASE_STORAGE_URL']}/storage/v1/object/public/snowchat/"
 gemini_url = image_url + "google-gemini-icon.png?t=2024-05-07T21%3A17%3A52.235Z"
 mistral_url = (
@@ -61,7 +61,7 @@ def format_message(text):
 
 def message_func(text, is_user=False, is_df=False, model="gpt"):
     """
-    This function is used to display the messages in the chatbot UI.
+    This function displays messages in the chatbot UI, ensuring proper alignment and avatar positioning.
 
     Parameters:
     text (str): The text to be displayed.
@@ -69,52 +69,36 @@ def message_func(text, is_user=False, is_df=False, model="gpt"):
     is_df (bool): Whether the message is a dataframe or not.
     """
     model_url = get_model_url(model)
+    avatar_url = user_url if is_user else model_url
+    message_bg_color = (
+        "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)" if is_user else "#71797E"
+    )
+    avatar_class = "user-avatar" if is_user else "bot-avatar"
+    alignment = "flex-end" if is_user else "flex-start"
+    margin_side = "margin-left" if is_user else "margin-right"
+    message_text = html.escape(text.strip()).replace('\n', '<br>')
 
-    avatar_url = model_url
     if is_user:
-        avatar_url = user_url
-        message_alignment = "flex-end"
-        message_bg_color = "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)"
-        avatar_class = "user-avatar"
-        st.write(
-            f"""
-                <div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
-                    <div style="background: {message_bg_color}; color: white; border-radius: 20px; padding: 10px; margin-right: 5px; max-width: 75%; font-size: 14px;">
-                        {text} \n </div>
-                    <img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width: 40px; height: 40px;" />
-                </div>
-                """,
-            unsafe_allow_html=True,
-        )
+        container_html = f"""
+        <div style="display:flex; align-items:flex-start; justify-content:flex-end; margin:0; padding:0; margin-bottom:10px;">
+            <div style="background:{message_bg_color}; color:white; border-radius:20px; padding:10px; margin-right:5px; max-width:75%; font-size:14px; margin:0; line-height:1.2; word-wrap:break-word;">
+                {message_text}
+            </div>
+            <img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width:40px; height:40px; margin:0;" />
+        </div>
+        """
     else:
-        message_alignment = "flex-start"
-        message_bg_color = "#71797E"
-        avatar_class = "bot-avatar"
+        container_html = f"""
+        <div style="display:flex; align-items:flex-start; justify-content:flex-start; margin:0; padding:0; margin-bottom:10px;">
+            <img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width:30px; height:30px; margin:0; margin-right:5px; margin-top:5px;" />
+            <div style="background:{message_bg_color}; color:white; border-radius:20px; padding:10px; margin-left:5px; max-width:75%; font-size:14px; margin:0; line-height:1.2; word-wrap:break-word;">
+                {message_text}
+            </div>
+        </div>
+        """
 
-        if is_df:
-            st.write(
-                f"""
-                    <div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
-                        <img src="{model_url}" class="{avatar_class}" alt="avatar" style="width: 50px; height: 50px;" />
-                    </div>
-                    """,
-                unsafe_allow_html=True,
-            )
-            st.write(text)
-            return
-        else:
-            text = format_message(text)
+    st.write(container_html, unsafe_allow_html=True)
 
-        st.write(
-            f"""
-                <div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
-                    <img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width: 30px; height: 30px;" />
-                    <div style="background: {message_bg_color}; color: white; border-radius: 20px; padding: 10px; margin-right: 5px; margin-left: 5px; max-width: 75%; font-size: 14px;">
-                        {text} \n </div>
-                </div>
-                """,
-            unsafe_allow_html=True,
-        )
 
 
 class StreamlitUICallbackHandler(BaseCallbackHandler):
@@ -125,6 +109,7 @@ def __init__(self, model):
         self.has_streaming_started = False
         self.model = model
         self.avatar_url = get_model_url(model)
+        self.final_message = ""
 
     def start_loading_message(self):
         loading_message_content = self._get_bot_message_container("Thinking...")
@@ -138,6 +123,7 @@ def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs):
         complete_message = "".join(self.token_buffer)
         container_content = self._get_bot_message_container(complete_message)
         self.placeholder.markdown(container_content, unsafe_allow_html=True)
+        self.final_message = "".join(self.token_buffer)
 
     def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs):
         self.token_buffer = []
@@ -146,16 +132,20 @@ def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs):
 
     def _get_bot_message_container(self, text):
         """Generate the bot's message container style for the given text."""
-        formatted_text = format_message(text)
+        formatted_text = format_message(text.strip())
         container_content = f"""
-            <div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: flex-start;">
-                <img src="{self.avatar_url}" class="bot-avatar" alt="avatar" style="width: 30px; height: 30px;" />
-                <div style="background: #71797E; color: white; border-radius: 20px; padding: 10px; margin-right: 5px; margin-left: 5px; max-width: 75%; font-size: 14px;">
-                    {formatted_text} \n </div>
+        <div style="display:flex; align-items:flex-start; justify-content:flex-start; margin:0; padding:0;">
+            <img src="{self.avatar_url}" class="bot-avatar" alt="avatar" style="width:30px; height:30px; margin:0;" />
+            <div style="background:#71797E; color:white; border-radius:20px; padding:10px; margin-left:5px; max-width:75%; font-size:14px; line-height:1.2; word-wrap:break-word;">
+                {formatted_text}
             </div>
+        </div>
         """
         return container_content
 
+
+
+
     def display_dataframe(self, df):
         """
         Display the dataframe in Streamlit UI within the chat container.
@@ -165,13 +155,14 @@ def display_dataframe(self, df):
 
         st.write(
             f"""
-            <div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
-                <img src="{self.avatar_url}" class="{avatar_class}" alt="avatar" style="width: 30px; height: 30px;" />
+            <div style="display: flex; align-items: flex-start; margin-bottom: 10px; justify-content: {message_alignment};">
+                <img src="{self.avatar_url}" class="{avatar_class}" alt="avatar" style="width: 30px; height: 30px; margin-top: 0;" />
             </div>
             """,
             unsafe_allow_html=True,
         )
         st.write(df)
 
+
     def __call__(self, *args, **kwargs):
         pass

From 28cbf6af920ecff52e2730d0214d212d431bc54d Mon Sep 17 00:00:00 2001
From: kaarthik108 <kaarthikandavar@gmail.com>
Date: Sun, 13 Oct 2024 22:45:52 +1300
Subject: [PATCH 2/9] use llama3.2

---
 agent.py | 8 ++++----
 main.py  | 4 ++--
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/agent.py b/agent.py
index a7a7462b..a13e8bd3 100644
--- a/agent.py
+++ b/agent.py
@@ -33,8 +33,8 @@ class ModelConfig:
 
 def create_agent(callback_handler: BaseCallbackHandler, model_name: str):
     model_configurations = {
-        "gpt-4o-mini": ModelConfig(
-            model_name="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY")
+        "gpt-4o": ModelConfig(
+            model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY")
         ),
         "gemma2-9b": ModelConfig(
             model_name="gemma2-9b-it",
@@ -44,8 +44,8 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str):
         "claude3-haiku": ModelConfig(
             model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY")
         ),
-        "mixtral-8x22b": ModelConfig(
-            model_name="accounts/fireworks/models/mixtral-8x22b-instruct",
+        "llama-3.2-3b": ModelConfig(
+            model_name="accounts/fireworks/models/llama-v3p2-3b-instruct",
             api_key=os.getenv("FIREWORKS_API_KEY"),
             base_url="https://api.fireworks.ai/inference/v1",
         ),
diff --git a/main.py b/main.py
index 97a6ddce..fe540794 100644
--- a/main.py
+++ b/main.py
@@ -35,11 +35,11 @@
 st.caption("Talk your way through data")
 
 model_options = {
-    "gpt-4o-mini": "GPT-4o Mini",
+    "gpt-4o": "GPT-4o",
     "llama-3.1-405b": "Llama 3.1 405B",
     "gemma2-9b": "Gemma 2 9B",
     "claude3-haiku": "Claude 3 Haiku",
-    "mixtral-8x22b": "Mixtral 8x22B",
+    "llama-3.2-3b": "Llama 3.2 3B",
 }
 
 model = st.radio(

From 77fd93e907769f8297038ad71befb8472d5b164d Mon Sep 17 00:00:00 2001
From: kaarthik108 <kaarthikandavar@gmail.com>
Date: Thu, 17 Oct 2024 20:34:43 +1300
Subject: [PATCH 3/9] Fix loading ui

---
 utils/snowchat_ui.py | 41 ++++++++++++++++++++---------------------
 1 file changed, 20 insertions(+), 21 deletions(-)

diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py
index 98a63370..83399cc4 100644
--- a/utils/snowchat_ui.py
+++ b/utils/snowchat_ui.py
@@ -78,26 +78,26 @@ def message_func(text, is_user=False, is_df=False, model="gpt"):
     margin_side = "margin-left" if is_user else "margin-right"
     message_text = html.escape(text.strip()).replace('\n', '<br>')
 
-    if is_user:
-        container_html = f"""
-        <div style="display:flex; align-items:flex-start; justify-content:flex-end; margin:0; padding:0; margin-bottom:10px;">
-            <div style="background:{message_bg_color}; color:white; border-radius:20px; padding:10px; margin-right:5px; max-width:75%; font-size:14px; margin:0; line-height:1.2; word-wrap:break-word;">
-                {message_text}
+    if message_text:  # Check if message_text is not empty
+        if is_user:
+            container_html = f"""
+            <div style="display:flex; align-items:flex-start; justify-content:flex-end; margin:0; padding:0; margin-bottom:10px;">
+                <div style="background:{message_bg_color}; color:white; border-radius:20px; padding:10px; margin-right:5px; max-width:75%; font-size:14px; margin:0; line-height:1.2; word-wrap:break-word;">
+                    {message_text}
+                </div>
+                <img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width:40px; height:40px; margin:0;" />
             </div>
-            <img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width:40px; height:40px; margin:0;" />
-        </div>
-        """
-    else:
-        container_html = f"""
-        <div style="display:flex; align-items:flex-start; justify-content:flex-start; margin:0; padding:0; margin-bottom:10px;">
-            <img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width:30px; height:30px; margin:0; margin-right:5px; margin-top:5px;" />
-            <div style="background:{message_bg_color}; color:white; border-radius:20px; padding:10px; margin-left:5px; max-width:75%; font-size:14px; margin:0; line-height:1.2; word-wrap:break-word;">
-                {message_text}
+            """
+        else:
+            container_html = f"""
+            <div style="display:flex; align-items:flex-start; justify-content:flex-start; margin:0; padding:0; margin-bottom:10px;">
+                <img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width:30px; height:30px; margin:0; margin-right:5px; margin-top:5px;" />
+                <div style="background:{message_bg_color}; color:white; border-radius:20px; padding:10px; margin-left:5px; max-width:75%; font-size:14px; margin:0; line-height:1.2; word-wrap:break-word;">
+                    {message_text}
+                </div>
             </div>
-        </div>
-        """
-
-    st.write(container_html, unsafe_allow_html=True)
+            """
+        st.write(container_html, unsafe_allow_html=True)
 
 
 
@@ -133,6 +133,8 @@ def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs):
     def _get_bot_message_container(self, text):
         """Generate the bot's message container style for the given text."""
         formatted_text = format_message(text.strip())
+        if not formatted_text:  # If no formatted text, show "Thinking..."
+            formatted_text = "Thinking..."
         container_content = f"""
         <div style="display:flex; align-items:flex-start; justify-content:flex-start; margin:0; padding:0;">
             <img src="{self.avatar_url}" class="bot-avatar" alt="avatar" style="width:30px; height:30px; margin:0;" />
@@ -143,9 +145,6 @@ def _get_bot_message_container(self, text):
         """
         return container_content
 
-
-
-
     def display_dataframe(self, df):
         """
         Display the dataframe in Streamlit UI within the chat container.

From 3c6472c1ecae27f4948b251b0b41cba9d0a62fee Mon Sep 17 00:00:00 2001
From: kaarthik108 <kaarthikandavar@gmail.com>
Date: Thu, 17 Oct 2024 21:14:30 +1300
Subject: [PATCH 4/9] update models

---
 agent.py             | 9 +++++----
 main.py              | 4 ++--
 utils/snowchat_ui.py | 2 ++
 3 files changed, 9 insertions(+), 6 deletions(-)

diff --git a/agent.py b/agent.py
index a13e8bd3..d3324fdc 100644
--- a/agent.py
+++ b/agent.py
@@ -1,4 +1,5 @@
 import os
+import streamlit as st
 from dataclasses import dataclass
 from typing import Annotated, Sequence, Optional
 
@@ -36,10 +37,10 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str):
         "gpt-4o": ModelConfig(
             model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY")
         ),
-        "gemma2-9b": ModelConfig(
-            model_name="gemma2-9b-it",
-            api_key=os.getenv("GROQ_API_KEY"),
-            base_url="https://api.groq.com/openai/v1",
+        "Gemini Flash 1.5 8B": ModelConfig(
+            model_name="google/gemini-flash-1.5-8b",
+            api_key=st.secrets["OPENROUTER_API_KEY"],
+            base_url="https://openrouter.ai/api/v1",
         ),
         "claude3-haiku": ModelConfig(
             model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY")
diff --git a/main.py b/main.py
index fe540794..b75b3f56 100644
--- a/main.py
+++ b/main.py
@@ -37,9 +37,9 @@
 model_options = {
     "gpt-4o": "GPT-4o",
     "llama-3.1-405b": "Llama 3.1 405B",
-    "gemma2-9b": "Gemma 2 9B",
-    "claude3-haiku": "Claude 3 Haiku",
     "llama-3.2-3b": "Llama 3.2 3B",
+    "claude3-haiku": "Claude 3 Haiku",
+    "Gemini Flash 1.5 8B": "Gemini Flash 1.5 8B",
 }
 
 model = st.radio(
diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py
index 83399cc4..2fe60d0e 100644
--- a/utils/snowchat_ui.py
+++ b/utils/snowchat_ui.py
@@ -35,6 +35,8 @@ def get_model_url(model_name):
         return snow_url
     elif "gpt" in model_name.lower():
         return openai_url
+    elif "gemini" in model_name.lower():
+        return gemini_url
     return mistral_url
 
 

From fde0064fd6c76506569ca09d21343e51563add01 Mon Sep 17 00:00:00 2001
From: kaarthik108 <kaarthikandavar@gmail.com>
Date: Fri, 18 Oct 2024 19:10:31 +1300
Subject: [PATCH 5/9] Add Cloudflare KV caching

---
 agent.py              |  88 +++++++++++++++++++++++-------------------
 graph.png             | Bin 0 -> 8039 bytes
 main.py               |   2 +-
 tools.py              |  19 +++++++--
 utils/snow_connect.py |  60 +++++++++++++++++++++++++---
 5 files changed, 120 insertions(+), 49 deletions(-)
 create mode 100644 graph.png

diff --git a/agent.py b/agent.py
index d3324fdc..fa0f2a69 100644
--- a/agent.py
+++ b/agent.py
@@ -13,9 +13,10 @@
 from langgraph.graph.message import add_messages
 from langchain_core.messages import BaseMessage
 
-from template import TEMPLATE
 from tools import retriever_tool
-
+from tools import search, sql_executor_tool
+from PIL import Image
+from io import BytesIO
 
 @dataclass
 class MessagesState:
@@ -32,39 +33,43 @@ class ModelConfig:
     base_url: Optional[str] = None
 
 
-def create_agent(callback_handler: BaseCallbackHandler, model_name: str):
-    model_configurations = {
-        "gpt-4o": ModelConfig(
-            model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY")
-        ),
-        "Gemini Flash 1.5 8B": ModelConfig(
-            model_name="google/gemini-flash-1.5-8b",
-            api_key=st.secrets["OPENROUTER_API_KEY"],
-            base_url="https://openrouter.ai/api/v1",
-        ),
-        "claude3-haiku": ModelConfig(
-            model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY")
-        ),
-        "llama-3.2-3b": ModelConfig(
-            model_name="accounts/fireworks/models/llama-v3p2-3b-instruct",
-            api_key=os.getenv("FIREWORKS_API_KEY"),
-            base_url="https://api.fireworks.ai/inference/v1",
-        ),
-        "llama-3.1-405b": ModelConfig(
-            model_name="accounts/fireworks/models/llama-v3p1-405b-instruct",
-            api_key=os.getenv("FIREWORKS_API_KEY"),
-            base_url="https://api.fireworks.ai/inference/v1",
-        ),
-    }
+model_configurations = {
+    "gpt-4o": ModelConfig(
+        model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY")
+    ),
+    "Gemini Flash 1.5 8B": ModelConfig(
+        model_name="google/gemini-flash-1.5-8b",
+        api_key=st.secrets["OPENROUTER_API_KEY"],
+        base_url="https://openrouter.ai/api/v1",
+    ),
+    "claude3-haiku": ModelConfig(
+        model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY")
+    ),
+    "llama-3.2-3b": ModelConfig(
+        model_name="accounts/fireworks/models/llama-v3p2-3b-instruct",
+        api_key=os.getenv("FIREWORKS_API_KEY"),
+        base_url="https://api.fireworks.ai/inference/v1",
+    ),
+    "llama-3.1-405b": ModelConfig(
+        model_name="accounts/fireworks/models/llama-v3p1-405b-instruct",
+        api_key=os.getenv("FIREWORKS_API_KEY"),
+        base_url="https://api.fireworks.ai/inference/v1",
+    ),
+}
+sys_msg = SystemMessage(
+    content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. Do not ask the user for schema or database details. You have access to the following tools:
+    - Database_Schema: This tool allows you to search for database schema details when needed to generate the SQL code.
+    - Internet_Search: This tool allows you to search the internet for snowflake sql related information when needed to generate the SQL code.
+    - Snowflake_SQL_Executor: This tool allows you to execute snowflake sql queries when needed to generate the SQL code.
+    """
+)
+tools = [retriever_tool, search, sql_executor_tool]
+
+def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> StateGraph:
     config = model_configurations.get(model_name)
     if not config:
         raise ValueError(f"Unsupported model name: {model_name}")
 
-    sys_msg = SystemMessage(
-        content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate.
-        Call the tool "Database_Schema" to search for database schema details when needed to generate the SQL code.
-        """
-    )
 
     llm = (
         ChatOpenAI(
@@ -73,6 +78,7 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str):
             callbacks=[callback_handler],
             streaming=True,
             base_url=config.base_url,
+            temperature=0.01,
         )
         if config.model_name != "claude-3-haiku-20240307"
         else ChatAnthropic(
@@ -83,21 +89,25 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str):
         )
     )
 
-    tools = [retriever_tool]
-
     llm_with_tools = llm.bind_tools(tools)
 
-    def reasoner(state: MessagesState):
+    def llm_agent(state: MessagesState):
         return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]}
 
-    # Build the graph
     builder = StateGraph(MessagesState)
-    builder.add_node("reasoner", reasoner)
+    builder.add_node("llm_agent", llm_agent)
     builder.add_node("tools", ToolNode(tools))
-    builder.add_edge(START, "reasoner")
-    builder.add_conditional_edges("reasoner", tools_condition)
-    builder.add_edge("tools", "reasoner")
+    builder.add_edge(START, "llm_agent")
+    builder.add_conditional_edges("llm_agent", tools_condition)
+    builder.add_edge("tools", "llm_agent")
 
     react_graph = builder.compile(checkpointer=memory)
 
+    # png_data = react_graph.get_graph(xray=True).draw_mermaid_png()
+    # with open("graph.png", "wb") as f:
+    #     f.write(png_data)
+
+    # image = Image.open(BytesIO(png_data))
+    # st.image(image, caption="React Graph")
+
     return react_graph
diff --git a/graph.png b/graph.png
new file mode 100644
index 0000000000000000000000000000000000000000..dcb88cd09d3a1a3d86f8a5b8262ebe6adc87849f
GIT binary patch
literal 8039
zcmb7I1wd5W);>r}2$C|CfFP;B(4~@!f`Zh5AkEM)^bjHvN(#~?paK%Z&?BMJ-ObP)
zLwEk;z4txed;fcXtaHxV-`Q*JwZC&_pS{<f%aO}zfZ~zTLnQzQ2LN!e2XHxyGxAVg
z-sq`@y3)hPihmkf0IcBM1pqsH7iSIS2duh!5LUwJKU@5=8Jof!f0h41uyGHDenkg>
ze*XX9^FLJ+o0-E*u?}0<i^Cb)IhHIn7E@dPiEsbHCVyh-U)bHn(FN=C^cQx1t|5=b
zW>|dN@*l9tKVVZw=U@I|SRZLSTi0KG{gPjdNzLuGG_fTS_F@EJfCiupJoxqe*n6xv
zWCDQrJ^<iG{jM`f1b~V_0H7cHUB~ee0It0Ofb#C&b-&+<qp`E`AK~z@`ig}G0PJJ~
z0GS>DQ1t-7RsBC~SoI&cv0_b(SY8g;!xFFs%mG&5Az%-f0s>eh1l$D#0ny8GKpw!o
za^=^F6+G;TPlS(;hlhWakdT0g^eQPS$yE{(GV<#bWaO0OBqS8H6qHodG&D4%*XZbJ
zsp+p%(@_5kf`f}~gNILyk55cZMnXpY|CY-i03{J{3=HAoumM*nad0VdE*k(Q?CZt_
zaB+SO#lHlffDjjt2<Hkh)~rDR;Nan4*%Oiw5)u<*J#c_4*he6sB)oo;N|1(@?f$b5
z{Y2Ds>_S@F#>KsTqR*XN{NKk2i)lE!#+K4^Jdl4K9hdR*7Fa~nED$-cg^daN)!9Gw
z2w;1_$0H!b8o`v<58?_AF&@b;UVjMA6-r#Z>o*1OKht>MOKqI-6Q7EPUD(9QKN>kB
z#J074IR=p8Vrfz0QUY?oMQ$eB6*g8jwktDb&Na7>*VUQa{2u89MNU3v8b5>$FEyKd
z0e>Bqj@z-bf+bx7XF2lzP{&Oennb_SirmSaYYZ*dzg68XAV6}Uh8{vxObLreO27~0
zsxqgjHPsu^OZh9RYDJ8F-80tsPWzkk?{rhkn5TPr<OLq51?LHGsM9!p7w*}&@M$|u
zy9DriJO%@gXM^)gHq{x5zDqydkMmiXEu>rYV$SbINo&=eb;}PX9&({Jx4_KUs2=}Y
z)Yc9785v=?u2I@;uXfpebZ%6h&*wqgdh@L6dR9jkl*`%QVhVgNz|H1wx#h-?B3!yL
zPLoN|(Z8?KnIdPb4wJI1KT4un&Z8VfeX<M!9eimuUT_SL&tAY-SEeS7ssIwfAFqu&
zx-CvQ&saj%vczhM{V8cnDJgL%DSj(zrNRz5^%d_zB~~U8>+WdIs><&U94zC-@md=<
zHRBH6b~KYbjuEZ;Y|0=e0Fy2P_M#`^(DZK`zVHls=nh5oA9-xiTYq4eo(PpjN$;qr
zZF-)dXP2l}ZmAS9xMkL8kld9@yOA(Tzp1K%W<l!+)#WCBwJ5OSiKh@R_+0sdR^}vz
z6DoI;Ju??i<Vy*7!^Shcm6{LUVzsI^PS!NfYuzZ*SIQ$xCKVIOk`Kxi>@MMCm&igk
ztq<ICu_HI@1*Vr5mPriy`r8L8GIS%lG$#bJx2v`Ky`GdIAX&+EM}~~42Ck@|OVPA-
zJtH&j(Tgohk`L}#bdoO!T>WViId(MyWgk_Nc3A6rkgSMQRdcJe6}KNzSd?uql?qSp
zMxn{#G?|U`Rwmr*Fh2P%X%I{mgHjilAbHxy)5vVEhc=HTW+R=4s){|(_6%NBOV7#q
z=l0XGlQl`3Kt>B?1aewG9auC$c=aUYoFwlbGQKYtVn^P5`{{kj!jcRFu*{rL(QHSh
zt~=Qm%x%(?kp*_eJk1U-Gv1O^8D}Tya1s{@L^#B$3(SWf;t|ze+1>;-OVcFBT1_+?
zAWS-(b<$$vl-SMRvW0ghsU$R`I$y#P`c#Se9q|Mzh+WRnmh_eV(PiOC4bnJ$=GF2=
z{|>vH(Fp%f$OhYkvNOy|C1Ha5M&@~kn-GK1wGTas;W9~0^GMIfpAmk#nFl>3q!pjy
z#N<uwmE%P%(w@3A7NW_b?1@4G5K$RMybn5gokYFLO<jZ<E|!q>mbfSPdD!KR(k_zU
zK;kPZH1u*hxu-l?NDD)%H_rvs9LBN-ErY``;Invwv-piM6wBRlkNMP%FwB0*hl|1G
z?CC+D?ZqluGX+DJ-$V2-gs+jZ-s1qa!_eY%3uHSWZawK*yrB*Qp`;_1C(4aSf(}x=
z5N{wLqj}R-YnfYbx+iTP&_Y_Cz4OF-8O^{rtuC<PG`_|&b|9)G5orZy6y<kBQ!ewA
zUa**EaLn1x-~D(Ad~MeEcDkn2%QWY(ZgmMn8NU4*_C%yJJM|LSoQQj?6N#z_#LxIn
zaH9X=`n|2kAl#)(Kza3QXV#JX=SnzPLTOW46Z9w_RhXJE7WV3q(li}`f9*JS4tX@Q
z&Dq?Sf4H4-@No7?<Y&FjIr;UN!I|x&M-Z_BqOTVmBPUvE7hu`e=bNU*gGYG*{T+i6
zk%R_yQ{TCT>c`xQx}SuN&t9nWzZF}VY9H{NL|lukQ&l?1$#As@{b5sW>NFNfA}<wb
z``Ti#W!QE*_|-jkUVqT5S+SG0<)awW6>a!Y&nd@m7OPuFXB?#t_r%Z>b#a6QTXU+V
z(zerrcU?gN$Q>V6-vzEO9=spH7KVI<xQf$`qv8R{Ol{&v8L=7lmcU{uWhDm<`tQ+H
zH$$k`r7n6}%$@V6e47S`Ln2p9B_K2V;uMJqPKCA4E`h*|<F<Ked(75IuIH*;5@V1O
zohL;}j?1%3_cr)9Z~v;dgt3LiKk-~YM$qeyWQJ#inG*agz<mkq;_npbLUkf)JWLUI
zc4TJHz#;{>m4GCG(xw@vx>36NT0sF4rJ#VPaPPk<{B-kBUzXA$j$IFA{D^1y;gFc0
z-14q(?^2JpU+mP(sj92}B%k>Sh1exfV+yZ6^^Hg!KaR2@3V9V0@;|8(bP>w}?q_x3
z``bo>^O%84Aazc1v%(#-8#$W6cgSCHd>?9<o!sc%R3Co{=!ETF0;d-$V2*#o@lW20
zT5Z-<J4Q%#LZzAtEq6<k&ejR)C%!*+gkvM_uu^_^1YuA&f7sU7=UumZG`6O}qa5rS
zyhtynpFEoVcZ5(jM|i_tw}1P$*6hEGTg=EqxVg*WH?~^N;j2q-84b-awJ>kJR~wUh
z>)J5GX84q^=B875ecLs{4mzZ_RKDEZpC{<aPO}@Eu(Pq?3Busn!7yKX3#UOi)l=Q5
zM(IvQq}tr*W>-^c@V?$Tb78yT(Gq@IoTAFsM$uV4;TjAM{*BhlHX|JM-2bmek#0?$
zh1+_8sCA!5`<!qSyEajl+u<8hV3lFPPvB_6yIi@TSCY9B0vt-A>XRe_L?P87*v#RN
z^x-eXcdEi{Rf2bTWO!)QE7g2X`?_R;<t)PV)a39e7~!ja!F%di*l@y@Yj%YA?6@yt
zmz=29M8@4d6Az!%fe$Bfc1RLLUjjEVvdUt<6X&#6<(X%Vyq0gH@9qj>$f0|$zdw5Y
zXIIN9i!!Kk<IWLjo|BDoZnuE4yTJsB67XZs6oT{d4@?U$W+}&&%|bTGpe)aptdTsM
z6qf+Ikl%pAxb=kpHuR(!V`>0?u3%vHFVi0_JnIs8-{EI#cnMJKV@H@Wwq3%-NX^B~
zwW*j(;N8Wf4N|r`zM=M;{)CnEsg9H%Mr*i$Mq}w}X_~xDO2=s3Pf>XZ`JahmsIGaG
zksXR7(pQ;+@U|1<50@$XQ0~yIJKb{vHuh3;;Er!GGVgve(cbx1PIv2;tY>2e$ZXLV
zkAfK<GWC*t+BSz#xP4XHH+tH)7@4<wY@7C}^9aOE@Zj`wo(ZA6L;S3-ykWng5OWWs
z+F9c`LU|>P78=gv8;DjD87z3(bi`a5Q3n1ZnVJn9^>mOr_Hma=7d>@(VRfYV#R{4#
z7u0OZj?h=rJG@|!ut-wUDhBcDv)A60ye9<}vq`zXc;ZjL*oAa4d{C0VPdw~psSmA6
zih@2aFR`HPUv$NZt7-X!<hU-%=x5o^Xa{@Ia^d!U_9yG+x5*&hP*C{3QaF2ed2+=i
zFl9gmPZ4)D<rcVq;s{?jBv4GL;1h{me~VXB%6HwlZfA?w#rV_yl%v{#iP=lfj;@S>
zZ&SR=P5jyuaD9V%2*1>$#t^4D($P7X;bL<Mo<${wiTh^bgxDGqqb;p7d!eeo=(2x&
zd$P&0aMr17`zyqC89u3|J1#;n;0?0};8n`e(-(mOb_EQ1=(W6^vC~^Ib-E4E1pWp5
zgGJ<Mejb0*q#NqA4<oaFSb2q53cPzAQucTX*rUxDbhcKlB&fOre%~%#G?qrcE-Qm#
z;(P27p+gqbh~fIEYta#AsyRHMcZOgV7#$o7po6Ju<~e|LsA8fS%JPqRB+}#w(Cb2C
zc`Q|G8`2qFx7_wFXl2|tqIyS9k4ekm<$L(v^>d#;q*beZ(0da{D>zLbDHeN!v5?D|
zSihlKqQ9i|>%6aA;>j1$#srz_^ra^h2Y@$%eYW||`}UR1)hrkzeeY<0Hj@LzoIu*!
zvhA4r_8#W1SsC!m*F9jP^I~q*l}z7W7k9pQ)0(__p=5Rm<V=+hn*8I<>W`z+bJGK4
zAp;MafY86p#w0Smmq5#+T=^Zp5U-wC&P6WXY-c+2X48p;3};Kd^~}Wgoum<~+Y6N<
zom@GqL>z~Q2z&MCkguF}a1TD~;wRKxfhAM&LbYY8<AoJr#{oX2<5|v|31wYQW`m4P
z&O7$Z+hV*13gI_Dy;hpP!l1Hm;oW<ztX16dj;d~SGb7Om@)^_r8woE)wn@f4$$Ya_
z4qvv=8Onvb>QpsJj$GU{-IEDpIx5qCkhL+?d}DR@A&J<cyO>>@REisyGlL4~N@6<L
z_t`*+ydv?8?O{&`qCn$d<u^*)RbF6H<?Manp)a#1%)9oy=G$-d2PcizvKO-~^A_qS
z;iq;6E@c*)tcS|9QSG8`uV3Kx4@78}%5*I9xy+zHcer(pTjEr5xW4GdUlcJmSFKFE
zTRG|WP%6J0wD3XX=R~qTRSMt)qN8)S2RJr_Ge=w%UA1V#C1<@`LZ1u|C>Ut7{o}-n
zooDYvIGU=a$U_L`)22U~`h=?+7b|lNASu**Rl+;lluVZ8D{diS(udBdt`G{i@V5VP
z`b|rXw|OS!ajC+LLM0M(<{TOUT=9(T&Z;N|cTb1g7<&CK`2cfd<yQ|6v!HO@uZkRf
zx^R=1RxZ)fwaTeg=q2%J@A=qHNGUi^(XjgQYoY}9h}%gFLtwIidrScCm-blf4u@{H
zIYf5RP|UJsSW6_<PVuw#(vq-+lhZd$7BvlzD&7Oa_eNcON<=_F)uC+qvUB$iV~Nu)
zP8-7uqsQ}}T}z3Xs?R;W+->1pepM$=|4azjO1#`^%_CzGKBNhVTh_;r{KzRS$WGrz
zXSYodx+A=?5WkSmLD+zx{p-1+0NY;F;%rdtLpyoxz~y<^xOtygmKkFf(odWZh#rP1
zm!Q&HM)&S1>d~0@ipb{{>qC6=X*rXMFa-vts4Arpr485^4qV~hb-+|t6z6Lr-u7r<
zjKVgDx)sb|a`u|j{g!`9Q9|DhT<jNs(UI{b#(o^8?o!6JKRnxrNZ6Rny%uHAF*7`p
zk34yF1kbi7;8~^t42Q=8;TcQGv4!}jVhQx~#>Kap2w6Bo>!ixs9-HMdHXfH;z^ld>
z6@Iu}V;rZAcoiOek9z0NWwpD~Q7vykGQhy0Q<;se*@UTfX7A-4Hm6~`&vAMwM90s{
zsg}nNbL~_c_zJ?g!xDREd8>uxq5`{YP;TUhAJmxoq$$V5nMOh7AXz0;{MY^2Ts+77
zpqoLQK|<nVzkbM9fIA@FBJsCH@ZLn`8CjJcgx}IwI_;?S!D2i6B#7mX^b<?5__iVP
zr~Tb+gz^L3Y~<xq>^g#xdA{Kxn7%S9%hwG8<Yvf9LXl4P3_ObaC&T7CTg9!)8Ww`i
za@^1e(w2G)^u6Km;&|}zyJwTS%c1Qr1~V;eRYpXA{5}7BM<-ZNdXQeqYqp!P?h`Fh
z1rkkY#;h!G`NX;OYQd~ccFMm&Y4m4Qm3_Dyyw?4^^0x&p0kTR&hOnfnT+X(pNpX~$
zP>@kL0*WUQ37uLH!Po?yi&>#5d+pYb&{S<YqV60ncPPNsLIb_Q8V<R+k`Q^C)E0&H
zhS}>g<Fc~$?!iqzd5tP!K}oxlTE~9kHI34ASwD?j)DYiuc0EvnM_v}?T-=*+M1?m}
zWQBxl$3w=ZXD0j3BZoKD1NoWuvMJY)=PRSzm0TYZOHfcn?IVresYg!CRV<U~S0Pd>
zYV5<Nr~M^!j_TaPmd<^-sPzr`r{?qhN%n?yDxpd`!tEVsIMGybs8idifCEfzma|f2
z%RNFOQ^-ImFMl_`??hihgRq@Nrn*|QY3V_otfEDg6JK*b(#$lkLOS1djW8hN%eLbj
zvMgCUMY<7L8m*c#yHwL96)({ROIZ@f797zJd9|KvDU4MXL%I_f$;bpl3c{LnXmWM%
z1XCItunm?Z8c>_Cx3IKX1MK(zkRumQ@%8T|fw#QHfnr<UyY)^ZdWG&+;!to_kqPVE
z_KhPxzuJK8NdLwgsC5+nF36B;#sn@{F?N)gnj<(<XB_~&5dJiWu#n5zjC?TC^Vf(Y
zR26X=Ig?p?4Ac^F#C|-1^kMU@t687ED(mAsx+jrcLdBJ4>FV7l@e#GbmCKPnI)SE1
z*YbUAA*Y$VPGi__6qA4$WIM%ixeGQ{A9Uk*PigT-hDcth3ex9Sdo+}FUEvZynDj{o
zgChzYr6^_wD2P_8POfpw+9*zMLi}tjyqsk9M8{LnS?+d==>73E1J&Tb^v<r7bEP}Q
z$j1HvnSPes<to|_3I_APNNev8l>{i7;0G)5p;OOeA=T)bywg^$j-l{YS^pw*6nTfV
zB+};H?a+?5DQS?1>sbZ!_Bt6_-+AB~<)YX6mMCaU4DT*48$w`Fl@ZC!jFBwW;K1N{
z45Vw``(|v>+TpjIR@NO8&{@Jbx@nJzP_u|5X3z;yP`Ag-sJUsAeCTIy2#Ty14L@F~
zMckQI53&R8_WKwb{V1B(uL+KE((00=3&L-Rgnfw_lwG%V`-E^7u19}Jqi?HJWVLmp
zclTzjf_d8<tKToW=AHqsnn_kV+ndsf-{O;?8u1jPF__Enu$Lnp6I=RnQvKcE$MOc)
zW}+f!&SB`t%<6n0k4)i!`AMUW__85$T>UDkMMmMuw&zt8QIe;@YG3)i{$(lV;8mXc
zo+h}rczY*5OMh@^-r?7|iEtS3qL}IN6u`q~BNc|AwV8bA6T@*m0>cEWG*RzF2>kQQ
zU2xQWCJbvtmWc<-J#)1cy8q$H%68yzuCsz-or2-lvxE-cY_SoZnlInE1dYBb9$YC}
zMYM)Buflv@=Z)Fyz-UfB8NC#Uxb`z05xwr7FR65&J;T&I0URgTwLtDRfg{EA?f8JV
z3(Y?_l+)y>BbdLx$iqvsYFL4<hdfjmvhSWD=55xMA88u0-B20ZFzp}|BwMXvs6M;~
zl(~E()L7=;5!Wpu|1N5T8~3JdEoxLOuq9bqF(akG6ZX)imEGR|VQqshXyZ+~Dx_mz
zmN`N4!5}Fg>ZoQd%Km|rNT#E0fld95(gIk5gy_g_c*=umMAO;&0ne<6Y(%p1lN>=e
zj4@wxj9s?Ow*1&PZh75)(v08~z(6|QQ$G-tIIWa3yC^RP52?62@Ip2fnZ5kHzbnpa
z27?}cXucVne`~y%xdZnl4mxdrqS-O)h4s1Cmtx6_3*uPtpw!a=3%rg)G*X(NDm%W^
zM&YHBQ3<GXaH(yX8vFlCK%aqA2egphQU3-l{}|%Abu~Mu?Xf{oPAdyFrKB2kayY(i
zmvyx2Iws{I=*|Pv#k9_={d)P((P++E$imEdObcE~f{PSW(_F@==Seeu*>2@c{~63X
zNxEk`%>yw^9EnEWrEa!1dfE-%`A%k?IE_C%4;DQ1Edu%v_P)IRVeByfsl%}O>FZC&
z#1%`$2TBYp<nyWM`!Ww@yrK)l#<+B@k;o;dN}PllnZpE_Y5l|`Vf*y8ZI#~%e32)y
zBc+*aOROOwBv)p;9*zpI>H2qZ^4aN5hi9V-W9FEZSD2#Y21tB#>1dq_o1Dl!(S2;K
zgApE+@RXwq{%K$TQI><cO)omJGUzb0A~<%~>}$)nW6-wAa<{z`cMvUIwLHm!Z+%bG
zl_i#(D_YN;9MseI^}3JTZ?4mg5584o;7JH7YwaR~a=aqcP){^*sTrV$qpn7p(eJ*g
zu#06&P0V3}xB0P~xe6`>(NXw5bA+NxdB?@GvNvnW;EcVPZp)Ir^8@f^e_j}XGhCe-
z3lAQn(y<app7<d84@*0V8AlM0M->ls&~ZgY<Z8y1jAs+|^BdI(vBNK(pCpyqm1VnD
z4>rny5W~UT(J`vDE9-_9k7e2~C>*T%#;;KFvd!-#S0$+kKtWLwuh9EPiH)QUBR()?
z)fr9{0d3g!i%!v_eCw{4cg84@cf9Rj25$T#dxWmt?`K2}r^w5KaN?jOiDB!blAExp
z(Q>YFvmeqs_DpSkqCoiBh3n584`@E`(QagaoMW-E$kQY!!IOR=>+)N#6mL4PtH58h
z^$*)cW_M@iwh7yb^)n5RpE1aWrs+eZ@WFku&`%k-QydWjuoP1$N6h5I-P}OM<~48a
z@TR9fr|dOUD^Gfc5(l;_vO`eF((9u)h;4$0t}Y$D=x;vPcE4b1f#YMc>B!mGtV7SA
z!8FErIo4%se(0HoD)c95_+KR7f&zZE?Cz&L6@i`F-E&!0Pjzjtao|{&md!7_Nfj%V
z)mL>%=5fAlsuD&x(j-h;{FJETW?2>zjlOoeD0tDAsL(4;(+Me%`|^Cv>pmsj;?>YW
zqnI0!1B<8lI;F$&i@f_PwLi2|I>pSK=)T^4bvzx-s-A%Q<Tv;n9NV(~5?XidR#Hzf
z358VF%=-eI4o8{$gooiF%MX^Zri9HGZ{H~Bef%TCtn|=3L$7oXn4I!Pa(yTW$vGAM
zOxZN~${n>KgA3^$?BaOEAjT0zsPPHRDd5c+R`;y5tE+>0)(pA%o&8t|Gu4e;`96Sb
z><y6FWyxKwqRI(DwP3dng*K@-2+Hc;bw_DekeouwQJ!Jy8>^D`ePVbHXfKL;Y)zpS
z%(ry48eN`zA&hq}<A<IFR|^ex$u1dcS8uc`?`8SHn%&)>RieRr1~&bPd@uMql?T6T
zkPN2pWRy^u8K3jB2I|JYVUX0JV<dBn8MZ<nTtgF)tU1<51cOh<IM1j0dg|WW8{vA1
zEsB2+<(gIZuL12c8a2jj^G~8%E3L5`FL^RWlsDP(RKn~9U;0~Ff<1>0CqG*ChO~cc
z_pu~&v0?iD>ZAL>VNLxZo$cS;ln&^0Rimogz%QfE56_aye!#kiyw-aMJk?mqPGYMZ
z=kAk%l%N=ym<7yE;=E-`c-DMF{Un5#3pUsgQv?zz=Tjm9y()a!^t2%IO=>%@!bpAO
zr*+Bn{KAu2^T`i+yxD`jwAoEM+8LB|WcVZry%4{rBeX&Z()LP@5(?#CmDzvn*-xvr
zL1~{4Yf*Vw^jT}wUfr|omF7D1HVTye4I$6$G4yCym-cRr_7QJyy24ia#I`T>Z%6}G
z_`^Z9YhQmFq2xnV=jYydr1=|aeWofkb?TP5Bh}+kK$d-1Ti=rDRnya=iDQcb4e`l|
g(*ss}G0l)rEv-R=$=q)?l$0-jCvv}$+~w$h0rjWn;{X5v

literal 0
HcmV?d00001

diff --git a/main.py b/main.py
index b75b3f56..ddbe5d19 100644
--- a/main.py
+++ b/main.py
@@ -204,7 +204,7 @@ def execute_sql(query, conn, retries=2):
         messages = [HumanMessage(content=user_input_content)]
 
         state = MessagesState(messages=messages)
-        result = react_graph.invoke(state, config=config)
+        result = react_graph.invoke(state, config=config, debug=True)
 
         if result["messages"]:
             assistant_message = callback_handler.final_message
diff --git a/tools.py b/tools.py
index 5b5a4504..2be57ba5 100644
--- a/tools.py
+++ b/tools.py
@@ -1,15 +1,15 @@
 import streamlit as st
-from langchain.prompts.prompt import PromptTemplate
 from supabase.client import Client, create_client
-from langchain.tools.retriever import create_retriever_tool
 from langchain_openai import OpenAIEmbeddings
 from langchain_community.vectorstores import SupabaseVectorStore
+from langchain.tools.retriever import create_retriever_tool
+from langchain_community.tools import DuckDuckGoSearchRun
+from utils.snow_connect import SnowflakeConnection
 
 supabase_url = st.secrets["SUPABASE_URL"]
 supabase_key = st.secrets["SUPABASE_SERVICE_KEY"]
 supabase: Client = create_client(supabase_url, supabase_key)
 
-
 embeddings = OpenAIEmbeddings(
     openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002"
 )
@@ -20,9 +20,20 @@
     query_name="v_match_documents",
 )
 
-
 retriever_tool = create_retriever_tool(
     vectorstore.as_retriever(),
     name="Database_Schema",
     description="Search for database schema details",
 )
+
+search = DuckDuckGoSearchRun()
+
+def sql_executor_tool(query: str, use_cache: bool = True) -> str:
+    """
+    Execute snowflake sql queries with optional caching.
+    """
+    conn = SnowflakeConnection()
+    return conn.execute_query(query, use_cache)
+
+if __name__ == "__main__":
+    print(sql_executor_tool("select * from STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS"))
diff --git a/utils/snow_connect.py b/utils/snow_connect.py
index d0b396a6..f525adc3 100644
--- a/utils/snow_connect.py
+++ b/utils/snow_connect.py
@@ -1,12 +1,13 @@
 from typing import Any, Dict
-
+import json
+import requests
 import streamlit as st
 from snowflake.snowpark.session import Session
 
 
 class SnowflakeConnection:
     """
-    This class is used to establish a connection to Snowflake.
+    This class is used to establish a connection to Snowflake and execute queries with optional caching.
 
     Attributes
     ----------
@@ -19,16 +20,24 @@ class SnowflakeConnection:
     -------
     get_session()
         Establishes and returns the Snowflake connection session.
-
+    execute_query(query: str, use_cache: bool = True)
+        Executes a Snowflake SQL query with optional caching.
     """
 
     def __init__(self):
         self.connection_parameters = self._get_connection_parameters_from_env()
         self.session = None
+        self.cloudflare_account_id = st.secrets["CLOUDFLARE_ACCOUNT_ID"]
+        self.cloudflare_namespace_id = st.secrets["CLOUDFLARE_NAMESPACE_ID"]
+        self.cloudflare_api_token = st.secrets["CLOUDFLARE_API_TOKEN"]
+        self.headers = {
+            "Authorization": f"Bearer {self.cloudflare_api_token}",
+            "Content-Type": "application/json"
+        }
 
     @staticmethod
     def _get_connection_parameters_from_env() -> Dict[str, Any]:
-        connection_parameters = {
+        return {
             "account": st.secrets["ACCOUNT"],
             "user": st.secrets["USER_NAME"],
             "password": st.secrets["PASSWORD"],
@@ -37,7 +46,6 @@ def _get_connection_parameters_from_env() -> Dict[str, Any]:
             "schema": st.secrets["SCHEMA"],
             "role": st.secrets["ROLE"],
         }
-        return connection_parameters
 
     def get_session(self):
         """
@@ -49,3 +57,45 @@ def get_session(self):
             self.session = Session.builder.configs(self.connection_parameters).create()
             self.session.sql_simplifier_enabled = True
         return self.session
+
+    def _construct_kv_url(self, key: str) -> str:
+        return f"https://api.cloudflare.com/client/v4/accounts/{self.cloudflare_account_id}/storage/kv/namespaces/{self.cloudflare_namespace_id}/values/{key}"
+
+    def get_from_cache(self, key: str) -> str:
+        url = self._construct_kv_url(key)
+        try:
+            response = requests.get(url, headers=self.headers)
+            response.raise_for_status()
+            print("\n\n\nCache hit\n\n\n")
+            return response.text
+        except requests.exceptions.RequestException as e:
+            print(f"Cache miss or error: {e}")
+        return None
+
+    def set_to_cache(self, key: str, value: str) -> None:
+        url = self._construct_kv_url(key)
+        serialized_value = json.dumps(value)
+        try:
+            response = requests.put(url, headers=self.headers, data=serialized_value)
+            response.raise_for_status()
+            print("Cache set successfully")
+        except requests.exceptions.RequestException as e:
+            print(f"Failed to set cache: {e}")
+
+    def execute_query(self, query: str, use_cache: bool = True) -> str:
+        """
+        Execute a Snowflake SQL query with optional caching.
+        """
+        if use_cache:
+            cached_response = self.get_from_cache(query)
+            if cached_response:
+                return json.loads(cached_response)
+
+        session = self.get_session()
+        result = session.sql(query).collect()
+        result_list = [row.as_dict() for row in result]
+
+        if use_cache:
+            self.set_to_cache(query, result_list)
+
+        return result_list

From cc1bbd0d7e777169ba67d8ab20c4f5adb6a81531 Mon Sep 17 00:00:00 2001
From: kaarthik108 <kaarthikandavar@gmail.com>
Date: Fri, 18 Oct 2024 19:18:07 +1300
Subject: [PATCH 6/9] update reqs

---
 requirements.txt | 28 ++++++++++++++--------------
 tools.py         |  4 ++--
 2 files changed, 16 insertions(+), 16 deletions(-)

diff --git a/requirements.txt b/requirements.txt
index 98406b92..3398eaa4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,15 +1,15 @@
-langchain==0.2.12
-pandas==1.5.0
-pydantic==1.10.8
+langchain==0.3.3
+langchain_anthropic==0.2.3
+langchain_community==0.3.2
+langchain_core==0.3.12
+langchain_openai==0.2.2
+langchain-google-genai==2.0.1
+langgraph==0.2.38
+Pillow==11.0.0
+pydantic==2.9.2
+Requests==2.32.3
+snowflake_connector_python==3.1.0
 snowflake_snowpark_python==1.5.0
-snowflake-snowpark-python[pandas]
-streamlit==1.31.0
-supabase==2.4.1
-unstructured
-tiktoken
-openai
-black
-langchain_openai
-langchain-community
-langchain-core
-langchain-anthropic
\ No newline at end of file
+streamlit==1.33.0
+websocket_client==1.7.0
+duckduckgo_search==6.3.0
\ No newline at end of file
diff --git a/tools.py b/tools.py
index 2be57ba5..fa89599e 100644
--- a/tools.py
+++ b/tools.py
@@ -35,5 +35,5 @@ def sql_executor_tool(query: str, use_cache: bool = True) -> str:
     conn = SnowflakeConnection()
     return conn.execute_query(query, use_cache)
 
-if __name__ == "__main__":
-    print(sql_executor_tool("select * from STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS"))
+# if __name__ == "__main__":
+#     print(sql_executor_tool("select * from STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS"))

From 5b30c10fb57192e555b86969a0f3a2dba11b2c59 Mon Sep 17 00:00:00 2001
From: kaarthik108 <kaarthikandavar@gmail.com>
Date: Fri, 18 Oct 2024 19:28:17 +1300
Subject: [PATCH 7/9] update prompt

---
 agent.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/agent.py b/agent.py
index fa0f2a69..2dc8ded9 100644
--- a/agent.py
+++ b/agent.py
@@ -60,7 +60,7 @@ class ModelConfig:
     content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. Do not ask the user for schema or database details. You have access to the following tools:
     - Database_Schema: This tool allows you to search for database schema details when needed to generate the SQL code.
     - Internet_Search: This tool allows you to search the internet for snowflake sql related information when needed to generate the SQL code.
-    - Snowflake_SQL_Executor: This tool allows you to execute snowflake sql queries when needed to generate the SQL code.
+    - Snowflake_SQL_Executor: This tool allows you to execute snowflake sql queries when needed to generate the SQL code. You only have read access to the database, do not modify the database in any way.
     """
 )
 tools = [retriever_tool, search, sql_executor_tool]

From c5b26abffedc69e9c314c19fb94620fe244f2f85 Mon Sep 17 00:00:00 2001
From: kaarthik108 <kaarthikandavar@gmail.com>
Date: Fri, 18 Oct 2024 19:28:53 +1300
Subject: [PATCH 8/9] remove pillow

---
 requirements.txt | 1 -
 1 file changed, 1 deletion(-)

diff --git a/requirements.txt b/requirements.txt
index 3398eaa4..3a356826 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,7 +5,6 @@ langchain_core==0.3.12
 langchain_openai==0.2.2
 langchain-google-genai==2.0.1
 langgraph==0.2.38
-Pillow==11.0.0
 pydantic==2.9.2
 Requests==2.32.3
 snowflake_connector_python==3.1.0

From d62fbb2bed99e58bd9a4e73a8cd94029397c724e Mon Sep 17 00:00:00 2001
From: kaarthik108 <kaarthikandavar@gmail.com>
Date: Fri, 18 Oct 2024 19:38:34 +1300
Subject: [PATCH 9/9] add error messages

---
 .github/workflows/lint.yml | 42 --------------------------------------
 README.md                  | 21 +++++++++----------
 agent.py                   | 10 +++++----
 utils/snowchat_ui.py       |  1 -
 4 files changed, 16 insertions(+), 58 deletions(-)
 delete mode 100644 .github/workflows/lint.yml

diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
deleted file mode 100644
index 6c7256e3..00000000
--- a/.github/workflows/lint.yml
+++ /dev/null
@@ -1,42 +0,0 @@
-name: Lint
-
-on:
-  push:
-    branches:
-      - main
-  pull_request:
-    branches:
-      - main
-
-jobs:
-  lint:
-    name: Lint and Format Code
-    runs-on: ubuntu-latest
-
-    steps:
-      - name: Check out repository
-        uses: actions/checkout@v3
-
-      - name: Set up Python
-        uses: actions/setup-python@v4
-        with:
-          python-version: "3.9"
-
-      - name: Cache pip dependencies
-        uses: actions/cache@v3
-        with:
-          path: ~/.cache/pip
-          key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
-          restore-keys: |
-            ${{ runner.os }}-pip-
-
-      - name: Install dependencies
-        run: |
-          python -m pip install --upgrade pip
-          pip install -r requirements.txt
-          pip install black ruff mypy codespell
-
-      - name: Run Formatting and Linting
-        run: |
-          make format
-          make lint
diff --git a/README.md b/README.md
index afe05936..7590ad5e 100644
--- a/README.md
+++ b/README.md
@@ -15,9 +15,11 @@
 
 ## Supported LLM's
 
-- GPT-3.5-turbo-0125
-- CodeLlama-70B
-- Mistral Medium
+- GPT-4o
+- Gemini Flash 1.5 8B
+- Claude 3 Haiku
+- Llama 3.2 3B
+- Llama 3.1 405B
 
 #
 
@@ -27,11 +29,12 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6-
 
 ## 🌟 Features
 
-- **Conversational AI**: Harnesses ChatGPT to translate natural language into precise SQL queries.
+- **Conversational AI**: Use ChatGPT and other models to translate natural language into precise SQL queries.
 - **Conversational Memory**: Retains context for interactive, dynamic responses.
 - **Snowflake Integration**: Offers seamless, real-time data insights straight from your Snowflake database.
 - **Self-healing SQL**: Proactively suggests solutions for SQL errors, streamlining data access.
 - **Interactive User Interface**: Transforms data querying into an engaging conversation, complete with a chat reset option.
+- **Agent-based Architecture**: Utilizes an agent to manage interactions and tool usage.
 
 ## 🛠️ Installation
 
@@ -42,7 +45,9 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6-
    cd snowchat
    pip install -r requirements.txt
 
-3. Set up your `OPENAI_API_KEY`, `ACCOUNT`, `USER_NAME`, `PASSWORD`, `ROLE`, `DATABASE`, `SCHEMA`, `WAREHOUSE`, `SUPABASE_URL` , `SUPABASE_SERVICE_KEY` and `REPLICATE_API_TOKEN` in project directory `secrets.toml`.
+3. Set up your `OPENAI_API_KEY`, `ACCOUNT`, `USER_NAME`, `PASSWORD`, `ROLE`, `DATABASE`, `SCHEMA`, `WAREHOUSE`, `SUPABASE_URL` , `SUPABASE_SERVICE_KEY`, `SUPABASE_STORAGE_URL`,`CLOUDFLARE_ACCOUNT_ID`, `CLOUDFLARE_NAMESPACE_ID`,
+   `CLOUDFLARE_API_TOKEN` in project directory `secrets.toml`.
+   Cloudflare is used here for caching Snowflake responses in KV.
 
 4. Make you're schemas and store them in docs folder that matches you're database.
 
@@ -53,12 +58,6 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6-
 7. Run the Streamlit app to start chatting:
    streamlit run main.py
 
-## 🚀 Additional Enhancements
-
-1. **Platform Integration**: Connect snowChat with popular communication platforms like Slack or Discord for seamless interaction.
-2. **Voice Integration**: Implement voice recognition and text-to-speech functionality to make the chatbot more interactive and user-friendly.
-3. **Advanced Analytics**: Integrate with popular data visualization libraries like Plotly or Matplotlib to generate interactive visualizations based on the user's queries (AutoGPT).
-
 ## Star History
 
 [![Star History Chart](https://api.star-history.com/svg?repos=kaarthik108/snowChat&type=Date)]
diff --git a/agent.py b/agent.py
index 2dc8ded9..95a8c107 100644
--- a/agent.py
+++ b/agent.py
@@ -35,7 +35,7 @@ class ModelConfig:
 
 model_configurations = {
     "gpt-4o": ModelConfig(
-        model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY")
+        model_name="gpt-4o", api_key=st.secrets["OPENAI_API_KEY"]
     ),
     "Gemini Flash 1.5 8B": ModelConfig(
         model_name="google/gemini-flash-1.5-8b",
@@ -43,16 +43,16 @@ class ModelConfig:
         base_url="https://openrouter.ai/api/v1",
     ),
     "claude3-haiku": ModelConfig(
-        model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY")
+        model_name="claude-3-haiku-20240307", api_key=st.secrets["ANTHROPIC_API_KEY"]
     ),
     "llama-3.2-3b": ModelConfig(
         model_name="accounts/fireworks/models/llama-v3p2-3b-instruct",
-        api_key=os.getenv("FIREWORKS_API_KEY"),
+        api_key=st.secrets["FIREWORKS_API_KEY"],
         base_url="https://api.fireworks.ai/inference/v1",
     ),
     "llama-3.1-405b": ModelConfig(
         model_name="accounts/fireworks/models/llama-v3p1-405b-instruct",
-        api_key=os.getenv("FIREWORKS_API_KEY"),
+        api_key=st.secrets["FIREWORKS_API_KEY"],
         base_url="https://api.fireworks.ai/inference/v1",
     ),
 }
@@ -70,6 +70,8 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> Stat
     if not config:
         raise ValueError(f"Unsupported model name: {model_name}")
 
+    if not config.api_key:
+        raise ValueError(f"API key for model '{model_name}' is not set. Please check your environment variables or secrets configuration.")
 
     llm = (
         ChatOpenAI(
diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py
index 2fe60d0e..05db4b1e 100644
--- a/utils/snowchat_ui.py
+++ b/utils/snowchat_ui.py
@@ -1,6 +1,5 @@
 import html
 import re
-import textwrap
 
 import streamlit as st
 from langchain.callbacks.base import BaseCallbackHandler