Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/mcp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ func main() {
if err != nil {
log.Fatalf("Failed to create MCP tool set: %v", err)
}
defer func() {
if err := mcpToolSet.Close(); err != nil {
log.Printf("Failed to close MCP tool set: %v", err)
}
}()

// Create LLMAgent with MCP tool set
a, err := llmagent.New(llmagent.Config{
Expand Down
9 changes: 8 additions & 1 deletion tool/mcptoolset/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import (
// Model: model,
// Description: "...",
// Instruction: "...",
// Toolsets: []tool.Set{
// Toolsets: []tool.Toolset{
// mcptoolset.New(mcptoolset.Config{
// Transport: &mcp.CommandTransport{Command: exec.Command("myserver")}
// }),
Expand Down Expand Up @@ -128,6 +128,13 @@ func (s *set) Tools(ctx agent.ReadonlyContext) ([]tool.Tool, error) {
return adkTools, nil
}

// Close closes the underlying MCP client.
func (s *set) Close() error {
if c, ok := s.mcpClient.(interface{ Close() error }); ok {
return c.Close()
}
return nil
}
// ConfirmationProvider defines a function that dynamically determines whether
// a specific tool execution requires user confirmation.
//
Expand Down
20 changes: 20 additions & 0 deletions tool/mcptoolset/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ func TestMCPToolSet(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create MCP tool set: %v", err)
}
defer func() {
if err := ts.Close(); err != nil {
t.Errorf("ts.Close() failed: %v", err)
}
}()

agent, err := llmagent.New(llmagent.Config{
Name: "weather_time_agent",
Expand Down Expand Up @@ -228,6 +233,11 @@ func TestToolFilter(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create MCP tool set: %v", err)
}
defer func() {
if err := ts.Close(); err != nil {
t.Errorf("ts.Close() failed: %v", err)
}
}()

tools, err := ts.Tools(icontext.NewReadonlyContext(
icontext.NewInvocationContext(
Expand Down Expand Up @@ -263,6 +273,11 @@ func TestListToolsReconnection(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create MCP tool set: %v", err)
}
defer func() {
if err := ts.Close(); err != nil {
t.Errorf("ts.Close() failed: %v", err)
}
}()

ctx := icontext.NewReadonlyContext(icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{}))

Expand Down Expand Up @@ -302,6 +317,11 @@ func TestCallToolReconnection(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create MCP tool set: %v", err)
}
defer func() {
if err := ts.Close(); err != nil {
t.Errorf("ts.Close() failed: %v", err)
}
}()

invCtx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{})
ctx := icontext.NewReadonlyContext(invCtx)
Expand Down
8 changes: 8 additions & 0 deletions tool/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ type Toolset interface {
// ReadonlyContext can be used to dynamically determine which tools
// to return based on the current invocation state.
Tools(ctx agent.ReadonlyContext) ([]Tool, error)

// Close performs cleanup and releases resources held by the toolset.
// It should be called when the toolset is no longer needed.
Close() error
}

// Predicate is a function which decides whether a tool should be exposed to LLM.
Expand Down Expand Up @@ -153,3 +157,7 @@ func (f *filteredToolset) Tools(ctx agent.ReadonlyContext) ([]Tool, error) {
}
return filtered, nil
}

func (f *filteredToolset) Close() error {
return f.toolset.Close()
}
23 changes: 23 additions & 0 deletions tool/tool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package tool_test
import (
"testing"

"google.golang.org/adk/agent"
"google.golang.org/adk/internal/toolinternal"
"google.golang.org/adk/tool"
"google.golang.org/adk/tool/agenttool"
Expand Down Expand Up @@ -103,3 +104,25 @@ func TestTypes(t *testing.T) {
})
}
}

func TestFilterToolset_Close(t *testing.T) {
mock := &mockToolset{}
ts := tool.FilterToolset(mock, func(ctx agent.ReadonlyContext, _ tool.Tool) bool { return true })
if err := ts.Close(); err != nil {
t.Errorf("Close() error = %v", err)
}
if !mock.closeCalled {
t.Error("Close() was not called on underlying toolset")
}
}

type mockToolset struct {
closeCalled bool
}

func (m *mockToolset) Name() string { return "mock" }
func (m *mockToolset) Tools(ctx agent.ReadonlyContext) ([]tool.Tool, error) { return nil, nil }
func (m *mockToolset) Close() error {
m.closeCalled = true
return nil
}