Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion tool/mcptoolset/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ import (
"google.golang.org/adk/tool"
)

// MetadataProvider is a callback function that extracts metadata from the tool context
// to be forwarded to MCP tool calls. The returned map[string]any will be set as the
// Meta field on mcp.CallToolParams.
//
// This allows forwarding request-scoped metadata (e.g., from A2A requests) to downstream
// MCP servers for tracing, authentication, or other purposes.
//
// If the provider returns nil, no metadata is attached to the MCP call.
type MetadataProvider func(ctx tool.Context) map[string]any

// New returns MCP ToolSet.
// MCP ToolSet connects to a MCP Server, retrieves MCP Tools into ADK Tools and
// passes them to the LLM.
Expand Down Expand Up @@ -52,6 +62,7 @@ func New(cfg Config) (tool.Toolset, error) {
toolFilter: cfg.ToolFilter,
requireConfirmation: cfg.RequireConfirmation,
requireConfirmationProvider: cfg.RequireConfirmationProvider,
metadataProvider: cfg.MetadataProvider,
}, nil
}

Expand All @@ -66,6 +77,9 @@ type Config struct {
// If ToolFilter is nil, then all tools are returned.
// tool.StringPredicate can be convenient if there's a known fixed list of tool names.
ToolFilter tool.Predicate
// MetadataProvider is an optional callback that provides metadata to forward
// to MCP tool calls. If nil, no metadata is forwarded.
MetadataProvider MetadataProvider

// RequireConfirmation flags whether the tools from this toolset must always ask for user confirmation
// before execution. If set to true, the ADK framework will automatically initiate
Expand All @@ -90,6 +104,7 @@ type set struct {
toolFilter tool.Predicate
requireConfirmation bool
requireConfirmationProvider ConfirmationProvider
metadataProvider MetadataProvider
}

func (*set) Name() string {
Expand All @@ -113,7 +128,7 @@ func (s *set) Tools(ctx agent.ReadonlyContext) ([]tool.Tool, error) {

var adkTools []tool.Tool
for _, mcpTool := range mcpTools {
t, err := convertTool(mcpTool, s.mcpClient, s.requireConfirmation, s.requireConfirmationProvider)
t, err := convertTool(mcpTool, s.mcpClient, s.requireConfirmation, s.requireConfirmationProvider, s.metadataProvider)
if err != nil {
return nil, fmt.Errorf("failed to convert MCP tool %q to adk tool: %w", mcpTool.Name, err)
}
Expand Down
97 changes: 97 additions & 0 deletions tool/mcptoolset/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -790,3 +790,100 @@ func TestNewToolSet_RequireConfirmationProvider_Validation(t *testing.T) {
})
}
}

func TestMetadataProvider(t *testing.T) {
testMetadata := map[string]any{
"request_id": "test-123",
"user_id": "user-456",
"nested_data": map[string]any{"key": "value"},
}

testCases := []struct {
name string
provider mcptoolset.MetadataProvider
wantMetadata map[string]any
}{
{
name: "provider returns metadata",
provider: func(ctx tool.Context) map[string]any {
return testMetadata
},
wantMetadata: testMetadata,
},
{
name: "provider is nil",
provider: nil,
wantMetadata: nil,
},
{
name: "provider returns nil",
provider: func(ctx tool.Context) map[string]any {
return nil
},
wantMetadata: nil,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var receivedMeta map[string]any
var toolCalled bool
echoToolFunc := func(ctx context.Context, req *mcp.CallToolRequest, input struct{}) (*mcp.CallToolResult, struct{ Message string }, error) {
toolCalled = true
receivedMeta = req.Params.Meta
return nil, struct{ Message string }{Message: "ok"}, nil
}

runMetadataTest(t, tc.provider, echoToolFunc)
if !toolCalled {
t.Fatal("Tool was not called")
}

if diff := cmp.Diff(tc.wantMetadata, receivedMeta); diff != "" {
t.Errorf("metadata mismatch (-want +got):\n%s", diff)
}
})
}
}

func runMetadataTest[In, Out any](t *testing.T, provider mcptoolset.MetadataProvider, toolFunc mcp.ToolHandlerFor[In, Out]) {
t.Helper()

clientTransport, serverTransport := mcp.NewInMemoryTransports()
server := mcp.NewServer(&mcp.Implementation{Name: "test_server", Version: "v1.0.0"}, nil)
mcp.AddTool(server, &mcp.Tool{Name: "echo_tool", Description: "echoes input"}, toolFunc)
_, err := server.Connect(t.Context(), serverTransport, nil)
if err != nil {
t.Fatal(err)
}

ts, err := mcptoolset.New(mcptoolset.Config{
Transport: clientTransport,
MetadataProvider: provider,
})
if err != nil {
t.Fatalf("Failed to create MCP tool set: %v", err)
}

invCtx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{})
readonlyCtx := icontext.NewReadonlyContext(invCtx)
tools, err := ts.Tools(readonlyCtx)
if err != nil {
t.Fatalf("Failed to get tools: %v", err)
}

