Skip to content

Commit 8b98efb

Browse files
committed
refactor: move LLM configs to llm_providers.py
1 parent 0c29506 commit 8b98efb

File tree

4 files changed

+138
-126
lines changed

4 files changed

+138
-126
lines changed

src/utils/llm_providers.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import os
2+
3+
PROVIDER_DISPLAY_NAMES = {
4+
"openai": "OpenAI",
5+
"azure_openai": "Azure OpenAI",
6+
"anthropic": "Anthropic",
7+
"deepseek": "DeepSeek",
8+
"google": "Google",
9+
"mistral": "Mistral",
10+
"alibaba": "Alibaba",
11+
"moonshot": "MoonShot"
12+
}
13+
14+
PROVIDER_CONFIGS = {
15+
"openai": {
16+
"api_key_env": "OPENAI_API_KEY", "base_url_env": "OPENAI_ENDPOINT",
17+
"default_base_url": "https://api.openai.com/v1", "default_model": "gpt-4o"
18+
},
19+
"azure_openai": {
20+
"api_key_env": "AZURE_OPENAI_API_KEY", "base_url_env": "AZURE_OPENAI_ENDPOINT",
21+
"api_version_env": "AZURE_OPENAI_API_VERSION",
22+
"default_api_version": "2025-01-01-preview", "default_model": "gpt-4o"
23+
},
24+
"anthropic": {
25+
"api_key_env": "ANTHROPIC_API_KEY", "base_url_env": "ANTHROPIC_ENDPOINT",
26+
"default_base_url": "https://api.anthropic.com", "default_model": "claude-3-5-sonnet-20241022"
27+
},
28+
"google": {
29+
"api_key_env": "GOOGLE_API_KEY",
30+
"default_model": "gemini-2.0-flash-exp"
31+
},
32+
"deepseek": {
33+
"api_key_env": "DEEPSEEK_API_KEY", "base_url_env": "DEEPSEEK_ENDPOINT",
34+
"default_base_url": "https://api.deepseek.com", "default_model": "deepseek-chat"
35+
},
36+
"mistral": {
37+
"api_key_env": "MISTRAL_API_KEY", "base_url_env": "MISTRAL_ENDPOINT",
38+
"default_base_url": "https://api.mistral.ai/v1", "default_model": "mistral-large-latest"
39+
},
40+
"alibaba": {
41+
"api_key_env": "ALIBABA_API_KEY", "base_url_env": "ALIBABA_ENDPOINT",
42+
"default_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", "default_model": "qwen-plus"
43+
},
44+
"moonshot": {
45+
"api_key_env": "MOONSHOT_API_KEY", "base_url_env": "MOONSHOT_ENDPOINT",
46+
"default_base_url": "https://api.moonshot.cn/v1", "default_model": "moonshot-v1-32k-vision-preview"
47+
},
48+
"ollama": {
49+
"base_url_env": "OLLAMA_ENDPOINT",
50+
"default_base_url": "http://localhost:11434", "default_model": "qwen2.5:7b"
51+
}
52+
}
53+
54+
# Predefined model names for common providers
55+
MODEL_NAMES = {
56+
"anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
57+
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini"],
58+
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
59+
"google": ["gemini-2.0-flash", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest",
60+
"gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-pro-exp-02-05"],
61+
"ollama": ["qwen2.5:7b", "qwen2.5:14b", "qwen2.5:32b", "qwen2.5-coder:14b", "qwen2.5-coder:32b", "llama2:7b",
62+
"deepseek-r1:14b", "deepseek-r1:32b"],
63+
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
64+
"mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"],
65+
"alibaba": ["qwen-plus", "qwen-max", "qwen-turbo", "qwen-long"],
66+
"moonshot": ["moonshot-v1-32k-vision-preview", "moonshot-v1-8k-vision-preview"],
67+
}
68+
69+
def get_provider_config(provider: str):
70+
return PROVIDER_CONFIGS.get(provider, {})
71+
72+
def get_config_value(provider: str, key: str, **kwargs):
73+
config = get_provider_config(provider)
74+
75+
if key in kwargs and kwargs[key]:
76+
return kwargs[key]
77+
78+
env_key = config.get(f"{key}_env")
79+
if env_key:
80+
env_value = os.getenv(env_key)
81+
if env_value:
82+
return env_value
83+
84+
return config.get(f"default_{key}")

