Skip to content

Commit 3af76d0

Browse files
committed
refactor: move LLM configs to llm_providers.py
1 parent f1a467a commit 3af76d0

File tree

4 files changed

+106
-166
lines changed

4 files changed

+106
-166
lines changed

src/utils/llm_providers.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import os
2+
3+
PROVIDER_DISPLAY_NAMES = {
4+
"openai": "OpenAI",
5+
"azure_openai": "Azure OpenAI",
6+
"anthropic": "Anthropic",
7+
"deepseek": "DeepSeek",
8+
"google": "Google",
9+
"mistral": "Mistral",
10+
"alibaba": "Alibaba",
11+
"moonshot": "MoonShot"
12+
}
13+
14+
PROVIDER_CONFIGS = {
15+
"openai": {"default_model": "gpt-4o", "default_base_url": "https://api.openai.com/v1"},
16+
"azure_openai": {"default_model": "gpt-4o", "default_api_version": "2025-01-01-preview"},
17+
"anthropic": {"default_model": "claude-3-5-sonnet-20241022", "default_base_url": "https://api.anthropic.com"},
18+
"google": {"default_model": "gemini-2.0-flash"},
19+
"deepseek": {"default_model": "deepseek-chat", "default_base_url": "https://api.deepseek.com"},
20+
"mistral": {"default_model": "mistral-large-latest", "default_base_url": "https://api.mistral.ai/v1"},
21+
"alibaba": {"default_model": "qwen-plus", "default_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1"},
22+
"moonshot": {"default_model": "moonshot-v1-32k-vision-preview", "default_base_url": "https://api.moonshot.cn/v1"},
23+
"ollama": {"default_model": "qwen2.5:7b", "default_base_url": "http://localhost:11434"}
24+
}
25+
26+
# Predefined model names for common providers
27+
MODEL_NAMES = {
28+
"anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
29+
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini"],
30+
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
31+
"google": ["gemini-2.0-flash", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest",
32+
"gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-pro-exp-02-05"],
33+
"ollama": ["qwen2.5:7b", "qwen2.5:14b", "qwen2.5:32b", "qwen2.5-coder:14b", "qwen2.5-coder:32b", "llama2:7b",
34+
"deepseek-r1:14b", "deepseek-r1:32b"],
35+
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
36+
"mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"],
37+
"alibaba": ["qwen-plus", "qwen-max", "qwen-turbo", "qwen-long"],
38+
"moonshot": ["moonshot-v1-32k-vision-preview", "moonshot-v1-8k-vision-preview"],
39+
}
40+
41+
def get_config_value(provider: str, key: str, **kwargs):
42+
"""Retrieves a configuration value for a given provider and key."""
43+
config = PROVIDER_CONFIGS.get(provider, {})
44+
45+
if key in kwargs and kwargs[key]:
46+
return kwargs[key]
47+
48+
env_key_name = None
49+
if key == "api_key":
50+
env_key_name = f"{provider.upper()}_API_KEY"
51+
elif key == "base_url":
52+
env_key_name = f"{provider.upper()}_ENDPOINT"
53+
elif key == "api_version":
54+
env_key_name = f"{provider.upper()}_API_VERSION"
55+
56+
if env_key_name:
57+
env_value = os.getenv(env_key_name)
58+
if env_value:
59+
return env_value
60+
61+
return config.get(f"default_{key}")

src/utils/utils.py

+39-146
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,7 @@
1212
from langchain_openai import AzureChatOpenAI, ChatOpenAI
1313

1414
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama
15-
16-
PROVIDER_DISPLAY_NAMES = {
17-
"openai": "OpenAI",
18-
"azure_openai": "Azure OpenAI",
19-
"anthropic": "Anthropic",
20-
"deepseek": "DeepSeek",
21-
"google": "Google",
22-
"alibaba": "Alibaba",
23-
"moonshot": "MoonShot"
24-
}
15+
from .llm_providers import MODEL_NAMES, PROVIDER_DISPLAY_NAMES, get_config_value
2516

2617

