12
12
from langchain_openai import AzureChatOpenAI , ChatOpenAI
13
13
14
14
from .llm import DeepSeekR1ChatOpenAI , DeepSeekR1ChatOllama
15
-
16
- PROVIDER_DISPLAY_NAMES = {
17
- "openai" : "OpenAI" ,
18
- "azure_openai" : "Azure OpenAI" ,
19
- "anthropic" : "Anthropic" ,
20
- "deepseek" : "DeepSeek" ,
21
- "google" : "Google" ,
22
- "alibaba" : "Alibaba" ,
23
- "moonshot" : "MoonShot"
24
- }
15
+ from .llm_providers import MODEL_NAMES , PROVIDER_DISPLAY_NAMES , get_config_value
25
16
26
17
27
18
def get_llm_model (provider : str , ** kwargs ):
@@ -31,153 +22,56 @@ def get_llm_model(provider: str, **kwargs):
31
22
:param kwargs:
32
23
:return:
33
24
"""
34
- if provider not in ["ollama" ]:
25
+ api_key = None
26
+ if provider not in {"ollama" }:
35
27
env_var = f"{ provider .upper ()} _API_KEY"
36
- api_key = kwargs . get ( "api_key" , "" ) or os . getenv ( env_var , "" )
28
+ api_key = get_config_value ( provider , "api_key" , ** kwargs )
37
29
if not api_key :
38
30
raise MissingAPIKeyError (provider , env_var )
39
- kwargs ["api_key" ] = api_key
40
31
41
- if provider == "anthropic" :
42
- if not kwargs .get ("base_url" , "" ):
43
- base_url = "https://api.anthropic.com"
44
- else :
45
- base_url = kwargs .get ("base_url" )
32
+ base_url = get_config_value (provider , "base_url" , ** kwargs )
33
+ model_name = get_config_value (provider , "model" , ** kwargs )
34
+ temperature = kwargs .get ("temperature" , 0.0 )
46
35
47
- return ChatAnthropic (
48
- model = kwargs .get ("model_name" , "claude-3-5-sonnet-20241022" ),
49
- temperature = kwargs .get ("temperature" , 0.0 ),
50
- base_url = base_url ,
51
- api_key = api_key ,
52
- )
53
- elif provider == 'mistral' :
54
- if not kwargs .get ("base_url" , "" ):
55
- base_url = os .getenv ("MISTRAL_ENDPOINT" , "https://api.mistral.ai/v1" )
56
- else :
57
- base_url = kwargs .get ("base_url" )
58
- if not kwargs .get ("api_key" , "" ):
59
- api_key = os .getenv ("MISTRAL_API_KEY" , "" )
60
- else :
61
- api_key = kwargs .get ("api_key" )
36
+ common_params = {
37
+ "model" : model_name ,
38
+ "temperature" : temperature ,
39
+ "base_url" : base_url ,
40
+ "api_key" : api_key ,
41
+ }
62
42
63
- return ChatMistralAI (
64
- model = kwargs .get ("model_name" , "mistral-large-latest" ),
65
- temperature = kwargs .get ("temperature" , 0.0 ),
66
- base_url = base_url ,
67
- api_key = api_key ,
68
- )
69
- elif provider == "openai" :
70
- if not kwargs .get ("base_url" , "" ):
71
- base_url = os .getenv ("OPENAI_ENDPOINT" , "https://api.openai.com/v1" )
72
- else :
73
- base_url = kwargs .get ("base_url" )
74
-
75
- return ChatOpenAI (
76
- model = kwargs .get ("model_name" , "gpt-4o" ),
77
- temperature = kwargs .get ("temperature" , 0.0 ),
78
- base_url = base_url ,
79
- api_key = api_key ,
80
- )
43
+ if provider == "anthropic" :
44
+ return ChatAnthropic (** common_params )
45
+ elif provider == "mistral" :
46
+ return ChatMistralAI (** common_params )
47
+ elif provider in {"openai" , "alibaba" , "moonshot" }:
48
+ return ChatOpenAI (** common_params )
81
49
elif provider == "deepseek" :
82
- if not kwargs .get ("base_url" , "" ):
83
- base_url = os .getenv ("DEEPSEEK_ENDPOINT" , "" )
84
- else :
85
- base_url = kwargs .get ("base_url" )
50
+ if model_name == "deepseek-reasoner" :
51
+ return DeepSeekR1ChatOpenAI (** common_params )
52
+ return ChatOpenAI (** common_params )
86
53
87
- if kwargs .get ("model_name" , "deepseek-chat" ) == "deepseek-reasoner" :
88
- return DeepSeekR1ChatOpenAI (
89
- model = kwargs .get ("model_name" , "deepseek-reasoner" ),
90
- temperature = kwargs .get ("temperature" , 0.0 ),
91
- base_url = base_url ,
92
- api_key = api_key ,
93
- )
94
- else :
95
- return ChatOpenAI (
96
- model = kwargs .get ("model_name" , "deepseek-chat" ),
97
- temperature = kwargs .get ("temperature" , 0.0 ),
98
- base_url = base_url ,
99
- api_key = api_key ,
100
- )
101
54
elif provider == "google" :
102
- return ChatGoogleGenerativeAI (
103
- model = kwargs .get ("model_name" , "gemini-2.0-flash-exp" ),
104
- temperature = kwargs .get ("temperature" , 0.0 ),
105
- api_key = api_key ,
106
- )
55
+ common_params .pop ("base_url" , None )
56
+ return ChatGoogleGenerativeAI (** common_params )
107
57
elif provider == "ollama" :
108
- if not kwargs .get ("base_url" , "" ):
109
- base_url = os .getenv ("OLLAMA_ENDPOINT" , "http://localhost:11434" )
110
- else :
111
- base_url = kwargs .get ("base_url" )
58
+ common_params .pop ("api_key" , None )
59
+ common_params ["num_ctx" ] = kwargs .get ("num_ctx" , 32000 )
112
60
113
- if "deepseek-r1" in kwargs .get ("model_name" , "qwen2.5:7b" ):
114
- return DeepSeekR1ChatOllama (
115
- model = kwargs .get ("model_name" , "deepseek-r1:14b" ),
116
- temperature = kwargs .get ("temperature" , 0.0 ),
117
- num_ctx = kwargs .get ("num_ctx" , 32000 ),
118
- base_url = base_url ,
119
- )
61
+ if "deepseek-r1" in model_name :
62
+ common_params ["model" ] = kwargs .get ("model_name" , "deepseek-r1:14b" )
63
+ return DeepSeekR1ChatOllama (** common_params )
120
64
else :
121
- return ChatOllama (
122
- model = kwargs .get ("model_name" , "qwen2.5:7b" ),
123
- temperature = kwargs .get ("temperature" , 0.0 ),
124
- num_ctx = kwargs .get ("num_ctx" , 32000 ),
125
- num_predict = kwargs .get ("num_predict" , 1024 ),
126
- base_url = base_url ,
127
- )
65
+ common_params ["num_predict" ] = kwargs .get ("num_predict" , 1024 )
66
+ return ChatOllama (** common_params )
128
67
elif provider == "azure_openai" :
129
- if not kwargs .get ("base_url" , "" ):
130
- base_url = os .getenv ("AZURE_OPENAI_ENDPOINT" , "" )
131
- else :
132
- base_url = kwargs .get ("base_url" )
133
- api_version = kwargs .get ("api_version" , "" ) or os .getenv ("AZURE_OPENAI_API_VERSION" , "2025-01-01-preview" )
134
- return AzureChatOpenAI (
135
- model = kwargs .get ("model_name" , "gpt-4o" ),
136
- temperature = kwargs .get ("temperature" , 0.0 ),
137
- api_version = api_version ,
138
- azure_endpoint = base_url ,
139
- api_key = api_key ,
140
- )
141
- elif provider == "alibaba" :
142
- if not kwargs .get ("base_url" , "" ):
143
- base_url = os .getenv ("ALIBABA_ENDPOINT" , "https://dashscope.aliyuncs.com/compatible-mode/v1" )
144
- else :
145
- base_url = kwargs .get ("base_url" )
146
-
147
- return ChatOpenAI (
148
- model = kwargs .get ("model_name" , "qwen-plus" ),
149
- temperature = kwargs .get ("temperature" , 0.0 ),
150
- base_url = base_url ,
151
- api_key = api_key ,
152
- )
153
-
154
- elif provider == "moonshot" :
155
- return ChatOpenAI (
156
- model = kwargs .get ("model_name" , "moonshot-v1-32k-vision-preview" ),
157
- temperature = kwargs .get ("temperature" , 0.0 ),
158
- base_url = os .getenv ("MOONSHOT_ENDPOINT" ),
159
- api_key = os .getenv ("MOONSHOT_API_KEY" ),
160
- )
68
+ common_params ["api_version" ] = get_config_value (provider , "api_version" , ** kwargs )
69
+ common_params ["azure_endpoint" ] = common_params .pop ("base_url" , None )
70
+ return AzureChatOpenAI (** common_params )
161
71
else :
162
72
raise ValueError (f"Unsupported provider: { provider } " )
163
73
164
74
165
- # Predefined model names for common providers
166
- model_names = {
167
- "anthropic" : ["claude-3-5-sonnet-20241022" , "claude-3-5-sonnet-20240620" , "claude-3-opus-20240229" ],
168
- "openai" : ["gpt-4o" , "gpt-4" , "gpt-3.5-turbo" , "o3-mini" ],
169
- "deepseek" : ["deepseek-chat" , "deepseek-reasoner" ],
170
- "google" : ["gemini-2.0-flash" , "gemini-2.0-flash-thinking-exp" , "gemini-1.5-flash-latest" ,
171
- "gemini-1.5-flash-8b-latest" , "gemini-2.0-flash-thinking-exp-01-21" , "gemini-2.0-pro-exp-02-05" ],
172
- "ollama" : ["qwen2.5:7b" , "qwen2.5:14b" , "qwen2.5:32b" , "qwen2.5-coder:14b" , "qwen2.5-coder:32b" , "llama2:7b" ,
173
- "deepseek-r1:14b" , "deepseek-r1:32b" ],
174
- "azure_openai" : ["gpt-4o" , "gpt-4" , "gpt-3.5-turbo" ],
175
- "mistral" : ["pixtral-large-latest" , "mistral-large-latest" , "mistral-small-latest" , "ministral-8b-latest" ],
176
- "alibaba" : ["qwen-plus" , "qwen-max" , "qwen-turbo" , "qwen-long" ],
177
- "moonshot" : ["moonshot-v1-32k-vision-preview" , "moonshot-v1-8k-vision-preview" ],
178
- }
179
-
180
-
181
75
# Callback to update the model name dropdown based on the selected provider
182
76
def update_model_dropdown (llm_provider , api_key = None , base_url = None ):
183
77
"""
@@ -186,15 +80,14 @@ def update_model_dropdown(llm_provider, api_key=None, base_url=None):
186
80
import gradio as gr
187
81
# Use API keys from .env if not provided
188
82
if not api_key :
189
- api_key = os . getenv ( f" { llm_provider . upper () } _API_KEY" , "" )
83
+ api_key = get_config_value ( llm_provider , "api_key " )
190
84
if not base_url :
191
- base_url = os . getenv ( f" { llm_provider . upper () } _BASE_URL" , "" )
85
+ base_url = get_config_value ( llm_provider , "base_url " )
192
86
193
87
# Use predefined models for the selected provider
194
- if llm_provider in model_names :
195
- return gr .Dropdown (choices = model_names [llm_provider ], value = model_names [llm_provider ][0 ], interactive = True )
196
- else :
197
- return gr .Dropdown (choices = [], value = "" , interactive = True , allow_custom_value = True )
88
+ if llm_provider in MODEL_NAMES :
89
+ return gr .Dropdown (choices = MODEL_NAMES [llm_provider ], value = MODEL_NAMES [llm_provider ][0 ], interactive = True )
90
+ return gr .Dropdown (choices = [], value = "" , interactive = True , allow_custom_value = True )
198
91
199
92
class MissingAPIKeyError (Exception ):
200
93
"""Custom exception for missing API key."""
0 commit comments