diff --git a/docs/mcp.md b/docs/mcp.md index eef61a047..d49791975 100644 --- a/docs/mcp.md +++ b/docs/mcp.md @@ -169,9 +169,9 @@ agent = Agent( ## Caching -Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp]. You should only do this if you're certain the tool list will not change. +Every time an Agent runs, it calls `list_tools()` and `list_prompts()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools and prompts, you can pass `cache_tools_list=True` and `cache_prompts_list=True` to [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp]. You should only do this if you're certain the tools and the prompts lists will not change. -If you want to invalidate the cache, you can call `invalidate_tools_cache()` on the servers. +If you want to invalidate the cache, you can call `invalidate_tools_cache()` and `invalidate_prompts_cache()` on the servers. ## End-to-end examples diff --git a/examples/mcp/caching/README.md b/examples/mcp/caching/README.md new file mode 100644 index 000000000..667cc5714 --- /dev/null +++ b/examples/mcp/caching/README.md @@ -0,0 +1,13 @@ +# Caching Example + +This example show how to integrate tools and prompts caching using a Streamable HTTP server in [server.py](server.py). + +Run the example via: + +``` +uv run python examples/mcp/caching/main.py +``` + +## Details + +The example uses the `MCPServerStreamableHttp` class from `agents.mcp`. The server runs in a sub-process at `https://localhost:8000/mcp`. diff --git a/examples/mcp/caching/main.py b/examples/mcp/caching/main.py new file mode 100644 index 000000000..b292b1910 --- /dev/null +++ b/examples/mcp/caching/main.py @@ -0,0 +1,83 @@ +import asyncio +import os +import shutil +import subprocess +import time +from typing import Any + +from agents import gen_trace_id, trace +from agents.mcp import MCPServerStreamableHttp + + +async def run(mcp_server: MCPServerStreamableHttp): + print("Cached tools before invoking tool_list") + print(mcp_server._tools_list) + + print("Cached tools names after invoking list_tools") + await mcp_server.list_tools() + cached_tools_list = mcp_server._tools_list + if cached_tools_list: + for tool in cached_tools_list: + print(f"name: {tool.name}") + + else: + print("Failed to cache list_prompts") + + print("Cached prompts before invoking list_prompts") + print(mcp_server._prompts_list) + + print("Cached prompts after invoking list_prompts") + await mcp_server.list_prompts() + cached_prompts_list = mcp_server._prompts_list + if cached_prompts_list: + for prompt in cached_prompts_list.prompts: + print(f"name: {prompt.name}") + else: + print("Failed to cache list_prompts") + +async def main(): + async with MCPServerStreamableHttp( + name="Streamable HTTP Python Server", + cache_tools_list=True, + cache_prompts_list=True, + params={ + "url": "http://localhost:8000/mcp", + }, + ) as server: + trace_id = gen_trace_id() + with trace(workflow_name="Caching Example", trace_id=trace_id): + print(f"View trace: https://platform.openai.com/traces/trace?trace_id={trace_id}\n") + await run(server) + + +if __name__ == "__main__": + # Let's make sure the user has uv installed + if not shutil.which("uv"): + raise RuntimeError( + "uv is not installed. Please install it: https://docs.astral.sh/uv/getting-started/installation/" + ) + + # We'll run the Streamable HTTP server in a subprocess. Usually this would be a remote server, but for this + # demo, we'll run it locally at http://localhost:8000/mcp + process: subprocess.Popen[Any] | None = None + try: + this_dir = os.path.dirname(os.path.abspath(__file__)) + server_file = os.path.join(this_dir, "server.py") + + print("Starting Streamable HTTP server at http://localhost:8000/mcp ...") + + # Run `uv run server.py` to start the Streamable HTTP server + process = subprocess.Popen(["uv", "run", server_file]) + # Give it 3 seconds to start + time.sleep(3) + + print("Streamable HTTP server started. Running example...\n\n") + except Exception as e: + print(f"Error starting Streamable HTTP server: {e}") + exit(1) + + try: + asyncio.run(main()) + finally: + if process: + process.terminate() diff --git a/examples/mcp/caching/server.py b/examples/mcp/caching/server.py new file mode 100644 index 000000000..0a031dc8d --- /dev/null +++ b/examples/mcp/caching/server.py @@ -0,0 +1,37 @@ +import random + +import requests +from mcp.server.fastmcp import FastMCP + +# Create server +mcp = FastMCP("Echo Server") + + +@mcp.tool() +def add(a: int, b: int) -> int: + """Add two numbers""" + print(f"[debug-server] add({a}, {b})") + return a + b + + +@mcp.tool() +def get_secret_word() -> str: + print("[debug-server] get_secret_word()") + return random.choice(["apple", "banana", "cherry"]) + + +@mcp.tool() +def get_current_weather(city: str) -> str: + print(f"[debug-server] get_current_weather({city})") + + endpoint = "https://wttr.in" + response = requests.get(f"{endpoint}/{city}") + return response.text + +@mcp.prompt() +def system_prompt() -> str: + return "Use the tools to answer the questions." + + +if __name__ == "__main__": + mcp.run(transport="streamable-http") diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 91a9274fc..05c3404b5 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -84,17 +84,26 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC): def __init__( self, cache_tools_list: bool, + cache_prompts_list: bool, client_session_timeout_seconds: float | None, tool_filter: ToolFilter = None, ): """ Args: cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be - cached and only fetched from the server once. If `False`, the tools list will be - fetched from the server on each call to `list_tools()`. The cache can be invalidated - by calling `invalidate_tools_cache()`. You should set this to `True` if you know the - server will not change its tools list, because it can drastically improve latency - (by avoiding a round-trip to the server every time). + cached and only fetched from the server once. If `False`, the tools list will be + fetched from the server on each call to `list_tools()`. The cache can be invalidated + by calling `invalidate_tools_cache()`. You should set this to `True` if you know the + server will not change its tools list, because it can drastically improve latency + (by avoiding a round-trip to the server every time). + + cache_prompts_list: Whether to cache the prompts list. If `True`, the prompts list + will be cached and only fetched from the server once. If `False`, the prompts + list will be fetched from the server on each call to `list_prompts()`. + The cache can be invalidated by calling `invalidate_prompts_cache()`. + You should set this to `True` if you know the server will not change + its prompts list, because it can drastically improve latency + (by avoiding a round-trip to the server every time). client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. tool_filter: The tool filter to use for filtering tools. @@ -103,13 +112,16 @@ def __init__( self.exit_stack: AsyncExitStack = AsyncExitStack() self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.cache_tools_list = cache_tools_list + self.cache_prompts_list = cache_prompts_list self.server_initialize_result: InitializeResult | None = None self.client_session_timeout_seconds = client_session_timeout_seconds - # The cache is always dirty at startup, so that we fetch tools at least once - self._cache_dirty = True + # The cache is always dirty at startup, so that we fetch tools and prompts at least once + self._cache_dirty_tools = True self._tools_list: list[MCPTool] | None = None + self._cache_dirty_prompts = True + self._prompts_list: ListPromptsResult | None = None self.tool_filter = tool_filter @@ -213,7 +225,11 @@ async def __aexit__(self, exc_type, exc_value, traceback): def invalidate_tools_cache(self): """Invalidate the tools cache.""" - self._cache_dirty = True + self._cache_dirty_tools = True + + def invalidate_prompts_cache(self): + """Invalidate the prompts cache.""" + self._cache_dirty_prompts = True async def connect(self): """Connect to the server.""" @@ -251,11 +267,11 @@ async def list_tools( raise UserError("Server not initialized. Make sure you call `connect()` first.") # Return from cache if caching is enabled, we have tools, and the cache is not dirty - if self.cache_tools_list and not self._cache_dirty and self._tools_list: + if self.cache_tools_list and not self._cache_dirty_tools and self._tools_list: tools = self._tools_list else: # Reset the cache dirty to False - self._cache_dirty = False + self._cache_dirty_tools = False # Fetch the tools from the server self._tools_list = (await self.session.list_tools()).tools tools = self._tools_list @@ -282,7 +298,16 @@ async def list_prompts( if not self.session: raise UserError("Server not initialized. Make sure you call `connect()` first.") - return await self.session.list_prompts() + if self.cache_prompts_list and not self._cache_dirty_prompts and self._prompts_list: + prompts = self._prompts_list + else: + # Reset the cache dirty to False + self._cache_dirty_prompts = False + # Fetch the prompts from the server + self._prompts_list = await self.session.list_prompts() + prompts = self._prompts_list + + return prompts async def get_prompt( self, name: str, arguments: dict[str, Any] | None = None @@ -343,6 +368,7 @@ def __init__( self, params: MCPServerStdioParams, cache_tools_list: bool = False, + cache_prompts_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, tool_filter: ToolFilter = None, @@ -354,21 +380,32 @@ def __init__( start the server, the args to pass to the command, the environment variables to set for the server, the working directory to use when spawning the process, and the text encoding used when sending/receiving messages to the server. + cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be cached and only fetched from the server once. If `False`, the tools list will be fetched from the server on each call to `list_tools()`. The cache can be invalidated by calling `invalidate_tools_cache()`. You should set this to `True` if you know the server will not change its tools list, because it can drastically improve latency (by avoiding a round-trip to the server every time). + + cache_prompts_list: Whether to cache the prompts list. If `True`, the prompts list + will be cached and only fetched from the server once. If `False`, the prompts + list will be fetched from the server on each call to `list_prompts()`. + The cache can be invalidated by calling `invalidate_prompts_cache()`. + You should set this to `True` if you know the server will not change + its prompts list, because it can drastically improve latency + (by avoiding a round-trip to the server every time). + name: A readable name for the server. If not provided, we'll create one from the command. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. tool_filter: The tool filter to use for filtering tools. """ super().__init__( - cache_tools_list, - client_session_timeout_seconds, - tool_filter, + cache_tools_list=cache_tools_list, + cache_prompts_list=cache_prompts_list, + client_session_timeout_seconds=client_session_timeout_seconds, + tool_filter=tool_filter, ) self.params = StdioServerParameters( @@ -426,6 +463,7 @@ def __init__( self, params: MCPServerSseParams, cache_tools_list: bool = False, + cache_prompts_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, tool_filter: ToolFilter = None, @@ -444,6 +482,14 @@ def __init__( if you know the server will not change its tools list, because it can drastically improve latency (by avoiding a round-trip to the server every time). + cache_prompts_list: Whether to cache the prompts list. If `True`, the prompts list + will be cached and only fetched from the server once. If `False`, the prompts + list will be fetched from the server on each call to `list_prompts()`. + The cache can be invalidated by calling `invalidate_prompts_cache()`. + You should set this to `True` if you know the server will not change + its prompts list, because it can drastically improve latency + (by avoiding a round-trip to the server every time). + name: A readable name for the server. If not provided, we'll create one from the URL. @@ -451,9 +497,10 @@ def __init__( tool_filter: The tool filter to use for filtering tools. """ super().__init__( - cache_tools_list, - client_session_timeout_seconds, - tool_filter, + cache_tools_list=cache_tools_list, + cache_prompts_list=cache_prompts_list, + client_session_timeout_seconds=client_session_timeout_seconds, + tool_filter=tool_filter, ) self.params = params @@ -511,6 +558,7 @@ def __init__( self, params: MCPServerStreamableHttpParams, cache_tools_list: bool = False, + cache_prompts_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, tool_filter: ToolFilter = None, @@ -530,6 +578,14 @@ def __init__( if you know the server will not change its tools list, because it can drastically improve latency (by avoiding a round-trip to the server every time). + cache_prompts_list: Whether to cache the prompts list. If `True`, the prompts list + will be cached and only fetched from the server once. If `False`, the prompts + list will be fetched from the server on each call to `list_prompts()`. + The cache can be invalidated by calling `invalidate_prompts_cache()`. + You should set this to `True` if you know the server will not change + its prompts list, because it can drastically improve latency + (by avoiding a round-trip to the server every time). + name: A readable name for the server. If not provided, we'll create one from the URL. @@ -537,9 +593,10 @@ def __init__( tool_filter: The tool filter to use for filtering tools. """ super().__init__( - cache_tools_list, - client_session_timeout_seconds, - tool_filter, + cache_tools_list=cache_tools_list, + cache_prompts_list=cache_prompts_list, + client_session_timeout_seconds=client_session_timeout_seconds, + tool_filter=tool_filter, ) self.params = params diff --git a/tests/mcp/helpers.py b/tests/mcp/helpers.py index 31d43c228..b6b259e2d 100644 --- a/tests/mcp/helpers.py +++ b/tests/mcp/helpers.py @@ -38,6 +38,7 @@ def __init__(self, tool_filter: ToolFilter, server_name: str): # Initialize parent class properly to avoid type errors super().__init__( cache_tools_list=False, + cache_prompts_list=False, client_session_timeout_seconds=None, tool_filter=tool_filter, ) diff --git a/tests/mcp/test_caching.py b/tests/mcp/test_caching.py index f31cdf951..15d5a6992 100644 --- a/tests/mcp/test_caching.py +++ b/tests/mcp/test_caching.py @@ -1,7 +1,7 @@ from unittest.mock import AsyncMock, patch import pytest -from mcp.types import ListToolsResult, Tool as MCPTool +from mcp.types import ListPromptsResult, ListToolsResult, Prompt, Tool as MCPTool from agents import Agent from agents.mcp import MCPServerStdio @@ -14,7 +14,7 @@ @patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) @patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) @patch("mcp.client.session.ClientSession.list_tools") -async def test_server_caching_works( +async def test_server_caching_tools_works( mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client ): """Test that if we turn caching on, the list of tools is cached and not fetched from the server @@ -61,3 +61,56 @@ async def test_server_caching_works( # Without invalidating the cache, calling list_tools() again should return the cached value result_tools = await server.list_tools(run_context, agent) assert result_tools == tools + +@pytest.mark.asyncio +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("mcp.client.session.ClientSession.list_prompts") +async def test_server_caching_prompts_works( + mock_list_prompts: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client +): + """Test that if we turn caching on, the list of prompts is cached and not fetched + from the server on each call to `list_prompts()`. + """ + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_prompts_list=True, + ) + + prompts = [ + Prompt(name="prompt1"), + Prompt(name="prompt2"), + ] + + list_prompts = ListPromptsResult(prompts=prompts) + mock_list_prompts.return_value = list_prompts + + async with server: + + # Call list_prompts() multiple times + result_prompts = await server.list_prompts() + assert result_prompts == list_prompts + + assert mock_list_prompts.call_count == 1, "list_prompts() should have been called once" + + # Call list_prompts() again, should return the cached value + result_prompts = await server.list_prompts() + assert result_prompts == list_prompts + + assert mock_list_prompts.call_count == 1, ("list_prompts() " + "should not have been called again") + + # Invalidate the cache and call list_prompts() again + server.invalidate_prompts_cache() + result_prompts = await server.list_prompts() + assert result_prompts == list_prompts + + assert mock_list_prompts.call_count == 2, ("list_prompts() " + "should be called again") + + # Without invalidating the cache, calling list_prompts() + # again should return the cached value + result_prompts = await server.list_prompts() + assert result_prompts == list_prompts diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py index 9e0455115..03f3bcdb9 100644 --- a/tests/mcp/test_server_errors.py +++ b/tests/mcp/test_server_errors.py @@ -8,7 +8,11 @@ class CrashingClientSessionServer(_MCPServerWithClientSession): def __init__(self): - super().__init__(cache_tools_list=False, client_session_timeout_seconds=5) + super().__init__( + cache_tools_list=False, + cache_prompts_list=False, + client_session_timeout_seconds=5 + ) self.cleanup_called = False def create_streams(self):