Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions ai_eval/llm_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,45 @@ def _get_headers(self):
self._ensure_token()
return {'Authorization': f'Bearer {self._access_token}'}

@staticmethod
def _parse_models_field(raw_models):
"""
Parse the value of a top-level "models" field into a list of model ids/names.
"""
if isinstance(raw_models, list):
return [str(m) for m in raw_models]
if isinstance(raw_models, str):
return [str(raw_models)]
if isinstance(raw_models, dict):
parsed_models = []
for key, val in raw_models.items():
if isinstance(val, dict):
candidate = val.get("name") or val.get("id") or key
elif isinstance(val, str) and val.strip():
candidate = val
else:
candidate = key
parsed_models.append(str(candidate))
return parsed_models
return []

@classmethod
def _parse_models_response(cls, data):
"""
Parse a models endpoint JSON response into a list of model ids/names.
"""
if isinstance(data, dict):
if "models" in data:
return cls._parse_models_field(data["models"])
if isinstance(data.get("data"), list):
return [str(m.get("id", str(m))) for m in data["data"]]
return []
if isinstance(data, list):
return [str(m) for m in data]
if isinstance(data, str):
return [str(data)]
return []

def get_response(
self,
model,
Expand Down Expand Up @@ -196,20 +235,7 @@ def get_available_models(self):
response.raise_for_status()
data = response.json()

models = []

if isinstance(data, dict):
if "models" in data:
if isinstance(data["models"], list):
models = [str(m) for m in data["models"]]
elif isinstance(data["models"], str):
models = [str(data["models"])]
elif "data" in data and isinstance(data["data"], list):
models = [str(m.get("id", str(m))) for m in data["data"]]
elif isinstance(data, list):
models = [str(m) for m in data]
elif isinstance(data, str):
models = [str(data)]
models = self._parse_models_response(data)

# Filter out non-string model names and empty strings
models = [m for m in models if isinstance(m, str) and m.strip()]
Expand Down
28 changes: 28 additions & 0 deletions ai_eval/tests/test_ai_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from ai_eval.base import AIEvalXBlock
from ai_eval.supported_models import SupportedModels
from ai_eval.llm_services import CustomLLMService
from ai_eval.backends.factory import BackendFactory
from ai_eval.backends.judge0 import Judge0Backend
from ai_eval.backends.custom import CustomServiceBackend
Expand Down Expand Up @@ -220,6 +221,33 @@ def test_multiagent_block_evaluator_response():
assert resp["is_evaluator"]


def test_custom_llm_models_dict_response_parsed():
"""Custom service supports {"models": {id: {...}}} response bodies."""
service = CustomLLMService(
models_url="https://example.com/models",
completions_url="https://example.com/completions",
token_url="https://example.com/token",
client_id="client",
client_secret="secret",
)
service._get_headers = Mock(return_value={"Authorization": "Bearer token"}) # pylint: disable=protected-access

mocked_response = Mock()
mocked_response.json.return_value = {
"models": {
"Meta-Llama4-Maverick": {
"name": "Meta-Llama4-Maverick",
"display_name": "Maverick (Llama 4)",
},
"gpt-4o": {"display_name": "GPT-4o"},
}
}
mocked_response.raise_for_status.return_value = None

with patch("ai_eval.llm_services.requests.get", return_value=mocked_response):
assert service.get_available_models() == ["Meta-Llama4-Maverick", "gpt-4o"]


@pytest.mark.parametrize(
"backend_config, expected_backend_class",
[
Expand Down