src/utils/utils.py

+48-106
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,10 @@
1414

1515
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama
1616

17-
PROVIDER_DISPLAY_NAMES = {
18-
"openai": "OpenAI",
19-
"azure_openai": "Azure OpenAI",
20-
"anthropic": "Anthropic",
21-
"deepseek": "DeepSeek",
22-
"google": "Google",
23-
"alibaba": "Alibaba",
24-
"moonshot": "MoonShot"
25-
}
17+
from .llm_providers import (
18+
get_provider_config, get_config_value,
19+
PROVIDER_DISPLAY_NAMES, MODEL_NAMES
20+
)
2621

2722

2823
def get_llm_model(provider: str, **kwargs):
@@ -32,176 +27,123 @@ def get_llm_model(provider: str, **kwargs):
3227
:param kwargs:
3328
:return:
3429
"""
35-
if provider not in ["ollama"]:
36-
env_var = f"{provider.upper()}_API_KEY"
37-
api_key = kwargs.get("api_key", "") or os.getenv(env_var, "")
30+
if provider not in {"ollama"}:
31+
api_key = get_config_value(provider, "api_key", **kwargs)
3832
if not api_key:
39-
handle_api_key_error(provider, env_var)
40-
kwargs["api_key"] = api_key
33+
handle_api_key_error(provider)
4134

42-
if provider == "anthropic":
43-
if not kwargs.get("base_url", ""):
44-
base_url = "https://api.anthropic.com"
45-
else:
46-
base_url = kwargs.get("base_url")
35+
base_url = get_config_value(provider, "base_url", **kwargs)
36+
model_name = get_config_value(provider, "model", **kwargs)
37+
temperature = kwargs.get("temperature", 0.0)
4738

39+
if provider == "anthropic":
4840
return ChatAnthropic(
49-
model=kwargs.get("model_name", "claude-3-5-sonnet-20241022"),
50-
temperature=kwargs.get("temperature", 0.0),
41+
model=model_name,
42+
temperature=temperature,
5143
base_url=base_url,
5244
api_key=api_key,
5345
)
54-
elif provider == 'mistral':
55-
if not kwargs.get("base_url", ""):
56-
base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1")
57-
else:
58-
base_url = kwargs.get("base_url")
59-
if not kwargs.get("api_key", ""):
60-
api_key = os.getenv("MISTRAL_API_KEY", "")
61-
else:
62-
api_key = kwargs.get("api_key")
63-
46+
elif provider == "mistral":
6447
return ChatMistralAI(
65-
model=kwargs.get("model_name", "mistral-large-latest"),
66-
temperature=kwargs.get("temperature", 0.0),
48+
model=model_name,
49+
temperature=temperature,
6750
base_url=base_url,
6851
api_key=api_key,
6952
)
7053
elif provider == "openai":
71-
if not kwargs.get("base_url", ""):
72-
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
73-
else:
74-
base_url = kwargs.get("base_url")
75-
7654
return ChatOpenAI(
77-
model=kwargs.get("model_name", "gpt-4o"),
78-
temperature=kwargs.get("temperature", 0.0),
55+
model=model_name,
56+
temperature=temperature,
7957
base_url=base_url,
8058
api_key=api_key,
8159
)
8260
elif provider == "deepseek":
83-
if not kwargs.get("base_url", ""):
84-
base_url = os.getenv("DEEPSEEK_ENDPOINT", "")
85-
else:
86-
base_url = kwargs.get("base_url")
87-
88-
if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner":
61+
if model_name == "deepseek-reasoner":
8962
return DeepSeekR1ChatOpenAI(
90-
model=kwargs.get("model_name", "deepseek-reasoner"),
91-
temperature=kwargs.get("temperature", 0.0),
63+
model=model_name,
64+
temperature=temperature,
9265
base_url=base_url,
9366
api_key=api_key,
9467
)
9568
else:
9669
return ChatOpenAI(
97-
model=kwargs.get("model_name", "deepseek-chat"),
98-
temperature=kwargs.get("temperature", 0.0),
70+
model=model_name,
71+
temperature=temperature,
9972
base_url=base_url,
10073
api_key=api_key,
10174
)
10275
elif provider == "google":
10376
return ChatGoogleGenerativeAI(
104-
model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
105-
temperature=kwargs.get("temperature", 0.0),
77+
model=model_name,
78+
temperature=temperature,
10679
api_key=api_key,
10780
)
10881
elif provider == "ollama":
109-
if not kwargs.get("base_url", ""):
110-
base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
111-
else:
112-
base_url = kwargs.get("base_url")
113-
114-
if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"):
82+
num_ctx = kwargs.get("num_ctx", 32000)
83+
if "deepseek-r1" in model_name:
11584
return DeepSeekR1ChatOllama(
11685
model=kwargs.get("model_name", "deepseek-r1:14b"),
117-
temperature=kwargs.get("temperature", 0.0),
118-
num_ctx=kwargs.get("num_ctx", 32000),
86+
temperature=temperature,
87+
num_ctx=num_ctx,
11988
base_url=base_url,
12089
)
12190
else:
12291
return ChatOllama(
123-
model=kwargs.get("model_name", "qwen2.5:7b"),
124-
temperature=kwargs.get("temperature", 0.0),
125-
num_ctx=kwargs.get("num_ctx", 32000),
92+
model=model_name,
93+
temperature=temperature,
94+
num_ctx=num_ctx,
12695
num_predict=kwargs.get("num_predict", 1024),
12796
base_url=base_url,
12897
)
12998
elif provider == "azure_openai":
130-
if not kwargs.get("base_url", ""):
131-
base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
132-
else:
133-
base_url = kwargs.get("base_url")
134-
api_version = kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview")
99+
api_version = get_config_value(provider, "api_version", **kwargs)
135100
return AzureChatOpenAI(
136-
model=kwargs.get("model_name", "gpt-4o"),
137-
temperature=kwargs.get("temperature", 0.0),
101+
model=model_name,
102+
temperature=temperature,
138103
api_version=api_version,
139104
azure_endpoint=base_url,
140105
api_key=api_key,
141106
)
142107
elif provider == "alibaba":
143-
if not kwargs.get("base_url", ""):
144-
base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1")
145-
else:
146-
base_url = kwargs.get("base_url")
147-
148108
return ChatOpenAI(
149-
model=kwargs.get("model_name", "qwen-plus"),
150-
temperature=kwargs.get("temperature", 0.0),
109+
model=model_name,
110+
temperature=temperature,
151111
base_url=base_url,
152112
api_key=api_key,
153113
)
154114

155115
elif provider == "moonshot":
156116
return ChatOpenAI(
157-
model=kwargs.get("model_name", "moonshot-v1-32k-vision-preview"),
158-
temperature=kwargs.get("temperature", 0.0),
159-
base_url=os.getenv("MOONSHOT_ENDPOINT"),
160-
api_key=os.getenv("MOONSHOT_API_KEY"),
117+
model=model_name,
118+
temperature=temperature,
119+
base_url=base_url,
120+
api_key=api_key,
161121
)
162122
else:
163123
raise ValueError(f"Unsupported provider: {provider}")
164124

165125

166-
# Predefined model names for common providers
167-
model_names = {
168-
"anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
169-
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini"],
170-
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
171-
"google": ["gemini-2.0-flash", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest",
172-
"gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-pro-exp-02-05"],
173-
"ollama": ["qwen2.5:7b", "qwen2.5:14b", "qwen2.5:32b", "qwen2.5-coder:14b", "qwen2.5-coder:32b", "llama2:7b",
174-
"deepseek-r1:14b", "deepseek-r1:32b"],
175-
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
176-
"mistral": ["mixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"],
177-
"alibaba": ["qwen-plus", "qwen-max", "qwen-turbo", "qwen-long"],
178-
"moonshot": ["moonshot-v1-32k-vision-preview", "moonshot-v1-8k-vision-preview"],
179-
}
180-
181-
182-
# Callback to update the model name dropdown based on the selected provider
183126
def update_model_dropdown(llm_provider, api_key=None, base_url=None):
184127
"""
185128
Update the model name dropdown with predefined models for the selected provider.
186129
"""
187130
# Use API keys from .env if not provided
188131
if not api_key:
189-
api_key = os.getenv(f"{llm_provider.upper()}_API_KEY", "")
132+
api_key = get_config_value(llm_provider, "api_key")
190133
if not base_url:
191-
base_url = os.getenv(f"{llm_provider.upper()}_BASE_URL", "")
134+
base_url = get_config_value(llm_provider, "base_url")
192135

193136
# Use predefined models for the selected provider
194-
if llm_provider in model_names:
195-
return gr.Dropdown(choices=model_names[llm_provider], value=model_names[llm_provider][0], interactive=True)
196-
else:
197-
return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True)
198-
137+
if llm_provider in MODEL_NAMES:
138+
return gr.Dropdown(choices=MODEL_NAMES[llm_provider], value=MODEL_NAMES[llm_provider][0], interactive=True)
139+
return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True)
199140

200-
def handle_api_key_error(provider: str, env_var: str):
141+
def handle_api_key_error(provider: str):
201142
"""
202143
Handles the missing API key error by raising a gr.Error with a clear message.
203144
"""
204145
provider_display = PROVIDER_DISPLAY_NAMES.get(provider, provider.upper())
146+
env_var = get_provider_config(provider).get("api_key_env")
205147
raise gr.Error(
206148
f"💥 {provider_display} API key not found! 🔑 Please set the "
207149
f"`{env_var}` environment variable or provide it in the UI."

tests/test_llm_api.py

+3-18
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,6 @@ def create_message_content(text, image_path=None):
3232
})
3333
return content
3434

35-
def get_env_value(key, provider):
36-
env_mappings = {
37-
"openai": {"api_key": "OPENAI_API_KEY", "base_url": "OPENAI_ENDPOINT"},
38-
"azure_openai": {"api_key": "AZURE_OPENAI_API_KEY", "base_url": "AZURE_OPENAI_ENDPOINT"},
39-
"google": {"api_key": "GOOGLE_API_KEY"},
40-
"deepseek": {"api_key": "DEEPSEEK_API_KEY", "base_url": "DEEPSEEK_ENDPOINT"},
41-
"mistral": {"api_key": "MISTRAL_API_KEY", "base_url": "MISTRAL_ENDPOINT"},
42-
"alibaba": {"api_key": "ALIBABA_API_KEY", "base_url": "ALIBABA_ENDPOINT"},
43-
"moonshot":{"api_key": "MOONSHOT_API_KEY", "base_url": "MOONSHOT_ENDPOINT"},
44-
}
45-
46-
if provider in env_mappings and key in env_mappings[provider]:
47-
return os.getenv(env_mappings[provider][key], "")
48-
return ""
49-
5035
def test_llm(config, query, image_path=None, system_message=None):
5136
from src.utils import utils
5237

@@ -69,8 +54,8 @@ def test_llm(config, query, image_path=None, system_message=None):
6954
provider=config.provider,
7055
model_name=config.model_name,
7156
temperature=config.temperature,
72-
base_url=config.base_url or get_env_value("base_url", config.provider),
73-
api_key=config.api_key or get_env_value("api_key", config.provider)
57+
base_url=config.base_url,
58+
api_key=config.api_key
7459
)
7560

7661
# Prepare messages for non-Ollama models
@@ -130,7 +115,7 @@ def test_moonshot_model():
130115
# test_openai_model()
131116
# test_google_model()
132117
# test_azure_openai_model()
133-
#test_deepseek_model()
118+
# test_deepseek_model()
134119
# test_ollama_model()
135120
test_deepseek_r1_model()
136121
# test_deepseek_r1_ollama_model()

webui.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from langchain_ollama import ChatOllama
2525
from playwright.async_api import async_playwright
2626
from src.utils.agent_state import AgentState
27+
from src.utils.llm_providers import MODEL_NAMES
2728

2829
from src.utils import utils
2930
from src.agent.custom_agent import CustomAgent
@@ -796,14 +797,14 @@ def create_ui(config, theme_name="Ocean"):
796797
with gr.TabItem("🔧 LLM Settings", id=2):
797798
with gr.Group():
798799
llm_provider = gr.Dropdown(
799-
choices=[provider for provider, model in utils.model_names.items()],
800+
choices=[provider for provider,model in MODEL_NAMES.items()],
800801
label="LLM Provider",
801802
value=config['llm_provider'],
802803
info="Select your preferred language model provider"
803804
)
804805
llm_model_name = gr.Dropdown(
805806
label="Model Name",
806-
choices=utils.model_names['openai'],
807+
choices=MODEL_NAMES['openai'],
807808
value=config['llm_model_name'],
808809
interactive=True,
809810
allow_custom_value=True, # Allow users to input custom model names

0 commit comments

Comments
 (0)