diff --git a/.env.example b/.env.example index ad0bc6ae..db4335cd 100644 --- a/.env.example +++ b/.env.example @@ -27,8 +27,8 @@ MOONSHOT_API_KEY= UNBOUND_ENDPOINT=https://api.getunbound.ai UNBOUND_API_KEY= -SiliconFLOW_ENDPOINT=https://api.siliconflow.cn/v1/ -SiliconFLOW_API_KEY= +SILICONFLOW_ENDPOINT=https://api.siliconflow.cn/v1/ +SILICONFLOW_API_KEY= IBM_ENDPOINT=https://us-south.ml.cloud.ibm.com IBM_API_KEY= diff --git a/src/utils/utils.py b/src/utils/utils.py index 10ebf7ac..dd60ee86 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -23,9 +23,11 @@ "anthropic": "Anthropic", "deepseek": "DeepSeek", "google": "Google", + "mistral": "Mistral", "alibaba": "Alibaba", "moonshot": "MoonShot", "unbound": "Unbound AI", + "siliconflow": "SiliconFlow", "ibm": "IBM" } @@ -37,174 +39,82 @@ def get_llm_model(provider: str, **kwargs): :param kwargs: :return: """ - if provider not in ["ollama"]: + api_key = None + if provider not in {"ollama"}: env_var = f"{provider.upper()}_API_KEY" - api_key = kwargs.get("api_key", "") or os.getenv(env_var, "") + api_key = get_config_value(provider, "api_key", **kwargs) if not api_key: raise MissingAPIKeyError(provider, env_var) - kwargs["api_key"] = api_key - if provider == "anthropic": - if not kwargs.get("base_url", ""): - base_url = "https://api.anthropic.com" - else: - base_url = kwargs.get("base_url") - - return ChatAnthropic( - model=kwargs.get("model_name", "claude-3-5-sonnet-20241022"), - temperature=kwargs.get("temperature", 0.0), - base_url=base_url, - api_key=api_key, - ) - elif provider == 'mistral': - if not kwargs.get("base_url", ""): - base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1") - else: - base_url = kwargs.get("base_url") - if not kwargs.get("api_key", ""): - api_key = os.getenv("MISTRAL_API_KEY", "") - else: - api_key = kwargs.get("api_key") + base_url = get_config_value(provider, "base_url", **kwargs) + model_name = get_config_value(provider, "model_name", **kwargs) + temperature = kwargs.get("temperature", 0.0) + num_ctx = kwargs.get("num_ctx", 32000) + num_predict = kwargs.get("num_predict", 1024) - return ChatMistralAI( - model=kwargs.get("model_name", "mistral-large-latest"), - temperature=kwargs.get("temperature", 0.0), - base_url=base_url, - api_key=api_key, - ) - elif provider == "openai": - if not kwargs.get("base_url", ""): - base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1") - else: - base_url = kwargs.get("base_url") + common_params = { + "model": model_name, + "temperature": temperature, + "base_url": base_url, + "api_key": api_key, + } - return ChatOpenAI( - model=kwargs.get("model_name", "gpt-4o"), - temperature=kwargs.get("temperature", 0.0), - base_url=base_url, - api_key=api_key, - ) + if provider == "anthropic": + return ChatAnthropic(**common_params) + elif provider == "mistral": + return ChatMistralAI(**common_params) + elif provider in {"openai", "alibaba", "moonshot", "unbound", "siliconflow"}: + return ChatOpenAI(**common_params) elif provider == "deepseek": - if not kwargs.get("base_url", ""): - base_url = os.getenv("DEEPSEEK_ENDPOINT", "") - else: - base_url = kwargs.get("base_url") - - if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner": - return DeepSeekR1ChatOpenAI( - model=kwargs.get("model_name", "deepseek-reasoner"), - temperature=kwargs.get("temperature", 0.0), - base_url=base_url, - api_key=api_key, - ) - else: - return ChatOpenAI( - model=kwargs.get("model_name", "deepseek-chat"), - temperature=kwargs.get("temperature", 0.0), - base_url=base_url, - api_key=api_key, - ) + if model_name == "deepseek-reasoner": + return DeepSeekR1ChatOpenAI(**common_params) + return ChatOpenAI(**common_params) + elif provider == "google": - return ChatGoogleGenerativeAI( - model=kwargs.get("model_name", "gemini-2.0-flash-exp"), - temperature=kwargs.get("temperature", 0.0), - api_key=api_key, - ) + common_params.pop("base_url", None) + return ChatGoogleGenerativeAI(**common_params) elif provider == "ollama": - if not kwargs.get("base_url", ""): - base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434") - else: - base_url = kwargs.get("base_url") - - if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"): - return DeepSeekR1ChatOllama( - model=kwargs.get("model_name", "deepseek-r1:14b"), - temperature=kwargs.get("temperature", 0.0), - num_ctx=kwargs.get("num_ctx", 32000), - base_url=base_url, - ) + common_params.pop("api_key", None) + common_params["num_ctx"] = num_ctx + + if model_name and "deepseek-r1" in model_name: + model = kwargs.get("model_name", "deepseek-r1:14b") + return DeepSeekR1ChatOllama(**common_params, model=model) else: - return ChatOllama( - model=kwargs.get("model_name", "qwen2.5:7b"), - temperature=kwargs.get("temperature", 0.0), - num_ctx=kwargs.get("num_ctx", 32000), - num_predict=kwargs.get("num_predict", 1024), - base_url=base_url, - ) + return ChatOllama(**common_params, num_predict=num_predict) elif provider == "azure_openai": - if not kwargs.get("base_url", ""): - base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "") - else: - base_url = kwargs.get("base_url") - api_version = kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview") - return AzureChatOpenAI( - model=kwargs.get("model_name", "gpt-4o"), - temperature=kwargs.get("temperature", 0.0), - api_version=api_version, - azure_endpoint=base_url, - api_key=api_key, - ) - elif provider == "alibaba": - if not kwargs.get("base_url", ""): - base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1") - else: - base_url = kwargs.get("base_url") - - return ChatOpenAI( - model=kwargs.get("model_name", "qwen-plus"), - temperature=kwargs.get("temperature", 0.0), - base_url=base_url, - api_key=api_key, - ) + api_version = get_config_value(provider, "api_version", **kwargs) + azure_endpoint = common_params.pop("base_url", None) + return AzureChatOpenAI(**common_params, api_version=api_version, azure_endpoint=azure_endpoint) elif provider == "ibm": - parameters = { - "temperature": kwargs.get("temperature", 0.0), - "max_tokens": kwargs.get("num_ctx", 32000) + ibm_params = { + "model_id": model_name, + "url": base_url, + "apikey": api_key, + "project_id": get_config_value(provider, "project_id", **kwargs), + "params": { + "temperature": temperature, + "max_tokens": num_ctx + } } - if not kwargs.get("base_url", ""): - base_url = os.getenv("IBM_ENDPOINT", "https://us-south.ml.cloud.ibm.com") - else: - base_url = kwargs.get("base_url") - - return ChatWatsonx( - model_id=kwargs.get("model_name", "ibm/granite-vision-3.1-2b-preview"), - url=base_url, - project_id=os.getenv("IBM_PROJECT_ID"), - apikey=os.getenv("IBM_API_KEY"), - params=parameters - ) - elif provider == "moonshot": - return ChatOpenAI( - model=kwargs.get("model_name", "moonshot-v1-32k-vision-preview"), - temperature=kwargs.get("temperature", 0.0), - base_url=os.getenv("MOONSHOT_ENDPOINT"), - api_key=os.getenv("MOONSHOT_API_KEY"), - ) - elif provider == "unbound": - return ChatOpenAI( - model=kwargs.get("model_name", "gpt-4o-mini"), - temperature=kwargs.get("temperature", 0.0), - base_url=os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai"), - api_key=api_key, - ) - elif provider == "siliconflow": - if not kwargs.get("api_key", ""): - api_key = os.getenv("SiliconFLOW_API_KEY", "") - else: - api_key = kwargs.get("api_key") - if not kwargs.get("base_url", ""): - base_url = os.getenv("SiliconFLOW_ENDPOINT", "") - else: - base_url = kwargs.get("base_url") - return ChatOpenAI( - api_key=api_key, - base_url=base_url, - model_name=kwargs.get("model_name", "Qwen/QwQ-32B"), - temperature=kwargs.get("temperature", 0.0), - ) + return ChatWatsonx(**ibm_params) else: raise ValueError(f"Unsupported provider: {provider}") +PROVIDER_CONFIGS = { + "openai": {"default_model": "gpt-4o", "default_base_url": "https://api.openai.com/v1"}, + "azure_openai": {"default_model": "gpt-4o", "default_api_version": "2025-01-01-preview"}, + "anthropic": {"default_model": "claude-3-5-sonnet-20241022", "default_base_url": "https://api.anthropic.com"}, + "google": {"default_model": "gemini-2.0-flash"}, + "deepseek": {"default_model": "deepseek-chat", "default_base_url": "https://api.deepseek.com"}, + "mistral": {"default_model": "mistral-large-latest", "default_base_url": "https://api.mistral.ai/v1"}, + "alibaba": {"default_model": "qwen-plus", "default_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1"}, + "moonshot": {"default_model": "moonshot-v1-32k-vision-preview", "default_base_url": "https://api.moonshot.cn/v1"}, + "unbound": {"default_model": "gpt-4o-mini", "default_base_url": "https://api.getunbound.ai"}, + "siliconflow": {"default_model": "Qwen/QwQ-32B", "default_base_url": "https://api.siliconflow.cn/v1"}, + "ibm": {"default_model": "ibm/granite-vision-3.1-2b-preview", "default_base_url": "https://us-south.ml.cloud.ibm.com"}, + "ollama": {"default_model": "qwen2.5:7b", "default_base_url": "http://localhost:11434"} +} # Predefined model names for common providers model_names = { @@ -256,6 +166,29 @@ def get_llm_model(provider: str, **kwargs): "ibm": ["ibm/granite-vision-3.1-2b-preview", "meta-llama/llama-4-maverick-17b-128e-instruct-fp8","meta-llama/llama-3-2-90b-vision-instruct"] } +def get_config_value(provider: str, key: str, **kwargs): + """Retrieves a configuration value for a given provider and key.""" + config = PROVIDER_CONFIGS.get(provider, {}) + + if key in kwargs and kwargs[key]: + return kwargs[key] + + env_key_name = None + if key == "api_key": + env_key_name = f"{provider.upper()}_API_KEY" + elif key == "base_url": + env_key_name = f"{provider.upper()}_ENDPOINT" + elif key == "api_version": + env_key_name = f"{provider.upper()}_API_VERSION" + elif key == "project_id": + env_key_name = f"{provider.upper()}_PROJECT_ID" + + if env_key_name: + env_value = os.getenv(env_key_name) + if env_value: + return env_value + + return config.get(f"default_{key}") # Callback to update the model name dropdown based on the selected provider def update_model_dropdown(llm_provider, api_key=None, base_url=None): @@ -265,9 +198,9 @@ def update_model_dropdown(llm_provider, api_key=None, base_url=None): import gradio as gr # Use API keys from .env if not provided if not api_key: - api_key = os.getenv(f"{llm_provider.upper()}_API_KEY", "") + api_key = get_config_value(llm_provider, "api_key") if not base_url: - base_url = os.getenv(f"{llm_provider.upper()}_BASE_URL", "") + base_url = get_config_value(llm_provider, "base_url") # Use predefined models for the selected provider if llm_provider in model_names: diff --git a/tests/test_llm_api.py b/tests/test_llm_api.py index 05bc06e1..e13ab6ff 100644 --- a/tests/test_llm_api.py +++ b/tests/test_llm_api.py @@ -32,22 +32,6 @@ def create_message_content(text, image_path=None): }) return content -def get_env_value(key, provider): - env_mappings = { - "openai": {"api_key": "OPENAI_API_KEY", "base_url": "OPENAI_ENDPOINT"}, - "azure_openai": {"api_key": "AZURE_OPENAI_API_KEY", "base_url": "AZURE_OPENAI_ENDPOINT"}, - "google": {"api_key": "GOOGLE_API_KEY"}, - "deepseek": {"api_key": "DEEPSEEK_API_KEY", "base_url": "DEEPSEEK_ENDPOINT"}, - "mistral": {"api_key": "MISTRAL_API_KEY", "base_url": "MISTRAL_ENDPOINT"}, - "alibaba": {"api_key": "ALIBABA_API_KEY", "base_url": "ALIBABA_ENDPOINT"}, - "moonshot":{"api_key": "MOONSHOT_API_KEY", "base_url": "MOONSHOT_ENDPOINT"}, - "ibm": {"api_key": "IBM_API_KEY", "base_url": "IBM_ENDPOINT"} - } - - if provider in env_mappings and key in env_mappings[provider]: - return os.getenv(env_mappings[provider][key], "") - return "" - def test_llm(config, query, image_path=None, system_message=None): from src.utils import utils @@ -70,8 +54,8 @@ def test_llm(config, query, image_path=None, system_message=None): provider=config.provider, model_name=config.model_name, temperature=config.temperature, - base_url=config.base_url or get_env_value("base_url", config.provider), - api_key=config.api_key or get_env_value("api_key", config.provider) + base_url=config.base_url, + api_key=config.api_key ) # Prepare messages for non-Ollama models