2718
def get_llm_model(provider: str, **kwargs):
@@ -31,153 +22,56 @@ def get_llm_model(provider: str, **kwargs):
3122
:param kwargs:
3223
:return:
3324
"""
34-
if provider not in ["ollama"]:
25+
api_key = None
26+
if provider not in {"ollama"}:
3527
env_var = f"{provider.upper()}_API_KEY"
36-
api_key = kwargs.get("api_key", "") or os.getenv(env_var, "")
28+
api_key = get_config_value(provider, "api_key", **kwargs)
3729
if not api_key:
3830
raise MissingAPIKeyError(provider, env_var)
39-
kwargs["api_key"] = api_key
4031

41-
if provider == "anthropic":
42-
if not kwargs.get("base_url", ""):
43-
base_url = "https://api.anthropic.com"
44-
else:
45-
base_url = kwargs.get("base_url")
32+
base_url = get_config_value(provider, "base_url", **kwargs)
33+
model_name = get_config_value(provider, "model", **kwargs)
34+
temperature = kwargs.get("temperature", 0.0)
4635

47-
return ChatAnthropic(
48-
model=kwargs.get("model_name", "claude-3-5-sonnet-20241022"),
49-
temperature=kwargs.get("temperature", 0.0),
50-
base_url=base_url,
51-
api_key=api_key,
52-
)
53-
elif provider == 'mistral':
54-
if not kwargs.get("base_url", ""):
55-
base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1")
56-
else:
57-
base_url = kwargs.get("base_url")
58-
if not kwargs.get("api_key", ""):
59-
api_key = os.getenv("MISTRAL_API_KEY", "")
60-
else:
61-
api_key = kwargs.get("api_key")
36+
common_params = {
37+
"model": model_name,
38+
"temperature": temperature,
39+
"base_url": base_url,
40+
"api_key": api_key,
41+
}
6242

63-
return ChatMistralAI(
64-
model=kwargs.get("model_name", "mistral-large-latest"),
65-
temperature=kwargs.get("temperature", 0.0),
66-
base_url=base_url,
67-
api_key=api_key,
68-
)
69-
elif provider == "openai":
70-
if not kwargs.get("base_url", ""):
71-
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
72-
else:
73-
base_url = kwargs.get("base_url")
74-
75-
return ChatOpenAI(
76-
model=kwargs.get("model_name", "gpt-4o"),
77-
temperature=kwargs.get("temperature", 0.0),
78-
base_url=base_url,
79-
api_key=api_key,
80-
)
43+
if provider == "anthropic":
44+
return ChatAnthropic(**common_params)
45+
elif provider == "mistral":
46+
return ChatMistralAI(**common_params)
47+
elif provider in {"openai", "alibaba", "moonshot"}:
48+
return ChatOpenAI(**common_params)
8149
elif provider == "deepseek":
82-
if not kwargs.get("base_url", ""):
83-
base_url = os.getenv("DEEPSEEK_ENDPOINT", "")
84-
else:
85-
base_url = kwargs.get("base_url")
50+
if model_name == "deepseek-reasoner":
51+
return DeepSeekR1ChatOpenAI(**common_params)
52+
return ChatOpenAI(**common_params)
8653

87-
if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner":
88-
return DeepSeekR1ChatOpenAI(
89-
model=kwargs.get("model_name", "deepseek-reasoner"),
90-
temperature=kwargs.get("temperature", 0.0),
91-
base_url=base_url,
92-
api_key=api_key,
93-
)
94-
else:
95-
return ChatOpenAI(
96-
model=kwargs.get("model_name", "deepseek-chat"),
97-
temperature=kwargs.get("temperature", 0.0),
98-
base_url=base_url,
99-
api_key=api_key,
100-
)
10154
elif provider == "google":
102-
return ChatGoogleGenerativeAI(
103-
model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
104-
temperature=kwargs.get("temperature", 0.0),
105-
api_key=api_key,
106-
)
55+
common_params.pop("base_url", None)
56+
return ChatGoogleGenerativeAI(**common_params)
10757
elif provider == "ollama":
108-
if not kwargs.get("base_url", ""):
109-
base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
110-
else:
111-
base_url = kwargs.get("base_url")
58+
common_params.pop("api_key", None)
59+
common_params["num_ctx"] = kwargs.get("num_ctx", 32000)
11260

113-
if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"):
114-
return DeepSeekR1ChatOllama(
115-
model=kwargs.get("model_name", "deepseek-r1:14b"),
116-
temperature=kwargs.get("temperature", 0.0),
117-
num_ctx=kwargs.get("num_ctx", 32000),
118-
base_url=base_url,
119-
)
61+
if "deepseek-r1" in model_name:
62+
common_params["model"] = kwargs.get("model_name", "deepseek-r1:14b")
63+
return DeepSeekR1ChatOllama(**common_params)
12064
else:
121-
return ChatOllama(
122-
model=kwargs.get("model_name", "qwen2.5:7b"),
123-
temperature=kwargs.get("temperature", 0.0),
124-
num_ctx=kwargs.get("num_ctx", 32000),
125-
num_predict=kwargs.get("num_predict", 1024),
126-
base_url=base_url,
127-
)
65+
common_params["num_predict"] = kwargs.get("num_predict", 1024)
66+
return ChatOllama(**common_params)
12867
elif provider == "azure_openai":
129-
if not kwargs.get("base_url", ""):
130-
base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
131-
else:
132-
base_url = kwargs.get("base_url")
133-
api_version = kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview")
134-
return AzureChatOpenAI(
135-
model=kwargs.get("model_name", "gpt-4o"),
136-
temperature=kwargs.get("temperature", 0.0),
137-
api_version=api_version,
138-
azure_endpoint=base_url,
139-
api_key=api_key,
140-
)
141-
elif provider == "alibaba":
142-
if not kwargs.get("base_url", ""):
143-
base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1")
144-
else:
145-
base_url = kwargs.get("base_url")
146-
147-
return ChatOpenAI(
148-
model=kwargs.get("model_name", "qwen-plus"),
149-
temperature=kwargs.get("temperature", 0.0),
150-
base_url=base_url,
151-
api_key=api_key,
152-
)
153-
154-
elif provider == "moonshot":
155-
return ChatOpenAI(
156-
model=kwargs.get("model_name", "moonshot-v1-32k-vision-preview"),
157-
temperature=kwargs.get("temperature", 0.0),
158-
base_url=os.getenv("MOONSHOT_ENDPOINT"),
159-
api_key=os.getenv("MOONSHOT_API_KEY"),
160-
)
68+
common_params["api_version"] = get_config_value(provider, "api_version", **kwargs)
69+
common_params["azure_endpoint"] = common_params.pop("base_url", None)
70+
return AzureChatOpenAI(**common_params)
16171
else:
16272
raise ValueError(f"Unsupported provider: {provider}")
16373

16474

165-
# Predefined model names for common providers
166-
model_names = {
167-
"anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
168-
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini"],
169-
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
170-
"google": ["gemini-2.0-flash", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest",
171-
"gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-pro-exp-02-05"],
172-
"ollama": ["qwen2.5:7b", "qwen2.5:14b", "qwen2.5:32b", "qwen2.5-coder:14b", "qwen2.5-coder:32b", "llama2:7b",
173-
"deepseek-r1:14b", "deepseek-r1:32b"],
174-
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
175-
"mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"],
176-
"alibaba": ["qwen-plus", "qwen-max", "qwen-turbo", "qwen-long"],
177-
"moonshot": ["moonshot-v1-32k-vision-preview", "moonshot-v1-8k-vision-preview"],
178-
}
179-
180-
18175
# Callback to update the model name dropdown based on the selected provider
18276
def update_model_dropdown(llm_provider, api_key=None, base_url=None):
18377
"""
@@ -186,15 +80,14 @@ def update_model_dropdown(llm_provider, api_key=None, base_url=None):
18680
import gradio as gr
18781
# Use API keys from .env if not provided
18882
if not api_key:
189-
api_key = os.getenv(f"{llm_provider.upper()}_API_KEY", "")
83+
api_key = get_config_value(llm_provider, "api_key")
19084
if not base_url:
191-
base_url = os.getenv(f"{llm_provider.upper()}_BASE_URL", "")
85+
base_url = get_config_value(llm_provider, "base_url")
19286

