diff --git a/pkg/server/server.go b/pkg/server/server.go index e69a2ab7..17ab1588 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -91,6 +91,7 @@ func NewMCPServer(provider *provider.ApiProvider, logger *zap.Logger, enabledToo version.Version, server.WithLogging(), server.WithRecovery(), + server.WithToolHandlerMiddleware(buildErrorRecoveryMiddleware(logger)), server.WithToolHandlerMiddleware(buildLoggerMiddleware(logger)), server.WithToolHandlerMiddleware(auth.BuildMiddleware(provider.ServerTransport(), logger)), ) @@ -402,6 +403,25 @@ func (s *MCPServer) ServeStdio() error { return err } +// buildErrorRecoveryMiddleware converts tool handler errors into MCP tool results +// with isError=true, allowing LLMs to see the error and retry with different parameters. +// Without this, errors become JSON-RPC -32603 protocol errors that crash MCP clients. +func buildErrorRecoveryMiddleware(logger *zap.Logger) server.ToolHandlerMiddleware { + return func(next server.ToolHandlerFunc) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + res, err := next(ctx, req) + if err != nil { + logger.Warn("Tool call returned error, converting to isError tool result", + zap.String("tool", req.Params.Name), + zap.Error(err), + ) + return mcp.NewToolResultError(err.Error()), nil + } + return res, nil + } + } +} + func buildLoggerMiddleware(logger *zap.Logger) server.ToolHandlerMiddleware { return func(next server.ToolHandlerFunc) server.ToolHandlerFunc { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index e0511a83..1e72a2b9 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -1,10 +1,20 @@ package server import ( + "bytes" + "context" + "fmt" + "io" "os" "testing" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" ) func TestShouldAddTool_ReadOnly_EmptyEnabledTools(t *testing.T) { @@ -273,6 +283,105 @@ func TestShouldAddTool_WriteTool_Attachment(t *testing.T) { }) } +// setupMCPClientServer creates an MCP server with the given options and tool handler, +// wires up a client via stdio pipes, and returns the connected client. +func setupMCPClientServer(t *testing.T, opts []server.ServerOption, toolHandler server.ToolHandlerFunc) *client.Client { + t.Helper() + + mcpSrv := server.NewMCPServer("test", "1.0.0", opts...) + mcpSrv.AddTool(mcp.NewTool("test_tool", + mcp.WithDescription("A test tool"), + ), toolHandler) + + serverReader, clientWriter := io.Pipe() + clientReader, serverWriter := io.Pipe() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + stdioSrv := server.NewStdioServer(mcpSrv) + go func() { + _ = stdioSrv.Listen(ctx, serverReader, serverWriter) + }() + + var logBuf bytes.Buffer + tr := transport.NewIO(clientReader, clientWriter, io.NopCloser(&logBuf)) + err := tr.Start(ctx) + require.NoError(t, err) + t.Cleanup(func() { tr.Close() }) + + c := client.NewClient(tr) + + var initReq mcp.InitializeRequest + initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + _, err = c.Initialize(ctx, initReq) + require.NoError(t, err) + + return c +} + +func TestIntegrationErrorRecoveryMiddleware(t *testing.T) { + logger := zap.NewNop() + + t.Run("handler error is converted to isError tool result", func(t *testing.T) { + c := setupMCPClientServer(t, + []server.ServerOption{server.WithToolHandlerMiddleware(buildErrorRecoveryMiddleware(logger))}, + func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return nil, fmt.Errorf("simulated tool error: invalid channel ID") + }, + ) + + var callReq mcp.CallToolRequest + callReq.Params.Name = "test_tool" + result, err := c.CallTool(context.Background(), callReq) + + require.NoError(t, err, "should not return a JSON-RPC error") + require.NotNil(t, result) + assert.True(t, result.IsError, "result should have isError=true") + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok, "content should be TextContent") + assert.Contains(t, textContent.Text, "simulated tool error: invalid channel ID") + }) + + t.Run("without middleware handler error becomes JSON-RPC error", func(t *testing.T) { + c := setupMCPClientServer(t, + nil, // no error recovery middleware + func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return nil, fmt.Errorf("simulated tool error: invalid channel ID") + }, + ) + + var callReq mcp.CallToolRequest + callReq.Params.Name = "test_tool" + result, err := c.CallTool(context.Background(), callReq) + + assert.Error(t, err, "should return a JSON-RPC error without middleware") + assert.Nil(t, result) + }) + + t.Run("successful tool call passes through unchanged", func(t *testing.T) { + c := setupMCPClientServer(t, + []server.ServerOption{server.WithToolHandlerMiddleware(buildErrorRecoveryMiddleware(logger))}, + func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("all good"), nil + }, + ) + + var callReq mcp.CallToolRequest + callReq.Params.Name = "test_tool" + result, err := c.CallTool(context.Background(), callReq) + + require.NoError(t, err) + require.NotNil(t, result) + assert.False(t, result.IsError, "successful result should not have isError=true") + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok) + assert.Equal(t, "all good", textContent.Text) + }) +} + func TestShouldAddTool_Matrix(t *testing.T) { // Test the complete matrix from the plan: // | ENABLED_TOOLS | TOOL_ENV_VAR | Result |