Skip to content

Commit 77fa5f1

Browse files
anmourchencopybara-github
authored andcommitted
test: add the unit test for EnterpriseWebSearchTool
PiperOrigin-RevId: 780722837
1 parent e545e5a commit 77fa5f1

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.adk.agents.invocation_context import InvocationContext
16+
from google.adk.agents.sequential_agent import SequentialAgent
17+
from google.adk.models.llm_request import LlmRequest
18+
from google.adk.sessions.in_memory_session_service import InMemorySessionService
19+
from google.adk.tools.enterprise_search_tool import EnterpriseWebSearchTool
20+
from google.adk.tools.tool_context import ToolContext
21+
from google.genai import types
22+
import pytest
23+
24+
25+
async def _create_tool_context() -> ToolContext:
26+
"""Creates a ToolContext for testing."""
27+
session_service = InMemorySessionService()
28+
session = await session_service.create_session(
29+
app_name='test_app', user_id='test_user'
30+
)
31+
agent = SequentialAgent(name='test_agent')
32+
invocation_context = InvocationContext(
33+
invocation_id='invocation_id',
34+
agent=agent,
35+
session=session,
36+
session_service=session_service,
37+
)
38+
return ToolContext(invocation_context)
39+
40+
41+
@pytest.mark.asyncio
42+
@pytest.mark.parametrize(
43+
'model_name',
44+
[
45+
'gemini-2.5-flash',
46+
'projects/test-project/locations/global/publishers/google/models/gemini-2.5-flash',
47+
],
48+
)
49+
async def test_process_llm_request_success_with_gemini_models(model_name):
50+
tool = EnterpriseWebSearchTool()
51+
llm_request = LlmRequest(
52+
model=model_name, config=types.GenerateContentConfig()
53+
)
54+
tool_context = await _create_tool_context()
55+
56+
await tool.process_llm_request(
57+
tool_context=tool_context, llm_request=llm_request
58+
)
59+
60+
assert (
61+
llm_request.config.tools[0].enterprise_web_search
62+
== types.EnterpriseWebSearch()
63+
)
64+
65+
66+
@pytest.mark.asyncio
67+
async def test_process_llm_request_failure_with_non_gemini_models():
68+
tool = EnterpriseWebSearchTool()
69+
llm_request = LlmRequest(model='gpt-4o', config=types.GenerateContentConfig())
70+
tool_context = await _create_tool_context()
71+
72+
with pytest.raises(ValueError) as exc_info:
73+
await tool.process_llm_request(
74+
tool_context=tool_context, llm_request=llm_request
75+
)
76+
assert 'is not supported for model' in str(exc_info.value)
77+
78+
79+
@pytest.mark.asyncio
80+
async def test_process_llm_request_failure_with_multiple_tools_gemini_1_models():
81+
tool = EnterpriseWebSearchTool()
82+
llm_request = LlmRequest(
83+
model='gemini-1.5-flash',
84+
config=types.GenerateContentConfig(
85+
tools=[
86+
types.Tool(google_search=types.GoogleSearch()),
87+
]
88+
),
89+
)
90+
tool_context = await _create_tool_context()
91+
92+
with pytest.raises(ValueError) as exc_info:
93+
await tool.process_llm_request(
94+
tool_context=tool_context, llm_request=llm_request
95+
)
96+
assert 'can not be used with other tools in Gemini 1.x.' in str(
97+
exc_info.value
98+
)

0 commit comments

Comments
 (0)