diff --git a/src/dedalus_mcp/context.py b/src/dedalus_mcp/context.py index 6e449ca..725cca5 100644 --- a/src/dedalus_mcp/context.py +++ b/src/dedalus_mcp/context.py @@ -418,7 +418,7 @@ def _get_connections(self, runtime: Mapping[str, Any]) -> dict[str, str]: connections = claims.get("ddls:connections") if not isinstance(connections, dict): - raise RuntimeError("Missing ddls:connections claim") + raise RuntimeError("Missing required JWT claims for connection resolution") return dict(connections) diff --git a/src/dedalus_mcp/server/core.py b/src/dedalus_mcp/server/core.py index fdd8140..7b218cd 100644 --- a/src/dedalus_mcp/server/core.py +++ b/src/dedalus_mcp/server/core.py @@ -163,6 +163,7 @@ def __init__( notification_sink: NotificationSink | None = None, http_security: TransportSecuritySettings | None = None, authorization: AuthorizationConfig | None = None, + authorization_server: str = "https://as.dedaluslabs.ai", streamable_http_stateless: bool = False, allow_dynamic_tools: bool = False, resource_uri: str | None = None, @@ -244,15 +245,27 @@ def __init__( # Auto-enable authorization when connections are defined (they require JWT claims) if authorization is not None: auth_config = authorization + auto_configure_jwt = False elif connections: # Connections require auth to resolve name → handle from JWT auth_config = AuthorizationConfig(enabled=True) + auto_configure_jwt = True else: auth_config = AuthorizationConfig() + auto_configure_jwt = False self._authorization_manager: AuthorizationManager | None = None if auth_config.enabled: self._authorization_manager = AuthorizationManager(auth_config) + # Auto-configure JWT validator when connections trigger auto-enable + if auto_configure_jwt: + from .services.jwt_validator import JWTValidator, JWTValidatorConfig + as_url = authorization_server.rstrip("/") + jwt_config = JWTValidatorConfig( + jwks_uri=f"{as_url}/.well-known/jwks.json", + issuer=as_url, + ) + self._authorization_manager.set_provider(JWTValidator(jwt_config)) self._transport_factories: dict[str, TransportFactory] = {} self.register_transport("stdio", lambda server: StdioTransport(server)) diff --git a/tests/test_context_dispatch.py b/tests/test_context_dispatch.py index 4bfe68e..e5cf794 100644 --- a/tests/test_context_dispatch.py +++ b/tests/test_context_dispatch.py @@ -50,11 +50,17 @@ def mock_resolver(handle: str) -> tuple[str, str, str]: @pytest.fixture def auth_context(self): - """Auth context with org reference (gateway validates connections).""" + """Auth context with connections MAP (required for dispatch).""" return AuthorizationContext( subject='user123', scopes=['mcp:tools:call'], - claims={'ddls:org': 'org_123'}, + claims={ + 'ddls:org': 'org_123', + 'ddls:connections': { + 'github': 'ddls:conn:019b2464-d1c1-7751-a409-ed273f51da82', + 'invalid': 'not-a-valid-handle', # For invalid handle test + }, + }, ) @pytest.fixture @@ -225,3 +231,103 @@ async def test_dispatch_no_auth_context_raises_error(self, backend): # Without auth context, dispatch fails (can't look up connections from JWT) with pytest.raises(RuntimeError, match='Authorization context is None'): await ctx.dispatch('github', request) + + @pytest.mark.asyncio + async def test_dispatch_with_jwt_connections_claim(self, backend): + """Full flow: JWT with ddls:connections claim → dispatch resolves name → handle.""" + # Simulate JWT claims with connection MAP format + jwt_claims = { + "sub": "user_123", + "aud": "https://mcp.example.com", + "ddls:connections": { + "github": "ddls:conn:019b2464-d1c1-7751-a409-ed273f51da82", + "supabase": "ddls:conn:019b2464-d1c1-7751-a409-ed273f51da83", + }, + } + auth_context = AuthorizationContext( + subject="user_123", + scopes=["mcp:tools:call"], + claims=jwt_claims, + ) + + mock_request_ctx = MockRequestContext( + lifespan_context={'dedalus_mcp.runtime': {'dispatch_backend': backend}} + ) + mock_request = MagicMock() + # Simulate auth middleware having set the auth context in scope + mock_request.scope = {"dedalus_mcp.auth": auth_context} + mock_request_ctx.request = mock_request + + ctx = Context( + _request_context=mock_request_ctx, + runtime={'dispatch_backend': backend} + ) + request = HttpRequest(method=HttpMethod.GET, path="/user") + + # Dispatch by connection NAME - should resolve to handle from JWT claims + result = await ctx.dispatch('github', request) + assert result.success is True + + @pytest.mark.asyncio + async def test_dispatch_connection_not_in_jwt_claims(self, backend): + """Dispatch fails if connection name not in JWT ddls:connections.""" + jwt_claims = { + "sub": "user_123", + "ddls:connections": { + "github": "ddls:conn:019b2464-d1c1-7751-a409-ed273f51da82", + }, + } + auth_context = AuthorizationContext( + subject="user_123", + scopes=[], + claims=jwt_claims, + ) + + mock_request_ctx = MockRequestContext( + lifespan_context={'dedalus_mcp.runtime': {'dispatch_backend': backend}} + ) + mock_request = MagicMock() + mock_request.scope = {"dedalus_mcp.auth": auth_context} + mock_request_ctx.request = mock_request + + ctx = Context( + _request_context=mock_request_ctx, + runtime={'dispatch_backend': backend} + ) + request = HttpRequest(method=HttpMethod.GET, path="/query") + + # "supabase" not in JWT claims - should fail + with pytest.raises(ValueError, match="Connection 'supabase' not found"): + await ctx.dispatch('supabase', request) + + @pytest.mark.asyncio + async def test_dispatch_invalid_jwt_connections_format(self, backend): + """Dispatch fails if ddls:connections is not a dict (old list format).""" + # Old LIST format - should fail + jwt_claims = { + "sub": "user_123", + "ddls:connections": [ + {"handle": "ddls:conn:123", "provider": "github"}, + ], + } + auth_context = AuthorizationContext( + subject="user_123", + scopes=[], + claims=jwt_claims, + ) + + mock_request_ctx = MockRequestContext( + lifespan_context={'dedalus_mcp.runtime': {'dispatch_backend': backend}} + ) + mock_request = MagicMock() + mock_request.scope = {"dedalus_mcp.auth": auth_context} + mock_request_ctx.request = mock_request + + ctx = Context( + _request_context=mock_request_ctx, + runtime={'dispatch_backend': backend} + ) + request = HttpRequest(method=HttpMethod.GET, path="/user") + + with pytest.raises(RuntimeError, match="Missing required JWT claims"): + await ctx.dispatch('github', request)