11import os
22from langchain_core .language_models .chat_models import BaseChatModel
3- from langchain_anthropic import ChatAnthropic
4- from langchain_openai import ChatOpenAI
5- from langchain_google_genai import ChatGoogleGenerativeAI
6- from langchain_mistralai import ChatMistralAI
7- from langchain_fireworks import ChatFireworks
8- from langchain_together import ChatTogether
9- from langchain_google_vertexai import ChatVertexAI
10- from langchain_groq import ChatGroq
11- from langchain_nvidia_ai_endpoints import ChatNVIDIA
12- from langchain_ollama import ChatOllama
13- from langchain_ai21 import ChatAI21
14- from langchain_upstage import ChatUpstage
15- from langchain_databricks import ChatDatabricks
16- from langchain_ibm import ChatWatsonx
17- from langchain_xai import ChatXAI
183from typing import List
194
205
@@ -24,12 +9,12 @@ def get_supported_providers() -> List[str]:
249 "anthropic" ,
2510 "openai" ,
2611 "google" ,
27- "mistralai " ,
12+ "mistral " ,
2813 "fireworks" ,
2914 "together" ,
30- "vertexai " ,
15+ "vertex " ,
3116 "groq" ,
32- "nvidia_ai " ,
17+ "nvidia " ,
3318 "ollama" ,
3419 "ai21" ,
3520 "upstage" ,
@@ -41,30 +26,55 @@ def get_supported_providers() -> List[str]:
4126
4227def get_model (provider_name : str | None , model_name : str | None ) -> BaseChatModel :
4328 """Initialize and configure the LangChain components with specified model."""
44- providers = {
45- "anthropic" : (ChatAnthropic , "ANTHROPIC_API_KEY" , "https://www.anthropic.com" , "anthropic_api_key" ),
46- "openai" : (ChatOpenAI , "OPENAI_API_KEY" , "https://platform.openai.com/account/api-keys" , "openai_api_key" ),
47- "google" : (ChatGoogleGenerativeAI , "GOOGLE_API_KEY" , "https://developers.generativeai.google/" , "google_api_key" ),
48- "mistral" : (ChatMistralAI , "MISTRAL_API_KEY" , "https://console.mistral.ai/api-keys/" , "mistral_api_key" ),
49- "fireworks" : (ChatFireworks , "FIREWORKS_API_KEY" , "https://app.fireworks.ai/" , "fireworks_api_key" ),
50- "together" : (ChatTogether , "TOGETHER_API_KEY" , "https://api.together.xyz/" , "together_api_key" ),
51- "vertex" : (ChatVertexAI , "GOOGLE_APPLICATION_CREDENTIALS" , "https://cloud.google.com/vertex-ai" , None ),
52- "groq" : (ChatGroq , "GROQ_API_KEY" , "https://console.groq.com/" , "groq_api_key" ),
53- "nvidia" : (ChatNVIDIA , "NVIDIA_API_KEY" , "https://api.nvidia.com/" , "nvidia_api_key" ),
54- "ollama" : (ChatOllama , None , "https://ollama.ai/" , None ),
55- "ai21" : (ChatAI21 , "AI21_API_KEY" , "https://www.ai21.com/studio" , "ai21_api_key" ),
56- "upstage" : (ChatUpstage , "UPSTAGE_API_KEY" , "https://upstage.ai/" , "upstage_api_key" ),
57- "databricks" : (ChatDatabricks , "DATABRICKS_TOKEN" , "https://www.databricks.com/" , "databricks_token" ),
58- "watsonx" : (ChatWatsonx , "WATSONX_API_KEY" , "https://www.ibm.com/watsonx" , "watsonx_api_key" ),
59- "xai" : (ChatXAI , "XAI_API_KEY" , "https://xai.com/" , "xai_api_key" ),
60- }
61-
62- max_tokens = 500
63-
64- if provider_name not in providers :
29+ if provider_name == 'anthropic' :
30+ from langchain_anthropic import ChatAnthropic
31+ model_class , api_key_env , api_url , api_key_param = ChatAnthropic , "ANTHROPIC_API_KEY" , "https://www.anthropic.com" , "anthropic_api_key"
32+ elif provider_name == 'openai' :
33+ from langchain_openai import ChatOpenAI
34+ model_class , api_key_env , api_url , api_key_param = ChatOpenAI , "OPENAI_API_KEY" , "https://platform.openai.com/account/api-keys" , "openai_api_key"
35+ elif provider_name == 'google' :
36+ from langchain_google_genai import ChatGoogleGenerativeAI
37+ model_class , api_key_env , api_url , api_key_param = ChatGoogleGenerativeAI , "GOOGLE_API_KEY" , "https://developers.generativeai.google/" , "google_api_key"
38+ elif provider_name == 'mistral' :
39+ from langchain_mistralai import ChatMistralAI
40+ model_class , api_key_env , api_url , api_key_param = ChatMistralAI , "MISTRAL_API_KEY" , "https://console.mistral.ai/api-keys/" , "mistral_api_key"
41+ elif provider_name == 'fireworks' :
42+ from langchain_fireworks import ChatFireworks
43+ model_class , api_key_env , api_url , api_key_param = ChatFireworks , "FIREWORKS_API_KEY" , "https://app.fireworks.ai/" , "fireworks_api_key"
44+ elif provider_name == 'together' :
45+ from langchain_together import ChatTogether
46+ model_class , api_key_env , api_url , api_key_param = ChatTogether , "TOGETHER_API_KEY" , "https://api.together.xyz/" , "together_api_key"
47+ elif provider_name == 'vertex' :
48+ from langchain_google_vertexai import ChatVertexAI
49+ model_class , api_key_env , api_url , api_key_param = ChatVertexAI , "GOOGLE_APPLICATION_CREDENTIALS" , "https://cloud.google.com/vertex-ai" , None
50+ elif provider_name == 'groq' :
51+ from langchain_groq import ChatGroq
52+ model_class , api_key_env , api_url , api_key_param = ChatGroq , "GROQ_API_KEY" , "https://console.groq.com/" , "groq_api_key"
53+ elif provider_name == 'nvidia' :
54+ from langchain_nvidia_ai_endpoints import ChatNVIDIA
55+ model_class , api_key_env , api_url , api_key_param = ChatNVIDIA , "NVIDIA_API_KEY" , "https://api.nvidia.com/" , "nvidia_api_key"
56+ elif provider_name == 'ollama' :
57+ from langchain_ollama import ChatOllama
58+ model_class , api_key_env , api_url , api_key_param = ChatOllama , None , "https://ollama.ai/" , None
59+ elif provider_name == 'ai21' :
60+ from langchain_ai21 import ChatAI21
61+ model_class , api_key_env , api_url , api_key_param = ChatAI21 , "AI21_API_KEY" , "https://www.ai21.com/studio" , "ai21_api_key"
62+ elif provider_name == 'upstage' :
63+ from langchain_upstage import ChatUpstage
64+ model_class , api_key_env , api_url , api_key_param = ChatUpstage , "UPSTAGE_API_KEY" , "https://upstage.ai/" , "upstage_api_key"
65+ elif provider_name == 'databricks' :
66+ from langchain_databricks import ChatDatabricks
67+ model_class , api_key_env , api_url , api_key_param = ChatDatabricks , "DATABRICKS_TOKEN" , "https://www.databricks.com/" , "databricks_token"
68+ elif provider_name == 'watsonx' :
69+ from langchain_watsonx import ChatWatsonx
70+ model_class , api_key_env , api_url , api_key_param = ChatWatsonx , "WATSONX_API_KEY" , "https://www.ibm.com/watsonx" , "watsonx_api_key"
71+ elif provider_name == 'xai' :
72+ from langchain_xai import ChatXAI
73+ model_class , api_key_env , api_url , api_key_param = ChatXAI , "XAI_API_KEY" , "https://xai.com/" , "xai_api_key"
74+ else :
6575 raise ValueError (f"Unsupported LLM provider: { provider_name } " )
6676
67- model_class , api_key_env , api_url , api_key_param = providers [ provider_name ]
77+ max_tokens = 500
6878
6979 if api_key_env :
7080 api_key = os .getenv (api_key_env )
0 commit comments