14
14
15
15
from .llm import DeepSeekR1ChatOpenAI , DeepSeekR1ChatOllama
16
16
17
- PROVIDER_DISPLAY_NAMES = {
18
- "openai" : "OpenAI" ,
19
- "azure_openai" : "Azure OpenAI" ,
20
- "anthropic" : "Anthropic" ,
21
- "deepseek" : "DeepSeek" ,
22
- "google" : "Google" ,
23
- "alibaba" : "Alibaba" ,
24
- "moonshot" : "MoonShot"
25
- }
17
+ from .llm_providers import (
18
+ get_provider_config , get_config_value ,
19
+ PROVIDER_DISPLAY_NAMES , MODEL_NAMES
20
+ )
26
21
27
22
28
23
def get_llm_model (provider : str , ** kwargs ):
@@ -32,176 +27,123 @@ def get_llm_model(provider: str, **kwargs):
32
27
:param kwargs:
33
28
:return:
34
29
"""
35
- if provider not in ["ollama" ]:
36
- env_var = f"{ provider .upper ()} _API_KEY"
37
- api_key = kwargs .get ("api_key" , "" ) or os .getenv (env_var , "" )
30
+ if provider not in {"ollama" }:
31
+ api_key = get_config_value (provider , "api_key" , ** kwargs )
38
32
if not api_key :
39
- handle_api_key_error (provider , env_var )
40
- kwargs ["api_key" ] = api_key
33
+ handle_api_key_error (provider )
41
34
42
- if provider == "anthropic" :
43
- if not kwargs .get ("base_url" , "" ):
44
- base_url = "https://api.anthropic.com"
45
- else :
46
- base_url = kwargs .get ("base_url" )
35
+ base_url = get_config_value (provider , "base_url" , ** kwargs )
36
+ model_name = get_config_value (provider , "model" , ** kwargs )
37
+ temperature = kwargs .get ("temperature" , 0.0 )
47
38
39
+ if provider == "anthropic" :
48
40
return ChatAnthropic (
49
- model = kwargs . get ( " model_name" , "claude-3-5-sonnet-20241022" ) ,
50
- temperature = kwargs . get ( " temperature" , 0.0 ) ,
41
+ model = model_name ,
42
+ temperature = temperature ,
51
43
base_url = base_url ,
52
44
api_key = api_key ,
53
45
)
54
- elif provider == 'mistral' :
55
- if not kwargs .get ("base_url" , "" ):
56
- base_url = os .getenv ("MISTRAL_ENDPOINT" , "https://api.mistral.ai/v1" )
57
- else :
58
- base_url = kwargs .get ("base_url" )
59
- if not kwargs .get ("api_key" , "" ):
60
- api_key = os .getenv ("MISTRAL_API_KEY" , "" )
61
- else :
62
- api_key = kwargs .get ("api_key" )
63
-
46
+ elif provider == "mistral" :
64
47
return ChatMistralAI (
65
- model = kwargs . get ( " model_name" , "mistral-large-latest" ) ,
66
- temperature = kwargs . get ( " temperature" , 0.0 ) ,
48
+ model = model_name ,
49
+ temperature = temperature ,
67
50
base_url = base_url ,
68
51
api_key = api_key ,
69
52
)
70
53
elif provider == "openai" :
71
- if not kwargs .get ("base_url" , "" ):
72
- base_url = os .getenv ("OPENAI_ENDPOINT" , "https://api.openai.com/v1" )
73
- else :
74
- base_url = kwargs .get ("base_url" )
75
-
76
54
return ChatOpenAI (
77
- model = kwargs . get ( " model_name" , "gpt-4o" ) ,
78
- temperature = kwargs . get ( " temperature" , 0.0 ) ,
55
+ model = model_name ,
56
+ temperature = temperature ,
79
57
base_url = base_url ,
80
58
api_key = api_key ,
81
59
)
82
60
elif provider == "deepseek" :
83
- if not kwargs .get ("base_url" , "" ):
84
- base_url = os .getenv ("DEEPSEEK_ENDPOINT" , "" )
85
- else :
86
- base_url = kwargs .get ("base_url" )
87
-
88
- if kwargs .get ("model_name" , "deepseek-chat" ) == "deepseek-reasoner" :
61
+ if model_name == "deepseek-reasoner" :
89
62
return DeepSeekR1ChatOpenAI (
90
- model = kwargs . get ( " model_name" , "deepseek-reasoner" ) ,
91
- temperature = kwargs . get ( " temperature" , 0.0 ) ,
63
+ model = model_name ,
64
+ temperature = temperature ,
92
65
base_url = base_url ,
93
66
api_key = api_key ,
94
67
)
95
68
else :
96
69
return ChatOpenAI (
97
- model = kwargs . get ( " model_name" , "deepseek-chat" ) ,
98
- temperature = kwargs . get ( " temperature" , 0.0 ) ,
70
+ model = model_name ,
71
+ temperature = temperature ,
99
72
base_url = base_url ,
100
73
api_key = api_key ,
101
74
)
102
75
elif provider == "google" :
103
76
return ChatGoogleGenerativeAI (
104
- model = kwargs . get ( " model_name" , "gemini-2.0-flash-exp" ) ,
105
- temperature = kwargs . get ( " temperature" , 0.0 ) ,
77
+ model = model_name ,
78
+ temperature = temperature ,
106
79
api_key = api_key ,
107
80
)
108
81
elif provider == "ollama" :
109
- if not kwargs .get ("base_url" , "" ):
110
- base_url = os .getenv ("OLLAMA_ENDPOINT" , "http://localhost:11434" )
111
- else :
112
- base_url = kwargs .get ("base_url" )
113
-
114
- if "deepseek-r1" in kwargs .get ("model_name" , "qwen2.5:7b" ):
82
+ num_ctx = kwargs .get ("num_ctx" , 32000 )
83
+ if "deepseek-r1" in model_name :
115
84
return DeepSeekR1ChatOllama (
116
85
model = kwargs .get ("model_name" , "deepseek-r1:14b" ),
117
- temperature = kwargs . get ( " temperature" , 0.0 ) ,
118
- num_ctx = kwargs . get ( " num_ctx" , 32000 ) ,
86
+ temperature = temperature ,
87
+ num_ctx = num_ctx ,
119
88
base_url = base_url ,
120
89
)
121
90
else :
122
91
return ChatOllama (
123
- model = kwargs . get ( " model_name" , "qwen2.5:7b" ) ,
124
- temperature = kwargs . get ( " temperature" , 0.0 ) ,
125
- num_ctx = kwargs . get ( " num_ctx" , 32000 ) ,
92
+ model = model_name ,
93
+ temperature = temperature ,
94
+ num_ctx = num_ctx ,
126
95
num_predict = kwargs .get ("num_predict" , 1024 ),
127
96
base_url = base_url ,
128
97
)
129
98
elif provider == "azure_openai" :
130
- if not kwargs .get ("base_url" , "" ):
131
- base_url = os .getenv ("AZURE_OPENAI_ENDPOINT" , "" )
132
- else :
133
- base_url = kwargs .get ("base_url" )
134
- api_version = kwargs .get ("api_version" , "" ) or os .getenv ("AZURE_OPENAI_API_VERSION" , "2025-01-01-preview" )
99
+ api_version = get_config_value (provider , "api_version" , ** kwargs )
135
100
return AzureChatOpenAI (
136
- model = kwargs . get ( " model_name" , "gpt-4o" ) ,
137
- temperature = kwargs . get ( " temperature" , 0.0 ) ,
101
+ model = model_name ,
102
+ temperature = temperature ,
138
103
api_version = api_version ,
139
104
azure_endpoint = base_url ,
140
105
api_key = api_key ,
141
106
)
142
107
elif provider == "alibaba" :
143
- if not kwargs .get ("base_url" , "" ):
144
- base_url = os .getenv ("ALIBABA_ENDPOINT" , "https://dashscope.aliyuncs.com/compatible-mode/v1" )
145
- else :
146
- base_url = kwargs .get ("base_url" )
147
-
148
108
return ChatOpenAI (
149
- model = kwargs . get ( " model_name" , "qwen-plus" ) ,
150
- temperature = kwargs . get ( " temperature" , 0.0 ) ,
109
+ model = model_name ,
110
+ temperature = temperature ,
151
111
base_url = base_url ,
152
112
api_key = api_key ,
153
113
)
154
114
155
115
elif provider == "moonshot" :
156
116
return ChatOpenAI (
157
- model = kwargs . get ( " model_name" , "moonshot-v1-32k-vision-preview" ) ,
158
- temperature = kwargs . get ( " temperature" , 0.0 ) ,
159
- base_url = os . getenv ( "MOONSHOT_ENDPOINT" ) ,
160
- api_key = os . getenv ( "MOONSHOT_API_KEY" ) ,
117
+ model = model_name ,
118
+ temperature = temperature ,
119
+ base_url = base_url ,
120
+ api_key = api_key ,
161
121
)
162
122
else :
163
123
raise ValueError (f"Unsupported provider: { provider } " )
164
124
165
125
166
- # Predefined model names for common providers
167
- model_names = {
168
- "anthropic" : ["claude-3-5-sonnet-20241022" , "claude-3-5-sonnet-20240620" , "claude-3-opus-20240229" ],
169
- "openai" : ["gpt-4o" , "gpt-4" , "gpt-3.5-turbo" , "o3-mini" ],
170
- "deepseek" : ["deepseek-chat" , "deepseek-reasoner" ],
171
- "google" : ["gemini-2.0-flash" , "gemini-2.0-flash-thinking-exp" , "gemini-1.5-flash-latest" ,
172
- "gemini-1.5-flash-8b-latest" , "gemini-2.0-flash-thinking-exp-01-21" , "gemini-2.0-pro-exp-02-05" ],
173
- "ollama" : ["qwen2.5:7b" , "qwen2.5:14b" , "qwen2.5:32b" , "qwen2.5-coder:14b" , "qwen2.5-coder:32b" , "llama2:7b" ,
174
- "deepseek-r1:14b" , "deepseek-r1:32b" ],
175
- "azure_openai" : ["gpt-4o" , "gpt-4" , "gpt-3.5-turbo" ],
176
- "mistral" : ["mixtral-large-latest" , "mistral-large-latest" , "mistral-small-latest" , "ministral-8b-latest" ],
177
- "alibaba" : ["qwen-plus" , "qwen-max" , "qwen-turbo" , "qwen-long" ],
178
- "moonshot" : ["moonshot-v1-32k-vision-preview" , "moonshot-v1-8k-vision-preview" ],
179
- }
180
-
181
-
182
- # Callback to update the model name dropdown based on the selected provider
183
126
def update_model_dropdown (llm_provider , api_key = None , base_url = None ):
184
127
"""
185
128
Update the model name dropdown with predefined models for the selected provider.
186
129
"""
187
130
# Use API keys from .env if not provided
188
131
if not api_key :
189
- api_key = os . getenv ( f" { llm_provider . upper () } _API_KEY" , "" )
132
+ api_key = get_config_value ( llm_provider , "api_key " )
190
133
if not base_url :
191
- base_url = os . getenv ( f" { llm_provider . upper () } _BASE_URL" , "" )
134
+ base_url = get_config_value ( llm_provider , "base_url " )
192
135
193
136
# Use predefined models for the selected provider
194
- if llm_provider in model_names :
195
- return gr .Dropdown (choices = model_names [llm_provider ], value = model_names [llm_provider ][0 ], interactive = True )
196
- else :
197
- return gr .Dropdown (choices = [], value = "" , interactive = True , allow_custom_value = True )
198
-
137
+ if llm_provider in MODEL_NAMES :
138
+ return gr .Dropdown (choices = MODEL_NAMES [llm_provider ], value = MODEL_NAMES [llm_provider ][0 ], interactive = True )
139
+ return gr .Dropdown (choices = [], value = "" , interactive = True , allow_custom_value = True )
199
140
200
- def handle_api_key_error (provider : str , env_var : str ):
141
+ def handle_api_key_error (provider : str ):
201
142
"""
202
143
Handles the missing API key error by raising a gr.Error with a clear message.
203
144
"""
204
145
provider_display = PROVIDER_DISPLAY_NAMES .get (provider , provider .upper ())
146
+ env_var = get_provider_config (provider ).get ("api_key_env" )
205
147
raise gr .Error (
206
148
f"💥 { provider_display } API key not found! 🔑 Please set the "
207
149
f"`{ env_var } ` environment variable or provide it in the UI."
0 commit comments