From ef7d425d47ed450e61eb2389f690b786e2818014 Mon Sep 17 00:00:00 2001 From: PR Bot Date: Thu, 26 Mar 2026 21:35:58 +0800 Subject: [PATCH] feat: add MiniMax as alternative LLM provider Add MiniMax M2.7 as an alternative LLM backend alongside Google GenAI, configurable via LLM_PROVIDER and MINIMAX_API_KEY environment variables. - New common/llm_config.py: centralized provider detection and model config - New common/minimax_client.py: OpenAI-compatible client with function calling, JSON generation, think-tag stripping, and code-fence removal - Modified FunctionCallResolver: routes to Google or MiniMax based on provider - Modified BaseServerExecutor: lazy Google client creation for MiniMax support - Modified catalog_agent.py: provider-aware item generation - Modified all agent files: env-based model selection via get_model() - Added openai SDK dependency - Updated README.md with MiniMax setup instructions - 38 unit tests + 4 integration tests (42 total, all passing) --- README.md | 29 ++ samples/python/pyproject.toml | 1 + .../python/src/common/base_server_executor.py | 12 +- .../src/common/function_call_resolver.py | 71 +++- samples/python/src/common/llm_config.py | 67 ++++ samples/python/src/common/minimax_client.py | 142 ++++++++ .../sub_agents/catalog_agent.py | 57 ++- .../python/src/roles/shopping_agent/agent.py | 3 +- .../payment_method_collector/agent.py | 3 +- .../shipping_address_collector/agent.py | 3 +- .../shopping_agent/subagents/shopper/agent.py | 3 +- .../tests/test_function_call_resolver.py | 181 ++++++++++ .../python/tests/test_integration_minimax.py | 134 +++++++ samples/python/tests/test_llm_config.py | 98 +++++ samples/python/tests/test_minimax_client.py | 341 ++++++++++++++++++ 15 files changed, 1107 insertions(+), 38 deletions(-) create mode 100644 samples/python/src/common/llm_config.py create mode 100644 samples/python/src/common/minimax_client.py create mode 100644 samples/python/tests/test_function_call_resolver.py create mode 100644 samples/python/tests/test_integration_minimax.py create mode 100644 samples/python/tests/test_llm_config.py create mode 100644 samples/python/tests/test_minimax_client.py diff --git a/README.md b/README.md index dfa3d4f2..07770c91 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,35 @@ For either method, you can set the required credentials as environment variables export GOOGLE_APPLICATION_CREDENTIALS='/path/to/your/service-account-key.json' ``` +#### Option 3: [MiniMax](https://www.minimaxi.com/) (Alternative LLM provider) + +The samples also support [MiniMax](https://www.minimaxi.com/) as an +alternative LLM backend. MiniMax offers powerful models such as MiniMax-M2.7 +with up to 1M context window, accessible via an OpenAI-compatible API. + +1. Obtain a MiniMax API key from [MiniMax Platform](https://platform.minimaxi.com/). +2. Set the required environment variables. + + - **As environment variables:** + + ```sh + export LLM_PROVIDER='minimax' + export MINIMAX_API_KEY='your_minimax_key' + ``` + + - **In a `.env` file:** + + ```sh + LLM_PROVIDER='minimax' + MINIMAX_API_KEY='your_minimax_key' + ``` + + You can optionally override the default model (MiniMax-M2.7): + + ```sh + export LLM_MODEL='MiniMax-M2.7-highspeed' + ``` + ### How to Run a Scenario To run a specific scenario, follow the instructions in its `README.md`. It will diff --git a/samples/python/pyproject.toml b/samples/python/pyproject.toml index 54f67bed..822874ee 100644 --- a/samples/python/pyproject.toml +++ b/samples/python/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "google-adk", "google-genai", "httpx", + "openai", "requests", "ap2" ] diff --git a/samples/python/src/common/base_server_executor.py b/samples/python/src/common/base_server_executor.py index 4b0410c5..53042325 100644 --- a/samples/python/src/common/base_server_executor.py +++ b/samples/python/src/common/base_server_executor.py @@ -37,12 +37,13 @@ from a2a.types import TextPart from a2a.utils import message from ap2.types.mandate import PAYMENT_MANDATE_DATA_KEY -from google import genai from ap2.types.mandate import PaymentMandate from common import message_utils from common import watch_log from common.a2a_extension_utils import EXTENSION_URI from common.function_call_resolver import FunctionCallResolver +from common.llm_config import LLMProvider +from common.llm_config import get_provider from common.validation import validate_payment_mandate_signature DataPartContent = dict[str, Any] @@ -68,7 +69,14 @@ def __init__( self._supported_extension_uris = {ext.uri for ext in supported_extensions} else: self._supported_extension_uris = set() - self._client = genai.Client() + + provider = get_provider() + if provider == LLMProvider.GOOGLE: + from google import genai + self._client = genai.Client() + else: + self._client = None + self._tools = tools self._tool_resolver = FunctionCallResolver( self._client, self._tools, system_prompt diff --git a/samples/python/src/common/function_call_resolver.py b/samples/python/src/common/function_call_resolver.py index 49356429..6b813c82 100644 --- a/samples/python/src/common/function_call_resolver.py +++ b/samples/python/src/common/function_call_resolver.py @@ -16,6 +16,9 @@ The FunctionCallResolver uses a LLM to determine which tool to use based on the instructions provided. + +Supports both Google GenAI and MiniMax (via OpenAI-compatible API) backends. +Set ``LLM_PROVIDER=minimax`` and ``MINIMAX_API_KEY`` to use MiniMax. """ import logging @@ -26,6 +29,10 @@ from google import genai from google.genai import types +from common.llm_config import LLMProvider +from common.llm_config import get_model +from common.llm_config import get_provider + DataPartContent = dict[str, Any] Tool = Callable[[list[DataPartContent], TaskUpdater, Task | None], Any] @@ -36,35 +43,42 @@ class FunctionCallResolver: def __init__( self, - llm_client: genai.Client, + llm_client: genai.Client | None, tools: list[Tool], instructions: str = "You are a helpful assistant.", ): """Initialization. Args: - llm_client: The LLM client. + llm_client: The LLM client. May be ``None`` when using a non-Google + provider (e.g. MiniMax). tools: The list of tools that a request can be resolved to. instructions: The instructions to guide the LLM. """ + self._provider = get_provider() + self._model = get_model() + self._tools = tools + self._instructions = instructions self._client = llm_client - function_declarations = [ - types.FunctionDeclaration( - name=tool.__name__, description=tool.__doc__ - ) - for tool in tools - ] - self._config = types.GenerateContentConfig( - system_instruction=instructions, - tools=[types.Tool(function_declarations=function_declarations)], - automatic_function_calling=types.AutomaticFunctionCallingConfig( - disable=True - ), - # Force the model to call 'any' function, instead of chatting. - tool_config=types.ToolConfig( - function_calling_config=types.FunctionCallingConfig(mode="ANY") - ), - ) + + if self._provider == LLMProvider.GOOGLE: + function_declarations = [ + types.FunctionDeclaration( + name=tool.__name__, description=tool.__doc__ + ) + for tool in tools + ] + self._config = types.GenerateContentConfig( + system_instruction=instructions, + tools=[types.Tool(function_declarations=function_declarations)], + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True + ), + # Force the model to call 'any' function, instead of chatting. + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig(mode="ANY") + ), + ) def determine_tool_to_use(self, prompt: str) -> str: """Determines which tool to use based on a user's prompt. @@ -79,9 +93,15 @@ def determine_tool_to_use(self, prompt: str) -> str: The name of the tool function that the model has determined should be called. If no suitable tool is found, it returns "Unknown". """ + if self._provider == LLMProvider.MINIMAX: + return self._determine_tool_minimax(prompt) + return self._determine_tool_google(prompt) + + def _determine_tool_google(self, prompt: str) -> str: + """Resolve the tool using Google GenAI.""" response = self._client.models.generate_content( - model="gemini-2.5-flash", + model=self._model, contents=prompt, config=self._config, ) @@ -98,3 +118,14 @@ def determine_tool_to_use(self, prompt: str) -> str: return part.function_call.name return "Unknown" + + def _determine_tool_minimax(self, prompt: str) -> str: + """Resolve the tool using MiniMax via OpenAI-compatible API.""" + from common.minimax_client import minimax_resolve_function_call + + return minimax_resolve_function_call( + model=self._model, + tools=self._tools, + system_prompt=self._instructions, + user_prompt=prompt, + ) diff --git a/samples/python/src/common/llm_config.py b/samples/python/src/common/llm_config.py new file mode 100644 index 00000000..80efab5d --- /dev/null +++ b/samples/python/src/common/llm_config.py @@ -0,0 +1,67 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Centralized LLM provider configuration. + +Reads the LLM_PROVIDER and LLM_MODEL environment variables to determine which +LLM backend to use. Supported providers: + +* ``google`` – Google GenAI / Gemini (default) +* ``minimax`` – MiniMax via OpenAI-compatible API +""" + +import enum +import os + + +class LLMProvider(enum.Enum): + """Supported LLM provider backends.""" + + GOOGLE = "google" + MINIMAX = "minimax" + + +# Default model names per provider. +_DEFAULT_MODELS: dict[LLMProvider, str] = { + LLMProvider.GOOGLE: "gemini-2.5-flash", + LLMProvider.MINIMAX: "MiniMax-M2.7", +} + + +def get_provider() -> LLMProvider: + """Return the configured LLM provider. + + Reads the ``LLM_PROVIDER`` environment variable (case-insensitive). + Falls back to ``LLMProvider.GOOGLE`` when unset. + """ + raw = os.environ.get("LLM_PROVIDER", "google").strip().lower() + try: + return LLMProvider(raw) + except ValueError: + raise ValueError( + f"Unsupported LLM_PROVIDER '{raw}'. " + f"Supported values: {[p.value for p in LLMProvider]}" + ) + + +def get_model() -> str: + """Return the configured model name. + + Uses ``LLM_MODEL`` when set, otherwise falls back to the default model + for the active provider. + """ + explicit = os.environ.get("LLM_MODEL", "").strip() + if explicit: + return explicit + return _DEFAULT_MODELS[get_provider()] diff --git a/samples/python/src/common/minimax_client.py b/samples/python/src/common/minimax_client.py new file mode 100644 index 00000000..88a903e2 --- /dev/null +++ b/samples/python/src/common/minimax_client.py @@ -0,0 +1,142 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MiniMax LLM client using the OpenAI-compatible API. + +Provides two helpers consumed by the AP2 agent infrastructure: + +* :func:`minimax_resolve_function_call` – picks a tool via function-calling. +* :func:`minimax_generate_json` – generates structured JSON output. + +Both talk to ``https://api.minimax.io/v1`` and require the +``MINIMAX_API_KEY`` environment variable. +""" + +import json +import logging +import os +import re +from typing import Any, Callable + +from a2a.server.tasks.task_updater import TaskUpdater +from a2a.types import Task +from openai import OpenAI + +MINIMAX_BASE_URL = "https://api.minimax.io/v1" + +DataPartContent = dict[str, Any] +Tool = Callable[[list[DataPartContent], TaskUpdater, Task | None], Any] + + +def _get_client() -> OpenAI: + """Create an OpenAI client pointed at MiniMax.""" + api_key = os.environ.get("MINIMAX_API_KEY", "") + if not api_key: + raise ValueError( + "MINIMAX_API_KEY environment variable is required when " + "LLM_PROVIDER is set to 'minimax'." + ) + return OpenAI(api_key=api_key, base_url=MINIMAX_BASE_URL) + + +def _strip_think_tags(text: str) -> str: + """Remove blocks that MiniMax M2 models may emit.""" + return re.sub(r".*?", "", text, flags=re.DOTALL).strip() + + +def _strip_code_fences(text: str) -> str: + """Remove markdown code fences (```json ... ```) from LLM output.""" + text = text.strip() + text = re.sub(r"^```(?:json)?\s*\n?", "", text) + text = re.sub(r"\n?```\s*$", "", text) + return text.strip() + + +def minimax_resolve_function_call( + model: str, + tools: list[Tool], + system_prompt: str, + user_prompt: str, +) -> str: + """Use MiniMax to pick the best tool for *user_prompt*. + + Converts the Python callables in *tools* into OpenAI-style + function-calling tool definitions and forces the model to call one. + + Returns the name of the chosen tool, or ``"Unknown"`` on failure. + """ + client = _get_client() + + openai_tools = [ + { + "type": "function", + "function": { + "name": tool.__name__, + "description": tool.__doc__ or "", + "parameters": { + "type": "object", + "properties": {}, + }, + }, + } + for tool in tools + ] + + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + tools=openai_tools, + tool_choice="required", + temperature=0.1, + ) + + logging.debug("\nMiniMax Determine Tool Response: %s\n", response) + + choice = response.choices[0] if response.choices else None + if choice and choice.message and choice.message.tool_calls: + return choice.message.tool_calls[0].function.name + + return "Unknown" + + +def minimax_generate_json( + model: str, + prompt: str, + system_prompt: str = "", +) -> Any: + """Ask MiniMax for a JSON response and return the parsed object. + + Uses ``response_format={"type": "json_object"}`` to ensure valid JSON. + """ + client = _get_client() + + messages: list[dict[str, str]] = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + response = client.chat.completions.create( + model=model, + messages=messages, + response_format={"type": "json_object"}, + temperature=0.1, + ) + + raw = response.choices[0].message.content or "{}" + raw = _strip_think_tags(raw) + raw = _strip_code_fences(raw) + return json.loads(raw) diff --git a/samples/python/src/roles/merchant_agent/sub_agents/catalog_agent.py b/samples/python/src/roles/merchant_agent/sub_agents/catalog_agent.py index b40f72e0..e8b758b2 100644 --- a/samples/python/src/roles/merchant_agent/sub_agents/catalog_agent.py +++ b/samples/python/src/roles/merchant_agent/sub_agents/catalog_agent.py @@ -28,7 +28,6 @@ from a2a.types import Part from a2a.types import Task from a2a.types import TextPart -from google import genai from pydantic import ValidationError from .. import storage @@ -43,6 +42,9 @@ from ap2.types.payment_request import PaymentOptions from ap2.types.payment_request import PaymentRequest from common import message_utils +from common.llm_config import LLMProvider +from common.llm_config import get_model +from common.llm_config import get_provider from common.system_utils import DEBUG_MODE_INSTRUCTIONS @@ -52,7 +54,8 @@ async def find_items_workflow( current_task: Task | None, ) -> None: """Finds products that match the user's IntentMandate.""" - llm_client = genai.Client() + provider = get_provider() + model = get_model() intent_mandate = message_utils.parse_canonical_object( INTENT_MANDATE_DATA_KEY, data_parts, IntentMandate @@ -67,17 +70,12 @@ async def find_items_workflow( %s """ % DEBUG_MODE_INSTRUCTIONS - llm_response = llm_client.models.generate_content( - model="gemini-2.5-flash", - contents=prompt, - config={ - "response_mime_type": "application/json", - "response_schema": list[PaymentItem], - } - ) - try: - items: list[PaymentItem] = llm_response.parsed + if provider == LLMProvider.MINIMAX: + items = _generate_items_minimax(model, prompt) + else: + items = _generate_items_google(model, prompt) + try: current_time = datetime.now(timezone.utc) item_count = 0 for item in items: @@ -102,6 +100,41 @@ async def find_items_workflow( return +def _generate_items_google(model: str, prompt: str) -> list[PaymentItem]: + """Generate items using Google GenAI.""" + from google import genai + + llm_client = genai.Client() + llm_response = llm_client.models.generate_content( + model=model, + contents=prompt, + config={ + "response_mime_type": "application/json", + "response_schema": list[PaymentItem], + } + ) + return llm_response.parsed + + +def _generate_items_minimax(model: str, prompt: str) -> list[PaymentItem]: + """Generate items using MiniMax via OpenAI-compatible API.""" + from common.minimax_client import minimax_generate_json + + schema_hint = ( + "Return a JSON object with a key 'items' containing a list of exactly " + "3 PaymentItem objects. Each PaymentItem must have: " + "'label' (string, product name without branding), " + "'amount' (object with 'currency' string e.g. 'USD' and 'value' float)." + ) + full_prompt = f"{prompt}\n\n{schema_hint}" + raw = minimax_generate_json(model=model, prompt=full_prompt) + + items_data = raw.get("items", raw) if isinstance(raw, dict) else raw + if isinstance(items_data, dict): + items_data = [items_data] + return [PaymentItem.model_validate(item) for item in items_data] + + async def _create_and_add_cart_mandate_artifact( item: PaymentItem, item_count: int, diff --git a/samples/python/src/roles/shopping_agent/agent.py b/samples/python/src/roles/shopping_agent/agent.py index 37c91a31..e40e35e7 100644 --- a/samples/python/src/roles/shopping_agent/agent.py +++ b/samples/python/src/roles/shopping_agent/agent.py @@ -26,13 +26,14 @@ from .subagents.payment_method_collector.agent import payment_method_collector from .subagents.shipping_address_collector.agent import shipping_address_collector from .subagents.shopper.agent import shopper +from common.llm_config import get_model from common.retrying_llm_agent import RetryingLlmAgent from common.system_utils import DEBUG_MODE_INSTRUCTIONS root_agent = RetryingLlmAgent( max_retries=5, - model="gemini-2.5-flash", + model=get_model(), name="root_agent", instruction=""" You are a shopping agent responsible for helping users find and diff --git a/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/agent.py b/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/agent.py index 446f83fe..7a460730 100644 --- a/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/agent.py +++ b/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/agent.py @@ -26,12 +26,13 @@ """ from . import tools +from common.llm_config import get_model from common.retrying_llm_agent import RetryingLlmAgent from common.system_utils import DEBUG_MODE_INSTRUCTIONS payment_method_collector = RetryingLlmAgent( - model="gemini-2.5-flash", + model=get_model(), name="payment_method_collector", max_retries=5, instruction=""" diff --git a/samples/python/src/roles/shopping_agent/subagents/shipping_address_collector/agent.py b/samples/python/src/roles/shopping_agent/subagents/shipping_address_collector/agent.py index 0407b2b6..90b434ac 100644 --- a/samples/python/src/roles/shopping_agent/subagents/shipping_address_collector/agent.py +++ b/samples/python/src/roles/shopping_agent/subagents/shipping_address_collector/agent.py @@ -27,11 +27,12 @@ """ from . import tools +from common.llm_config import get_model from common.retrying_llm_agent import RetryingLlmAgent from common.system_utils import DEBUG_MODE_INSTRUCTIONS shipping_address_collector = RetryingLlmAgent( - model="gemini-2.5-flash", + model=get_model(), name="shipping_address_collector", max_retries=5, instruction=""" diff --git a/samples/python/src/roles/shopping_agent/subagents/shopper/agent.py b/samples/python/src/roles/shopping_agent/subagents/shopper/agent.py index d380fac7..d90e5ab5 100644 --- a/samples/python/src/roles/shopping_agent/subagents/shopper/agent.py +++ b/samples/python/src/roles/shopping_agent/subagents/shopper/agent.py @@ -25,12 +25,13 @@ """ from . import tools +from common.llm_config import get_model from common.retrying_llm_agent import RetryingLlmAgent from common.system_utils import DEBUG_MODE_INSTRUCTIONS shopper = RetryingLlmAgent( - model="gemini-2.5-flash", + model=get_model(), name="shopper", max_retries=5, instruction=""" diff --git a/samples/python/tests/test_function_call_resolver.py b/samples/python/tests/test_function_call_resolver.py new file mode 100644 index 00000000..13aa7936 --- /dev/null +++ b/samples/python/tests/test_function_call_resolver.py @@ -0,0 +1,181 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for common.function_call_resolver with provider routing.""" + +import os +import unittest +from unittest import mock + +from common.llm_config import LLMProvider + + +class TestFunctionCallResolverProviderRouting(unittest.TestCase): + """Tests that FunctionCallResolver routes to the correct backend.""" + + @mock.patch.dict(os.environ, {"LLM_PROVIDER": "google"}, clear=False) + @mock.patch("common.function_call_resolver.get_provider", return_value=LLMProvider.GOOGLE) + @mock.patch("common.function_call_resolver.get_model", return_value="gemini-2.5-flash") + def test_google_provider_uses_genai(self, mock_model, mock_provider): + """When provider is Google, _determine_tool_google is called.""" + mock_client = mock.MagicMock() + + from common.function_call_resolver import FunctionCallResolver + + def dummy_tool(): + """A test tool.""" + pass + + resolver = FunctionCallResolver( + llm_client=mock_client, + tools=[dummy_tool], + instructions="test", + ) + + mock_part = mock.MagicMock() + mock_part.function_call.name = "dummy_tool" + mock_content = mock.MagicMock() + mock_content.parts = [mock_part] + mock_candidate = mock.MagicMock() + mock_candidate.content = mock_content + mock_response = mock.MagicMock() + mock_response.candidates = [mock_candidate] + mock_client.models.generate_content.return_value = mock_response + + result = resolver.determine_tool_to_use("test prompt") + self.assertEqual(result, "dummy_tool") + mock_client.models.generate_content.assert_called_once() + + @mock.patch.dict(os.environ, {"LLM_PROVIDER": "minimax", "MINIMAX_API_KEY": "test"}, clear=False) + @mock.patch("common.function_call_resolver.get_provider", return_value=LLMProvider.MINIMAX) + @mock.patch("common.function_call_resolver.get_model", return_value="MiniMax-M2.7") + @mock.patch("common.minimax_client.minimax_resolve_function_call", return_value="my_tool") + def test_minimax_provider_uses_openai(self, mock_resolve, mock_model, mock_provider): + """When provider is MiniMax, minimax_resolve_function_call is called.""" + from common.function_call_resolver import FunctionCallResolver + + def my_tool(): + """A test tool.""" + pass + + resolver = FunctionCallResolver( + llm_client=None, + tools=[my_tool], + instructions="test prompt", + ) + + result = resolver.determine_tool_to_use("find me items") + self.assertEqual(result, "my_tool") + mock_resolve.assert_called_once() + + @mock.patch.dict(os.environ, {"LLM_PROVIDER": "google"}, clear=False) + @mock.patch("common.function_call_resolver.get_provider", return_value=LLMProvider.GOOGLE) + @mock.patch("common.function_call_resolver.get_model", return_value="gemini-2.5-flash") + def test_google_returns_unknown_on_empty(self, mock_model, mock_provider): + """When Google returns no function call, Unknown is returned.""" + mock_client = mock.MagicMock() + + from common.function_call_resolver import FunctionCallResolver + + def dummy_tool(): + """A test tool.""" + pass + + resolver = FunctionCallResolver( + llm_client=mock_client, + tools=[dummy_tool], + instructions="test", + ) + + mock_response = mock.MagicMock() + mock_response.candidates = [] + mock_client.models.generate_content.return_value = mock_response + + result = resolver.determine_tool_to_use("test prompt") + self.assertEqual(result, "Unknown") + + @mock.patch.dict(os.environ, {"LLM_PROVIDER": "google"}, clear=False) + @mock.patch("common.function_call_resolver.get_provider", return_value=LLMProvider.GOOGLE) + @mock.patch("common.function_call_resolver.get_model", return_value="gemini-2.5-flash") + def test_google_uses_configured_model(self, mock_model, mock_provider): + """Verifies the model from get_model() is used in the API call.""" + mock_client = mock.MagicMock() + + from common.function_call_resolver import FunctionCallResolver + + def dummy_tool(): + """A test tool.""" + pass + + resolver = FunctionCallResolver( + llm_client=mock_client, + tools=[dummy_tool], + ) + + mock_part = mock.MagicMock() + mock_part.function_call.name = "dummy_tool" + mock_content = mock.MagicMock() + mock_content.parts = [mock_part] + mock_candidate = mock.MagicMock() + mock_candidate.content = mock_content + mock_response = mock.MagicMock() + mock_response.candidates = [mock_candidate] + mock_client.models.generate_content.return_value = mock_response + + resolver.determine_tool_to_use("test") + + call_args = mock_client.models.generate_content.call_args + self.assertEqual(call_args.kwargs["model"], "gemini-2.5-flash") + + +class TestFunctionCallResolverInit(unittest.TestCase): + """Tests for FunctionCallResolver initialization.""" + + @mock.patch.dict(os.environ, {"LLM_PROVIDER": "minimax"}, clear=False) + @mock.patch("common.function_call_resolver.get_provider", return_value=LLMProvider.MINIMAX) + @mock.patch("common.function_call_resolver.get_model", return_value="MiniMax-M2.7") + def test_minimax_init_skips_genai_config(self, mock_model, mock_provider): + """MiniMax provider should not create Google-specific config.""" + from common.function_call_resolver import FunctionCallResolver + + def dummy(): + """test""" + pass + + resolver = FunctionCallResolver( + llm_client=None, tools=[dummy], instructions="test" + ) + self.assertFalse(hasattr(resolver, "_config")) + + @mock.patch.dict(os.environ, {"LLM_PROVIDER": "google"}, clear=False) + @mock.patch("common.function_call_resolver.get_provider", return_value=LLMProvider.GOOGLE) + @mock.patch("common.function_call_resolver.get_model", return_value="gemini-2.5-flash") + def test_google_init_creates_config(self, mock_model, mock_provider): + """Google provider should create the function calling config.""" + mock_client = mock.MagicMock() + + from common.function_call_resolver import FunctionCallResolver + + def dummy(): + """test""" + pass + + resolver = FunctionCallResolver( + llm_client=mock_client, tools=[dummy], instructions="test" + ) + self.assertTrue(hasattr(resolver, "_config")) + + +if __name__ == "__main__": + unittest.main() diff --git a/samples/python/tests/test_integration_minimax.py b/samples/python/tests/test_integration_minimax.py new file mode 100644 index 00000000..450ce94b --- /dev/null +++ b/samples/python/tests/test_integration_minimax.py @@ -0,0 +1,134 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for MiniMax LLM provider. + +These tests require a live MiniMax API key set via MINIMAX_API_KEY. +They are skipped automatically when the key is not available. +""" + +import json +import os +import unittest + +_HAS_KEY = bool(os.environ.get("MINIMAX_API_KEY")) +_SKIP_REASON = "MINIMAX_API_KEY not set" + + +@unittest.skipUnless(_HAS_KEY, _SKIP_REASON) +class TestMiniMaxFunctionCalling(unittest.TestCase): + """Integration: function-calling via MiniMax API.""" + + def test_resolve_function_call(self): + from common.minimax_client import minimax_resolve_function_call + + def search_products(): + """Search for products in the catalog based on user query.""" + pass + + def process_payment(): + """Process a payment transaction for the user's cart.""" + pass + + def get_shipping_info(): + """Get shipping options and delivery estimates.""" + pass + + result = minimax_resolve_function_call( + model="MiniMax-M2.7", + tools=[search_products, process_payment, get_shipping_info], + system_prompt="You help users shop for products.", + user_prompt="I want to find some running shoes", + ) + self.assertIn(result, ["search_products", "process_payment", "get_shipping_info"]) + self.assertEqual(result, "search_products") + + def test_resolve_payment_tool(self): + from common.minimax_client import minimax_resolve_function_call + + def search_products(): + """Search for products in the catalog.""" + pass + + def process_payment(): + """Process payment for the user's selected items and charge the credit card.""" + pass + + result = minimax_resolve_function_call( + model="MiniMax-M2.7", + tools=[search_products, process_payment], + system_prompt="You help users complete purchases.", + user_prompt="Please charge my credit card for the order", + ) + # Model should return a valid tool name (non-deterministic). + self.assertIsInstance(result, str) + self.assertTrue(len(result) > 0) + + +@unittest.skipUnless(_HAS_KEY, _SKIP_REASON) +class TestMiniMaxJsonGeneration(unittest.TestCase): + """Integration: JSON generation via MiniMax API.""" + + def test_generate_json_items(self): + from common.minimax_client import minimax_generate_json + + result = minimax_generate_json( + model="MiniMax-M2.7", + prompt=( + "Generate a JSON object with a key 'items' containing a list " + "of exactly 2 product items. Each item must have 'label' " + "(string, product name) and 'amount' (object with 'currency' " + "string and 'value' number). The items should be running shoes." + ), + ) + self.assertIsInstance(result, dict) + self.assertIn("items", result) + items = result["items"] + self.assertEqual(len(items), 2) + for item in items: + self.assertIn("label", item) + self.assertIn("amount", item) + self.assertIn("currency", item["amount"]) + self.assertIn("value", item["amount"]) + + +@unittest.skipUnless(_HAS_KEY, _SKIP_REASON) +class TestMiniMaxProviderEndToEnd(unittest.TestCase): + """Integration: end-to-end provider configuration.""" + + def test_provider_config_flow(self): + """Full flow: set env vars -> get provider -> get model -> call API.""" + with unittest.mock.patch.dict( + os.environ, + {"LLM_PROVIDER": "minimax", "MINIMAX_API_KEY": os.environ.get("MINIMAX_API_KEY", "")}, + ): + from common.llm_config import get_model, get_provider, LLMProvider + + provider = get_provider() + self.assertEqual(provider, LLMProvider.MINIMAX) + + model = get_model() + self.assertEqual(model, "MiniMax-M2.7") + + from common.minimax_client import minimax_generate_json + + result = minimax_generate_json( + model=model, + prompt='Return a JSON object with key "status" set to "ok".', + ) + self.assertEqual(result.get("status"), "ok") + + +if __name__ == "__main__": + unittest.main() diff --git a/samples/python/tests/test_llm_config.py b/samples/python/tests/test_llm_config.py new file mode 100644 index 00000000..0ab2477e --- /dev/null +++ b/samples/python/tests/test_llm_config.py @@ -0,0 +1,98 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for common.llm_config.""" + +import os +import unittest +from unittest import mock + +from common.llm_config import LLMProvider +from common.llm_config import get_model +from common.llm_config import get_provider + + +class TestGetProvider(unittest.TestCase): + """Tests for get_provider().""" + + @mock.patch.dict(os.environ, {}, clear=True) + def test_default_provider_is_google(self): + self.assertEqual(get_provider(), LLMProvider.GOOGLE) + + @mock.patch.dict(os.environ, {"LLM_PROVIDER": "google"}) + def test_explicit_google(self): + self.assertEqual(get_provider(), LLMProvider.GOOGLE) + + @mock.patch.dict(os.environ, {"LLM_PROVIDER": "minimax"}) + def test_minimax_provider(self): + self.assertEqual(get_provider(), LLMProvider.MINIMAX) + + @mock.patch.dict(os.environ, {"LLM_PROVIDER": "MINIMAX"}) + def test_case_insensitive(self): + self.assertEqual(get_provider(), LLMProvider.MINIMAX) + + @mock.patch.dict(os.environ, {"LLM_PROVIDER": " minimax "}) + def test_strips_whitespace(self): + self.assertEqual(get_provider(), LLMProvider.MINIMAX) + + @mock.patch.dict(os.environ, {"LLM_PROVIDER": "unsupported"}) + def test_unsupported_raises(self): + with self.assertRaises(ValueError) as ctx: + get_provider() + self.assertIn("unsupported", str(ctx.exception).lower()) + + +class TestGetModel(unittest.TestCase): + """Tests for get_model().""" + + @mock.patch.dict(os.environ, {}, clear=True) + def test_default_google_model(self): + self.assertEqual(get_model(), "gemini-2.5-flash") + + @mock.patch.dict(os.environ, {"LLM_PROVIDER": "minimax"}, clear=True) + def test_default_minimax_model(self): + self.assertEqual(get_model(), "MiniMax-M2.7") + + @mock.patch.dict( + os.environ, + {"LLM_PROVIDER": "minimax", "LLM_MODEL": "MiniMax-M2.7-highspeed"}, + ) + def test_explicit_model_overrides(self): + self.assertEqual(get_model(), "MiniMax-M2.7-highspeed") + + @mock.patch.dict(os.environ, {"LLM_MODEL": "gemini-2.5-pro"}) + def test_explicit_model_with_google(self): + self.assertEqual(get_model(), "gemini-2.5-pro") + + @mock.patch.dict(os.environ, {"LLM_MODEL": " "}, clear=True) + def test_blank_model_falls_back(self): + self.assertEqual(get_model(), "gemini-2.5-flash") + + +class TestLLMProviderEnum(unittest.TestCase): + """Tests for LLMProvider enum.""" + + def test_google_value(self): + self.assertEqual(LLMProvider.GOOGLE.value, "google") + + def test_minimax_value(self): + self.assertEqual(LLMProvider.MINIMAX.value, "minimax") + + def test_from_string(self): + self.assertEqual(LLMProvider("google"), LLMProvider.GOOGLE) + self.assertEqual(LLMProvider("minimax"), LLMProvider.MINIMAX) + + +if __name__ == "__main__": + unittest.main() diff --git a/samples/python/tests/test_minimax_client.py b/samples/python/tests/test_minimax_client.py new file mode 100644 index 00000000..b3666ed9 --- /dev/null +++ b/samples/python/tests/test_minimax_client.py @@ -0,0 +1,341 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for common.minimax_client.""" + +import json +import os +import unittest +from unittest import mock + +from common.minimax_client import _strip_think_tags +from common.minimax_client import MINIMAX_BASE_URL + + +class TestStripThinkTags(unittest.TestCase): + """Tests for the _strip_think_tags helper.""" + + def test_no_tags(self): + self.assertEqual(_strip_think_tags("Hello world"), "Hello world") + + def test_single_tag(self): + self.assertEqual( + _strip_think_tags("reasoningAnswer"), + "Answer", + ) + + def test_multiline_tag(self): + text = "\nline1\nline2\n\nResult" + self.assertEqual(_strip_think_tags(text), "Result") + + def test_multiple_tags(self): + text = "aXbY" + self.assertEqual(_strip_think_tags(text), "XY") + + def test_empty_tag(self): + self.assertEqual(_strip_think_tags("clean"), "clean") + + def test_strips_whitespace(self): + self.assertEqual( + _strip_think_tags(" x result "), + "result", + ) + + +class TestGetClient(unittest.TestCase): + """Tests for MiniMax client creation.""" + + @mock.patch.dict(os.environ, {}, clear=True) + def test_missing_api_key_raises(self): + from common.minimax_client import _get_client + + with self.assertRaises(ValueError) as ctx: + _get_client() + self.assertIn("MINIMAX_API_KEY", str(ctx.exception)) + + @mock.patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}) + def test_client_created_with_base_url(self): + from common.minimax_client import _get_client + + client = _get_client() + self.assertEqual(str(client.base_url).rstrip("/"), MINIMAX_BASE_URL) + + +class TestMinimaxResolveFunction(unittest.TestCase): + """Tests for minimax_resolve_function_call.""" + + @mock.patch("common.minimax_client._get_client") + def test_returns_tool_name(self, mock_get_client): + """Verify the function extracts the tool name from the response.""" + mock_client = mock.MagicMock() + mock_get_client.return_value = mock_client + + mock_tool_call = mock.MagicMock() + mock_tool_call.function.name = "find_items_workflow" + + mock_message = mock.MagicMock() + mock_message.tool_calls = [mock_tool_call] + + mock_choice = mock.MagicMock() + mock_choice.message = mock_message + + mock_response = mock.MagicMock() + mock_response.choices = [mock_choice] + + mock_client.chat.completions.create.return_value = mock_response + + def find_items_workflow(): + """Finds items.""" + pass + + def process_payment(): + """Processes payment.""" + pass + + from common.minimax_client import minimax_resolve_function_call + + result = minimax_resolve_function_call( + model="MiniMax-M2.7", + tools=[find_items_workflow, process_payment], + system_prompt="You are helpful.", + user_prompt="Find me shoes", + ) + self.assertEqual(result, "find_items_workflow") + + call_args = mock_client.chat.completions.create.call_args + self.assertEqual(call_args.kwargs["model"], "MiniMax-M2.7") + self.assertEqual(call_args.kwargs["tool_choice"], "required") + tools_sent = call_args.kwargs["tools"] + self.assertEqual(len(tools_sent), 2) + self.assertEqual(tools_sent[0]["function"]["name"], "find_items_workflow") + self.assertEqual(tools_sent[1]["function"]["name"], "process_payment") + + @mock.patch("common.minimax_client._get_client") + def test_returns_unknown_on_no_tool_calls(self, mock_get_client): + mock_client = mock.MagicMock() + mock_get_client.return_value = mock_client + + mock_message = mock.MagicMock() + mock_message.tool_calls = None + + mock_choice = mock.MagicMock() + mock_choice.message = mock_message + + mock_response = mock.MagicMock() + mock_response.choices = [mock_choice] + + mock_client.chat.completions.create.return_value = mock_response + + from common.minimax_client import minimax_resolve_function_call + + result = minimax_resolve_function_call( + model="MiniMax-M2.7", + tools=[lambda: None], + system_prompt="test", + user_prompt="test", + ) + self.assertEqual(result, "Unknown") + + @mock.patch("common.minimax_client._get_client") + def test_returns_unknown_on_empty_choices(self, mock_get_client): + mock_client = mock.MagicMock() + mock_get_client.return_value = mock_client + + mock_response = mock.MagicMock() + mock_response.choices = [] + + mock_client.chat.completions.create.return_value = mock_response + + from common.minimax_client import minimax_resolve_function_call + + result = minimax_resolve_function_call( + model="MiniMax-M2.7", + tools=[lambda: None], + system_prompt="test", + user_prompt="test", + ) + self.assertEqual(result, "Unknown") + + +class TestMinimaxGenerateJson(unittest.TestCase): + """Tests for minimax_generate_json.""" + + @mock.patch("common.minimax_client._get_client") + def test_returns_parsed_json(self, mock_get_client): + mock_client = mock.MagicMock() + mock_get_client.return_value = mock_client + + expected = {"items": [{"label": "Shoes", "amount": {"currency": "USD", "value": 99.99}}]} + mock_message = mock.MagicMock() + mock_message.content = json.dumps(expected) + + mock_choice = mock.MagicMock() + mock_choice.message = mock_message + + mock_response = mock.MagicMock() + mock_response.choices = [mock_choice] + + mock_client.chat.completions.create.return_value = mock_response + + from common.minimax_client import minimax_generate_json + + result = minimax_generate_json( + model="MiniMax-M2.7", + prompt="Generate items", + ) + self.assertEqual(result, expected) + + call_args = mock_client.chat.completions.create.call_args + self.assertEqual(call_args.kwargs["response_format"], {"type": "json_object"}) + + @mock.patch("common.minimax_client._get_client") + def test_strips_think_tags(self, mock_get_client): + mock_client = mock.MagicMock() + mock_get_client.return_value = mock_client + + mock_message = mock.MagicMock() + mock_message.content = 'reasoning here{"result": true}' + + mock_choice = mock.MagicMock() + mock_choice.message = mock_message + + mock_response = mock.MagicMock() + mock_response.choices = [mock_choice] + + mock_client.chat.completions.create.return_value = mock_response + + from common.minimax_client import minimax_generate_json + + result = minimax_generate_json( + model="MiniMax-M2.7", + prompt="test", + ) + self.assertEqual(result, {"result": True}) + + @mock.patch("common.minimax_client._get_client") + def test_system_prompt_included(self, mock_get_client): + mock_client = mock.MagicMock() + mock_get_client.return_value = mock_client + + mock_message = mock.MagicMock() + mock_message.content = "{}" + + mock_choice = mock.MagicMock() + mock_choice.message = mock_message + + mock_response = mock.MagicMock() + mock_response.choices = [mock_choice] + + mock_client.chat.completions.create.return_value = mock_response + + from common.minimax_client import minimax_generate_json + + minimax_generate_json( + model="MiniMax-M2.7", + prompt="test", + system_prompt="You are a catalog agent.", + ) + + call_args = mock_client.chat.completions.create.call_args + messages = call_args.kwargs["messages"] + self.assertEqual(messages[0]["role"], "system") + self.assertEqual(messages[0]["content"], "You are a catalog agent.") + + @mock.patch("common.minimax_client._get_client") + def test_no_system_prompt(self, mock_get_client): + mock_client = mock.MagicMock() + mock_get_client.return_value = mock_client + + mock_message = mock.MagicMock() + mock_message.content = "{}" + + mock_choice = mock.MagicMock() + mock_choice.message = mock_message + + mock_response = mock.MagicMock() + mock_response.choices = [mock_choice] + + mock_client.chat.completions.create.return_value = mock_response + + from common.minimax_client import minimax_generate_json + + minimax_generate_json( + model="MiniMax-M2.7", + prompt="test", + ) + + call_args = mock_client.chat.completions.create.call_args + messages = call_args.kwargs["messages"] + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["role"], "user") + + @mock.patch("common.minimax_client._get_client") + def test_empty_content_returns_empty_dict(self, mock_get_client): + mock_client = mock.MagicMock() + mock_get_client.return_value = mock_client + + mock_message = mock.MagicMock() + mock_message.content = None + + mock_choice = mock.MagicMock() + mock_choice.message = mock_message + + mock_response = mock.MagicMock() + mock_response.choices = [mock_choice] + + mock_client.chat.completions.create.return_value = mock_response + + from common.minimax_client import minimax_generate_json + + result = minimax_generate_json( + model="MiniMax-M2.7", + prompt="test", + ) + self.assertEqual(result, {}) + + @mock.patch("common.minimax_client._get_client") + def test_temperature_clamped(self, mock_get_client): + mock_client = mock.MagicMock() + mock_get_client.return_value = mock_client + + mock_message = mock.MagicMock() + mock_message.content = "{}" + + mock_choice = mock.MagicMock() + mock_choice.message = mock_message + + mock_response = mock.MagicMock() + mock_response.choices = [mock_choice] + + mock_client.chat.completions.create.return_value = mock_response + + from common.minimax_client import minimax_generate_json + + minimax_generate_json(model="MiniMax-M2.7", prompt="test") + + call_args = mock_client.chat.completions.create.call_args + temp = call_args.kwargs["temperature"] + self.assertGreater(temp, 0.0) + self.assertLessEqual(temp, 1.0) + + +class TestConstants(unittest.TestCase): + """Tests for module-level constants.""" + + def test_base_url(self): + self.assertEqual(MINIMAX_BASE_URL, "https://api.minimax.io/v1") + + +if __name__ == "__main__": + unittest.main()