diff --git a/ai_eval/llm_services.py b/ai_eval/llm_services.py index 84dcdee..59734d4 100644 --- a/ai_eval/llm_services.py +++ b/ai_eval/llm_services.py @@ -136,6 +136,45 @@ def _get_headers(self): self._ensure_token() return {'Authorization': f'Bearer {self._access_token}'} + @staticmethod + def _parse_models_field(raw_models): + """ + Parse the value of a top-level "models" field into a list of model ids/names. + """ + if isinstance(raw_models, list): + return [str(m) for m in raw_models] + if isinstance(raw_models, str): + return [str(raw_models)] + if isinstance(raw_models, dict): + parsed_models = [] + for key, val in raw_models.items(): + if isinstance(val, dict): + candidate = val.get("name") or val.get("id") or key + elif isinstance(val, str) and val.strip(): + candidate = val + else: + candidate = key + parsed_models.append(str(candidate)) + return parsed_models + return [] + + @classmethod + def _parse_models_response(cls, data): + """ + Parse a models endpoint JSON response into a list of model ids/names. + """ + if isinstance(data, dict): + if "models" in data: + return cls._parse_models_field(data["models"]) + if isinstance(data.get("data"), list): + return [str(m.get("id", str(m))) for m in data["data"]] + return [] + if isinstance(data, list): + return [str(m) for m in data] + if isinstance(data, str): + return [str(data)] + return [] + def get_response( self, model, @@ -196,20 +235,7 @@ def get_available_models(self): response.raise_for_status() data = response.json() - models = [] - - if isinstance(data, dict): - if "models" in data: - if isinstance(data["models"], list): - models = [str(m) for m in data["models"]] - elif isinstance(data["models"], str): - models = [str(data["models"])] - elif "data" in data and isinstance(data["data"], list): - models = [str(m.get("id", str(m))) for m in data["data"]] - elif isinstance(data, list): - models = [str(m) for m in data] - elif isinstance(data, str): - models = [str(data)] + models = self._parse_models_response(data) # Filter out non-string model names and empty strings models = [m for m in models if isinstance(m, str) and m.strip()] diff --git a/ai_eval/tests/test_ai_eval.py b/ai_eval/tests/test_ai_eval.py index 313bcfa..e5666b8 100644 --- a/ai_eval/tests/test_ai_eval.py +++ b/ai_eval/tests/test_ai_eval.py @@ -19,6 +19,7 @@ ) from ai_eval.base import AIEvalXBlock from ai_eval.supported_models import SupportedModels +from ai_eval.llm_services import CustomLLMService from ai_eval.backends.factory import BackendFactory from ai_eval.backends.judge0 import Judge0Backend from ai_eval.backends.custom import CustomServiceBackend @@ -220,6 +221,33 @@ def test_multiagent_block_evaluator_response(): assert resp["is_evaluator"] +def test_custom_llm_models_dict_response_parsed(): + """Custom service supports {"models": {id: {...}}} response bodies.""" + service = CustomLLMService( + models_url="https://example.com/models", + completions_url="https://example.com/completions", + token_url="https://example.com/token", + client_id="client", + client_secret="secret", + ) + service._get_headers = Mock(return_value={"Authorization": "Bearer token"}) # pylint: disable=protected-access + + mocked_response = Mock() + mocked_response.json.return_value = { + "models": { + "Meta-Llama4-Maverick": { + "name": "Meta-Llama4-Maverick", + "display_name": "Maverick (Llama 4)", + }, + "gpt-4o": {"display_name": "GPT-4o"}, + } + } + mocked_response.raise_for_status.return_value = None + + with patch("ai_eval.llm_services.requests.get", return_value=mocked_response): + assert service.get_available_models() == ["Meta-Llama4-Maverick", "gpt-4o"] + + @pytest.mark.parametrize( "backend_config, expected_backend_class", [