From 55dbbcb296521a65e550c3ef7204a864655bbec6 Mon Sep 17 00:00:00 2001 From: jaxxjj Date: Mon, 12 Jan 2026 12:49:49 +0800 Subject: [PATCH 01/10] fix: race conditions in parallel agents --- agent/workflowagents/parallelagent/agent.go | 36 +++-- .../parallelagent/agent_test.go | 142 ++++++++++++++++++ ...elAgentWithToolsBackpressure_agent1.httprr | 116 ++++++++++++++ ...elAgentWithToolsBackpressure_agent2.httprr | 116 ++++++++++++++ internal/sessioninternal/mutablesession.go | 31 ++++ internal/telemetry/telemetry.go | 16 +- 6 files changed, 444 insertions(+), 13 deletions(-) create mode 100644 agent/workflowagents/parallelagent/testdata/TestParallelAgentWithToolsBackpressure_agent1.httprr create mode 100644 agent/workflowagents/parallelagent/testdata/TestParallelAgentWithToolsBackpressure_agent2.httprr diff --git a/agent/workflowagents/parallelagent/agent.go b/agent/workflowagents/parallelagent/agent.go index d8afeffd5..fff0afdb5 100644 --- a/agent/workflowagents/parallelagent/agent.go +++ b/agent/workflowagents/parallelagent/agent.go @@ -107,7 +107,14 @@ func run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { defer close(doneChan) for res := range resultsChan { - if !yield(res.event, res.err) { + shouldContinue := yield(res.event, res.err) + + // Signal sub-agent that event processing (including session append) is complete + if res.ackChan != nil { + close(res.ackChan) + } + + if !shouldContinue { break } } @@ -116,30 +123,37 @@ func run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { func runSubAgent(ctx agent.InvocationContext, agent agent.Agent, results chan<- result, done <-chan bool) error { for event, err := range agent.Run(ctx) { + ackChan := make(chan struct{}) + select { case <-done: return nil case <-ctx.Done(): - select { - case <-done: - case results <- result{ - err: ctx.Err(), - }: - } return ctx.Err() case results <- result{ - event: event, - err: err, + event: event, + err: err, + ackChan: ackChan, }: if err != nil { return err } + + // Wait for runner to finish processing before continuing to next iteration + select { + case <-ackChan: + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } } } return nil } type result struct { - event *session.Event - err error + event *session.Event + err error + ackChan chan struct{} } diff --git a/agent/workflowagents/parallelagent/agent_test.go b/agent/workflowagents/parallelagent/agent_test.go index e41ead433..dc8ecd210 100644 --- a/agent/workflowagents/parallelagent/agent_test.go +++ b/agent/workflowagents/parallelagent/agent_test.go @@ -19,7 +19,10 @@ import ( "fmt" "iter" rand "math/rand/v2" + "net/http" + "path/filepath" "slices" + "strings" "testing" "time" @@ -27,13 +30,21 @@ import ( "google.golang.org/genai" "google.golang.org/adk/agent" + "google.golang.org/adk/agent/llmagent" "google.golang.org/adk/agent/workflowagents/loopagent" "google.golang.org/adk/agent/workflowagents/parallelagent" + "google.golang.org/adk/internal/httprr" + "google.golang.org/adk/internal/testutil" "google.golang.org/adk/model" + "google.golang.org/adk/model/gemini" "google.golang.org/adk/runner" "google.golang.org/adk/session" + "google.golang.org/adk/tool" + "google.golang.org/adk/tool/functiontool" ) +const modelName = "gemini-2.0-flash-exp" + func TestNewParallelAgent(t *testing.T) { tests := []struct { name string @@ -224,3 +235,134 @@ func customRun(id int, agentErr error) func(agent.InvocationContext) iter.Seq2[* } } } + +func TestParallelAgentWithTools(t *testing.T) { + agent1 := createAgentWithGemini(t, "agent1") + agent2 := createAgentWithGemini(t, "agent2") + + parallelAgent, err := parallelagent.New(parallelagent.Config{ + AgentConfig: agent.Config{ + Name: "parallel_test", + SubAgents: []agent.Agent{agent1, agent2}, + }, + }) + if err != nil { + t.Fatalf("Failed to create parallel agent: %v", err) + } + + runner := testutil.NewTestAgentRunner(t, parallelAgent) + stream := runner.Run(t, "test_session", "Search for AI news") + + events, err := testutil.CollectEvents(stream) + if err != nil { + t.Fatalf("Agent run failed: %v", err) + } + + if len(events) < 2 { + t.Errorf("Expected at least 2 events from parallel agents, got %d", len(events)) + } + + // Count FunctionCall and FunctionResponse events per branch + branchCalls := make(map[string]int) + branchResponses := make(map[string]int) + + for _, ev := range events { + branch := ev.Branch + if ev.LLMResponse.Content != nil { + for _, part := range ev.LLMResponse.Content.Parts { + if part.FunctionCall != nil { + branchCalls[branch]++ + } + if part.FunctionResponse != nil { + branchResponses[branch]++ + } + } + } + } + + for branch, calls := range branchCalls { + responses := branchResponses[branch] + if calls > responses { + t.Errorf("Branch %s: session has %d FunctionCalls but only %d FunctionResponses. "+ + "This indicates race condition: agent read session before FunctionResponse was appended.", + branch, calls, responses) + } + } +} + +func createAgentWithGemini(t *testing.T, name string) agent.Agent { + t.Helper() + + searchTool, err := functiontool.New( + functiontool.Config{ + Name: fmt.Sprintf("search_tool_%s", name), + Description: "Search for information on the web", + }, + func(ctx tool.Context, args struct{ Query string }) (string, error) { + return fmt.Sprintf("search result for '%s' from %s", args.Query, name), nil + }, + ) + if err != nil { + t.Fatalf("Failed to create search tool: %v", err) + } + + analyzeTool, err := functiontool.New( + functiontool.Config{ + Name: fmt.Sprintf("analyze_tool_%s", name), + Description: "Analyze data and return insights", + }, + func(ctx tool.Context, args struct{ Data string }) (string, error) { + return fmt.Sprintf("analysis result for '%s' from %s", args.Data, name), nil + }, + ) + if err != nil { + t.Fatalf("Failed to create analyze tool: %v", err) + } + + model := newGeminiModelForTest(t, modelName, name) + + a, err := llmagent.New(llmagent.Config{ + Name: name, + Description: fmt.Sprintf("Test agent %s that searches for information", name), + Model: model, + Tools: []tool.Tool{searchTool, analyzeTool}, + Instruction: "Use the search tool to find information, then provide a brief response.", + }) + if err != nil { + t.Fatalf("Failed to create agent %s: %v", name, err) + } + + return a +} + +func newGeminiModelForTest(t *testing.T, modelName string, agentName string) model.LLM { + t.Helper() + + trace := filepath.Join("testdata", fmt.Sprintf("%s_%s.httprr", + strings.ReplaceAll(t.Name(), "/", "_"), agentName)) + + apiKey := "fakeKey" + transport, recording := newGeminiTestTransport(t, trace) + if recording { + apiKey = "" + } + + model, err := gemini.NewModel(t.Context(), modelName, &genai.ClientConfig{ + HTTPClient: &http.Client{Transport: transport}, + APIKey: apiKey, + }) + if err != nil { + t.Fatalf("Failed to create Gemini model: %v", err) + } + return model +} + +func newGeminiTestTransport(t *testing.T, rrfile string) (http.RoundTripper, bool) { + t.Helper() + rr, err := testutil.NewGeminiTransport(rrfile) + if err != nil { + t.Fatal(err) + } + recording, _ := httprr.Recording(rrfile) + return rr, recording +} diff --git a/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithToolsBackpressure_agent1.httprr b/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithToolsBackpressure_agent1.httprr new file mode 100644 index 000000000..adb1d6a02 --- /dev/null +++ b/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithToolsBackpressure_agent1.httprr @@ -0,0 +1,116 @@ +httprr trace v1 +994 1148 +POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent HTTP/1.1 +Host: generativelanguage.googleapis.com +User-Agent: Go-http-client/1.1 +Content-Length: 758 +Content-Type: application/json + +{"contents":[{"parts":[{"text":"Search for AI news"}],"role":"user"}],"generationConfig":{},"systemInstruction":{"parts":[{"text":"Use the search tool to find information, then provide a brief response."}],"role":"user"},"tools":[{"functionDeclarations":[{"description":"Search for information on the web","name":"search_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Query":{"type":"string"}},"required":["Query"],"type":"object"},"responseJsonSchema":{"type":"string"}},{"description":"Analyze data and return insights","name":"analyze_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Data":{"type":"string"}},"required":["Data"],"type":"object"},"responseJsonSchema":{"type":"string"}}]}]}HTTP/2.0 200 OK +Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 +Content-Type: application/json; charset=UTF-8 +Date: Mon, 12 Jan 2026 03:32:51 GMT +Server: scaffolding on HTTPServer2 +Server-Timing: gfet4t7; dur=661 +Vary: Origin +Vary: X-Origin +Vary: Referer +X-Content-Type-Options: nosniff +X-Frame-Options: SAMEORIGIN +X-Xss-Protection: 0 + +{ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "search_tool_agent1", + "args": { + "Query": "AI news" + } + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -1.7506775394495991e-05 + } + ], + "usageMetadata": { + "promptTokenCount": 41, + "candidatesTokenCount": 9, + "totalTokenCount": 50, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 41 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 9 + } + ] + }, + "modelVersion": "gemini-2.0-flash-exp", + "responseId": "Ymtkad2GMvz_2roPv6WLmQs" +} +1237 1033 +POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent HTTP/1.1 +Host: generativelanguage.googleapis.com +User-Agent: Go-http-client/1.1 +Content-Length: 1000 +Content-Type: application/json + +{"contents":[{"parts":[{"text":"Search for AI news"}],"role":"user"},{"parts":[{"functionCall":{"args":{"Query":"AI news"},"name":"search_tool_agent1"}}],"role":"model"},{"parts":[{"functionResponse":{"name":"search_tool_agent1","response":{"result":"search result for 'AI news' from agent1"}}}],"role":"user"}],"generationConfig":{},"systemInstruction":{"parts":[{"text":"Use the search tool to find information, then provide a brief response."}],"role":"user"},"tools":[{"functionDeclarations":[{"description":"Search for information on the web","name":"search_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Query":{"type":"string"}},"required":["Query"],"type":"object"},"responseJsonSchema":{"type":"string"}},{"description":"Analyze data and return insights","name":"analyze_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Data":{"type":"string"}},"required":["Data"],"type":"object"},"responseJsonSchema":{"type":"string"}}]}]}HTTP/2.0 200 OK +Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 +Content-Type: application/json; charset=UTF-8 +Date: Mon, 12 Jan 2026 03:32:52 GMT +Server: scaffolding on HTTPServer2 +Server-Timing: gfet4t7; dur=656 +Vary: Origin +Vary: X-Origin +Vary: Referer +X-Content-Type-Options: nosniff +X-Frame-Options: SAMEORIGIN +X-Xss-Protection: 0 + +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "I have searched for AI news." + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -0.03738286665507725 + } + ], + "usageMetadata": { + "promptTokenCount": 67, + "candidatesTokenCount": 7, + "totalTokenCount": 74, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 67 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 7 + } + ] + }, + "modelVersion": "gemini-2.0-flash-exp", + "responseId": "Y2tkaaH-NJKd0-kP9puPwAc" +} diff --git a/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithToolsBackpressure_agent2.httprr b/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithToolsBackpressure_agent2.httprr new file mode 100644 index 000000000..3ee793c9b --- /dev/null +++ b/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithToolsBackpressure_agent2.httprr @@ -0,0 +1,116 @@ +httprr trace v1 +994 1148 +POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent HTTP/1.1 +Host: generativelanguage.googleapis.com +User-Agent: Go-http-client/1.1 +Content-Length: 758 +Content-Type: application/json + +{"contents":[{"parts":[{"text":"Search for AI news"}],"role":"user"}],"generationConfig":{},"systemInstruction":{"parts":[{"text":"Use the search tool to find information, then provide a brief response."}],"role":"user"},"tools":[{"functionDeclarations":[{"description":"Search for information on the web","name":"search_tool_agent2","parametersJsonSchema":{"additionalProperties":false,"properties":{"Query":{"type":"string"}},"required":["Query"],"type":"object"},"responseJsonSchema":{"type":"string"}},{"description":"Analyze data and return insights","name":"analyze_tool_agent2","parametersJsonSchema":{"additionalProperties":false,"properties":{"Data":{"type":"string"}},"required":["Data"],"type":"object"},"responseJsonSchema":{"type":"string"}}]}]}HTTP/2.0 200 OK +Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 +Content-Type: application/json; charset=UTF-8 +Date: Mon, 12 Jan 2026 03:32:51 GMT +Server: scaffolding on HTTPServer2 +Server-Timing: gfet4t7; dur=922 +Vary: Origin +Vary: X-Origin +Vary: Referer +X-Content-Type-Options: nosniff +X-Frame-Options: SAMEORIGIN +X-Xss-Protection: 0 + +{ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "search_tool_agent2", + "args": { + "Query": "AI news" + } + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -1.1056118334333101e-05 + } + ], + "usageMetadata": { + "promptTokenCount": 41, + "candidatesTokenCount": 9, + "totalTokenCount": 50, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 41 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 9 + } + ] + }, + "modelVersion": "gemini-2.0-flash-exp", + "responseId": "YmtkafzJPLKx2roPsNnkiA4" +} +1237 1034 +POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent HTTP/1.1 +Host: generativelanguage.googleapis.com +User-Agent: Go-http-client/1.1 +Content-Length: 1000 +Content-Type: application/json + +{"contents":[{"parts":[{"text":"Search for AI news"}],"role":"user"},{"parts":[{"functionCall":{"args":{"Query":"AI news"},"name":"search_tool_agent2"}}],"role":"model"},{"parts":[{"functionResponse":{"name":"search_tool_agent2","response":{"result":"search result for 'AI news' from agent2"}}}],"role":"user"}],"generationConfig":{},"systemInstruction":{"parts":[{"text":"Use the search tool to find information, then provide a brief response."}],"role":"user"},"tools":[{"functionDeclarations":[{"description":"Search for information on the web","name":"search_tool_agent2","parametersJsonSchema":{"additionalProperties":false,"properties":{"Query":{"type":"string"}},"required":["Query"],"type":"object"},"responseJsonSchema":{"type":"string"}},{"description":"Analyze data and return insights","name":"analyze_tool_agent2","parametersJsonSchema":{"additionalProperties":false,"properties":{"Data":{"type":"string"}},"required":["Data"],"type":"object"},"responseJsonSchema":{"type":"string"}}]}]}HTTP/2.0 200 OK +Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 +Content-Type: application/json; charset=UTF-8 +Date: Mon, 12 Jan 2026 03:32:52 GMT +Server: scaffolding on HTTPServer2 +Server-Timing: gfet4t7; dur=740 +Vary: Origin +Vary: X-Origin +Vary: Referer +X-Content-Type-Options: nosniff +X-Frame-Options: SAMEORIGIN +X-Xss-Protection: 0 + +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "I have searched for AI news." + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -0.010005326143332891 + } + ], + "usageMetadata": { + "promptTokenCount": 67, + "candidatesTokenCount": 7, + "totalTokenCount": 74, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 67 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 7 + } + ] + }, + "modelVersion": "gemini-2.0-flash-exp", + "responseId": "ZGtkadaWAtnG2roP5uuq4Qc" +} diff --git a/internal/sessioninternal/mutablesession.go b/internal/sessioninternal/mutablesession.go index 1c68f8ef7..c4881b255 100644 --- a/internal/sessioninternal/mutablesession.go +++ b/internal/sessioninternal/mutablesession.go @@ -15,8 +15,10 @@ package sessioninternal import ( + "context" "fmt" "iter" + "sync" "time" "google.golang.org/adk/session" @@ -26,6 +28,7 @@ import ( type MutableSession struct { service session.Service storedSession session.Session + mu sync.RWMutex } // NewMutableSession creates and returns session.Session implementation. @@ -41,26 +44,50 @@ func (s *MutableSession) State() session.State { } func (s *MutableSession) AppName() string { + s.mu.RLock() + defer s.mu.RUnlock() return s.storedSession.AppName() } func (s *MutableSession) UserID() string { + s.mu.RLock() + defer s.mu.RUnlock() return s.storedSession.UserID() } func (s *MutableSession) ID() string { + s.mu.RLock() + defer s.mu.RUnlock() return s.storedSession.ID() } func (s *MutableSession) Events() session.Events { + s.mu.Lock() + defer s.mu.Unlock() + + ctx := context.Background() + resp, err := s.service.Get(ctx, &session.GetRequest{ + AppName: s.storedSession.AppName(), + UserID: s.storedSession.UserID(), + SessionID: s.storedSession.ID(), + }) + if err != nil { + return s.storedSession.Events() + } + + s.storedSession = resp.Session return s.storedSession.Events() } func (s *MutableSession) LastUpdateTime() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() return s.storedSession.LastUpdateTime() } func (s *MutableSession) Get(key string) (any, error) { + s.mu.RLock() + defer s.mu.RUnlock() value, err := s.storedSession.State().Get(key) if err != nil { return nil, fmt.Errorf("failed to get key %q from state: %w", key, err) @@ -69,10 +96,14 @@ func (s *MutableSession) Get(key string) (any, error) { } func (s *MutableSession) All() iter.Seq2[string, any] { + s.mu.RLock() + defer s.mu.RUnlock() return s.storedSession.State().All() } func (s *MutableSession) Set(key string, value any) error { + s.mu.Lock() + defer s.mu.Unlock() mutableState, ok := s.storedSession.State().(MutableState) if !ok { return fmt.Errorf("this session state is not mutable") diff --git a/internal/telemetry/telemetry.go b/internal/telemetry/telemetry.go index 640ab775c..087ac6785 100644 --- a/internal/telemetry/telemetry.go +++ b/internal/telemetry/telemetry.go @@ -47,6 +47,7 @@ type tracerProviderConfig struct { var ( once sync.Once localTracer tracerProviderHolder + localTracerMu sync.RWMutex localTracerConfig = tracerProviderConfig{ spanProcessors: []sdktrace.SpanProcessor{}, mu: &sync.RWMutex{}, @@ -101,7 +102,10 @@ func RegisterTelemetry() { for _, processor := range spanProcessors { traceProvider.RegisterSpanProcessor(processor) } + + localTracerMu.Lock() localTracer = tracerProviderHolder{tp: traceProvider} + localTracerMu.Unlock() }) } @@ -109,11 +113,19 @@ func RegisterTelemetry() { // That means that the spans are NOT recording/exporting // If the local tracer is not set, we'll set up tracer with all registered span processors. func getTracers() []trace.Tracer { - if localTracer.tp == nil { + localTracerMu.RLock() + tp := localTracer.tp + localTracerMu.RUnlock() + + if tp == nil { RegisterTelemetry() + localTracerMu.RLock() + tp = localTracer.tp + localTracerMu.RUnlock() } + return []trace.Tracer{ - localTracer.tp.Tracer(systemName), + tp.Tracer(systemName), otel.GetTracerProvider().Tracer(systemName), } } From 750f444dfaf4bcba6943b81f14994e2c7c527429 Mon Sep 17 00:00:00 2001 From: jaxxjj Date: Mon, 12 Jan 2026 12:58:13 +0800 Subject: [PATCH 02/10] fix --- internal/sessioninternal/mutablesession.go | 3 +++ internal/telemetry/telemetry.go | 9 ++------- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/internal/sessioninternal/mutablesession.go b/internal/sessioninternal/mutablesession.go index c4881b255..c9d276771 100644 --- a/internal/sessioninternal/mutablesession.go +++ b/internal/sessioninternal/mutablesession.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "iter" + "log" "sync" "time" @@ -72,6 +73,8 @@ func (s *MutableSession) Events() session.Events { SessionID: s.storedSession.ID(), }) if err != nil { + log.Printf("MutableSession: failed to fetch latest session (app=%s, user=%s, session=%s), using cached version: %v", + s.storedSession.AppName(), s.storedSession.UserID(), s.storedSession.ID(), err) return s.storedSession.Events() } diff --git a/internal/telemetry/telemetry.go b/internal/telemetry/telemetry.go index 087ac6785..ab75c04fc 100644 --- a/internal/telemetry/telemetry.go +++ b/internal/telemetry/telemetry.go @@ -113,17 +113,12 @@ func RegisterTelemetry() { // That means that the spans are NOT recording/exporting // If the local tracer is not set, we'll set up tracer with all registered span processors. func getTracers() []trace.Tracer { + RegisterTelemetry() + localTracerMu.RLock() tp := localTracer.tp localTracerMu.RUnlock() - if tp == nil { - RegisterTelemetry() - localTracerMu.RLock() - tp = localTracer.tp - localTracerMu.RUnlock() - } - return []trace.Tracer{ tp.Tracer(systemName), otel.GetTracerProvider().Tracer(systemName), From 7e0d0d07781210aa5eb6d63161c2748cc4b084d3 Mon Sep 17 00:00:00 2001 From: jaxxjj Date: Mon, 12 Jan 2026 13:16:58 +0800 Subject: [PATCH 03/10] file name --- ...ure_agent1.httprr => TestParallelAgentWithTools_agent1.httprr} | 0 ...ure_agent2.httprr => TestParallelAgentWithTools_agent2.httprr} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename agent/workflowagents/parallelagent/testdata/{TestParallelAgentWithToolsBackpressure_agent1.httprr => TestParallelAgentWithTools_agent1.httprr} (100%) rename agent/workflowagents/parallelagent/testdata/{TestParallelAgentWithToolsBackpressure_agent2.httprr => TestParallelAgentWithTools_agent2.httprr} (100%) diff --git a/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithToolsBackpressure_agent1.httprr b/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent1.httprr similarity index 100% rename from agent/workflowagents/parallelagent/testdata/TestParallelAgentWithToolsBackpressure_agent1.httprr rename to agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent1.httprr diff --git a/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithToolsBackpressure_agent2.httprr b/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent2.httprr similarity index 100% rename from agent/workflowagents/parallelagent/testdata/TestParallelAgentWithToolsBackpressure_agent2.httprr rename to agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent2.httprr From bf9a189bbfef1b17e881889baf0e7d6b89155aef Mon Sep 17 00:00:00 2001 From: westerberg Date: Tue, 3 Feb 2026 14:49:06 +0000 Subject: [PATCH 04/10] refactor: centralize session concurrency management by removing mutexes from `MutableSession` and adding them to `InMemorySession`'s event and state updates, along with a new deadlock test. --- .../parallelagent/agent_test.go | 2 +- internal/sessioninternal/mutablesession.go | 34 ------------------ session/inmemory.go | 11 +++--- session/inmemory_test.go | 36 +++++++++++++++++++ 4 files changed, 43 insertions(+), 40 deletions(-) diff --git a/agent/workflowagents/parallelagent/agent_test.go b/agent/workflowagents/parallelagent/agent_test.go index dc8ecd210..2a589a7bf 100644 --- a/agent/workflowagents/parallelagent/agent_test.go +++ b/agent/workflowagents/parallelagent/agent_test.go @@ -335,7 +335,7 @@ func createAgentWithGemini(t *testing.T, name string) agent.Agent { return a } -func newGeminiModelForTest(t *testing.T, modelName string, agentName string) model.LLM { +func newGeminiModelForTest(t *testing.T, modelName, agentName string) model.LLM { t.Helper() trace := filepath.Join("testdata", fmt.Sprintf("%s_%s.httprr", diff --git a/internal/sessioninternal/mutablesession.go b/internal/sessioninternal/mutablesession.go index c9d276771..1c68f8ef7 100644 --- a/internal/sessioninternal/mutablesession.go +++ b/internal/sessioninternal/mutablesession.go @@ -15,11 +15,8 @@ package sessioninternal import ( - "context" "fmt" "iter" - "log" - "sync" "time" "google.golang.org/adk/session" @@ -29,7 +26,6 @@ import ( type MutableSession struct { service session.Service storedSession session.Session - mu sync.RWMutex } // NewMutableSession creates and returns session.Session implementation. @@ -45,52 +41,26 @@ func (s *MutableSession) State() session.State { } func (s *MutableSession) AppName() string { - s.mu.RLock() - defer s.mu.RUnlock() return s.storedSession.AppName() } func (s *MutableSession) UserID() string { - s.mu.RLock() - defer s.mu.RUnlock() return s.storedSession.UserID() } func (s *MutableSession) ID() string { - s.mu.RLock() - defer s.mu.RUnlock() return s.storedSession.ID() } func (s *MutableSession) Events() session.Events { - s.mu.Lock() - defer s.mu.Unlock() - - ctx := context.Background() - resp, err := s.service.Get(ctx, &session.GetRequest{ - AppName: s.storedSession.AppName(), - UserID: s.storedSession.UserID(), - SessionID: s.storedSession.ID(), - }) - if err != nil { - log.Printf("MutableSession: failed to fetch latest session (app=%s, user=%s, session=%s), using cached version: %v", - s.storedSession.AppName(), s.storedSession.UserID(), s.storedSession.ID(), err) - return s.storedSession.Events() - } - - s.storedSession = resp.Session return s.storedSession.Events() } func (s *MutableSession) LastUpdateTime() time.Time { - s.mu.RLock() - defer s.mu.RUnlock() return s.storedSession.LastUpdateTime() } func (s *MutableSession) Get(key string) (any, error) { - s.mu.RLock() - defer s.mu.RUnlock() value, err := s.storedSession.State().Get(key) if err != nil { return nil, fmt.Errorf("failed to get key %q from state: %w", key, err) @@ -99,14 +69,10 @@ func (s *MutableSession) Get(key string) (any, error) { } func (s *MutableSession) All() iter.Seq2[string, any] { - s.mu.RLock() - defer s.mu.RUnlock() return s.storedSession.State().All() } func (s *MutableSession) Set(key string, value any) error { - s.mu.Lock() - defer s.mu.Unlock() mutableState, ok := s.storedSession.State().(MutableState) if !ok { return fmt.Errorf("this session state is not mutable") diff --git a/session/inmemory.go b/session/inmemory.go index ea28af2ac..e5b113d11 100644 --- a/session/inmemory.go +++ b/session/inmemory.go @@ -314,6 +314,8 @@ func (s *session) State() State { } func (s *session) Events() Events { + s.mu.RLock() + defer s.mu.RUnlock() return events(s.events) } @@ -329,6 +331,9 @@ func (s *session) appendEvent(event *Event) error { return nil } + s.mu.Lock() + defer s.mu.Unlock() + processedEvent := trimTempDeltaState(event) if err := updateSessionState(s, processedEvent); err != nil { return fmt.Errorf("error on appendEvent: %w", err) @@ -434,15 +439,11 @@ func updateSessionState(session *session, event *Event) error { session.state = make(map[string]any) } - state := session.State() for key, value := range event.Actions.StateDelta { if strings.HasPrefix(key, KeyPrefixTemp) { continue } - err := state.Set(key, value) - if err != nil { - return fmt.Errorf("error on updateSessionState state: %w", err) - } + session.state[key] = value } return nil } diff --git a/session/inmemory_test.go b/session/inmemory_test.go index d68e2fdd6..92a6fa30a 100644 --- a/session/inmemory_test.go +++ b/session/inmemory_test.go @@ -1049,3 +1049,39 @@ func Test_inMemoryService_CreateConcurrentAccess(t *testing.T) { t.Errorf("expected %d 'already exists' errors, but got %d", expectedErrors, errorCount.Load()) } } + +func TestInMemorySession_AppendEvent_Deadlock(t *testing.T) { + ctx := t.Context() + service := InMemoryService() + + // Create a session + createReq := &CreateRequest{ + AppName: "testapp", + UserID: "testuser", + } + createResp, err := service.Create(ctx, createReq) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + sess := createResp.Session + + // Event with StateDelta to trigger updateSessionState + event := &Event{ + ID: "event1", + Timestamp: time.Now(), + Actions: EventActions{ + StateDelta: map[string]any{ + "test_key": "test_value", + }, + }, + } + + // This call should hang if the deadlock is present + err = service.AppendEvent(ctx, sess, event) + if err != nil { + t.Fatalf("AppendEvent failed: %v", err) + } + + // If it doesn't hang, the test passes (meaning no deadlock) + t.Log("AppendEvent did not deadlock") +} From d8e529f39fe37e63d3474aea00870e577d2bfdac Mon Sep 17 00:00:00 2001 From: westerberg Date: Fri, 6 Feb 2026 13:27:25 +0000 Subject: [PATCH 05/10] fix yield errgroup.Wait() err --- agent/workflowagents/parallelagent/agent.go | 16 ++-- .../parallelagent/agent_test.go | 94 +++++++++++++++++++ 2 files changed, 104 insertions(+), 6 deletions(-) diff --git a/agent/workflowagents/parallelagent/agent.go b/agent/workflowagents/parallelagent/agent.go index fff0afdb5..63ee7a59c 100644 --- a/agent/workflowagents/parallelagent/agent.go +++ b/agent/workflowagents/parallelagent/agent.go @@ -99,7 +99,12 @@ func run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { } go func() { - _ = errGroup.Wait() // this error is already sent to the user via iterator + if err := errGroup.Wait(); err != nil { + select { + case resultsChan <- result{err: err}: + case <-doneChan: + } + } close(resultsChan) }() @@ -123,6 +128,10 @@ func run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { func runSubAgent(ctx agent.InvocationContext, agent agent.Agent, results chan<- result, done <-chan bool) error { for event, err := range agent.Run(ctx) { + if err != nil { + return err + } + ackChan := make(chan struct{}) select { @@ -132,13 +141,8 @@ func runSubAgent(ctx agent.InvocationContext, agent agent.Agent, results chan<- return ctx.Err() case results <- result{ event: event, - err: err, ackChan: ackChan, }: - if err != nil { - return err - } - // Wait for runner to finish processing before continuing to next iteration select { case <-ackChan: diff --git a/agent/workflowagents/parallelagent/agent_test.go b/agent/workflowagents/parallelagent/agent_test.go index 2a589a7bf..018ae2ee2 100644 --- a/agent/workflowagents/parallelagent/agent_test.go +++ b/agent/workflowagents/parallelagent/agent_test.go @@ -366,3 +366,97 @@ func newGeminiTestTransport(t *testing.T, rrfile string) (http.RoundTripper, boo recording, _ := httprr.Recording(rrfile) return rr, recording } + +// TestParallelAgent_PropagatesContextError verifies that if the context is canceled, +// the iterator yields the error from errgroup.Wait(). +func TestParallelAgent_PropagatesContextError(t *testing.T) { + t.Parallel() + + // Create a sub-agent that yields an event and then waits. + // We want to trigger runSubAgent returning ctx.Err(). + subAgent := must(agent.New(agent.Config{ + Name: "yielder", + Run: func(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { + return func(yield func(*session.Event, error) bool) { + // Yield one event so we engage runSubAgent logic + if !yield(&session.Event{ + LLMResponse: model.LLMResponse{ + Content: genai.NewContentFromText("hello", genai.RoleModel), + }, + }, nil) { + return + } + + // Wait for context cancellation + <-ctx.Done() + } + }, + })) + + parallelAgent, err := parallelagent.New(parallelagent.Config{ + AgentConfig: agent.Config{ + Name: "parallel_agent", + SubAgents: []agent.Agent{subAgent}, + }, + }) + if err != nil { + t.Fatal(err) + } + + spy := &spyAgent{Agent: parallelAgent} + + ctx, cancel := context.WithCancel(t.Context()) + + sessionService := session.InMemoryService() + _, _ = sessionService.Create(ctx, &session.CreateRequest{ + AppName: "test_app", + UserID: "user_id", + SessionID: "session_id", + }) + + r, err := runner.New(runner.Config{ + AppName: "test_app", + Agent: spy, + SessionService: sessionService, + }) + if err != nil { + t.Fatal(err) + } + + go func() { + // Wait a tiny bit to ensure we started + time.Sleep(10 * time.Millisecond) + cancel() + }() + + for _, _ = range r.Run(ctx, "user_id", "session_id", genai.NewContentFromText("hi", genai.RoleUser), agent.RunConfig{}) { + // Simulate processing delay so that ackChan takes time, + // increasing chance runSubAgent is blocked on ackChan when cancel happens? + time.Sleep(100 * time.Millisecond) + } + + if spy.yieldedError == nil { + t.Fatal("Expected parallelAgent to yield an error (e.g. context canceled), but it yielded nil") + } + + t.Logf("Yielded error: %v", spy.yieldedError) +} + +type spyAgent struct { + agent.Agent + yieldedError error +} + +func (s *spyAgent) Run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { + next := s.Agent.Run(ctx) + return func(yield func(*session.Event, error) bool) { + for event, err := range next { + if err != nil { + s.yieldedError = err + } + if !yield(event, err) { + return + } + } + } +} From 965b3b196144aa7a77fb26b866b6e4f80c4c34fb Mon Sep 17 00:00:00 2001 From: westerberg Date: Fri, 6 Feb 2026 13:40:00 +0000 Subject: [PATCH 06/10] lint fix --- agent/workflowagents/parallelagent/agent_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent/workflowagents/parallelagent/agent_test.go b/agent/workflowagents/parallelagent/agent_test.go index 018ae2ee2..4167c0988 100644 --- a/agent/workflowagents/parallelagent/agent_test.go +++ b/agent/workflowagents/parallelagent/agent_test.go @@ -429,7 +429,7 @@ func TestParallelAgent_PropagatesContextError(t *testing.T) { cancel() }() - for _, _ = range r.Run(ctx, "user_id", "session_id", genai.NewContentFromText("hi", genai.RoleUser), agent.RunConfig{}) { + for range r.Run(ctx, "user_id", "session_id", genai.NewContentFromText("hi", genai.RoleUser), agent.RunConfig{}) { // Simulate processing delay so that ackChan takes time, // increasing chance runSubAgent is blocked on ackChan when cancel happens? time.Sleep(100 * time.Millisecond) From 03c513149cabf3059d43c63ef859c3b79ff32ba7 Mon Sep 17 00:00:00 2001 From: westerberg Date: Thu, 19 Feb 2026 16:27:26 +0000 Subject: [PATCH 07/10] Add clone to state All and store a copy of the event to stored_session --- session/database/session.go | 13 ++++++------- session/inmemory.go | 35 ++++++++++++++++++++++++++--------- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/session/database/session.go b/session/database/session.go index d444ec4af..e925844bd 100644 --- a/session/database/session.go +++ b/session/database/session.go @@ -126,18 +126,17 @@ func (s *state) Get(key string) (any, error) { } func (s *state) All() iter.Seq2[string, any] { - return func(yield func(key string, val any) bool) { - s.mu.RLock() + s.mu.RLock() + // Create a copy of the state to iterate over it without holding the lock. + stateCopy := maps.Clone(s.state) + s.mu.RUnlock() - for k, v := range s.state { - s.mu.RUnlock() + return func(yield func(key string, val any) bool) { + for k, v := range stateCopy { if !yield(k, v) { return } - s.mu.RLock() } - - s.mu.RUnlock() } } diff --git a/session/inmemory.go b/session/inmemory.go index bd92d31f9..b1f7ccc7f 100644 --- a/session/inmemory.go +++ b/session/inmemory.go @@ -218,13 +218,31 @@ func (s *inMemoryService) AppendEvent(ctx context.Context, curSession Session, e return fmt.Errorf("session not found, cannot apply event") } + eventCopy := &Event{ + ID: event.ID, + InvocationID: event.InvocationID, + Timestamp: event.Timestamp, + Author: event.Author, + Branch: event.Branch, + Actions: EventActions{ + StateDelta: maps.Clone(event.Actions.StateDelta), + ArtifactDelta: maps.Clone(event.Actions.ArtifactDelta), + RequestedToolConfirmations: maps.Clone(event.Actions.RequestedToolConfirmations), + TransferToAgent: event.Actions.TransferToAgent, + Escalate: event.Actions.Escalate, + SkipSummarization: event.Actions.SkipSummarization, + }, + LongRunningToolIDs: slices.Clone(event.LongRunningToolIDs), + LLMResponse: event.LLMResponse, + } + // update the in-memory session if err := sess.appendEvent(event); err != nil { return fmt.Errorf("fail to set state on appendEvent: %w", err) } // update the in-memory session service - stored_session.events = append(stored_session.events, event) + stored_session.events = append(stored_session.events, eventCopy) stored_session.updatedAt = event.Timestamp if len(event.Actions.StateDelta) > 0 { appDelta, userDelta, sessionDelta := sessionutils.ExtractStateDeltas(event.Actions.StateDelta) @@ -385,18 +403,17 @@ func (s *state) Get(key string) (any, error) { } func (s *state) All() iter.Seq2[string, any] { - return func(yield func(key string, val any) bool) { - s.mu.RLock() + s.mu.RLock() + // Create a copy of the state to iterate over it without holding the lock. + stateCopy := maps.Clone(s.state) + s.mu.RUnlock() - for k, v := range s.state { - s.mu.RUnlock() + return func(yield func(key string, val any) bool) { + for k, v := range stateCopy { if !yield(k, v) { return } - s.mu.RLock() } - - s.mu.RUnlock() } } @@ -454,4 +471,4 @@ func copySessionWithoutStateAndEvents(sess *session) *session { } } -var _ Service = (*inMemoryService)(nil) +var _ Service = (*inMemoryService)(nil) \ No newline at end of file From b01a75f8af9dc93f952f9ec58b86f2c4768b0276 Mon Sep 17 00:00:00 2001 From: westerberg Date: Thu, 19 Feb 2026 16:53:35 +0000 Subject: [PATCH 08/10] fix temp event state delta clear --- session/inmemory.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/session/inmemory.go b/session/inmemory.go index b1f7ccc7f..de4082f9f 100644 --- a/session/inmemory.go +++ b/session/inmemory.go @@ -218,6 +218,11 @@ func (s *inMemoryService) AppendEvent(ctx context.Context, curSession Session, e return fmt.Errorf("session not found, cannot apply event") } + // update the in-memory session + if err := sess.appendEvent(event); err != nil { + return fmt.Errorf("fail to set state on appendEvent: %w", err) + } + eventCopy := &Event{ ID: event.ID, InvocationID: event.InvocationID, @@ -236,11 +241,6 @@ func (s *inMemoryService) AppendEvent(ctx context.Context, curSession Session, e LLMResponse: event.LLMResponse, } - // update the in-memory session - if err := sess.appendEvent(event); err != nil { - return fmt.Errorf("fail to set state on appendEvent: %w", err) - } - // update the in-memory session service stored_session.events = append(stored_session.events, eventCopy) stored_session.updatedAt = event.Timestamp @@ -471,4 +471,4 @@ func copySessionWithoutStateAndEvents(sess *session) *session { } } -var _ Service = (*inMemoryService)(nil) \ No newline at end of file +var _ Service = (*inMemoryService)(nil) From 081d3ea08872a25e952621767ffcf52466c6c557 Mon Sep 17 00:00:00 2001 From: westerberg Date: Tue, 3 Mar 2026 13:18:30 +0000 Subject: [PATCH 09/10] Add TestParallelAgent_StateSync --- .../parallelagent/agent_test.go | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/agent/workflowagents/parallelagent/agent_test.go b/agent/workflowagents/parallelagent/agent_test.go index 4167c0988..bbaf22d31 100644 --- a/agent/workflowagents/parallelagent/agent_test.go +++ b/agent/workflowagents/parallelagent/agent_test.go @@ -460,3 +460,78 @@ func (s *spyAgent) Run(ctx agent.InvocationContext) iter.Seq2[*session.Event, er } } } + +func TestParallelAgent_StateSync(t *testing.T) { + ctx := t.Context() + + var gotValue any + var gotErr error + + subAgent, err := agent.New(agent.Config{ + Name: "test_subagent", + Run: func(agent.InvocationContext) iter.Seq2[*session.Event, error] { + return func(yield func(*session.Event, error) bool) { + event := &session.Event{ + LLMResponse: model.LLMResponse{ + Content: genai.NewContentFromText("hello", genai.RoleModel), + }, + Actions: session.EventActions{ + StateDelta: map[string]any{"test_key": "test_value"}, + }, + } + yield(event, nil) + } + }, + AfterAgentCallbacks: []agent.AfterAgentCallback{ + func(c agent.CallbackContext) (*genai.Content, error) { + gotValue, gotErr = c.State().Get("test_key") + return nil, nil + }, + }, + }) + if err != nil { + t.Fatal(err) + } + + parallelAgent, err := parallelagent.New(parallelagent.Config{ + AgentConfig: agent.Config{ + Name: "test_parallel_agent", + SubAgents: []agent.Agent{subAgent}, + }, + }) + if err != nil { + t.Fatal(err) + } + + sessionService := session.InMemoryService() + agentRunner, err := runner.New(runner.Config{ + AppName: "test_app", + Agent: parallelAgent, + SessionService: sessionService, + }) + if err != nil { + t.Fatal(err) + } + + _, err = sessionService.Create(ctx, &session.CreateRequest{ + AppName: "test_app", + UserID: "user_id", + SessionID: "session_id", + }) + if err != nil { + t.Fatal(err) + } + + for _, err := range agentRunner.Run(ctx, "user_id", "session_id", genai.NewContentFromText("user input", genai.RoleUser), agent.RunConfig{}) { + if err != nil { + t.Fatal(err) + } + } + + if gotErr != nil { + t.Fatalf("expected to get value from state, got error: %v", gotErr) + } + if gotValue != "test_value" { + t.Fatalf("expected state value 'test_value', got %v", gotValue) + } +} \ No newline at end of file From ab054adfbacfbeb1111ebefcde1bf5fcbac0727f Mon Sep 17 00:00:00 2001 From: westerberg Date: Tue, 3 Mar 2026 13:30:34 +0000 Subject: [PATCH 10/10] lint fix --- agent/workflowagents/parallelagent/agent_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent/workflowagents/parallelagent/agent_test.go b/agent/workflowagents/parallelagent/agent_test.go index bbaf22d31..90466c385 100644 --- a/agent/workflowagents/parallelagent/agent_test.go +++ b/agent/workflowagents/parallelagent/agent_test.go @@ -534,4 +534,4 @@ func TestParallelAgent_StateSync(t *testing.T) { if gotValue != "test_value" { t.Fatalf("expected state value 'test_value', got %v", gotValue) } -} \ No newline at end of file +}