Skip to content

Refactor LLM settings #217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ MOONSHOT_API_KEY=
UNBOUND_ENDPOINT=https://api.getunbound.ai
UNBOUND_API_KEY=

SiliconFLOW_ENDPOINT=https://api.siliconflow.cn/v1/
SiliconFLOW_API_KEY=
SILICONFLOW_ENDPOINT=https://api.siliconflow.cn/v1/
SILICONFLOW_API_KEY=

IBM_ENDPOINT=https://us-south.ml.cloud.ibm.com
IBM_API_KEY=
Expand Down
241 changes: 87 additions & 154 deletions src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
"anthropic": "Anthropic",
"deepseek": "DeepSeek",
"google": "Google",
"mistral": "Mistral",
"alibaba": "Alibaba",
"moonshot": "MoonShot",
"unbound": "Unbound AI",
"siliconflow": "SiliconFlow",
"ibm": "IBM"
}

Expand All @@ -37,174 +39,82 @@ def get_llm_model(provider: str, **kwargs):
:param kwargs:
:return:
"""
if provider not in ["ollama"]:
api_key = None
if provider not in {"ollama"}:
env_var = f"{provider.upper()}_API_KEY"
api_key = kwargs.get("api_key", "") or os.getenv(env_var, "")
api_key = get_config_value(provider, "api_key", **kwargs)
if not api_key:
raise MissingAPIKeyError(provider, env_var)
kwargs["api_key"] = api_key

if provider == "anthropic":
if not kwargs.get("base_url", ""):
base_url = "https://api.anthropic.com"
else:
base_url = kwargs.get("base_url")

return ChatAnthropic(
model=kwargs.get("model_name", "claude-3-5-sonnet-20241022"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == 'mistral':
if not kwargs.get("base_url", ""):
base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1")
else:
base_url = kwargs.get("base_url")
if not kwargs.get("api_key", ""):
api_key = os.getenv("MISTRAL_API_KEY", "")
else:
api_key = kwargs.get("api_key")
base_url = get_config_value(provider, "base_url", **kwargs)
model_name = get_config_value(provider, "model_name", **kwargs)
temperature = kwargs.get("temperature", 0.0)
num_ctx = kwargs.get("num_ctx", 32000)
num_predict = kwargs.get("num_predict", 1024)

return ChatMistralAI(
model=kwargs.get("model_name", "mistral-large-latest"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == "openai":
if not kwargs.get("base_url", ""):
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
else:
base_url = kwargs.get("base_url")
common_params = {
"model": model_name,
"temperature": temperature,
"base_url": base_url,
"api_key": api_key,
}

return ChatOpenAI(
model=kwargs.get("model_name", "gpt-4o"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
if provider == "anthropic":
return ChatAnthropic(**common_params)
elif provider == "mistral":
return ChatMistralAI(**common_params)
elif provider in {"openai", "alibaba", "moonshot", "unbound", "siliconflow"}:
return ChatOpenAI(**common_params)
elif provider == "deepseek":
if not kwargs.get("base_url", ""):
base_url = os.getenv("DEEPSEEK_ENDPOINT", "")
else:
base_url = kwargs.get("base_url")

if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner":
return DeepSeekR1ChatOpenAI(
model=kwargs.get("model_name", "deepseek-reasoner"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
else:
return ChatOpenAI(
model=kwargs.get("model_name", "deepseek-chat"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
if model_name == "deepseek-reasoner":
return DeepSeekR1ChatOpenAI(**common_params)
return ChatOpenAI(**common_params)

elif provider == "google":
return ChatGoogleGenerativeAI(
model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
temperature=kwargs.get("temperature", 0.0),
api_key=api_key,
)
common_params.pop("base_url", None)
return ChatGoogleGenerativeAI(**common_params)
elif provider == "ollama":
if not kwargs.get("base_url", ""):
base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
else:
base_url = kwargs.get("base_url")

if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"):
return DeepSeekR1ChatOllama(
model=kwargs.get("model_name", "deepseek-r1:14b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=kwargs.get("num_ctx", 32000),
base_url=base_url,
)
common_params.pop("api_key", None)
common_params["num_ctx"] = num_ctx

if model_name and "deepseek-r1" in model_name:
model = kwargs.get("model_name", "deepseek-r1:14b")
return DeepSeekR1ChatOllama(**common_params, model=model)
else:
return ChatOllama(
model=kwargs.get("model_name", "qwen2.5:7b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=kwargs.get("num_ctx", 32000),
num_predict=kwargs.get("num_predict", 1024),
base_url=base_url,
)
return ChatOllama(**common_params, num_predict=num_predict)
elif provider == "azure_openai":
if not kwargs.get("base_url", ""):
base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
else:
base_url = kwargs.get("base_url")
api_version = kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview")
return AzureChatOpenAI(
model=kwargs.get("model_name", "gpt-4o"),
temperature=kwargs.get("temperature", 0.0),
api_version=api_version,
azure_endpoint=base_url,
api_key=api_key,
)
elif provider == "alibaba":
if not kwargs.get("base_url", ""):
base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1")
else:
base_url = kwargs.get("base_url")

return ChatOpenAI(
model=kwargs.get("model_name", "qwen-plus"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
api_version = get_config_value(provider, "api_version", **kwargs)
azure_endpoint = common_params.pop("base_url", None)
return AzureChatOpenAI(**common_params, api_version=api_version, azure_endpoint=azure_endpoint)
elif provider == "ibm":
parameters = {
"temperature": kwargs.get("temperature", 0.0),
"max_tokens": kwargs.get("num_ctx", 32000)
ibm_params = {
"model_id": model_name,
"url": base_url,
"apikey": api_key,
"project_id": get_config_value(provider, "project_id", **kwargs),
"params": {
"temperature": temperature,
"max_tokens": num_ctx
}
}
if not kwargs.get("base_url", ""):
base_url = os.getenv("IBM_ENDPOINT", "https://us-south.ml.cloud.ibm.com")
else:
base_url = kwargs.get("base_url")

return ChatWatsonx(
model_id=kwargs.get("model_name", "ibm/granite-vision-3.1-2b-preview"),
url=base_url,
project_id=os.getenv("IBM_PROJECT_ID"),
apikey=os.getenv("IBM_API_KEY"),
params=parameters
)
elif provider == "moonshot":
return ChatOpenAI(
model=kwargs.get("model_name", "moonshot-v1-32k-vision-preview"),
temperature=kwargs.get("temperature", 0.0),
base_url=os.getenv("MOONSHOT_ENDPOINT"),
api_key=os.getenv("MOONSHOT_API_KEY"),
)
elif provider == "unbound":
return ChatOpenAI(
model=kwargs.get("model_name", "gpt-4o-mini"),
temperature=kwargs.get("temperature", 0.0),
base_url=os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai"),
api_key=api_key,
)
elif provider == "siliconflow":
if not kwargs.get("api_key", ""):
api_key = os.getenv("SiliconFLOW_API_KEY", "")
else:
api_key = kwargs.get("api_key")
if not kwargs.get("base_url", ""):
base_url = os.getenv("SiliconFLOW_ENDPOINT", "")
else:
base_url = kwargs.get("base_url")
return ChatOpenAI(
api_key=api_key,
base_url=base_url,
model_name=kwargs.get("model_name", "Qwen/QwQ-32B"),
temperature=kwargs.get("temperature", 0.0),
)
return ChatWatsonx(**ibm_params)
else:
raise ValueError(f"Unsupported provider: {provider}")

PROVIDER_CONFIGS = {
"openai": {"default_model": "gpt-4o", "default_base_url": "https://api.openai.com/v1"},
"azure_openai": {"default_model": "gpt-4o", "default_api_version": "2025-01-01-preview"},
"anthropic": {"default_model": "claude-3-5-sonnet-20241022", "default_base_url": "https://api.anthropic.com"},
"google": {"default_model": "gemini-2.0-flash"},
"deepseek": {"default_model": "deepseek-chat", "default_base_url": "https://api.deepseek.com"},
"mistral": {"default_model": "mistral-large-latest", "default_base_url": "https://api.mistral.ai/v1"},
"alibaba": {"default_model": "qwen-plus", "default_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1"},
"moonshot": {"default_model": "moonshot-v1-32k-vision-preview", "default_base_url": "https://api.moonshot.cn/v1"},
"unbound": {"default_model": "gpt-4o-mini", "default_base_url": "https://api.getunbound.ai"},
"siliconflow": {"default_model": "Qwen/QwQ-32B", "default_base_url": "https://api.siliconflow.cn/v1"},
"ibm": {"default_model": "ibm/granite-vision-3.1-2b-preview", "default_base_url": "https://us-south.ml.cloud.ibm.com"},
"ollama": {"default_model": "qwen2.5:7b", "default_base_url": "http://localhost:11434"}
}

# Predefined model names for common providers
model_names = {
Expand Down Expand Up @@ -256,6 +166,29 @@ def get_llm_model(provider: str, **kwargs):
"ibm": ["ibm/granite-vision-3.1-2b-preview", "meta-llama/llama-4-maverick-17b-128e-instruct-fp8","meta-llama/llama-3-2-90b-vision-instruct"]
}

def get_config_value(provider: str, key: str, **kwargs):
"""Retrieves a configuration value for a given provider and key."""
config = PROVIDER_CONFIGS.get(provider, {})

if key in kwargs and kwargs[key]:
return kwargs[key]

env_key_name = None
if key == "api_key":
env_key_name = f"{provider.upper()}_API_KEY"
elif key == "base_url":
env_key_name = f"{provider.upper()}_ENDPOINT"
elif key == "api_version":
env_key_name = f"{provider.upper()}_API_VERSION"
elif key == "project_id":
env_key_name = f"{provider.upper()}_PROJECT_ID"

if env_key_name:
env_value = os.getenv(env_key_name)
if env_value:
return env_value

return config.get(f"default_{key}")

# Callback to update the model name dropdown based on the selected provider
def update_model_dropdown(llm_provider, api_key=None, base_url=None):
Expand All @@ -265,9 +198,9 @@ def update_model_dropdown(llm_provider, api_key=None, base_url=None):
import gradio as gr
# Use API keys from .env if not provided
if not api_key:
api_key = os.getenv(f"{llm_provider.upper()}_API_KEY", "")
api_key = get_config_value(llm_provider, "api_key")
if not base_url:
base_url = os.getenv(f"{llm_provider.upper()}_BASE_URL", "")
base_url = get_config_value(llm_provider, "base_url")

# Use predefined models for the selected provider
if llm_provider in model_names:
Expand Down
20 changes: 2 additions & 18 deletions tests/test_llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,6 @@ def create_message_content(text, image_path=None):
})
return content

def get_env_value(key, provider):
env_mappings = {
"openai": {"api_key": "OPENAI_API_KEY", "base_url": "OPENAI_ENDPOINT"},
"azure_openai": {"api_key": "AZURE_OPENAI_API_KEY", "base_url": "AZURE_OPENAI_ENDPOINT"},
"google": {"api_key": "GOOGLE_API_KEY"},
"deepseek": {"api_key": "DEEPSEEK_API_KEY", "base_url": "DEEPSEEK_ENDPOINT"},
"mistral": {"api_key": "MISTRAL_API_KEY", "base_url": "MISTRAL_ENDPOINT"},
"alibaba": {"api_key": "ALIBABA_API_KEY", "base_url": "ALIBABA_ENDPOINT"},
"moonshot":{"api_key": "MOONSHOT_API_KEY", "base_url": "MOONSHOT_ENDPOINT"},
"ibm": {"api_key": "IBM_API_KEY", "base_url": "IBM_ENDPOINT"}
}

if provider in env_mappings and key in env_mappings[provider]:
return os.getenv(env_mappings[provider][key], "")
return ""

def test_llm(config, query, image_path=None, system_message=None):
from src.utils import utils

Expand All @@ -70,8 +54,8 @@ def test_llm(config, query, image_path=None, system_message=None):
provider=config.provider,
model_name=config.model_name,
temperature=config.temperature,
base_url=config.base_url or get_env_value("base_url", config.provider),
api_key=config.api_key or get_env_value("api_key", config.provider)
base_url=config.base_url,
api_key=config.api_key
)

# Prepare messages for non-Ollama models
Expand Down