Skip to content

Commit

Permalink
Add Claude from Bedrock
Browse files Browse the repository at this point in the history
  • Loading branch information
kaarthik108 committed Sep 30, 2023
1 parent 8e1d610 commit 17fe28a
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 183 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ secrets.toml
archived_logs/

build/
snowchat.egg-info/
snowchat.egg-info/

chroma_db
252 changes: 116 additions & 136 deletions chain.py
Original file line number Diff line number Diff line change
@@ -1,161 +1,141 @@
from typing import Any, Callable, Dict, Optional

import boto3
import streamlit as st
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI, Replicate
from langchain.prompts.prompt import PromptTemplate
from langchain.llms.bedrock import Bedrock
from langchain.vectorstores import SupabaseVectorStore
from pydantic import BaseModel, validator
from supabase.client import Client, create_client

template = """You are an AI chatbot having a conversation with a human.
Chat History:\"""
{chat_history}
\"""
Human Input: \"""
{question}
\"""
AI:"""

condense_question_prompt = PromptTemplate.from_template(template)

TEMPLATE = """
You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate.
When asked about your capabilities, provide a general overview of your ability to assist with data analysis tasks using Snowflake SQL, instead of performing specific SQL queries.
Based on the question provided, if it pertains to data analysis or SQL tasks, generate SQL code that is compatible with the Snowflake environment. Additionally, offer a brief explanation about how you arrived at the SQL code. If the required column isn't explicitly stated in the context, suggest an alternative using available columns, but do not assume the existence of any columns that are not mentioned. Also, do not modify the database in any way (no insert, update, or delete operations). You are only allowed to query the database. Refrain from using the information schema.
**You are only required to write one SQL query per question.**
If the question or context does not clearly involve SQL or data analysis tasks, respond appropriately without generating SQL queries.
When the user expresses gratitude or says "Thanks", interpret it as a signal to conclude the conversation. Respond with an appropriate closing statement without generating further SQL queries.
If you don't know the answer, simply state, "I'm sorry, I don't know the answer to your question."
Write your response in markdown format.
Question: ```{question}```
{context}
Answer:
"""
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

LLAMA_TEMPLATE = """
You're specialized with Snowflake SQL. When providing answers, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate.
If the question or context does not clearly involve SQL or data analysis tasks, respond appropriately without generating SQL queries.
If you don't know the answer, simply state, "I'm sorry, I don't know the answer to your question."
Write SQL code for this Question based on the below context details: {question}
<<CONTEXT>>
context: \n {context}
<</CONTEXT>>
write responses in markdown format
Answer:
"""

LLAMA_TEMPLATE = B_INST + B_SYS + LLAMA_TEMPLATE + E_SYS + E_INST

QA_PROMPT = PromptTemplate(template=TEMPLATE, input_variables=["question", "context"])
LLAMA_PROMPT = PromptTemplate(
template=LLAMA_TEMPLATE, input_variables=["question", "context"]
)
from template import CONDENSE_QUESTION_PROMPT, LLAMA_PROMPT, QA_PROMPT

supabase_url = st.secrets["SUPABASE_URL"]
supabase_key = st.secrets["SUPABASE_SERVICE_KEY"]
supabase: Client = create_client(supabase_url, supabase_key)

VERSION = "be553392065353425e0f0193d2a896d6a5ff201549f5d7cd9180c8dfdeac39ed"
VERSION = "1f01a52ff933873dff339d5fb5e1fd6f24f77456836f514fa05e91c1a42699c7"
LLAMA = "meta/codellama-13b-instruct:{}".format(VERSION)


def get_chain_replicate(vectorstore, callback_handler=None):
"""
Get a chain for chatting with a vector database.
"""
q_llm = Replicate(
model=LLAMA,
input={"temperature": 0.2, "max_length": 200, "top_p": 1},
replicate_api_token=st.secrets["REPLICATE_API_TOKEN"],
)
llm = Replicate(
streaming=True,
callbacks=[callback_handler],
model=LLAMA,
input={"temperature": 0.2, "max_length": 300, "top_p": 1},
replicate_api_token=st.secrets["REPLICATE_API_TOKEN"],
)

question_generator = LLMChain(llm=q_llm, prompt=condense_question_prompt)

doc_chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=QA_PROMPT)
conv_chain = ConversationalRetrievalChain(
callbacks=[callback_handler],
retriever=vectorstore.as_retriever(),
combine_docs_chain=doc_chain,
question_generator=question_generator,
)

return conv_chain


def get_chain_gpt(vectorstore, callback_handler=None):
"""
Get a chain for chatting with a vector database.
"""
q_llm = OpenAI(
temperature=0.1,
openai_api_key=st.secrets["OPENAI_API_KEY"],
model_name="gpt-3.5-turbo-16k",
max_tokens=500,
)

llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
temperature=0.5,
openai_api_key=st.secrets["OPENAI_API_KEY"],
max_tokens=500,
callbacks=[callback_handler],
streaming=True,
)
question_generator = LLMChain(llm=q_llm, prompt=condense_question_prompt)

doc_chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=QA_PROMPT)
conv_chain = ConversationalRetrievalChain(
retriever=vectorstore.as_retriever(),
combine_docs_chain=doc_chain,
question_generator=question_generator,
)

return conv_chain
class ModelConfig(BaseModel):
model_type: str
secrets: Dict[str, Any]
callback_handler: Optional[Callable] = None

