Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func NewMCPServer(provider *provider.ApiProvider, logger *zap.Logger, enabledToo
version.Version,
server.WithLogging(),
server.WithRecovery(),
server.WithToolHandlerMiddleware(buildErrorRecoveryMiddleware(logger)),
server.WithToolHandlerMiddleware(buildLoggerMiddleware(logger)),
server.WithToolHandlerMiddleware(auth.BuildMiddleware(provider.ServerTransport(), logger)),
)
Expand Down Expand Up @@ -402,6 +403,25 @@ func (s *MCPServer) ServeStdio() error {
return err
}

// buildErrorRecoveryMiddleware converts tool handler errors into MCP tool results
// with isError=true, allowing LLMs to see the error and retry with different parameters.
// Without this, errors become JSON-RPC -32603 protocol errors that crash MCP clients.
func buildErrorRecoveryMiddleware(logger *zap.Logger) server.ToolHandlerMiddleware {
return func(next server.ToolHandlerFunc) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
res, err := next(ctx, req)
if err != nil {
logger.Warn("Tool call returned error, converting to isError tool result",
zap.String("tool", req.Params.Name),
zap.Error(err),
)
return mcp.NewToolResultError(err.Error()), nil
}
return res, nil
}
}
}

func buildLoggerMiddleware(logger *zap.Logger) server.ToolHandlerMiddleware {
return func(next server.ToolHandlerFunc) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
Expand Down
109 changes: 109 additions & 0 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
package server

import (
"bytes"
"context"
"fmt"
"io"
"os"
"testing"

"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)

func TestShouldAddTool_ReadOnly_EmptyEnabledTools(t *testing.T) {
Expand Down Expand Up @@ -273,6 +283,105 @@ func TestShouldAddTool_WriteTool_Attachment(t *testing.T) {
})
}

// setupMCPClientServer creates an MCP server with the given options and tool handler,
// wires up a client via stdio pipes, and returns the connected client.
func setupMCPClientServer(t *testing.T, opts []server.ServerOption, toolHandler server.ToolHandlerFunc) *client.Client {
t.Helper()

mcpSrv := server.NewMCPServer("test", "1.0.0", opts...)
mcpSrv.AddTool(mcp.NewTool("test_tool",
mcp.WithDescription("A test tool"),
), toolHandler)

serverReader, clientWriter := io.Pipe()
clientReader, serverWriter := io.Pipe()

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

stdioSrv := server.NewStdioServer(mcpSrv)
go func() {
_ = stdioSrv.Listen(ctx, serverReader, serverWriter)
}()

var logBuf bytes.Buffer
tr := transport.NewIO(clientReader, clientWriter, io.NopCloser(&logBuf))
err := tr.Start(ctx)
require.NoError(t, err)
t.Cleanup(func() { tr.Close() })

c := client.NewClient(tr)

var initReq mcp.InitializeRequest
initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
_, err = c.Initialize(ctx, initReq)
require.NoError(t, err)

return c
}

func TestIntegrationErrorRecoveryMiddleware(t *testing.T) {
logger := zap.NewNop()

t.Run("handler error is converted to isError tool result", func(t *testing.T) {
c := setupMCPClientServer(t,
[]server.ServerOption{server.WithToolHandlerMiddleware(buildErrorRecoveryMiddleware(logger))},
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return nil, fmt.Errorf("simulated tool error: invalid channel ID")
},
)

var callReq mcp.CallToolRequest
callReq.Params.Name = "test_tool"
result, err := c.CallTool(context.Background(), callReq)

require.NoError(t, err, "should not return a JSON-RPC error")
require.NotNil(t, result)
assert.True(t, result.IsError, "result should have isError=true")
require.Len(t, result.Content, 1)
textContent, ok := result.Content[0].(mcp.TextContent)
require.True(t, ok, "content should be TextContent")
assert.Contains(t, textContent.Text, "simulated tool error: invalid channel ID")
})

t.Run("without middleware handler error becomes JSON-RPC error", func(t *testing.T) {
c := setupMCPClientServer(t,
nil, // no error recovery middleware
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return nil, fmt.Errorf("simulated tool error: invalid channel ID")
},
)

var callReq mcp.CallToolRequest
callReq.Params.Name = "test_tool"
result, err := c.CallTool(context.Background(), callReq)

assert.Error(t, err, "should return a JSON-RPC error without middleware")
assert.Nil(t, result)
})

t.Run("successful tool call passes through unchanged", func(t *testing.T) {
c := setupMCPClientServer(t,
[]server.ServerOption{server.WithToolHandlerMiddleware(buildErrorRecoveryMiddleware(logger))},
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return mcp.NewToolResultText("all good"), nil
},
)

var callReq mcp.CallToolRequest
callReq.Params.Name = "test_tool"
result, err := c.CallTool(context.Background(), callReq)

require.NoError(t, err)
require.NotNil(t, result)
assert.False(t, result.IsError, "successful result should not have isError=true")
require.Len(t, result.Content, 1)
textContent, ok := result.Content[0].(mcp.TextContent)
require.True(t, ok)
assert.Equal(t, "all good", textContent.Text)
})
}

func TestShouldAddTool_Matrix(t *testing.T) {
// Test the complete matrix from the plan:
// | ENABLED_TOOLS | TOOL_ENV_VAR | Result |
Expand Down