Skip to content
Open
29 changes: 22 additions & 7 deletions tool/mcptoolset/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ import (
"google.golang.org/adk/tool"
)

// MetadataProvider is a callback function that extracts metadata from the tool context
// to be forwarded to MCP tool calls. The returned map[string]any will be set as the
// Meta field on mcp.CallToolParams.
//
// This allows forwarding request-scoped metadata (e.g., from A2A requests) to downstream
// MCP servers for tracing, authentication, or other purposes.
//
// If the provider returns nil, no metadata is attached to the MCP call.
type MetadataProvider func(ctx tool.Context) map[string]any

// New returns MCP ToolSet.
// MCP ToolSet connects to a MCP Server, retrieves MCP Tools into ADK Tools and
// passes them to the LLM.
Expand Down Expand Up @@ -55,9 +65,10 @@ func New(cfg Config) (tool.Toolset, error) {
client = mcp.NewClient(&mcp.Implementation{Name: "adk-mcp-client", Version: version.Version}, nil)
}
return &set{
client: client,
transport: cfg.Transport,
toolFilter: cfg.ToolFilter,
client: client,
transport: cfg.Transport,
toolFilter: cfg.ToolFilter,
metadataProvider: cfg.MetadataProvider,
}, nil
}

Expand All @@ -71,12 +82,16 @@ type Config struct {
// If ToolFilter is nil, then all tools are returned.
// tool.StringPredicate can be convenient if there's a known fixed list of tool names.
ToolFilter tool.Predicate
// MetadataProvider is an optional callback that provides metadata to forward
// to MCP tool calls. If nil, no metadata is forwarded.
MetadataProvider MetadataProvider
}

type set struct {
client *mcp.Client
transport mcp.Transport
toolFilter tool.Predicate
client *mcp.Client
transport mcp.Transport
toolFilter tool.Predicate
metadataProvider MetadataProvider

mu sync.Mutex
session *mcp.ClientSession
Expand Down Expand Up @@ -113,7 +128,7 @@ func (s *set) Tools(ctx agent.ReadonlyContext) ([]tool.Tool, error) {
}

for _, mcpTool := range resp.Tools {
t, err := convertTool(mcpTool, s.getSession)
t, err := convertTool(mcpTool, s.getSession, s.metadataProvider)
if err != nil {
return nil, fmt.Errorf("failed to convert MCP tool %q to adk tool: %w", mcpTool.Name, err)
}
Expand Down
168 changes: 168 additions & 0 deletions tool/mcptoolset/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
icontext "google.golang.org/adk/internal/context"
"google.golang.org/adk/internal/httprr"
"google.golang.org/adk/internal/testutil"
"google.golang.org/adk/internal/toolinternal"
"google.golang.org/adk/model"
"google.golang.org/adk/model/gemini"
"google.golang.org/adk/runner"
Expand Down Expand Up @@ -307,3 +308,170 @@ func TestToolFilter(t *testing.T) {
t.Errorf("tools mismatch (-want +got):\n%s", diff)
}
}

func TestMetadataProvider(t *testing.T) {
var receivedMeta map[string]any

echoToolFunc := func(ctx context.Context, req *mcp.CallToolRequest, input struct{}) (*mcp.CallToolResult, struct{ Message string }, error) {
receivedMeta = req.Params.Meta
return nil, struct{ Message string }{Message: "ok"}, nil
}

clientTransport, serverTransport := mcp.NewInMemoryTransports()

server := mcp.NewServer(&mcp.Implementation{Name: "test_server", Version: "v1.0.0"}, nil)
mcp.AddTool(server, &mcp.Tool{Name: "echo_tool", Description: "echoes input"}, echoToolFunc)
_, err := server.Connect(t.Context(), serverTransport, nil)
if err != nil {
t.Fatal(err)
}

testMetadata := map[string]any{
"request_id": "test-123",
"user_id": "user-456",
"nested_data": map[string]any{"key": "value"},
}
metadataProvider := func(ctx tool.Context) map[string]any {
return testMetadata
}

ts, err := mcptoolset.New(mcptoolset.Config{
Transport: clientTransport,
MetadataProvider: metadataProvider,
})
if err != nil {
t.Fatalf("Failed to create MCP tool set: %v", err)
}

invCtx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{})
readonlyCtx := icontext.NewReadonlyContext(invCtx)
tools, err := ts.Tools(readonlyCtx)
if err != nil {
t.Fatalf("Failed to get tools: %v", err)
}

if len(tools) != 1 {
t.Fatalf("Expected 1 tool, got %d", len(tools))
}

fnTool, ok := tools[0].(toolinternal.FunctionTool)
if !ok {
t.Fatal("Tool does not implement FunctionTool interface")
}

toolCtx := toolinternal.NewToolContext(invCtx, "", nil)
result, err := fnTool.Run(toolCtx, map[string]any{})
if err != nil {
t.Fatalf("Failed to run tool: %v", err)
}

if result == nil {
t.Fatal("Expected non-nil result")
}

if diff := cmp.Diff(testMetadata, receivedMeta); diff != "" {
t.Errorf("metadata mismatch (-want +got):\n%s", diff)
}
}