19387
# Use predefined models for the selected provider
194-
if llm_provider in model_names:
195-
return gr.Dropdown(choices=model_names[llm_provider], value=model_names[llm_provider][0], interactive=True)
196-
else:
197-
return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True)
88+
if llm_provider in MODEL_NAMES:
89+
return gr.Dropdown(choices=MODEL_NAMES[llm_provider], value=MODEL_NAMES[llm_provider][0], interactive=True)
90+
return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True)
19891

19992
class MissingAPIKeyError(Exception):
20093
"""Custom exception for missing API key."""

tests/test_llm_api.py

+3-18
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,6 @@ def create_message_content(text, image_path=None):
3232
})
3333
return content
3434

35-
def get_env_value(key, provider):
36-
env_mappings = {
37-
"openai": {"api_key": "OPENAI_API_KEY", "base_url": "OPENAI_ENDPOINT"},
38-
"azure_openai": {"api_key": "AZURE_OPENAI_API_KEY", "base_url": "AZURE_OPENAI_ENDPOINT"},
39-
"google": {"api_key": "GOOGLE_API_KEY"},
40-
"deepseek": {"api_key": "DEEPSEEK_API_KEY", "base_url": "DEEPSEEK_ENDPOINT"},
41-
"mistral": {"api_key": "MISTRAL_API_KEY", "base_url": "MISTRAL_ENDPOINT"},
42-
"alibaba": {"api_key": "ALIBABA_API_KEY", "base_url": "ALIBABA_ENDPOINT"},
43-
"moonshot":{"api_key": "MOONSHOT_API_KEY", "base_url": "MOONSHOT_ENDPOINT"},
44-
}
45-
46-
if provider in env_mappings and key in env_mappings[provider]:
47-
return os.getenv(env_mappings[provider][key], "")
48-
return ""
49-
5035
def test_llm(config, query, image_path=None, system_message=None):
5136
from src.utils import utils
5237