if len(tools) != 1 {
t.Fatalf("Expected 1 tool, got %d", len(tools))
}

fnTool, ok := tools[0].(toolinternal.FunctionTool)
if !ok {
t.Fatal("Tool does not implement FunctionTool interface")
}

toolCtx := toolinternal.NewToolContext(invCtx, "", nil, nil)
_, err = fnTool.Run(toolCtx, map[string]any{})
if err != nil {
t.Fatalf("Failed to run tool: %v", err)
}
}
26 changes: 20 additions & 6 deletions tool/mcptoolset/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package mcptoolset
import (
"errors"
"fmt"
"runtime/debug"
"strings"

"github.com/modelcontextprotocol/go-sdk/mcp"
Expand All @@ -28,7 +29,7 @@ import (
"google.golang.org/adk/tool"
)

func convertTool(t *mcp.Tool, client MCPClient, requireConfirmation bool, requireConfirmationProvider ConfirmationProvider) (tool.Tool, error) {
func convertTool(t *mcp.Tool, client MCPClient, requireConfirmation bool, requireConfirmationProvider ConfirmationProvider, metadataProvider MetadataProvider) (tool.Tool, error) {
mcp := &mcpTool{
name: t.Name,
description: t.Description,
Expand All @@ -39,6 +40,7 @@ func convertTool(t *mcp.Tool, client MCPClient, requireConfirmation bool, requir
mcpClient: client,
requireConfirmation: requireConfirmation,
requireConfirmationProvider: requireConfirmationProvider,
metadataProvider: metadataProvider,
}

// Since t.InputSchema and t.OutputSchema are pointers (*jsonschema.Schema) and the destination ResponseJsonSchema
Expand All @@ -65,6 +67,8 @@ type mcpTool struct {
requireConfirmation bool

requireConfirmationProvider ConfirmationProvider

metadataProvider MetadataProvider
}

// Name implements the tool.Tool.
Expand All @@ -90,7 +94,13 @@ func (t *mcpTool) Declaration() *genai.FunctionDeclaration {
return t.funcDeclaration
}

func (t *mcpTool) Run(ctx tool.Context, args any) (map[string]any, error) {
func (t *mcpTool) Run(ctx tool.Context, args any) (result map[string]any, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in tool %q: %v\nstack: %s", t.Name(), r, debug.Stack())
}
}()
Comment on lines +98 to +102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The panic recovery is a great addition for robustness. To make these panic errors more programmatically accessible and inspectable by callers, consider wrapping the panic information in a dedicated structured error type instead of a formatted string. This allows consumers of the toolset to type-assert the error and handle tool panics differently from other errors, for example, by logging the stack trace separately or extracting panic details for monitoring.

For example, you could define a custom error type:

// ToolPanicError represents an error caused by a panic within a tool's execution.
type ToolPanicError struct {
	ToolName   string
	PanicValue any
	Stack      []byte
}

func (e *ToolPanicError) Error() string {
	return fmt.Sprintf("panic in tool %q: %v", e.ToolName, e.PanicValue)
}

And then use it in the defer block:

ddefer func() {
	if r := recover(); r != nil {
		err = &ToolPanicError{
			ToolName:   t.Name(),
			PanicValue: r,
			Stack:      debug.Stack(),
		}
	}
}()

This makes the error much more useful for upstream callers.


if confirmation := ctx.ToolConfirmation(); confirmation != nil {
if !confirmation.Confirmed {
return nil, fmt.Errorf("error tool %q call is rejected", t.Name())
Expand All @@ -115,12 +125,16 @@ func (t *mcpTool) Run(ctx tool.Context, args any) (map[string]any, error) {
return nil, fmt.Errorf("error tool %q requires confirmation, please approve or reject", t.Name())
}
}

// TODO: add auth
res, err := t.mcpClient.CallTool(ctx, &mcp.CallToolParams{
params := &mcp.CallToolParams{
Name: t.name,
Arguments: args,
})
}

if t.metadataProvider != nil {
params.Meta = t.metadataProvider(ctx)
}
// TODO: add auth
res, err := t.mcpClient.CallTool(ctx, params)
if err != nil {
return nil, fmt.Errorf("failed to call MCP tool %q with err: %w", t.name, err)
}
Expand Down