func TestMetadataProviderNil(t *testing.T) {
echoToolFunc := func(ctx context.Context, req *mcp.CallToolRequest, input struct{}) (*mcp.CallToolResult, struct{ Message string }, error) {
if req.Params.Meta != nil {
t.Errorf("Expected nil metadata, got %v", req.Params.Meta)
}
return nil, struct{ Message string }{Message: "ok"}, nil
}

clientTransport, serverTransport := mcp.NewInMemoryTransports()

server := mcp.NewServer(&mcp.Implementation{Name: "test_server", Version: "v1.0.0"}, nil)
mcp.AddTool(server, &mcp.Tool{Name: "echo_tool", Description: "echoes input"}, echoToolFunc)
_, err := server.Connect(t.Context(), serverTransport, nil)
if err != nil {
t.Fatal(err)
}

ts, err := mcptoolset.New(mcptoolset.Config{
Transport: clientTransport,
// MetadataProvider is nil
})
if err != nil {
t.Fatalf("Failed to create MCP tool set: %v", err)
}

invCtx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{})
readonlyCtx := icontext.NewReadonlyContext(invCtx)
tools, err := ts.Tools(readonlyCtx)
if err != nil {
t.Fatalf("Failed to get tools: %v", err)
}

fnTool, ok := tools[0].(toolinternal.FunctionTool)
if !ok {
t.Fatal("Tool does not implement FunctionTool interface")
}

toolCtx := toolinternal.NewToolContext(invCtx, "", nil)
_, err = fnTool.Run(toolCtx, map[string]any{})
if err != nil {
t.Fatalf("Failed to run tool: %v", err)
}
}

func TestMetadataProviderReturnsNil(t *testing.T) {
var receivedMeta map[string]any
var metaCalled bool

echoToolFunc := func(ctx context.Context, req *mcp.CallToolRequest, input struct{}) (*mcp.CallToolResult, struct{ Message string }, error) {
metaCalled = true
receivedMeta = req.Params.Meta
return nil, struct{ Message string }{Message: "ok"}, nil
}

clientTransport, serverTransport := mcp.NewInMemoryTransports()

server := mcp.NewServer(&mcp.Implementation{Name: "test_server", Version: "v1.0.0"}, nil)
mcp.AddTool(server, &mcp.Tool{Name: "echo_tool", Description: "echoes input"}, echoToolFunc)
_, err := server.Connect(t.Context(), serverTransport, nil)
if err != nil {
t.Fatal(err)
}

metadataProvider := func(ctx tool.Context) map[string]any {
return nil
}

ts, err := mcptoolset.New(mcptoolset.Config{
Transport: clientTransport,
MetadataProvider: metadataProvider,
})
if err != nil {
t.Fatalf("Failed to create MCP tool set: %v", err)
}

invCtx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{})
readonlyCtx := icontext.NewReadonlyContext(invCtx)
tools, err := ts.Tools(readonlyCtx)
if err != nil {
t.Fatalf("Failed to get tools: %v", err)
}

fnTool, ok := tools[0].(toolinternal.FunctionTool)
if !ok {
t.Fatal("Tool does not implement FunctionTool interface")
}

toolCtx := toolinternal.NewToolContext(invCtx, "", nil)
_, err = fnTool.Run(toolCtx, map[string]any{})
if err != nil {
t.Fatalf("Failed to run tool: %v", err)
}

if !metaCalled {
t.Fatal("Tool was not called")
}

if receivedMeta != nil {
t.Errorf("Expected nil metadata when provider returns nil, got %v", receivedMeta)
}
}
Comment on lines +312 to +477
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The three new tests TestMetadataProvider, TestMetadataProviderNil, and TestMetadataProviderReturnsNil contain a significant amount of duplicated setup and execution logic. This can make the tests harder to maintain.

