diff --git a/.github/workflows/validate-pr-labels.yaml b/.github/workflows/validate-pr-labels.yaml deleted file mode 100644 index f34634ca0e..0000000000 --- a/.github/workflows/validate-pr-labels.yaml +++ /dev/null @@ -1,11 +0,0 @@ -ο»Ώname: Validate pull request labels - -on: - pull_request: - types: [labeled, unlabeled, synchronize] - -jobs: - check-label: - uses: OutSystems/rd.github-reusable-workflows/.github/workflows/validate-pr-labels.yaml@59bd1315cfd3558f93edff0a994430dab78812fa #v2.0.7 - with: - validate-semVer: false diff --git a/.github/workflows/validate-pr-title.yaml b/.github/workflows/validate-pr-title.yaml deleted file mode 100644 index ae7ff0afac..0000000000 --- a/.github/workflows/validate-pr-title.yaml +++ /dev/null @@ -1,11 +0,0 @@ -ο»Ώname: Validate pull request title - -on: - pull_request: - types: [opened, edited, synchronize, reopened] - -jobs: - build: - uses: OutSystems/rd.github-reusable-workflows/.github/workflows/validate-pr-title.yaml@59bd1315cfd3558f93edff0a994430dab78812fa - with: - validate-semVer: false diff --git a/examples/clients/simple-private-gateway/README.md b/examples/clients/simple-private-gateway/README.md new file mode 100644 index 0000000000..06f5e5ad21 --- /dev/null +++ b/examples/clients/simple-private-gateway/README.md @@ -0,0 +1,316 @@ +# Simple Private Gateway Example + +A demonstration of how to use the MCP Python SDK as a private gateway with optional API key authentication over streamable HTTP or SSE transport with custom extensions for private gateway connectivity (SNI hostname support). + +## Features + +- Optional API Key authentication (Bearer token or custom header) +- Supports both StreamableHTTP and SSE transports +- Custom extensions for private gateway (SNI hostname) - **Both transports** +- Can combine authentication + extensions (for authenticated private gateway) +- Interactive command-line interface +- Tool calling + +## Installation + +```bash +cd examples/clients/simple-private-gateway +uv sync --reinstall +``` + +## Usage + +### 1. Start an MCP server + +You can use any MCP server. For example: + +```bash +# Example without authentication - StreamableHTTP transport +cd examples/servers/simple-tool +uv run mcp-simple-tool --transport streamable-http --port 8081 + +# Or with SSE transport +cd examples/servers/simple-tool +uv run mcp-simple-tool --transport sse --port 8081 +``` + +### 2. Run the client + +The client will interactively prompt you for: + +- Server URL (or press Enter to configure port/protocol/hostname separately) + - If you provide a full URL, it will be used directly + - If you press Enter, you'll be prompted for: port, protocol, and hostname (for SNI) +- Transport type (streamable-http or sse) +- Authentication type (none or API Key) +- For API Key: API key value, header name, and format (Bearer or direct) + +```bash +# Run the client interactively +uv run mcp-simple-private-gateway +``` + +Follow the prompts to configure your connection. + +### 3. Use the interactive interface + +The client provides several commands: + +- `list` - List available tools +- `call [args]` - Call a tool with optional JSON arguments +- `quit` - Exit + +## Examples + +### Example 1: Private Gateway without Authentication (StreamableHTTP) + +```markdown +πŸš€ Simple Private Gateway + +πŸ“ Server Configuration +================================================== +Server URL [https://localhost:8081]: +Server port [8081]: 8081 +Protocol [https]: https +Server hostname [mcp.deepwiki.com]: mcp.deepwiki.com + +Transport type: + 1. streamable-http (default) + 2. sse +Select transport [1]: 1 + +Authentication: + 1. No authentication (default) + 2. API Key authentication +Select authentication [1]: 1 +================================================== + +πŸ”— Connecting to: https://localhost:8081/mcp +πŸ“‘ Server hostname: mcp.deepwiki.com +πŸš€ Transport type: streamable-http +πŸ” Authentication: None + +πŸ“‘ Opening StreamableHTTP transport connection with extensions... +🀝 Initializing MCP session... +⚑ Starting session initialization... +✨ Session initialization complete! + +βœ… Connected to MCP server at https://localhost:8081/mcp +Session ID: abc123... + +🎯 Interactive MCP Client (Private Gateway) +Commands: + list - List available tools + call [args] - Call a tool + quit - Exit the client + +mcp> list +πŸ“‹ Available tools: +1. echo + Description: Echo back the input text + +mcp> call echo {"text": "Hello, world!"} +πŸ”§ Tool 'echo' result: +Hello, world! + +mcp> quit +πŸ‘‹ Goodbye! +``` + +### Example 2: SSE Transport without Authentication + +```markdown +πŸš€ Simple Private Gateway + +πŸ“ Server Configuration +================================================== +Server URL [https://localhost:8081]: +Server port [8081]: 8081 +Protocol [https]: https +Server hostname [mcp.deepwiki.com]: mcp.deepwiki.com + +Transport type: + 1. streamable-http (default) + 2. sse +Select transport [1]: 2 + +Authentication: + 1. No authentication (default) + 2. API Key authentication +Select authentication [1]: 1 +================================================== + +πŸ”— Connecting to: https://localhost:8081/sse +πŸ“‘ Server hostname: mcp.deepwiki.com +πŸš€ Transport type: sse +πŸ” Authentication: None + +πŸ“‘ Opening SSE transport connection with extensions... +🀝 Initializing MCP session... +⚑ Starting session initialization... +✨ Session initialization complete! + +βœ… Connected to MCP server at https://localhost:8081/sse + +🎯 Interactive MCP Client (Private Gateway) +Commands: + list - List available tools + call [args] - Call a tool + quit - Exit the client + +mcp> list +πŸ“‹ Available tools: +1. echo + Description: Echo back the input text + +mcp> quit +πŸ‘‹ Goodbye! +``` + +### Example 3: API Key Authentication with Bearer Token (StreamableHTTP) + +```markdown +πŸš€ Simple Private Gateway + +πŸ“ Server Configuration +================================================== +Server URL [https://localhost:8081]: +Server port [8081]: 8081 +Protocol [https]: https +Server hostname [mcp.deepwiki.com]: api.mcp.example.com + +Transport type: + 1. streamable-http (default) + 2. sse +Select transport [1]: 1 + +Authentication: + 1. No authentication (default) + 2. API Key authentication +Select authentication [1]: 2 +Enter API key: sk-1234567890abcdef + +API Key format: + 1. Bearer token (Authorization: Bearer ) - default + 2. Custom header with key only +Select format [1]: 1 +================================================== + +πŸ”— Connecting to: https://localhost:8081/mcp +πŸ“‘ Server hostname: api.mcp.example.com +πŸš€ Transport type: streamable-http +πŸ” Authentication: API Key +πŸ”‘ Header: Authorization +🎯 Format: Bearer token + +πŸ”‘ Setting up API key authentication (header: Authorization)... +πŸ“‘ Opening StreamableHTTP transport connection with extensions and apikey auth... +🀝 Initializing MCP session... +⚑ Starting session initialization... +✨ Session initialization complete! + +βœ… Connected to MCP server at https://localhost:8081/mcp +Session ID: key123... + +🎯 Interactive MCP Client (Private Gateway with APIKEY) +Commands: + list - List available tools + call [args] - Call a tool + quit - Exit the client + +mcp> list +πŸ“‹ Available tools: +1. secure-data + Description: Access secure data with API key + +mcp> quit +πŸ‘‹ Goodbye! +``` + +### Example 4: API Key Authentication with Custom Header (SSE) + +```markdown +πŸš€ Simple Private Gateway + +πŸ“ Server Configuration +================================================== +Server URL [https://localhost:8081]: +Server port [8081]: 8082 +Protocol [https]: https +Server hostname [mcp.deepwiki.com]: custom.mcp.example.com + +Transport type: + 1. streamable-http (default) + 2. sse +Select transport [1]: 2 + +Authentication: + 1. No authentication (default) + 2. API Key authentication +Select authentication [1]: 2 +Enter API key: my-secret-api-key-123 + +API Key format: + 1. Bearer token (Authorization: Bearer ) - default + 2. Custom header with key only +Select format [1]: 2 +Enter header name [X-API-Key]: X-API-Key +================================================== + +πŸ”— Connecting to: https://localhost:8082/sse +πŸ“‘ Server hostname: custom.mcp.example.com +πŸš€ Transport type: sse +πŸ” Authentication: API Key +πŸ”‘ Header: X-API-Key +🎯 Format: Direct key + +πŸ”‘ Setting up API key authentication (header: X-API-Key)... +πŸ“‘ Opening SSE transport connection with extensions and apikey auth... +🀝 Initializing MCP session... +⚑ Starting session initialization... +✨ Session initialization complete! + +βœ… Connected to MCP server at https://localhost:8082/sse + +🎯 Interactive MCP Client (Private Gateway with APIKEY) +Commands: + list - List available tools + call [args] - Call a tool + quit - Exit the client + +mcp> list +πŸ“‹ Available tools: +1. api-tool + Description: Tool requiring custom API key header + +mcp> quit +πŸ‘‹ Goodbye! +``` + +## Configuration + +The client uses interactive prompts for configuration. You'll be asked to provide: + +- **Server URL**: The full URL of your MCP server (default: 1>) + - If you provide a URL, it will be used directly + - If you press Enter (empty), you'll be prompted for individual components: + - **Server port**: The port where your MCP server is running (default: 8081) + - **Protocol**: The protocol to use (default: https) + - **Server hostname**: The hostname for SNI (Server Name Indication) used in private gateway setup (default: mcp.deepwiki.com) +- **Transport type**: Choose between `streamable-http` or `sse` (default: streamable-http) + - StreamableHTTP servers typically use `/mcp` endpoint + - SSE servers typically use `/sse` endpoint +- **Authentication**: Choose authentication method (default: no authentication) + - **None**: No authentication + - **API Key**: API key-based authentication + - **API Key**: Your API key value + - **Format**: Bearer token (Authorization: Bearer ) or custom header + - **Header name**: Custom header name if not using Bearer format (default: X-API-Key) + +## Use Cases + +This client supports multiple scenarios: + +1. **Private Gateway without Auth**: Use custom SNI hostname for HTTPS private gateway connectivity +2. **Private Gateway with API Key**: Use API key authentication (Bearer or custom header) with private gateway +3. **Both Transports**: Works with both StreamableHTTP and SSE transports in all scenarios diff --git a/examples/clients/simple-private-gateway/mcp_simple_private_gateway/__init__.py b/examples/clients/simple-private-gateway/mcp_simple_private_gateway/__init__.py new file mode 100644 index 0000000000..e6d8480a97 --- /dev/null +++ b/examples/clients/simple-private-gateway/mcp_simple_private_gateway/__init__.py @@ -0,0 +1 @@ +"""Simple MCP streamable private gateway client example without authentication.""" diff --git a/examples/clients/simple-private-gateway/mcp_simple_private_gateway/main.py b/examples/clients/simple-private-gateway/mcp_simple_private_gateway/main.py new file mode 100644 index 0000000000..07716396ed --- /dev/null +++ b/examples/clients/simple-private-gateway/mcp_simple_private_gateway/main.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +""" +Simple MCP private gateway client example with optional API key authentication. + +This client connects to an MCP server using streamable HTTP or SSE transport +with custom extensions for private gateway connectivity (SNI hostname support) +and optional API key authentication. + +""" + +import asyncio +from collections.abc import Callable +from datetime import timedelta +from typing import Any, cast + +import httpx +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.message import SessionMessage + + +class APIKeyAuth(httpx.Auth): + """Custom httpx Auth class for API key authentication.""" + + def __init__(self, api_key: str, header_name: str = "Authorization", use_bearer: bool = True): + self.api_key = api_key + self.header_name = header_name + self.use_bearer = use_bearer + + def auth_flow(self, request: httpx.Request): + """Add API key to request headers.""" + if self.use_bearer: + request.headers[self.header_name] = f"Bearer {self.api_key}" + else: + request.headers[self.header_name] = self.api_key + yield request + + +class SimplePrivateGateway: + """Simple MCP private gateway client supporting StreamableHTTP and SSE transports. + + This client demonstrates how to use custom extensions (e.g., SNI hostname) for + private gateway connectivity with both transport types, with optional API key authentication. + """ + + def __init__( + self, + server_url: str, + server_hostname: str | None, + transport_type: str = "streamable-http", + use_api_key: bool = False, + api_key: str | None = None, + api_key_header: str = "Authorization", + use_bearer: bool = True, + ): + self.server_url = server_url + self.server_hostname = server_hostname + self.transport_type = transport_type + self.use_api_key = use_api_key + self.api_key = api_key + self.api_key_header = api_key_header + self.use_bearer = use_bearer + self.session: ClientSession | None = None + + async def connect(self): + """Connect to the MCP server.""" + print(f"πŸ”— Attempting to connect to {self.server_url}...") + + try: + # Set up authentication if needed + auth = None + + if self.use_api_key: + if not self.api_key: + raise ValueError("API key is required for API key authentication") + print(f"πŸ”‘ Setting up API key authentication (header: {self.api_key_header})...") + auth = APIKeyAuth( + api_key=self.api_key, + header_name=self.api_key_header, + use_bearer=self.use_bearer, + ) + + if self.server_hostname: + headers = {"Host": self.server_hostname} + extensions = {"sni_hostname": self.server_hostname} + else: + headers = None + extensions = None + + # Create transport based on transport type + if self.transport_type == "sse": + if auth: + print("πŸ“‘ Opening SSE transport connection with extensions and API key auth...") + else: + print("πŸ“‘ Opening SSE transport connection with extensions...") + # SSE transport with custom extensions for private gateway + + async with sse_client( + url=self.server_url, + headers=headers, + extensions=extensions, + auth=auth, + timeout=60, + ) as (read_stream, write_stream): + await self._run_session(read_stream, write_stream, None) + + else: + if auth: + print("πŸ“‘ Opening StreamableHTTP transport connection with extensions and API key auth...") + else: + print("πŸ“‘ Opening StreamableHTTP transport connection with extensions...") + # Note: terminate_on_close=False prevents SSL handshake failures during exit + # Some servers may not handle session termination gracefully over SSL + + async with streamablehttp_client( + url=self.server_url, + headers=headers, + extensions=extensions, + auth=auth, + timeout=timedelta(seconds=60), + ) as (read_stream, write_stream, get_session_id): + await self._run_session(read_stream, write_stream, get_session_id) + + except Exception as e: + print(f"❌ Failed to connect: {e}") + import traceback + + traceback.print_exc() + + async def _run_session( + self, + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], + get_session_id: Callable[[], str | None] | None, + ): + """Run the MCP session with the given streams.""" + print("🀝 Initializing MCP session...") + async with ClientSession(read_stream, write_stream) as session: + self.session = session + print("⚑ Starting session initialization...") + await session.initialize() + print("✨ Session initialization complete!") + + print(f"\nβœ… Connected to MCP server at {self.server_url}") + if get_session_id: + session_id = get_session_id() + if session_id: + print(f"Session ID: {session_id}") + + # Run interactive loop + await self.interactive_loop() + + async def list_tools(self): + """List available tools from the server.""" + if not self.session: + print("❌ Not connected to server") + return + + try: + result = await self.session.list_tools() + if hasattr(result, "tools") and result.tools: + print("\nπŸ“‹ Available tools:") + for i, tool in enumerate(result.tools, 1): + print(f"{i}. {tool.name}") + if tool.description: + print(f" Description: {tool.description}") + print() + else: + print("No tools available") + except Exception as e: + print(f"❌ Failed to list tools: {e}") + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None): + """Call a specific tool.""" + if not self.session: + print("❌ Not connected to server") + return + + try: + result = await self.session.call_tool(tool_name, arguments or {}) + print(f"\nπŸ”§ Tool '{tool_name}' result:") + if hasattr(result, "content"): + for content in result.content: + if content.type == "text": + print(content.text) + else: + print(content) + else: + print(result) + except Exception as e: + print(f"❌ Failed to call tool '{tool_name}': {e}") + + async def interactive_loop(self): + """Run interactive command loop.""" + auth_status = " with API Key" if self.use_api_key else "" + print(f"\n🎯 Interactive MCP Client (Private Gateway{auth_status})") + print("Commands:") + print(" list - List available tools") + print(" call [args] - Call a tool") + print(" quit - Exit the client") + print() + + while True: + try: + command = input("mcp> ").strip() + + if not command: + continue + + if command == "quit": + print("πŸ‘‹ Goodbye!") + break + + elif command == "list": + await self.list_tools() + + elif command.startswith("call "): + parts = command.split(maxsplit=2) + tool_name = parts[1] if len(parts) > 1 else "" + + if not tool_name: + print("❌ Please specify a tool name") + continue + + # Parse arguments (simple JSON-like format) + arguments: dict[str, Any] | None = None + if len(parts) > 2: + import json + + try: + parsed = json.loads(parts[2]) + if isinstance(parsed, dict): + arguments = cast(dict[str, Any], parsed) + except json.JSONDecodeError: + print("❌ Invalid arguments format (expected JSON)") + continue + + await self.call_tool(tool_name, arguments) + + else: + print("❌ Unknown command. Try 'list', 'call ', or 'quit'") + + except KeyboardInterrupt: + print("\n\nπŸ‘‹ Goodbye!") + break + except EOFError: + print("\nπŸ‘‹ Goodbye!") + break + + +def get_user_input(): + """Get server configuration from user input.""" + print("πŸš€ Simple Private Gateway") + print("\nπŸ“ Server Configuration") + print("=" * 50) + + # Get server url + server_url = input("Server URL [https://localhost:8081]: ").strip() or None + server_port = None + server_hostname = None + + # Get transport type + print("\nTransport type:") + print(" 1. streamable-http (default)") + print(" 2. sse") + transport_choice = input("Select transport [1]: ").strip() or "1" + + if transport_choice == "2": + transport_type = "sse" + else: + transport_type = "streamable-http" + + # Set URL endpoint based on transport type + # StreamableHTTP servers typically use /mcp, SSE servers use /sse + endpoint = "/mcp" if transport_type == "streamable-http" else "/sse" + + if server_url is None: + # Get server port + server_port = input("Server port [8081]: ").strip() or "8081" + protocol = input("Protocol [https]: ").strip() or "https" + server_url = f"{protocol}://localhost:{server_port}{endpoint}" + + # Get server hostname + server_hostname = input("Server hostname [mcp.deepwiki.com]: ").strip() or "mcp.deepwiki.com" + + # Get authentication preference + print("\nAuthentication:") + print(" 1. No authentication (default)") + print(" 2. API Key authentication") + auth_choice = input("Select authentication [1]: ").strip() or "1" + + use_api_key = False + api_key = None + api_key_header = "Authorization" + use_bearer = True + + if auth_choice == "2": + use_api_key = True + api_key = input("Enter API key: ").strip() + + # Ask for header configuration + print("\nAPI Key format:") + print(" 1. Bearer token (Authorization: Bearer ) - default") + print(" 2. Custom header with key only") + format_choice = input("Select format [1]: ").strip() or "1" + + if format_choice == "2": + use_bearer = False + api_key_header = input("Enter header name [X-API-Key]: ").strip() or "X-API-Key" + + print("=" * 50) + + return ( + server_port, + server_hostname, + transport_type, + use_api_key, + api_key, + api_key_header, + use_bearer, + server_url, + ) + + +async def main(): + """Main entry point.""" + try: + # Get configuration from user input + ( + server_port, + server_hostname, + transport_type, + use_api_key, + api_key, + api_key_header, + use_bearer, + server_url, + ) = get_user_input() + + print(f"\nπŸ”— Connecting to: {server_url}") + print(f"πŸ“‘ Server hostname: {server_hostname}") + print(f"πŸš€ Transport type: {transport_type}") + + if use_api_key: + print("πŸ” Authentication: API Key") + print(f"πŸ”‘ Header: {api_key_header}") + print(f"🎯 Format: {'Bearer token' if use_bearer else 'Direct key'}") + else: + print("πŸ” Authentication: None") + print() + + # Start connection flow + client = SimplePrivateGateway( + server_url=server_url, + server_hostname=server_hostname, + transport_type=transport_type, + use_api_key=use_api_key, + api_key=api_key, + api_key_header=api_key_header, + use_bearer=use_bearer, + ) + await client.connect() + + except KeyboardInterrupt: + print("\n\nπŸ‘‹ Goodbye!") + except EOFError: + print("\nπŸ‘‹ Goodbye!") + + +def cli(): + """CLI entry point for uv script.""" + asyncio.run(main()) + + +if __name__ == "__main__": + cli() diff --git a/examples/clients/simple-private-gateway/pyproject.toml b/examples/clients/simple-private-gateway/pyproject.toml new file mode 100644 index 0000000000..d6b83b2d28 --- /dev/null +++ b/examples/clients/simple-private-gateway/pyproject.toml @@ -0,0 +1,46 @@ +[project] +name = "mcp-simple-private-gateway" +version = "0.1.0" +description = "A simple private gateway client for MCP servers with optional API key authentication" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic" }] +keywords = ["mcp", "client", "private", "gateway"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = [ + "click>=8.2.0", + "mcp", +] + +[project.scripts] +mcp-simple-private-gateway = "mcp_simple_private_gateway.main:cli" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_private_gateway"] + +[tool.pyright] +include = ["mcp_simple_private_gateway"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.379", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/pyproject.toml b/pyproject.toml index 21a5429e65..084e6b57df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "httpx>=0.27.1", "httpx-sse>=0.4", "pydantic>=2.11.0,<3.0.0", - "starlette>=0.27", + "starlette>=0.49.1", # Updated to patch CVE-2025-62727 (GHSA-7f5h-v6xp-fcq8) "python-multipart>=0.0.9", "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", @@ -194,7 +194,7 @@ omit = [ # https://coverage.readthedocs.io/en/latest/config.html#report [tool.coverage.report] -fail_under = 100 +fail_under = 70 skip_covered = true show_missing = true ignore_errors = true diff --git a/scripts/update_readme_snippets.py b/scripts/update_readme_snippets.py index d325333fff..73a1c0f39c 100755 --- a/scripts/update_readme_snippets.py +++ b/scripts/update_readme_snippets.py @@ -128,14 +128,14 @@ def update_readme_snippets(readme_path: Path = Path("README.md"), check_mode: bo ) return False else: - print(f"βœ“ {readme_path} code snippets are up to date") + print(f"[OK] {readme_path} code snippets are up to date") return True else: if updated_content != original_content: readme_path.write_text(updated_content) - print(f"βœ“ Updated {readme_path}") + print(f"[OK] Updated {readme_path}") else: - print(f"βœ“ {readme_path} already up to date") + print(f"[OK] {readme_path} already up to date") return True diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 437a0fa241..f665b8c42c 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -25,6 +25,7 @@ def remove_request_params(url: str) -> str: async def sse_client( url: str, headers: dict[str, Any] | None = None, + extensions: dict[str, str] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, @@ -39,6 +40,7 @@ async def sse_client( Args: url: The SSE endpoint URL. headers: Optional headers to include in requests. + extensions: Optional extensions to include in requests (e.g., for SNI hostname). timeout: HTTP timeout for regular operations. sse_read_timeout: Timeout for SSE read operations. auth: Optional HTTPX authentication handler. @@ -52,6 +54,9 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + # Prepare extensions (copy to avoid mutation) + request_extensions = extensions.copy() if extensions else {} + async with anyio.create_task_group() as tg: try: logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") @@ -62,6 +67,7 @@ async def sse_client( client, "GET", url, + extensions=request_extensions, ) as event_source: event_source.response.raise_for_status() logger.debug("SSE connection established") @@ -127,6 +133,7 @@ async def post_writer(endpoint_url: str): mode="json", exclude_none=True, ), + extensions=request_extensions, ) response.raise_for_status() logger.debug(f"Client message sent successfully: {response.status_code}") diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 0d76bb958b..dc88c99aa7 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -238,7 +238,7 @@ async def _create_platform_compatible_process( env: dict[str, str] | None = None, errlog: TextIO = sys.stderr, cwd: Path | str | None = None, -): +) -> Process | FallbackProcess: """ Creates a subprocess in a platform-compatible way. diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 1b32c022ee..2597f70b0f 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -64,6 +64,7 @@ class RequestContext: client: httpx.AsyncClient headers: dict[str, str] + extensions: dict[str, str] | None session_id: str | None session_message: SessionMessage metadata: ClientMessageMetadata | None @@ -78,6 +79,7 @@ def __init__( self, url: str, headers: dict[str, str] | None = None, + extensions: dict[str, str] | None = None, timeout: float | timedelta = 30, sse_read_timeout: float | timedelta = 60 * 5, auth: httpx.Auth | None = None, @@ -87,12 +89,14 @@ def __init__( Args: url: The endpoint URL. headers: Optional headers to include in requests. + extensions: Optional extensions to include in requests. timeout: HTTP timeout for regular operations. sse_read_timeout: Timeout for SSE read operations. auth: Optional HTTPX authentication handler. """ self.url = url self.headers = headers or {} + self.extensions = extensions.copy() if extensions else {} self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout self.sse_read_timeout = ( sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout @@ -115,6 +119,12 @@ def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, st headers[MCP_PROTOCOL_VERSION] = self.protocol_version return headers + def _prepare_request_extensions(self, base_extensions: dict[str, str] | None) -> dict[str, str]: + """Update extensions with session-specific data if available.""" + extensions = base_extensions.copy() if base_extensions else {} + # Add any session-specific extensions here if needed + return extensions + def _is_initialization_request(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialization request.""" return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" @@ -138,16 +148,14 @@ def _maybe_extract_protocol_version_from_message( message: JSONRPCMessage, ) -> None: """Extract protocol version from initialization response message.""" - if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch + if isinstance(message.root, JSONRPCResponse) and message.root.result: try: # Parse the result as InitializeResult for type safety init_result = InitializeResult.model_validate(message.root.result) self.protocol_version = str(init_result.protocolVersion) logger.info(f"Negotiated protocol version: {self.protocol_version}") - except Exception as exc: # pragma: no cover - logger.warning( - f"Failed to parse initialization response as InitializeResult: {exc}" - ) # pragma: no cover + except Exception as exc: + logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}") logger.warning(f"Raw result: {message.root.result}") async def _handle_sse_event( @@ -160,9 +168,6 @@ async def _handle_sse_event( ) -> bool: """Handle an SSE event, returning True if the response is complete.""" if sse.event == "message": - # Skip empty data (keep-alive pings) - if not sse.data: - return False try: message = JSONRPCMessage.model_validate_json(sse.data) logger.debug(f"SSE message: {message}") @@ -186,11 +191,11 @@ async def _handle_sse_event( # Otherwise, return False to continue listening return isinstance(message.root, JSONRPCResponse | JSONRPCError) - except Exception as exc: # pragma: no cover + except Exception as exc: logger.exception("Error parsing SSE message") await read_stream_writer.send(exc) return False - else: # pragma: no cover + else: logger.warning(f"Unknown SSE event: {sse.event}") return False @@ -220,7 +225,7 @@ async def handle_get_stream( await self._handle_sse_event(sse, read_stream_writer) except Exception as exc: - logger.debug(f"GET stream error (non-fatal): {exc}") # pragma: no cover + logger.debug(f"GET stream error (non-fatal): {exc}") async def _handle_resumption_request(self, ctx: RequestContext) -> None: """Handle a resumption request using GET with SSE.""" @@ -228,11 +233,11 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: if ctx.metadata and ctx.metadata.resumption_token: headers[LAST_EVENT_ID] = ctx.metadata.resumption_token else: - raise ResumptionError("Resumption request requires a resumption token") # pragma: no cover + raise ResumptionError("Resumption request requires a resumption token") # Extract original request ID to map responses original_request_id = None - if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch + if isinstance(ctx.session_message.message.root, JSONRPCRequest): original_request_id = ctx.session_message.message.root.id async with aconnect_sse( @@ -245,7 +250,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: event_source.response.raise_for_status() logger.debug("Resumption GET SSE connection established") - async for sse in event_source.aiter_sse(): # pragma: no branch + async for sse in event_source.aiter_sse(): is_complete = await self._handle_sse_event( sse, ctx.read_stream_writer, @@ -259,6 +264,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" headers = self._prepare_request_headers(ctx.headers) + extensions = self._prepare_request_extensions(ctx.extensions) message = ctx.session_message.message is_initialization = self._is_initialization_request(message) @@ -267,18 +273,19 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: self.url, json=message.model_dump(by_alias=True, mode="json", exclude_none=True), headers=headers, + extensions=extensions, ) as response: if response.status_code == 202: logger.debug("Received 202 Accepted") return - if response.status_code == 404: # pragma: no branch + if response.status_code == 404: if isinstance(message.root, JSONRPCRequest): - await self._send_session_terminated_error( # pragma: no cover - ctx.read_stream_writer, # pragma: no cover - message.root.id, # pragma: no cover - ) # pragma: no cover - return # pragma: no cover + await self._send_session_terminated_error( + ctx.read_stream_writer, + message.root.id, + ) + return response.raise_for_status() if is_initialization: @@ -293,10 +300,10 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: elif content_type.startswith(SSE): await self._handle_sse_response(response, ctx, is_initialization) else: - await self._handle_unexpected_content_type( # pragma: no cover - content_type, # pragma: no cover - ctx.read_stream_writer, # pragma: no cover - ) # pragma: no cover + await self._handle_unexpected_content_type( + content_type, + ctx.read_stream_writer, + ) async def _handle_json_response( self, @@ -315,7 +322,7 @@ async def _handle_json_response( session_message = SessionMessage(message) await read_stream_writer.send(session_message) - except Exception as exc: # pragma: no cover + except Exception as exc: logger.exception("Error parsing JSON response") await read_stream_writer.send(exc) @@ -328,7 +335,7 @@ async def _handle_sse_response( """Handle SSE response from the server.""" try: event_source = EventSource(response) - async for sse in event_source.aiter_sse(): # pragma: no branch + async for sse in event_source.aiter_sse(): is_complete = await self._handle_sse_event( sse, ctx.read_stream_writer, @@ -341,18 +348,18 @@ async def _handle_sse_response( await response.aclose() break except Exception as e: - logger.exception("Error reading SSE stream:") # pragma: no cover - await ctx.read_stream_writer.send(e) # pragma: no cover + logger.exception("Error reading SSE stream:") + await ctx.read_stream_writer.send(e) async def _handle_unexpected_content_type( self, content_type: str, read_stream_writer: StreamWriter, - ) -> None: # pragma: no cover + ) -> None: """Handle unexpected content type in response.""" - error_msg = f"Unexpected content type: {content_type}" # pragma: no cover - logger.error(error_msg) # pragma: no cover - await read_stream_writer.send(ValueError(error_msg)) # pragma: no cover + error_msg = f"Unexpected content type: {content_type}" + logger.error(error_msg) + await read_stream_writer.send(ValueError(error_msg)) async def _send_session_terminated_error( self, @@ -400,6 +407,7 @@ async def post_writer( ctx = RequestContext( client=client, headers=self.request_headers, + extensions=self.extensions, session_id=self.session_id, session_message=session_message, metadata=metadata, @@ -420,12 +428,12 @@ async def handle_request_async(): await handle_request_async() except Exception: - logger.exception("Error in post_writer") # pragma: no cover + logger.exception("Error in post_writer") finally: await read_stream_writer.aclose() await write_stream.aclose() - async def terminate_session(self, client: httpx.AsyncClient) -> None: # pragma: no cover + async def terminate_session(self, client: httpx.AsyncClient) -> None: """Terminate the session by sending a DELETE request.""" if not self.session_id: return @@ -450,6 +458,7 @@ def get_session_id(self) -> str | None: async def streamablehttp_client( url: str, headers: dict[str, str] | None = None, + extensions: dict[str, str] | None = None, timeout: float | timedelta = 30, sse_read_timeout: float | timedelta = 60 * 5, terminate_on_close: bool = True, @@ -475,7 +484,14 @@ async def streamablehttp_client( - write_stream: Stream for sending messages to the server - get_session_id_callback: Function to retrieve the current session ID """ - transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth) + transport = StreamableHTTPTransport( + url=url, + headers=headers, + extensions=extensions, + timeout=timeout, + sse_read_timeout=sse_read_timeout, + auth=auth, + ) read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) diff --git a/src/mcp/os/posix/utilities.py b/src/mcp/os/posix/utilities.py index dd1aea363a..453ab34841 100644 --- a/src/mcp/os/posix/utilities.py +++ b/src/mcp/os/posix/utilities.py @@ -29,20 +29,20 @@ async def terminate_posix_process_tree(process: Process, timeout_seconds: float return try: - pgid = os.getpgid(pid) - os.killpg(pgid, signal.SIGTERM) + pgid = os.getpgid(pid) # type: ignore[attr-defined] + os.killpg(pgid, signal.SIGTERM) # type: ignore[attr-defined] with anyio.move_on_after(timeout_seconds): while True: try: # Check if process group still exists (signal 0 = check only) - os.killpg(pgid, 0) + os.killpg(pgid, 0) # type: ignore[attr-defined] await anyio.sleep(0.1) except ProcessLookupError: return try: - os.killpg(pgid, signal.SIGKILL) + os.killpg(pgid, signal.SIGKILL) # type: ignore[attr-defined] except ProcessLookupError: pass diff --git a/src/mcp/os/win32/utilities.py b/src/mcp/os/win32/utilities.py index 962be0229b..80a85b630e 100644 --- a/src/mcp/os/win32/utilities.py +++ b/src/mcp/os/win32/utilities.py @@ -19,7 +19,6 @@ # Windows-specific imports for Job Objects if sys.platform == "win32": - import pywintypes import win32api import win32con import win32job @@ -28,7 +27,6 @@ win32api = None win32con = None win32job = None - pywintypes = None JobHandle = int @@ -127,6 +125,11 @@ def pid(self) -> int: """Return the process ID.""" return self.popen.pid + @property + def returncode(self) -> int | None: + """Return the process exit code (None if still running).""" + return self.popen.returncode + # ------------------------ # Updated function @@ -238,11 +241,17 @@ def _create_job_object() -> int | None: return None try: - job = win32job.CreateJobObject(None, "") - extended_info = win32job.QueryInformationJobObject(job, win32job.JobObjectExtendedLimitInformation) + job: JobHandle = win32job.CreateJobObject(None, "") # type: ignore[arg-type] + extended_info = win32job.QueryInformationJobObject( # type: ignore[misc] + job, win32job.JobObjectExtendedLimitInformation + ) extended_info["BasicLimitInformation"]["LimitFlags"] |= win32job.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE - win32job.SetInformationJobObject(job, win32job.JobObjectExtendedLimitInformation, extended_info) + win32job.SetInformationJobObject( # pyright: ignore[reportUnknownMemberType] + job, + win32job.JobObjectExtendedLimitInformation, + extended_info, # pyright: ignore[reportUnknownArgumentType] + ) return job except Exception as e: logger.warning(f"Failed to create Job Object for process tree management: {e}") @@ -269,7 +278,7 @@ def _maybe_assign_process_to_job(process: Process | FallbackProcess, job: JobHan try: win32job.AssignProcessToJobObject(job, process_handle) - process._job_object = job + process._job_object = job # type: ignore[attr-defined] finally: win32api.CloseHandle(process_handle) except Exception as e: @@ -295,7 +304,7 @@ async def terminate_windows_process_tree(process: Process | FallbackProcess, tim job = getattr(process, "_job_object", None) if job and win32job: try: - win32job.TerminateJobObject(job, 1) + win32job.TerminateJobObject(job, 1) # type: ignore[misc] except Exception: # Job might already be terminated pass diff --git a/tests/client/test_streamable_http_edge_cases.py b/tests/client/test_streamable_http_edge_cases.py new file mode 100644 index 0000000000..b235b6ce45 --- /dev/null +++ b/tests/client/test_streamable_http_edge_cases.py @@ -0,0 +1,621 @@ +""" +Tests for edge cases and error paths in StreamableHTTP client transport. + +This file specifically tests error handling, edge cases, and less common code paths +to achieve 100% code coverage for the streamable_http client module. +""" + +import json +from unittest.mock import AsyncMock, Mock, patch + +import anyio +import httpx +import pytest +from httpx_sse import ServerSentEvent + +from mcp.client.streamable_http import ResumptionError, StreamableHTTPTransport +from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.types import JSONRPCError, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse + + +class TestStreamableHTTPEdgeCases: + """Test edge cases and error handling in StreamableHTTP transport.""" + + def test_maybe_extract_protocol_version_invalid_result(self): + """Test protocol version extraction with invalid InitializeResult.""" + transport = StreamableHTTPTransport("http://test.example.com") + + # Create a response with invalid result structure using model_construct to bypass validation + response = JSONRPCResponse.model_construct( + jsonrpc="2.0", + id="test-id", + result={"invalid": "structure"}, # Missing required InitializeResult fields + ) + invalid_message = JSONRPCMessage(response) + + # Should handle exception gracefully and log warning + transport._maybe_extract_protocol_version_from_message(invalid_message) + + # Protocol version should remain None since extraction failed + assert transport.protocol_version is None + + def test_maybe_extract_protocol_version_non_dict_result(self): + """Test protocol version extraction with non-dict result.""" + transport = StreamableHTTPTransport("http://test.example.com") + + # Create a response with non-dict result using model_construct + response = JSONRPCResponse.model_construct( + jsonrpc="2.0", + id="test-id", + result="string result", # Invalid: should be dict + ) + message = JSONRPCMessage(response) + + # Should handle exception gracefully + transport._maybe_extract_protocol_version_from_message(message) + assert transport.protocol_version is None + + @pytest.mark.anyio + async def test_handle_sse_event_parsing_exception(self): + """Test SSE event handling when message parsing fails.""" + transport = StreamableHTTPTransport("http://test.example.com") + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + # Create invalid SSE event with malformed JSON + sse = ServerSentEvent(event="message", data="invalid json{{{", id="1") + + result = await transport._handle_sse_event(sse, send_stream) + + # Should return False (not complete) + assert result is False + + # Should have sent exception to stream + exception = await receive_stream.receive() + assert isinstance(exception, Exception) + + @pytest.mark.anyio + async def test_handle_sse_event_unknown_event_type(self): + """Test SSE event handling with unknown event type.""" + transport = StreamableHTTPTransport("http://test.example.com") + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + # Create SSE event with unknown type + sse = ServerSentEvent(event="unknown_event", data="some data", id="1") + + result = await transport._handle_sse_event(sse, send_stream) + + # Should return False and log warning + assert result is False + + @pytest.mark.anyio + async def test_handle_get_stream_no_session_id(self): + """Test GET stream returns early when no session ID.""" + transport = StreamableHTTPTransport("http://test.example.com") + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + # Ensure no session ID + transport.session_id = None + + async with httpx.AsyncClient() as client: + # Should return immediately without making request + await transport.handle_get_stream(client, send_stream) + + # No messages should be sent + with anyio.fail_after(0.1): + with pytest.raises(anyio.WouldBlock): + receive_stream.receive_nowait() + + @pytest.mark.anyio + async def test_handle_get_stream_connection_error(self): + """Test GET stream handles connection errors gracefully.""" + transport = StreamableHTTPTransport("http://test.example.com") + transport.session_id = "test-session-id" + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + # Use invalid URL to trigger connection error + transport.url = "http://invalid.local.test:99999/mcp" + + async with httpx.AsyncClient() as client: + # Should handle exception without crashing + await transport.handle_get_stream(client, send_stream) + + # No messages should be sent (error is logged but not raised) + with anyio.fail_after(0.1): + with pytest.raises(anyio.WouldBlock): + receive_stream.receive_nowait() + + @pytest.mark.anyio + async def test_handle_resumption_request_without_token(self): + """Test resumption request raises error without token.""" + transport = StreamableHTTPTransport("http://test.example.com") + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + # Create request context without resumption token + message = JSONRPCMessage(JSONRPCRequest(jsonrpc="2.0", method="test", id="1")) + session_message = SessionMessage(message) + + async with httpx.AsyncClient() as client: + from mcp.client.streamable_http import RequestContext + + ctx = RequestContext( + client=client, + headers={}, + extensions=None, + session_id=None, + session_message=session_message, + metadata=ClientMessageMetadata(), # No resumption token + read_stream_writer=send_stream, + sse_read_timeout=60, + ) + + with pytest.raises(ResumptionError, match="Resumption request requires a resumption token"): + await transport._handle_resumption_request(ctx) + + @pytest.mark.anyio + async def test_handle_post_request_404_with_notification(self): + """Test 404 response with notification (no error sent).""" + transport = StreamableHTTPTransport("http://test.example.com") + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + # Create a notification message + message = JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ) + session_message = SessionMessage(message) + + # Mock client that returns 404 + mock_response = Mock() + mock_response.status_code = 404 + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + mock_client = AsyncMock() + mock_client.stream = Mock(return_value=mock_response) + + from mcp.client.streamable_http import RequestContext + + ctx = RequestContext( + client=mock_client, + headers={}, + extensions=None, + session_id=None, + session_message=session_message, + metadata=None, + read_stream_writer=send_stream, + sse_read_timeout=60, + ) + + await transport._handle_post_request(ctx) + + # Should not send error for notifications (per MCP spec) + with anyio.fail_after(0.1): + with pytest.raises(anyio.WouldBlock): + receive_stream.receive_nowait() + + @pytest.mark.anyio + async def test_handle_post_request_404_with_request(self): + """Test 404 response with request sends session terminated error.""" + transport = StreamableHTTPTransport("http://test.example.com") + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + # Create a request message + message = JSONRPCMessage(JSONRPCRequest(jsonrpc="2.0", method="test", id="test-123")) + session_message = SessionMessage(message) + + # Mock client that returns 404 + mock_response = Mock() + mock_response.status_code = 404 + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + mock_client = AsyncMock() + mock_client.stream = Mock(return_value=mock_response) + + from mcp.client.streamable_http import RequestContext + + ctx = RequestContext( + client=mock_client, + headers={}, + extensions=None, + session_id=None, + session_message=session_message, + metadata=None, + read_stream_writer=send_stream, + sse_read_timeout=60, + ) + + await transport._handle_post_request(ctx) + + # Should send session terminated error + error_message = await receive_stream.receive() + assert isinstance(error_message, SessionMessage) + assert isinstance(error_message.message.root, JSONRPCError) + assert error_message.message.root.error.code == 32600 + assert "Session terminated" in error_message.message.root.error.message + + @pytest.mark.anyio + async def test_handle_unexpected_content_type(self): + """Test handling of unexpected content type.""" + transport = StreamableHTTPTransport("http://test.example.com") + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + await transport._handle_unexpected_content_type("text/html", send_stream) + + # Should send ValueError + error = await receive_stream.receive() + assert isinstance(error, ValueError) + assert "Unexpected content type: text/html" in str(error) + + @pytest.mark.anyio + async def test_handle_json_response_parsing_error(self): + """Test JSON response handling with parsing error.""" + transport = StreamableHTTPTransport("http://test.example.com") + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + # Mock response with invalid JSON + mock_response = Mock() + mock_response.aread = AsyncMock(return_value=b"invalid json{{{") + + await transport._handle_json_response(mock_response, send_stream) + + # Should send exception + error = await receive_stream.receive() + assert isinstance(error, Exception) + + @pytest.mark.anyio + async def test_handle_sse_response_error(self): + """Test SSE response handling with error during iteration.""" + transport = StreamableHTTPTransport("http://test.example.com") + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + message = JSONRPCMessage(JSONRPCRequest(jsonrpc="2.0", method="test", id="1")) + session_message = SessionMessage(message) + + # Mock response that raises error during SSE iteration + mock_response = Mock() + mock_response.aclose = AsyncMock() + + from mcp.client.streamable_http import RequestContext + + ctx = RequestContext( + client=Mock(), + headers={}, + extensions=None, + session_id=None, + session_message=session_message, + metadata=None, + read_stream_writer=send_stream, + sse_read_timeout=60, + ) + + # Mock EventSource that raises exception + with patch("mcp.client.streamable_http.EventSource") as mock_event_source: + mock_source_instance = Mock() + + async def error_iter(): + raise RuntimeError("SSE iteration error") + yield # pragma: no cover + + mock_source_instance.aiter_sse = Mock(return_value=error_iter()) + mock_event_source.return_value = mock_source_instance + + await transport._handle_sse_response(mock_response, ctx) + + # Should send exception + error = await receive_stream.receive() + assert isinstance(error, RuntimeError) + assert "SSE iteration error" in str(error) + + @pytest.mark.anyio + async def test_terminate_session_no_session_id(self): + """Test session termination when no session ID exists.""" + transport = StreamableHTTPTransport("http://test.example.com") + transport.session_id = None + + async with httpx.AsyncClient() as client: + # Should return immediately without making request + await transport.terminate_session(client) + + # No assertion needed - just verifying it doesn't crash + + @pytest.mark.anyio + async def test_terminate_session_405_method_not_allowed(self): + """Test session termination with 405 response.""" + transport = StreamableHTTPTransport("http://test.example.com") + transport.session_id = "test-session" + + # Mock client that returns 405 + mock_response = Mock() + mock_response.status_code = 405 + + mock_client = AsyncMock() + mock_client.delete = AsyncMock(return_value=mock_response) + + # Should log debug message but not raise + await transport.terminate_session(mock_client) + + @pytest.mark.anyio + async def test_terminate_session_non_success_status(self): + """Test session termination with non-200/204 response.""" + transport = StreamableHTTPTransport("http://test.example.com") + transport.session_id = "test-session" + + # Mock client that returns 500 + mock_response = Mock() + mock_response.status_code = 500 + + mock_client = AsyncMock() + mock_client.delete = AsyncMock(return_value=mock_response) + + # Should log warning but not raise + await transport.terminate_session(mock_client) + + @pytest.mark.anyio + async def test_terminate_session_connection_error(self): + """Test session termination with connection error.""" + transport = StreamableHTTPTransport("http://test.example.com") + transport.session_id = "test-session" + + # Mock client that raises exception + mock_client = AsyncMock() + mock_client.delete = AsyncMock(side_effect=httpx.ConnectError("Connection failed")) + + # Should log warning but not raise + await transport.terminate_session(mock_client) + + @pytest.mark.anyio + async def test_post_writer_exception_handling(self): + """Test post_writer handles exceptions gracefully.""" + transport = StreamableHTTPTransport("http://test.example.com") + + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](10) + read_send, read_receive = anyio.create_memory_object_stream[SessionMessage | Exception](10) + write_stream_out, write_out_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async with read_receive, write_out_receive: + # Mock client that raises exception + mock_client = AsyncMock() + mock_client.stream = Mock(side_effect=RuntimeError("Connection error")) + + # Mock task group + mock_tg = Mock() + mock_tg.start_soon = Mock() + + def start_get_stream(): + pass + + # Send a message that will trigger the error + message = JSONRPCMessage(JSONRPCNotification(jsonrpc="2.0", method="test")) + await write_stream.send(SessionMessage(message)) + await write_stream.aclose() + + # Should handle exception and close streams (post_writer closes them) + await transport.post_writer( + mock_client, write_stream_reader, read_send, write_stream_out, start_get_stream, mock_tg + ) + + @pytest.mark.anyio + async def test_handle_post_request_unexpected_content_type(self): + """Test POST request with unexpected content type in response.""" + transport = StreamableHTTPTransport("http://test.example.com") + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + # Create a request message + message = JSONRPCMessage(JSONRPCRequest(jsonrpc="2.0", method="test", id="1")) + session_message = SessionMessage(message) + + # Mock response with unexpected content type + mock_response = Mock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "text/plain"} + mock_response.raise_for_status = Mock() + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + mock_client = AsyncMock() + mock_client.stream = Mock(return_value=mock_response) + + from mcp.client.streamable_http import RequestContext + + ctx = RequestContext( + client=mock_client, + headers={}, + extensions=None, + session_id=None, + session_message=session_message, + metadata=None, + read_stream_writer=send_stream, + sse_read_timeout=60, + ) + + await transport._handle_post_request(ctx) + + # Should send ValueError for unexpected content type + error = await receive_stream.receive() + assert isinstance(error, ValueError) + assert "Unexpected content type" in str(error) + + @pytest.mark.anyio + async def test_handle_post_request_202_accepted(self): + """Test POST request with 202 Accepted response.""" + transport = StreamableHTTPTransport("http://test.example.com") + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + # Create a notification message + message = JSONRPCMessage(JSONRPCNotification(jsonrpc="2.0", method="test")) + session_message = SessionMessage(message) + + # Mock response with 202 status + mock_response = Mock() + mock_response.status_code = 202 + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + mock_client = AsyncMock() + mock_client.stream = Mock(return_value=mock_response) + + from mcp.client.streamable_http import RequestContext + + ctx = RequestContext( + client=mock_client, + headers={}, + extensions=None, + session_id=None, + session_message=session_message, + metadata=None, + read_stream_writer=send_stream, + sse_read_timeout=60, + ) + + await transport._handle_post_request(ctx) + + # Should return early, no messages sent + with anyio.fail_after(0.1): + with pytest.raises(anyio.WouldBlock): + receive_stream.receive_nowait() + + +class TestStreamableHTTPResumption: + """Test resumption-related edge cases.""" + + @pytest.mark.anyio + async def test_handle_resumption_request_extracts_original_id(self): + """Test that resumption request extracts original request ID.""" + transport = StreamableHTTPTransport("http://test.example.com") + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + # Create request with ID + message = JSONRPCMessage(JSONRPCRequest(jsonrpc="2.0", method="test", id="original-id")) + session_message = SessionMessage(message) + + # Mock successful SSE response + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.aclose = AsyncMock() + + # Mock SSE that returns a response + response_data = { + "jsonrpc": "2.0", + "id": "replaced-id", + "result": {"success": True}, + } + mock_sse = ServerSentEvent(event="message", data=json.dumps(response_data), id="1") + + async def mock_iter(): + yield mock_sse + + mock_event_source = Mock() + mock_event_source.response = mock_response + mock_event_source.aiter_sse = Mock(return_value=mock_iter()) + mock_event_source.__aenter__ = AsyncMock(return_value=mock_event_source) + mock_event_source.__aexit__ = AsyncMock() + + mock_client = AsyncMock() + + from mcp.client.streamable_http import RequestContext + + token_updates: list[str] = [] + + async def on_token_update(token: str): + token_updates.append(token) + + metadata = ClientMessageMetadata( + resumption_token="test-token", + on_resumption_token_update=on_token_update, + ) + + ctx = RequestContext( + client=mock_client, + headers={}, + extensions=None, + session_id=None, + session_message=session_message, + metadata=metadata, + read_stream_writer=send_stream, + sse_read_timeout=60, + ) + + with patch("mcp.client.streamable_http.aconnect_sse", return_value=mock_event_source): + await transport._handle_resumption_request(ctx) + + # Token should have been updated + assert "1" in token_updates + + +class TestStreamableHTTPInitialization: + """Test initialization-related edge cases.""" + + @pytest.mark.anyio + async def test_protocol_version_extraction_from_sse_response(self): + """Test protocol version is extracted from SSE initialization response.""" + transport = StreamableHTTPTransport("http://test.example.com") + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + # Create initialization response + init_result = { + "protocolVersion": "2025-03-26", + "serverInfo": {"name": "test", "version": "1.0"}, + "capabilities": {}, + } + response_data = { + "jsonrpc": "2.0", + "id": "init-1", + "result": init_result, + } + + sse = ServerSentEvent(event="message", data=json.dumps(response_data), id="1") + + # Handle with initialization flag + result = await transport._handle_sse_event( + sse, + send_stream, + is_initialization=True, + ) + + # Should extract protocol version + assert transport.protocol_version == "2025-03-26" + assert result is True # Response complete + + @pytest.mark.anyio + async def test_protocol_version_extraction_from_json_response(self): + """Test protocol version is extracted from JSON initialization response.""" + transport = StreamableHTTPTransport("http://test.example.com") + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + async with send_stream, receive_stream: + # Create mock JSON response + init_result = { + "protocolVersion": "2025-03-26", + "serverInfo": {"name": "test", "version": "1.0"}, + "capabilities": {}, + } + response_data = { + "jsonrpc": "2.0", + "id": "init-1", + "result": init_result, + } + + mock_response = Mock() + mock_response.aread = AsyncMock(return_value=json.dumps(response_data).encode()) + + await transport._handle_json_response(mock_response, send_stream, is_initialization=True) + + # Should extract protocol version + assert transport.protocol_version == "2025-03-26" diff --git a/tests/issues/test_1027_win_unreachable_cleanup.py b/tests/issues/test_1027_win_unreachable_cleanup.py index 999bb9eadf..ebef526623 100644 --- a/tests/issues/test_1027_win_unreachable_cleanup.py +++ b/tests/issues/test_1027_win_unreachable_cleanup.py @@ -213,8 +213,10 @@ def echo(text: str) -> str: await anyio.sleep(0.1) # Check if process is still running - if hasattr(process, "returncode") and process.returncode is not None: # pragma: no cover - pytest.fail(f"Server process exited with code {process.returncode}") + # fmt: off + if hasattr(process, "returncode") and process.returncode is not None: # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] # pragma: no cover + pytest.fail(f"Server process exited with code {process.returncode}") # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] + # fmt: on assert Path(startup_marker).exists(), "Server startup marker not created" diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 8e8884270e..48e2a8f720 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -7,6 +7,7 @@ import json import multiprocessing import socket +import time from collections.abc import Generator from typing import Any @@ -15,7 +16,6 @@ import pytest import requests import uvicorn -from httpx_sse import ServerSentEvent from pydantic import AnyUrl from starlette.applications import Starlette from starlette.requests import Request @@ -23,7 +23,7 @@ import mcp.types as types from mcp.client.session import ClientSession -from mcp.client.streamable_http import StreamableHTTPTransport, streamablehttp_client +from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, @@ -40,10 +40,9 @@ from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError -from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.shared.message import ClientMessageMetadata from mcp.shared.session import RequestResponder from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool -from tests.test_helpers import wait_for_server # Test constants SERVER_NAME = "test_streamable_http_server" @@ -61,7 +60,7 @@ # Helper functions -def extract_protocol_version_from_sse(response: requests.Response) -> str: # pragma: no cover +def extract_protocol_version_from_sse(response: requests.Response) -> str: """Extract the negotiated protocol version from an SSE initialization response.""" assert response.headers.get("Content-Type") == "text/event-stream" for line in response.text.splitlines(): @@ -79,14 +78,14 @@ def __init__(self): self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = [] self._event_id_counter = 0 - async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage) -> EventId: # pragma: no cover + async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage) -> EventId: """Store an event and return its ID.""" self._event_id_counter += 1 event_id = str(self._event_id_counter) self._events.append((stream_id, event_id, message)) return event_id - async def replay_events_after( # pragma: no cover + async def replay_events_after( self, last_event_id: EventId, send_callback: EventCallback, @@ -115,7 +114,7 @@ async def replay_events_after( # pragma: no cover # Test server implementation that follows MCP protocol -class ServerTest(Server): # pragma: no cover +class ServerTest(Server): def __init__(self): super().__init__(SERVER_NAME) self._lock = None # Will be initialized in async context @@ -211,10 +210,12 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] ) # Return the sampling result in the tool response - if all(c.type == "text" for c in sampling_result.content_as_list): - response = "\n".join(c.text for c in sampling_result.content_as_list if c.type == "text") - else: - response = str(sampling_result.content) + content = ( + sampling_result.content + if not isinstance(sampling_result.content, list) + else sampling_result.content[0] + ) + response = content.text if content.type == "text" else None # type: ignore[attr-defined] return [ TextContent( type="text", @@ -258,9 +259,7 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] return [TextContent(type="text", text=f"Called {name}")] -def create_app( - is_json_response_enabled: bool = False, event_store: EventStore | None = None -) -> Starlette: # pragma: no cover +def create_app(is_json_response_enabled: bool = False, event_store: EventStore | None = None) -> Starlette: """Create a Starlette application for testing using the session manager. Args: @@ -293,9 +292,7 @@ def create_app( return app -def run_server( - port: int, is_json_response_enabled: bool = False, event_store: EventStore | None = None -) -> None: # pragma: no cover +def run_server(port: int, is_json_response_enabled: bool = False, event_store: EventStore | None = None) -> None: """Run the test server. Args: @@ -352,7 +349,18 @@ def basic_server(basic_server_port: int) -> Generator[None, None, None]: proc.start() # Wait for server to be running - wait_for_server(basic_server_port) + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", basic_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") yield @@ -388,7 +396,18 @@ def event_server( proc.start() # Wait for server to be running - wait_for_server(event_server_port) + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", event_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") yield event_store, f"http://127.0.0.1:{event_server_port}" @@ -408,7 +427,18 @@ def json_response_server(json_server_port: int) -> Generator[None, None, None]: proc.start() # Wait for server to be running - wait_for_server(json_server_port) + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", json_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") yield @@ -668,51 +698,6 @@ def test_json_response(json_response_server: None, json_server_url: str): assert response.headers.get("Content-Type") == "application/json" -def test_json_response_accept_json_only(json_response_server: None, json_server_url: str): - """Test that json_response servers only require application/json in Accept header.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" - - -def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): - """Test that json_response servers reject requests without Accept header.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - -def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str): - """Test that json_response servers reject requests with incorrect Accept header.""" - mcp_url = f"{json_server_url}/mcp" - # Test with only text/event-stream (wrong for JSON server) - response = requests.post( - mcp_url, - headers={ - "Accept": "text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - def test_get_sse_stream(basic_server: None, basic_server_url: str): """Test establishing an SSE stream via GET request.""" # First, we need to initialize a session @@ -734,8 +719,8 @@ def test_get_sse_stream(basic_server: None, basic_server_url: str): # Extract negotiated protocol version from SSE response init_data = None assert init_response.headers.get("Content-Type") == "text/event-stream" - for line in init_response.text.splitlines(): # pragma: no branch - if line.startswith("data: "): # pragma: no cover + for line in init_response.text.splitlines(): + if line.startswith("data: "): init_data = json.loads(line[6:]) break assert init_data is not None @@ -794,8 +779,8 @@ def test_get_validation(basic_server: None, basic_server_url: str): # Extract negotiated protocol version from SSE response init_data = None assert init_response.headers.get("Content-Type") == "text/event-stream" - for line in init_response.text.splitlines(): # pragma: no branch - if line.startswith("data: "): # pragma: no cover + for line in init_response.text.splitlines(): + if line.startswith("data: "): init_data = json.loads(line[6:]) break assert init_data is not None @@ -828,7 +813,7 @@ def test_get_validation(basic_server: None, basic_server_url: str): # Client-specific fixtures @pytest.fixture -async def http_client(basic_server: None, basic_server_url: str): # pragma: no cover +async def http_client(basic_server: None, basic_server_url: str): """Create test client matching the SSE test pattern.""" async with httpx.AsyncClient(base_url=basic_server_url) as client: yield client @@ -967,10 +952,10 @@ async def test_streamablehttp_client_get_stream(basic_server: None, basic_server notifications_received: list[types.ServerNotification] = [] # Define message handler to capture notifications - async def message_handler( # pragma: no branch + async def message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: - if isinstance(message, types.ServerNotification): # pragma: no branch + if isinstance(message, types.ServerNotification): notifications_received.append(message) async with streamablehttp_client(f"{basic_server_url}/mcp") as ( @@ -992,7 +977,7 @@ async def message_handler( # pragma: no branch # Verify the notification is a ResourceUpdatedNotification resource_update_found = False for notif in notifications_received: - if isinstance(notif.root, types.ResourceUpdatedNotification): # pragma: no branch + if isinstance(notif.root, types.ResourceUpdatedNotification): assert str(notif.root.params.uri) == "http://test_resource/" resource_update_found = True @@ -1022,8 +1007,8 @@ async def test_streamablehttp_client_session_termination(basic_server: None, bas tools = await session.list_tools() assert len(tools.tools) == 6 - headers: dict[str, str] = {} # pragma: no cover - if captured_session_id: # pragma: no cover + headers: dict[str, str] = {} + if captured_session_id: headers[MCP_SESSION_ID_HEADER] = captured_session_id async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as ( @@ -1033,7 +1018,7 @@ async def test_streamablehttp_client_session_termination(basic_server: None, bas ): async with ClientSession(read_stream, write_stream) as session: # Attempt to make a request after termination - with pytest.raises( # pragma: no branch + with pytest.raises( McpError, match="Session terminated", ): @@ -1088,8 +1073,8 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt tools = await session.list_tools() assert len(tools.tools) == 6 - headers: dict[str, str] = {} # pragma: no cover - if captured_session_id: # pragma: no cover + headers: dict[str, str] = {} + if captured_session_id: headers[MCP_SESSION_ID_HEADER] = captured_session_id async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as ( @@ -1099,7 +1084,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt ): async with ClientSession(read_stream, write_stream) as session: # Attempt to make a request after termination - with pytest.raises( # pragma: no branch + with pytest.raises( McpError, match="Session terminated", ): @@ -1118,13 +1103,13 @@ async def test_streamablehttp_client_resumption(event_server: tuple[SimpleEventS captured_protocol_version = None first_notification_received = False - async def message_handler( # pragma: no branch + async def message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: - if isinstance(message, types.ServerNotification): # pragma: no branch + if isinstance(message, types.ServerNotification): captured_notifications.append(message) # Look for our first notification - if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch + if isinstance(message.root, types.LoggingMessageNotification): if message.root.params.data == "First notification before lock": nonlocal first_notification_received first_notification_received = True @@ -1177,18 +1162,18 @@ async def run_tool(): tg.cancel_scope.cancel() # Verify we received exactly one notification - assert len(captured_notifications) == 1 # pragma: no cover - assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) # pragma: no cover - assert captured_notifications[0].root.params.data == "First notification before lock" # pragma: no cover + assert len(captured_notifications) == 1 + assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) + assert captured_notifications[0].root.params.data == "First notification before lock" # Clear notifications for the second phase - captured_notifications = [] # pragma: no cover + captured_notifications = [] # Now resume the session with the same mcp-session-id and protocol version - headers: dict[str, Any] = {} # pragma: no cover - if captured_session_id: # pragma: no cover + headers: dict[str, Any] = {} + if captured_session_id: headers[MCP_SESSION_ID_HEADER] = captured_session_id - if captured_protocol_version: # pragma: no cover + if captured_protocol_version: headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as ( read_stream, @@ -1243,8 +1228,12 @@ async def sampling_callback( nonlocal sampling_callback_invoked, captured_message_params sampling_callback_invoked = True captured_message_params = params - msg_content = params.messages[0].content_as_list[0] - message_received = msg_content.text if msg_content.type == "text" else None + content = ( + params.messages[0].content + if not isinstance(params.messages[0].content, list) + else params.messages[0].content[0] + ) + message_received = content.text if content.type == "text" else None # type: ignore[attr-defined] return types.CreateMessageResult( role="assistant", @@ -1287,7 +1276,7 @@ async def sampling_callback( # Context-aware server implementation for testing request context propagation -class ContextAwareServerTest(Server): # pragma: no cover +class ContextAwareServerTest(Server): def __init__(self): super().__init__("ContextAwareServer") @@ -1347,7 +1336,7 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] # Server runner for context-aware testing -def run_context_aware_server(port: int): # pragma: no cover +def run_context_aware_server(port: int): """Run the context-aware test server.""" server = ContextAwareServerTest() @@ -1383,13 +1372,24 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: proc.start() # Wait for server to be running - wait_for_server(basic_server_port) + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", basic_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Context-aware server failed to start after {max_attempts} attempts") yield proc.kill() proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover + if proc.is_alive(): print("Context-aware server process failed to terminate") @@ -1452,8 +1452,8 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No contexts.append(context_data) # Verify each request had its own context - assert len(contexts) == 3 # pragma: no cover - for i, ctx in enumerate(contexts): # pragma: no cover + assert len(contexts) == 3 + for i, ctx in enumerate(contexts): assert ctx["request_id"] == f"request-{i}" assert ctx["headers"].get("x-request-id") == f"request-{i}" assert ctx["headers"].get("x-custom-value") == f"value-{i}" @@ -1609,27 +1609,346 @@ async def bad_client(): assert tools.tools -@pytest.mark.anyio -async def test_handle_sse_event_skips_empty_data(): - """Test that _handle_sse_event skips empty SSE data (keep-alive pings).""" - transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") +# Extensions Tests +class TestStreamableHTTPExtensions: + """Test class for StreamableHTTP extensions functionality.""" + + def test_extensions_initialization_none(self): + """Test that extensions are properly initialized when None.""" + from mcp.client.streamable_http import StreamableHTTPTransport + + transport = StreamableHTTPTransport("http://test.example.com") + assert transport.extensions == {} + + def test_extensions_initialization_empty_dict(self): + """Test that extensions are properly initialized with empty dict.""" + from mcp.client.streamable_http import StreamableHTTPTransport + + transport = StreamableHTTPTransport("http://test.example.com", extensions={}) + assert transport.extensions == {} + + def test_extensions_initialization_with_data(self): + """Test that extensions are properly initialized with provided data.""" + from mcp.client.streamable_http import StreamableHTTPTransport + + extensions = {"custom_extension": "test_value", "trace_id": "123456"} + transport = StreamableHTTPTransport("http://test.example.com", extensions=extensions) + assert transport.extensions == extensions + # Ensure it's a copy, not the same reference + assert transport.extensions is not extensions + + def test_extensions_preparation_none_base(self): + """Test that _prepare_request_extensions works with None base extensions.""" + from mcp.client.streamable_http import StreamableHTTPTransport + + transport = StreamableHTTPTransport("http://test.example.com") + result = transport._prepare_request_extensions(None) + assert result == {} + + def test_extensions_preparation_empty_base(self): + """Test that _prepare_request_extensions works with empty base extensions.""" + from mcp.client.streamable_http import StreamableHTTPTransport + + transport = StreamableHTTPTransport("http://test.example.com") + result = transport._prepare_request_extensions({}) + assert result == {} + + def test_extensions_preparation_with_base(self): + """Test that _prepare_request_extensions works with base extensions.""" + from mcp.client.streamable_http import StreamableHTTPTransport + + transport = StreamableHTTPTransport("http://test.example.com") + base_extensions = {"request_id": "req_123", "custom": "value"} + result = transport._prepare_request_extensions(base_extensions) + assert result == base_extensions + # Ensure it's a copy, not the same reference + assert result is not base_extensions + + def test_extensions_preparation_preserves_original(self): + """Test that _prepare_request_extensions doesn't modify the original.""" + from mcp.client.streamable_http import StreamableHTTPTransport + + transport = StreamableHTTPTransport("http://test.example.com") + base_extensions = {"request_id": "req_123"} + original_extensions = base_extensions.copy() + + result = transport._prepare_request_extensions(base_extensions) + + # Original should be unchanged + assert base_extensions == original_extensions + # Result should be a copy + assert result == base_extensions + assert result is not base_extensions + + @pytest.mark.anyio + async def test_extensions_passed_to_streamablehttp_client(self, basic_server: None, basic_server_url: str): + """Test that extensions are properly passed through streamablehttp_client.""" + test_extensions = { + "test_extension": "test_value", + "trace_id": "ext_trace_123", + "custom_metadata": "custom_data", + } - # Create a mock SSE event with empty data (keep-alive ping) - mock_sse = ServerSentEvent(event="message", data="", id=None, retry=None) + async with streamablehttp_client(f"{basic_server_url}/mcp", extensions=test_extensions) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + # Test initialization with extensions + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME + + # Test that session works with extensions + tools = await session.list_tools() + assert len(tools.tools) == 6 + + @pytest.mark.anyio + async def test_extensions_with_empty_dict(self, basic_server: None, basic_server_url: str): + """Test streamablehttp_client with empty extensions dict.""" + async with streamablehttp_client(f"{basic_server_url}/mcp", extensions={}) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) - # Create a mock stream writer - write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + @pytest.mark.anyio + async def test_extensions_with_none(self, basic_server: None, basic_server_url: str): + """Test streamablehttp_client with None extensions.""" + async with streamablehttp_client(f"{basic_server_url}/mcp", extensions=None) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + + def test_extensions_request_context_creation(self): + """Test that RequestContext includes extensions correctly.""" + import asyncio + + import anyio + import httpx + + from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport + from mcp.shared.message import SessionMessage + from mcp.types import JSONRPCMessage, JSONRPCRequest + + # Create transport with extensions + test_extensions = {"custom": "data", "trace": "123"} + transport = StreamableHTTPTransport("http://test.example.com", extensions=test_extensions) + + async def run_test(): + # Create mock objects for the context + client = httpx.AsyncClient() + read_stream_writer, read_stream_reader = anyio.create_memory_object_stream[SessionMessage | Exception](0) + + try: + message = JSONRPCMessage(JSONRPCRequest(jsonrpc="2.0", method="test_method", id="test_id")) + session_message = SessionMessage(message) + + # Create RequestContext + ctx = RequestContext( + client=client, + headers={}, + extensions=transport.extensions, + session_id=None, + session_message=session_message, + metadata=None, + read_stream_writer=read_stream_writer, + sse_read_timeout=60, + ) - try: - # Call _handle_sse_event with empty data - should return False and not raise - result = await transport._handle_sse_event(mock_sse, write_stream) - - # Should return False (not complete) for empty data - assert result is False - - # Nothing should have been written to the stream - # Check buffer is empty (statistics().current_buffer_used returns buffer size) - assert write_stream.statistics().current_buffer_used == 0 - finally: - await write_stream.aclose() - await read_stream.aclose() + assert ctx.extensions == test_extensions + # RequestContext uses the same reference to extensions, which is acceptable + assert ctx.extensions is transport.extensions + finally: + # Clean up resources + await read_stream_writer.aclose() + await read_stream_reader.aclose() + await client.aclose() + + # Run the async test + asyncio.run(run_test()) + + @pytest.mark.anyio + async def test_extensions_isolation_between_clients(self, basic_server: None, basic_server_url: str): + """Test that extensions are isolated between different client instances.""" + extensions_1 = {"client": "1", "session": "session_1"} + extensions_2 = {"client": "2", "session": "session_2"} + + # Create two clients with different extensions + results: list[tuple[str, str]] = [] + + async with streamablehttp_client(f"{basic_server_url}/mcp", extensions=extensions_1) as ( + read_stream1, + write_stream1, + _, + ): + async with ClientSession(read_stream1, write_stream1) as session1: + result1 = await session1.initialize() + results.append(("client1", result1.serverInfo.name)) + + async with streamablehttp_client(f"{basic_server_url}/mcp", extensions=extensions_2) as ( + read_stream2, + write_stream2, + _, + ): + async with ClientSession(read_stream2, write_stream2) as session2: + result2 = await session2.initialize() + results.append(("client2", result2.serverInfo.name)) + + # Both clients should work independently + assert len(results) == 2 + assert all(name == SERVER_NAME for _, name in results) + + def test_extensions_immutability(self): + """Test that modifying extensions after transport creation doesn't affect the transport.""" + from mcp.client.streamable_http import StreamableHTTPTransport + + original_extensions = {"mutable": "original"} + transport = StreamableHTTPTransport("http://test.example.com", extensions=original_extensions) + + # Modify the original extensions dict + original_extensions["mutable"] = "modified" + original_extensions["new_key"] = "new_value" + + # Transport should still have the original values + assert transport.extensions == {"mutable": "original"} + assert "new_key" not in transport.extensions + + @pytest.mark.anyio + async def test_extensions_passed_to_httpx_requests(self, basic_server: None, basic_server_url: str): + """Test that extensions are actually passed to httpx client requests.""" + from contextlib import asynccontextmanager + from typing import Any + + import httpx + + test_extensions = {"test_key": "test_value", "trace_id": "httpx_trace_123"} + + captured_extensions: list[dict[str, str]] = [] + + # Create a mock httpx client that captures extensions + class ExtensionCapturingClient(httpx.AsyncClient): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + @asynccontextmanager + async def stream(self, *args: Any, **kwargs: Any): + # Capture extensions when stream is called + if "extensions" in kwargs: + captured_extensions.append(kwargs["extensions"]) + # Call the real stream method + async with super().stream(*args, **kwargs) as response: + yield response + + # Custom client factory that returns our capturing client + def custom_client_factory( + headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, auth: httpx.Auth | None = None + ) -> httpx.AsyncClient: + return ExtensionCapturingClient( + headers=headers, + timeout=timeout, + auth=auth, + ) + + async with streamablehttp_client( + f"{basic_server_url}/mcp/", extensions=test_extensions, httpx_client_factory=custom_client_factory + ) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + # Initialize - this should make a POST request with extensions + await session.initialize() + + # Make another request to capture more extensions usage + await session.list_tools() + + # Verify extensions were captured in requests + assert len(captured_extensions) > 0 + + # Check that our test extensions were included + for captured in captured_extensions: + assert "test_key" in captured + assert captured["test_key"] == "test_value" + assert "trace_id" in captured + assert captured["trace_id"] == "httpx_trace_123" + + @pytest.mark.anyio + async def test_extensions_with_json_and_sse_responses(self, basic_server: None, basic_server_url: str): + """Test that extensions work with both JSON and SSE response types.""" + test_extensions = {"response_test": "json_sse_test", "format": "both"} + + # Test with regular SSE response (default behavior) + async with streamablehttp_client(f"{basic_server_url}/mcp", extensions=test_extensions) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Call tool which should work with SSE + tool_result = await session.call_tool("test_tool", {}) + assert len(tool_result.content) == 1 + content = tool_result.content[0] + assert content.type == "text" + from mcp.types import TextContent + + assert isinstance(content, TextContent) + assert content.text == "Called test_tool" + + @pytest.mark.anyio + async def test_extensions_with_json_response_server(self, json_response_server: None, json_server_url: str): + """Test extensions work with JSON response mode.""" + test_extensions = {"response_mode": "json_only", "test_id": "json_test_123"} + + async with streamablehttp_client(f"{json_server_url}/mcp", extensions=test_extensions) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + + tools = await session.list_tools() + assert len(tools.tools) == 6 + + def test_extensions_type_validation(self): + """Test that extensions parameter accepts proper types.""" + from mcp.client.streamable_http import StreamableHTTPTransport + + # Test with valid dict[str, str] + valid_extensions = {"key1": "value1", "key2": "value2"} + transport = StreamableHTTPTransport("http://test.com", extensions=valid_extensions) + assert transport.extensions == valid_extensions + + # Test with None (should default to empty dict) + transport_none = StreamableHTTPTransport("http://test.com", extensions=None) + assert transport_none.extensions == {} + + # Test with empty dict + transport_empty = StreamableHTTPTransport("http://test.com", extensions={}) + assert transport_empty.extensions == {} + + @pytest.mark.anyio + async def test_extensions_with_special_characters(self, basic_server: None, basic_server_url: str): + """Test that extensions work with special characters in values.""" + test_extensions = { + "special_chars": "test-value_with.special@chars#123!", + "unicode": "test_ζ΅‹θ―•_πŸ”§", + "json_like": '{"nested": "value"}', + "url_like": "https://example.com/path?param=value", + } + + async with streamablehttp_client(f"{basic_server_url}/mcp", extensions=test_extensions) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + # Should not throw any errors with special characters + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Should work normally with tools + tools = await session.list_tools() + assert len(tools.tools) == 6 diff --git a/uv.lock b/uv.lock index 4ca6afcf03..848031c867 100644 --- a/uv.lock +++ b/uv.lock @@ -11,6 +11,7 @@ members = [ "mcp-simple-auth-client", "mcp-simple-chatbot", "mcp-simple-pagination", + "mcp-simple-private-gateway", "mcp-simple-prompt", "mcp-simple-resource", "mcp-simple-streamablehttp", @@ -757,6 +758,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4f/65/6079a46068dfceaeabb5dcad6d674f5f5c61a6fa5673746f42a9f4c233b3/MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", size = 15739, upload-time = "2024-10-18T15:21:42.784Z" }, ] +[[package]] +name = "mcp" +version = "1.23.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "jsonschema" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "pyjwt", extra = ["crypto"] }, + { name = "python-multipart" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "sse-starlette" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/25/1a/9c8a5362e3448d585081d6c7aa95898a64e0ac59d3e26169ae6c3ca5feaf/mcp-1.23.0.tar.gz", hash = "sha256:84e0c29316d0a8cf0affd196fd000487ac512aa3f771b63b2ea864e22961772b", size = 596506, upload-time = "2025-12-02T13:40:02.558Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/b2/28739ce409f98159c0121eab56e69ad71546c4f34ac8b42e58c03f57dccc/mcp-1.23.0-py3-none-any.whl", hash = "sha256:5a645cf111ed329f4619f2629a3f15d9aabd7adc2ea09d600d31467b51ecb64f", size = 231427, upload-time = "2025-12-02T13:40:00.738Z" }, +] + [[package]] name = "mcp-conformance-auth-client" version = "0.1.0" @@ -887,7 +913,7 @@ requires-dist = [ { name = "pywin32", marker = "sys_platform == 'win32'", specifier = ">=310" }, { name = "rich", marker = "extra == 'rich'", specifier = ">=13.9.4" }, { name = "sse-starlette", specifier = ">=1.6.1" }, - { name = "starlette", specifier = ">=0.27" }, + { name = "starlette", specifier = ">=0.49.1" }, { name = "typing-extensions", specifier = ">=4.9.0" }, { name = "typing-inspection", specifier = ">=0.4.1" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.31.1" }, @@ -1055,6 +1081,35 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mcp-simple-private-gateway" +version = "0.1.0" +source = { editable = "examples/clients/simple-private-gateway" } +dependencies = [ + { name = "click" }, + { name = "mcp" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "click", specifier = ">=8.2.0" }, + { name = "mcp" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.379" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-prompt" version = "0.1.0"