-
-
Notifications
You must be signed in to change notification settings - Fork 273
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8e1d610
commit 17fe28a
Showing
7 changed files
with
222 additions
and
183 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,4 +6,6 @@ secrets.toml | |
archived_logs/ | ||
|
||
build/ | ||
snowchat.egg-info/ | ||
snowchat.egg-info/ | ||
|
||
chroma_db |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,161 +1,141 @@ | ||
from typing import Any, Callable, Dict, Optional | ||
|
||
import boto3 | ||
import streamlit as st | ||
from langchain.chains import ConversationalRetrievalChain, LLMChain | ||
from langchain.chains.question_answering import load_qa_chain | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.embeddings.openai import OpenAIEmbeddings | ||
from langchain.llms import OpenAI, Replicate | ||
from langchain.prompts.prompt import PromptTemplate | ||
from langchain.llms.bedrock import Bedrock | ||
from langchain.vectorstores import SupabaseVectorStore | ||
from pydantic import BaseModel, validator | ||
from supabase.client import Client, create_client | ||
|
||
template = """You are an AI chatbot having a conversation with a human. | ||
Chat History:\""" | ||
{chat_history} | ||
\""" | ||
Human Input: \""" | ||
{question} | ||
\""" | ||
AI:""" | ||
|
||
condense_question_prompt = PromptTemplate.from_template(template) | ||
|
||
TEMPLATE = """ | ||
You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. | ||
When asked about your capabilities, provide a general overview of your ability to assist with data analysis tasks using Snowflake SQL, instead of performing specific SQL queries. | ||
Based on the question provided, if it pertains to data analysis or SQL tasks, generate SQL code that is compatible with the Snowflake environment. Additionally, offer a brief explanation about how you arrived at the SQL code. If the required column isn't explicitly stated in the context, suggest an alternative using available columns, but do not assume the existence of any columns that are not mentioned. Also, do not modify the database in any way (no insert, update, or delete operations). You are only allowed to query the database. Refrain from using the information schema. | ||
**You are only required to write one SQL query per question.** | ||
If the question or context does not clearly involve SQL or data analysis tasks, respond appropriately without generating SQL queries. | ||
When the user expresses gratitude or says "Thanks", interpret it as a signal to conclude the conversation. Respond with an appropriate closing statement without generating further SQL queries. | ||
If you don't know the answer, simply state, "I'm sorry, I don't know the answer to your question." | ||
Write your response in markdown format. | ||
Question: ```{question}``` | ||
{context} | ||
Answer: | ||
""" | ||
B_INST, E_INST = "[INST]", "[/INST]" | ||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | ||
|
||
LLAMA_TEMPLATE = """ | ||
You're specialized with Snowflake SQL. When providing answers, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. | ||
If the question or context does not clearly involve SQL or data analysis tasks, respond appropriately without generating SQL queries. | ||
If you don't know the answer, simply state, "I'm sorry, I don't know the answer to your question." | ||
Write SQL code for this Question based on the below context details: {question} | ||
<<CONTEXT>> | ||
context: \n {context} | ||
<</CONTEXT>> | ||
write responses in markdown format | ||
Answer: | ||
""" | ||
|
||
LLAMA_TEMPLATE = B_INST + B_SYS + LLAMA_TEMPLATE + E_SYS + E_INST | ||
|
||
QA_PROMPT = PromptTemplate(template=TEMPLATE, input_variables=["question", "context"]) | ||
LLAMA_PROMPT = PromptTemplate( | ||
template=LLAMA_TEMPLATE, input_variables=["question", "context"] | ||
) | ||
from template import CONDENSE_QUESTION_PROMPT, LLAMA_PROMPT, QA_PROMPT | ||
|
||
supabase_url = st.secrets["SUPABASE_URL"] | ||
supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] | ||
supabase: Client = create_client(supabase_url, supabase_key) | ||
|
||
VERSION = "be553392065353425e0f0193d2a896d6a5ff201549f5d7cd9180c8dfdeac39ed" | ||
VERSION = "1f01a52ff933873dff339d5fb5e1fd6f24f77456836f514fa05e91c1a42699c7" | ||
LLAMA = "meta/codellama-13b-instruct:{}".format(VERSION) | ||
|
||
|
||
def get_chain_replicate(vectorstore, callback_handler=None): | ||
""" | ||
Get a chain for chatting with a vector database. | ||
""" | ||
q_llm = Replicate( | ||
model=LLAMA, | ||
input={"temperature": 0.2, "max_length": 200, "top_p": 1}, | ||
replicate_api_token=st.secrets["REPLICATE_API_TOKEN"], | ||
) | ||
llm = Replicate( | ||
streaming=True, | ||
callbacks=[callback_handler], | ||
model=LLAMA, | ||
input={"temperature": 0.2, "max_length": 300, "top_p": 1}, | ||
replicate_api_token=st.secrets["REPLICATE_API_TOKEN"], | ||
) | ||
|
||
question_generator = LLMChain(llm=q_llm, prompt=condense_question_prompt) | ||
|
||
doc_chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=QA_PROMPT) | ||
conv_chain = ConversationalRetrievalChain( | ||
callbacks=[callback_handler], | ||
retriever=vectorstore.as_retriever(), | ||
combine_docs_chain=doc_chain, | ||
question_generator=question_generator, | ||
) | ||
|
||
return conv_chain | ||
|
||
|
||
def get_chain_gpt(vectorstore, callback_handler=None): | ||
""" | ||
Get a chain for chatting with a vector database. | ||
""" | ||
q_llm = OpenAI( | ||
temperature=0.1, | ||
openai_api_key=st.secrets["OPENAI_API_KEY"], | ||
model_name="gpt-3.5-turbo-16k", | ||
max_tokens=500, | ||
) | ||
|
||
llm = ChatOpenAI( | ||
model_name="gpt-3.5-turbo", | ||
temperature=0.5, | ||
openai_api_key=st.secrets["OPENAI_API_KEY"], | ||
max_tokens=500, | ||
callbacks=[callback_handler], | ||
streaming=True, | ||
) | ||
question_generator = LLMChain(llm=q_llm, prompt=condense_question_prompt) | ||
|
||
doc_chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=QA_PROMPT) | ||
conv_chain = ConversationalRetrievalChain( | ||
retriever=vectorstore.as_retriever(), | ||
combine_docs_chain=doc_chain, | ||
question_generator=question_generator, | ||
) | ||
|
||
return conv_chain | ||
class ModelConfig(BaseModel): | ||
model_type: str | ||
secrets: Dict[str, Any] | ||
callback_handler: Optional[Callable] = None | ||
|
||
@validator("model_type", pre=True, always=True) | ||
def validate_model_type(cls, v): | ||
if v not in ["code-llama", "gpt", "claude"]: | ||
raise ValueError(f"Unsupported model type: {v}") | ||
return v | ||
|
||
|
||
class ModelWrapper: | ||
def __init__(self, config: ModelConfig): | ||
self.model_type = config.model_type | ||
self.secrets = config.secrets | ||
self.callback_handler = config.callback_handler | ||
self.setup() | ||
|
||
def setup(self): | ||
if self.model_type == "code-llama": | ||
self.setup_llama() | ||
elif self.model_type == "gpt": | ||
self.setup_gpt() | ||
elif self.model_type == "claude": | ||
self.setup_claude() | ||
|
||
def setup_llama(self): | ||
self.q_llm = Replicate( | ||
model=LLAMA, | ||
input={"temperature": 0.2, "max_length": 200, "top_p": 1}, | ||
replicate_api_token=self.secrets["REPLICATE_API_TOKEN"], | ||
) | ||
self.llm = Replicate( | ||
streaming=True, | ||
callbacks=[self.callback_handler], | ||
model=LLAMA, | ||
input={"temperature": 0.2, "max_length": 300, "top_p": 1}, | ||
replicate_api_token=self.secrets["REPLICATE_API_TOKEN"], | ||
) | ||
|
||
def setup_gpt(self): | ||
self.q_llm = OpenAI( | ||
temperature=0.1, | ||
openai_api_key=self.secrets["OPENAI_API_KEY"], | ||
model_name="gpt-3.5-turbo-16k", | ||
max_tokens=500, | ||
) | ||
|
||
self.llm = ChatOpenAI( | ||
model_name="gpt-3.5-turbo-16k", | ||
temperature=0.5, | ||
openai_api_key=self.secrets["OPENAI_API_KEY"], | ||
max_tokens=500, | ||
callbacks=[self.callback_handler], | ||
streaming=True, | ||
) | ||
|
||
def setup_claude(self): | ||
bedrock_runtime = boto3.client( | ||
service_name="bedrock-runtime", | ||
aws_access_key_id=self.secrets["AWS_ACCESS_KEY_ID"], | ||
aws_secret_access_key=self.secrets["AWS_SECRET_ACCESS_KEY"], | ||
region_name="us-east-1", | ||
) | ||
parameters = { | ||
"max_tokens_to_sample": 1000, | ||
"stop_sequences": [], | ||
"temperature": 0, | ||
"top_p": 0.9, | ||
} | ||
self.q_llm = Bedrock( | ||
model_id="anthropic.claude-instant-v1", client=bedrock_runtime | ||
) | ||
self.llm = Bedrock( | ||
model_id="anthropic.claude-instant-v1", | ||
client=bedrock_runtime, | ||
callbacks=[self.callback_handler], | ||
streaming=True, | ||
model_kwargs=parameters, | ||
) | ||
|
||
def get_chain(self, vectorstore): | ||
if not self.q_llm or not self.llm: | ||
raise ValueError("Models have not been properly initialized.") | ||
question_generator = LLMChain(llm=self.q_llm, prompt=CONDENSE_QUESTION_PROMPT) | ||
doc_chain = load_qa_chain(llm=self.llm, chain_type="stuff", prompt=QA_PROMPT) | ||
conv_chain = ConversationalRetrievalChain( | ||
retriever=vectorstore.as_retriever(), | ||
combine_docs_chain=doc_chain, | ||
question_generator=question_generator, | ||
) | ||
return conv_chain | ||
|
||
|
||
def load_chain(model_name="GPT-3.5", callback_handler=None): | ||
""" | ||
Load the chain from the local file system | ||
Returns: | ||
chain (Chain): The chain object | ||
""" | ||
|
||
embeddings = OpenAIEmbeddings( | ||
openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" | ||
) | ||
vectorstore = SupabaseVectorStore( | ||
embedding=embeddings, client=supabase, table_name="documents" | ||
embedding=embeddings, | ||
client=supabase, | ||
table_name="documents", | ||
query_name="v_match_documents", | ||
) | ||
return ( | ||
get_chain_gpt(vectorstore, callback_handler=callback_handler) | ||
if "GPT-3.5" in model_name | ||
else get_chain_replicate(vectorstore, callback_handler=callback_handler) | ||
|
||
if "claude" in model_name.lower(): | ||
model_type = "claude" | ||
elif "GPT-3.5" in model_name: | ||
model_type = "gpt" | ||
else: | ||
model_type = "code-llama" | ||
|
||
config = ModelConfig( | ||
model_type=model_type, secrets=st.secrets, callback_handler=callback_handler | ||
) | ||
model = ModelWrapper(config) | ||
return model.get_chain(vectorstore) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
langchain==0.0.266 | ||
langchain==0.0.305 | ||
pandas==1.5.0 | ||
pydantic==1.10.8 | ||
snowflake_snowpark_python==1.5.0 | ||
snowflake-snowpark-python[pandas] | ||
streamlit==1.24.0 | ||
streamlit==1.27.1 | ||
supabase==1.0.3 | ||
unstructured==0.7.12 | ||
tiktoken==0.4.0 | ||
openai==0.27.8 | ||
black==23.3.0 | ||
replicate==0.8.4 | ||
replicate==0.8.4 | ||
boto3==1.28.57 |
Oops, something went wrong.