@@ -69,8 +54,8 @@ def test_llm(config, query, image_path=None, system_message=None):
6954
provider=config.provider,
7055
model_name=config.model_name,
7156
temperature=config.temperature,
72-
base_url=config.base_url or get_env_value("base_url", config.provider),
73-
api_key=config.api_key or get_env_value("api_key", config.provider)
57+
base_url=config.base_url,
58+
api_key=config.api_key
7459
)
7560

7661
# Prepare messages for non-Ollama models
@@ -130,7 +115,7 @@ def test_moonshot_model():
130115
# test_openai_model()
131116
# test_google_model()
132117
# test_azure_openai_model()
133-
#test_deepseek_model()
118+
# test_deepseek_model()
134119
# test_ollama_model()
135120
test_deepseek_r1_model()
136121
# test_deepseek_r1_ollama_model()

webui.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from langchain_ollama import ChatOllama
2525
from playwright.async_api import async_playwright
2626
from src.utils.agent_state import AgentState
27+
from src.utils.llm_providers import MODEL_NAMES
2728

2829
from src.utils import utils
2930
from src.agent.custom_agent import CustomAgent
@@ -798,14 +799,14 @@ def create_ui(config, theme_name="Ocean"):
798799
with gr.TabItem("🔧 LLM Settings", id=2):
799800
with gr.Group():
800801
llm_provider = gr.Dropdown(
801-
choices=[provider for provider, model in utils.model_names.items()],
802+
choices=[provider for provider,model in MODEL_NAMES.items()],
802803
label="LLM Provider",
803804
value=config['llm_provider'],
804805
info="Select your preferred language model provider"
805806
)
806807
llm_model_name = gr.Dropdown(
807808
label="Model Name",
808-
choices=utils.model_names['openai'],
809+
choices=MODEL_NAMES['openai'],
809810
value=config['llm_model_name'],
810811
interactive=True,
811812
allow_custom_value=True, # Allow users to input custom model names

0 commit comments

Comments
 (0)