Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,29 @@ func (c *Connection) ServerInfo() (name, version string) {
return initResult.ServerInfo.Name, initResult.ServerInfo.Version
}

// BackendHasPromptsCapability reports whether the backend declared prompt support
// in its MCP initialize response. Only SDK-based connections (streamable, SSE, stdio)
// expose this information; plain JSON-RPC connections always return false because
// their initialize response is not parsed into a typed capability struct.
//
// In the MCP specification, a nil Capabilities.Prompts field means the server did not
// include a "prompts" entry in its capabilities object — i.e. it explicitly does not
// support prompts. A non-nil value (even an empty struct) signals prompt support.
//
// Callers should skip prompts/list on backends where this returns false to avoid
// issuing an unsupported request that could corrupt an SDK session via an EOF response.
func (c *Connection) BackendHasPromptsCapability() bool {
sess := c.getSDKSession()
if sess == nil {
return false
}
initResult := sess.InitializeResult()
if initResult == nil || initResult.Capabilities == nil {
return false
}
return initResult.Capabilities.Prompts != nil
}

// withReconnectLock acquires the session write lock, logs the reconnect attempt, runs
// reconnect, logs the result, and wraps any error with a consistent message.
// transportName is included in the debug log to identify which transport is reconnecting
Expand Down Expand Up @@ -401,8 +424,10 @@ func (c *Connection) reconnectSDKTransport() error {

// callSDKMethodWithReconnect calls the SDK method and, if the session has expired,
// reconnects and retries exactly once before propagating the error.
func (c *Connection) callSDKMethodWithReconnect(method string, params interface{}) (*Response, error) {
result, err := c.callSDKMethod(method, params)
// ctx is the per-request context (e.g. carrying a tool-timeout deadline) and is
// forwarded to callSDKMethod so that cancellations and deadlines are respected.
func (c *Connection) callSDKMethodWithReconnect(ctx context.Context, method string, params interface{}) (*Response, error) {
result, err := c.callSDKMethod(ctx, method, params)
if err != nil && isSessionNotFoundError(err) {
logConn.Printf("Session not found error from SDK (serverID=%s), attempting reconnect", c.serverID)
if reconnErr := c.reconnectSDKTransport(); reconnErr != nil {
Expand All @@ -411,7 +436,7 @@ func (c *Connection) callSDKMethodWithReconnect(method string, params interface{
// Return the original session-not-found error so the caller sees a meaningful message.
return result, err
}
result, err = c.callSDKMethod(method, params)
result, err = c.callSDKMethod(ctx, method, params)
}
return result, err
}
Expand Down Expand Up @@ -449,12 +474,14 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string,
if c.httpTransportType == HTTPTransportPlainJSON {
result, err = c.sendHTTPRequest(ctx, method, params)
} else {
// For streamable and SSE transports, use SDK session methods
result, err = c.callSDKMethodWithReconnect(method, params)
// For streamable and SSE transports, use SDK session methods.
// ctx is forwarded so per-request deadlines (e.g. tool timeout) are enforced.
result, err = c.callSDKMethodWithReconnect(ctx, method, params)
}
} else {
// Handle stdio connections using SDK client
result, err = c.callSDKMethod(method, params)
// Handle stdio connections using SDK client.
// ctx is forwarded so per-request deadlines (e.g. tool timeout) are enforced.
result, err = c.callSDKMethod(ctx, method, params)
}

return logInboundRPCResponseFromResult(serverID, result, err, loggingSnapshot)
Expand Down
32 changes: 19 additions & 13 deletions internal/mcp/connection_methods.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mcp

import (
"context"
"fmt"

sdk "github.com/modelcontextprotocol/go-sdk/mcp"
Expand All @@ -9,17 +10,22 @@ import (
// callSDKMethod calls the appropriate SDK method based on the method name.
// This centralizes the method dispatch logic used by both HTTP SDK transports and stdio.
//
// ctx is the per-request context (e.g. carrying a tool-timeout deadline) and is
// forwarded directly to every SDK session call so that cancellations and deadlines
// set by the caller are respected. c.ctx (the long-lived connection context) is NOT
// used here; that context only controls the lifetime of the connection itself.
//
// The tools/list, resources/list, and prompts/list cases share the same structure
// (paginated SDK list call via listSDKItems). They cannot be merged into a single
// method because Go does not support method-level type parameters; the pagination
// loop itself is already unified in listSDKItems (see pagination.go).
func (c *Connection) callSDKMethod(method string, params interface{}) (*Response, error) {
func (c *Connection) callSDKMethod(ctx context.Context, method string, params interface{}) (*Response, error) {
logConn.Printf("Dispatching SDK method: %s, serverID=%s", method, c.serverID)
switch method {
case "tools/list":
return listSDKItems(c, "tools",
func(cursor string) (*sdk.ListToolsResult, error) {
return c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{Cursor: cursor})
return c.getSDKSession().ListTools(ctx, &sdk.ListToolsParams{Cursor: cursor})
},
func(result *sdk.ListToolsResult) paginatedPage[*sdk.Tool] {
return paginatedPage[*sdk.Tool]{Items: result.Tools, NextCursor: result.NextCursor}
Expand All @@ -29,11 +35,11 @@ func (c *Connection) callSDKMethod(method string, params interface{}) (*Response
},
)
case "tools/call":
return c.callTool(params)
return c.callTool(ctx, params)
case "resources/list":
return listSDKItems(c, "resources",
func(cursor string) (*sdk.ListResourcesResult, error) {
return c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{Cursor: cursor})
return c.getSDKSession().ListResources(ctx, &sdk.ListResourcesParams{Cursor: cursor})
},
func(result *sdk.ListResourcesResult) paginatedPage[*sdk.Resource] {
return paginatedPage[*sdk.Resource]{Items: result.Resources, NextCursor: result.NextCursor}
Expand All @@ -43,11 +49,11 @@ func (c *Connection) callSDKMethod(method string, params interface{}) (*Response
},
)
case "resources/read":
return c.readResource(params)
return c.readResource(ctx, params)
case "prompts/list":
return listSDKItems(c, "prompts",
func(cursor string) (*sdk.ListPromptsResult, error) {
return c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{Cursor: cursor})
return c.getSDKSession().ListPrompts(ctx, &sdk.ListPromptsParams{Cursor: cursor})
},
func(result *sdk.ListPromptsResult) paginatedPage[*sdk.Prompt] {
return paginatedPage[*sdk.Prompt]{Items: result.Prompts, NextCursor: result.NextCursor}
Expand All @@ -57,48 +63,48 @@ func (c *Connection) callSDKMethod(method string, params interface{}) (*Response
},
)
case "prompts/get":
return c.getPrompt(params)
return c.getPrompt(ctx, params)
default:
logConn.Printf("Unsupported method: %s", method)
return nil, fmt.Errorf("unsupported method: %s", method)
}
}

func (c *Connection) callTool(params interface{}) (*Response, error) {
func (c *Connection) callTool(ctx context.Context, params interface{}) (*Response, error) {
return callParamMethod(c, params, func(p CallToolParams) (interface{}, error) {
// Ensure arguments is never nil - default to empty map
// This is required by the MCP protocol which expects arguments to always be present
if p.Arguments == nil {
p.Arguments = make(map[string]interface{})
}
logConn.Printf("callTool: parsed name=%s, argumentCount=%d", p.Name, len(p.Arguments))
return c.getSDKSession().CallTool(c.ctx, &sdk.CallToolParams{
return c.getSDKSession().CallTool(ctx, &sdk.CallToolParams{
Name: p.Name,
Arguments: p.Arguments,
})
})
}

func (c *Connection) readResource(params interface{}) (*Response, error) {
func (c *Connection) readResource(ctx context.Context, params interface{}) (*Response, error) {
type readResourceParams struct {
URI string `json:"uri"`
}
return callParamMethod(c, params, func(p readResourceParams) (interface{}, error) {
logConn.Printf("readResource: reading resource uri=%s from serverID=%s", p.URI, c.serverID)
return c.getSDKSession().ReadResource(c.ctx, &sdk.ReadResourceParams{
return c.getSDKSession().ReadResource(ctx, &sdk.ReadResourceParams{
URI: p.URI,
})
})
}

func (c *Connection) getPrompt(params interface{}) (*Response, error) {
func (c *Connection) getPrompt(ctx context.Context, params interface{}) (*Response, error) {
type getPromptParams struct {
Name string `json:"name"`
Arguments map[string]string `json:"arguments"`
}
return callParamMethod(c, params, func(p getPromptParams) (interface{}, error) {
logConn.Printf("getPrompt: getting prompt name=%s from serverID=%s", p.Name, c.serverID)
return c.getSDKSession().GetPrompt(c.ctx, &sdk.GetPromptParams{
return c.getSDKSession().GetPrompt(ctx, &sdk.GetPromptParams{
Name: p.Name,
Arguments: p.Arguments,
})
Expand Down
20 changes: 10 additions & 10 deletions internal/mcp/reconnect_sdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func TestCallSDKMethodWithReconnect_NoReconnectOnNonSessionError(t *testing.T) {
require.Equal(t, HTTPTransportStreamable, conn.httpTransportType)
require.Equal(t, int32(1), initCount.Load(), "expected one initialize during initial connect")

result, err := conn.callSDKMethodWithReconnect("tools/list", nil)
result, err := conn.callSDKMethodWithReconnect(context.Background(), "tools/list", nil)

require.Error(t, err)
assert.Nil(t, result)
Expand Down Expand Up @@ -228,7 +228,7 @@ func TestCallSDKMethodWithReconnect_SessionNotFound_ReconnectFails(t *testing.T)
// returns an error immediately without making network calls.
conn.httpTransportType = HTTPTransportType("none")

result, err := conn.callSDKMethodWithReconnect("tools/list", nil)
result, err := conn.callSDKMethodWithReconnect(context.Background(), "tools/list", nil)

require.Error(t, err)
assert.Nil(t, result)
Expand Down Expand Up @@ -290,7 +290,7 @@ func TestCallSDKMethodWithReconnect_SessionNotFound_ReconnectSucceeds(t *testing
require.Equal(t, HTTPTransportStreamable, conn.httpTransportType)
require.Equal(t, int32(1), initCount.Load(), "should have had exactly one initialize during NewHTTPConnection")

result, err := conn.callSDKMethodWithReconnect("tools/list", nil)
result, err := conn.callSDKMethodWithReconnect(context.Background(), "tools/list", nil)

require.NoError(t, err)
require.NotNil(t, result)
Expand Down Expand Up @@ -333,7 +333,7 @@ func TestCallSDKMethodWithReconnect_Success_NoReconnect(t *testing.T) {
conn := newStreamableConn(t, srv.URL)
require.Equal(t, int32(1), initCount.Load(), "expected one initialize during initial connect")

result, err := conn.callSDKMethodWithReconnect("tools/list", nil)
result, err := conn.callSDKMethodWithReconnect(context.Background(), "tools/list", nil)

require.NoError(t, err)
require.NotNil(t, result)
Expand Down Expand Up @@ -378,7 +378,7 @@ func TestListResources_WithSession(t *testing.T) {

conn := newStreamableConn(t, srv.URL)

result, err := conn.callSDKMethod("resources/list", nil)
result, err := conn.callSDKMethod(context.Background(), "resources/list", nil)

require.NoError(t, err)
require.NotNil(t, result)
Expand Down Expand Up @@ -426,7 +426,7 @@ func TestListPrompts_WithSession(t *testing.T) {

conn := newStreamableConn(t, srv.URL)

result, err := conn.callSDKMethod("prompts/list", nil)
result, err := conn.callSDKMethod(context.Background(), "prompts/list", nil)

require.NoError(t, err)
require.NotNil(t, result)
Expand Down Expand Up @@ -466,7 +466,7 @@ func TestReadResource_WithSession(t *testing.T) {

conn := newStreamableConn(t, srv.URL)

result, err := conn.callSDKMethod("resources/read", map[string]interface{}{"uri": "file:///test.txt"})
result, err := conn.callSDKMethod(context.Background(), "resources/read", map[string]interface{}{"uri": "file:///test.txt"})

require.Error(t, err)
assert.Nil(t, result)
Expand Down Expand Up @@ -497,7 +497,7 @@ func TestGetPrompt_WithSession(t *testing.T) {

conn := newStreamableConn(t, srv.URL)

result, err := conn.callSDKMethod("prompts/get", map[string]interface{}{
result, err := conn.callSDKMethod(context.Background(), "prompts/get", map[string]interface{}{
"name": "summarise",
"arguments": map[string]string{},
})
Expand Down Expand Up @@ -583,7 +583,7 @@ func TestListResources_Empty(t *testing.T) {
defer srv.Close()

conn := newStreamableConn(t, srv.URL)
result, err := conn.callSDKMethod("resources/list", nil)
result, err := conn.callSDKMethod(context.Background(), "resources/list", nil)

require.NoError(t, err)
require.NotNil(t, result)
Expand Down Expand Up @@ -617,7 +617,7 @@ func TestListPrompts_Empty(t *testing.T) {
defer srv.Close()

conn := newStreamableConn(t, srv.URL)
result, err := conn.callSDKMethod("prompts/list", nil)
result, err := conn.callSDKMethod(context.Background(), "prompts/list", nil)

require.NoError(t, err)
require.NotNil(t, result)
Expand Down
14 changes: 7 additions & 7 deletions internal/mcp/sdk_method_dispatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestCallSDKMethod_UnsupportedMethod(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conn := newTestConnection(t)
result, err := conn.callSDKMethod(tt.method, nil)
result, err := conn.callSDKMethod(context.Background(), tt.method, nil)
require.Error(t, err)
assert.Nil(t, result)
assert.ErrorContains(t, err, "unsupported method")
Expand All @@ -110,7 +110,7 @@ func TestCallSDKMethod_UnsupportedMethod(t *testing.T) {
// switch case is reached and returns a session error when no session is configured.
func TestCallSDKMethod_ResourcesList_NilSession(t *testing.T) {
conn := newTestConnection(t)
result, err := conn.callSDKMethod("resources/list", nil)
result, err := conn.callSDKMethod(context.Background(), "resources/list", nil)
require.Error(t, err)
assert.Nil(t, result)
assert.ErrorContains(t, err, "SDK session not available")
Expand All @@ -120,7 +120,7 @@ func TestCallSDKMethod_ResourcesList_NilSession(t *testing.T) {
// and returns a session error when no session is configured.
func TestCallSDKMethod_ResourcesRead_NilSession(t *testing.T) {
conn := newTestConnection(t)
result, err := conn.callSDKMethod("resources/read", map[string]interface{}{"uri": "file:///test"})
result, err := conn.callSDKMethod(context.Background(), "resources/read", map[string]interface{}{"uri": "file:///test"})
require.Error(t, err)
assert.Nil(t, result)
assert.ErrorContains(t, err, "SDK session not available")
Expand All @@ -130,7 +130,7 @@ func TestCallSDKMethod_ResourcesRead_NilSession(t *testing.T) {
// and returns a session error when no session is configured.
func TestCallSDKMethod_PromptsList_NilSession(t *testing.T) {
conn := newTestConnection(t)
result, err := conn.callSDKMethod("prompts/list", nil)
result, err := conn.callSDKMethod(context.Background(), "prompts/list", nil)
require.Error(t, err)
assert.Nil(t, result)
assert.ErrorContains(t, err, "SDK session not available")
Expand All @@ -140,7 +140,7 @@ func TestCallSDKMethod_PromptsList_NilSession(t *testing.T) {
// and returns a session error when no session is configured.
func TestCallSDKMethod_PromptsGet_NilSession(t *testing.T) {
conn := newTestConnection(t)
result, err := conn.callSDKMethod("prompts/get", map[string]interface{}{"name": "my-prompt"})
result, err := conn.callSDKMethod(context.Background(), "prompts/get", map[string]interface{}{"name": "my-prompt"})
require.Error(t, err)
assert.Nil(t, result)
assert.ErrorContains(t, err, "SDK session not available")
Expand All @@ -151,7 +151,7 @@ func TestCallSDKMethod_PromptsGet_NilSession(t *testing.T) {
// existing HTTP plain-JSON tests that never reach callSDKMethod).
func TestCallSDKMethod_ToolsList_NilSession(t *testing.T) {
conn := newTestConnection(t)
result, err := conn.callSDKMethod("tools/list", nil)
result, err := conn.callSDKMethod(context.Background(), "tools/list", nil)
require.Error(t, err)
assert.Nil(t, result)
assert.ErrorContains(t, err, "SDK session not available")
Expand All @@ -161,7 +161,7 @@ func TestCallSDKMethod_ToolsList_NilSession(t *testing.T) {
// and returns a session error when no session is configured.
func TestCallSDKMethod_ToolsCall_NilSession(t *testing.T) {
conn := newTestConnection(t)
result, err := conn.callSDKMethod("tools/call", map[string]interface{}{"name": "my-tool"})
result, err := conn.callSDKMethod(context.Background(), "tools/call", map[string]interface{}{"name": "my-tool"})
require.Error(t, err)
assert.Nil(t, result)
assert.ErrorContains(t, err, "SDK session not available")
Expand Down
Loading
Loading