Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions agent/workflowagents/loopagent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,16 @@ func (a *loopAgent) Run(ctx agent.InvocationContext) iter.Seq2[*session.Event, e
for _, subAgent := range ctx.Agent().SubAgents() {
for event, err := range subAgent.Run(ctx) {
// TODO: ensure consistency -- if there's an error, return and close iterator, verify everywhere in ADK.
if event != nil && event.Actions.ExitLoop {
shouldExit = true
// Consume the ExitLoop flag so parent agents
// (e.g. a SequentialAgent wrapping this loop)
// don't also react to it and exit prematurely.
event.Actions.ExitLoop = false
}
if !yield(event, err) {
return
}

if event != nil && event.Actions.Escalate {
shouldExit = true
}
}
if shouldExit {
return
Expand Down
26 changes: 11 additions & 15 deletions agent/workflowagents/loopagent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func TestNewLoopAgent(t *testing.T) {
},
},
{
name: "loop with escalate function returns sumarization",
name: "loop with exit_loop function returns summarization",
args: args{
maxIterations: 2,
subAgents: []agent.Agent{newLmmAgentWithFunctionCall(t, 0, false), newCustomAgent(t, 1)},
Expand All @@ -135,9 +135,6 @@ func TestNewLoopAgent(t *testing.T) {
LLMResponse: model.LLMResponse{
Content: genai.NewContentFromFunctionResponse("exampleFunction", make(map[string]any), genai.RoleUser),
},
Actions: session.EventActions{
Escalate: true,
},
},
{
Author: "custom_agent_0",
Expand All @@ -153,7 +150,7 @@ func TestNewLoopAgent(t *testing.T) {
},
},
{
name: "loop with escalate function returns sumarization",
name: "loop with exit_loop function skips summarization",
args: args{
maxIterations: 2,
subAgents: []agent.Agent{newLmmAgentWithFunctionCall(t, 0, true), newCustomAgent(t, 1)},
Expand All @@ -171,7 +168,6 @@ func TestNewLoopAgent(t *testing.T) {
Content: genai.NewContentFromFunctionResponse("exampleFunction", make(map[string]any), genai.RoleUser),
},
Actions: session.EventActions{
Escalate: true,
SkipSummarization: true,
},
},
Expand Down Expand Up @@ -287,29 +283,29 @@ func (a *customAgent) Run(agent.InvocationContext) iter.Seq2[*session.Event, err

type EmptyArgs struct{}

func exampleFunctionThatEscalates(ctx tool.Context, myArgs EmptyArgs) (map[string]string, error) {
ctx.Actions().Escalate = true
func exampleFunctionThatExitsLoop(ctx tool.Context, myArgs EmptyArgs) (map[string]string, error) {
ctx.Actions().ExitLoop = true
ctx.Actions().SkipSummarization = false
return map[string]string{}, nil
}

func exampleFunctionThatEscalatesAndSkips(ctx tool.Context, myArgs EmptyArgs) (map[string]string, error) {
ctx.Actions().Escalate = true
func exampleFunctionThatExitsLoopAndSkips(ctx tool.Context, myArgs EmptyArgs) (map[string]string, error) {
ctx.Actions().ExitLoop = true
ctx.Actions().SkipSummarization = true
return map[string]string{}, nil
}

func newLmmAgentWithFunctionCall(t *testing.T, id int, skipSummarization bool) agent.Agent {
t.Helper()

exampleFunction := exampleFunctionThatEscalates
exampleFunction := exampleFunctionThatExitsLoop
if skipSummarization {
exampleFunction = exampleFunctionThatEscalatesAndSkips
exampleFunction = exampleFunctionThatExitsLoopAndSkips
}

exampleFunctionThatEscalatesTool, err := functiontool.New(functiontool.Config{
exampleFunctionThatExitsLoopTool, err := functiontool.New(functiontool.Config{
Name: "exampleFunction",
Description: "Call this function to escalate\n",
Description: "Call this function to exit the loop\n",
}, exampleFunction)
if err != nil {
t.Fatalf("error creating exampleFunction tool: %s", err)
Expand All @@ -318,7 +314,7 @@ func newLmmAgentWithFunctionCall(t *testing.T, id int, skipSummarization bool) a
customAgent, err := llmagent.New(llmagent.Config{
Name: fmt.Sprintf("custom_agent_%v", id),
Model: &FakeLLM{id: id, callCounter: 0, skipSummarization: skipSummarization},
Tools: []tool.Tool{exampleFunctionThatEscalatesTool},
Tools: []tool.Tool{exampleFunctionThatExitsLoopTool},
})
if err != nil {
t.Fatal(err)
Expand Down
122 changes: 122 additions & 0 deletions agent/workflowagents/sequentialagent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

"google.golang.org/adk/agent"
"google.golang.org/adk/agent/llmagent"
"google.golang.org/adk/agent/workflowagents/loopagent"
"google.golang.org/adk/agent/workflowagents/sequentialagent"
"google.golang.org/adk/model"
"google.golang.org/adk/runner"
Expand Down Expand Up @@ -312,3 +313,124 @@ func (f *FakeLLM) GenerateContent(ctx context.Context, req *model.LLMRequest, st
}, nil)
}
}

// exitLoopAgent is an agent whose Run yields an event with ExitLoop=true,
// simulating what happens when the exit_loop tool is called.
type exitLoopAgent struct {
id int
}

func newExitLoopAgent(t *testing.T, id int) agent.Agent {
t.Helper()
ea := &exitLoopAgent{id: id}
a, err := agent.New(agent.Config{
Name: fmt.Sprintf("exit_loop_agent_%v", id),
Run: ea.Run,
})
if err != nil {
t.Fatal(err)
}
return a
}

func (a *exitLoopAgent) Run(agent.InvocationContext) iter.Seq2[*session.Event, error] {
return func(yield func(*session.Event, error) bool) {
yield(&session.Event{
LLMResponse: model.LLMResponse{
Content: genai.NewContentFromText(fmt.Sprintf("exiting loop %v", a.id), genai.RoleModel),
},
Actions: session.EventActions{
ExitLoop: true,
SkipSummarization: true,
},
}, nil)
}
}

func TestSequentialAgentWithNestedLoop(t *testing.T) {
tests := []struct {
name string
subAgents []agent.Agent
wantAuthors []string
}{
{
name: "exit_loop does not stop sequential pipeline",
subAgents: func() []agent.Agent {
innerLoop, err := loopagent.New(loopagent.Config{
AgentConfig: agent.Config{
Name: "inner_loop",
SubAgents: []agent.Agent{newCustomAgent(t, 0), newExitLoopAgent(t, 1)},
},
MaxIterations: 3,
})
if err != nil {
t.Fatal(err)
}
return []agent.Agent{innerLoop, newCustomAgent(t, 2)}
}(),
wantAuthors: []string{"custom_agent_0", "exit_loop_agent_1", "custom_agent_2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := t.Context()

pipeline, err := sequentialagent.New(sequentialagent.Config{
AgentConfig: agent.Config{
Name: "pipeline",
SubAgents: tt.subAgents,
},
})
if err != nil {
t.Fatal(err)
}

sessionService := session.InMemoryService()
agentRunner, err := runner.New(runner.Config{
AppName: "test_app",
Agent: pipeline,
SessionService: sessionService,
})
if err != nil {
t.Fatal(err)
}

_, err = sessionService.Create(ctx, &session.CreateRequest{
AppName: "test_app",
UserID: "user_id",
SessionID: "session_id",
})
if err != nil {
t.Fatal(err)
}

var gotEvents []*session.Event
for event, err := range agentRunner.Run(ctx, "user_id", "session_id", genai.NewContentFromText("user input", genai.RoleUser), agent.RunConfig{}) {
if err != nil {
t.Fatalf("got unexpected error: %v", err)
}
gotEvents = append(gotEvents, event)
}

if len(gotEvents) != len(tt.wantAuthors) {
var gotAuthors []string
for _, e := range gotEvents {
gotAuthors = append(gotAuthors, e.Author)
}
t.Fatalf("expected %d events %v, got %d events %v", len(tt.wantAuthors), tt.wantAuthors, len(gotEvents), gotAuthors)
}
for i, want := range tt.wantAuthors {
if gotEvents[i].Author != want {
t.Errorf("event[%d].Author = %q, want %q", i, gotEvents[i].Author, want)
}
}

// Verify the ExitLoop flag was consumed by the loop agent.
for i, event := range gotEvents {
if event.Actions.ExitLoop {
t.Errorf("event[%d] has ExitLoop=true, expected it to be consumed by the loop agent", i)
}
}
})
}
}
3 changes: 3 additions & 0 deletions internal/llminternal/base_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,9 @@ func mergeEventActions(base, other *session.EventActions) *session.EventActions
if other.Escalate {
base.Escalate = true
}
if other.ExitLoop {
base.ExitLoop = true
}
if other.StateDelta != nil {
base.StateDelta = deepMergeMap(base.StateDelta, other.StateDelta)
}
Expand Down
4 changes: 4 additions & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ type EventActions struct {
TransferToAgent string
// The agent is escalating to a higher level agent.
Escalate bool
// If true, signals that the current loop should exit.
// Used by the exit_loop tool to break out of a LoopAgent without
// propagating to parent agents (unlike Escalate).
ExitLoop bool
}

// Prefixes for defining session's state scopes
Expand Down
2 changes: 1 addition & 1 deletion tool/exitlooptool/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
type EmptyArgs struct{}

func exitLoop(ctx tool.Context, myArgs EmptyArgs) (map[string]string, error) {
ctx.Actions().Escalate = true
ctx.Actions().ExitLoop = true
ctx.Actions().SkipSummarization = true
return map[string]string{}, nil
}
Expand Down