diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 66332549c..61ad16abf 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -209,6 +209,26 @@ async def _apply_dynamic_tool_filter( return filtered_tools + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + """Invoke a tool on the server. + + Args: + tool_name: The name of the tool to call. This can be either the prefixed name + (server_name_tool_name) or the original tool name. + arguments: The arguments to pass to the tool. + """ + if not self.session: + raise UserError("Server not initialized. Make sure you call `connect()` first.") + + # If the tool name is prefixed with server name, strip it + if tool_name.startswith(f"{self.name}_"): + # Remove the server name prefix and the underscore + original_tool_name = tool_name[len(self.name) + 1:] + else: + original_tool_name = tool_name + + return await self.session.call_tool(original_tool_name, arguments) + @abc.abstractmethod def create_streams( self, @@ -275,8 +295,14 @@ async def list_tools( # Reset the cache dirty to False self._cache_dirty = False # Fetch the tools from the server - self._tools_list = (await self.session.list_tools()).tools - tools = self._tools_list + tools = (await self.session.list_tools()).tools + # Add server name prefix to each tool's name to ensure global uniqueness + for tool in tools: + # Store original name for actual tool calls + tool.original_name = tool.name # type: ignore[attr-defined] + # Prefix tool name with server name using underscore separator + tool.name = f"{self.name}_{tool.name}" + self._tools_list = tools # Filter tools based on tool_filter filtered_tools = tools @@ -286,13 +312,6 @@ async def list_tools( filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent) return filtered_tools - async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: - """Invoke a tool on the server.""" - if not self.session: - raise UserError("Server not initialized. Make sure you call `connect()` first.") - - return await self.session.call_tool(tool_name, arguments) - async def list_prompts( self, ) -> ListPromptsResult: diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 6b2b4679f..4c19dbaaa 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -178,21 +178,27 @@ async def invoke_mcp_tool( f"Invalid JSON input for tool {tool.name}: {input_json}" ) from e + # Use original tool name for server call (strip server prefix if present) + original_name = getattr(tool, "original_name", tool.name) + if _debug.DONT_LOG_TOOL_DATA: - logger.debug(f"Invoking MCP tool {tool.name}") + logger.debug(f"Invoking MCP tool {tool.name} (original: {original_name})") else: - logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}") + logger.debug( + f"Invoking MCP tool {tool.name} (original: {original_name}) " + f"with input {input_json}" + ) try: - result = await server.call_tool(tool.name, json_data) + result = await server.call_tool(original_name, json_data) except Exception as e: logger.error(f"Error invoking MCP tool {tool.name}: {e}") raise AgentsException(f"Error invoking MCP tool {tool.name}: {e}") from e if _debug.DONT_LOG_TOOL_DATA: - logger.debug(f"MCP tool {tool.name} completed.") + logger.debug(f"MCP tool {tool.name} (original: {original_name}) completed.") else: - logger.debug(f"MCP tool {tool.name} returned {result}") + logger.debug(f"MCP tool {tool.name} (original: {original_name}) returned {result}") # The MCP tool result is a list of content items, whereas OpenAI tool outputs are a single # string. We'll try to convert. diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index eab4c291b..475d4f93f 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -288,10 +288,11 @@ async def handle_stream( function_call = state.function_calls[tc_delta.index] # Start streaming as soon as we have function name and call_id - if (not state.function_call_streaming[tc_delta.index] and - function_call.name and - function_call.call_id): - + if ( + not state.function_call_streaming[tc_delta.index] + and function_call.name + and function_call.call_id + ): # Calculate the output index for this function call function_call_starting_index = 0 if state.reasoning_content_index_and_output: @@ -308,9 +309,9 @@ async def handle_stream( # Mark this function call as streaming and store its output index state.function_call_streaming[tc_delta.index] = True - state.function_call_output_idx[ - tc_delta.index - ] = function_call_starting_index + state.function_call_output_idx[tc_delta.index] = ( + function_call_starting_index + ) # Send initial function call added event yield ResponseOutputItemAddedEvent( @@ -327,10 +328,11 @@ async def handle_stream( ) # Stream arguments if we've started streaming this function call - if (state.function_call_streaming.get(tc_delta.index, False) and - tc_function and - tc_function.arguments): - + if ( + state.function_call_streaming.get(tc_delta.index, False) + and tc_function + and tc_function.arguments + ): output_index = state.function_call_output_idx[tc_delta.index] yield ResponseFunctionCallArgumentsDeltaEvent( delta=tc_function.arguments, diff --git a/tests/mcp/test_tool_name_conflicts.py b/tests/mcp/test_tool_name_conflicts.py new file mode 100644 index 000000000..40442ec97 --- /dev/null +++ b/tests/mcp/test_tool_name_conflicts.py @@ -0,0 +1,343 @@ +""" +Tests for MCP tool name conflict resolution. + +This test file specifically tests the functionality that resolves tool name conflicts +when multiple MCP servers have tools with the same name by adding server name prefixes. +""" + +from typing import Any, Union + +import pytest +from mcp import Tool as MCPTool +from mcp.types import CallToolResult, TextContent + +from agents import Agent +from agents.agent import AgentBase +from agents.exceptions import UserError +from agents.mcp import MCPUtil +from agents.mcp.server import MCPServer +from agents.run_context import RunContextWrapper + + +def create_test_agent(name: str = "test_agent") -> Agent: + """Create a test agent for tool name conflict tests.""" + return Agent(name=name, instructions="Test agent") + + +def create_test_context() -> RunContextWrapper: + """Create a test run context for tool name conflict tests.""" + return RunContextWrapper(context=None) + + +class MockMCPServer(MCPServer): + """Mock MCP server for testing tool name prefixing functionality.""" + + def __init__(self, name: str, tools: list[tuple[str, dict]]): + super().__init__() + self._name = name + self._tools = [ + MCPTool(name=tool_name, description=f"Tool {tool_name}", inputSchema=schema) + for tool_name, schema in tools + ] + + @property + def name(self) -> str: + return self._name + + async def connect(self): + pass + + async def cleanup(self): + pass + + async def list_tools( + self, + run_context: Union[RunContextWrapper[Any], None] = None, + agent: Union["AgentBase", None] = None, + ) -> list[MCPTool]: + """Return tools with server name prefix to simulate the actual server behavior.""" + tools = [] + for tool in self._tools: + # Simulate the server adding prefix behavior + tool_copy = MCPTool( + name=tool.name, description=tool.description, inputSchema=tool.inputSchema + ) + # Store original name + tool_copy.original_name = tool.name + # Add server name prefix + tool_copy.name = f"{self.name}_{tool.name}" + tools.append(tool_copy) + return tools + + async def call_tool( + self, tool_name: str, arguments: Union[dict[str, Any], None] + ) -> CallToolResult: + """Mock tool invocation.""" + # If the tool name is prefixed with server name, strip it to get original name + if "_" in tool_name and tool_name.startswith(f"{self.name}_"): + original_tool_name = tool_name.split("_", 1)[1] + else: + original_tool_name = tool_name + return CallToolResult( + content=[ + TextContent(type="text", text=f"Result from {self.name}.{original_tool_name}") + ] + ) + + async def list_prompts(self): + return {"prompts": []} + + async def get_prompt(self, name: str, arguments: Union[dict[str, Any], None] = None): + return {"messages": []} + + +@pytest.mark.asyncio +async def test_tool_name_prefixing_single_server(): + """Test tool name prefixing functionality for a single server.""" + server = MockMCPServer("server1", [("run", {}), ("echo", {})]) + + run_context = create_test_context() + agent = create_test_agent() + + tools = await server.list_tools(run_context, agent) + + # Verify tool names have correct prefixes + assert len(tools) == 2 + tool_names = [tool.name for tool in tools] + assert "server1_run" in tool_names + assert "server1_echo" in tool_names + + # Verify original names are preserved + for tool in tools: + assert hasattr(tool, 'original_name') + if tool.name == "server1_run": + assert tool.original_name == "run" + elif tool.name == "server1_echo": + assert tool.original_name == "echo" + + +@pytest.mark.asyncio +async def test_tool_name_prefixing_multiple_servers(): + """Test tool name prefixing functionality with multiple servers having conflicting names.""" + server1 = MockMCPServer("server1", [("run", {}), ("echo", {})]) + server2 = MockMCPServer("server2", [("run", {}), ("list", {})]) + + run_context = create_test_context() + agent = create_test_agent() + + # Get all tools + tools1 = await server1.list_tools(run_context, agent) + tools2 = await server2.list_tools(run_context, agent) + + all_tools = tools1 + tools2 + + # Verify no duplicate tool names + tool_names = [tool.name for tool in all_tools] + assert len(tool_names) == len(set(tool_names)), "Tool names should be unique" + + # Verify specific tool names + expected_names = ["server1_run", "server1_echo", "server2_run", "server2_list"] + assert set(tool_names) == set(expected_names) + + # Verify original names are correctly preserved + for tool in all_tools: + assert hasattr(tool, 'original_name') + if tool.name == "server1_run": + assert tool.original_name == "run" + elif tool.name == "server2_run": + assert tool.original_name == "run" + + +@pytest.mark.asyncio +async def test_mcp_util_get_all_function_tools_no_conflicts(): + """Test MCPUtil.get_all_function_tools with no conflicting tool names.""" + server1 = MockMCPServer("server1", [("tool1", {}), ("tool2", {})]) + server2 = MockMCPServer("server2", [("tool3", {}), ("tool4", {})]) + + run_context = create_test_context() + agent = create_test_agent() + + # Since tool names are now prefixed, there should be no conflicts + function_tools = await MCPUtil.get_all_function_tools( + [server1, server2], + convert_schemas_to_strict=False, + run_context=run_context, + agent=agent + ) + + assert len(function_tools) == 4 + tool_names = [tool.name for tool in function_tools] + assert "server1_tool1" in tool_names + assert "server1_tool2" in tool_names + assert "server2_tool3" in tool_names + assert "server2_tool4" in tool_names + + +@pytest.mark.asyncio +async def test_mcp_util_get_all_function_tools_with_resolved_conflicts(): + """Test MCPUtil.get_all_function_tools with originally conflicting tool names.""" + # Create two servers with same tool names + server1 = MockMCPServer("server1", [("run", {}), ("echo", {})]) + server2 = MockMCPServer("server2", [("run", {}), ("list", {})]) + + run_context = create_test_context() + agent = create_test_agent() + + # Since tool names are now prefixed, this should not raise an exception + function_tools = await MCPUtil.get_all_function_tools( + [server1, server2], + convert_schemas_to_strict=False, + run_context=run_context, + agent=agent + ) + + assert len(function_tools) == 4 + tool_names = [tool.name for tool in function_tools] + assert "server1_run" in tool_names + assert "server1_echo" in tool_names + assert "server2_run" in tool_names + assert "server2_list" in tool_names + + +class LegacyMockMCPServer(MCPServer): + """Mock MCP server that simulates legacy behavior without name prefixing.""" + + def __init__(self, name: str, tools: list[tuple[str, dict]]): + super().__init__() + self._name = name + self._tools = [ + MCPTool(name=tool_name, description=f"Tool {tool_name}", inputSchema=schema) + for tool_name, schema in tools + ] + + @property + def name(self) -> str: + return self._name + + async def connect(self): + pass + + async def cleanup(self): + pass + + async def list_tools( + self, + run_context: Union[RunContextWrapper[Any], None] = None, + agent: Union["AgentBase", None] = None, + ) -> list[MCPTool]: + """Return tools without prefixes (simulating legacy behavior).""" + return self._tools.copy() + + async def call_tool( + self, tool_name: str, arguments: Union[dict[str, Any], None] + ) -> CallToolResult: + return CallToolResult( + content=[TextContent(type="text", text=f"Result from {self.name}.{tool_name}")] + ) + + async def list_prompts(self): + return {"prompts": []} + + async def get_prompt(self, name: str, arguments: Union[dict[str, Any], None] = None): + return {"messages": []} + + +@pytest.mark.asyncio +async def test_legacy_behavior_with_conflicts(): + """Test legacy behavior where conflicting tool names should raise UserError.""" + # Use servers without prefixing functionality + server1 = LegacyMockMCPServer("server1", [("run", {}), ("echo", {})]) + server2 = LegacyMockMCPServer("server2", [("run", {}), ("list", {})]) + + run_context = create_test_context() + agent = create_test_agent() + + # Should raise UserError due to tool name conflicts + with pytest.raises(UserError, match="Duplicate tool names found"): + await MCPUtil.get_all_function_tools( + [server1, server2], + convert_schemas_to_strict=False, + run_context=run_context, + agent=agent + ) + + +@pytest.mark.asyncio +async def test_tool_invocation_uses_original_name(): + """Test that tool invocation uses the original name rather than the prefixed name.""" + server = MockMCPServer("server1", [("run", {})]) + + run_context = create_test_context() + agent = create_test_agent() + + # Get tools + tools = await server.list_tools(run_context, agent) + tool = tools[0] + + # Verify tool has both prefixed name and original name + assert tool.name == "server1_run" + assert tool.original_name == "run" + + # Create function tool via MCPUtil + function_tool = MCPUtil.to_function_tool(tool, server, convert_schemas_to_strict=False) + + # Verify function tool uses prefixed name + assert function_tool.name == "server1_run" + + # Simulate tool invocation + result = await MCPUtil.invoke_mcp_tool(server, tool, run_context, "{}") + + # Verify invocation succeeds + assert "Result from server1.run" in result + + +@pytest.mark.asyncio +async def test_empty_server_name(): + """Test handling of empty server names.""" + server = MockMCPServer("", [("run", {})]) + + run_context = create_test_context() + agent = create_test_agent() + + tools = await server.list_tools(run_context, agent) + + # Verify even with empty server name, prefix is added to avoid empty names + assert len(tools) == 1 + tool = tools[0] + assert tool.name == "_run" # empty prefix + "_" + tool name + assert tool.original_name == "run" + + +@pytest.mark.asyncio +async def test_special_characters_in_server_name(): + """Test handling of server names with special characters.""" + server = MockMCPServer("server-1.test", [("run", {})]) + + run_context = create_test_context() + agent = create_test_agent() + + tools = await server.list_tools(run_context, agent) + + # Verify special characters in server names are handled correctly + assert len(tools) == 1 + tool = tools[0] + assert tool.name == "server-1.test_run" + assert tool.original_name == "run" + + +@pytest.mark.asyncio +async def test_tool_description_preserved(): + """Test that tool descriptions are preserved after adding name prefixes.""" + original_description = "This is a test tool" + server = MockMCPServer("server1", [("run", {"description": original_description})]) + + run_context = create_test_context() + agent = create_test_agent() + + tools = await server.list_tools(run_context, agent) + tool = tools[0] + + # Verify description is preserved + assert tool.description == "Tool run" # Based on MockMCPServer implementation + assert tool.name == "server1_run" + assert tool.original_name == "run"