Consider refactoring the common code into a test helper function. This helper could handle setting up the MCP server, creating the toolset, and running the tool. The individual test cases would then just provide the specific metadataProvider and echoToolFunc.

For example, you could have a helper like this:

func runMetadataTest(t *testing.T, provider mcptoolset.MetadataProvider, toolFunc any) {
	t.Helper()

	clientTransport, serverTransport := mcp.NewInMemoryTransports()
	server := mcp.NewServer(&mcp.Implementation{Name: "test_server", Version: "v1.0.0"}, nil)
	mcp.AddTool(server, &mcp.Tool{Name: "echo_tool", Description: "echoes input"}, toolFunc)
	_, err := server.Connect(t.Context(), serverTransport, nil)
	if err != nil {
		t.Fatal(err)
	}

	ts, err := mcptoolset.New(mcptoolset.Config{
		Transport:        clientTransport,
		MetadataProvider: provider,
	})
	if err != nil {
		t.Fatalf("Failed to create MCP tool set: %v", err)
	}

	invCtx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{})
	readonlyCtx := icontext.NewReadonlyContext(invCtx)
	tools, err := ts.Tools(readonlyCtx)
	if err != nil {
		t.Fatalf("Failed to get tools: %v", err)
	}

	if len(tools) != 1 {
		t.Fatalf("Expected 1 tool, got %d", len(tools))
	}

	fnTool, ok := tools[0].(toolinternal.FunctionTool)
	if !ok {
		t.Fatal("Tool does not implement FunctionTool interface")
	}

	toolCtx := toolinternal.NewToolContext(invCtx, "", nil)
	_, err = fnTool.Run(toolCtx, map[string]any{})
	if err != nil {
		t.Fatalf("Failed to run tool: %v", err)
	}
}

Then your tests would become much simpler. For example, TestMetadataProviderNil would look like:

func TestMetadataProviderNil(t *testing.T) {
	echoToolFunc := func(ctx context.Context, req *mcp.CallToolRequest, input struct{}) (*mcp.CallToolResult, struct{ Message string }, error) {
		if req.Params.Meta != nil {
			t.Errorf("Expected nil metadata, got %v", req.Params.Meta)
		}
		return nil, struct{ Message string }{Message: "ok"}, nil
	}
	runMetadataTest(t, nil, echoToolFunc)
}

This would improve readability and maintainability.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mcp.AddTool(server, &mcp.Tool{Name: "echo_tool", Description: "echoes input"}, toolFunc) is not valid code since In cannot be inferred. Also this way the actual asserts are outside of the tests. I think in this case some duplication helps readability and cognitive load when looking at the various tests.

22 changes: 16 additions & 6 deletions tool/mcptoolset/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ import (

type getSessionFunc func(ctx context.Context) (*mcp.ClientSession, error)

func convertTool(t *mcp.Tool, getSessionFunc getSessionFunc) (tool.Tool, error) {
func convertTool(t *mcp.Tool, getSessionFunc getSessionFunc, metadataProvider MetadataProvider) (tool.Tool, error) {
mcp := &mcpTool{
name: t.Name,
description: t.Description,
funcDeclaration: &genai.FunctionDeclaration{
Name: t.Name,
Description: t.Description,
},
getSessionFunc: getSessionFunc,
getSessionFunc: getSessionFunc,
metadataProvider: metadataProvider,
}

// Since t.InputSchema and t.OutputSchema are pointers (*jsonschema.Schema) and the destination ResponseJsonSchema
Expand All @@ -61,7 +62,8 @@ type mcpTool struct {
description string
funcDeclaration *genai.FunctionDeclaration

getSessionFunc getSessionFunc
getSessionFunc getSessionFunc
metadataProvider MetadataProvider
}

// Name implements the tool.Tool.
Expand Down Expand Up @@ -93,11 +95,19 @@ func (t *mcpTool) Run(ctx tool.Context, args any) (map[string]any, error) {
return nil, fmt.Errorf("failed to get session: %w", err)
}

// TODO: add auth
res, err := session.CallTool(ctx, &mcp.CallToolParams{
params := &mcp.CallToolParams{
Name: t.name,
Arguments: args,
})
}

if t.metadataProvider != nil {
if meta := t.metadataProvider(ctx); meta != nil {
params.Meta = mcp.Meta(meta)
}
}

// TODO: add auth
res, err := session.CallTool(ctx, params)
if err != nil {
return nil, fmt.Errorf("failed to call MCP tool %q with err: %w", t.name, err)
}
Expand Down