diff --git a/chain.py b/chain.py index 8addf90..1896620 100644 --- a/chain.py +++ b/chain.py @@ -33,7 +33,8 @@ class ModelConfig(BaseModel): @validator("model_type", pre=True, always=True) def validate_model_type(cls, v): - if v not in ["gpt", "llama", "claude", "mixtral8x7b", "arctic"]: + valid_model_types = ["qwen", "llama", "claude", "mixtral8x7b", "arctic"] + if v not in valid_model_types: raise ValueError(f"Unsupported model type: {v}") return v @@ -43,85 +44,47 @@ def __init__(self, config: ModelConfig): self.model_type = config.model_type self.secrets = config.secrets self.callback_handler = config.callback_handler - account_tag = self.secrets["CF_ACCOUNT_TAG"] - self.gateway_url = ( - f"https://gateway.ai.cloudflare.com/v1/{account_tag}/k-1-gpt/openai" - ) - self.setup() - - def setup(self): - if self.model_type == "gpt": - self.setup_gpt() - elif self.model_type == "claude": - self.setup_claude() - elif self.model_type == "mixtral8x7b": - self.setup_mixtral_8x7b() - elif self.model_type == "llama": - self.setup_llama() - elif self.model_type == "arctic": - self.setup_arctic() - - def setup_gpt(self): - self.llm = ChatOpenAI( - model_name="gpt-3.5-turbo", - temperature=0.2, - api_key=self.secrets["OPENAI_API_KEY"], - max_tokens=1000, - callbacks=[self.callback_handler], - streaming=True, - # base_url=self.gateway_url, - ) - - def setup_mixtral_8x7b(self): - self.llm = ChatOpenAI( - model_name="mixtral-8x7b-32768", - temperature=0.2, - api_key=self.secrets["GROQ_API_KEY"], - max_tokens=3000, - callbacks=[self.callback_handler], - streaming=True, - base_url="https://api.groq.com/openai/v1", - ) - - def setup_claude(self): - self.llm = ChatOpenAI( - model_name="anthropic/claude-3-haiku", - temperature=0.1, - api_key=self.secrets["OPENROUTER_API_KEY"], - max_tokens=700, - callbacks=[self.callback_handler], - streaming=True, - base_url="https://openrouter.ai/api/v1", - default_headers={ - "HTTP-Referer": "https://snowchat.streamlit.app/", - "X-Title": "Snowchat", + self.llm = self._setup_llm() + + def _setup_llm(self): + model_config = { + "qwen": { + "model_name": "qwen/qwen-2-72b-instruct", + "api_key": self.secrets["OPENROUTER_API_KEY"], + "base_url": "https://openrouter.ai/api/v1", }, - ) - - def setup_llama(self): - self.llm = ChatOpenAI( - model_name="meta-llama/llama-3-70b-instruct", - temperature=0.1, - api_key=self.secrets["OPENROUTER_API_KEY"], - max_tokens=700, - callbacks=[self.callback_handler], - streaming=True, - base_url="https://openrouter.ai/api/v1", - default_headers={ - "HTTP-Referer": "https://snowchat.streamlit.app/", - "X-Title": "Snowchat", + "claude": { + "model_name": "anthropic/claude-3-haiku", + "api_key": self.secrets["OPENROUTER_API_KEY"], + "base_url": "https://openrouter.ai/api/v1", }, - ) + "mixtral8x7b": { + "model_name": "mixtral-8x7b-32768", + "api_key": self.secrets["GROQ_API_KEY"], + "base_url": "https://api.groq.com/openai/v1", + }, + "llama": { + "model_name": "meta-llama/llama-3-70b-instruct", + "api_key": self.secrets["OPENROUTER_API_KEY"], + "base_url": "https://openrouter.ai/api/v1", + }, + "arctic": { + "model_name": "snowflake/snowflake-arctic-instruct", + "api_key": self.secrets["OPENROUTER_API_KEY"], + "base_url": "https://openrouter.ai/api/v1", + }, + } - def setup_arctic(self): - self.llm = ChatOpenAI( - model_name="snowflake/snowflake-arctic-instruct", + config = model_config[self.model_type] + + return ChatOpenAI( + model_name=config["model_name"], temperature=0.1, - api_key=self.secrets["OPENROUTER_API_KEY"], + api_key=config["api_key"], max_tokens=700, callbacks=[self.callback_handler], streaming=True, - base_url="https://openrouter.ai/api/v1", + base_url=config["base_url"], default_headers={ "HTTP-Referer": "https://snowchat.streamlit.app/", "X-Title": "Snowchat", @@ -154,7 +117,7 @@ def _combine_documents( return conversational_qa_chain -def load_chain(model_name="GPT-3.5", callback_handler=None): +def load_chain(model_name="qwen", callback_handler=None): embeddings = OpenAIEmbeddings( openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" ) @@ -165,17 +128,16 @@ def load_chain(model_name="GPT-3.5", callback_handler=None): query_name="v_match_documents", ) - if "GPT-3.5" in model_name: - model_type = "gpt" - elif "mixtral 8x7b" in model_name.lower(): - model_type = "mixtral8x7b" - elif "claude" in model_name.lower(): - model_type = "claude" - elif "llama" in model_name.lower(): - model_type = "llama" - elif "arctic" in model_name.lower(): - model_type = "arctic" - else: + model_type_mapping = { + "qwen 2-72b": "qwen", + "mixtral 8x7b": "mixtral8x7b", + "claude-3 haiku": "claude", + "llama 3-70b": "llama", + "snowflake arctic": "arctic", + } + + model_type = model_type_mapping.get(model_name.lower()) + if model_type is None: raise ValueError(f"Unsupported model name: {model_name}") config = ModelConfig( diff --git a/main.py b/main.py index 275f9c2..e193f6e 100644 --- a/main.py +++ b/main.py @@ -34,7 +34,13 @@ st.caption("Talk your way through data") model = st.radio( "", - options=["Claude-3 Haiku", "Mixtral 8x7B", "Llama 3-70B", "GPT-3.5", "Snowflake Arctic"], + options=[ + "Claude-3 Haiku", + "Mixtral 8x7B", + "Llama 3-70B", + "Qwen 2-72B", + "Snowflake Arctic", + ], index=0, horizontal=True, ) diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index 46d8f8e..11ed874 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -7,7 +7,10 @@ image_url = f"{st.secrets['SUPABASE_STORAGE_URL']}/storage/v1/object/public/snowchat/" gemini_url = image_url + "google-gemini-icon.png?t=2024-05-07T21%3A17%3A52.235Z" -mistral_url = image_url + "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png?t=2024-05-07T21%3A18%3A22.737Z" +mistral_url = ( + image_url + + "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png?t=2024-05-07T21%3A18%3A22.737Z" +) openai_url = ( image_url + "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-05-07T21%3A18%3A44.079Z" @@ -16,10 +19,12 @@ claude_url = image_url + "Claude.png?t=2024-05-07T21%3A16%3A17.252Z" meta_url = image_url + "meta-logo.webp?t=2024-05-07T21%3A18%3A12.286Z" snow_url = image_url + "Snowflake_idCkdSg0B6_6.png?t=2024-05-07T21%3A24%3A02.597Z" +qwen_url = image_url + "qwen.png?t=2024-06-07T08%3A51%3A36.363Z" + def get_model_url(model_name): - if "gpt" in model_name.lower(): - return openai_url + if "qwen" in model_name.lower(): + return qwen_url elif "claude" in model_name.lower(): return claude_url elif "llama" in model_name.lower():