diff --git a/tool/mcptoolset/set.go b/tool/mcptoolset/set.go index e27ecdb69..f35cb9785 100644 --- a/tool/mcptoolset/set.go +++ b/tool/mcptoolset/set.go @@ -27,6 +27,16 @@ import ( "google.golang.org/adk/tool" ) +// MetadataProvider is a callback function that extracts metadata from the tool context +// to be forwarded to MCP tool calls. The returned map[string]any will be set as the +// Meta field on mcp.CallToolParams. +// +// This allows forwarding request-scoped metadata (e.g., from A2A requests) to downstream +// MCP servers for tracing, authentication, or other purposes. +// +// If the provider returns nil, no metadata is attached to the MCP call. +type MetadataProvider func(ctx tool.Context) map[string]any + // New returns MCP ToolSet. // MCP ToolSet connects to a MCP Server, retrieves MCP Tools into ADK Tools and // passes them to the LLM. @@ -55,9 +65,10 @@ func New(cfg Config) (tool.Toolset, error) { client = mcp.NewClient(&mcp.Implementation{Name: "adk-mcp-client", Version: version.Version}, nil) } return &set{ - client: client, - transport: cfg.Transport, - toolFilter: cfg.ToolFilter, + client: client, + transport: cfg.Transport, + toolFilter: cfg.ToolFilter, + metadataProvider: cfg.MetadataProvider, }, nil } @@ -71,12 +82,16 @@ type Config struct { // If ToolFilter is nil, then all tools are returned. // tool.StringPredicate can be convenient if there's a known fixed list of tool names. ToolFilter tool.Predicate + // MetadataProvider is an optional callback that provides metadata to forward + // to MCP tool calls. If nil, no metadata is forwarded. + MetadataProvider MetadataProvider } type set struct { - client *mcp.Client - transport mcp.Transport - toolFilter tool.Predicate + client *mcp.Client + transport mcp.Transport + toolFilter tool.Predicate + metadataProvider MetadataProvider mu sync.Mutex session *mcp.ClientSession @@ -113,7 +128,7 @@ func (s *set) Tools(ctx agent.ReadonlyContext) ([]tool.Tool, error) { } for _, mcpTool := range resp.Tools { - t, err := convertTool(mcpTool, s.getSession) + t, err := convertTool(mcpTool, s.getSession, s.metadataProvider) if err != nil { return nil, fmt.Errorf("failed to convert MCP tool %q to adk tool: %w", mcpTool.Name, err) } diff --git a/tool/mcptoolset/set_test.go b/tool/mcptoolset/set_test.go index 87f14b7c7..94f418fd1 100644 --- a/tool/mcptoolset/set_test.go +++ b/tool/mcptoolset/set_test.go @@ -34,6 +34,7 @@ import ( icontext "google.golang.org/adk/internal/context" "google.golang.org/adk/internal/httprr" "google.golang.org/adk/internal/testutil" + "google.golang.org/adk/internal/toolinternal" "google.golang.org/adk/model" "google.golang.org/adk/model/gemini" "google.golang.org/adk/runner" @@ -307,3 +308,170 @@ func TestToolFilter(t *testing.T) { t.Errorf("tools mismatch (-want +got):\n%s", diff) } } + +func TestMetadataProvider(t *testing.T) { + var receivedMeta map[string]any + + echoToolFunc := func(ctx context.Context, req *mcp.CallToolRequest, input struct{}) (*mcp.CallToolResult, struct{ Message string }, error) { + receivedMeta = req.Params.Meta + return nil, struct{ Message string }{Message: "ok"}, nil + } + + clientTransport, serverTransport := mcp.NewInMemoryTransports() + + server := mcp.NewServer(&mcp.Implementation{Name: "test_server", Version: "v1.0.0"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "echo_tool", Description: "echoes input"}, echoToolFunc) + _, err := server.Connect(t.Context(), serverTransport, nil) + if err != nil { + t.Fatal(err) + } + + testMetadata := map[string]any{ + "request_id": "test-123", + "user_id": "user-456", + "nested_data": map[string]any{"key": "value"}, + } + metadataProvider := func(ctx tool.Context) map[string]any { + return testMetadata + } + + ts, err := mcptoolset.New(mcptoolset.Config{ + Transport: clientTransport, + MetadataProvider: metadataProvider, + }) + if err != nil { + t.Fatalf("Failed to create MCP tool set: %v", err) + } + + invCtx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{}) + readonlyCtx := icontext.NewReadonlyContext(invCtx) + tools, err := ts.Tools(readonlyCtx) + if err != nil { + t.Fatalf("Failed to get tools: %v", err) + } + + if len(tools) != 1 { + t.Fatalf("Expected 1 tool, got %d", len(tools)) + } + + fnTool, ok := tools[0].(toolinternal.FunctionTool) + if !ok { + t.Fatal("Tool does not implement FunctionTool interface") + } + + toolCtx := toolinternal.NewToolContext(invCtx, "", nil) + result, err := fnTool.Run(toolCtx, map[string]any{}) + if err != nil { + t.Fatalf("Failed to run tool: %v", err) + } + + if result == nil { + t.Fatal("Expected non-nil result") + } + + if diff := cmp.Diff(testMetadata, receivedMeta); diff != "" { + t.Errorf("metadata mismatch (-want +got):\n%s", diff) + } +} + +func TestMetadataProviderNil(t *testing.T) { + echoToolFunc := func(ctx context.Context, req *mcp.CallToolRequest, input struct{}) (*mcp.CallToolResult, struct{ Message string }, error) { + if req.Params.Meta != nil { + t.Errorf("Expected nil metadata, got %v", req.Params.Meta) + } + return nil, struct{ Message string }{Message: "ok"}, nil + } + + clientTransport, serverTransport := mcp.NewInMemoryTransports() + + server := mcp.NewServer(&mcp.Implementation{Name: "test_server", Version: "v1.0.0"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "echo_tool", Description: "echoes input"}, echoToolFunc) + _, err := server.Connect(t.Context(), serverTransport, nil) + if err != nil { + t.Fatal(err) + } + + ts, err := mcptoolset.New(mcptoolset.Config{ + Transport: clientTransport, + // MetadataProvider is nil + }) + if err != nil { + t.Fatalf("Failed to create MCP tool set: %v", err) + } + + invCtx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{}) + readonlyCtx := icontext.NewReadonlyContext(invCtx) + tools, err := ts.Tools(readonlyCtx) + if err != nil { + t.Fatalf("Failed to get tools: %v", err) + } + + fnTool, ok := tools[0].(toolinternal.FunctionTool) + if !ok { + t.Fatal("Tool does not implement FunctionTool interface") + } + + toolCtx := toolinternal.NewToolContext(invCtx, "", nil) + _, err = fnTool.Run(toolCtx, map[string]any{}) + if err != nil { + t.Fatalf("Failed to run tool: %v", err) + } +} + +func TestMetadataProviderReturnsNil(t *testing.T) { + var receivedMeta map[string]any + var metaCalled bool + + echoToolFunc := func(ctx context.Context, req *mcp.CallToolRequest, input struct{}) (*mcp.CallToolResult, struct{ Message string }, error) { + metaCalled = true + receivedMeta = req.Params.Meta + return nil, struct{ Message string }{Message: "ok"}, nil + } + + clientTransport, serverTransport := mcp.NewInMemoryTransports() + + server := mcp.NewServer(&mcp.Implementation{Name: "test_server", Version: "v1.0.0"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "echo_tool", Description: "echoes input"}, echoToolFunc) + _, err := server.Connect(t.Context(), serverTransport, nil) + if err != nil { + t.Fatal(err) + } + + metadataProvider := func(ctx tool.Context) map[string]any { + return nil + } + + ts, err := mcptoolset.New(mcptoolset.Config{ + Transport: clientTransport, + MetadataProvider: metadataProvider, + }) + if err != nil { + t.Fatalf("Failed to create MCP tool set: %v", err) + } + + invCtx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{}) + readonlyCtx := icontext.NewReadonlyContext(invCtx) + tools, err := ts.Tools(readonlyCtx) + if err != nil { + t.Fatalf("Failed to get tools: %v", err) + } + + fnTool, ok := tools[0].(toolinternal.FunctionTool) + if !ok { + t.Fatal("Tool does not implement FunctionTool interface") + } + + toolCtx := toolinternal.NewToolContext(invCtx, "", nil) + _, err = fnTool.Run(toolCtx, map[string]any{}) + if err != nil { + t.Fatalf("Failed to run tool: %v", err) + } + + if !metaCalled { + t.Fatal("Tool was not called") + } + + if receivedMeta != nil { + t.Errorf("Expected nil metadata when provider returns nil, got %v", receivedMeta) + } +} diff --git a/tool/mcptoolset/tool.go b/tool/mcptoolset/tool.go index ee2354f09..eaca2e6aa 100644 --- a/tool/mcptoolset/tool.go +++ b/tool/mcptoolset/tool.go @@ -31,7 +31,7 @@ import ( type getSessionFunc func(ctx context.Context) (*mcp.ClientSession, error) -func convertTool(t *mcp.Tool, getSessionFunc getSessionFunc) (tool.Tool, error) { +func convertTool(t *mcp.Tool, getSessionFunc getSessionFunc, metadataProvider MetadataProvider) (tool.Tool, error) { mcp := &mcpTool{ name: t.Name, description: t.Description, @@ -39,7 +39,8 @@ func convertTool(t *mcp.Tool, getSessionFunc getSessionFunc) (tool.Tool, error) Name: t.Name, Description: t.Description, }, - getSessionFunc: getSessionFunc, + getSessionFunc: getSessionFunc, + metadataProvider: metadataProvider, } // Since t.InputSchema and t.OutputSchema are pointers (*jsonschema.Schema) and the destination ResponseJsonSchema @@ -61,7 +62,8 @@ type mcpTool struct { description string funcDeclaration *genai.FunctionDeclaration - getSessionFunc getSessionFunc + getSessionFunc getSessionFunc + metadataProvider MetadataProvider } // Name implements the tool.Tool. @@ -93,11 +95,19 @@ func (t *mcpTool) Run(ctx tool.Context, args any) (map[string]any, error) { return nil, fmt.Errorf("failed to get session: %w", err) } - // TODO: add auth - res, err := session.CallTool(ctx, &mcp.CallToolParams{ + params := &mcp.CallToolParams{ Name: t.name, Arguments: args, - }) + } + + if t.metadataProvider != nil { + if meta := t.metadataProvider(ctx); meta != nil { + params.Meta = mcp.Meta(meta) + } + } + + // TODO: add auth + res, err := session.CallTool(ctx, params) if err != nil { return nil, fmt.Errorf("failed to call MCP tool %q with err: %w", t.name, err) }