diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index 5ee0a3867..9715eb856 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -315,6 +315,29 @@ func (c *Connection) ServerInfo() (name, version string) { return initResult.ServerInfo.Name, initResult.ServerInfo.Version } +// BackendHasPromptsCapability reports whether the backend declared prompt support +// in its MCP initialize response. Only SDK-based connections (streamable, SSE, stdio) +// expose this information; plain JSON-RPC connections always return false because +// their initialize response is not parsed into a typed capability struct. +// +// In the MCP specification, a nil Capabilities.Prompts field means the server did not +// include a "prompts" entry in its capabilities object — i.e. it explicitly does not +// support prompts. A non-nil value (even an empty struct) signals prompt support. +// +// Callers should skip prompts/list on backends where this returns false to avoid +// issuing an unsupported request that could corrupt an SDK session via an EOF response. +func (c *Connection) BackendHasPromptsCapability() bool { + sess := c.getSDKSession() + if sess == nil { + return false + } + initResult := sess.InitializeResult() + if initResult == nil || initResult.Capabilities == nil { + return false + } + return initResult.Capabilities.Prompts != nil +} + // withReconnectLock acquires the session write lock, logs the reconnect attempt, runs // reconnect, logs the result, and wraps any error with a consistent message. // transportName is included in the debug log to identify which transport is reconnecting @@ -401,8 +424,10 @@ func (c *Connection) reconnectSDKTransport() error { // callSDKMethodWithReconnect calls the SDK method and, if the session has expired, // reconnects and retries exactly once before propagating the error. -func (c *Connection) callSDKMethodWithReconnect(method string, params interface{}) (*Response, error) { - result, err := c.callSDKMethod(method, params) +// ctx is the per-request context (e.g. carrying a tool-timeout deadline) and is +// forwarded to callSDKMethod so that cancellations and deadlines are respected. +func (c *Connection) callSDKMethodWithReconnect(ctx context.Context, method string, params interface{}) (*Response, error) { + result, err := c.callSDKMethod(ctx, method, params) if err != nil && isSessionNotFoundError(err) { logConn.Printf("Session not found error from SDK (serverID=%s), attempting reconnect", c.serverID) if reconnErr := c.reconnectSDKTransport(); reconnErr != nil { @@ -411,7 +436,7 @@ func (c *Connection) callSDKMethodWithReconnect(method string, params interface{ // Return the original session-not-found error so the caller sees a meaningful message. return result, err } - result, err = c.callSDKMethod(method, params) + result, err = c.callSDKMethod(ctx, method, params) } return result, err } @@ -449,12 +474,14 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string, if c.httpTransportType == HTTPTransportPlainJSON { result, err = c.sendHTTPRequest(ctx, method, params) } else { - // For streamable and SSE transports, use SDK session methods - result, err = c.callSDKMethodWithReconnect(method, params) + // For streamable and SSE transports, use SDK session methods. + // ctx is forwarded so per-request deadlines (e.g. tool timeout) are enforced. + result, err = c.callSDKMethodWithReconnect(ctx, method, params) } } else { - // Handle stdio connections using SDK client - result, err = c.callSDKMethod(method, params) + // Handle stdio connections using SDK client. + // ctx is forwarded so per-request deadlines (e.g. tool timeout) are enforced. + result, err = c.callSDKMethod(ctx, method, params) } return logInboundRPCResponseFromResult(serverID, result, err, loggingSnapshot) diff --git a/internal/mcp/connection_methods.go b/internal/mcp/connection_methods.go index c1ef6b7d8..24be44dbc 100644 --- a/internal/mcp/connection_methods.go +++ b/internal/mcp/connection_methods.go @@ -1,6 +1,7 @@ package mcp import ( + "context" "fmt" sdk "github.com/modelcontextprotocol/go-sdk/mcp" @@ -9,17 +10,22 @@ import ( // callSDKMethod calls the appropriate SDK method based on the method name. // This centralizes the method dispatch logic used by both HTTP SDK transports and stdio. // +// ctx is the per-request context (e.g. carrying a tool-timeout deadline) and is +// forwarded directly to every SDK session call so that cancellations and deadlines +// set by the caller are respected. c.ctx (the long-lived connection context) is NOT +// used here; that context only controls the lifetime of the connection itself. +// // The tools/list, resources/list, and prompts/list cases share the same structure // (paginated SDK list call via listSDKItems). They cannot be merged into a single // method because Go does not support method-level type parameters; the pagination // loop itself is already unified in listSDKItems (see pagination.go). -func (c *Connection) callSDKMethod(method string, params interface{}) (*Response, error) { +func (c *Connection) callSDKMethod(ctx context.Context, method string, params interface{}) (*Response, error) { logConn.Printf("Dispatching SDK method: %s, serverID=%s", method, c.serverID) switch method { case "tools/list": return listSDKItems(c, "tools", func(cursor string) (*sdk.ListToolsResult, error) { - return c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{Cursor: cursor}) + return c.getSDKSession().ListTools(ctx, &sdk.ListToolsParams{Cursor: cursor}) }, func(result *sdk.ListToolsResult) paginatedPage[*sdk.Tool] { return paginatedPage[*sdk.Tool]{Items: result.Tools, NextCursor: result.NextCursor} @@ -29,11 +35,11 @@ func (c *Connection) callSDKMethod(method string, params interface{}) (*Response }, ) case "tools/call": - return c.callTool(params) + return c.callTool(ctx, params) case "resources/list": return listSDKItems(c, "resources", func(cursor string) (*sdk.ListResourcesResult, error) { - return c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{Cursor: cursor}) + return c.getSDKSession().ListResources(ctx, &sdk.ListResourcesParams{Cursor: cursor}) }, func(result *sdk.ListResourcesResult) paginatedPage[*sdk.Resource] { return paginatedPage[*sdk.Resource]{Items: result.Resources, NextCursor: result.NextCursor} @@ -43,11 +49,11 @@ func (c *Connection) callSDKMethod(method string, params interface{}) (*Response }, ) case "resources/read": - return c.readResource(params) + return c.readResource(ctx, params) case "prompts/list": return listSDKItems(c, "prompts", func(cursor string) (*sdk.ListPromptsResult, error) { - return c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{Cursor: cursor}) + return c.getSDKSession().ListPrompts(ctx, &sdk.ListPromptsParams{Cursor: cursor}) }, func(result *sdk.ListPromptsResult) paginatedPage[*sdk.Prompt] { return paginatedPage[*sdk.Prompt]{Items: result.Prompts, NextCursor: result.NextCursor} @@ -57,14 +63,14 @@ func (c *Connection) callSDKMethod(method string, params interface{}) (*Response }, ) case "prompts/get": - return c.getPrompt(params) + return c.getPrompt(ctx, params) default: logConn.Printf("Unsupported method: %s", method) return nil, fmt.Errorf("unsupported method: %s", method) } } -func (c *Connection) callTool(params interface{}) (*Response, error) { +func (c *Connection) callTool(ctx context.Context, params interface{}) (*Response, error) { return callParamMethod(c, params, func(p CallToolParams) (interface{}, error) { // Ensure arguments is never nil - default to empty map // This is required by the MCP protocol which expects arguments to always be present @@ -72,33 +78,33 @@ func (c *Connection) callTool(params interface{}) (*Response, error) { p.Arguments = make(map[string]interface{}) } logConn.Printf("callTool: parsed name=%s, argumentCount=%d", p.Name, len(p.Arguments)) - return c.getSDKSession().CallTool(c.ctx, &sdk.CallToolParams{ + return c.getSDKSession().CallTool(ctx, &sdk.CallToolParams{ Name: p.Name, Arguments: p.Arguments, }) }) } -func (c *Connection) readResource(params interface{}) (*Response, error) { +func (c *Connection) readResource(ctx context.Context, params interface{}) (*Response, error) { type readResourceParams struct { URI string `json:"uri"` } return callParamMethod(c, params, func(p readResourceParams) (interface{}, error) { logConn.Printf("readResource: reading resource uri=%s from serverID=%s", p.URI, c.serverID) - return c.getSDKSession().ReadResource(c.ctx, &sdk.ReadResourceParams{ + return c.getSDKSession().ReadResource(ctx, &sdk.ReadResourceParams{ URI: p.URI, }) }) } -func (c *Connection) getPrompt(params interface{}) (*Response, error) { +func (c *Connection) getPrompt(ctx context.Context, params interface{}) (*Response, error) { type getPromptParams struct { Name string `json:"name"` Arguments map[string]string `json:"arguments"` } return callParamMethod(c, params, func(p getPromptParams) (interface{}, error) { logConn.Printf("getPrompt: getting prompt name=%s from serverID=%s", p.Name, c.serverID) - return c.getSDKSession().GetPrompt(c.ctx, &sdk.GetPromptParams{ + return c.getSDKSession().GetPrompt(ctx, &sdk.GetPromptParams{ Name: p.Name, Arguments: p.Arguments, }) diff --git a/internal/mcp/reconnect_sdk_test.go b/internal/mcp/reconnect_sdk_test.go index bde1ef553..d810c6222 100644 --- a/internal/mcp/reconnect_sdk_test.go +++ b/internal/mcp/reconnect_sdk_test.go @@ -188,7 +188,7 @@ func TestCallSDKMethodWithReconnect_NoReconnectOnNonSessionError(t *testing.T) { require.Equal(t, HTTPTransportStreamable, conn.httpTransportType) require.Equal(t, int32(1), initCount.Load(), "expected one initialize during initial connect") - result, err := conn.callSDKMethodWithReconnect("tools/list", nil) + result, err := conn.callSDKMethodWithReconnect(context.Background(), "tools/list", nil) require.Error(t, err) assert.Nil(t, result) @@ -228,7 +228,7 @@ func TestCallSDKMethodWithReconnect_SessionNotFound_ReconnectFails(t *testing.T) // returns an error immediately without making network calls. conn.httpTransportType = HTTPTransportType("none") - result, err := conn.callSDKMethodWithReconnect("tools/list", nil) + result, err := conn.callSDKMethodWithReconnect(context.Background(), "tools/list", nil) require.Error(t, err) assert.Nil(t, result) @@ -290,7 +290,7 @@ func TestCallSDKMethodWithReconnect_SessionNotFound_ReconnectSucceeds(t *testing require.Equal(t, HTTPTransportStreamable, conn.httpTransportType) require.Equal(t, int32(1), initCount.Load(), "should have had exactly one initialize during NewHTTPConnection") - result, err := conn.callSDKMethodWithReconnect("tools/list", nil) + result, err := conn.callSDKMethodWithReconnect(context.Background(), "tools/list", nil) require.NoError(t, err) require.NotNil(t, result) @@ -333,7 +333,7 @@ func TestCallSDKMethodWithReconnect_Success_NoReconnect(t *testing.T) { conn := newStreamableConn(t, srv.URL) require.Equal(t, int32(1), initCount.Load(), "expected one initialize during initial connect") - result, err := conn.callSDKMethodWithReconnect("tools/list", nil) + result, err := conn.callSDKMethodWithReconnect(context.Background(), "tools/list", nil) require.NoError(t, err) require.NotNil(t, result) @@ -378,7 +378,7 @@ func TestListResources_WithSession(t *testing.T) { conn := newStreamableConn(t, srv.URL) - result, err := conn.callSDKMethod("resources/list", nil) + result, err := conn.callSDKMethod(context.Background(), "resources/list", nil) require.NoError(t, err) require.NotNil(t, result) @@ -426,7 +426,7 @@ func TestListPrompts_WithSession(t *testing.T) { conn := newStreamableConn(t, srv.URL) - result, err := conn.callSDKMethod("prompts/list", nil) + result, err := conn.callSDKMethod(context.Background(), "prompts/list", nil) require.NoError(t, err) require.NotNil(t, result) @@ -466,7 +466,7 @@ func TestReadResource_WithSession(t *testing.T) { conn := newStreamableConn(t, srv.URL) - result, err := conn.callSDKMethod("resources/read", map[string]interface{}{"uri": "file:///test.txt"}) + result, err := conn.callSDKMethod(context.Background(), "resources/read", map[string]interface{}{"uri": "file:///test.txt"}) require.Error(t, err) assert.Nil(t, result) @@ -497,7 +497,7 @@ func TestGetPrompt_WithSession(t *testing.T) { conn := newStreamableConn(t, srv.URL) - result, err := conn.callSDKMethod("prompts/get", map[string]interface{}{ + result, err := conn.callSDKMethod(context.Background(), "prompts/get", map[string]interface{}{ "name": "summarise", "arguments": map[string]string{}, }) @@ -583,7 +583,7 @@ func TestListResources_Empty(t *testing.T) { defer srv.Close() conn := newStreamableConn(t, srv.URL) - result, err := conn.callSDKMethod("resources/list", nil) + result, err := conn.callSDKMethod(context.Background(), "resources/list", nil) require.NoError(t, err) require.NotNil(t, result) @@ -617,7 +617,7 @@ func TestListPrompts_Empty(t *testing.T) { defer srv.Close() conn := newStreamableConn(t, srv.URL) - result, err := conn.callSDKMethod("prompts/list", nil) + result, err := conn.callSDKMethod(context.Background(), "prompts/list", nil) require.NoError(t, err) require.NotNil(t, result) diff --git a/internal/mcp/sdk_method_dispatch_test.go b/internal/mcp/sdk_method_dispatch_test.go index e93cfad45..ab202dd03 100644 --- a/internal/mcp/sdk_method_dispatch_test.go +++ b/internal/mcp/sdk_method_dispatch_test.go @@ -97,7 +97,7 @@ func TestCallSDKMethod_UnsupportedMethod(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { conn := newTestConnection(t) - result, err := conn.callSDKMethod(tt.method, nil) + result, err := conn.callSDKMethod(context.Background(), tt.method, nil) require.Error(t, err) assert.Nil(t, result) assert.ErrorContains(t, err, "unsupported method") @@ -110,7 +110,7 @@ func TestCallSDKMethod_UnsupportedMethod(t *testing.T) { // switch case is reached and returns a session error when no session is configured. func TestCallSDKMethod_ResourcesList_NilSession(t *testing.T) { conn := newTestConnection(t) - result, err := conn.callSDKMethod("resources/list", nil) + result, err := conn.callSDKMethod(context.Background(), "resources/list", nil) require.Error(t, err) assert.Nil(t, result) assert.ErrorContains(t, err, "SDK session not available") @@ -120,7 +120,7 @@ func TestCallSDKMethod_ResourcesList_NilSession(t *testing.T) { // and returns a session error when no session is configured. func TestCallSDKMethod_ResourcesRead_NilSession(t *testing.T) { conn := newTestConnection(t) - result, err := conn.callSDKMethod("resources/read", map[string]interface{}{"uri": "file:///test"}) + result, err := conn.callSDKMethod(context.Background(), "resources/read", map[string]interface{}{"uri": "file:///test"}) require.Error(t, err) assert.Nil(t, result) assert.ErrorContains(t, err, "SDK session not available") @@ -130,7 +130,7 @@ func TestCallSDKMethod_ResourcesRead_NilSession(t *testing.T) { // and returns a session error when no session is configured. func TestCallSDKMethod_PromptsList_NilSession(t *testing.T) { conn := newTestConnection(t) - result, err := conn.callSDKMethod("prompts/list", nil) + result, err := conn.callSDKMethod(context.Background(), "prompts/list", nil) require.Error(t, err) assert.Nil(t, result) assert.ErrorContains(t, err, "SDK session not available") @@ -140,7 +140,7 @@ func TestCallSDKMethod_PromptsList_NilSession(t *testing.T) { // and returns a session error when no session is configured. func TestCallSDKMethod_PromptsGet_NilSession(t *testing.T) { conn := newTestConnection(t) - result, err := conn.callSDKMethod("prompts/get", map[string]interface{}{"name": "my-prompt"}) + result, err := conn.callSDKMethod(context.Background(), "prompts/get", map[string]interface{}{"name": "my-prompt"}) require.Error(t, err) assert.Nil(t, result) assert.ErrorContains(t, err, "SDK session not available") @@ -151,7 +151,7 @@ func TestCallSDKMethod_PromptsGet_NilSession(t *testing.T) { // existing HTTP plain-JSON tests that never reach callSDKMethod). func TestCallSDKMethod_ToolsList_NilSession(t *testing.T) { conn := newTestConnection(t) - result, err := conn.callSDKMethod("tools/list", nil) + result, err := conn.callSDKMethod(context.Background(), "tools/list", nil) require.Error(t, err) assert.Nil(t, result) assert.ErrorContains(t, err, "SDK session not available") @@ -161,7 +161,7 @@ func TestCallSDKMethod_ToolsList_NilSession(t *testing.T) { // and returns a session error when no session is configured. func TestCallSDKMethod_ToolsCall_NilSession(t *testing.T) { conn := newTestConnection(t) - result, err := conn.callSDKMethod("tools/call", map[string]interface{}{"name": "my-tool"}) + result, err := conn.callSDKMethod(context.Background(), "tools/call", map[string]interface{}{"name": "my-tool"}) require.Error(t, err) assert.Nil(t, result) assert.ErrorContains(t, err, "SDK session not available") diff --git a/internal/server/tool_registry.go b/internal/server/tool_registry.go index d6a1fd0bd..d1a87d3ba 100644 --- a/internal/server/tool_registry.go +++ b/internal/server/tool_registry.go @@ -340,13 +340,111 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error { } logUnified.Printf("Registered %d tools from %s: %s", len(listResult.Tools), serverID, strings.Join(toolNames, ", ")) + + // Register prompts from this backend. Prompt support is optional; failures are + // logged but do not cause tool registration to fail. + if err := us.registerPromptsFromBackend(context.Background(), serverID, conn); err != nil { + logger.LogWarn("backend", "Failed to register prompts from %s (non-fatal): %v", serverID, err) + } + return nil } -// registerSysTool is a helper function that registers a sys tool by storing its metadata -// in the internal tools map. Sys tools are deprecated: agent labels are set when a guard -// is initialized via label_agent, so sys tools no longer need to be exposed to agents. -// The handler implementations are kept for potential future use. +// registerPromptsFromBackend registers prompts from a specific backend with ___ +// naming, mirroring the tool registration convention. Prompt capability is optional in the MCP +// spec; backends that do not support prompts/list will return an error that is treated as a +// graceful skip rather than a hard failure. +func (us *UnifiedServer) registerPromptsFromBackend(ctx context.Context, serverID string, conn *mcp.Connection) error { + // Only call prompts/list on backends that explicitly declared prompt support in + // their initialize response. For SDK-based connections (streamable, SSE, stdio), + // an unsupported prompts/list request can return EOF which the SDK interprets as + // a session close, breaking subsequent tool calls on the same connection. + // Plain JSON-RPC connections return false here too; their initialize response is + // not parsed into typed capabilities, so we cannot safely detect support. + if !conn.BackendHasPromptsCapability() { + logUnified.Printf("Backend %s does not declare prompts capability (skipping)", serverID) + return nil + } + + // List prompts from backend + result, err := conn.SendRequestWithServerID(ctx, "prompts/list", nil, serverID) + if err != nil { + // Many backends do not implement prompts — treat as a graceful skip. + logUnified.Printf("Backend %s does not support prompts/list (skipping): %v", serverID, err) + return nil + } + + // A JSON-RPC error from the backend also means prompts are unavailable. + if result.Error != nil { + logUnified.Printf("Backend %s returned error for prompts/list (skipping): code=%d, message=%s", + serverID, result.Error.Code, result.Error.Message) + return nil + } + + // Parse the prompt list. + var listResult struct { + Prompts []*sdk.Prompt `json:"prompts"` + } + if err := json.Unmarshal(result.Result, &listResult); err != nil { + return fmt.Errorf("failed to parse prompts from %s: %w", serverID, err) + } + + if len(listResult.Prompts) == 0 { + logUnified.Printf("Backend %s has no prompts to register", serverID) + return nil + } + + // Register each prompt with a prefixed name so front-end clients can + // distinguish prompts from different backends. + promptNames := make([]string, 0, len(listResult.Prompts)) + for _, prompt := range listResult.Prompts { + prefixedName := fmt.Sprintf("%s___%s", serverID, prompt.Name) + promptDesc := fmt.Sprintf("[%s] %s", serverID, prompt.Description) + promptNames = append(promptNames, prompt.Name) + + serverIDCopy := serverID + promptNameCopy := prompt.Name + + sdkPrompt := &sdk.Prompt{ + Name: prefixedName, + Description: promptDesc, + Arguments: prompt.Arguments, + } + + us.server.AddPrompt(sdkPrompt, func(ctx context.Context, req *sdk.GetPromptRequest) (*sdk.GetPromptResult, error) { + sessionID := us.getSessionID(ctx) + backendConn, connErr := launcher.GetOrLaunchForSession(us.launcher, serverIDCopy, sessionID) + if connErr != nil { + return nil, fmt.Errorf("failed to connect to backend %s: %w", serverIDCopy, connErr) + } + + params := map[string]interface{}{ + "name": promptNameCopy, + "arguments": req.Params.Arguments, + } + + response, reqErr := backendConn.SendRequestWithServerID(ctx, "prompts/get", params, serverIDCopy) + if reqErr != nil { + return nil, reqErr + } + if response.Error != nil { + return nil, fmt.Errorf("backend error for prompts/get %s/%s: code=%d, message=%s", + serverIDCopy, promptNameCopy, response.Error.Code, response.Error.Message) + } + + var promptResult sdk.GetPromptResult + if unmarshErr := json.Unmarshal(response.Result, &promptResult); unmarshErr != nil { + return nil, fmt.Errorf("failed to parse prompt result: %w", unmarshErr) + } + return &promptResult, nil + }) + + logUnified.Printf("Registered prompt: %s___%s", serverID, promptNameCopy) + } + + logUnified.Printf("Registered %d prompts from %s: %s", len(listResult.Prompts), serverID, strings.Join(promptNames, ", ")) + return nil +} func (us *UnifiedServer) registerSysTool(name, description string, inputSchema map[string]interface{}, handler func(context.Context, *sdk.CallToolRequest, interface{}) (*sdk.CallToolResult, interface{}, error)) { // Store tool info internally only -- sys tools are intentionally NOT registered // with the MCP SDK server and therefore never appear in tools/list.