@validator("model_type", pre=True, always=True)
def validate_model_type(cls, v):
if v not in ["code-llama", "gpt", "claude"]:
raise ValueError(f"Unsupported model type: {v}")
return v


class ModelWrapper:
def __init__(self, config: ModelConfig):
self.model_type = config.model_type
self.secrets = config.secrets
self.callback_handler = config.callback_handler
self.setup()

def setup(self):
if self.model_type == "code-llama":
self.setup_llama()
elif self.model_type == "gpt":
self.setup_gpt()
elif self.model_type == "claude":
self.setup_claude()

def setup_llama(self):
self.q_llm = Replicate(
model=LLAMA,
input={"temperature": 0.2, "max_length": 200, "top_p": 1},
replicate_api_token=self.secrets["REPLICATE_API_TOKEN"],
)
self.llm = Replicate(
streaming=True,
callbacks=[self.callback_handler],
model=LLAMA,
input={"temperature": 0.2, "max_length": 300, "top_p": 1},
replicate_api_token=self.secrets["REPLICATE_API_TOKEN"],
)

def setup_gpt(self):
self.q_llm = OpenAI(
temperature=0.1,
openai_api_key=self.secrets["OPENAI_API_KEY"],
model_name="gpt-3.5-turbo-16k",
max_tokens=500,
)

self.llm = ChatOpenAI(
model_name="gpt-3.5-turbo-16k",
temperature=0.5,
openai_api_key=self.secrets["OPENAI_API_KEY"],
max_tokens=500,
callbacks=[self.callback_handler],
streaming=True,
)

def setup_claude(self):
bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=self.secrets["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=self.secrets["AWS_SECRET_ACCESS_KEY"],
region_name="us-east-1",
)
parameters = {
"max_tokens_to_sample": 1000,
"stop_sequences": [],
"temperature": 0,
"top_p": 0.9,
}
self.q_llm = Bedrock(
model_id="anthropic.claude-instant-v1", client=bedrock_runtime
)
self.llm = Bedrock(
model_id="anthropic.claude-instant-v1",
client=bedrock_runtime,
callbacks=[self.callback_handler],
streaming=True,
model_kwargs=parameters,
)

def get_chain(self, vectorstore):
if not self.q_llm or not self.llm:
raise ValueError("Models have not been properly initialized.")
question_generator = LLMChain(llm=self.q_llm, prompt=CONDENSE_QUESTION_PROMPT)
doc_chain = load_qa_chain(llm=self.llm, chain_type="stuff", prompt=QA_PROMPT)
conv_chain = ConversationalRetrievalChain(
retriever=vectorstore.as_retriever(),
combine_docs_chain=doc_chain,
question_generator=question_generator,
)
return conv_chain


def load_chain(model_name="GPT-3.5", callback_handler=None):
"""
Load the chain from the local file system
Returns:
chain (Chain): The chain object
"""

embeddings = OpenAIEmbeddings(
openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002"
)
vectorstore = SupabaseVectorStore(
embedding=embeddings, client=supabase, table_name="documents"
embedding=embeddings,
client=supabase,
table_name="documents",
query_name="v_match_documents",
)
return (
get_chain_gpt(vectorstore, callback_handler=callback_handler)
if "GPT-3.5" in model_name
else get_chain_replicate(vectorstore, callback_handler=callback_handler)

if "claude" in model_name.lower():
model_type = "claude"
elif "GPT-3.5" in model_name:
model_type = "gpt"
else:
model_type = "code-llama"

config = ModelConfig(
model_type=model_type, secrets=st.secrets, callback_handler=callback_handler
)
model = ModelWrapper(config)
return model.get_chain(vectorstore)
18 changes: 8 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
st.caption("Talk your way through data")
model = st.radio(
"",
options=["✨ GPT-3.5", "🐐 code-LLama"],
options=["✨ GPT-3.5", "🐐 code-LLama", "♾️ Claude"],
index=0,
horizontal=True,
)
Expand Down Expand Up @@ -94,8 +94,6 @@ def get_sql(text):

def append_message(content, role="assistant", display=False):
message = {"role": role, "content": content}
if model == "LLama-2": # unable to get streaming working with LLama-2
message_func(content, False, display)
st.session_state.messages.append(message)
if role != "data":
append_chat_history(st.session_state.messages[-2]["content"], content)
Expand Down Expand Up @@ -138,11 +136,11 @@ def execute_sql(query, conn, retries=2):
result = chain(
{"question": content, "chat_history": st.session_state["history"]}
)["answer"]
# print(result)
print(result)
append_message(result)
if get_sql(result):
conn = SnowflakeConnection().get_session()
df = execute_sql(get_sql(result), conn)
if df is not None:
callback_handler.display_dataframe(df)
append_message(df, "data", True)
# if get_sql(result):
# conn = SnowflakeConnection().get_session()
# df = execute_sql(get_sql(result), conn)
# if df is not None:
# callback_handler.display_dataframe(df)
# append_message(df, "data", True)
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
langchain==0.0.266
langchain==0.0.305
pandas==1.5.0
pydantic==1.10.8
snowflake_snowpark_python==1.5.0
snowflake-snowpark-python[pandas]
streamlit==1.24.0
streamlit==1.27.1
supabase==1.0.3
unstructured==0.7.12
tiktoken==0.4.0
openai==0.27.8
black==23.3.0
replicate==0.8.4
replicate==0.8.4
boto3==1.28.57
Loading

0 comments on commit 17fe28a

Please sign in to comment.