diff --git a/tool/mcptoolset/set.go b/tool/mcptoolset/set.go index 53e424f55..c4304ced5 100644 --- a/tool/mcptoolset/set.go +++ b/tool/mcptoolset/set.go @@ -24,6 +24,16 @@ import ( "google.golang.org/adk/tool" ) +// MetadataProvider is a callback function that extracts metadata from the tool context +// to be forwarded to MCP tool calls. The returned map[string]any will be set as the +// Meta field on mcp.CallToolParams. +// +// This allows forwarding request-scoped metadata (e.g., from A2A requests) to downstream +// MCP servers for tracing, authentication, or other purposes. +// +// If the provider returns nil, no metadata is attached to the MCP call. +type MetadataProvider func(ctx tool.Context) map[string]any + // New returns MCP ToolSet. // MCP ToolSet connects to a MCP Server, retrieves MCP Tools into ADK Tools and // passes them to the LLM. @@ -52,6 +62,7 @@ func New(cfg Config) (tool.Toolset, error) { toolFilter: cfg.ToolFilter, requireConfirmation: cfg.RequireConfirmation, requireConfirmationProvider: cfg.RequireConfirmationProvider, + metadataProvider: cfg.MetadataProvider, }, nil } @@ -66,6 +77,9 @@ type Config struct { // If ToolFilter is nil, then all tools are returned. // tool.StringPredicate can be convenient if there's a known fixed list of tool names. ToolFilter tool.Predicate + // MetadataProvider is an optional callback that provides metadata to forward + // to MCP tool calls. If nil, no metadata is forwarded. + MetadataProvider MetadataProvider // RequireConfirmation flags whether the tools from this toolset must always ask for user confirmation // before execution. If set to true, the ADK framework will automatically initiate @@ -90,6 +104,7 @@ type set struct { toolFilter tool.Predicate requireConfirmation bool requireConfirmationProvider ConfirmationProvider + metadataProvider MetadataProvider } func (*set) Name() string { @@ -113,7 +128,7 @@ func (s *set) Tools(ctx agent.ReadonlyContext) ([]tool.Tool, error) { var adkTools []tool.Tool for _, mcpTool := range mcpTools { - t, err := convertTool(mcpTool, s.mcpClient, s.requireConfirmation, s.requireConfirmationProvider) + t, err := convertTool(mcpTool, s.mcpClient, s.requireConfirmation, s.requireConfirmationProvider, s.metadataProvider) if err != nil { return nil, fmt.Errorf("failed to convert MCP tool %q to adk tool: %w", mcpTool.Name, err) } diff --git a/tool/mcptoolset/set_test.go b/tool/mcptoolset/set_test.go index 0b4c1cffd..56821b043 100644 --- a/tool/mcptoolset/set_test.go +++ b/tool/mcptoolset/set_test.go @@ -790,3 +790,100 @@ func TestNewToolSet_RequireConfirmationProvider_Validation(t *testing.T) { }) } } + +func TestMetadataProvider(t *testing.T) { + testMetadata := map[string]any{ + "request_id": "test-123", + "user_id": "user-456", + "nested_data": map[string]any{"key": "value"}, + } + + testCases := []struct { + name string + provider mcptoolset.MetadataProvider + wantMetadata map[string]any + }{ + { + name: "provider returns metadata", + provider: func(ctx tool.Context) map[string]any { + return testMetadata + }, + wantMetadata: testMetadata, + }, + { + name: "provider is nil", + provider: nil, + wantMetadata: nil, + }, + { + name: "provider returns nil", + provider: func(ctx tool.Context) map[string]any { + return nil + }, + wantMetadata: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var receivedMeta map[string]any + var toolCalled bool + echoToolFunc := func(ctx context.Context, req *mcp.CallToolRequest, input struct{}) (*mcp.CallToolResult, struct{ Message string }, error) { + toolCalled = true + receivedMeta = req.Params.Meta + return nil, struct{ Message string }{Message: "ok"}, nil + } + + runMetadataTest(t, tc.provider, echoToolFunc) + if !toolCalled { + t.Fatal("Tool was not called") + } + + if diff := cmp.Diff(tc.wantMetadata, receivedMeta); diff != "" { + t.Errorf("metadata mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func runMetadataTest[In, Out any](t *testing.T, provider mcptoolset.MetadataProvider, toolFunc mcp.ToolHandlerFor[In, Out]) { + t.Helper() + + clientTransport, serverTransport := mcp.NewInMemoryTransports() + server := mcp.NewServer(&mcp.Implementation{Name: "test_server", Version: "v1.0.0"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "echo_tool", Description: "echoes input"}, toolFunc) + _, err := server.Connect(t.Context(), serverTransport, nil) + if err != nil { + t.Fatal(err) + } + + ts, err := mcptoolset.New(mcptoolset.Config{ + Transport: clientTransport, + MetadataProvider: provider, + }) + if err != nil { + t.Fatalf("Failed to create MCP tool set: %v", err) + } + + invCtx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{}) + readonlyCtx := icontext.NewReadonlyContext(invCtx) + tools, err := ts.Tools(readonlyCtx) + if err != nil { + t.Fatalf("Failed to get tools: %v", err) + } + + if len(tools) != 1 { + t.Fatalf("Expected 1 tool, got %d", len(tools)) + } + + fnTool, ok := tools[0].(toolinternal.FunctionTool) + if !ok { + t.Fatal("Tool does not implement FunctionTool interface") + } + + toolCtx := toolinternal.NewToolContext(invCtx, "", nil, nil) + _, err = fnTool.Run(toolCtx, map[string]any{}) + if err != nil { + t.Fatalf("Failed to run tool: %v", err) + } +} diff --git a/tool/mcptoolset/tool.go b/tool/mcptoolset/tool.go index 640c0cfc3..d54e73c70 100644 --- a/tool/mcptoolset/tool.go +++ b/tool/mcptoolset/tool.go @@ -17,6 +17,7 @@ package mcptoolset import ( "errors" "fmt" + "runtime/debug" "strings" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -28,7 +29,7 @@ import ( "google.golang.org/adk/tool" ) -func convertTool(t *mcp.Tool, client MCPClient, requireConfirmation bool, requireConfirmationProvider ConfirmationProvider) (tool.Tool, error) { +func convertTool(t *mcp.Tool, client MCPClient, requireConfirmation bool, requireConfirmationProvider ConfirmationProvider, metadataProvider MetadataProvider) (tool.Tool, error) { mcp := &mcpTool{ name: t.Name, description: t.Description, @@ -39,6 +40,7 @@ func convertTool(t *mcp.Tool, client MCPClient, requireConfirmation bool, requir mcpClient: client, requireConfirmation: requireConfirmation, requireConfirmationProvider: requireConfirmationProvider, + metadataProvider: metadataProvider, } // Since t.InputSchema and t.OutputSchema are pointers (*jsonschema.Schema) and the destination ResponseJsonSchema @@ -65,6 +67,8 @@ type mcpTool struct { requireConfirmation bool requireConfirmationProvider ConfirmationProvider + + metadataProvider MetadataProvider } // Name implements the tool.Tool. @@ -90,7 +94,13 @@ func (t *mcpTool) Declaration() *genai.FunctionDeclaration { return t.funcDeclaration } -func (t *mcpTool) Run(ctx tool.Context, args any) (map[string]any, error) { +func (t *mcpTool) Run(ctx tool.Context, args any) (result map[string]any, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic in tool %q: %v\nstack: %s", t.Name(), r, debug.Stack()) + } + }() + if confirmation := ctx.ToolConfirmation(); confirmation != nil { if !confirmation.Confirmed { return nil, fmt.Errorf("error tool %q call is rejected", t.Name()) @@ -115,12 +125,16 @@ func (t *mcpTool) Run(ctx tool.Context, args any) (map[string]any, error) { return nil, fmt.Errorf("error tool %q requires confirmation, please approve or reject", t.Name()) } } - - // TODO: add auth - res, err := t.mcpClient.CallTool(ctx, &mcp.CallToolParams{ + params := &mcp.CallToolParams{ Name: t.name, Arguments: args, - }) + } + + if t.metadataProvider != nil { + params.Meta = t.metadataProvider(ctx) + } + // TODO: add auth + res, err := t.mcpClient.CallTool(ctx, params) if err != nil { return nil, fmt.Errorf("failed to call MCP tool %q with err: %w", t.name, err) }