diff --git a/examples/auth/03_client_credentials/client.py b/examples/auth/03_client_credentials/client.py deleted file mode 100644 index f66d27c..0000000 --- a/examples/auth/03_client_credentials/client.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors -# SPDX-License-Identifier: MIT - -"""Client-credentials MCP client for the Supabase OAuth demo. - -This example demonstrates fully headless authentication: the CLI exchanges a -confidential ``client_id``/``client_secret`` pair for an access token using the -OAuth 2.1 ``client_credentials`` grant, then calls the protected MCP resource -server from ``examples/auth/02_as``. - -Quick start:: - - # 1. Seed the "dedalus-m2m" client by exporting a secret before - # starting the Go authorization server - $ export AS_M2M_CLIENT_SECRET="dev-m2m-secret" - $ cd ~/Desktop/dedalus-labs/codebase/mcp-knox/openmcp-authorization-server - $ go run ./cmd/serve - - # 2. Start the protected Supabase resource server - $ cd ~/Desktop/dedalus-labs/codebase/openmcp - $ uv run python examples/auth/02_as/server.py - - # 3. Run this client with matching credentials (env or flags) - $ export MCP_CLIENT_SECRET="dev-m2m-secret" - $ uv run python examples/auth/03_client_credentials/client.py \ - --client-id dedalus-m2m --table users --limit 5 - -Because this grant type never involves a browser or Clerk, it is ideal for CI/CD -pipelines and other machine-to-machine workflows. -""" - -from __future__ import annotations - -import argparse -import asyncio -import json -import os -from typing import Any - -import httpx -from pydantic import ValidationError - -from dedalus_mcp.client import open_connection -from dedalus_mcp.types import ( - CallToolRequest, - CallToolRequestParams, - CallToolResult, - ClientRequest, - ListToolsRequest, - ListToolsResult, -) -from dedalus_mcp.utils import to_json - -DEFAULT_SERVER_URL = os.getenv("MCP_SERVER_URL", "http://127.0.0.1:8000/mcp") -DEFAULT_RESOURCE = os.getenv("MCP_RESOURCE_URL", "http://127.0.0.1:8000") -DEFAULT_ISSUER = os.getenv("AS_ISSUER", "http://localhost:4444") -DEFAULT_SCOPE = os.getenv("MCP_REQUIRED_SCOPES", "mcp:tools:call") -DEFAULT_CLIENT_ID = os.getenv("MCP_CLIENT_ID", "dedalus-m2m") -DEFAULT_CLIENT_SECRET = os.getenv("MCP_CLIENT_SECRET") - - -class OAuthError(RuntimeError): - """Raised when the OAuth handshake fails.""" - - -async def fetch_access_token(args: argparse.Namespace) -> dict[str, Any]: - """Exchange client credentials for an access token.""" - - token_url = f"{args.issuer.rstrip('/')}/oauth2/token" - data = { - "grant_type": "client_credentials", - "scope": args.scope, - "resource": args.resource, - } - auth = httpx.BasicAuth(args.client_id, args.client_secret) - - async with httpx.AsyncClient(timeout=30.0) as client: - try: - token_response = await client.post(token_url, data=data, auth=auth) - except httpx.ConnectError as exc: - raise OAuthError( - f"Failed to reach token endpoint at {token_url}. " - "Confirm AS_ISSUER is correct and network access is available." - ) from exc - - if token_response.status_code != 200: - raise OAuthError( - f"Token request failed: HTTP {token_response.status_code} {token_response.text}" - ) - return token_response.json() - - -def build_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Supabase OAuth client demo (client_credentials grant)") - parser.add_argument("--url", default=DEFAULT_SERVER_URL, help="MCP endpoint (default: %(default)s)") - parser.add_argument("--resource", default=DEFAULT_RESOURCE, help="Resource/audience URI (default: %(default)s)") - parser.add_argument("--issuer", default=DEFAULT_ISSUER, help="Authorization Server issuer (default: %(default)s)") - parser.add_argument("--client-id", default=DEFAULT_CLIENT_ID, help="OAuth client_id (default: %(default)s)") - parser.add_argument( - "--client-secret", - default=DEFAULT_CLIENT_SECRET, - help="OAuth client_secret (default: MCP_CLIENT_SECRET env)", - ) - parser.add_argument("--scope", default=DEFAULT_SCOPE, help="Space-separated scopes (default: %(default)s)") - parser.add_argument("--table", default="users", help="Supabase table to query") - parser.add_argument("--columns", default="*", help="Column projection for Supabase") - parser.add_argument("--limit", type=int, default=5, help="Row limit (default: %(default)s)") - parser.add_argument( - "--transport", - default="streamable-http", - choices=["streamable-http", "lambda-http"], - help="MCP transport (default: %(default)s)", - ) - parser.add_argument("--access-token", help="Skip OAuth flow and use an existing access token") - return parser - - -async def call_supabase_tool(args: argparse.Namespace, access_token: str) -> None: - headers = {"Authorization": f"Bearer {access_token}"} - async with open_connection(url=args.url, transport=args.transport, headers=headers) as client: - init = client.initialize_result - if init is None: - raise RuntimeError("MCP initialize handshake failed") - print( - f"Connected to {init.serverInfo.name} v{init.serverInfo.version or '0.0.0'} via {args.transport}" - ) - print(f"Negotiated MCP protocol version: {init.protocolVersion}\n") - - list_request = ClientRequest(ListToolsRequest()) - tools_result = await client.send_request(list_request, ListToolsResult) - - print("Available tools:") - for idx, tool in enumerate(tools_result.tools, start=1): - desc = tool.description or "(no description)" - print(f" {idx:>2}. {tool.name} — {desc}") - - expected_tool = "supabase_select_live" - available = {tool.name for tool in tools_result.tools} - if expected_tool not in available: - print( - f"Tool '{expected_tool}' is not available on server '{init.serverInfo.name}'.\n" - "Ensure the protected server (examples/auth/02_as/server.py) is running " - "or register an equivalent tool before retrying." - ) - return - - arguments: dict[str, Any] = {"table": args.table, "columns": args.columns} - if args.limit is not None: - arguments["limit"] = args.limit - - try: - request = ClientRequest( - CallToolRequest(params=CallToolRequestParams(name="supabase_select_live", arguments=arguments)) - ) - except ValidationError as exc: - raise SystemExit(f"Invalid tool arguments: {exc}") from exc - - result = await client.send_request(request, CallToolResult) - status = "error" if result.isError else "success" - print(f"\nTool call status: {status}") - - payload = to_json(result) - print(json.dumps(payload, indent=2)) - - -async def main() -> None: - parser = build_parser() - args = parser.parse_args() - - if args.limit is not None and args.limit < 0: - args.limit = None - - if not args.client_secret: - raise SystemExit("Provide --client-secret or set MCP_CLIENT_SECRET before running this example.") - - token_data: dict[str, Any] - if args.access_token: - token_data = {"access_token": args.access_token} - else: - print("Requesting OAuth token (client_credentials)…") - token_data = await fetch_access_token(args) - print("Received access token; calling MCP server…\n") - - access_token = token_data.get("access_token") - if not access_token: - raise SystemExit("Authorization Server response lacked access_token") - - await call_supabase_tool(args, access_token) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/auth/07_client_auth_module/README.md b/examples/auth/07_client_auth_module/README.md new file mode 100644 index 0000000..8398875 --- /dev/null +++ b/examples/auth/07_client_auth_module/README.md @@ -0,0 +1,77 @@ +# Client Auth Module Examples + +This directory demonstrates the new `dedalus_mcp.client.auth` module for spec-compliant OAuth authentication. + +## Overview + +The auth module provides: + +- **`ClientCredentialsAuth`**: M2M / backend service authentication +- **`TokenExchangeAuth`**: User delegation via token exchange (RFC 8693) +- **`DeviceCodeAuth`**: CLI tool authentication (stub) +- **`AuthorizationCodeAuth`**: Browser-based authentication (stub, planned for Clerk) + +## Quick Start + +```bash +# Set the M2M secret +export MCP_CLIENT_SECRET="your-m2m-secret" + +# Run the example +uv run python examples/auth/07_client_auth_module/client_credentials_example.py +``` + +## Usage with MCPClient + +```python +from dedalus_mcp.client import MCPClient +from dedalus_mcp.client.auth import ClientCredentialsAuth + +# Option 1: Auto-discovery from protected resource +auth = await ClientCredentialsAuth.from_resource( + resource_url="https://mcp.example.com/mcp", + client_id="m2m", + client_secret=os.environ["M2M_SECRET"], +) +await auth.get_token() +client = await MCPClient.connect("https://mcp.example.com/mcp", auth=auth) + +# Option 2: Direct construction (when you know the AS) +from dedalus_mcp.client.auth import fetch_authorization_server_metadata + +async with httpx.AsyncClient() as http: + server_metadata = await fetch_authorization_server_metadata(http, "https://as.example.com") + +auth = ClientCredentialsAuth( + server_metadata=server_metadata, + client_id="m2m", + client_secret=secret, +) +await auth.get_token() +client = await MCPClient.connect("https://mcp.example.com/mcp", auth=auth) +``` + +## Token Exchange (User Delegation) + +```python +from dedalus_mcp.client.auth import TokenExchangeAuth + +# Exchange a Clerk/Auth0 token for an MCP-scoped token +auth = await TokenExchangeAuth.from_resource( + resource_url="https://mcp.example.com/mcp", + client_id="dedalus-sdk", + subject_token=clerk_session_token, +) +await auth.get_token() +client = await MCPClient.connect("https://mcp.example.com/mcp", auth=auth) +``` + +## Architecture + +The module implements: + +- **RFC 9728**: OAuth 2.0 Protected Resource Metadata discovery +- **RFC 8414**: OAuth 2.0 Authorization Server Metadata discovery +- **RFC 6749 Section 4.4**: Client Credentials Grant +- **RFC 8693**: Token Exchange Grant +- **RFC 8707**: Resource Indicators diff --git a/examples/auth/07_client_auth_module/client_credentials_example.py b/examples/auth/07_client_auth_module/client_credentials_example.py new file mode 100644 index 0000000..93b7ec8 --- /dev/null +++ b/examples/auth/07_client_auth_module/client_credentials_example.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""Example: Client Credentials Authentication with the new auth module. + +This demonstrates the simplified OAuth 2.0 client credentials flow using +the new `dedalus_mcp.client.auth` module. + +Usage: + # Set the M2M secret (or pass via --client-secret) + export MCP_CLIENT_SECRET="your-m2m-secret" + + # Run the example + uv run python examples/auth/07_client_auth_module/client_credentials_example.py + + # Or with explicit args + uv run python examples/auth/07_client_auth_module/client_credentials_example.py \ + --issuer https://preview.as.dedaluslabs.ai \ + --client-id m2m \ + --client-secret "your-secret" +""" + +from __future__ import annotations + +import argparse +import asyncio +import os + +from dedalus_mcp.client.auth import ( + AuthorizationServerMetadata, + ClientCredentialsAuth, + TokenError, + discover_authorization_server, + fetch_authorization_server_metadata, +) + +# Defaults for preview environment +DEFAULT_ISSUER = os.getenv("AS_ISSUER", "https://preview.as.dedaluslabs.ai") +DEFAULT_CLIENT_ID = os.getenv("MCP_CLIENT_ID", "m2m") +DEFAULT_CLIENT_SECRET = os.getenv("MCP_CLIENT_SECRET") + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Client Credentials Auth Example", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--issuer", + default=DEFAULT_ISSUER, + help=f"Authorization Server issuer URL (default: {DEFAULT_ISSUER})", + ) + parser.add_argument( + "--client-id", + default=DEFAULT_CLIENT_ID, + help=f"OAuth client ID (default: {DEFAULT_CLIENT_ID})", + ) + parser.add_argument( + "--client-secret", + default=DEFAULT_CLIENT_SECRET, + help="OAuth client secret (default: MCP_CLIENT_SECRET env)", + ) + parser.add_argument( + "--scope", + default=None, + help="Optional scope to request", + ) + return parser + + +async def main() -> None: + parser = build_parser() + args = parser.parse_args() + + if not args.client_secret: + parser.error("Provide --client-secret or set MCP_CLIENT_SECRET") + + print(f"Fetching AS metadata from {args.issuer}...") + print() + + # Approach 1: Direct construction with AS metadata + # Use this when you know the AS URL upfront + import httpx + + async with httpx.AsyncClient() as client: + server_metadata = await fetch_authorization_server_metadata(client, args.issuer) + + print("Authorization Server Metadata:") + print(f" Issuer: {server_metadata.issuer}") + print(f" Token Endpoint: {server_metadata.token_endpoint}") + print(f" Supported Grants: {server_metadata.grant_types_supported}") + print() + + # Create auth instance + auth = ClientCredentialsAuth( + server_metadata=server_metadata, + client_id=args.client_id, + client_secret=args.client_secret, + scope=args.scope, + ) + + print(f"Requesting token for client '{auth.client_id}'...") + + try: + token = await auth.get_token() + except TokenError as e: + print(f"Token request failed: {e}") + return + + print() + print("Token acquired successfully!") + print(f" Token Type: {token.token_type}") + print(f" Expires In: {token.expires_in} seconds") + print(f" Access Token: {token.access_token[:50]}...") + print() + + # Demonstrate token caching + print("Requesting token again (should be cached)...") + token2 = await auth.get_token() + assert token.access_token == token2.access_token + print("Token was cached correctly.") + print() + + # Show how to use with MCPClient + print("=" * 60) + print("Integration with MCPClient:") + print("=" * 60) + print(""" +# With the new auth module, connecting to a protected MCP server is simple: + +from dedalus_mcp.client import MCPClient +from dedalus_mcp.client.auth import ClientCredentialsAuth + +# Option 1: Auto-discovery from protected resource +auth = await ClientCredentialsAuth.from_resource( + resource_url="https://mcp.example.com/mcp", + client_id="m2m", + client_secret=os.environ["M2M_SECRET"], +) +await auth.get_token() +client = await MCPClient.connect("https://mcp.example.com/mcp", auth=auth) + +# Option 2: Direct construction (when you know the AS) +server_metadata = await fetch_authorization_server_metadata(http, "https://as.example.com") +auth = ClientCredentialsAuth( + server_metadata=server_metadata, + client_id="m2m", + client_secret=secret, +) +await auth.get_token() +client = await MCPClient.connect("https://mcp.example.com/mcp", auth=auth) +""") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 03244f1..4c58919 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ classifiers = [ dependencies = [ "pydantic>=2.12.0", - "mcp>=1.20.0", + "mcp>=1.24.0", "pyjwt[crypto]>=2.10.1", "typing_extensions>=4.0.0", ] diff --git a/src/dedalus_mcp/client/auth/__init__.py b/src/dedalus_mcp/client/auth/__init__.py new file mode 100644 index 0000000..0f58190 --- /dev/null +++ b/src/dedalus_mcp/client/auth/__init__.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""OAuth authentication for MCP clients. + +This module provides spec-compliant OAuth authentication per MCP authorization spec: +- ClientCredentialsAuth: M2M/backend service authentication +- TokenExchangeAuth: User delegation via token exchange (RFC 8693) +- DeviceCodeAuth: CLI tool authentication (stub) +- AuthorizationCodeAuth: Browser-based authentication (stub) + +Example usage: + + # M2M / Backend service + auth = await ClientCredentialsAuth.from_resource( + resource_url="https://mcp.example.com/mcp", + client_id="m2m", + client_secret=os.environ["M2M_SECRET"], + ) + await auth.get_token() + client = await MCPClient.connect("https://mcp.example.com/mcp", auth=auth) + + # User delegation (e.g., from Clerk token) + auth = await TokenExchangeAuth.from_resource( + resource_url="https://mcp.example.com/mcp", + client_id="dedalus-sdk", + subject_token=clerk_session_token, + ) + await auth.get_token() + client = await MCPClient.connect("https://mcp.example.com/mcp", auth=auth) +""" + +from .authorization_code import AuthorizationCodeAuth +from .client_credentials import AuthConfigError, ClientCredentialsAuth, TokenError +from .device_code import DeviceCodeAuth +from .discovery import ( + DiscoveryError, + DiscoveryResult, + discover_authorization_server, + fetch_authorization_server_metadata, + fetch_resource_metadata, +) +from .models import ( + AuthorizationServerMetadata, + ResourceMetadata, + TokenResponse, + WWWAuthenticateChallenge, + parse_www_authenticate, +) +from .token_exchange import TokenExchangeAuth + +__all__ = [ + # Primary auth classes + "ClientCredentialsAuth", + "TokenExchangeAuth", + # Stubs for future + "DeviceCodeAuth", + "AuthorizationCodeAuth", + # Discovery + "discover_authorization_server", + "fetch_authorization_server_metadata", + "fetch_resource_metadata", + "DiscoveryResult", + "DiscoveryError", + # Models + "ResourceMetadata", + "AuthorizationServerMetadata", + "TokenResponse", + "WWWAuthenticateChallenge", + "parse_www_authenticate", + # Errors + "AuthConfigError", + "TokenError", +] diff --git a/src/dedalus_mcp/client/auth/authorization_code.py b/src/dedalus_mcp/client/auth/authorization_code.py new file mode 100644 index 0000000..d013397 --- /dev/null +++ b/src/dedalus_mcp/client/auth/authorization_code.py @@ -0,0 +1,25 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""OAuth 2.0 Authorization Code Grant (RFC 6749 Section 4.1). + +AuthorizationCodeAuth is for browser-based flows where the user +authenticates via redirect. + +TODO: Implement with Clerk integration for browser flows. +""" + +from __future__ import annotations + + +class AuthorizationCodeAuth: + """OAuth 2.0 Authorization Code Grant (RFC 6749 Section 4.1). + + Not yet implemented. Planned for Clerk integration. + """ + + def __init__(self) -> None: + raise NotImplementedError( + "AuthorizationCodeAuth is not yet implemented. " + "Planned for Clerk integration." + ) diff --git a/src/dedalus_mcp/client/auth/client_credentials.py b/src/dedalus_mcp/client/auth/client_credentials.py new file mode 100644 index 0000000..ff68e1d --- /dev/null +++ b/src/dedalus_mcp/client/auth/client_credentials.py @@ -0,0 +1,190 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""OAuth 2.0 Client Credentials Auth (RFC 6749 Section 4.4). + +ClientCredentialsAuth is the primary auth mechanism for M2M (machine-to-machine) +communication, CI/CD pipelines, and backend services. +""" + +from __future__ import annotations + +from typing import Generator + +import httpx + +from .discovery import discover_authorization_server +from .models import AuthorizationServerMetadata, TokenResponse + + +class AuthConfigError(Exception): + """Configuration error for authentication.""" + + +class TokenError(Exception): + """Error acquiring token.""" + + +class ClientCredentialsAuth(httpx.Auth): + """OAuth 2.0 Client Credentials authentication. + + Implements httpx.Auth for transparent token injection into HTTP requests. + Acquires tokens using the client_credentials grant type. + """ + + def __init__( + self, + *, + server_metadata: AuthorizationServerMetadata, + client_id: str, + client_secret: str, + scope: str | None = None, + resource: str | None = None, + ) -> None: + """Initialize ClientCredentialsAuth. + + Args: + server_metadata: Authorization Server metadata. + client_id: OAuth client ID. + client_secret: OAuth client secret. + scope: Optional scope to request. + resource: Optional resource indicator (RFC 8707). + + Raises: + AuthConfigError: If AS doesn't support client_credentials grant. + """ + if not server_metadata.supports_grant_type("client_credentials"): + raise AuthConfigError( + "Authorization server does not support client_credentials grant type" + ) + + self._server_metadata = server_metadata + self._client_id = client_id + self._client_secret = client_secret + self._scope = scope + self._resource = resource + self._cached_token: TokenResponse | None = None + + @property + def client_id(self) -> str: + """Return the client ID.""" + return self._client_id + + @property + def token_endpoint(self) -> str: + """Return the token endpoint URL.""" + return self._server_metadata.token_endpoint + + @property + def scope(self) -> str | None: + """Return the configured scope.""" + return self._scope + + @classmethod + async def from_resource( + cls, + *, + resource_url: str, + client_id: str, + client_secret: str, + scope: str | None = None, + ) -> ClientCredentialsAuth: + """Create ClientCredentialsAuth via OAuth discovery. + + Performs the full discovery flow: + 1. Probes resource for 401 + 2. Fetches Protected Resource Metadata + 3. Fetches Authorization Server Metadata + 4. Returns configured auth instance + + Args: + resource_url: URL of the protected resource. + client_id: OAuth client ID. + client_secret: OAuth client secret. + scope: Optional scope to request. + + Returns: + Configured ClientCredentialsAuth instance. + + Raises: + DiscoveryError: If discovery fails. + AuthConfigError: If AS doesn't support client_credentials. + """ + async with httpx.AsyncClient() as client: + result = await discover_authorization_server(client, resource_url) + + return cls( + server_metadata=result.authorization_server_metadata, + client_id=client_id, + client_secret=client_secret, + scope=scope, + resource=result.resource_metadata.resource, + ) + + async def get_token(self) -> TokenResponse: + """Acquire an access token using client credentials. + + Caches the token for reuse. + + Returns: + TokenResponse with access token. + + Raises: + TokenError: If token acquisition fails. + """ + if self._cached_token is not None: + return self._cached_token + + data = { + "grant_type": "client_credentials", + } + if self._scope: + data["scope"] = self._scope + if self._resource: + data["resource"] = self._resource + + async with httpx.AsyncClient() as client: + response = await client.post( + self._server_metadata.token_endpoint, + data=data, + auth=(self._client_id, self._client_secret), + ) + + if response.status_code != 200: + try: + error_data = response.json() + error = error_data.get("error", "unknown_error") + description = error_data.get("error_description", "") + raise TokenError(f"{error}: {description}") + except TokenError: + raise + except Exception: + raise TokenError(f"Token request failed: {response.status_code}") + + token = TokenResponse.from_dict(response.json()) + self._cached_token = token + return token + + def sync_auth_flow( + self, request: httpx.Request + ) -> Generator[httpx.Request, httpx.Response, None]: + """Synchronous auth flow for httpx.Auth interface. + + Injects the Bearer token into the request. Token must be + pre-fetched via get_token() for sync usage. + """ + if self._cached_token is not None: + request.headers["Authorization"] = f"Bearer {self._cached_token.access_token}" + yield request + + async def async_auth_flow( + self, request: httpx.Request + ) -> Generator[httpx.Request, httpx.Response, None]: + """Async auth flow for httpx.Auth interface. + + Injects the Bearer token into the request. Token must be + pre-fetched via get_token() for proper operation. + """ + if self._cached_token is not None: + request.headers["Authorization"] = f"Bearer {self._cached_token.access_token}" + yield request diff --git a/src/dedalus_mcp/client/auth/device_code.py b/src/dedalus_mcp/client/auth/device_code.py new file mode 100644 index 0000000..dfe245d --- /dev/null +++ b/src/dedalus_mcp/client/auth/device_code.py @@ -0,0 +1,22 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""OAuth 2.0 Device Authorization Grant (RFC 8628). + +DeviceCodeAuth is for CLI tools and devices without browsers. +The user authorizes on a separate device. + +TODO: Implement for CLI usage. +""" + +from __future__ import annotations + + +class DeviceCodeAuth: + """OAuth 2.0 Device Authorization Grant (RFC 8628). + + Not yet implemented. + """ + + def __init__(self) -> None: + raise NotImplementedError("DeviceCodeAuth is not yet implemented") diff --git a/src/dedalus_mcp/client/auth/discovery.py b/src/dedalus_mcp/client/auth/discovery.py new file mode 100644 index 0000000..fb6689f --- /dev/null +++ b/src/dedalus_mcp/client/auth/discovery.py @@ -0,0 +1,185 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""OAuth discovery (RFC 9728, RFC 8414). + +This module implements the discovery flow for MCP OAuth: +1. Probe resource for 401 with WWW-Authenticate header +2. Fetch Protected Resource Metadata (PRM) per RFC 9728 +3. Fetch Authorization Server Metadata per RFC 8414 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from urllib.parse import urljoin, urlparse + +import httpx + +from .models import ( + AuthorizationServerMetadata, + ResourceMetadata, + parse_www_authenticate, +) + + +class DiscoveryError(Exception): + """Error during OAuth discovery.""" + + +@dataclass +class DiscoveryResult: + """Result of OAuth discovery.""" + + resource_metadata: ResourceMetadata + authorization_server_metadata: AuthorizationServerMetadata + + +def build_resource_metadata_url(base_url: str, resource_metadata_path: str) -> str: + """Build the full URL for Protected Resource Metadata. + + Args: + base_url: The base resource URL (e.g., https://mcp.example.com/mcp) + resource_metadata_path: The path from WWW-Authenticate header + + Returns: + Full URL to fetch PRM from. + """ + # If it's already a full URL, return as-is + if resource_metadata_path.startswith("http://") or resource_metadata_path.startswith("https://"): + return resource_metadata_path + + # If it starts with /, it's absolute path from origin + if resource_metadata_path.startswith("/"): + parsed = urlparse(base_url) + return f"{parsed.scheme}://{parsed.netloc}{resource_metadata_path}" + + # Otherwise, relative path - resolve against base URL + return urljoin(base_url, resource_metadata_path) + + +def build_authorization_server_metadata_url(issuer: str) -> str: + """Build the well-known URL for Authorization Server Metadata. + + Per RFC 8414, if issuer has a path, insert .well-known between + the origin and the path. + + Args: + issuer: The authorization server issuer URL. + + Returns: + The .well-known URL for AS metadata. + """ + issuer = issuer.rstrip("/") + parsed = urlparse(issuer) + + # If there's a path component, insert .well-known before it + if parsed.path and parsed.path != "/": + path = parsed.path.lstrip("/") + return f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server/{path}" + + return f"{issuer}/.well-known/oauth-authorization-server" + + +async def fetch_resource_metadata(client: httpx.AsyncClient, url: str) -> ResourceMetadata: + """Fetch Protected Resource Metadata (RFC 9728). + + Args: + client: HTTP client to use for the request. + url: Full URL to the PRM endpoint. + + Returns: + Parsed ResourceMetadata. + + Raises: + DiscoveryError: If the fetch fails or response is invalid. + """ + response = await client.get(url) + + if response.status_code != 200: + raise DiscoveryError(f"Failed to fetch resource metadata: {response.status_code}") + + try: + data = response.json() + except Exception as e: + raise DiscoveryError(f"Invalid JSON in resource metadata response: {e}") from e + + return ResourceMetadata.from_dict(data) + + +async def fetch_authorization_server_metadata(client: httpx.AsyncClient, issuer: str) -> AuthorizationServerMetadata: + """Fetch Authorization Server Metadata (RFC 8414). + + Args: + client: HTTP client to use for the request. + issuer: The authorization server issuer URL. + + Returns: + Parsed AuthorizationServerMetadata. + + Raises: + DiscoveryError: If the fetch fails or response is invalid. + """ + url = build_authorization_server_metadata_url(issuer) + response = await client.get(url) + + if response.status_code != 200: + raise DiscoveryError(f"Failed to fetch AS metadata: {response.status_code}") + + try: + data = response.json() + except Exception as e: + raise DiscoveryError(f"Invalid JSON in AS metadata response: {e}") from e + + return AuthorizationServerMetadata.from_dict(data) + + +async def discover_authorization_server( + client: httpx.AsyncClient, + resource_url: str, +) -> DiscoveryResult: + """Perform full OAuth discovery flow. + + This implements the MCP spec-compliant discovery: + 1. Probe the resource URL, expecting 401 + 2. Parse WWW-Authenticate header for resource_metadata path + 3. Fetch Protected Resource Metadata + 4. Fetch Authorization Server Metadata + + Args: + client: HTTP client to use for requests. + resource_url: URL of the protected resource (e.g., MCP endpoint). + + Returns: + DiscoveryResult with both resource and AS metadata. + + Raises: + DiscoveryError: If discovery fails at any step. + """ + # Step 1: Probe resource for 401 + response = await client.get(resource_url) + + if response.status_code != 401: + raise DiscoveryError(f"Resource is not protected (got {response.status_code}, expected 401)") + + # Step 2: Parse WWW-Authenticate header + www_auth = response.headers.get("WWW-Authenticate") + if not www_auth: + raise DiscoveryError("401 response missing WWW-Authenticate header") + + challenge = parse_www_authenticate(www_auth) + if not challenge.resource_metadata: + raise DiscoveryError("WWW-Authenticate header missing resource_metadata parameter") + + # Step 3: Fetch Protected Resource Metadata + prm_url = build_resource_metadata_url(resource_url, challenge.resource_metadata) + resource_metadata = await fetch_resource_metadata(client, prm_url) + + # Step 4: Fetch AS Metadata (use first/primary AS) + as_url = resource_metadata.primary_authorization_server + authorization_server_metadata = await fetch_authorization_server_metadata(client, as_url) + + return DiscoveryResult( + resource_metadata=resource_metadata, + authorization_server_metadata=authorization_server_metadata, + ) diff --git a/src/dedalus_mcp/client/auth/models.py b/src/dedalus_mcp/client/auth/models.py new file mode 100644 index 0000000..2b8d651 --- /dev/null +++ b/src/dedalus_mcp/client/auth/models.py @@ -0,0 +1,196 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""OAuth metadata models (RFC 9728, RFC 8414). + +This module defines data models for: +- Protected Resource Metadata (RFC 9728) +- Authorization Server Metadata (RFC 8414) +- Token responses +- WWW-Authenticate header parsing +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Any + + +@dataclass +class ResourceMetadata: + """OAuth 2.0 Protected Resource Metadata (RFC 9728). + + Describes the OAuth configuration of a protected resource, + including which authorization servers can issue tokens for it. + """ + + resource: str + authorization_servers: list[str] + scopes_supported: list[str] | None = None + bearer_methods_supported: list[str] | None = None + resource_signing_alg_values_supported: list[str] | None = None + + @property + def primary_authorization_server(self) -> str: + """Return the first (primary) authorization server.""" + return self.authorization_servers[0] + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> ResourceMetadata: + """Create ResourceMetadata from a dictionary. + + Raises: + ValueError: If required fields are missing. + """ + if "resource" not in data: + raise ValueError("Missing required field: resource") + if "authorization_servers" not in data: + raise ValueError("Missing required field: authorization_servers") + + return cls( + resource=data["resource"], + authorization_servers=data["authorization_servers"], + scopes_supported=data.get("scopes_supported"), + bearer_methods_supported=data.get("bearer_methods_supported"), + resource_signing_alg_values_supported=data.get("resource_signing_alg_values_supported"), + ) + + +@dataclass +class AuthorizationServerMetadata: + """OAuth 2.0 Authorization Server Metadata (RFC 8414). + + Describes the OAuth configuration of an authorization server, + including endpoints and supported features. + """ + + issuer: str + token_endpoint: str + authorization_endpoint: str | None = None + registration_endpoint: str | None = None + jwks_uri: str | None = None + scopes_supported: list[str] | None = None + response_types_supported: list[str] | None = None + grant_types_supported: list[str] | None = None + token_endpoint_auth_methods_supported: list[str] | None = None + code_challenge_methods_supported: list[str] | None = None + + def supports_grant_type(self, grant_type: str) -> bool: + """Check if the AS supports a specific grant type.""" + if self.grant_types_supported is None: + return False + return grant_type in self.grant_types_supported + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> AuthorizationServerMetadata: + """Create AuthorizationServerMetadata from a dictionary. + + Raises: + ValueError: If required fields are missing. + """ + if "issuer" not in data: + raise ValueError("Missing required field: issuer") + if "token_endpoint" not in data: + raise ValueError("Missing required field: token_endpoint") + + return cls( + issuer=data["issuer"], + token_endpoint=data["token_endpoint"], + authorization_endpoint=data.get("authorization_endpoint"), + registration_endpoint=data.get("registration_endpoint"), + jwks_uri=data.get("jwks_uri"), + scopes_supported=data.get("scopes_supported"), + response_types_supported=data.get("response_types_supported"), + grant_types_supported=data.get("grant_types_supported"), + token_endpoint_auth_methods_supported=data.get("token_endpoint_auth_methods_supported"), + code_challenge_methods_supported=data.get("code_challenge_methods_supported"), + ) + + +@dataclass +class TokenResponse: + """OAuth token response. + + Contains the access token and related metadata returned + from the token endpoint. + """ + + access_token: str + token_type: str + expires_in: int | None = None + refresh_token: str | None = None + scope: str | None = None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> TokenResponse: + """Create TokenResponse from a dictionary. + + Raises: + ValueError: If required fields are missing. + """ + if "access_token" not in data: + raise ValueError("Missing required field: access_token") + if "token_type" not in data: + raise ValueError("Missing required field: token_type") + + return cls( + access_token=data["access_token"], + token_type=data["token_type"], + expires_in=data.get("expires_in"), + refresh_token=data.get("refresh_token"), + scope=data.get("scope"), + ) + + +@dataclass +class WWWAuthenticateChallenge: + """Parsed WWW-Authenticate header challenge.""" + + scheme: str + error: str | None = None + error_description: str | None = None + resource_metadata: str | None = None + + +# Regex to parse WWW-Authenticate parameters +_PARAM_PATTERN = re.compile(r'(\w+)="([^"]*)"') + + +def parse_www_authenticate(header: str) -> WWWAuthenticateChallenge: + """Parse a WWW-Authenticate header value. + + Args: + header: The WWW-Authenticate header value. + + Returns: + Parsed challenge with scheme and parameters. + + Raises: + ValueError: If the header is empty or malformed. + """ + if not header: + raise ValueError("Empty WWW-Authenticate header") + + parts = header.split(None, 1) + if not parts: + raise ValueError("Malformed WWW-Authenticate header") + + scheme = parts[0] + params_str = parts[1] if len(parts) > 1 else "" + + # Validate scheme looks reasonable + if not scheme.isalpha(): + raise ValueError("Malformed WWW-Authenticate header: invalid scheme") + + # Parse parameters + params: dict[str, str] = {} + for match in _PARAM_PATTERN.finditer(params_str): + params[match.group(1)] = match.group(2) + + return WWWAuthenticateChallenge( + scheme=scheme, + error=params.get("error"), + error_description=params.get("error_description"), + resource_metadata=params.get("resource_metadata"), + ) diff --git a/src/dedalus_mcp/client/auth/token_exchange.py b/src/dedalus_mcp/client/auth/token_exchange.py new file mode 100644 index 0000000..97ad888 --- /dev/null +++ b/src/dedalus_mcp/client/auth/token_exchange.py @@ -0,0 +1,221 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""OAuth 2.0 Token Exchange Auth (RFC 8693). + +TokenExchangeAuth exchanges an existing token (e.g., from Clerk, Auth0) +for an MCP-scoped access token. Used for user delegation flows. +""" + +from __future__ import annotations + +from typing import Generator + +import httpx + +from .discovery import discover_authorization_server +from .models import AuthorizationServerMetadata, TokenResponse + +# RFC 8693 grant type +TOKEN_EXCHANGE_GRANT = "urn:ietf:params:oauth:grant-type:token-exchange" + +# RFC 8693 token types +ACCESS_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" +ID_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:id_token" + + +class AuthConfigError(Exception): + """Configuration error for authentication.""" + + +class TokenError(Exception): + """Error acquiring token.""" + + +class TokenExchangeAuth(httpx.Auth): + """OAuth 2.0 Token Exchange authentication (RFC 8693). + + Implements httpx.Auth for transparent token injection into HTTP requests. + Exchanges an existing token for a new access token. + """ + + def __init__( + self, + *, + server_metadata: AuthorizationServerMetadata, + client_id: str, + subject_token: str, + subject_token_type: str = ACCESS_TOKEN_TYPE, + actor_token: str | None = None, + actor_token_type: str | None = None, + scope: str | None = None, + resource: str | None = None, + ) -> None: + """Initialize TokenExchangeAuth. + + Args: + server_metadata: Authorization Server metadata. + client_id: OAuth client ID. + subject_token: The token to exchange. + subject_token_type: Type of subject token (default: access_token). + actor_token: Optional actor token for delegation. + actor_token_type: Type of actor token. + scope: Optional scope to request. + resource: Optional resource indicator (RFC 8707). + + Raises: + AuthConfigError: If AS doesn't support token-exchange grant. + """ + if not server_metadata.supports_grant_type(TOKEN_EXCHANGE_GRANT): + raise AuthConfigError( + "Authorization server does not support token-exchange grant type" + ) + + self._server_metadata = server_metadata + self._client_id = client_id + self._subject_token = subject_token + self._subject_token_type = subject_token_type + self._actor_token = actor_token + self._actor_token_type = actor_token_type or ACCESS_TOKEN_TYPE + self._scope = scope + self._resource = resource + self._cached_token: TokenResponse | None = None + + @property + def client_id(self) -> str: + """Return the client ID.""" + return self._client_id + + @property + def token_endpoint(self) -> str: + """Return the token endpoint URL.""" + return self._server_metadata.token_endpoint + + @property + def subject_token_type(self) -> str: + """Return the subject token type.""" + return self._subject_token_type + + @property + def actor_token(self) -> str | None: + """Return the actor token.""" + return self._actor_token + + @classmethod + async def from_resource( + cls, + *, + resource_url: str, + client_id: str, + subject_token: str, + subject_token_type: str = ACCESS_TOKEN_TYPE, + scope: str | None = None, + ) -> TokenExchangeAuth: + """Create TokenExchangeAuth via OAuth discovery. + + Performs the full discovery flow: + 1. Probes resource for 401 + 2. Fetches Protected Resource Metadata + 3. Fetches Authorization Server Metadata + 4. Returns configured auth instance + + Args: + resource_url: URL of the protected resource. + client_id: OAuth client ID. + subject_token: The token to exchange. + subject_token_type: Type of subject token. + scope: Optional scope to request. + + Returns: + Configured TokenExchangeAuth instance. + + Raises: + DiscoveryError: If discovery fails. + AuthConfigError: If AS doesn't support token-exchange. + """ + async with httpx.AsyncClient() as client: + result = await discover_authorization_server(client, resource_url) + + return cls( + server_metadata=result.authorization_server_metadata, + client_id=client_id, + subject_token=subject_token, + subject_token_type=subject_token_type, + scope=scope, + resource=result.resource_metadata.resource, + ) + + async def get_token(self) -> TokenResponse: + """Exchange subject token for an access token. + + Caches the token for reuse. + + Returns: + TokenResponse with access token. + + Raises: + TokenError: If token exchange fails. + """ + if self._cached_token is not None: + return self._cached_token + + data = { + "grant_type": TOKEN_EXCHANGE_GRANT, + "client_id": self._client_id, + "subject_token": self._subject_token, + "subject_token_type": self._subject_token_type, + } + + if self._actor_token: + data["actor_token"] = self._actor_token + data["actor_token_type"] = self._actor_token_type + + if self._scope: + data["scope"] = self._scope + if self._resource: + data["resource"] = self._resource + + async with httpx.AsyncClient() as client: + response = await client.post( + self._server_metadata.token_endpoint, + data=data, + ) + + if response.status_code != 200: + try: + error_data = response.json() + error = error_data.get("error", "unknown_error") + description = error_data.get("error_description", "") + raise TokenError(f"{error}: {description}") + except TokenError: + raise + except Exception: + raise TokenError(f"Token exchange failed: {response.status_code}") + + token = TokenResponse.from_dict(response.json()) + self._cached_token = token + return token + + def sync_auth_flow( + self, request: httpx.Request + ) -> Generator[httpx.Request, httpx.Response, None]: + """Synchronous auth flow for httpx.Auth interface. + + Injects the Bearer token into the request. Token must be + pre-fetched via get_token() for sync usage. + """ + if self._cached_token is not None: + request.headers["Authorization"] = f"Bearer {self._cached_token.access_token}" + yield request + + async def async_auth_flow( + self, request: httpx.Request + ) -> Generator[httpx.Request, httpx.Response, None]: + """Async auth flow for httpx.Auth interface. + + Injects the Bearer token into the request. Token must be + pre-fetched via get_token() for proper operation. + """ + if self._cached_token is not None: + request.headers["Authorization"] = f"Bearer {self._cached_token.access_token}" + yield request diff --git a/src/dedalus_mcp/client/connection.py b/src/dedalus_mcp/client/connection.py index 1205991..687eb93 100644 --- a/src/dedalus_mcp/client/connection.py +++ b/src/dedalus_mcp/client/connection.py @@ -9,26 +9,20 @@ The helper deliberately keeps the surface tiny: callers choose a transport via ``transport=`` (defaulting to streamable HTTP) and receive an :class:`~dedalus_mcp.client.MCPClient` instance that already negotiated -capabilities. Power users can still reach the underlying transport by using +capabilities. Power users can still reach the underlying transport by using the lower-level helpers directly. """ from __future__ import annotations -from collections.abc import AsyncGenerator, Callable, Mapping +from collections.abc import AsyncGenerator, Mapping from contextlib import asynccontextmanager from datetime import timedelta import httpx -from mcp.client.streamable_http import ( - MCP_PROTOCOL_VERSION, - streamablehttp_client, -) -from mcp.shared._httpx_utils import ( - McpHttpClientFactory, - create_mcp_http_client, -) +from mcp.client.streamable_http import MCP_PROTOCOL_VERSION, streamable_http_client +from mcp.shared._httpx_utils import create_mcp_http_client from mcp.types import LATEST_PROTOCOL_VERSION, Implementation from .core import ClientCapabilitiesConfig, MCPClient @@ -39,6 +33,29 @@ LambdaHTTPNames = {"lambda-http", "lambda_http"} +def _build_http_client( + headers: Mapping[str, str] | None, + timeout: float | timedelta, + sse_read_timeout: float | timedelta, + auth: httpx.Auth | None, +) -> httpx.AsyncClient: + """Build an httpx.AsyncClient with MCP-appropriate settings.""" + # Build headers with MCP protocol version + base_headers: dict[str, str] = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} + if headers: + base_headers.update(headers) + + # Convert timedelta to float if needed + timeout_sec = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + sse_timeout_sec = sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout + + return create_mcp_http_client( + headers=base_headers, + timeout=httpx.Timeout(timeout_sec, read=sse_timeout_sec), + auth=auth, + ) + + @asynccontextmanager async def open_connection( url: str, @@ -48,11 +65,9 @@ async def open_connection( timeout: float | timedelta = 30, sse_read_timeout: float | timedelta = 300, terminate_on_close: bool = True, - httpx_client_factory: McpHttpClientFactory | Callable[..., httpx.AsyncClient] = create_mcp_http_client, auth: httpx.Auth | None = None, capabilities: ClientCapabilitiesConfig | None = None, client_info: Implementation | None = None, - **transport_kwargs, ) -> AsyncGenerator[MCPClient, None]: """Open an MCP client connection. @@ -60,15 +75,13 @@ async def open_connection( url: Fully qualified MCP endpoint (for example, ``"http://127.0.0.1:8000/mcp"``). transport: Transport name. Defaults to ``"streamable-http"``; accepts aliases like ``"shttp"`` and ``"lambda-http"``. - headers: Optional HTTP headers to merge into the Streamable HTTP or Lambda HTTP transport. + headers: Optional HTTP headers to merge into the transport. timeout: Total request timeout passed to the underlying transport. sse_read_timeout: Streaming read timeout for Server-Sent Events. terminate_on_close: Whether to send a transport-level termination request when closing. - httpx_client_factory: Factory used to build the HTTPX client for Streamable HTTP variants. auth: Optional HTTPX authentication handler. capabilities: Optional client capability configuration advertised during initialization. client_info: Implementation metadata forwarded during the MCP handshake. - transport_kwargs: Transport-specific keyword arguments forwarded to the underlying helper. Yields: MCPClient: A negotiated MCP client ready for ``send_request`` and other operations. @@ -76,60 +89,45 @@ async def open_connection( selected = transport.lower() if selected in StreamableHTTPNames: - # The Streamable HTTP handshake requires ``MCP-Protocol-Version`` on - # every request. Ensure callers always send the latest version we - # support while still allowing custom headers to override it. - base_headers: dict[str, str] = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} - if headers: - base_headers.update(headers) - - async with ( - streamablehttp_client( - url, - headers=base_headers, - timeout=timeout, - sse_read_timeout=sse_read_timeout, - terminate_on_close=terminate_on_close, - httpx_client_factory=httpx_client_factory, - auth=auth, - **transport_kwargs, - ) as (read_stream, write_stream, get_session_id), - MCPClient( - read_stream, - write_stream, - capabilities=capabilities, - client_info=client_info, - get_session_id=get_session_id, - ) as client, - ): - yield client + client = _build_http_client(headers, timeout, sse_read_timeout, auth) + + async with client: + async with ( + streamable_http_client( + url, + http_client=client, + terminate_on_close=terminate_on_close, + ) as (read_stream, write_stream, get_session_id), + MCPClient( + read_stream, + write_stream, + capabilities=capabilities, + client_info=client_info, + get_session_id=get_session_id, + ) as mcp_client, + ): + yield mcp_client return if selected in LambdaHTTPNames: - base_headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} - if headers: - base_headers.update(headers) - - async with ( - lambda_http_client( - url, - headers=base_headers, - timeout=timeout, - sse_read_timeout=sse_read_timeout, - terminate_on_close=terminate_on_close, - httpx_client_factory=httpx_client_factory, - auth=auth, - **transport_kwargs, - ) as (read_stream, write_stream, get_session_id), - MCPClient( - read_stream, - write_stream, - capabilities=capabilities, - client_info=client_info, - get_session_id=get_session_id, - ) as client, - ): - yield client + client = _build_http_client(headers, timeout, sse_read_timeout, auth) + + async with client: + async with ( + lambda_http_client( + url, + http_client=client, + terminate_on_close=terminate_on_close, + ) as (read_stream, write_stream, get_session_id), + MCPClient( + read_stream, + write_stream, + capabilities=capabilities, + client_info=client_info, + get_session_id=get_session_id, + ) as mcp_client, + ): + yield mcp_client return raise ValueError(f"Unsupported transport '{transport}'") diff --git a/src/dedalus_mcp/client/core.py b/src/dedalus_mcp/client/core.py index d8860e8..d1009cb 100644 --- a/src/dedalus_mcp/client/core.py +++ b/src/dedalus_mcp/client/core.py @@ -39,6 +39,14 @@ from mcp.client.session import ClientSession +from .errors import MCPConnectionError, SessionExpiredError +from .error_handling import ( + extract_http_error, + extract_network_error, + http_error_to_mcp_error, + network_error_to_mcp_error, +) + from ..types.client.elicitation import ElicitRequestParams, ElicitResult from ..types.client.roots import ListRootsResult, Root from ..types.client.sampling import CreateMessageRequestParams, CreateMessageResult @@ -207,7 +215,8 @@ async def connect( return client # Real implementation: use transport helpers - from mcp.client.streamable_http import MCP_PROTOCOL_VERSION, streamablehttp_client + from mcp.client.streamable_http import MCP_PROTOCOL_VERSION, streamable_http_client + from mcp.shared._httpx_utils import create_mcp_http_client from mcp.types import LATEST_PROTOCOL_VERSION from .transports import lambda_http_client @@ -215,64 +224,97 @@ async def connect( exit_stack = AsyncExitStack() try: - base_headers: dict[str, str] = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} - if headers: - base_headers.update(headers) - - transport_lower = transport.lower() - if transport_lower in {"streamable-http", "streamable_http", "shttp", "http"}: - read_stream, write_stream, get_session_id = await exit_stack.enter_async_context( - streamablehttp_client( - url, headers=base_headers, timeout=timeout, sse_read_timeout=sse_read_timeout, auth=auth - ) + try: + # Build httpx client with MCP-appropriate settings + base_headers: dict[str, str] = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} + if headers: + base_headers.update(headers) + + http_client = create_mcp_http_client( + headers=base_headers, + timeout=httpx.Timeout(timeout, read=sse_read_timeout), + auth=auth, ) - elif transport_lower in {"lambda-http", "lambda_http"}: - read_stream, write_stream, get_session_id = await exit_stack.enter_async_context( - lambda_http_client( - url, headers=base_headers, timeout=timeout, sse_read_timeout=sse_read_timeout, auth=auth + await exit_stack.enter_async_context(http_client) + + transport_lower = transport.lower() + if transport_lower in {"streamable-http", "streamable_http", "shttp", "http"}: + read_stream, write_stream, get_session_id = await exit_stack.enter_async_context( + streamable_http_client(url, http_client=http_client) + ) + elif transport_lower in {"lambda-http", "lambda_http"}: + read_stream, write_stream, get_session_id = await exit_stack.enter_async_context( + lambda_http_client(url, http_client=http_client) ) + else: + raise ValueError(f"Unsupported transport: {transport}") + + # Create client with exit stack for cleanup + client = cls( + read_stream, + write_stream, + capabilities=capabilities, + client_info=client_info, + get_session_id=get_session_id, + _exit_stack=exit_stack, ) - else: - raise ValueError(f"Unsupported transport: {transport}") - - # Create client with exit stack for cleanup - client = cls( - read_stream, - write_stream, - capabilities=capabilities, - client_info=client_info, - get_session_id=get_session_id, - _exit_stack=exit_stack, - ) - - # Enter the session context - session = ClientSession( - read_stream, - write_stream, - sampling_callback=client._build_sampling_handler(), - elicitation_callback=client._build_elicitation_handler(), - list_roots_callback=client._build_roots_handler(), - logging_callback=client._build_logging_handler(), - client_info=client._client_info, - ) - client._session = await exit_stack.enter_async_context(session) - try: + + # Enter the session context + session = ClientSession( + read_stream, + write_stream, + sampling_callback=client._build_sampling_handler(), + elicitation_callback=client._build_elicitation_handler(), + list_roots_callback=client._build_roots_handler(), + logging_callback=client._build_logging_handler(), + client_info=client._client_info, + ) + client._session = await exit_stack.enter_async_context(session) client.initialize_result = await client._session.initialize() - except Exception as e: - err = str(e).lower() - if "session terminated" in err or "connection" in err: - msg = "Failed to connect to the MCP server" - raise ConnectionError(msg) from e - raise - - # Transfer ownership of exit_stack - don't close it here - exit_stack = None # type: ignore[assignment] - return client - finally: - # Only close if we didn't transfer ownership - if exit_stack is not None: - await exit_stack.aclose() + # Transfer ownership of exit_stack - don't close it here + exit_stack = None # type: ignore[assignment] + return client + + finally: + # Only close if we didn't transfer ownership + if exit_stack is not None: + await exit_stack.aclose() + + except (httpx.ConnectError, httpx.TimeoutException) as e: + raise network_error_to_mcp_error(e) from e + + except httpx.HTTPStatusError as e: + raise http_error_to_mcp_error(e) from e + + except BaseExceptionGroup as e: + # MCP SDK transport wraps errors in ExceptionGroup from anyio + http_error = extract_http_error(e) + if http_error is not None: + raise http_error_to_mcp_error(http_error) from e + + # Check for network errors in the group + network_error = extract_network_error(e) + if network_error is not None: + raise network_error_to_mcp_error(network_error) from e + + # Not an HTTP or network error - re-raise the group + raise + + except Exception as e: + # Handle MCP SDK errors (e.g., McpError for session terminated) + from mcp.shared.exceptions import McpError + + if isinstance(e, McpError): + err_msg = str(e).lower() + if "session" in err_msg and ("terminated" in err_msg or "expired" in err_msg): + raise SessionExpiredError(f"Session expired or terminated: {e}") from e + raise MCPConnectionError(f"MCP error: {e}") from e + + # Check for network errors that might not be caught above + if isinstance(e.__cause__, (httpx.ConnectError, httpx.TimeoutException)): + raise network_error_to_mcp_error(e.__cause__) from e + raise # --------------------------------------------------------------------- # Cleanup: close() diff --git a/src/dedalus_mcp/client/error_handling.py b/src/dedalus_mcp/client/error_handling.py new file mode 100644 index 0000000..f28c669 --- /dev/null +++ b/src/dedalus_mcp/client/error_handling.py @@ -0,0 +1,187 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""Error extraction and conversion utilities for MCPClient. + +This module handles the conversion of low-level transport errors (httpx, anyio) +into user-friendly MCPConnectionError subclasses with actionable messages. +""" + +from __future__ import annotations + +import httpx + +from .errors import ( + MCPConnectionError, + AuthRequiredError, + BadRequestError, + ForbiddenError, + ServerError, + SessionExpiredError, + TransportError, +) + + +def extract_http_error(exc: BaseException) -> httpx.HTTPStatusError | None: + """Extract HTTPStatusError from an exception or exception group. + + The MCP SDK transport layer raises errors wrapped in ExceptionGroup from anyio. + This helper extracts the underlying HTTPStatusError if present. + """ + if isinstance(exc, httpx.HTTPStatusError): + return exc + + if isinstance(exc, BaseExceptionGroup): + for sub_exc in exc.exceptions: + result = extract_http_error(sub_exc) + if result is not None: + return result + + return None + + +def extract_network_error(exc: BaseException) -> httpx.ConnectError | httpx.TimeoutException | None: + """Extract ConnectError or TimeoutException from an exception group.""" + if isinstance(exc, (httpx.ConnectError, httpx.TimeoutException)): + return exc + + if isinstance(exc, BaseExceptionGroup): + for sub_exc in exc.exceptions: + result = extract_network_error(sub_exc) + if result is not None: + return result + + return None + + +def http_error_to_mcp_error(error: httpx.HTTPStatusError) -> MCPConnectionError: + """Convert an HTTPStatusError to the appropriate MCPConnectionError subclass.""" + status = error.response.status_code + headers = error.response.headers + + error_msg = _extract_error_message(error) + www_auth = headers.get("WWW-Authenticate", "") + + if status == 400: + return _handle_400(error_msg) + elif status == 401: + return _handle_401(error_msg, www_auth) + elif status == 403: + return _handle_403(error_msg, www_auth) + elif status == 404: + return _handle_404(error_msg) + elif status == 405: + return _handle_405(headers) + elif status == 415: + return _handle_415(error_msg) + elif status == 422: + return _handle_422(error_msg) + elif 500 <= status < 600: + return _handle_5xx(status, error_msg, headers) + + # Fallback for other status codes + msg = f"HTTP error {status}: {error_msg}" if error_msg else f"HTTP error {status}" + return MCPConnectionError(msg, status_code=status) + + +def network_error_to_mcp_error(error: Exception) -> MCPConnectionError: + """Convert network-level errors to MCPConnectionError.""" + err_str = str(error).lower() + + if isinstance(error, httpx.TimeoutException): + return MCPConnectionError(f"Connection timed out: {error}") + + if isinstance(error, httpx.ConnectError): + if "refused" in err_str: + return MCPConnectionError(f"Connection refused - server may be down: {error}") + if "dns" in err_str or "resolve" in err_str or "name" in err_str: + return MCPConnectionError(f"DNS resolution failed - check the server URL: {error}") + return MCPConnectionError(f"Failed to connect: {error}") + + return MCPConnectionError(f"Connection error: {error}") + + +# --------------------------------------------------------------------------- +# Internal helpers for specific status codes +# --------------------------------------------------------------------------- + + +def _extract_error_message(error: httpx.HTTPStatusError) -> str: + """Extract error message from response body, handling streaming responses.""" + try: + body = error.response.json() + return body.get("error_description") or body.get("message") or body.get("error", "") + except httpx.ResponseNotRead: + return "" + except Exception: + try: + return error.response.text[:200] if error.response.text else "" + except httpx.ResponseNotRead: + return "" + + +def _handle_400(error_msg: str) -> BadRequestError: + msg = f"Bad request: {error_msg}" if error_msg else "Bad request to MCP server" + if "version" in error_msg.lower() or "protocol" in error_msg.lower(): + msg = f"Invalid protocol version: {error_msg}" + return BadRequestError(msg, status_code=400) + + +def _handle_401(error_msg: str, www_auth: str) -> AuthRequiredError: + if "invalid_token" in www_auth.lower() or "expired" in error_msg.lower(): + msg = f"Token invalid or expired: {error_msg}" if error_msg else "Token invalid or expired" + else: + msg = "Authentication required - provide valid credentials" + return AuthRequiredError(msg, status_code=401, www_authenticate=www_auth or None) + + +def _handle_403(error_msg: str, www_auth: str) -> ForbiddenError: + if "scope" in www_auth.lower() or "scope" in error_msg.lower(): + msg = f"Insufficient scope or permissions: {error_msg}" if error_msg else "Insufficient scope" + else: + msg = f"Forbidden: {error_msg}" if error_msg else "Access forbidden - insufficient permissions" + return ForbiddenError(msg, status_code=403) + + +def _handle_404(error_msg: str) -> MCPConnectionError: + if "session" in error_msg.lower(): + return SessionExpiredError(f"Session expired or terminated: {error_msg}", status_code=404) + msg = f"Endpoint not found (404): {error_msg}" if error_msg else "MCP endpoint not found" + return MCPConnectionError(msg, status_code=404) + + +def _handle_405(headers: httpx.Headers) -> TransportError: + allow = headers.get("Allow", "") + msg = f"Method not allowed (405). Server accepts: {allow}" if allow else "HTTP method not allowed (405)" + return TransportError(msg, status_code=405) + + +def _handle_415(error_msg: str) -> TransportError: + msg = f"Unsupported content type (415): {error_msg}" if error_msg else "Unsupported media type" + return TransportError(msg, status_code=415) + + +def _handle_422(error_msg: str) -> BadRequestError: + msg = f"Invalid request format (422): {error_msg}" if error_msg else "Unprocessable request" + return BadRequestError(msg, status_code=422) + + +def _handle_5xx(status: int, error_msg: str, headers: httpx.Headers) -> ServerError: + retry_after = headers.get("Retry-After") + status_messages = { + 500: "Internal server error", + 502: "Bad gateway - upstream server error", + 503: "Service unavailable", + 504: "Gateway timeout - server did not respond in time", + } + base_msg = status_messages.get(status, f"Server error ({status})") + msg = f"{base_msg}: {error_msg}" if error_msg else base_msg + return ServerError(msg, status_code=status, retry_after=retry_after) + + +__all__ = [ + "extract_http_error", + "extract_network_error", + "http_error_to_mcp_error", + "network_error_to_mcp_error", +] diff --git a/src/dedalus_mcp/client/errors.py b/src/dedalus_mcp/client/errors.py new file mode 100644 index 0000000..c6a0dce --- /dev/null +++ b/src/dedalus_mcp/client/errors.py @@ -0,0 +1,170 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""Connection error types for MCPClient. + +These exceptions provide specific, actionable error messages for HTTP status +codes encountered during MCP connection per the spec: +- RFC 9728 (OAuth Protected Resource Metadata) +- MCP Transport Specification +- MCP Authorization Specification +""" + +from __future__ import annotations + + +class MCPConnectionError(Exception): + """Base class for MCP connection errors. + + All HTTP status code errors during connection inherit from this class. + + Attributes: + status_code: HTTP status code that triggered the error. + message: Human-readable error description. + """ + + def __init__( + self, + message: str, + *, + status_code: int | None = None, + ) -> None: + super().__init__(message) + self.status_code = status_code + self.message = message + + def __str__(self) -> str: + if self.status_code: + return f"[{self.status_code}] {self.message}" + return self.message + + +class BadRequestError(MCPConnectionError): + """400 Bad Request - Invalid input or protocol version. + + Raised when: + - Invalid MCP-Protocol-Version header + - Malformed JSON-RPC request + - Invalid request parameters (422) + """ + + def __init__( + self, + message: str = "Bad request to MCP server", + *, + status_code: int = 400, + ) -> None: + super().__init__(message, status_code=status_code) + + +class AuthRequiredError(MCPConnectionError): + """401 Unauthorized - Authentication required or token invalid. + + Raised when: + - No credentials provided to protected resource + - Access token expired or invalid + - Token signature verification failed + + Attributes: + www_authenticate: The WWW-Authenticate header value, if present. + """ + + def __init__( + self, + message: str = "Authentication required", + *, + status_code: int = 401, + www_authenticate: str | None = None, + ) -> None: + super().__init__(message, status_code=status_code) + self.www_authenticate = www_authenticate + + +class ForbiddenError(MCPConnectionError): + """403 Forbidden - Insufficient scopes or permissions. + + Raised when: + - Token lacks required scopes + - User/client lacks permission for the requested operation + """ + + def __init__( + self, + message: str = "Access forbidden - insufficient permissions", + *, + status_code: int = 403, + ) -> None: + super().__init__(message, status_code=status_code) + + +class SessionExpiredError(MCPConnectionError): + """404 Not Found - Session terminated or expired. + + Per MCP Transport spec, 404 during an active session indicates + the session has been terminated by the server. + + Raised when: + - Session ID is no longer valid + - Server terminated the session + """ + + def __init__( + self, + message: str = "Session expired or terminated", + *, + status_code: int = 404, + ) -> None: + super().__init__(message, status_code=status_code) + + +class TransportError(MCPConnectionError): + """405/415 - Transport or protocol mismatch. + + Raised when: + - HTTP method not allowed (405) + - Wrong Content-Type (415) + - Transport type incompatibility + """ + + def __init__( + self, + message: str = "Transport error - protocol mismatch", + *, + status_code: int | None = None, + ) -> None: + super().__init__(message, status_code=status_code) + + +class ServerError(MCPConnectionError): + """5xx Server Error - Server-side failure. + + Raised when: + - 500 Internal Server Error + - 502 Bad Gateway + - 503 Service Unavailable + - 504 Gateway Timeout + + Attributes: + retry_after: The Retry-After header value, if present. + """ + + def __init__( + self, + message: str = "Server error", + *, + status_code: int = 500, + retry_after: str | None = None, + ) -> None: + super().__init__(message, status_code=status_code) + self.retry_after = retry_after + + +__all__ = [ + "MCPConnectionError", + "BadRequestError", + "AuthRequiredError", + "ForbiddenError", + "SessionExpiredError", + "TransportError", + "ServerError", +] diff --git a/src/dedalus_mcp/client/transports.py b/src/dedalus_mcp/client/transports.py index 35c537d..7d91997 100644 --- a/src/dedalus_mcp/client/transports.py +++ b/src/dedalus_mcp/client/transports.py @@ -3,42 +3,35 @@ """HTTP transport helpers for :mod:`dedalus_mcp.client`. -TODO: Check this. This module provides variants of the streamable HTTP transport described in the -Model Context Protocol specification (see -the MCP specification). ``lambda_http_client`` mirrors +Model Context Protocol specification. ``lambda_http_client`` mirrors the reference SDK implementation but deliberately avoids registering a server-push GET stream so that it works with stateless environments such as AWS -Lambda. The behavior aligns with the "POST-only" pattern noted in the spec's -server guidance and our notes in ``docs/dedalus_mcp/transports.md``. +Lambda. The behavior aligns with the "POST-only" pattern noted in the spec's +server guidance. """ from __future__ import annotations -from collections.abc import AsyncGenerator, Callable +import contextlib +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from datetime import timedelta import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import httpx +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.client.streamable_http import GetSessionIdCallback, StreamableHTTPTransport -from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client +from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import SessionMessage -# TODO: Do we even need this anymore? @asynccontextmanager async def lambda_http_client( url: str, *, - headers: dict[str, str] | None = None, - timeout: float | timedelta = 30, - sse_read_timeout: float | timedelta = 300, + http_client: httpx.AsyncClient | None = None, terminate_on_close: bool = True, - httpx_client_factory: McpHttpClientFactory | Callable[..., httpx.AsyncClient] = create_mcp_http_client, - auth: httpx.Auth | None = None, ) -> AsyncGenerator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], @@ -50,30 +43,40 @@ async def lambda_http_client( """Create a streamable HTTP transport without the persistent GET stream. The Model Context Protocol allows streamable HTTP transports to keep a - server-push channel open (see https://modelcontextprotocol.io/specification/2024-11-05/basic/transports), - but serverless hosts like AWS Lambda cannot maintain such long-lived - connections. ``lambda_http_client`` mirrors the reference SDK's - ``streamablehttp_client`` implementation while replacing the - ``start_get_stream`` callback with a no-op. This keeps each JSON-RPC request - self-contained (``initialize`` -> operation -> optional ``session/close``) and - matches the stateless guidance in ``docs/dedalus_mcp/transports.md``. + server-push channel open, but serverless hosts like AWS Lambda cannot + maintain such long-lived connections. ``lambda_http_client`` mirrors the + reference SDK's ``streamable_http_client`` implementation while replacing + the ``start_get_stream`` callback with a no-op. This keeps each JSON-RPC + request self-contained. + + Args: + url: The MCP server endpoint URL. + http_client: Optional pre-configured httpx.AsyncClient. If None, a default + client with recommended MCP timeouts will be created. To configure headers, + authentication, or other HTTP settings, create an httpx.AsyncClient + and pass it here. + terminate_on_close: If True, send a DELETE request to terminate the session + when the context exits. Yields: Tuple of ``(read_stream, write_stream, get_session_id)`` compatible with :class:`mcp.client.session.ClientSession`. """ - transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth) - read_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) write_stream, write_reader = anyio.create_memory_object_stream[SessionMessage](0) + # Determine if we need to create and manage the client + client_provided = http_client is not None + client = http_client if client_provided else create_mcp_http_client() + + transport = StreamableHTTPTransport(url) + async with anyio.create_task_group() as tg: try: - async with httpx_client_factory( - headers=transport.request_headers, - timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), - auth=transport.auth, - ) as client: + async with contextlib.AsyncExitStack() as stack: + # Only manage client lifecycle if we created it + if not client_provided: + await stack.enter_async_context(client) def _noop_start_get_stream() -> None: """Lambda-safe placeholder that intentionally avoids SSE.""" diff --git a/tests/client/__init__.py b/tests/client/__init__.py new file mode 100644 index 0000000..e87003c --- /dev/null +++ b/tests/client/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT diff --git a/tests/client/auth/__init__.py b/tests/client/auth/__init__.py new file mode 100644 index 0000000..e87003c --- /dev/null +++ b/tests/client/auth/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT diff --git a/tests/client/auth/test_client_credentials.py b/tests/client/auth/test_client_credentials.py new file mode 100644 index 0000000..0b109ce --- /dev/null +++ b/tests/client/auth/test_client_credentials.py @@ -0,0 +1,459 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""Tests for OAuth 2.0 Client Credentials Auth (RFC 6749 Section 4.4). + +ClientCredentialsAuth is the primary auth mechanism for M2M (machine-to-machine) +communication, CI/CD pipelines, and backend services. +""" + +from __future__ import annotations + +import pytest +import httpx +import respx + + +# ============================================================================= +# ClientCredentialsAuth Construction Tests +# ============================================================================= + + +class TestClientCredentialsAuthConstruction: + """Tests for ClientCredentialsAuth initialization.""" + + def test_construction_with_server_metadata(self): + """ClientCredentialsAuth can be constructed with AS metadata.""" + from dedalus_mcp.client.auth.client_credentials import ClientCredentialsAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["client_credentials"], + ) + + auth = ClientCredentialsAuth( + server_metadata=server_metadata, + client_id="m2m", + client_secret="secret123", + ) + + assert auth.client_id == "m2m" + assert auth.token_endpoint == "https://as.example.com/oauth2/token" + + def test_construction_validates_grant_type_support(self): + """ClientCredentialsAuth raises if AS doesn't support client_credentials.""" + from dedalus_mcp.client.auth.client_credentials import ( + ClientCredentialsAuth, + AuthConfigError, + ) + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["authorization_code"], # No client_credentials + ) + + with pytest.raises(AuthConfigError, match="client_credentials"): + ClientCredentialsAuth( + server_metadata=server_metadata, + client_id="m2m", + client_secret="secret123", + ) + + def test_construction_with_scope(self): + """ClientCredentialsAuth accepts optional scope parameter.""" + from dedalus_mcp.client.auth.client_credentials import ClientCredentialsAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["client_credentials"], + ) + + auth = ClientCredentialsAuth( + server_metadata=server_metadata, + client_id="m2m", + client_secret="secret123", + scope="openid mcp:read", + ) + + assert auth.scope == "openid mcp:read" + + +# ============================================================================= +# Factory Method Tests +# ============================================================================= + + +class TestClientCredentialsAuthFromResource: + """Tests for ClientCredentialsAuth.from_resource factory method.""" + + @respx.mock + @pytest.mark.anyio + async def test_from_resource_full_discovery(self): + """from_resource performs full discovery and returns configured auth.""" + from dedalus_mcp.client.auth.client_credentials import ClientCredentialsAuth + + # Mock initial 401 response + respx.get("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="/.well-known/oauth-protected-resource"' + }, + ) + ) + + # Mock PRM endpoint + respx.get("https://mcp.example.com/.well-known/oauth-protected-resource").mock( + return_value=httpx.Response( + 200, + json={ + "resource": "https://mcp.example.com", + "authorization_servers": ["https://as.example.com"], + }, + ) + ) + + # Mock AS metadata endpoint + respx.get("https://as.example.com/.well-known/oauth-authorization-server").mock( + return_value=httpx.Response( + 200, + json={ + "issuer": "https://as.example.com", + "token_endpoint": "https://as.example.com/oauth2/token", + "grant_types_supported": ["client_credentials"], + }, + ) + ) + + auth = await ClientCredentialsAuth.from_resource( + resource_url="https://mcp.example.com/mcp", + client_id="m2m", + client_secret="secret123", + ) + + assert auth.client_id == "m2m" + assert auth.token_endpoint == "https://as.example.com/oauth2/token" + + @respx.mock + @pytest.mark.anyio + async def test_from_resource_unprotected_raises(self): + """from_resource raises if resource is not protected (no 401).""" + from dedalus_mcp.client.auth.client_credentials import ClientCredentialsAuth + from dedalus_mcp.client.auth.discovery import DiscoveryError + + respx.get("https://mcp.example.com/mcp").mock(return_value=httpx.Response(200)) + + with pytest.raises(DiscoveryError, match="not protected"): + await ClientCredentialsAuth.from_resource( + resource_url="https://mcp.example.com/mcp", + client_id="m2m", + client_secret="secret123", + ) + + +# ============================================================================= +# Token Acquisition Tests +# ============================================================================= + + +class TestClientCredentialsAuthTokenAcquisition: + """Tests for token acquisition via client credentials grant.""" + + @respx.mock + @pytest.mark.anyio + async def test_get_token_success(self): + """get_token acquires token from token endpoint.""" + from dedalus_mcp.client.auth.client_credentials import ClientCredentialsAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["client_credentials"], + ) + + respx.post("https://as.example.com/oauth2/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "eyJhbGciOiJFUzI1NiIs...", + "token_type": "Bearer", + "expires_in": 3600, + }, + ) + ) + + auth = ClientCredentialsAuth( + server_metadata=server_metadata, + client_id="m2m", + client_secret="secret123", + ) + + token = await auth.get_token() + + assert token.access_token == "eyJhbGciOiJFUzI1NiIs..." + assert token.token_type == "Bearer" + assert token.expires_in == 3600 + + @respx.mock + @pytest.mark.anyio + async def test_get_token_with_scope(self): + """get_token sends scope in token request.""" + from dedalus_mcp.client.auth.client_credentials import ClientCredentialsAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["client_credentials"], + ) + + route = respx.post("https://as.example.com/oauth2/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "token", + "token_type": "Bearer", + }, + ) + ) + + auth = ClientCredentialsAuth( + server_metadata=server_metadata, + client_id="m2m", + client_secret="secret123", + scope="openid mcp:read", + ) + + await auth.get_token() + + # Verify scope was sent in request + request = route.calls.last.request + body = request.content.decode() + assert "scope=openid" in body or "scope=openid+mcp%3Aread" in body or "openid" in body + + @respx.mock + @pytest.mark.anyio + async def test_get_token_uses_basic_auth(self): + """get_token uses HTTP Basic Auth for client authentication.""" + from dedalus_mcp.client.auth.client_credentials import ClientCredentialsAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + import base64 + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["client_credentials"], + ) + + route = respx.post("https://as.example.com/oauth2/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "token", + "token_type": "Bearer", + }, + ) + ) + + auth = ClientCredentialsAuth( + server_metadata=server_metadata, + client_id="m2m", + client_secret="secret123", + ) + + await auth.get_token() + + # Verify Basic Auth header + request = route.calls.last.request + auth_header = request.headers.get("Authorization") + assert auth_header is not None + assert auth_header.startswith("Basic ") + + # Decode and verify credentials + encoded = auth_header.split(" ")[1] + decoded = base64.b64decode(encoded).decode() + assert decoded == "m2m:secret123" + + @respx.mock + @pytest.mark.anyio + async def test_get_token_error_response(self): + """get_token raises on error response from token endpoint.""" + from dedalus_mcp.client.auth.client_credentials import ( + ClientCredentialsAuth, + TokenError, + ) + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["client_credentials"], + ) + + respx.post("https://as.example.com/oauth2/token").mock( + return_value=httpx.Response( + 400, + json={ + "error": "invalid_client", + "error_description": "Client authentication failed", + }, + ) + ) + + auth = ClientCredentialsAuth( + server_metadata=server_metadata, + client_id="m2m", + client_secret="wrong_secret", + ) + + with pytest.raises(TokenError, match="invalid_client"): + await auth.get_token() + + +# ============================================================================= +# httpx.Auth Interface Tests +# ============================================================================= + + +class TestClientCredentialsAuthHttpxInterface: + """Tests for ClientCredentialsAuth as httpx.Auth implementation.""" + + @respx.mock + @pytest.mark.anyio + async def test_auth_flow_injects_bearer_token(self): + """ClientCredentialsAuth injects Bearer token into requests.""" + from dedalus_mcp.client.auth.client_credentials import ClientCredentialsAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["client_credentials"], + ) + + # Mock token endpoint + respx.post("https://as.example.com/oauth2/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "the_access_token", + "token_type": "Bearer", + "expires_in": 3600, + }, + ) + ) + + # Mock protected resource + protected_route = respx.get("https://mcp.example.com/api").mock( + return_value=httpx.Response(200, json={"result": "success"}) + ) + + auth = ClientCredentialsAuth( + server_metadata=server_metadata, + client_id="m2m", + client_secret="secret123", + ) + + # Pre-fetch token + await auth.get_token() + + # Make request with auth + async with httpx.AsyncClient() as client: + response = await client.get("https://mcp.example.com/api", auth=auth) + + assert response.status_code == 200 + + # Verify Bearer token was injected + request = protected_route.calls.last.request + assert request.headers.get("Authorization") == "Bearer the_access_token" + + @respx.mock + @pytest.mark.anyio + async def test_token_caching(self): + """ClientCredentialsAuth caches token and reuses it.""" + from dedalus_mcp.client.auth.client_credentials import ClientCredentialsAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["client_credentials"], + ) + + token_route = respx.post("https://as.example.com/oauth2/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "cached_token", + "token_type": "Bearer", + "expires_in": 3600, + }, + ) + ) + + auth = ClientCredentialsAuth( + server_metadata=server_metadata, + client_id="m2m", + client_secret="secret123", + ) + + # Get token twice + token1 = await auth.get_token() + token2 = await auth.get_token() + + # Should only hit token endpoint once + assert len(token_route.calls) == 1 + assert token1.access_token == token2.access_token + + +# ============================================================================= +# Resource Indicator Tests (RFC 8707) +# ============================================================================= + + +class TestClientCredentialsAuthResourceIndicator: + """Tests for resource indicator support (RFC 8707).""" + + @respx.mock + @pytest.mark.anyio + async def test_get_token_with_resource_indicator(self): + """get_token can include resource indicator in token request.""" + from dedalus_mcp.client.auth.client_credentials import ClientCredentialsAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["client_credentials"], + ) + + route = respx.post("https://as.example.com/oauth2/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "token", + "token_type": "Bearer", + }, + ) + ) + + auth = ClientCredentialsAuth( + server_metadata=server_metadata, + client_id="m2m", + client_secret="secret123", + resource="https://mcp.example.com", + ) + + await auth.get_token() + + # Verify resource was sent in request + request = route.calls.last.request + body = request.content.decode() + assert "resource=" in body diff --git a/tests/client/auth/test_discovery.py b/tests/client/auth/test_discovery.py new file mode 100644 index 0000000..2081db6 --- /dev/null +++ b/tests/client/auth/test_discovery.py @@ -0,0 +1,308 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""Tests for OAuth discovery (RFC 9728, RFC 8414). + +MCP clients MUST: +- Parse WWW-Authenticate headers and respond to 401 responses +- Use OAuth 2.0 Protected Resource Metadata for AS discovery +- Use OAuth 2.0 Authorization Server Metadata +""" + +from __future__ import annotations + +import pytest +import httpx +import respx + + +# ============================================================================= +# Resource Metadata Discovery Tests (RFC 9728) +# ============================================================================= + + +class TestFetchResourceMetadata: + """Tests for fetching Protected Resource Metadata.""" + + @respx.mock + @pytest.mark.anyio + async def test_fetch_resource_metadata_success(self): + """fetch_resource_metadata fetches and parses PRM.""" + from dedalus_mcp.client.auth.discovery import fetch_resource_metadata + + respx.get("https://mcp.example.com/.well-known/oauth-protected-resource").mock( + return_value=httpx.Response( + 200, + json={ + "resource": "https://mcp.example.com", + "authorization_servers": ["https://as.example.com"], + "scopes_supported": ["openid"], + }, + ) + ) + + async with httpx.AsyncClient() as client: + meta = await fetch_resource_metadata( + client, "https://mcp.example.com/.well-known/oauth-protected-resource" + ) + + assert meta.resource == "https://mcp.example.com" + assert meta.authorization_servers == ["https://as.example.com"] + + @respx.mock + @pytest.mark.anyio + async def test_fetch_resource_metadata_not_found(self): + """fetch_resource_metadata raises on 404.""" + from dedalus_mcp.client.auth.discovery import fetch_resource_metadata, DiscoveryError + + respx.get("https://mcp.example.com/.well-known/oauth-protected-resource").mock( + return_value=httpx.Response(404) + ) + + async with httpx.AsyncClient() as client: + with pytest.raises(DiscoveryError, match="404"): + await fetch_resource_metadata( + client, "https://mcp.example.com/.well-known/oauth-protected-resource" + ) + + @respx.mock + @pytest.mark.anyio + async def test_fetch_resource_metadata_invalid_json(self): + """fetch_resource_metadata raises on invalid JSON.""" + from dedalus_mcp.client.auth.discovery import fetch_resource_metadata, DiscoveryError + + respx.get("https://mcp.example.com/.well-known/oauth-protected-resource").mock( + return_value=httpx.Response(200, content=b"not json") + ) + + async with httpx.AsyncClient() as client: + with pytest.raises(DiscoveryError, match="JSON"): + await fetch_resource_metadata( + client, "https://mcp.example.com/.well-known/oauth-protected-resource" + ) + + +# ============================================================================= +# Authorization Server Metadata Discovery Tests (RFC 8414) +# ============================================================================= + + +class TestFetchASMetadata: + """Tests for fetching Authorization Server Metadata.""" + + @respx.mock + @pytest.mark.anyio + async def test_fetch_authorization_server_metadata_success(self): + """fetch_authorization_server_metadata fetches and parses AS metadata.""" + from dedalus_mcp.client.auth.discovery import fetch_authorization_server_metadata + + respx.get("https://as.example.com/.well-known/oauth-authorization-server").mock( + return_value=httpx.Response( + 200, + json={ + "issuer": "https://as.example.com", + "token_endpoint": "https://as.example.com/oauth2/token", + "grant_types_supported": ["client_credentials", "authorization_code"], + }, + ) + ) + + async with httpx.AsyncClient() as client: + meta = await fetch_authorization_server_metadata(client, "https://as.example.com") + + assert meta.issuer == "https://as.example.com" + assert meta.token_endpoint == "https://as.example.com/oauth2/token" + assert "client_credentials" in meta.grant_types_supported + + @respx.mock + @pytest.mark.anyio + async def test_fetch_authorization_server_metadata_constructs_url(self): + """fetch_authorization_server_metadata constructs the well-known URL correctly.""" + from dedalus_mcp.client.auth.discovery import fetch_authorization_server_metadata + + # AS URL with trailing slash + route = respx.get("https://as.example.com/.well-known/oauth-authorization-server").mock( + return_value=httpx.Response( + 200, + json={ + "issuer": "https://as.example.com", + "token_endpoint": "https://as.example.com/oauth2/token", + }, + ) + ) + + async with httpx.AsyncClient() as client: + await fetch_authorization_server_metadata(client, "https://as.example.com/") + + assert route.called + + @respx.mock + @pytest.mark.anyio + async def test_fetch_authorization_server_metadata_not_found(self): + """fetch_authorization_server_metadata raises on 404.""" + from dedalus_mcp.client.auth.discovery import fetch_authorization_server_metadata, DiscoveryError + + respx.get("https://as.example.com/.well-known/oauth-authorization-server").mock( + return_value=httpx.Response(404) + ) + + async with httpx.AsyncClient() as client: + with pytest.raises(DiscoveryError, match="404"): + await fetch_authorization_server_metadata(client, "https://as.example.com") + + +# ============================================================================= +# Full Discovery Flow Tests +# ============================================================================= + + +class TestDiscoverAuthorizationServer: + """Tests for the complete discovery flow.""" + + @respx.mock + @pytest.mark.anyio + async def test_discover_from_401_response(self): + """discover_authorization_server handles 401 → PRM → AS metadata flow.""" + from dedalus_mcp.client.auth.discovery import discover_authorization_server + + # Mock 401 response with WWW-Authenticate header + respx.get("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer error="invalid_token", resource_metadata="/.well-known/oauth-protected-resource"' + }, + ) + ) + + # Mock PRM endpoint + respx.get("https://mcp.example.com/.well-known/oauth-protected-resource").mock( + return_value=httpx.Response( + 200, + json={ + "resource": "https://mcp.example.com", + "authorization_servers": ["https://as.example.com"], + }, + ) + ) + + # Mock AS metadata endpoint + respx.get("https://as.example.com/.well-known/oauth-authorization-server").mock( + return_value=httpx.Response( + 200, + json={ + "issuer": "https://as.example.com", + "token_endpoint": "https://as.example.com/oauth2/token", + "grant_types_supported": ["client_credentials"], + }, + ) + ) + + async with httpx.AsyncClient() as client: + result = await discover_authorization_server(client, "https://mcp.example.com/mcp") + + assert result.resource_metadata.resource == "https://mcp.example.com" + assert result.authorization_server_metadata.issuer == "https://as.example.com" + assert result.authorization_server_metadata.token_endpoint == "https://as.example.com/oauth2/token" + + @respx.mock + @pytest.mark.anyio + async def test_discover_no_401_raises(self): + """discover_authorization_server raises if no 401 received.""" + from dedalus_mcp.client.auth.discovery import discover_authorization_server, DiscoveryError + + # Server returns 200 (not protected) + respx.get("https://mcp.example.com/mcp").mock(return_value=httpx.Response(200)) + + async with httpx.AsyncClient() as client: + with pytest.raises(DiscoveryError, match="not protected"): + await discover_authorization_server(client, "https://mcp.example.com/mcp") + + @respx.mock + @pytest.mark.anyio + async def test_discover_missing_www_authenticate(self): + """discover_authorization_server raises if 401 lacks WWW-Authenticate.""" + from dedalus_mcp.client.auth.discovery import discover_authorization_server, DiscoveryError + + respx.get("https://mcp.example.com/mcp").mock(return_value=httpx.Response(401)) + + async with httpx.AsyncClient() as client: + with pytest.raises(DiscoveryError, match="WWW-Authenticate"): + await discover_authorization_server(client, "https://mcp.example.com/mcp") + + @respx.mock + @pytest.mark.anyio + async def test_discover_missing_resource_metadata_param(self): + """discover_authorization_server raises if WWW-Authenticate lacks resource_metadata.""" + from dedalus_mcp.client.auth.discovery import discover_authorization_server, DiscoveryError + + respx.get("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 401, + headers={"WWW-Authenticate": 'Bearer error="invalid_token"'}, + ) + ) + + async with httpx.AsyncClient() as client: + with pytest.raises(DiscoveryError, match="resource_metadata"): + await discover_authorization_server(client, "https://mcp.example.com/mcp") + + +# ============================================================================= +# URL Construction Tests +# ============================================================================= + + +class TestBuildMetadataUrl: + """Tests for metadata URL construction helpers.""" + + def test_build_resource_metadata_url_absolute(self): + """build_resource_metadata_url handles absolute paths.""" + from dedalus_mcp.client.auth.discovery import build_resource_metadata_url + + url = build_resource_metadata_url( + "https://mcp.example.com/mcp", "/.well-known/oauth-protected-resource" + ) + assert url == "https://mcp.example.com/.well-known/oauth-protected-resource" + + def test_build_resource_metadata_url_relative(self): + """build_resource_metadata_url handles relative paths.""" + from dedalus_mcp.client.auth.discovery import build_resource_metadata_url + + url = build_resource_metadata_url( + "https://mcp.example.com/api/mcp", ".well-known/oauth-protected-resource" + ) + # Relative to the path + assert "mcp.example.com" in url + assert "oauth-protected-resource" in url + + def test_build_resource_metadata_url_full_url(self): + """build_resource_metadata_url handles full URLs.""" + from dedalus_mcp.client.auth.discovery import build_resource_metadata_url + + url = build_resource_metadata_url( + "https://mcp.example.com/mcp", "https://other.example.com/.well-known/prm" + ) + assert url == "https://other.example.com/.well-known/prm" + + def test_build_authorization_server_metadata_url(self): + """build_authorization_server_metadata_url constructs well-known URL.""" + from dedalus_mcp.client.auth.discovery import build_authorization_server_metadata_url + + url = build_authorization_server_metadata_url("https://as.example.com") + assert url == "https://as.example.com/.well-known/oauth-authorization-server" + + def test_build_authorization_server_metadata_url_strips_trailing_slash(self): + """build_authorization_server_metadata_url strips trailing slash.""" + from dedalus_mcp.client.auth.discovery import build_authorization_server_metadata_url + + url = build_authorization_server_metadata_url("https://as.example.com/") + assert url == "https://as.example.com/.well-known/oauth-authorization-server" + + def test_build_authorization_server_metadata_url_with_path(self): + """build_authorization_server_metadata_url handles AS URL with path.""" + from dedalus_mcp.client.auth.discovery import build_authorization_server_metadata_url + + # Per RFC 8414, if issuer has path, insert .well-known between origin and path + url = build_authorization_server_metadata_url("https://as.example.com/tenant1") + assert url == "https://as.example.com/.well-known/oauth-authorization-server/tenant1" diff --git a/tests/client/auth/test_models.py b/tests/client/auth/test_models.py new file mode 100644 index 0000000..b1a091d --- /dev/null +++ b/tests/client/auth/test_models.py @@ -0,0 +1,326 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""Tests for OAuth metadata models (RFC 9728, RFC 8414).""" + +from __future__ import annotations + +import pytest + + +# ============================================================================= +# ResourceMetadata Tests (RFC 9728) +# ============================================================================= + + +class TestResourceMetadata: + """Tests for OAuth 2.0 Protected Resource Metadata (RFC 9728).""" + + def test_construction_minimal(self): + """ResourceMetadata can be constructed with minimal required fields.""" + from dedalus_mcp.client.auth.models import ResourceMetadata + + meta = ResourceMetadata( + resource="https://mcp.example.com", + authorization_servers=["https://as.example.com"], + ) + assert meta.resource == "https://mcp.example.com" + assert meta.authorization_servers == ["https://as.example.com"] + + def test_construction_full(self): + """ResourceMetadata can be constructed with all optional fields.""" + from dedalus_mcp.client.auth.models import ResourceMetadata + + meta = ResourceMetadata( + resource="https://mcp.example.com", + authorization_servers=["https://as.example.com", "https://as2.example.com"], + scopes_supported=["openid", "mcp:read", "mcp:write"], + bearer_methods_supported=["header"], + resource_signing_alg_values_supported=["RS256", "ES256"], + ) + assert meta.resource == "https://mcp.example.com" + assert len(meta.authorization_servers) == 2 + assert meta.scopes_supported == ["openid", "mcp:read", "mcp:write"] + assert meta.bearer_methods_supported == ["header"] + assert meta.resource_signing_alg_values_supported == ["RS256", "ES256"] + + def test_from_dict(self): + """ResourceMetadata can be created from a dictionary.""" + from dedalus_mcp.client.auth.models import ResourceMetadata + + data = { + "resource": "https://mcp.example.com", + "authorization_servers": ["https://as.example.com"], + "scopes_supported": ["openid"], + } + meta = ResourceMetadata.from_dict(data) + assert meta.resource == "https://mcp.example.com" + assert meta.authorization_servers == ["https://as.example.com"] + assert meta.scopes_supported == ["openid"] + + def test_from_dict_ignores_unknown_fields(self): + """ResourceMetadata.from_dict ignores unknown fields.""" + from dedalus_mcp.client.auth.models import ResourceMetadata + + data = { + "resource": "https://mcp.example.com", + "authorization_servers": ["https://as.example.com"], + "unknown_field": "should be ignored", + } + meta = ResourceMetadata.from_dict(data) + assert meta.resource == "https://mcp.example.com" + assert not hasattr(meta, "unknown_field") + + def test_from_dict_missing_required_field_raises(self): + """ResourceMetadata.from_dict raises on missing required fields.""" + from dedalus_mcp.client.auth.models import ResourceMetadata + + with pytest.raises(ValueError, match="resource"): + ResourceMetadata.from_dict({"authorization_servers": ["https://as.example.com"]}) + + with pytest.raises(ValueError, match="authorization_servers"): + ResourceMetadata.from_dict({"resource": "https://mcp.example.com"}) + + def test_primary_authorization_server(self): + """primary_authorization_server returns first AS.""" + from dedalus_mcp.client.auth.models import ResourceMetadata + + meta = ResourceMetadata( + resource="https://mcp.example.com", + authorization_servers=["https://as1.example.com", "https://as2.example.com"], + ) + assert meta.primary_authorization_server == "https://as1.example.com" + + +# ============================================================================= +# AuthorizationServerMetadata Tests (RFC 8414) +# ============================================================================= + + +class TestAuthorizationServerMetadata: + """Tests for OAuth 2.0 Authorization Server Metadata (RFC 8414).""" + + def test_construction_minimal(self): + """ASMetadata can be constructed with minimal required fields.""" + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + meta = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + ) + assert meta.issuer == "https://as.example.com" + assert meta.token_endpoint == "https://as.example.com/oauth2/token" + + def test_construction_full(self): + """ASMetadata can be constructed with all common fields.""" + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + meta = AuthorizationServerMetadata( + issuer="https://as.example.com", + authorization_endpoint="https://as.example.com/oauth2/auth", + token_endpoint="https://as.example.com/oauth2/token", + registration_endpoint="https://as.example.com/register", + jwks_uri="https://as.example.com/.well-known/jwks.json", + scopes_supported=["openid", "offline_access"], + response_types_supported=["code"], + grant_types_supported=["authorization_code", "client_credentials", "refresh_token"], + token_endpoint_auth_methods_supported=["client_secret_basic", "client_secret_post"], + code_challenge_methods_supported=["S256"], + ) + assert meta.issuer == "https://as.example.com" + assert meta.authorization_endpoint == "https://as.example.com/oauth2/auth" + assert "client_credentials" in meta.grant_types_supported + assert "S256" in meta.code_challenge_methods_supported + + def test_from_dict(self): + """ASMetadata can be created from a dictionary.""" + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + data = { + "issuer": "https://as.example.com", + "token_endpoint": "https://as.example.com/oauth2/token", + "grant_types_supported": ["client_credentials"], + } + meta = AuthorizationServerMetadata.from_dict(data) + assert meta.issuer == "https://as.example.com" + assert meta.token_endpoint == "https://as.example.com/oauth2/token" + assert meta.grant_types_supported == ["client_credentials"] + + def test_from_dict_ignores_unknown_fields(self): + """ASMetadata.from_dict ignores unknown fields.""" + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + data = { + "issuer": "https://as.example.com", + "token_endpoint": "https://as.example.com/oauth2/token", + "custom_extension": "value", + } + meta = AuthorizationServerMetadata.from_dict(data) + assert meta.issuer == "https://as.example.com" + assert not hasattr(meta, "custom_extension") + + def test_from_dict_missing_required_field_raises(self): + """ASMetadata.from_dict raises on missing required fields.""" + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + with pytest.raises(ValueError, match="issuer"): + AuthorizationServerMetadata.from_dict({"token_endpoint": "https://as.example.com/token"}) + + with pytest.raises(ValueError, match="token_endpoint"): + AuthorizationServerMetadata.from_dict({"issuer": "https://as.example.com"}) + + def test_supports_grant_type(self): + """supports_grant_type checks grant_types_supported list.""" + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + meta = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["authorization_code", "client_credentials"], + ) + assert meta.supports_grant_type("client_credentials") is True + assert meta.supports_grant_type("authorization_code") is True + assert meta.supports_grant_type("refresh_token") is False + + def test_supports_grant_type_default_none(self): + """supports_grant_type returns False when grant_types_supported is None.""" + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + meta = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + ) + # Per RFC 8414, if not present, default is ["authorization_code", "implicit"] + # but we don't assume - just return False for safety + assert meta.supports_grant_type("client_credentials") is False + + +# ============================================================================= +# TokenResponse Tests +# ============================================================================= + + +class TestTokenResponse: + """Tests for OAuth token response model.""" + + def test_construction(self): + """TokenResponse can be constructed with all fields.""" + from dedalus_mcp.client.auth.models import TokenResponse + + token = TokenResponse( + access_token="eyJhbGciOiJFUzI1NiIs...", + token_type="Bearer", + expires_in=3600, + refresh_token="refresh_token_value", + scope="openid mcp:read", + ) + assert token.access_token == "eyJhbGciOiJFUzI1NiIs..." + assert token.token_type == "Bearer" + assert token.expires_in == 3600 + assert token.refresh_token == "refresh_token_value" + assert token.scope == "openid mcp:read" + + def test_construction_minimal(self): + """TokenResponse can be constructed with minimal fields.""" + from dedalus_mcp.client.auth.models import TokenResponse + + token = TokenResponse( + access_token="eyJhbGciOiJFUzI1NiIs...", + token_type="Bearer", + ) + assert token.access_token == "eyJhbGciOiJFUzI1NiIs..." + assert token.token_type == "Bearer" + assert token.expires_in is None + assert token.refresh_token is None + + def test_from_dict(self): + """TokenResponse can be created from a dictionary.""" + from dedalus_mcp.client.auth.models import TokenResponse + + data = { + "access_token": "token123", + "token_type": "Bearer", + "expires_in": 7200, + } + token = TokenResponse.from_dict(data) + assert token.access_token == "token123" + assert token.token_type == "Bearer" + assert token.expires_in == 7200 + + def test_from_dict_missing_required_raises(self): + """TokenResponse.from_dict raises on missing required fields.""" + from dedalus_mcp.client.auth.models import TokenResponse + + with pytest.raises(ValueError, match="access_token"): + TokenResponse.from_dict({"token_type": "Bearer"}) + + with pytest.raises(ValueError, match="token_type"): + TokenResponse.from_dict({"access_token": "token"}) + + +# ============================================================================= +# WWWAuthenticate Parsing Tests +# ============================================================================= + + +class TestWWWAuthenticateParsing: + """Tests for WWW-Authenticate header parsing.""" + + def test_parse_bearer_with_resource_metadata(self): + """Parse WWW-Authenticate header with resource_metadata parameter.""" + from dedalus_mcp.client.auth.models import parse_www_authenticate + + header = 'Bearer error="invalid_token", resource_metadata="/.well-known/oauth-protected-resource"' + result = parse_www_authenticate(header) + assert result.scheme == "Bearer" + assert result.resource_metadata == "/.well-known/oauth-protected-resource" + + def test_parse_dpop_scheme(self): + """Parse WWW-Authenticate header with DPoP scheme.""" + from dedalus_mcp.client.auth.models import parse_www_authenticate + + header = 'DPoP error="invalid_token", resource_metadata="/prm"' + result = parse_www_authenticate(header) + assert result.scheme == "DPoP" + assert result.resource_metadata == "/prm" + + def test_parse_with_error_description(self): + """Parse WWW-Authenticate header with error_description.""" + from dedalus_mcp.client.auth.models import parse_www_authenticate + + header = 'Bearer error="invalid_token", error_description="Token expired", resource_metadata="/prm"' + result = parse_www_authenticate(header) + assert result.error == "invalid_token" + assert result.error_description == "Token expired" + assert result.resource_metadata == "/prm" + + def test_parse_missing_resource_metadata(self): + """parse_www_authenticate returns None for resource_metadata if not present.""" + from dedalus_mcp.client.auth.models import parse_www_authenticate + + header = 'Bearer error="invalid_token"' + result = parse_www_authenticate(header) + assert result.scheme == "Bearer" + assert result.resource_metadata is None + + def test_parse_case_insensitive_scheme(self): + """parse_www_authenticate handles case-insensitive scheme.""" + from dedalus_mcp.client.auth.models import parse_www_authenticate + + header = 'BEARER error="invalid_token", resource_metadata="/prm"' + result = parse_www_authenticate(header) + assert result.scheme.upper() == "BEARER" + + def test_parse_empty_raises(self): + """parse_www_authenticate raises on empty header.""" + from dedalus_mcp.client.auth.models import parse_www_authenticate + + with pytest.raises(ValueError): + parse_www_authenticate("") + + def test_parse_malformed_raises(self): + """parse_www_authenticate raises on malformed header.""" + from dedalus_mcp.client.auth.models import parse_www_authenticate + + with pytest.raises(ValueError): + parse_www_authenticate("not-a-valid-header") diff --git a/tests/client/auth/test_token_exchange.py b/tests/client/auth/test_token_exchange.py new file mode 100644 index 0000000..111a0c0 --- /dev/null +++ b/tests/client/auth/test_token_exchange.py @@ -0,0 +1,472 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""Tests for OAuth 2.0 Token Exchange Auth (RFC 8693). + +TokenExchangeAuth exchanges an existing token (e.g., from Clerk, Auth0) +for an MCP-scoped access token. Used for user delegation flows. +""" + +from __future__ import annotations + +import pytest +import httpx +import respx + + +# ============================================================================= +# TokenExchangeAuth Construction Tests +# ============================================================================= + + +class TestTokenExchangeAuthConstruction: + """Tests for TokenExchangeAuth initialization.""" + + def test_construction_with_server_metadata(self): + """TokenExchangeAuth can be constructed with AS metadata.""" + from dedalus_mcp.client.auth.token_exchange import TokenExchangeAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["urn:ietf:params:oauth:grant-type:token-exchange"], + ) + + auth = TokenExchangeAuth( + server_metadata=server_metadata, + client_id="dedalus-sdk", + subject_token="eyJhbGciOiJSUzI1NiIs...", + ) + + assert auth.client_id == "dedalus-sdk" + assert auth.token_endpoint == "https://as.example.com/oauth2/token" + + def test_construction_validates_grant_type_support(self): + """TokenExchangeAuth raises if AS doesn't support token-exchange.""" + from dedalus_mcp.client.auth.token_exchange import ( + TokenExchangeAuth, + AuthConfigError, + ) + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["client_credentials"], # No token-exchange + ) + + with pytest.raises(AuthConfigError, match="token-exchange"): + TokenExchangeAuth( + server_metadata=server_metadata, + client_id="dedalus-sdk", + subject_token="token", + ) + + def test_construction_with_subject_token_type(self): + """TokenExchangeAuth accepts subject_token_type parameter.""" + from dedalus_mcp.client.auth.token_exchange import TokenExchangeAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["urn:ietf:params:oauth:grant-type:token-exchange"], + ) + + auth = TokenExchangeAuth( + server_metadata=server_metadata, + client_id="dedalus-sdk", + subject_token="token", + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + ) + + assert auth.subject_token_type == "urn:ietf:params:oauth:token-type:id_token" + + def test_construction_default_subject_token_type(self): + """TokenExchangeAuth defaults to access_token type.""" + from dedalus_mcp.client.auth.token_exchange import TokenExchangeAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["urn:ietf:params:oauth:grant-type:token-exchange"], + ) + + auth = TokenExchangeAuth( + server_metadata=server_metadata, + client_id="dedalus-sdk", + subject_token="token", + ) + + assert auth.subject_token_type == "urn:ietf:params:oauth:token-type:access_token" + + +# ============================================================================= +# Factory Method Tests +# ============================================================================= + + +class TestTokenExchangeAuthFromResource: + """Tests for TokenExchangeAuth.from_resource factory method.""" + + @respx.mock + @pytest.mark.anyio + async def test_from_resource_full_discovery(self): + """from_resource performs full discovery and returns configured auth.""" + from dedalus_mcp.client.auth.token_exchange import TokenExchangeAuth + + # Mock initial 401 response + respx.get("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="/.well-known/oauth-protected-resource"' + }, + ) + ) + + # Mock PRM endpoint + respx.get("https://mcp.example.com/.well-known/oauth-protected-resource").mock( + return_value=httpx.Response( + 200, + json={ + "resource": "https://mcp.example.com", + "authorization_servers": ["https://as.example.com"], + }, + ) + ) + + # Mock AS metadata endpoint + respx.get("https://as.example.com/.well-known/oauth-authorization-server").mock( + return_value=httpx.Response( + 200, + json={ + "issuer": "https://as.example.com", + "token_endpoint": "https://as.example.com/oauth2/token", + "grant_types_supported": ["urn:ietf:params:oauth:grant-type:token-exchange"], + }, + ) + ) + + auth = await TokenExchangeAuth.from_resource( + resource_url="https://mcp.example.com/mcp", + client_id="dedalus-sdk", + subject_token="user_token_from_clerk", + ) + + assert auth.client_id == "dedalus-sdk" + assert auth.token_endpoint == "https://as.example.com/oauth2/token" + + +# ============================================================================= +# Token Exchange Tests +# ============================================================================= + + +class TestTokenExchangeAuthTokenAcquisition: + """Tests for token acquisition via token exchange grant.""" + + @respx.mock + @pytest.mark.anyio + async def test_get_token_success(self): + """get_token exchanges subject token for access token.""" + from dedalus_mcp.client.auth.token_exchange import TokenExchangeAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["urn:ietf:params:oauth:grant-type:token-exchange"], + ) + + respx.post("https://as.example.com/oauth2/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "exchanged_access_token", + "token_type": "Bearer", + "expires_in": 3600, + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + }, + ) + ) + + auth = TokenExchangeAuth( + server_metadata=server_metadata, + client_id="dedalus-sdk", + subject_token="user_id_token", + ) + + token = await auth.get_token() + + assert token.access_token == "exchanged_access_token" + assert token.token_type == "Bearer" + + @respx.mock + @pytest.mark.anyio + async def test_get_token_sends_correct_params(self): + """get_token sends RFC 8693 compliant parameters.""" + from dedalus_mcp.client.auth.token_exchange import TokenExchangeAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["urn:ietf:params:oauth:grant-type:token-exchange"], + ) + + route = respx.post("https://as.example.com/oauth2/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "token", + "token_type": "Bearer", + }, + ) + ) + + auth = TokenExchangeAuth( + server_metadata=server_metadata, + client_id="dedalus-sdk", + subject_token="the_subject_token", + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + ) + + await auth.get_token() + + # Verify RFC 8693 parameters + request = route.calls.last.request + body = request.content.decode() + + assert "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange" in body + assert "subject_token=the_subject_token" in body + assert "subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token" in body + + @respx.mock + @pytest.mark.anyio + async def test_get_token_with_resource_indicator(self): + """get_token can include resource indicator.""" + from dedalus_mcp.client.auth.token_exchange import TokenExchangeAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["urn:ietf:params:oauth:grant-type:token-exchange"], + ) + + route = respx.post("https://as.example.com/oauth2/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "token", + "token_type": "Bearer", + }, + ) + ) + + auth = TokenExchangeAuth( + server_metadata=server_metadata, + client_id="dedalus-sdk", + subject_token="token", + resource="https://mcp.example.com", + ) + + await auth.get_token() + + request = route.calls.last.request + body = request.content.decode() + assert "resource=" in body + + @respx.mock + @pytest.mark.anyio + async def test_get_token_with_scope(self): + """get_token can include requested scope.""" + from dedalus_mcp.client.auth.token_exchange import TokenExchangeAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["urn:ietf:params:oauth:grant-type:token-exchange"], + ) + + route = respx.post("https://as.example.com/oauth2/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "token", + "token_type": "Bearer", + }, + ) + ) + + auth = TokenExchangeAuth( + server_metadata=server_metadata, + client_id="dedalus-sdk", + subject_token="token", + scope="openid mcp:read", + ) + + await auth.get_token() + + request = route.calls.last.request + body = request.content.decode() + assert "scope=" in body + + @respx.mock + @pytest.mark.anyio + async def test_get_token_error_response(self): + """get_token raises on error response.""" + from dedalus_mcp.client.auth.token_exchange import TokenExchangeAuth, TokenError + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["urn:ietf:params:oauth:grant-type:token-exchange"], + ) + + respx.post("https://as.example.com/oauth2/token").mock( + return_value=httpx.Response( + 400, + json={ + "error": "invalid_grant", + "error_description": "Subject token is expired", + }, + ) + ) + + auth = TokenExchangeAuth( + server_metadata=server_metadata, + client_id="dedalus-sdk", + subject_token="expired_token", + ) + + with pytest.raises(TokenError, match="invalid_grant"): + await auth.get_token() + + +# ============================================================================= +# httpx.Auth Interface Tests +# ============================================================================= + + +class TestTokenExchangeAuthHttpxInterface: + """Tests for TokenExchangeAuth as httpx.Auth implementation.""" + + @respx.mock + @pytest.mark.anyio + async def test_auth_flow_injects_bearer_token(self): + """TokenExchangeAuth injects Bearer token into requests.""" + from dedalus_mcp.client.auth.token_exchange import TokenExchangeAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["urn:ietf:params:oauth:grant-type:token-exchange"], + ) + + # Mock token endpoint + respx.post("https://as.example.com/oauth2/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "exchanged_token", + "token_type": "Bearer", + }, + ) + ) + + # Mock protected resource + protected_route = respx.get("https://mcp.example.com/api").mock( + return_value=httpx.Response(200, json={"result": "success"}) + ) + + auth = TokenExchangeAuth( + server_metadata=server_metadata, + client_id="dedalus-sdk", + subject_token="user_token", + ) + + # Pre-fetch token + await auth.get_token() + + # Make request with auth + async with httpx.AsyncClient() as client: + response = await client.get("https://mcp.example.com/api", auth=auth) + + assert response.status_code == 200 + + # Verify Bearer token was injected + request = protected_route.calls.last.request + assert request.headers.get("Authorization") == "Bearer exchanged_token" + + +# ============================================================================= +# Actor Token Tests (RFC 8693 Section 2.1) +# ============================================================================= + + +class TestTokenExchangeAuthActorToken: + """Tests for actor token support (delegation scenarios).""" + + def test_construction_with_actor_token(self): + """TokenExchangeAuth accepts actor_token for delegation.""" + from dedalus_mcp.client.auth.token_exchange import TokenExchangeAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["urn:ietf:params:oauth:grant-type:token-exchange"], + ) + + auth = TokenExchangeAuth( + server_metadata=server_metadata, + client_id="dedalus-sdk", + subject_token="user_token", + actor_token="service_token", + actor_token_type="urn:ietf:params:oauth:token-type:access_token", + ) + + assert auth.actor_token == "service_token" + + @respx.mock + @pytest.mark.anyio + async def test_get_token_with_actor_token(self): + """get_token includes actor_token in request when provided.""" + from dedalus_mcp.client.auth.token_exchange import TokenExchangeAuth + from dedalus_mcp.client.auth.models import AuthorizationServerMetadata + + server_metadata = AuthorizationServerMetadata( + issuer="https://as.example.com", + token_endpoint="https://as.example.com/oauth2/token", + grant_types_supported=["urn:ietf:params:oauth:grant-type:token-exchange"], + ) + + route = respx.post("https://as.example.com/oauth2/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "token", + "token_type": "Bearer", + }, + ) + ) + + auth = TokenExchangeAuth( + server_metadata=server_metadata, + client_id="dedalus-sdk", + subject_token="user_token", + actor_token="service_token", + ) + + await auth.get_token() + + request = route.calls.last.request + body = request.content.decode() + assert "actor_token=service_token" in body diff --git a/tests/client/test_connection_errors.py b/tests/client/test_connection_errors.py new file mode 100644 index 0000000..0ac9736 --- /dev/null +++ b/tests/client/test_connection_errors.py @@ -0,0 +1,650 @@ +# Copyright (c) 2025 Dedalus Labs, Inc. and its contributors +# SPDX-License-Identifier: MIT + +"""TDD tests for MCPClient connection error handling. + +These tests define the expected behavior for HTTP status codes during +MCP connection per the spec: +- RFC 9728 (OAuth Protected Resource Metadata) +- MCP Transport Specification (2025-11-25) +- MCP Authorization Specification + +Each error should produce a specific, actionable error message. +""" + +from __future__ import annotations + +import pytest +import httpx +import respx + + +# ============================================================================= +# Expected Exception Types (TDD: Define interface first) +# ============================================================================= + + +class TestConnectionErrorTypes: + """Verify that connection error types exist and have expected hierarchy.""" + + def test_mcp_connection_error_exists(self): + """MCPConnectionError should be the base for all connection errors.""" + from dedalus_mcp.client.errors import MCPConnectionError + + assert issubclass(MCPConnectionError, Exception) + + def test_auth_required_error_exists(self): + """AuthRequiredError for 401 responses.""" + from dedalus_mcp.client.errors import AuthRequiredError, MCPConnectionError + + assert issubclass(AuthRequiredError, MCPConnectionError) + + def test_forbidden_error_exists(self): + """ForbiddenError for 403 responses.""" + from dedalus_mcp.client.errors import ForbiddenError, MCPConnectionError + + assert issubclass(ForbiddenError, MCPConnectionError) + + def test_session_expired_error_exists(self): + """SessionExpiredError for 404 responses (session terminated).""" + from dedalus_mcp.client.errors import SessionExpiredError, MCPConnectionError + + assert issubclass(SessionExpiredError, MCPConnectionError) + + def test_transport_error_exists(self): + """TransportError for 405/415 responses (protocol mismatch).""" + from dedalus_mcp.client.errors import TransportError, MCPConnectionError + + assert issubclass(TransportError, MCPConnectionError) + + def test_bad_request_error_exists(self): + """BadRequestError for 400 responses.""" + from dedalus_mcp.client.errors import BadRequestError, MCPConnectionError + + assert issubclass(BadRequestError, MCPConnectionError) + + def test_server_error_exists(self): + """ServerError for 5xx responses.""" + from dedalus_mcp.client.errors import ServerError, MCPConnectionError + + assert issubclass(ServerError, MCPConnectionError) + + +# ============================================================================= +# 400 Bad Request Tests (MCP Transport Spec) +# ============================================================================= + + +class TestBadRequestErrors: + """Tests for 400 Bad Request handling. + + Per MCP spec, 400 indicates: + - Invalid input (malformed JSON-RPC) + - Invalid MCP-Protocol-Version header + - Malformed authorization request + """ + + @respx.mock + @pytest.mark.anyio + async def test_400_invalid_protocol_version(self): + """400 with version error produces BadRequestError. + + Note: When the response body isn't readable (streaming), we fall back + to a generic message. The key behavior is raising BadRequestError. + """ + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import BadRequestError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 400, + json={ + "error": "invalid_request", + "error_description": "Unsupported MCP-Protocol-Version: 2023-01-01", + }, + ) + ) + + with pytest.raises(BadRequestError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + # Should be BadRequestError with status code 400 + assert exc_info.value.status_code == 400 + + @respx.mock + @pytest.mark.anyio + async def test_400_malformed_json_rpc(self): + """400 with parse error produces actionable BadRequestError.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import BadRequestError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 400, + json={ + "jsonrpc": "2.0", + "error": {"code": -32700, "message": "Parse error"}, + "id": None, + }, + ) + ) + + with pytest.raises(BadRequestError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + assert "400" in str(exc_info.value) or "request" in str(exc_info.value).lower() + + @respx.mock + @pytest.mark.anyio + async def test_400_generic_bad_request(self): + """400 without specific error info still produces BadRequestError.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import BadRequestError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response(400, text="Bad Request") + ) + + with pytest.raises(BadRequestError): + await MCPClient.connect("https://mcp.example.com/mcp") + + +# ============================================================================= +# 401 Unauthorized Tests (MCP Authorization Spec) +# ============================================================================= + + +class TestAuthRequiredErrors: + """Tests for 401 Unauthorized handling. + + Per MCP Authorization spec, 401 indicates: + - Authorization required + - Token invalid or expired + """ + + @respx.mock + @pytest.mark.anyio + async def test_401_no_credentials(self): + """401 without credentials produces AuthRequiredError.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import AuthRequiredError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="/.well-known/oauth-protected-resource"' + }, + ) + ) + + with pytest.raises(AuthRequiredError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + # Should mention auth/credentials + err = str(exc_info.value).lower() + assert "auth" in err or "credential" in err or "unauthorized" in err + + @respx.mock + @pytest.mark.anyio + async def test_401_invalid_token(self): + """401 with invalid_token error produces AuthRequiredError.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import AuthRequiredError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer error="invalid_token", error_description="Token has expired"' + }, + ) + ) + + with pytest.raises(AuthRequiredError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + # Should mention token expiration or invalidity + err = str(exc_info.value).lower() + assert "token" in err or "expired" in err or "invalid" in err + + @respx.mock + @pytest.mark.anyio + async def test_401_includes_www_authenticate_info(self): + """AuthRequiredError should include WWW-Authenticate details when available.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import AuthRequiredError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer realm="mcp", error="invalid_token"' + }, + ) + ) + + with pytest.raises(AuthRequiredError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + # The error should preserve useful auth context + assert exc_info.value.www_authenticate is not None or "Bearer" in str(exc_info.value) + + +# ============================================================================= +# 403 Forbidden Tests (MCP Authorization Spec) +# ============================================================================= + + +class TestForbiddenErrors: + """Tests for 403 Forbidden handling. + + Per MCP Authorization spec, 403 indicates: + - Invalid scopes + - Insufficient permissions + """ + + @respx.mock + @pytest.mark.anyio + async def test_403_insufficient_scope(self): + """403 with insufficient_scope error produces ForbiddenError.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import ForbiddenError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 403, + headers={ + "WWW-Authenticate": 'Bearer error="insufficient_scope", scope="mcp:admin"' + }, + ) + ) + + with pytest.raises(ForbiddenError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + # Should mention scope/permission + err = str(exc_info.value).lower() + assert "scope" in err or "permission" in err or "forbidden" in err + + @respx.mock + @pytest.mark.anyio + async def test_403_generic_forbidden(self): + """403 without specific error still produces ForbiddenError.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import ForbiddenError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response(403, text="Forbidden") + ) + + with pytest.raises(ForbiddenError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + assert "403" in str(exc_info.value) or "forbidden" in str(exc_info.value).lower() + + +# ============================================================================= +# 404 Not Found Tests (MCP Transport Spec) +# ============================================================================= + + +class TestSessionExpiredErrors: + """Tests for 404 Not Found handling. + + Per MCP Transport spec, 404 during a session indicates: + - Session has been terminated + - Session ID is expired or invalid + """ + + @respx.mock + @pytest.mark.anyio + async def test_404_session_terminated(self): + """404 with session context produces SessionExpiredError.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import SessionExpiredError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 404, + json={"error": "session_not_found", "message": "Session has been terminated"}, + ) + ) + + with pytest.raises(SessionExpiredError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + # Should mention session + err = str(exc_info.value).lower() + assert "session" in err or "terminated" in err or "expired" in err + + @respx.mock + @pytest.mark.anyio + async def test_404_endpoint_not_found(self): + """404 for endpoint not found produces MCPConnectionError.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import MCPConnectionError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response(404, text="Not Found") + ) + + with pytest.raises(MCPConnectionError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + # Should provide helpful message about endpoint + err = str(exc_info.value).lower() + assert "404" in str(exc_info.value) or "not found" in err or "endpoint" in err + + +# ============================================================================= +# 405 Method Not Allowed Tests (MCP Transport Spec) +# ============================================================================= + + +class TestMethodNotAllowedErrors: + """Tests for 405 Method Not Allowed handling. + + Per MCP Transport spec, 405 indicates: + - Server doesn't support GET (for SSE) + - Wrong HTTP method for the endpoint + """ + + @respx.mock + @pytest.mark.anyio + async def test_405_method_not_allowed(self): + """405 produces TransportError with method suggestion.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import TransportError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 405, + headers={"Allow": "GET"}, + text="Method Not Allowed", + ) + ) + + with pytest.raises(TransportError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + # Should mention method/transport + err = str(exc_info.value).lower() + assert "method" in err or "405" in str(exc_info.value) or "transport" in err + + +# ============================================================================= +# 415 Unsupported Media Type Tests +# ============================================================================= + + +class TestUnsupportedMediaTypeErrors: + """Tests for 415 Unsupported Media Type handling. + + 415 indicates wrong Content-Type header for the request. + """ + + @respx.mock + @pytest.mark.anyio + async def test_415_wrong_content_type(self): + """415 produces TransportError with content-type info.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import TransportError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 415, + json={"error": "Expected application/json"}, + ) + ) + + with pytest.raises(TransportError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + # Should mention content-type or media type + err = str(exc_info.value).lower() + assert "content" in err or "media" in err or "415" in str(exc_info.value) + + +# ============================================================================= +# 422 Unprocessable Entity Tests +# ============================================================================= + + +class TestUnprocessableEntityErrors: + """Tests for 422 Unprocessable Entity handling. + + 422 indicates semantic errors in the request: + - Invalid JSON-RPC structure + - Missing required fields + """ + + @respx.mock + @pytest.mark.anyio + async def test_422_invalid_jsonrpc(self): + """422 produces BadRequestError with validation info.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import BadRequestError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 422, + json={ + "jsonrpc": "2.0", + "error": {"code": -32600, "message": "Invalid Request"}, + "id": None, + }, + ) + ) + + with pytest.raises(BadRequestError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + err = str(exc_info.value).lower() + assert "invalid" in err or "422" in str(exc_info.value) or "request" in err + + +# ============================================================================= +# 5xx Server Error Tests +# ============================================================================= + + +class TestServerErrors: + """Tests for 5xx server error handling.""" + + @respx.mock + @pytest.mark.anyio + async def test_500_internal_server_error(self): + """500 produces ServerError.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import ServerError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response(500, text="Internal Server Error") + ) + + with pytest.raises(ServerError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + err = str(exc_info.value).lower() + assert "server" in err or "500" in str(exc_info.value) + + @respx.mock + @pytest.mark.anyio + async def test_502_bad_gateway(self): + """502 produces ServerError with gateway context.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import ServerError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response(502, text="Bad Gateway") + ) + + with pytest.raises(ServerError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + err = str(exc_info.value).lower() + assert "gateway" in err or "502" in str(exc_info.value) or "server" in err + + @respx.mock + @pytest.mark.anyio + async def test_503_service_unavailable(self): + """503 produces ServerError suggesting retry.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import ServerError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 503, + headers={"Retry-After": "30"}, + text="Service Unavailable", + ) + ) + + with pytest.raises(ServerError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + # Should mention unavailable or retry + err = str(exc_info.value).lower() + assert "unavailable" in err or "503" in str(exc_info.value) or "retry" in err + + @respx.mock + @pytest.mark.anyio + async def test_504_gateway_timeout(self): + """504 produces ServerError with timeout context.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import ServerError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response(504, text="Gateway Timeout") + ) + + with pytest.raises(ServerError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + err = str(exc_info.value).lower() + assert "timeout" in err or "504" in str(exc_info.value) or "gateway" in err + + +# ============================================================================= +# Error Attribute Tests +# ============================================================================= + + +class TestErrorAttributes: + """Tests for error objects having useful attributes.""" + + @respx.mock + @pytest.mark.anyio + async def test_connection_error_has_status_code(self): + """MCPConnectionError should expose the HTTP status code.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import MCPConnectionError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response(418, text="I'm a teapot") + ) + + with pytest.raises(MCPConnectionError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + assert exc_info.value.status_code == 418 + + @respx.mock + @pytest.mark.anyio + async def test_auth_error_has_www_authenticate(self): + """AuthRequiredError should expose WWW-Authenticate header.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import AuthRequiredError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 401, + headers={"WWW-Authenticate": 'Bearer realm="mcp"'}, + ) + ) + + with pytest.raises(AuthRequiredError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + assert exc_info.value.www_authenticate == 'Bearer realm="mcp"' + + @respx.mock + @pytest.mark.anyio + async def test_server_error_has_retry_after(self): + """ServerError should expose Retry-After header when present.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import ServerError + + respx.post("https://mcp.example.com/mcp").mock( + return_value=httpx.Response( + 503, + headers={"Retry-After": "60"}, + text="Service Unavailable", + ) + ) + + with pytest.raises(ServerError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + assert exc_info.value.retry_after == "60" + + +# ============================================================================= +# Network-Level Error Tests +# ============================================================================= + + +class TestNetworkErrors: + """Tests for network-level failures (not HTTP status codes).""" + + @respx.mock + @pytest.mark.anyio + async def test_connection_refused(self): + """Connection refused produces MCPConnectionError.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import MCPConnectionError + + respx.post("https://mcp.example.com/mcp").mock( + side_effect=httpx.ConnectError("Connection refused") + ) + + with pytest.raises(MCPConnectionError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + err = str(exc_info.value).lower() + assert "connect" in err or "refused" in err + + @respx.mock + @pytest.mark.anyio + async def test_dns_resolution_failure(self): + """DNS failure produces MCPConnectionError.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import MCPConnectionError + + respx.post("https://nonexistent.invalid/mcp").mock( + side_effect=httpx.ConnectError("Name or service not known") + ) + + with pytest.raises(MCPConnectionError) as exc_info: + await MCPClient.connect("https://nonexistent.invalid/mcp") + + err = str(exc_info.value).lower() + assert "connect" in err or "dns" in err or "resolve" in err + + @respx.mock + @pytest.mark.anyio + async def test_timeout(self): + """Request timeout produces MCPConnectionError.""" + from dedalus_mcp.client import MCPClient + from dedalus_mcp.client.errors import MCPConnectionError + + respx.post("https://mcp.example.com/mcp").mock( + side_effect=httpx.TimeoutException("Request timed out") + ) + + with pytest.raises(MCPConnectionError) as exc_info: + await MCPClient.connect("https://mcp.example.com/mcp") + + err = str(exc_info.value).lower() + assert "timeout" in err or "timed out" in err diff --git a/tests/test_client_transports.py b/tests/test_client_transports.py index 1850025..16806c5 100644 --- a/tests/test_client_transports.py +++ b/tests/test_client_transports.py @@ -24,7 +24,7 @@ from dedalus_mcp.client.transports import lambda_http_client -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client class SentinelError(RuntimeError): @@ -32,16 +32,15 @@ class SentinelError(RuntimeError): class FakeTransport: - """Test double mirroring the SDK transport interface without real I/O.""" + """Test double mirroring the SDK transport interface without real I/O. - def __init__( - self, url: str, headers: dict[str, str] | None, timeout: float, sse_read_timeout: float, auth: Any - ) -> None: + Note: As of MCP SDK 1.24.0, StreamableHTTPTransport only takes url in its + constructor. HTTP configuration (headers, timeout, auth) is now handled + by passing a pre-configured httpx.AsyncClient. + """ + + def __init__(self, url: str) -> None: self.url = url - self.request_headers = headers or {} - self.timeout = timeout - self.sse_read_timeout = sse_read_timeout - self.auth = auth self.session_id = "fake-session" self.start_get_stream = None self.terminated = False @@ -75,10 +74,8 @@ async def test_lambda_http_client_injects_noop_get_stream(monkeypatch: pytest.Mo """Our wrapper must *not* invoke the GET/SSE starter the SDK normally uses.""" transport_instances: list[FakeTransport] = [] - def transport_factory( - url: str, headers: dict[str, str] | None, timeout: float, sse_timeout: float, auth: Any - ) -> FakeTransport: - inst = FakeTransport(url, headers, timeout, sse_timeout, auth) + def transport_factory(url: str) -> FakeTransport: + inst = FakeTransport(url) transport_instances.append(inst) return inst @@ -106,7 +103,7 @@ async def fake_client_factory(**_: Any): @pytest.mark.anyio -async def test_streamablehttp_client_raises_when_get_stream_starts(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_streamable_http_client_raises_when_get_stream_starts(monkeypatch: pytest.MonkeyPatch) -> None: """Regression guard: stock streamable client still tries to attach SSE.""" class RaisingTransport(FakeTransport): @@ -131,10 +128,8 @@ async def post_writer( except BaseException: pass - def transport_factory( - url: str, headers: dict[str, str] | None, timeout: float, sse_timeout: float, auth: Any - ) -> RaisingTransport: - return RaisingTransport(url, headers, timeout, sse_timeout, auth) + def transport_factory(url: str) -> RaisingTransport: + return RaisingTransport(url) @asynccontextmanager async def fake_client_factory(**_: Any): @@ -144,7 +139,7 @@ async def fake_client_factory(**_: Any): monkeypatch.setattr("mcp.client.streamable_http.create_mcp_http_client", fake_client_factory) with pytest.raises(BaseExceptionGroup) as excinfo: - async with streamablehttp_client(url="https://lambda.example.test/mcp"): + async with streamable_http_client(url="https://lambda.example.test/mcp"): pass def _flatten(exc: BaseException | BaseExceptionGroup): diff --git a/tests/test_dispatch_http.py b/tests/test_dispatch_http.py index c0506b6..08b1fa6 100644 --- a/tests/test_dispatch_http.py +++ b/tests/test_dispatch_http.py @@ -215,10 +215,11 @@ def test_error_codes_exist(self): """DispatchErrorCode should define infrastructure errors.""" from dedalus_mcp.dispatch import DispatchErrorCode - assert DispatchErrorCode.CONNECTION_NOT_FOUND == "connection_not_found" - assert DispatchErrorCode.CONNECTION_REVOKED == "connection_revoked" - assert DispatchErrorCode.DOWNSTREAM_TIMEOUT == "downstream_timeout" - assert DispatchErrorCode.DOWNSTREAM_UNREACHABLE == "downstream_unreachable" + # NOTE: Wire format uses SCREAMING_CASE + assert DispatchErrorCode.CONNECTION_NOT_FOUND == "CONNECTION_NOT_FOUND" + assert DispatchErrorCode.CONNECTION_REVOKED == "CONNECTION_REVOKED" + assert DispatchErrorCode.DOWNSTREAM_TIMEOUT == "DOWNSTREAM_TIMEOUT" + assert DispatchErrorCode.DOWNSTREAM_UNREACHABLE == "DOWNSTREAM_UNREACHABLE" def test_dispatch_error_construction(self): """DispatchError should hold code, message, retryable.""" diff --git a/uv.lock b/uv.lock index 9481db1..c5462bb 100644 --- a/uv.lock +++ b/uv.lock @@ -567,7 +567,7 @@ test = [ [package.metadata] requires-dist = [ - { name = "mcp", specifier = ">=1.20.0" }, + { name = "mcp", specifier = ">=1.24.0" }, { name = "pydantic", specifier = ">=2.12.0" }, { name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.1" }, { name = "typing-extensions", specifier = ">=4.0.0" }, @@ -1171,7 +1171,7 @@ wheels = [ [[package]] name = "mcp" -version = "1.23.2" +version = "1.24.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1189,9 +1189,9 @@ dependencies = [ { name = "typing-inspection" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/39/a9/0e95530946408747ae200e86553ceda0dbd851d4ae9bbe0d02a69cbd6ad5/mcp-1.23.2.tar.gz", hash = "sha256:df4e4b7273dca2aaf428f9cf7a25bbac0c9007528a65004854b246aef3d157bc", size = 599953, upload-time = "2025-12-08T15:51:02.432Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/2c/db9ae5ab1fcdd9cd2bcc7ca3b7361b712e30590b64d5151a31563af8f82d/mcp-1.24.0.tar.gz", hash = "sha256:aeaad134664ce56f2721d1abf300666a1e8348563f4d3baff361c3b652448efc", size = 604375, upload-time = "2025-12-12T14:19:38.205Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ad/6a/1a726905cf41a69d00989e8dfd9de7bd9b4a9f3c8723dac3077b0ba1a7b9/mcp-1.23.2-py3-none-any.whl", hash = "sha256:d8e4c6af0317ad954ea0a53dfb5e229dddea2d0a54568c080e82e8fae4a8264e", size = 231897, upload-time = "2025-12-08T15:51:01.023Z" }, + { url = "https://files.pythonhosted.org/packages/61/0d/5cf14e177c8ae655a2fd9324a6ef657ca4cafd3fc2201c87716055e29641/mcp-1.24.0-py3-none-any.whl", hash = "sha256:db130e103cc50ddc3dffc928382f33ba3eaef0b711f7a87c05e7ded65b1ca062", size = 232896, upload-time = "2025-12-12T14:19:36.14Z" }, ] [[package]]