diff --git a/cascadeflow/schema/config.py b/cascadeflow/schema/config.py index b67121db..c689b078 100644 --- a/cascadeflow/schema/config.py +++ b/cascadeflow/schema/config.py @@ -111,12 +111,14 @@ def validate_provider(cls, v): allowed = [ "openai", "anthropic", + "deepseek", "groq", "ollama", "huggingface", "together", "vllm", "replicate", + "openrouter", "custom", ] if v not in allowed: diff --git a/tests/test_model_config_provider_validation.py b/tests/test_model_config_provider_validation.py new file mode 100644 index 00000000..76760530 --- /dev/null +++ b/tests/test_model_config_provider_validation.py @@ -0,0 +1,29 @@ +import pytest + +from cascadeflow.config import ModelConfig +from cascadeflow.providers.base import PROVIDER_CAPABILITIES + + +@pytest.mark.parametrize( + "provider", + [ + "deepseek", + "DeepSeek", + "openrouter", + "OpenRouter", + ], +) +def test_model_config_provider_allows_supported_providers(provider: str) -> None: + config = ModelConfig(name="x", provider=provider, cost=0.0) + assert config.provider == provider.lower() + + +def test_model_config_provider_rejects_unknown_provider() -> None: + with pytest.raises(ValueError, match="Provider must be one of"): + ModelConfig(name="x", provider="unknown_provider", cost=0.0) + + +def test_model_config_provider_allows_all_provider_capability_keys() -> None: + for provider in PROVIDER_CAPABILITIES: + config = ModelConfig(name="x", provider=provider, cost=0.0) + assert config.provider == provider